In [1]:
import numpy as np
import pandas as pd
from chromatography import *
from separation_utility import *
from torch import optim, tensor
import torch.nn as nn
import matplotlib.pyplot as plt
import time

%matplotlib qt

In [2]:
alists = []
alists.append(pd.read_csv(f'../data/GilarSample.csv'))
alists.append(pd.read_csv(f'../data/Peterpeptides.csv'))
alists.append(pd.read_csv(f'../data/Roca.csv'))
alists.append(pd.read_csv(f'../data/Peter32.csv'))
alists.append(pd.read_csv(f'../data/Eosin.csv'))
alists.append(pd.read_csv(f'../data/Alizarin.csv'))
alists.append(pd.read_csv(f'../data/Controlmix2.csv'))


In [3]:
def loss_field(exp, taus, N = 200):
    phis = np.linspace(0, 1, N)
    losses = np.zeros((N, N))
    j = 0
    for phi1 in phis:
        i = 0
        for phi2 in phis:
            exp.reset()
            exp.run_all([phi1, phi2], taus)
            losses[i, j] = exp.loss()
            i += 1
        j += 1
    X, Y = np.meshgrid(phis, phis)
    
    return X, Y, losses

def average_over_equal_intervals(arr, interval):
    return np.mean(arr.reshape(-1, interval), axis=1)

In [None]:
#Gillar
taus = [.25, .25, 10]
losses_gen = []
losses_from_gen = []
time_start = time.perf_counter()
exp = ExperimentAnalytes(k0 = alists[0].k0.values, S = alists[0].S.values, h=0.001,run_time=10.0)
for i in range(100):
    pol = PolicyGeneral(
                phi = nn.Sequential(
                    PermEqui2_max(2, 5),
                    nn.ELU(inplace=True),
                    PermEqui2_max(5, 5),
                    nn.ELU(inplace=True),
                    PermEqui2_max(5, 5),
                    nn.ELU(inplace=True),
                ),
                rho = nn.Sequential(
                    nn.Linear(5, 5),
                    nn.ELU(inplace=True),
                    nn.Linear(5, 5),
                    nn.ELU(inplace=True),
                    Rho(n_steps=len(taus), hidden=5, in_dim=5, sigma_max=.3, sigma_min=.01),
                )
            )
    reinforce_gen(
        alists = alists[1:], 
        policy = pol, 
        delta_taus = taus, 
        num_episodes = 20_000, 
        sample_size = 10,
        batch_size = 1, 
        lr = .05, 
        optim = lambda a, b: torch.optim.SGD(a, b),
        lr_decay_factor = 0.75,
        lr_milestones = 2500,
        print_every = 20_000,
        baseline = .55,
        max_norm = 1.5,
        max_rand_analytes = 30,
        min_rand_analytes = 15,
        rand_prob = 0.7
    )
    mu, _ = pol(torch.Tensor(alists[0][['S', 'lnk0']].values))
    exp.reset()
    exp.run_all(mu.detach().numpy(), taus)
    losses_gen.append(exp.loss())
    _, _, mus, _ = reinforce_single_from_gen(
        alist=alists[0], 
        policy=pol, 
        delta_taus=taus, 
        num_episodes=2000, 
        sample_size=10, 
        lr=.05, 
        optim=torch.optim.SGD,
        lr_decay_factor=.5,
        lr_milestones=1000,
        print_every=100,
        baseline=0.55,
        max_norm=2.
    )
    exp.reset()
    exp.run_all(mus[-1,:], taus)
    losses_from_gen.append(exp.loss())
    
time_end = time.perf_counter()
# time 507.8653401895899

  return delta_tau_phi * (1 + self.k(phi)) / self.k(phi)


Loss: 0.4030468970902817, epoch: 20000/20000
Loss: 0.5169691473367679, epoch: 100/2000
Loss: 0.507738990448791, epoch: 200/2000
Loss: 0.5007827904090174, epoch: 300/2000
Loss: 0.5240758422062577, epoch: 400/2000
Loss: 0.4994911578994783, epoch: 500/2000
Loss: 0.5017388527503364, epoch: 600/2000
Loss: 0.2603477173950704, epoch: 700/2000
Loss: 0.3098511163328784, epoch: 800/2000
Loss: 0.29809759696839844, epoch: 900/2000
Loss: 0.22547046457459122, epoch: 1000/2000
Loss: 0.23987654567750138, epoch: 1100/2000
Loss: 0.26525608822673785, epoch: 1200/2000
Loss: 0.22628237415482072, epoch: 1300/2000
Loss: 0.2445938398011717, epoch: 1400/2000
Loss: 0.23825156886402796, epoch: 1500/2000
Loss: 0.2234275821037061, epoch: 1600/2000
Loss: 0.20286648643034066, epoch: 1700/2000
Loss: 0.20623322537282252, epoch: 1800/2000
Loss: 0.30389252627392466, epoch: 1900/2000
Loss: 0.1961650241528139, epoch: 2000/2000
Loss: 1.7813370965294202, epoch: 20000/20000
Loss: 0.288586346959544, epoch: 100/2000
Loss: 0.25

