In [1]:
import matplotlib.pyplot as plt
import math
import random
from collections import namedtuple, deque
from itertools import count
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pickle
import copy
from scipy.interpolate import make_interp_spline, BSpline
from io import StringIO 

CPBP_MODEL_PATH = 'MiniCPBP/cpbpnewnew.jar'

#### stuff

In [2]:
#@title Marginals Manager
from subprocess import Popen, PIPE, STDOUT
#from signal import signal, SIGPIPE, SIG_DFL
import json
import time
import psutil

class Marginals():
    def __init__(self,length, note_max, composition, jar_path,rules=[1 for i in range(12)]):
        self.composition = copy.deepcopy(composition)
        self.jar_path    = jar_path
        self.length      = length
        self.note_max    = note_max
        self.done = False
        self.marginals= {str(j):0 for j in range(0,note_max)}
        self.expected = {i:0 for i in range(30)}
        self.violations = [0 for i in range(12)]
        self.unsat = False
        self.write_ready = False
        self.rules = rules
        self.deleted = False
        jar_array = ['java', '-cp',jar_path,'minicpbp.examples.Counterpoint_Soft']
        jar_array.extend([str(i) for i in composition])
        jar_array.extend([str(i) for i in rules])
        jar_array += [str(length)]
        p = Popen(jar_array, stdin=PIPE, stdout=PIPE, stderr=STDOUT)
        self.p = p
        process_msg = self.get_process_msg()
        try: 
            self.unsat = bool(process_msg['sat_exception'])
        except:
            print('Exception occured...')
            print(process_msg)
        self.done  = bool(process_msg['done'])
        self.expected  = process_msg['expected']
        self.expected_marginals  = process_msg['E_marginals']
        self.violations = process_msg['violations']
        while((not self.write_ready) and (not self.unsat) and (not self.done)):
            garbage_msg = self.get_process_msg()
            self.write_ready = bool(garbage_msg['read'])
            self.done = self.done or bool(garbage_msg['done'])
            self.unsat = self.unsat or bool(garbage_msg['sat_exception'])
        #if((process_msg['done'] +  process_msg['sat_exception'] == 0)):
        #    self.expected  = process_msg['expected']
            #self.marginals  = process_msg['marginals'][str(len(self.composition))]
        #self.violations = process_msg['violations']

    def step(self,note):
        if((self.done + self.unsat == 0) and self.write_ready and (len(self.composition)<self.length)):
            self.send_msg(note)
            process_msg = self.get_process_msg()
            if(process_msg['expected']>=0):
                self.expected = process_msg['expected']
            if(len(process_msg['violations'].keys())>0):
                self.violations = process_msg['violations']
            if(len(process_msg['E_marginals'].keys())>0):
                self.expected_marginals  = process_msg['E_marginals']

            while((not self.write_ready) and (not self.unsat) and (not self.done)):
                garbage_msg = self.get_process_msg()
                self.write_ready = bool(garbage_msg['read'])
                self.done = self.done or bool(garbage_msg['done'])
                self.unsat = self.unsat or bool(garbage_msg['sat_exception'])
            self.composition += [note]

            self.done      = self.done or bool(process_msg['done'])
            self.unsat     = self.unsat or bool(process_msg['sat_exception'])            
            #print("done: ",process_msg['done'],"unsat: ",process_msg['sat_exception']," len: ",len(self.composition)," violations:",self.violations)
    
    def get_marginals(self):
        expected_marginals = {str(j):0 for j in range(0,300)}
        if(self.unsat):
            print('Unsat...')
            return expected_marginals
        for k in self.expected_marginals.keys():
            expected_marginals[k] = self.expected_marginals[k]
        return expected_marginals
    
    def get_expected(self):
        if(self.unsat):
            #print('Unsat error!')
            #print(self.composition)
            return 500
        return self.expected

    def get_violations(self):
        return self.violations

    def send_msg(self,note):
        if(self.done or self.unsat or (not self.write_ready)):
            return
        else:
            p = self.p
            p.stdin.write(bytes(str(note) + '\n', encoding='utf-8'))
            p.stdin.flush()
            self.write_ready = False

    def get_process_msg(self):
        if(self.done or self.write_ready or self.unsat):
            return {}
        p = self.p
        msg = (p.stdout.readline()).decode('utf-8')
        try:
            msg_json = json.loads(msg)
        except:
            print("Failed to convert to json...")
            return {}
        return msg_json
    
    #Returns the key for the current composition and rules
    def key(self):
        key_ = '_'.join([str(n) for n in self.composition])
        rules_key = 'r'.join([str(r) for r in self.rules])
        return key_+'='+rules_key
    def delete(self):
        #del self.composition
        #del self.expected
        #del self.violations
        if(not self.deleted):
            parent_pid = self.p.pid
            parent = psutil.Process(parent_pid)
            for child in parent.children(recursive=True):  # or parent.children() for recursive=False
                child.kill()
                parent.kill()
            del self.p
            self.deleted = True
        #del self