Loss: 0.18809147810397878, epoch: 100/2000
Loss: 0.18225412583476147, epoch: 200/2000
Loss: 0.19244668605813395, epoch: 300/2000
Loss: 0.18093442836283305, epoch: 400/2000
Loss: 0.15754673227879185, epoch: 500/2000
Loss: 0.17805622083646008, epoch: 600/2000
Loss: 0.156967013571856, epoch: 700/2000
Loss: 0.14041596772103082, epoch: 800/2000
Loss: 0.12691275409753142, epoch: 900/2000
Loss: 0.19456009340069982, epoch: 1000/2000
Loss: 0.126934404230362, epoch: 1100/2000
Loss: 0.12751456716969067, epoch: 1200/2000
Loss: 0.13102280006370357, epoch: 1300/2000
Loss: 0.11620051967623435, epoch: 1400/2000
Loss: 0.11356217611979325, epoch: 1500/2000
Loss: 0.12704196817942676, epoch: 1600/2000
Loss: 0.16366109652157862, epoch: 1700/2000
Loss: 0.12849349746878663, epoch: 1800/2000
Loss: 0.12754505645282485, epoch: 1900/2000
Loss: 0.18540940688774632, epoch: 2000/2000
Loss: 0.4150051381934273, epoch: 20000/20000
Loss: 0.5051699138481506, epoch: 100/2000
Loss: 0.25259014681313213, epoch: 200/2000
Los

Loss: 0.18829540031749184, epoch: 200/2000
Loss: 0.18414021104412406, epoch: 300/2000
Loss: 0.17311253989106634, epoch: 400/2000
Loss: 0.15522537088228322, epoch: 500/2000
Loss: 0.17145762244396662, epoch: 600/2000
Loss: 0.16594431430788428, epoch: 700/2000
Loss: 0.17285655851571735, epoch: 800/2000
Loss: 0.19291842596279274, epoch: 900/2000
Loss: 0.19746557381503785, epoch: 1000/2000
Loss: 0.14997475442208602, epoch: 1100/2000
Loss: 0.14929254550055768, epoch: 1200/2000
Loss: 0.13758844227646108, epoch: 1300/2000
Loss: 0.13182675300648725, epoch: 1400/2000
Loss: 0.14408632133528684, epoch: 1500/2000
Loss: 0.1521652308653678, epoch: 1600/2000
Loss: 0.131682654114789, epoch: 1700/2000
Loss: 0.13578934594218375, epoch: 1800/2000
Loss: 0.14262675081288365, epoch: 1900/2000
Loss: 0.13780034052910992, epoch: 2000/2000
Loss: 1.8733367619303074, epoch: 20000/20000
Loss: 0.34107503314471754, epoch: 100/2000
Loss: 0.3063082399517322, epoch: 200/2000
Loss: 0.26534403936884005, epoch: 300/2000
Lo

  return delta_tau_phi * (1 + self.k(phi)) / self.k(phi)


Loss: 0.496698711957057, epoch: 20000/20000
Loss: 0.2675341176284668, epoch: 100/2000
Loss: 0.19999011754705998, epoch: 200/2000
Loss: 0.19460545572514718, epoch: 300/2000
Loss: 0.2176222391960152, epoch: 400/2000
Loss: 0.21267556331027043, epoch: 500/2000
Loss: 0.2661110043998379, epoch: 600/2000
Loss: 0.2950393543370105, epoch: 700/2000
Loss: 0.4620367072779089, epoch: 800/2000
Loss: 0.35648190844289235, epoch: 900/2000
Loss: 0.3846020421826666, epoch: 1000/2000
Loss: 0.22579677950998972, epoch: 1100/2000
Loss: 0.19030089200149344, epoch: 1200/2000
Loss: 0.18601993417882956, epoch: 1300/2000
Loss: 0.2312089769252989, epoch: 1400/2000
Loss: 0.21168497998862615, epoch: 1500/2000
Loss: 0.19539468638629598, epoch: 1600/2000
Loss: 0.21172795728427113, epoch: 1700/2000
Loss: 0.22022611565777583, epoch: 1800/2000
Loss: 0.2500915171401647, epoch: 1900/2000
Loss: 0.18779753356665585, epoch: 2000/2000
Loss: 0.9338551220050105, epoch: 20000/20000
Loss: 0.5126584653884111, epoch: 100/2000
Loss: 

Loss: 0.20520034048422203, epoch: 100/2000
Loss: 0.18217505862065636, epoch: 200/2000
Loss: 0.23541919389304597, epoch: 300/2000
Loss: 0.2501897251601691, epoch: 400/2000
Loss: 0.27410779387520595, epoch: 500/2000
Loss: 0.2696065721153847, epoch: 600/2000
Loss: 0.2742242427114591, epoch: 700/2000
Loss: 0.19697203549464984, epoch: 800/2000
Loss: 0.30700348981869385, epoch: 900/2000
Loss: 0.29952146595573437, epoch: 1000/2000
Loss: 0.26506182527764205, epoch: 1100/2000
Loss: 0.28713328192390264, epoch: 1200/2000
Loss: 0.29027370910047007, epoch: 1300/2000
Loss: 0.15928754524238425, epoch: 1400/2000
Loss: 0.1631191896428231, epoch: 1500/2000
Loss: 0.15566159581503672, epoch: 1600/2000
Loss: 0.18418223190753588, epoch: 1700/2000
Loss: 0.1405815962972026, epoch: 1800/2000
Loss: 0.14824155987643783, epoch: 1900/2000
Loss: 0.1587848985573554, epoch: 2000/2000
Loss: 0.5096229896221696, epoch: 20000/20000
Loss: 0.22345246510001943, epoch: 100/2000
Loss: 0.16017754543774712, epoch: 200/2000
Loss

Loss: 0.23522102574054463, epoch: 200/2000
Loss: 0.24665250385335297, epoch: 300/2000
Loss: 0.19849107980915226, epoch: 400/2000
Loss: 0.23920353477494544, epoch: 500/2000
Loss: 0.203590065983904, epoch: 600/2000
Loss: 0.2773278021970444, epoch: 700/2000
Loss: 0.32617663056765595, epoch: 800/2000
Loss: 0.2632875742899804, epoch: 900/2000
Loss: 0.2182682278562731, epoch: 1000/2000
Loss: 0.2252726635849187, epoch: 1100/2000
Loss: 0.24862327285278943, epoch: 1200/2000
Loss: 0.24254760437749195, epoch: 1300/2000
Loss: 0.3062679594504557, epoch: 1400/2000
Loss: 0.2495810280791987, epoch: 1500/2000
Loss: 0.21644062059480457, epoch: 1600/2000
Loss: 0.2245730149128345, epoch: 1700/2000
Loss: 0.22122190058241933, epoch: 1800/2000
Loss: 0.21074829901175982, epoch: 1900/2000
Loss: 0.2248999558862403, epoch: 2000/2000
Loss: 0.6537083973260203, epoch: 20000/20000
Loss: 0.3488064782790869, epoch: 100/2000
Loss: 0.3099653602421818, epoch: 200/2000
Loss: 0.20816161070418854, epoch: 300/2000
Loss: 0.18

Loss: 0.13156992950014987, epoch: 300/2000
Loss: 0.14092098938505132, epoch: 400/2000
Loss: 0.1919726931801478, epoch: 500/2000
Loss: 0.23054112204555594, epoch: 600/2000
Loss: 0.16626174152363796, epoch: 700/2000
Loss: 0.15137437751174576, epoch: 800/2000
Loss: 0.18440768630914686, epoch: 900/2000
Loss: 0.12740980422048998, epoch: 1000/2000
Loss: 0.10167242664286855, epoch: 1100/2000
Loss: 0.1504731668041996, epoch: 1200/2000
Loss: 0.11921892539710377, epoch: 1300/2000
Loss: 0.14774978416057247, epoch: 1400/2000
Loss: 0.1311694460788037, epoch: 1500/2000
Loss: 0.1592828516462128, epoch: 1600/2000
Loss: 0.16440131547503345, epoch: 1700/2000
Loss: 0.14174810367030308, epoch: 1800/2000
Loss: 0.1432426969742558, epoch: 1900/2000
Loss: 0.13266068392934846, epoch: 2000/2000
Loss: 0.43874691730365784, epoch: 20000/20000
Loss: 0.14416931882778922, epoch: 100/2000
Loss: 0.1714526604961442, epoch: 200/2000
Loss: 0.17704330772670876, epoch: 300/2000
Loss: 0.13875837789212958, epoch: 400/2000
Los