#Controller/worker manager 
class Marginals_Manager():
    def __init__(self,length, note_max, composition, jar_path):
        self.length = length
        self.Ms = [] #workers of the class Marginals()
        self.jar_path=jar_path
        self.length = length
        self.note_max = note_max
        
    def compute_counterpoint(self, notes, rules=[1 for i in range(12)], mode='expected'):
        #Clearing memory
        if(len(self.Ms)>5):
            self.clear()
            print('Cleared!')
        notes_ = [str(note) for note in notes]
        rules_key = 'r'.join([str(r) for r in rules])
        key = '_'.join(notes_)
        key_prev = '_'.join(notes_[:-1])
        key = key + '=' + rules_key
        key_prev = key_prev + '=' + rules_key
        M = get_M(self.Ms,key)
        M_prev = get_M(self.Ms,key_prev)
        if(not M == None):
            if(mode=='expected'):
                return M.get_expected()
            elif(mode=='violations'):
                return M.get_violations()
            elif(mode=='marginals'):
                return M.get_marginals()
        elif(not M_prev == None): 
            M_prev.step(notes[-1])
            if(mode=='expected'):
                return M_prev.get_expected()
            elif(mode=='violations'):
                return M_prev.get_violations()
            elif(mode=='marginals'):
                return M_prev.get_marginals()
        else: 
            #print('couldnt find ',notes)
            M = Marginals(self.length, self.note_max, notes, self.jar_path,rules=rules)
            self.Ms += [M]
            if(mode=='expected'):
                return M.get_expected()
            elif(mode=='violations'):
                return M.get_violations()
            elif(mode=='marginals'):
                return M.get_marginals()

    def clear(self):
        for i in range(len(self.Ms)):
            self.Ms[i].delete()
        self.Ms.clear()
        
#Finds the M in Ms with the right key
def get_M(Ms,key):
    for i in range(len(Ms)):
        if(Ms[i].key()==key):
            return Ms[i]
    return None
#Removes M from Ms with the right key
def remove_M(Ms,key):
    for i in range(len(Ms)):
        if(Ms[i].key()==key):
            Ms.remove(Ms[i])
            return True
    return False

In [3]:
class LSTM(nn.Module):
    def __init__(self, input_size, h1_in_size, h1_state_size, length, n_labels, n_layers, loss_mode='NLL', time_skip = True,
                SAVE_PATH = 'saved_models/',MODEL_NAME='LSTM_fast'):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        super(LSTM,self).__init__()
        #Basic attributes
        self.time_skip     = time_skip
        self.input_size    = input_size
        self.h1_in_size    = h1_in_size
        self.h1_state_size = h1_state_size
        self.n_labels      = n_labels
        self.length        = length
        self.n_layers      = n_layers
        self.h_0          = torch.zeros((n_layers,1,self.h1_state_size),device=self.device,dtype=torch.float64)
        self.c_0          = torch.zeros((n_layers,1,self.h1_state_size),device=self.device,dtype=torch.float64)
        #Initializing network gates
        self.lstm        = nn.LSTM(self.h1_in_size, self.h1_state_size,self.n_layers,dtype=torch.float64,device=self.device,dropout=0.0)
        self.lstm.batch_first=True
        
        self.dense_in      = nn.Linear(self.input_size, self.h1_in_size,device=self.device,dtype=torch.float64)
        self.dense_1a      = nn.Linear(self.h1_in_size,self.h1_in_size ,device=self.device,dtype=torch.float64)
        self.dense_1b      = nn.Linear(self.h1_in_size,self.h1_in_size ,device=self.device,dtype=torch.float64)
        self.dense_1c      = nn.Linear(self.h1_in_size,self.h1_in_size ,device=self.device,dtype=torch.float64)
        
        self.dense_2a      = nn.Linear(self.h1_state_size,self.h1_state_size,device=self.device,dtype=torch.float64)
        self.dense_2b      = nn.Linear(self.h1_state_size,self.h1_state_size,device=self.device,dtype=torch.float64)
        self.dense_2c      = nn.Linear(self.h1_state_size,self.h1_state_size,device=self.device,dtype=torch.float64)
        self.dense_out     = nn.Linear(self.h1_state_size,self.n_labels     ,device=self.device,dtype=torch.float64)
        
        self.bn_in = nn.BatchNorm1d(self.h1_in_size,device=self.device,dtype=torch.float64)
        self.bn_1a = nn.BatchNorm1d(self.h1_in_size,device=self.device,dtype=torch.float64)
        self.bn_1b = nn.BatchNorm1d(self.h1_in_size,device=self.device,dtype=torch.float64)
        self.bn_1c = nn.BatchNorm1d(self.h1_in_size,device=self.device,dtype=torch.float64)
        
        self.bn_2a = nn.BatchNorm1d(self.h1_state_size,device=self.device,dtype=torch.float64)
        self.bn_2b = nn.BatchNorm1d(self.h1_state_size,device=self.device,dtype=torch.float64)
        self.bn_2c = nn.BatchNorm1d(self.h1_state_size,device=self.device,dtype=torch.float64)
        self.bn_out= nn.BatchNorm1d(self.n_labels     ,device=self.device,dtype=torch.float64)
        
        self.relu          = nn.ReLU()
        self.activation    = nn.LogSoftmax(dim=2)
        self.activation_generate = nn.LogSoftmax(dim=1)
        self.SAVE_PATH = SAVE_PATH
        self.MODEL_NAME = MODEL_NAME
        self.accuracy = 0
        
    #Initialize h1_0 and c1_0
    def initialize_hiddens(self,batch_size):
        h1_0 = torch.tile(self.h_0,(1,batch_size,1))
        c1_0 = torch.tile(self.c_0,(1,batch_size,1))
        h1_0.requires_grad = False
        c1_0.requires_grad = False
        return h1_0.to(self.device),c1_0.to(self.device)
    
    def forward(self,X,eval_mode=False):
        bt_ends = X[:,-1,0]
        bt_ends = bt_ends.type(torch.int64)
        X = X[:,:-1,:]
        
        X = X.type(torch.float64)
        if(eval_mode):
            self.eval()
        else:
            self.train()
        batch_size = X.shape[0]
        length     = X.shape[1]
        input_size = X.shape[2]
        h1_0,c1_0  = self.initialize_hiddens(batch_size)
        Y = self.dense_in(X)
        Y=self.bn_in(Y.view(batch_size,-1,length))        
        Y=self.relu(Y.view(batch_size,length,-1))
        Y = self.dense_1a(Y)
        Y=self.bn_1a(Y.view(batch_size,-1,length))
        Y=self.relu(Y.view(batch_size,length,-1))
        Y = self.dense_1b(Y)
        Y=self.bn_1b(Y.view(batch_size,-1,length))
        Y=self.relu(Y.view(batch_size,length,-1))
        Y = self.dense_1c(Y)
        Y=self.bn_1c(Y.view(batch_size,-1,length))
        Y=self.relu(Y.view(batch_size,length,-1))
        Y, (h, c) = self.lstm(Y,(h1_0,c1_0))
        h_0 = torch.zeros((Y.shape[0],1,Y.shape[2]),device=device,dtype=torch.float64)
        Y = torch.cat((h_0,Y),1)
        Y_=torch.zeros((Y.shape[0],Y.shape[2]),device=device,dtype=torch.float64)
        for i0 in range(Y.shape[0]):
            Y_[i0,:] = Y[i0,bt_ends[i0].item(),:]
        Y = Y_
        Y = self.dense_2a(Y)
        Y=self.bn_2a(Y.view(batch_size,-1))
        Y=self.relu(Y.view(batch_size,-1))
        Y = self.dense_2b(Y)
        Y=self.bn_2b(Y.view(batch_size,-1))
        Y=self.relu(Y.view(batch_size,-1))
        Y = self.dense_2c(Y)
        Y=self.bn_2c(Y.view(batch_size,-1))
        Y=self.relu(Y.view(batch_size,-1))
        Y = self.dense_out(Y)

        return Y