Loss: 0.18282700623131415, epoch: 300/2000
Loss: 0.13029484862409066, epoch: 400/2000
Loss: 0.2074860267661367, epoch: 500/2000
Loss: 0.1583964038956625, epoch: 600/2000
Loss: 0.15143726127970297, epoch: 700/2000
Loss: 0.17415325166668713, epoch: 800/2000
Loss: 0.15491667523065217, epoch: 900/2000
Loss: 0.2397127616838683, epoch: 1000/2000
Loss: 0.18349260707858342, epoch: 1100/2000
Loss: 0.13842688348593307, epoch: 1200/2000
Loss: 0.12179281309512802, epoch: 1300/2000
Loss: 0.22252087830283376, epoch: 1400/2000
Loss: 0.15840019854741624, epoch: 1500/2000
Loss: 0.19512944494289514, epoch: 1600/2000
Loss: 0.2263850815880967, epoch: 1700/2000
Loss: 0.24080184640636912, epoch: 1800/2000
Loss: 0.2547519572926033, epoch: 1900/2000
Loss: 0.24617485107390946, epoch: 2000/2000
Loss: 0.569923524436482, epoch: 20000/20000
Loss: 0.16964159001317491, epoch: 100/2000
Loss: 0.2210728292082325, epoch: 200/2000
Loss: 0.18969343333793426, epoch: 300/2000
Loss: 0.20303969134194708, epoch: 400/2000
Loss:

Loss: 0.27849550983112326, epoch: 400/2000
Loss: 0.24647075974194896, epoch: 500/2000
Loss: 0.31234966967299127, epoch: 600/2000
Loss: 0.22315613858155828, epoch: 700/2000
Loss: 0.2331520236686157, epoch: 800/2000
Loss: 0.23378536907173858, epoch: 900/2000
Loss: 0.36700929157978357, epoch: 1000/2000
Loss: 0.27872614377494537, epoch: 1100/2000
Loss: 0.24323632851001253, epoch: 1200/2000
Loss: 0.21028316447009487, epoch: 1300/2000
Loss: 0.2598333128761773, epoch: 1400/2000
Loss: 0.21330964507477068, epoch: 1500/2000
Loss: 0.22404994224189506, epoch: 1600/2000
Loss: 0.23039462719667814, epoch: 1700/2000
Loss: 0.22756432874640747, epoch: 1800/2000
Loss: 0.22514095868259854, epoch: 1900/2000
Loss: 0.20367754249109865, epoch: 2000/2000
Loss: 1.8608081126178875, epoch: 20000/20000
Loss: 0.2485665749718325, epoch: 100/2000
Loss: 0.20865985781770452, epoch: 200/2000
Loss: 0.15959774820551167, epoch: 300/2000
Loss: 0.14791509858410917, epoch: 400/2000
Loss: 0.1914439482838449, epoch: 500/2000
Lo

Loss: 0.2157185688824182, epoch: 400/2000
Loss: 0.2506934189398796, epoch: 500/2000
Loss: 0.21586786537679262, epoch: 600/2000
Loss: 0.15935867732101375, epoch: 700/2000
Loss: 0.18190957682380948, epoch: 800/2000
Loss: 0.2146523096829689, epoch: 900/2000
Loss: 0.1879617823044434, epoch: 1000/2000
Loss: 0.19168829504855966, epoch: 1100/2000
Loss: 0.17865714313170042, epoch: 1200/2000
Loss: 0.1986868959906313, epoch: 1300/2000
Loss: 0.17737015705574494, epoch: 1400/2000
Loss: 0.17048136830780597, epoch: 1500/2000
Loss: 0.18701670530800604, epoch: 1600/2000
Loss: 0.14495544462517038, epoch: 1700/2000
Loss: 0.1794288943920821, epoch: 1800/2000
Loss: 0.2155886181584458, epoch: 1900/2000
Loss: 0.21818583691833143, epoch: 2000/2000
Loss: 1.654492653557402, epoch: 20000/20000
Loss: 0.30726548392711067, epoch: 100/2000
Loss: 0.28276857624580953, epoch: 200/2000
Loss: 0.47410277755413555, epoch: 300/2000
Loss: 0.2309204041713032, epoch: 400/2000
Loss: 0.19662280012608246, epoch: 500/2000
Loss: 0

Loss: 0.15677818622375622, epoch: 500/2000
Loss: 0.23372419644959858, epoch: 600/2000
Loss: 0.1732529906440526, epoch: 700/2000
Loss: 0.22959982724624464, epoch: 800/2000
Loss: 0.18974646404429985, epoch: 900/2000
Loss: 0.13927186566558333, epoch: 1000/2000
Loss: 0.25290386545468124, epoch: 1100/2000
Loss: 0.29018674973493275, epoch: 1200/2000
Loss: 0.2949228368577172, epoch: 1300/2000
Loss: 0.22400831496983864, epoch: 1400/2000
Loss: 0.12582101400179915, epoch: 1500/2000
Loss: 0.22439460677475664, epoch: 1600/2000
Loss: 0.1468320715064812, epoch: 1700/2000
Loss: 0.18098437592858055, epoch: 1800/2000
Loss: 0.10570261509520706, epoch: 1900/2000
Loss: 0.15756055586164802, epoch: 2000/2000
Loss: 0.8618364331290268, epoch: 20000/20000
Loss: 0.2579929000538169, epoch: 100/2000
Loss: 0.2325393980472923, epoch: 200/2000
Loss: 0.23508794678342979, epoch: 300/2000
Loss: 0.30125340467015826, epoch: 400/2000
Loss: 0.1405731657462596, epoch: 500/2000
Loss: 0.1691563155804413, epoch: 600/2000
Loss:

In [None]:
#Gillar
alists = []
alists.append(pd.read_csv(f'../data/Peterpeptides.csv'))
alists.append(pd.read_csv(f'../data/GilarSample.csv'))

alists.append(pd.read_csv(f'../data/Roca.csv'))
alists.append(pd.read_csv(f'../data/Peter32.csv'))
alists.append(pd.read_csv(f'../data/Eosin.csv'))
alists.append(pd.read_csv(f'../data/Alizarin.csv'))
alists.append(pd.read_csv(f'../data/Controlmix2.csv'))

taus = [.25, .25, 10]
losses_gen_pp = []
losses_from_gen_pp = []
time_start_pp = time.perf_counter()
exp = ExperimentAnalytes(k0 = alists[0].k0.values, S = alists[0].S.values, h=0.001,run_time=10.0)
for i in range(100):
    pol = PolicyGeneral(
                phi = nn.Sequential(
                    PermEqui2_max(2, 5),
                    nn.ELU(inplace=True),
                    PermEqui2_max(5, 5),
                    nn.ELU(inplace=True),
                    PermEqui2_max(5, 5),
                    nn.ELU(inplace=True),
                ),
                rho = nn.Sequential(
                    nn.Linear(5, 5),
                    nn.ELU(inplace=True),
                    nn.Linear(5, 5),
                    nn.ELU(inplace=True),
                    Rho(n_steps=len(taus), hidden=5, in_dim=5, sigma_max=.3, sigma_min=.01),
                )
            )
    reinforce_gen(
        alists = alists[1:], 
        policy = pol, 
        delta_taus = taus, 
        num_episodes = 20_000, 
        sample_size = 10,
        batch_size = 1, 
        lr = .05, 
        optim = lambda a, b: torch.optim.SGD(a, b),
        lr_decay_factor = 0.75,
        lr_milestones = 2500,
        print_every = 20_000,
        baseline = .55,
        max_norm = 1.5,
        max_rand_analytes = 30,
        min_rand_analytes = 15,
        rand_prob = 0.7
    )
    mu, _ = pol(torch.Tensor(alists[0][['S', 'lnk0']].values))
    exp.reset()
    exp.run_all(mu.detach().numpy(), taus)
    losses_gen_pp.append(exp.loss())
    _, _, mus, _ = reinforce_single_from_gen(
        alist=alists[0], 
        policy=pol, 
        delta_taus=taus, 
        num_episodes=2000, 
        sample_size=10, 
        lr=.05, 
        optim=torch.optim.SGD,
        lr_decay_factor=.5,
        lr_milestones=1000,
        print_every=100,
        baseline=0.55,
        max_norm=2.
    )
    exp.reset()
    exp.run_all(mus[-1,:], taus)
    losses_from_gen_pp.append(exp.loss())
    
time_end_pp = time.perf_counter()
# time 523.31177600672

In [19]:
(time_end_pp - time_start_pp)/100

523.31177600672

In [16]:
plt.hist(losses_gen_pp, bins=50)
plt.title(f"Final Result Distribution for GenModel (PeterPeptides)")
plt.xlabel("Loss")
plt.ylabel("Occurrences")

Text(0, 0.5, 'Occurrences')

In [17]:
plt.hist(losses_from_gen_pp, bins=50)
plt.title(f"Final Result Distribution for GenModel + FineTune(PeterPeptides)")
plt.xlabel("Loss")
plt.ylabel("Occurrences")

Text(0, 0.5, 'Occurrences')

In [None]:
#Gillar
taus = [.25, .25, 10]
pol = PolicyGeneral(
            phi = nn.Sequential(
                PermEqui2_max(2, 5),
                nn.ELU(inplace=True),
                PermEqui2_max(5, 5),
                nn.ELU(inplace=True),
                PermEqui2_max(5, 5),
                nn.ELU(inplace=True),
            ),
            rho = nn.Sequential(
                nn.Linear(5, 5),
                nn.ELU(inplace=True),
                nn.Linear(5, 5),
                nn.ELU(inplace=True),
                Rho(n_steps=len(taus), hidden=5, in_dim=5, sigma_max=.3, sigma_min=.01),
            )
        )
reinforce_gen(
    alists = alists, 
    policy = pol, 
    delta_taus = taus, 
    num_episodes = 40_000, 
    sample_size = 10,
    batch_size = 1, 
    lr = .05, 
    optim = lambda a, b: torch.optim.SGD(a, b),
    lr_decay_factor = 0.75,
    lr_milestones = 5000,
    print_every = 20_000,
    baseline = .55,
    max_norm = 1.5,
    max_rand_analytes = 30,
    min_rand_analytes = 15,
    rand_prob = 0.7
)