In [4]:
#@title DQN (simple E)
class Tuner_Env:
    def __init__(self, length):
        #self.lstm = lstm
        self.CPBP_MODEL_PATH=CPBP_MODEL_PATH
        self.cpbp = Marginals_Manager(length, 30, [], CPBP_MODEL_PATH)
        self.composition = []
        self.length = length
        self.violations  = {str(i):0 for i in range(12)}
        self.last_E  = self.cpbp.compute_counterpoint([], mode='expected')
        self.last_vs = 0
        self.max_length=40
    def get_reward(self):
        vs   = self.get_violations_sum()
        r_vs = -(vs-self.last_vs)
        self.last_vs=vs
        #print(self.composition)
        #print(self.last_vs)
        cur_E = self.cpbp.compute_counterpoint(self.composition, mode='expected')
        r_e = -(cur_E-self.last_E)
        self.last_E = cur_E

        eps = 1e-45
        r_5  = eps
        r_15 = eps
        r_25 = eps
        r_35 = eps
        r_45 = eps

        expected_marginals = self.cpbp.compute_counterpoint(self.composition, mode='marginals')
        for k in expected_marginals.keys():
            if(int(k)<=45):
                r_45 += expected_marginals[k]
            
            if(int(k)<=35):
                r_35 += expected_marginals[k]
            
            if(int(k)<=25):
                r_25 += expected_marginals[k]

            if(int(k)<=15):
                r_15 += expected_marginals[k]
            
            if(int(k)<=5):
                r_5 += expected_marginals[k]

        #return 0, float(r_e + r_vs)
        #return 0, float(r_vs)
        return 0, float(r_vs + r_e + np.log(r_5)/500 + np.log(r_15)/500 + np.log(r_25)/500 + np.log(r_35)/500 + np.log(r_45)/500)
        #return 0, float(np.log(r_5)/500 + np.log(r_15)/500 + np.log(r_25)/500 + np.log(r_35)/500 + np.log(r_45)/500 + r_e)
    def step(self,action):
        if(len(self.composition)<self.length):
            assert(not action == None)
            self.composition += [action]
            r_lstm, r_cpbp = self.get_reward()
        else:
            r_lstm, r_cpbp = 0,0
        terminated = len(self.composition)==self.length
        return self.get_observation(), r_lstm + r_cpbp, terminated

    def get_observation(self):
        obs = []
        for i in range(len(self.composition)):
            onehot = np.zeros(31)
            onehot[self.composition[i]] = 1
            onehot[30] = len(self.composition) >= self.length-1
            obs += [list(onehot)]
        for i in range(self.max_length - len(self.composition)):
            onehot = np.zeros(31)
            obs += [list(onehot)]
        
        return obs+[[len(self.composition)]*31]

    def get_violations(self):
        return self.cpbp.compute_counterpoint(self.composition, mode='violations')

    def get_violations_sum(self, selected_vk = -1):
        violations = self.get_violations()
        v_sum = 0
        for v_k in violations.keys():
            if((selected_vk == -1) or v_k == str(selected_vk)):
                v_sum+=violations[v_k]
        return v_sum

    def reset(self):
        self.composition = []
        self.violations  = {str(i):0 for i in range(12)}
        self.last_E  = self.cpbp.compute_counterpoint([], mode='expected')
        self.last_vs = self.get_violations_sum()
        self.cpbp.clear()
        self.cpbp = Marginals_Manager(self.length, 30, [], CPBP_MODEL_PATH)
        return self.get_observation()
    