In [None]:
plt.plot(average_over_equal_intervals(np.array(L8).mean(0), 500), label="L8")
plt.title("Loss (average of 500 random sets)")
plt.xlabel("Episode")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
i = 0
exp = ExperimentAnalytes(k0 = alists[i].k0.values, S = alists[i].S.values, h=0.001,run_time=10.0)
mu, sig = pol(torch.Tensor(alists[i][['S', 'lnk0']].values))
exp.run_all(mu.detach().numpy(), taus)

exp.print_analytes(title=f"Solvent Strength Program(Gen)\nLoss:{round(exp.loss(), 4)}", rc=(10,10), angle=40)

In [None]:
L, B, mus, sigmas = reinforce_single_from_gen(
        alist=alists[0], 
        policy=pol, 
        delta_taus=taus, 
        num_episodes=2000, 
        sample_size=10, 
        lr=.05, 
        optim=torch.optim.SGD,
        lr_decay_factor=.5,
        lr_milestones=1000,
        print_every=100,
        baseline=0.55,
        max_norm=2.
    )

In [None]:
exp.loss()

In [None]:
i = 0
exp = ExperimentAnalytes(k0 = alists[i].k0.values, S = alists[i].S.values, h=0.001,run_time=10.0)
exp.run_all(mus[-1,:], taus)
exp.print_analytes(title=f"Solvent Strength Program(Iso)\nLoss:{round(exp.loss(), 4)}", rc=(10,10), angle=40)

In [None]:
exp = ExperimentAnalytes(k0 = alists[i].k0.values, S = alists[i].S.values, h=0.001,run_time=10.0)
exp.run_all(B, taus)
exp.print_analytes(title=f"Best Solvent Strength Program\nLoss:{round(exp.loss(), 4)}", rc=(10,10), angle=40)

In [None]:
plt.plot(mus)

In [None]:
plt.plot(sigmas)

In [None]:
pol = PolicyGeneral(
            phi = nn.Sequential(
                PermEqui2_max(2, 3),
                nn.ELU(inplace=True),
                PermEqui2_max(3, 3),
                nn.ELU(inplace=True),
                PermEqui2_max(3, 3),
                nn.ELU(inplace=True),
            ),
            rho = RhoTime(n_steps=3, hidden=5, in_dim=3, sigma_max=.3, sigma_min=.02)
        )
losses = reinforce_delta_tau_gen(
    alists = alists, 
    policy = pol,
    num_episodes = 10000, 
    batch_size = 10, 
    lr = .1, 
    optim = lambda a, b: torch.optim.SGD(a, b),
    lr_decay_factor= 0.75,
    lr_milestones=1000,
    print_every = 100,
    baseline = .55,
    max_norm = 1.2,
    max_rand_analytes = 35,
    min_rand_analytes = 18,
    rand_prob = 0.7
)

In [None]:
plt.plot(np.linspace(0, 200000, 200),average_over_equal_intervals(losses[0], 1000))
plt.title("Loss [variable delta tau] (average of 1000 random sets)")
plt.xlabel("Episode")
plt.ylabel("Loss")
plt.show()

In [None]:
i = 6
exp = ExperimentAnalytes(k0 = alists[i].k0.values, S = alists[i].S.values, h=0.001,run_time=10.0)
mu, _ = pol(torch.Tensor(alists[i][['S', 'lnk0']].values))
mu = mu.tolist()
mu.append(10.)
exp.run_all(mu[0:3], mu[3:])

exp.print_analytes(title=f"Solvent Strength Program\nLoss:{round(exp.loss(), 4)}", rc=(10,10), angle=40)

In [None]:
L, B, mus, sigmas = reinforce_single_from_delta_tau_gen(
        alist=alists[i], 
        policy=pol,
        num_episodes=5000, 
        batch_size=10, 
        lr=.1, 
        optim=torch.optim.SGD,
        lr_decay_factor=.5,
        lr_milestones=500,
        print_every=500,
        baseline=0.65,
        max_norm=1.2
    )

In [None]:
plt.plot(mus)

In [None]:
plt.plot(L)

In [None]:
exp = ExperimentAnalytes(k0 = alists[i].k0.values, S = alists[i].S.values, h=0.001,run_time=10.0)
mu = mus[-1].tolist()
mu.append(10.)
exp.run_all(mu[0:3], mu[3:])

exp.print_analytes(title=f"Solvent Strength Program\nLoss:{round(exp.loss(), 4)}", rc=(10,10), angle=40)

# NEW

In [None]:
class PolicyGeneralISO(nn.Module):
    def __init__(self, 
            phi: nn.Module,
            rho: nn.Module
        ) -> None:
        """
        Constructor for PolicyTime torch Module.

        Parameters
        ----------
        phi: nn.Module
            The network that encodes the analyte set to a single 
            vector (embedding)
        rho: nn.Module
            The network that outputs the programe for separation
            returns mean and standard deviation of the action space

        Ex:
        For a 4 step solvent gradient programe the generalized policy 
        with 3 elements embedding for the analyte set and intermediate
        layers of 5 neurons.
        policy = PolicyGeneral(
            phi = nn.Sequential(
                PermEqui1_max(2, 5),
                nn.ELU(inplace=True),
                PermEqui1_max(5, 5),
                nn.ELU(inplace=True),
                PermEqui1_max(5, 3),
                nn.ELU(inplace=True),
            ),
            rho = Rho(4, 5, 3, .3, .05)
        )
        """
        super().__init__()

        self.phi = phi
        self.rho = rho
        
    def forward(self, x, y):
        phi_output = self.phi(x)
        sum_output = phi_output.sum(0, keepdim=True)
        sum_output = torch.cat([sum_output, y], 1)
        mu, sigma = self.rho(sum_output)
        return mu, sigma

#############################################################################
def reinforce_gen_iso(
        alists: Iterable[pd.DataFrame],
        policy: PolicyGeneral, 
        delta_taus: Iterable[float], 
        num_episodes: int = 1000, 
        sample_size: int = 10,
        batch_size : int = 10,
        lr: float = 1., 
        optim = torch.optim.SGD,
        lr_decay_factor: float = 1.,
        lr_milestones: Union[int, Iterable[int]] = 1000,
        rand_prob: float = .2,
        max_rand_analytes: int = 30,
        min_rand_analytes: int = 10,
        print_every: int = 100,
        baseline: float = 0.,
        max_norm: float = None,
        beta: float = .0,
        weights: list = [1., 1.]
    ):
    """
    Run Reinforcement Learning for a single set learning.

    alists: Iterable[pd.DataFrame]
        A list with pd.Dataframes for each dataset used to train on. 
    policy: PolicyGeneral
        The policy that learns the optimal values for the solvent
        strength program.
    delta_taus: Iterable[float]
        Iterable list with the points of solvent strength change.
        MUST be the same length as policy.n_steps
    num_episodes = 1000
        Number of learning steps.
    sample_size = 10
        Number of samples taken from the action distribution to perform 
        Expected loss for the distribution of actions.
    batch_size:
        Number of experiments to run in order to aproximate the true gradient.
    lr = 1.
        Learning rate.
    optim = torch.optim.SGD
        Optimizer that performs weight update using gradients.
        By defauld is Stochastic Gradient Descent.
    lr_decay_factor: float
        Learning rate decay factor used for the LRScheduler.
        lr is updated according to lr = lr ** lr_decay_factor.
    lr_milestones: Union[int, Iterable[int]]
        Milestone episode/s to update the learning rate.
        If it is int StepLR is used where lr is changed every lr_milestones.
        If it is a list of ints then at that specific episode the lr
        will be changed.
    rand_prob: float = .2
        The probability to draw a random subset from all the analytes.
        1 - rand_prob is the probability to use a "real" set (provided in
        alists).
    max_rand_analytes: int = 30
        The maximum number of analytes in the randomly drawn set.
    min_rand_analytes: int = 10
        The minimum number of analytes in the randomly drawn set.
    print_every = 100,
        Number of episodes to print the average loss on.
    weights = [1., 1.]
        Weigths of the errors to consider, first one is for the Placement Error,
        second one is for Overlap Error, By default both have the same wights.
    baseline = 0.
        Baseline value for the REINFORCE algorithm.
    max_norm = None
        Maximal value for the Neural Network Norm2.
    beta = .0
        Entropy Regularization term, is used for more exploration.
        By defauld is disabled.
    Returns
    -------
    (losses, best_program, mus, sigmas)
    losses: np.ndarray
        Expected loss of the action distribution over the whole learning
        process.
    """

    losses = []
    perfect_loss = []
    exps = []

    # Make ExperimentAnalytes object for the given analyte sets for time saving purpose
    for alist in alists:
        exps.append(ExperimentAnalytes(k0 = alist.k0.values, S = alist.S.values, h=0.001, run_time=10.0))

    num_exps = len(alists)

    all_analytes = pd.concat(alists, sort=True)[['k0', 'S', 'lnk0']]

    # Optimizer
    optimizer = optim(policy.parameters(), lr)

    # LR sheduler
    if isinstance(lr_milestones, list) or isinstance(lr_milestones, np.ndarray):
        scheduler = MultiStepLR(optimizer, lr_milestones, gamma=lr_decay_factor)
    else:
        scheduler = StepLR(optimizer, lr_milestones, gamma=lr_decay_factor)

    J_batch = 0

    for n in range(num_episodes):
        # the set to use for the experiment.
        if random() < rand_prob:
            dataframe = all_analytes.sample(randint(min_rand_analytes, max_rand_analytes))
            input_data = torch.tensor(dataframe[['S', 'lnk0']].values, dtype=torch.float32)
            exp = ExperimentAnalytes(k0 = dataframe.k0.values, S = dataframe.S.values, h=0.001, run_time=10.0)

        else:
            # Choose a random set
            set_index = randint(0, num_exps - 1) 
            exp = exps[set_index]
            input_data = torch.tensor(alists[set_index][['S', 'lnk0']].values, dtype=torch.float32)
        
        expected_loss = 10
        for phi in np.linspace(0, 1, 100):
            exp.reset()
            exp.step(phi, 1.)
            if exp.loss() < expected_loss:
                phi_iso = phi
                expected_loss = exp.loss()    
        
        # compute distribution parameters (Normal)
        mu, sigma = policy.forward(input_data, torch.tensor([[phi_iso]]))

        # Sample some values from the actions distributions
        programs = sample(mu, sigma, sample_size)
        
        # Fit the sampled data to the constraint [0,1]
        constr_programs = programs.clone()
        constr_programs[constr_programs > 1] = 1
        constr_programs[constr_programs < 0] = 0
        
        J = 0
        expected_loss = 0
        for i in range(sample_size):
            exp.reset()            
            exp.run_all(constr_programs[i].data.numpy(), delta_taus)

            error = exp.loss(weights)
            expected_loss += error
            log_prob_ = log_prob(programs[i], mu, sigma)
            J += (error - baseline) * log_prob_ - beta * torch.exp(log_prob_) * log_prob_
        
        losses.append(expected_loss/sample_size)
        perfect_loss.append(exp.perfect_loss(weights))
        if (n + 1) % print_every == 0:
            print(f"Loss: {losses[-1]}, epoch: {n+1}/{num_episodes}")

        J_batch += J/sample_size
        if (i + 1) % batch_size == 0:
            J_batch /= batch_size
            optimizer.zero_grad()
            # Calculate gradients
            J_batch.backward()

            if max_norm:
                torch.nn.utils.clip_grad_norm_(policy.parameters(), max_norm)

            # Apply gradients
            optimizer.step()

            # learning rate decay
            scheduler.step()

            J_batch = 0
        
    return np.array(losses), np.array(perfect_loss)