class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

class DQN(nn.Module):
    def __init__(self, n_actions):
        super(DQN, self).__init__()
        self.LSTM = LSTM(n_actions+1, n_actions+1, 100, 32, n_actions, 2, time_skip=False)
        
    def forward(self, x, eval_mode=True):
        x = self.LSTM(x,eval_mode=eval_mode)
        return x
    
class Agent():
    def __init__(self,length=32, batch_size=128,gamma=0.99,eps_start=0.9,eps_end=0.05,eps_decay=1000,tau=0.005,lr=0.0001, memory_size=10000, name="",SAVE_PATH="DQN_models/"):
        self.BATCH_SIZE = batch_size
        self.GAMMA = gamma
        self.EPS_START = eps_start
        self.EPS_END = eps_end
        self.EPS_DECAY = eps_decay
        self.TAU = tau
        self.LR = lr
        self.n_actions = 30
        # Get the number of state observations
        self.policy_net = DQN(self.n_actions).to(device)
        self.target_net = DQN(self.n_actions).to(device)
        self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=self.LR, amsgrad=True)
        self.memory = ReplayMemory(memory_size)
        self.steps_done = 0
        self.rand_actions = list(np.arange(30))
        self.device = device
        self.name = name
        self.save_path = SAVE_PATH
        self.eval_lengths = [length, length+10, length+30, length+50]
        self.run_results = [{str(i):[] for i in list(range(12))+['total']}, {k:{str(i):[] for i in list(range(12))+['total']} for k in self.eval_lengths},'agent: '+self.name]
        try:
            self.load()
        except:
            print("failed to load")
    def select_action(self,state,eval_mode = False):
        sample = random.random()
        eps_threshold = self.EPS_END + (self.EPS_START - self.EPS_END) * \
            math.exp(-1. * self.steps_done / self.EPS_DECAY)
        self.steps_done += 1
        if((sample > eps_threshold) or eval_mode):
            with torch.no_grad():
                # t.max(1) will return the largest column value of each row.
                # second column on max result is index of where max element was
                # found, so we pick action with the larger expected reward.
                return self.policy_net(state).max(1)[1].view(1, 1)
        else:
            return torch.tensor([[random.choice(self.rand_actions)]], device=device, dtype=torch.long)
    
    def optimize_model(self):
        if len(self.memory) < self.BATCH_SIZE:
            return
        transitions = self.memory.sample(self.BATCH_SIZE)
        batch = Transition(*zip(*transitions))
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
        non_final_next_states = torch.cat([s for s in batch.next_state
                                                    if s is not None])
        state_batch  = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        
        state_action_values = self.policy_net(state_batch, eval_mode=False).gather(1, action_batch)
        next_state_values = torch.zeros(self.BATCH_SIZE, device=self.device, dtype=torch.float64)
        with torch.no_grad():
            next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0]
        # Compute the expected Q values
        expected_state_action_values = (next_state_values * self.GAMMA) + reward_batch
        # Compute Huber loss
        criterion = nn.MSELoss()
        loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        # In-place gradient clipping
        torch.nn.utils.clip_grad_value_(self.policy_net.parameters(), 5)
        self.optimizer.step()
    
    def load(self):
        try:
            self.policy_net.load_state_dict(torch.load(self.save_path+self.name+"_policy"))
            self.target_net.load_state_dict(torch.load(self.save_path+self.name+"_target"))
        except:
            print("failed to load DQN")
        with open(self.save_path+self.name+"_meta.pkl", 'rb') as file:
            meta = pickle.load(file)
            self.steps_done = meta["steps_done"]
            self.run_results = meta["run_results"]
            
    def save(self):
        torch.save(self.policy_net.state_dict(), self.save_path+self.name+"_policy")
        torch.save(self.target_net.state_dict(), self.save_path+self.name+"_target")
        meta = {"steps_done":self.steps_done, "run_results":self.run_results}
        with open(self.save_path+self.name+"_meta.pkl", 'wb') as file:
            pickle.dump(meta, file)
            
    def eval_model(self, eval_length = 32):
        eval_env = Tuner_Env(eval_length)
        viols = {str(i):0 for i in range(12)}
        viols["total"] = 0
        for i in range(10):
            state = eval_env.reset()
            state = torch.tensor(state, dtype=torch.float64, device=device).unsqueeze(0)
            for j in range(eval_length):
                action = agent.select_action(state,eval_mode = True)
                observation, reward, terminated = eval_env.step(action.item())
                state = torch.tensor(observation, dtype=torch.float64, device=device).unsqueeze(0)
            viols["total"] += eval_env.get_violations_sum()
            for k in range(12):
                viols[str(k)] += eval_env.get_violations_sum(selected_vk=k)
        eval_env.reset()
        viols["total"]/=10
        for k in range(12):
            viols[str(k)]/=10
        return viols
        


#### plot function

In [5]:
def prune_num(name):
    if('1' in name.split('-')):
        name = name.split('-')
        name.remove('1')
        name = '-'.join(name)
    return name
def prune_LSTM(name):
    if('LSTM' in name.split('-')):
        name = name.split('-')
        name.remove('LSTM')
        name = '-'.join(name)
    return name