In [None]:
taus = [.25, .25, 10]
L5_b1_iso = []
PL5_b1_iso = []
for i in range(5):
    pol = PolicyGeneralISO(
                phi = nn.Sequential(
                    PermEqui2_max(2, 5),
                    nn.ELU(inplace=True),
                    PermEqui2_max(5, 5),
                    nn.ELU(inplace=True),
                    PermEqui2_max(5, 5),
                    nn.ELU(inplace=True),
                ),
                rho = nn.Sequential(
                    nn.Linear(6, 6),
                    nn.ELU(inplace=True),
                    nn.Linear(6, 6),
                    nn.ELU(inplace=True),
                    Rho(n_steps=len(taus), hidden=6, in_dim=6, sigma_max=.3, sigma_min=.01),
                )
            )
    l, p = reinforce_gen_iso(
        alists = alists, 
        policy = pol, 
        delta_taus = taus, 
        num_episodes = 20_000, 
        sample_size = 10,
        batch_size = 1, 
        lr = .05, 
        optim = lambda a, b: torch.optim.SGD(a, b),
        lr_decay_factor = 0.75,
        lr_milestones = 5000,
        print_every = 5000,
        baseline = .55,
        max_norm = 1.7,
        max_rand_analytes = 30,
        min_rand_analytes = 10,
        rand_prob = 0.7
    )
    L5_b1_iso.append(l)
    PL5_b1_iso.append(p)

In [None]:
plt.figure()
plt.plot((np.array(L5_b1) - np.array(PL5_b1)).reshape((10, 200, 100)).mean(2).T)
plt.ylim((0.3, 1))
plt.title("Simple batch one small")

In [None]:
plt.figure()
#plt.plot((np.array(L5_b1_iso) - np.array(PL5_b1_iso)).mean(0).reshape(-1, 100).mean(1), label = "ISO")
plt.plot((np.array(L5_b1) - np.array(PL5_b1)).mean(0).reshape(-1, 100).mean(1), label = "Small NN")
#plt.plot((np.array(L5_b1_big) - np.array(PL5_b1_big)).mean(0).reshape(-1, 100).mean(1), label = "Big NN")
plt.title("Small rho NN vs Big rho NN")
plt.legend()

In [None]:
(np.array(L5_b1) - np.array(PL5_b1)).mean(0).reshape(-1, 100).mean(1)