In [None]:
def plot_violations(ep_lists,eval_period=15, k=1, agent_names=[]):
    an = copy.deepcopy(agent_names)
    for i in range(len(an)):
        if('1' in an[i].split('-')):
            an[i] = prune_num(an[i])
        if('LSTM' in an[i].split('-')):
            an[i] = prune_LSTM(an[i])
    agent_names = set()
    for i in range(len(an)):
        agent_names.add(an[i])
    agent_names = list(agent_names)
    cmap = get_cmap_string(domain=agent_names)
    length = 35
    eval_lengths = [length-20, length, length+55]
    plot_keys = ["total"]+[str(i) for i in range(12)]
    versions  = ['new', 'old']
    violation_names_new = {"total":"total",
                      "0":"naturalNotes",
                      "1":"bFlat",
                      "2":"noRepeat",
                      "3":"stepwiseDescentToFinal",
                      "4":"tonicEnds",
                      "5":"avoidSixths",
                      "6":"skipStepsSequence",
                      "7":"skipStepRatio",
                      "8":"coverModalRange",
                      "9":"characteristicModalSkips",
                      "10":"skipStepSequence",
                      "11":"tritonOutlines"}
    violation_names_old = {"total":"total",
                      "0":"naturalNotes",
                      "1":"intervals",
                      "2":"tritonOutlines",
                      "3":"tonicEnds",
                      "4":"stepwiseDescentToFinal",
                      "5":"repeats",
                      "6":"coverModalRange",
                      "7":"characteristicModalSkips",
                      "8":"skipStepsRatio",
                      "9":"sixths",
                      "10":"skipStepsSequence",
                      "11":"bFlat"}
    
    highlighted_labels = [
                        #"V.dE.m-C-100DQN-PR", "V.dE.m-C-100DQN-LSTM-PR",
                        "V.dE-C-100DQN-PR", "V.dE-C-100DQN-LSTM-PR",
                        #"V.m-C-100DQN-PR",  "V.m-C-100DQN-LSTM-PR",
                        "V-D(normed)-64DQN-shorts-PR","V-D(normed)-64DQN-shorts-LSTM-PR",
                        #"V-C-100DQN-PR","V-C-100DQN-LSTM-PR",
                        #"dE.m-C-100DQN-PR","dE.m-C-100DQN-LSTM-PR"
                         ]
    for ver in versions:
        for el in eval_lengths:
            for pk in plot_keys:
                if(ver=='new'):
                    violation_names = violation_names_new
                else:
                    violation_names = violation_names_old
                bar_names_color = []
                bar_names_label = []
                bar_vals  = []
                bar_stds  = []
                if(not pk == "total"):# and (not pk =="3") and  (not pk =="4")):
                    ...
                    #continue
                ymax=-1
                ymin=999
                for ep in ep_lists:
                    vs = []
                    evals = []
                    stds  = []
                    ep_viols = ep[0]
                    ep_evals = ep[1][ver][el]
                    ep_label = prune_num(ep[2][7:])
                    lastx = 0
                    for i in range(len(ep_evals[pk])):
                        if(ep_evals[pk][i][0]>9000 or ep_evals[pk][i][0]<200):
                            ...
                            #continue
                        evals += [[ep_evals[pk][i][0],ep_evals[pk][i][1]]]#-(ep_evals["3"][i][1]+ep_evals["4"][i][1])]]
                        lastx = ep_evals[pk][i][0]
                        lasty = ep_evals[pk][i][1]
                        laststd=ep_evals[pk][i][-1]
                        stds  += [ep_evals[pk][i][-1]]
                    evals = np.array(evals)
                    al = 1
                    dots = "-."
                    if("LSTM" in ep_label.split('-')):# and (not ep_label=="no-marginals")):
                        dots = "-"
                    if(ep_label in highlighted_labels):
                        ...
                        al = 1
                    ys_hi = evals[:,1] + stds
                    ys_lo = evals[:,1] - stds
                    if(ys_hi.max()>ymax):
                        ymax=ys_hi.max()
                    if(ys_lo.min()<ymin):
                        ymin=ys_lo.min()
                    spl_hi = make_interp_spline(evals[:,0], ys_hi, k=k)
                    spl_lo = make_interp_spline(evals[:,0], ys_lo, k=k)
                    spl    = make_interp_spline(evals[:,0], evals[:,1], k=k)
                    xsmooth = np.linspace(evals[:,0].min(), evals[:,0].max(), 200) 
                    y_smooth_hi = spl_hi(xsmooth)
                    y_smooth_lo = spl_lo(xsmooth)
                    y_smooth_   = spl(xsmooth)
                    symb = ""
                    lsymb = ""
                    if("LSTM" in ep_label.split("-")):
                        symb = "s"
                        lsymb="--"
                    else:
                        lsymb="-"
                        symb = "o"
                    ep_label  = ep_label.split('-')
                    ep_label_ = copy.deepcopy(ep_label)
                    real_label = []
                    if('dV' in ep_label):
                        real_label += ['ΔV']
                    elif('dV.m' in ep_label_):
                        real_label += ['ΔV.m']
                    elif('dV.dE' in ep_label_):
                        real_label += ['ΔV.ΔE']
                    elif('dV.dE(normalized)' in ep_label_):
                        real_label += ['ΔV.ΔE.normalized']
                    elif('dE.m' in ep_label_):
                        real_label += ['ΔE.m']
                    elif('dV.dE.m' in ep_label_):
                        real_label += ['ΔV.ΔE.m']
                    if('A' in ep_label_):
                        real_label += ['[A]']
                    elif('B' in ep_label_):
                        real_label += ['[B]']
                    elif('B(more.features)' in ep_label_):
                        real_label += ['[B]']
                    elif('B(noised)' in ep_label_):
                        real_label += ['[B.noised]']
                    elif('B.2' in ep_label_):
                        real_label += ['[B.lessFeatures]']
                    elif('C' in ep_label_):
                        real_label += ['[C]']
                    if('0.92m' in ep_label_):
                        real_label += ['0.92m']
                    if('0.75m' in ep_label_):
                        real_label += ['0.75m']
                    if('attention' in ep_label_):
                        real_label += ['attention']
                    if('transformer' in ep_label_):
                        real_label += ['transformer']
                    if('NEW' in ep_label_):
                        real_label += ['NEW']
                    if('OLD' in ep_label_):
                        real_label += ['OLD']
                    ep_label_= '-'.join(real_label[:])
                    if('LSTM' in ep_label):
                        ep_label.remove('LSTM')
                        ep_label = '-'.join(ep_label)
                        bar_names_color += [ep_label]
                        bar_vals  += [lasty]
                        bar_stds  += [laststd]
                        bar_names_label += [ep_label_[:]]
                        plt.fill_between(xsmooth,y_smooth_hi,y_smooth_lo,color=cmap(ep_label),alpha=al/10,linewidth=0.0)
                        plt.plot(xsmooth, y_smooth_, lsymb, color=cmap(ep_label),alpha=al, linewidth=1.2)
                    else:
                        ep_label = '-'.join(ep_label)
                        bar_names_color += [ep_label]
                        bar_vals  += [lasty]
                        bar_stds  += [laststd]
                        bar_names_label += [ep_label_]
                        plt.fill_between(xsmooth,y_smooth_hi,y_smooth_lo,color=cmap(ep_label),alpha=al/5,linewidth=0.0,label=ep_label_)
                        plt.plot(xsmooth, y_smooth_, lsymb, color=cmap(ep_label),alpha=al, linewidth=1)

                fz = 7
                plt.title('['+ver+']'+ ' length ' +str(el) +' rule: '+ violation_names[pk],fontsize=fz+2)
                plt.xlabel('n. episodes',fontsize=fz)
                plt.ylabel('n. violations',fontsize=fz)
                leg=plt.legend(prop={'size': 7}, markerscale=2.5)
                for lh in leg.legendHandles: 
                    lh.set_alpha(1)
                fig = plt.gcf()
                fig.set_size_inches(4,2.5)
                plt.xticks(np.arange(100,6001,650), fontsize=fz)
                ymin -= 5
                ymin = max(0,ymin)
                #ymin=0
                ymax += ymax*0.1
                ax = plt.gca()
                ax.set_facecolor((1, 1, 1))
                #ymin = 0
                #print((ymin,ymax))
                #print(np.arange(int(ymin/5)*5, int(ymax/5)*5+10,((ymax-ymin)/10)))
                plt.yticks(np.arange(int(ymin/5)*5, int(ymax/5)*5+5,((ymax+5-ymin)/15)), fontsize=fz)
                plt.grid(linestyle = '--', linewidth = 0.5)
                plt.show()
                """
                ymax = np.max(bar_vals) + 35
                ymin = 0#np.min(bar_vals)
                y_pos = 1*np.arange(len(bar_names_color))
                patterns = []
                for ipos in range(len(y_pos)):
                    if(ipos%2==0):
                        patterns += [""]
                    else:
                        patterns += ["O"]
                y_pos_ = np.arange(0,len(bar_names_color),2,dtype=float)
                y_pos_ += (y_pos[1]-y_pos[0])/2
                for ipos in range(len(y_pos)):
                    if(ipos%2==0):
                        plt.bar(y_pos[ipos], bar_vals[ipos],color=cmap(bar_names_color[ipos]),yerr=bar_stds[ipos], capsize=2,hatch=patterns[ipos], label=bar_names_label[ipos])
                    else:
                        plt.bar(y_pos[ipos], bar_vals[ipos],color=cmap(bar_names_color[ipos]),yerr=bar_stds[ipos], capsize=2,hatch=patterns[ipos])
                bar_names_label_2s = [bar_names_label[2*i] for i in range(int(len(bar_names_label)/2))]
                plt.xticks(y_pos_, ["" for b in bar_names_label_2s], fontsize=fz)#,rotation = 15, ha="right",rotation_mode='anchor')
                #leg=plt.legend(prop={'size': 7}, markerscale=2.5)
                plt.yticks(np.arange(int(ymin/5)*5, int(ymax/5)*5+10,((ymax-ymin)/8)), fontsize=fz)
                for i in range(len(bar_names_label)):
                    plt.text(x = y_pos[i]-0.65 , y = bar_vals[i]+bar_stds[i]+0.75, s = bar_vals[i], size = 8)
                #for i in range(len(y_pos_)):
                #    plt.text(x = y_pos_[i] - len(bar_names_label[2*i])/10, y = -7, s = bar_names_label[2*i], size = 8)
                fig = plt.gcf()
                fig.set_size_inches(6,5.2)
                #leg=plt.legend(prop={'size': 7}, markerscale=2.5)
                plt.title("Final number of violations for instances of length "+str(el), fontsize=fz+2)
                plt.ylabel('n. violations',fontsize=fz)
                plt.show()
                """
    rnn_plots = 0
    for ep in ep_lists:
        ep_label = ep[2][7:]
        if(not "LSTM" in ep_label.split('-')):
            continue
        #print(ep_label)
        rnn_plots += 1
        ys_avg = []
        ys_std = []
        ys_hi  = []
        ys_lo  = []
        xs     = []
        k=25
        for i in range(0,len(ep[0]['rnn_score']),50):
            #print([[i,k*np.mean(ep[0]['rnn_score'][max(0,i-5):min(len(ep[0]['rnn_score'])-1,i+5)])]])
            if(i<100):
                continue
            xs     += [i]
            ys_avg += [k*np.mean(ep[0]['rnn_score'][max(0,i-25):min(len(ep[0]['rnn_score'])-1,i+25)])]
            ys_std += [k*np.std(ep[0]['rnn_score'][max(0,i-25):min(len(ep[0]['rnn_score'])-1,i+25)])]
        ys_avg = np.array(ys_avg)
        ys_std = np.array(ys_std)
        xs     = np.array(xs)
        #plt.plot(np.arange(len(ep[0]['rnn_score'])),np.array(ep[0]['rnn_score'])*k, linewidth=0.5,alpha=0.25,color=cmap(ep_label))
        #plt.plot(ys_avg[:,0],ys_avg[:,1], linewidth=1.0,color=cmap(ep_label),alpha=0.5)
        #plt.plot(ys_avg[:,0],ys_avg[:,1], 'o', color=cmap(ep_label),alpha=1.0, markersize=1.0,label=ep_label)
        #plt.errorbar(ys_avg[:,0], ys_avg[:,1], yerr=ys_std, color=cmap(ep_label),alpha=1.0,linewidth=0,elinewidth=1.0,capsize=0)
        plt.legend()
        fz = 7
        
        ys_hi = ys_avg + ys_std
        ys_lo = ys_avg - ys_std
        #print(ys_avg)
        spl_hi = make_interp_spline(xs, ys_hi, k=1)
        
        spl_lo = make_interp_spline(xs, ys_lo, k=1)
        spl    = make_interp_spline(xs, ys_avg,k=1)
        xsmooth = np.linspace(xs.min(), xs.max(), 200) 
        y_smooth_hi = spl_hi(xsmooth)
        y_smooth_lo = spl_lo(xsmooth)
        y_smooth_   = spl(xsmooth)
        
        
        ep_label = ep_label.split('-')
        ep_label.remove('LSTM')
        ep_label_ = copy.deepcopy(ep_label)
        real_label = []
        if('dV' in ep_label_):
            real_label += ['ΔV']
        elif('dV.m' in ep_label_):
            real_label += ['ΔV.m']
        elif('dV.dE' in ep_label_):
            real_label += ['ΔV.ΔE']
        elif('dV.dE(normalized)' in ep_label_):
            real_label += ['ΔV.ΔE.normalized']
        elif('dE.m' in ep_label_):
            real_label += ['ΔE.m']
        elif('dV.dE.m' in ep_label_):
            real_label += ['ΔV.ΔE.m']
        if('A' in ep_label_):
            real_label += ['[A]']
        elif('B' in ep_label_):
            real_label += ['[B]']
        elif('B(normed)' in ep_label_):
            real_label += ['[B.normed]']
        elif('B(noised)' in ep_label_):
            real_label += ['[B.noised]']
        elif('B.2' in ep_label_):
            real_label += ['[B.lessFeatures]']
        elif('C' in ep_label_):
            real_label += ['[C]']
        ep_label_= '-'.join(real_label)
        ep_label = '-'.join(ep_label)
        al =0.95
        if(ep_label in highlighted_labels):
            al=0.95
        
        plt.fill_between(xsmooth,y_smooth_hi,y_smooth_lo,color=cmap(ep_label),alpha=al/5,linewidth=0.0, label=ep_label_)
        #plt.plot(xs,ys_avg, 'o', color=cmap(ep_label),alpha=al, markersize=1.2,label=ep_label)
        plt.plot(xsmooth,y_smooth_, linewidth=1.45,color=cmap(ep_label),alpha=al)
        plt.title('RNN based avg. reward per step',fontsize=fz+2)
        plt.xlabel('n. episodes',fontsize=fz)
        plt.ylabel('RNN based return',fontsize=fz)
        leg=plt.legend(prop={'size': 7}, markerscale=3)
        for lh in leg.legendHandles: 
            lh.set_alpha(1)
        fig = plt.gcf()
        fig.set_size_inches(4.5,3.0)
        plt.xticks(np.arange(0,6001,500), fontsize=fz)
        plt.yticks(np.arange(-90,0,15),fontsize=fz)
        plt.grid(linestyle = '--', linewidth = 0.5)
        ax = plt.gca()
        ax.set_facecolor((1, 1, 1))
        
    if(rnn_plots>0):
        #plt.legend()
        plt.show()
    rnn_plots = 0
    for ep in ep_lists:
        ep_label = ep[2][7:]
        if(not "LSTM" in ep_label.split('-')):
            continue
        #print(ep_label)
        rnn_plots += 1
        ys_avg = []
        ys_std = []
        ys_hi  = []
        ys_lo  = []
        xs     = []
        k=25
        y0 = 0
        for i in range(0,len(ep[0]['cpbp_score']),50):
            #print([[i,k*np.mean(ep[0]['rnn_score'][max(0,i-5):min(len(ep[0]['rnn_score'])-1,i+5)])]])
            if(i<100):
                continue
            xs     += [i]
            if(len(ys_avg)==0):
                y0 = k*np.mean(ep[0]['cpbp_score'][max(0,i-25):min(len(ep[0]['cpbp_score'])-1,i+25)])
            ys_avg += [k*np.mean(ep[0]['cpbp_score'][max(0,i-25):min(len(ep[0]['cpbp_score'])-1,i+25)])-y0]
            ys_std += [k*np.std(ep[0]['cpbp_score'][max(0,i-25):min(len(ep[0]['cpbp_score'])-1,i+25)])]
        ys_avg = np.array(ys_avg)
        ys_std = np.array(ys_std)
        xs     = np.array(xs)
        #plt.plot(np.arange(len(ep[0]['rnn_score'])),np.array(ep[0]['rnn_score'])*k, linewidth=0.5,alpha=0.25,color=cmap(ep_label))
        #plt.plot(ys_avg[:,0],ys_avg[:,1], linewidth=1.0,color=cmap(ep_label),alpha=0.5)
        #plt.plot(ys_avg[:,0],ys_avg[:,1], 'o', color=cmap(ep_label),alpha=1.0, markersize=1.0,label=ep_label)
        #plt.errorbar(ys_avg[:,0], ys_avg[:,1], yerr=ys_std, color=cmap(ep_label),alpha=1.0,linewidth=0,elinewidth=1.0,capsize=0)
        fz = 7
        
        ys_hi = ys_avg + ys_std
        ys_lo = ys_avg - ys_std
        #print(ys_avg)
        spl_hi = make_interp_spline(xs, ys_hi, k=1)
        
        spl_lo = make_interp_spline(xs, ys_lo, k=1)
        spl    = make_interp_spline(xs, ys_avg,k=1)
        xsmooth = np.linspace(xs.min(), xs.max(), 200) 
        y_smooth_hi = spl_hi(xsmooth)
        y_smooth_lo = spl_lo(xsmooth)
        y_smooth_   = spl(xsmooth)
        
        
        ep_label = ep_label.split('-')
        ep_label.remove('LSTM')
        ep_label_ = copy.deepcopy(ep_label)
        real_label = []
        if('dV' in ep_label_):
            real_label += ['ΔV']
        elif('dV.m' in ep_label_):
            real_label += ['ΔV.m']
        elif('dV.dE' in ep_label_):
            real_label += ['ΔV.ΔE']
        elif('dV.dE(normalized)' in ep_label_):
            real_label += ['ΔV.ΔE.normalized']
        elif('dE.m' in ep_label_):
            real_label += ['ΔE.m']
        elif('dV.dE.m' in ep_label_):
            real_label += ['ΔV.ΔE.m']
        if('A' in ep_label_):
            real_label += ['[A]']
        elif('B' in ep_label_):
            real_label += ['[B]']
        elif('B(normed)' in ep_label_):
            real_label += ['[B.normed]']
        elif('B(noised)' in ep_label_):
            real_label += ['[B.noised]']
        elif('C' in ep_label_):
            real_label += ['[C]']
        ep_label_= '-'.join(real_label)
        ep_label = '-'.join(ep_label)
        al =0.95
        if(ep_label in highlighted_labels):
            al=0.95
        
        plt.fill_between(xsmooth,y_smooth_hi,y_smooth_lo,color=cmap(ep_label),alpha=al/5,linewidth=0.0,label=ep_label_)
        #plt.plot(xs,ys_avg, 'o', color=cmap(ep_label),alpha=al, markersize=1.2,label=ep_label)
        plt.plot(xsmooth,y_smooth_, linewidth=1.45,color=cmap(ep_label),alpha=al)
        plt.title('CP based avg. reward per step',fontsize=fz+2)
        plt.xlabel('n. episodes',fontsize=fz)
        plt.ylabel('CP based return',fontsize=fz)
        leg=plt.legend(prop={'size': 7}, markerscale=3)
        for lh in leg.legendHandles: 
            lh.set_alpha(1)
        fig = plt.gcf()
        fig.set_size_inches(4.5,3.0)
        plt.xticks(np.arange(100,6001,500), fontsize=fz)
        plt.yticks(np.arange(0,60,15),fontsize=fz)
        plt.grid(linestyle = '--', linewidth = 0.5)
        ax = plt.gca()
        ax.set_facecolor((1, 1, 1))
    if(rnn_plots>0):
        plt.show()

def get_cmap_string(palette='tab10', domain=[]):
    domain_unique = np.unique(domain)
    hash_table = {key: i_str for i_str, key in enumerate(domain_unique)}
    mpl_cmap = plt.cm.get_cmap(palette, lut=len(domain_unique))

    def cmap_out(X, **kwargs):
        return mpl_cmap(hash_table[X], **kwargs)
    return cmap_out


#### print

In [None]:
#no-marginals-explorer-2-LSTM is the new LSTM version
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
agent_names = []

agent_names += [
    "dV.dE-NEW","dV.m-OLD"
]

results = []
agent_params = {'batch_size':128,
                'gamma'     :0.99,
                'eps_start' :0.99,
                
                'eps_end'   :0.10,
                'eps_decay' :6500,
                'tau'       :0.005,
                'lr'        :0.001,
                'memory_size':40000,
                'name':""}
for an in agent_names:
    agent_params['name'] = an
    agent_ = Agent(**agent_params)
    results += [agent_.run_results]
plot_violations(results, k=1, agent_names=agent_names)
