## Library imports

In [19]:
import numpy as np
import pandas as pd
from chromatography import ExperimentAnalytes
from separation_utility import *
from torch import optim, tensor
import matplotlib.pyplot as plt
import time
%matplotlib qt

## Dataset Import

In [20]:
sample = 'Peterpeptides'
alist = pd.read_csv(f'../data/{sample}.csv')
# GilarSample
# Peterpeptides
# Roca
# Peter32
# Eosin
# Alizarin
# Controlmix1
# Controlmix2
#alist['k0'] = np.exp(alist.lnk0)

## Useful functions definition

In [21]:
def loss_field(exp, taus, N = 200):
    phis = np.linspace(0, 1, N)
    losses = np.zeros((N, N))
    j = 0
    for phi1 in phis:
        i = 0
        for phi2 in phis:
            exp.reset()
            exp.run_all([phi1, phi2], taus)
            losses[i, j] = exp.loss()
            i += 1
        j += 1
    X, Y = np.meshgrid(phis, phis)
    
    return X, Y, losses
def average_over_equal_intervals(arr, interval):
    return np.mean(arr.reshape(-1, interval), axis=1)

In [32]:
# Parameters
sample = 'Peterpeptides'
alist = pd.read_csv(f'../data/{sample}.csv')
# GilarSample
run_time_lim = 10.0
sigma_max = 0.3
delta_taus = [.25, .25, 10]
num_episodes = 5000
sample_size = 10
lr = .01
optim = lambda a, b: torch.optim.SGD(a, b)#, momentum=0.65)
lr_decay_factor = .5
lr_milestones = 1000
print_every= 100
baseline = 0.55
max_norm = 2.
exp = ExperimentAnalytes(k0 = alist.k0.values, S = alist.S.values, h=0.001, run_time=10.0)
losses_dist = []

start_process_peterp = time.perf_counter()

for i in range(50):
    pol = PolicySingle(len(delta_taus), sigma_max = sigma_max)
    reinforce_one_set(
            exp, 
            pol, 
            delta_taus= delta_taus, 
            num_episodes=num_episodes, 
            sample_size=sample_size,
            optim=optim,
            lr=lr, 
            print_every=print_every,
            lr_decay_factor=lr_decay_factor,
            lr_milestones=lr_milestones,
            baseline=baseline,
            max_norm=max_norm,

        )
    exp.reset()
    exp.run_all(pol.mu.detach().numpy(), delta_taus)
    
    losses_dist.append(exp.loss())
end_process_peterp = time.perf_counter()
# time 104.34018015212001

Loss: 1.6380038139188333, epoch: 100/5000
Loss: 1.2211607928921329, epoch: 200/5000
Loss: 1.0522902627203907, epoch: 300/5000
Loss: 0.952018517947665, epoch: 400/5000
Loss: 0.9246127765196384, epoch: 500/5000
Loss: 1.046780273813264, epoch: 600/5000
Loss: 0.9427216873809353, epoch: 700/5000
Loss: 0.8003280160285067, epoch: 800/5000
Loss: 0.7687648861073912, epoch: 900/5000
Loss: 0.8331133116345356, epoch: 1000/5000
Loss: 0.8371346356633456, epoch: 1100/5000
Loss: 0.7768087577063533, epoch: 1200/5000
Loss: 0.9956787898091894, epoch: 1300/5000
Loss: 0.7400051412402013, epoch: 1400/5000
Loss: 0.7709807625833145, epoch: 1500/5000
Loss: 0.8152594364773215, epoch: 1600/5000
Loss: 0.7255885957636378, epoch: 1700/5000
Loss: 0.769068063762485, epoch: 1800/5000
Loss: 0.8560800070257581, epoch: 1900/5000
Loss: 0.7691545178591361, epoch: 2000/5000
Loss: 0.7550358980113537, epoch: 2100/5000
Loss: 0.7728580088284421, epoch: 2200/5000
Loss: 0.7331194503561166, epoch: 2300/5000
Loss: 0.773917148947530

Loss: 0.9049174082412765, epoch: 4300/5000
Loss: 0.7486066687331261, epoch: 4400/5000
Loss: 1.0248325707067418, epoch: 4500/5000
Loss: 0.7518658829130281, epoch: 4600/5000
Loss: 0.8049082066384713, epoch: 4700/5000
Loss: 0.7447070030881454, epoch: 4800/5000
Loss: 0.8581141306430151, epoch: 4900/5000
Loss: 0.7742668634623393, epoch: 5000/5000
Loss: 1.5375397445986523, epoch: 100/5000
Loss: 1.3197034934856022, epoch: 200/5000
Loss: 1.449704616339422, epoch: 300/5000
Loss: 1.092069106801763, epoch: 400/5000
Loss: 0.9302447354434509, epoch: 500/5000
Loss: 0.7148321338067863, epoch: 600/5000
Loss: 0.9130053945356599, epoch: 700/5000
Loss: 0.9158695863876127, epoch: 800/5000
Loss: 0.8248528912123423, epoch: 900/5000
Loss: 0.8544395128741332, epoch: 1000/5000
Loss: 0.7905548128418807, epoch: 1100/5000
Loss: 0.8236046427047763, epoch: 1200/5000
Loss: 0.8216953325745833, epoch: 1300/5000
Loss: 0.7890165562283818, epoch: 1400/5000
Loss: 0.7588939338924537, epoch: 1500/5000
Loss: 0.81870699268638

Loss: 0.947933238725513, epoch: 3600/5000
Loss: 0.7297641666662449, epoch: 3700/5000
Loss: 0.9116000517843099, epoch: 3800/5000
Loss: 0.8085210450013246, epoch: 3900/5000
Loss: 0.8552287623439998, epoch: 4000/5000
Loss: 0.7885586707587258, epoch: 4100/5000
Loss: 0.7522925408527261, epoch: 4200/5000
Loss: 0.8955967016927231, epoch: 4300/5000
Loss: 0.801703501021503, epoch: 4400/5000
Loss: 0.8282134354976989, epoch: 4500/5000
Loss: 0.8400359686208482, epoch: 4600/5000
Loss: 0.9603694843254406, epoch: 4700/5000
Loss: 0.8110922379359197, epoch: 4800/5000
Loss: 0.7863123334787191, epoch: 4900/5000
Loss: 0.8393853073022799, epoch: 5000/5000
Loss: 1.8637149692057737, epoch: 100/5000
Loss: 1.8326497308979826, epoch: 200/5000
Loss: 1.8386125541956244, epoch: 300/5000
Loss: 1.831359997523731, epoch: 400/5000
Loss: 1.8225264965172947, epoch: 500/5000
Loss: 1.5909461945302197, epoch: 600/5000
Loss: 1.860034614688153, epoch: 700/5000
Loss: 1.809194854269056, epoch: 800/5000
Loss: 1.795321194902862,

Loss: 1.10367009169646, epoch: 2900/5000
Loss: 0.8346982714825819, epoch: 3000/5000
Loss: 0.9253201979725076, epoch: 3100/5000
Loss: 0.8827250666407449, epoch: 3200/5000
Loss: 0.9736911138976565, epoch: 3300/5000
Loss: 0.9505516433382801, epoch: 3400/5000
Loss: 0.9183821028301938, epoch: 3500/5000
Loss: 0.9255314384405816, epoch: 3600/5000
Loss: 0.8988593112685331, epoch: 3700/5000
Loss: 0.9068148200913322, epoch: 3800/5000
Loss: 0.8612118051349, epoch: 3900/5000
Loss: 0.8957208678499065, epoch: 4000/5000
Loss: 0.811135796482008, epoch: 4100/5000
Loss: 0.91009331473381, epoch: 4200/5000
Loss: 0.9387633409718072, epoch: 4300/5000
Loss: 0.8536006842158954, epoch: 4400/5000
Loss: 0.9220505843762071, epoch: 4500/5000
Loss: 0.9030378009996811, epoch: 4600/5000
Loss: 0.8775427091761234, epoch: 4700/5000
Loss: 0.9128216931228608, epoch: 4800/5000
Loss: 0.9379833434420496, epoch: 4900/5000
Loss: 0.910710449954683, epoch: 5000/5000
Loss: 1.9018566447534024, epoch: 100/5000
Loss: 1.4794102710526

Loss: 0.8149814836454278, epoch: 2200/5000
Loss: 0.8545227751066218, epoch: 2300/5000
Loss: 0.9040616499172313, epoch: 2400/5000
Loss: 0.7579146246508707, epoch: 2500/5000
Loss: 0.8979964031098279, epoch: 2600/5000
Loss: 1.007433518621756, epoch: 2700/5000
Loss: 0.9766635873175549, epoch: 2800/5000
Loss: 0.8524252444519101, epoch: 2900/5000
Loss: 0.783074102366925, epoch: 3000/5000
Loss: 0.7745866827344636, epoch: 3100/5000
Loss: 0.9230140957531298, epoch: 3200/5000
Loss: 0.8688265137628564, epoch: 3300/5000
Loss: 0.8928045700858018, epoch: 3400/5000
Loss: 0.9404454294840285, epoch: 3500/5000
Loss: 1.0188518179213117, epoch: 3600/5000
Loss: 0.9739117143777264, epoch: 3700/5000
Loss: 0.9950058403520213, epoch: 3800/5000
Loss: 0.9007863858572005, epoch: 3900/5000
Loss: 0.8019029331969323, epoch: 4000/5000
Loss: 0.9630438723934096, epoch: 4100/5000
Loss: 0.8321116409506683, epoch: 4200/5000
Loss: 0.8989373381810231, epoch: 4300/5000
Loss: 0.800999098952744, epoch: 4400/5000
Loss: 1.014114

Loss: 0.7517254269381243, epoch: 1500/5000
Loss: 0.8155635181711469, epoch: 1600/5000
Loss: 0.7228741562004373, epoch: 1700/5000
Loss: 0.7842373564034198, epoch: 1800/5000
Loss: 0.7688953645493224, epoch: 1900/5000
Loss: 0.8159322532678285, epoch: 2000/5000
Loss: 0.7463955207461126, epoch: 2100/5000
Loss: 0.8181914464513241, epoch: 2200/5000
Loss: 0.7448269925702024, epoch: 2300/5000
Loss: 0.8222018918472067, epoch: 2400/5000
Loss: 0.8056858074351634, epoch: 2500/5000
Loss: 0.7354085805709222, epoch: 2600/5000
Loss: 0.7472625066050768, epoch: 2700/5000
Loss: 0.7896221927646992, epoch: 2800/5000
Loss: 0.8031573148369974, epoch: 2900/5000
Loss: 0.8511802971741336, epoch: 3000/5000
Loss: 0.7865879770294473, epoch: 3100/5000
Loss: 0.8475549067376791, epoch: 3200/5000
Loss: 0.821079142635269, epoch: 3300/5000
Loss: 0.7506197954167196, epoch: 3400/5000
Loss: 0.7158531735130634, epoch: 3500/5000
Loss: 0.8419289535717944, epoch: 3600/5000
Loss: 0.738215194083175, epoch: 3700/5000
Loss: 0.72230

Loss: 1.3157082093219898, epoch: 700/5000
Loss: 0.8694898773116988, epoch: 800/5000
Loss: 1.0261330665024215, epoch: 900/5000
Loss: 0.9892560111249583, epoch: 1000/5000
Loss: 1.1105904421001758, epoch: 1100/5000
Loss: 0.839083330133543, epoch: 1200/5000
Loss: 0.9281425504364265, epoch: 1300/5000
Loss: 0.8551502221106526, epoch: 1400/5000
Loss: 0.9386773325246061, epoch: 1500/5000
Loss: 0.8138457080778367, epoch: 1600/5000
Loss: 0.9399720759524512, epoch: 1700/5000
Loss: 0.8154668664240967, epoch: 1800/5000
Loss: 0.8625575940814251, epoch: 1900/5000
Loss: 0.8159101070780418, epoch: 2000/5000
Loss: 0.9231300494405229, epoch: 2100/5000
Loss: 0.8185789707253127, epoch: 2200/5000
Loss: 0.7758837538962848, epoch: 2300/5000
Loss: 0.7838818176131899, epoch: 2400/5000
Loss: 0.7797048765284373, epoch: 2500/5000
Loss: 0.7379156456471364, epoch: 2600/5000
Loss: 0.8452375882126484, epoch: 2700/5000
Loss: 0.8024409730545923, epoch: 2800/5000
Loss: 0.8685288782176286, epoch: 2900/5000
Loss: 0.8087460

Loss: 0.7337608589389552, epoch: 4900/5000
Loss: 0.7533061280051447, epoch: 5000/5000
Loss: 1.7742320065617956, epoch: 100/5000
Loss: 1.286007064439636, epoch: 200/5000
Loss: 1.0623726991786382, epoch: 300/5000
Loss: 1.2215674282439675, epoch: 400/5000
Loss: 1.126629832324903, epoch: 500/5000
Loss: 1.111891732190421, epoch: 600/5000
Loss: 1.3372425707819418, epoch: 700/5000
Loss: 1.0659082112167302, epoch: 800/5000
Loss: 1.0099311019166872, epoch: 900/5000
Loss: 0.9903969227428945, epoch: 1000/5000
Loss: 0.9016439636283764, epoch: 1100/5000
Loss: 0.9902467109041153, epoch: 1200/5000
Loss: 1.0635875685475533, epoch: 1300/5000
Loss: 0.9341477959509069, epoch: 1400/5000
Loss: 1.0295266869033421, epoch: 1500/5000
Loss: 1.1991037267687383, epoch: 1600/5000
Loss: 1.0085709454515266, epoch: 1700/5000
Loss: 0.918221501182242, epoch: 1800/5000
Loss: 0.9894726609415787, epoch: 1900/5000
Loss: 1.0402960607734184, epoch: 2000/5000
Loss: 0.9164311268483465, epoch: 2100/5000
Loss: 0.9725950381748778

Loss: 1.7317509083910512, epoch: 4200/5000
Loss: 1.698729089439836, epoch: 4300/5000
Loss: 1.7888667860468803, epoch: 4400/5000
Loss: 1.74155216436093, epoch: 4500/5000
Loss: 1.6960676683224563, epoch: 4600/5000
Loss: 1.7696520616226543, epoch: 4700/5000
Loss: 1.5836112488167855, epoch: 4800/5000
Loss: 1.6665381208123933, epoch: 4900/5000
Loss: 1.7150652421350323, epoch: 5000/5000
Loss: 1.8788371969750337, epoch: 100/5000
Loss: 1.7435380567026866, epoch: 200/5000
Loss: 1.8623691986588082, epoch: 300/5000
Loss: 1.5989732506563623, epoch: 400/5000
Loss: 1.777182566095976, epoch: 500/5000
Loss: 1.8416092335663303, epoch: 600/5000
Loss: 1.675234577376575, epoch: 700/5000
Loss: 1.8190489341165335, epoch: 800/5000
Loss: 1.658429812666632, epoch: 900/5000
Loss: 1.7504512052073227, epoch: 1000/5000
Loss: 1.7492402824885396, epoch: 1100/5000
Loss: 1.7361520598671731, epoch: 1200/5000
Loss: 1.757801083313021, epoch: 1300/5000
Loss: 1.6928806455456509, epoch: 1400/5000
Loss: 1.6851651649387651, e

Loss: 0.7903978674425118, epoch: 3500/5000
Loss: 0.7548506146782905, epoch: 3600/5000
Loss: 0.8820200047248219, epoch: 3700/5000
Loss: 0.7437981179270805, epoch: 3800/5000
Loss: 0.8970414248869772, epoch: 3900/5000
Loss: 0.8497732242190634, epoch: 4000/5000
Loss: 0.980877381528711, epoch: 4100/5000
Loss: 0.8329282970389611, epoch: 4200/5000
Loss: 0.7562916419614074, epoch: 4300/5000
Loss: 0.84918657553044, epoch: 4400/5000
Loss: 0.842198347816993, epoch: 4500/5000
Loss: 0.7032661199973471, epoch: 4600/5000
Loss: 0.7444903913129758, epoch: 4700/5000
Loss: 0.9061431649953434, epoch: 4800/5000
Loss: 0.8539675051653175, epoch: 4900/5000
Loss: 0.8799001639500632, epoch: 5000/5000
Loss: 1.7078557087093997, epoch: 100/5000
Loss: 1.8008088093353876, epoch: 200/5000
Loss: 1.7716088839975348, epoch: 300/5000
Loss: 1.5659602095262934, epoch: 400/5000
Loss: 1.7127548142805882, epoch: 500/5000
Loss: 1.7251088426876937, epoch: 600/5000
Loss: 1.412688631552711, epoch: 700/5000
Loss: 1.3405795327477, 

Loss: 0.8248572451041485, epoch: 2800/5000
Loss: 0.7277509949890965, epoch: 2900/5000
Loss: 0.8337858301093977, epoch: 3000/5000
Loss: 0.7897452551604671, epoch: 3100/5000
Loss: 0.7594994937983348, epoch: 3200/5000
Loss: 0.7043915145470835, epoch: 3300/5000
Loss: 0.7511481800717754, epoch: 3400/5000
Loss: 0.7604935448418444, epoch: 3500/5000
Loss: 0.7215915955487062, epoch: 3600/5000
Loss: 0.768686328453871, epoch: 3700/5000
Loss: 0.775845149728341, epoch: 3800/5000
Loss: 0.8218830611928517, epoch: 3900/5000
Loss: 0.7883252114496564, epoch: 4000/5000
Loss: 0.8522767706657574, epoch: 4100/5000
Loss: 0.8097827656309777, epoch: 4200/5000
Loss: 0.7569727229394243, epoch: 4300/5000
Loss: 0.7280287941812771, epoch: 4400/5000
Loss: 0.7267353210956139, epoch: 4500/5000
Loss: 0.7563020759163557, epoch: 4600/5000
Loss: 0.717480075443255, epoch: 4700/5000
Loss: 0.7746659043361154, epoch: 4800/5000
Loss: 0.8338692745010826, epoch: 4900/5000
Loss: 0.7368546069630117, epoch: 5000/5000
Loss: 1.633162

Loss: 0.9926140786120927, epoch: 2100/5000
Loss: 1.0123935135219038, epoch: 2200/5000
Loss: 0.8900828695916398, epoch: 2300/5000
Loss: 0.9670165500600012, epoch: 2400/5000
Loss: 1.079508349533323, epoch: 2500/5000
Loss: 1.056768941076918, epoch: 2600/5000
Loss: 0.9972152242802894, epoch: 2700/5000
Loss: 1.071762973958401, epoch: 2800/5000
Loss: 0.8448359096912237, epoch: 2900/5000
Loss: 1.0102266777255218, epoch: 3000/5000
Loss: 1.116843706206798, epoch: 3100/5000
Loss: 0.8943150842872811, epoch: 3200/5000
Loss: 0.8414532230300205, epoch: 3300/5000
Loss: 1.0347914166882473, epoch: 3400/5000
Loss: 0.9380095641825859, epoch: 3500/5000
Loss: 1.07281842033807, epoch: 3600/5000
Loss: 0.9904761004300585, epoch: 3700/5000
Loss: 1.1919557354650816, epoch: 3800/5000
Loss: 0.9590538241906424, epoch: 3900/5000
Loss: 0.9427431391840845, epoch: 4000/5000
Loss: 0.8798969422740408, epoch: 4100/5000
Loss: 1.0215279803280501, epoch: 4200/5000
Loss: 0.9729521123071736, epoch: 4300/5000
Loss: 1.012073807

Loss: 0.8540370339569954, epoch: 1400/5000
Loss: 1.1024953285829462, epoch: 1500/5000
Loss: 0.8408673207991075, epoch: 1600/5000
Loss: 0.9147993898694402, epoch: 1700/5000
Loss: 0.9587668155564508, epoch: 1800/5000
Loss: 0.9733218299984603, epoch: 1900/5000
Loss: 0.9017402306423309, epoch: 2000/5000
Loss: 0.8509404771125002, epoch: 2100/5000
Loss: 0.7510147396735936, epoch: 2200/5000
Loss: 0.864676990799123, epoch: 2300/5000
Loss: 0.8137868983773009, epoch: 2400/5000
Loss: 0.8465399459064514, epoch: 2500/5000
Loss: 0.7758592585616917, epoch: 2600/5000
Loss: 0.8067078366529319, epoch: 2700/5000
Loss: 0.737524967687777, epoch: 2800/5000
Loss: 0.7848984787716549, epoch: 2900/5000
Loss: 0.8259145665618168, epoch: 3000/5000
Loss: 0.7824734217038627, epoch: 3100/5000
Loss: 0.765245658974767, epoch: 3200/5000
Loss: 0.7326521013688302, epoch: 3300/5000
Loss: 0.7574561594237071, epoch: 3400/5000
Loss: 0.8516781485600236, epoch: 3500/5000
Loss: 0.9001182401202804, epoch: 3600/5000
Loss: 0.835237

In [35]:
(end_process_peterp - start_process_peterp)/50

104.34018015212001

In [46]:
plt.hist(losses_dist, bins=50)
plt.title(f"Final Result Distribution for SingleSetModel (PeterPeptides)")
plt.xlabel("Loss")
plt.ylabel("Occurrences")

Text(0, 0.5, 'Occurrences')

In [33]:
# Parameters
sample = 'GilarSample'
alist = pd.read_csv(f'../data/{sample}.csv')
# GilarSample
run_time_lim = 10.0
sigma_max = 0.3
delta_taus = [.25, .25, 10]
num_episodes = 5000
sample_size = 10
lr = .01
optim = lambda a, b: torch.optim.SGD(a, b)#, momentum=0.65)
lr_decay_factor = .5
lr_milestones = 1000
print_every= 100
baseline = 0.55
max_norm = 2.
exp = ExperimentAnalytes(k0 = alist.k0.values, S = alist.S.values, h=0.001, run_time=10.0)
losses_dist_gil = []

start_process_gil = time.perf_counter()

for i in range(50):
    pol = PolicySingle(len(delta_taus), sigma_max = sigma_max)
    reinforce_one_set(
            exp, 
            pol, 
            delta_taus= delta_taus, 
            num_episodes=num_episodes, 
            sample_size=sample_size,
            optim=optim,
            lr=lr, 
            print_every=print_every,
            lr_decay_factor=lr_decay_factor,
            lr_milestones=lr_milestones,
            baseline=baseline,
            max_norm=max_norm,

        )
    exp.reset()
    exp.run_all(pol.mu.detach().numpy(), delta_taus)
    
    losses_dist_gil.append(exp.loss())
end_process_gil = time.perf_counter()

Loss: 0.3837382332293865, epoch: 100/5000
Loss: 0.3420753674926548, epoch: 200/5000
Loss: 0.30895629974658123, epoch: 300/5000
Loss: 0.3128230615612189, epoch: 400/5000
Loss: 0.3217646248690046, epoch: 500/5000
Loss: 0.24406403380075908, epoch: 600/5000
Loss: 0.29445692814914187, epoch: 700/5000
Loss: 0.34421701946666594, epoch: 800/5000
Loss: 0.29123627694388615, epoch: 900/5000
Loss: 0.22314111102440864, epoch: 1000/5000
Loss: 0.29405003167933363, epoch: 1100/5000
Loss: 0.390656503581118, epoch: 1200/5000
Loss: 0.2445287457395994, epoch: 1300/5000
Loss: 0.3329568584054795, epoch: 1400/5000
Loss: 0.329513515726685, epoch: 1500/5000
Loss: 0.25217158373698456, epoch: 1600/5000
Loss: 0.2543345487608631, epoch: 1700/5000
Loss: 0.2839556734727638, epoch: 1800/5000
Loss: 0.31755711442959195, epoch: 1900/5000
Loss: 0.2803525822206795, epoch: 2000/5000
Loss: 0.3173572295592877, epoch: 2100/5000
Loss: 0.2640285411280815, epoch: 2200/5000
Loss: 0.3134394907132161, epoch: 2300/5000
Loss: 0.24193

Loss: 0.23179932081924193, epoch: 4000/5000
Loss: 0.37163078702709174, epoch: 4100/5000
Loss: 0.2892151737585139, epoch: 4200/5000
Loss: 0.24555940981696472, epoch: 4300/5000
Loss: 0.31438274677699063, epoch: 4400/5000
Loss: 0.24916740897384043, epoch: 4500/5000
Loss: 0.2602629183946118, epoch: 4600/5000
Loss: 0.2227230495488211, epoch: 4700/5000
Loss: 0.2081953730252021, epoch: 4800/5000
Loss: 0.26224487528111684, epoch: 4900/5000
Loss: 0.23682541242949032, epoch: 5000/5000
Loss: 0.2539566361657369, epoch: 100/5000
Loss: 0.20218796966970887, epoch: 200/5000
Loss: 0.19366633214869972, epoch: 300/5000
Loss: 0.22406149736816694, epoch: 400/5000
Loss: 0.20991524522378385, epoch: 500/5000
Loss: 0.22708164790787122, epoch: 600/5000
Loss: 0.2242285807821859, epoch: 700/5000
Loss: 0.2244224562728269, epoch: 800/5000
Loss: 0.2249529264121867, epoch: 900/5000
Loss: 0.23442437855909803, epoch: 1000/5000
Loss: 0.25769096764580124, epoch: 1100/5000
Loss: 0.2532242481464014, epoch: 1200/5000
Loss: 

Loss: 0.23635491616708118, epoch: 2900/5000
Loss: 0.2161567872558216, epoch: 3000/5000
Loss: 0.2723884723362354, epoch: 3100/5000
Loss: 0.23896541938797156, epoch: 3200/5000
Loss: 0.2391523107657839, epoch: 3300/5000
Loss: 0.21022789986119345, epoch: 3400/5000
Loss: 0.20563461300231226, epoch: 3500/5000
Loss: 0.22056062084900177, epoch: 3600/5000
Loss: 0.2257071692801551, epoch: 3700/5000
Loss: 0.27407561568479955, epoch: 3800/5000
Loss: 0.27785317486480665, epoch: 3900/5000
Loss: 0.24845765393706012, epoch: 4000/5000
Loss: 0.17805383746178008, epoch: 4100/5000
Loss: 0.23388769900852976, epoch: 4200/5000
Loss: 0.20569662152686302, epoch: 4300/5000
Loss: 0.22173742466906027, epoch: 4400/5000
Loss: 0.2763353458728894, epoch: 4500/5000
Loss: 0.21289477082559957, epoch: 4600/5000
Loss: 0.20774437956505346, epoch: 4700/5000
Loss: 0.22581863324283552, epoch: 4800/5000
Loss: 0.2856747042470172, epoch: 4900/5000
Loss: 0.24678944087039242, epoch: 5000/5000
Loss: 0.360552552785293, epoch: 100/50

Loss: 0.19071717681251396, epoch: 1900/5000
Loss: 0.19927516234412262, epoch: 2000/5000
Loss: 0.23489161977939, epoch: 2100/5000
Loss: 0.22802005580060322, epoch: 2200/5000
Loss: 0.20941645764419586, epoch: 2300/5000
Loss: 0.2200583151806304, epoch: 2400/5000
Loss: 0.19374995078673343, epoch: 2500/5000
Loss: 0.17778555372000387, epoch: 2600/5000
Loss: 0.17818204606875504, epoch: 2700/5000
Loss: 0.20359701353208895, epoch: 2800/5000
Loss: 0.21703565981356698, epoch: 2900/5000
Loss: 0.1969705549174775, epoch: 3000/5000
Loss: 0.2198315513274419, epoch: 3100/5000
Loss: 0.2147041771604783, epoch: 3200/5000
Loss: 0.1907147842239806, epoch: 3300/5000
Loss: 0.21167532404534564, epoch: 3400/5000
Loss: 0.18831213193581955, epoch: 3500/5000
Loss: 0.178477022129236, epoch: 3600/5000
Loss: 0.19454022304296587, epoch: 3700/5000
Loss: 0.2100513842336787, epoch: 3800/5000
Loss: 0.21323276014980147, epoch: 3900/5000
Loss: 0.1800955575772113, epoch: 4000/5000
Loss: 0.1715916792497188, epoch: 4100/5000
L

Loss: 0.24800060358804252, epoch: 800/5000
Loss: 0.3049223670241513, epoch: 900/5000
Loss: 0.2606952372524302, epoch: 1000/5000
Loss: 0.2620387104812337, epoch: 1100/5000
Loss: 0.3430547174802533, epoch: 1200/5000
Loss: 0.31850377507599587, epoch: 1300/5000
Loss: 0.22304435372073525, epoch: 1400/5000
Loss: 0.27209400975834785, epoch: 1500/5000
Loss: 0.2442145073998042, epoch: 1600/5000
Loss: 0.33222993772459075, epoch: 1700/5000
Loss: 0.1858708732031753, epoch: 1800/5000
Loss: 0.2846512349048775, epoch: 1900/5000
Loss: 0.1962105710598721, epoch: 2000/5000
Loss: 0.26701926079634347, epoch: 2100/5000
Loss: 0.31351945033527195, epoch: 2200/5000
Loss: 0.22521976554797227, epoch: 2300/5000
Loss: 0.250525892727612, epoch: 2400/5000
Loss: 0.190377076540312, epoch: 2500/5000
Loss: 0.2699190136980454, epoch: 2600/5000
Loss: 0.19495260781219298, epoch: 2700/5000
Loss: 0.24946443383321762, epoch: 2800/5000
Loss: 0.19617341034625407, epoch: 2900/5000
Loss: 0.26459906025186447, epoch: 3000/5000
Los

Loss: 0.18044752735106886, epoch: 4700/5000
Loss: 0.18412981727768574, epoch: 4800/5000
Loss: 0.186366789171231, epoch: 4900/5000
Loss: 0.20014731689859996, epoch: 5000/5000
Loss: 0.31582404731944747, epoch: 100/5000
Loss: 0.2772012171018387, epoch: 200/5000
Loss: 0.24279688706703323, epoch: 300/5000
Loss: 0.2696445389563572, epoch: 400/5000
Loss: 0.23531186048112232, epoch: 500/5000
Loss: 0.2289234380183919, epoch: 600/5000
Loss: 0.22538576583621758, epoch: 700/5000
Loss: 0.18130539307160182, epoch: 800/5000
Loss: 0.22720105384407865, epoch: 900/5000
Loss: 0.21664623246094536, epoch: 1000/5000
Loss: 0.2286931280156363, epoch: 1100/5000
Loss: 0.2250558939685364, epoch: 1200/5000
Loss: 0.22521375563995177, epoch: 1300/5000
Loss: 0.19859260708965704, epoch: 1400/5000
Loss: 0.20432822707552495, epoch: 1500/5000
Loss: 0.1714800476558045, epoch: 1600/5000
Loss: 0.18867558758191, epoch: 1700/5000
Loss: 0.20079198075125865, epoch: 1800/5000
Loss: 0.21441838104177036, epoch: 1900/5000
Loss: 0.

Loss: 0.2253984632599742, epoch: 3700/5000
Loss: 0.2315108492377219, epoch: 3800/5000
Loss: 0.22050993797054147, epoch: 3900/5000
Loss: 0.31653340959230725, epoch: 4000/5000
Loss: 0.23071698696560797, epoch: 4100/5000
Loss: 0.18289697934586874, epoch: 4200/5000
Loss: 0.24203222627045093, epoch: 4300/5000
Loss: 0.23202245368779373, epoch: 4400/5000
Loss: 0.22004445294360148, epoch: 4500/5000
Loss: 0.2245074115819794, epoch: 4600/5000
Loss: 0.22173354158693356, epoch: 4700/5000
Loss: 0.22891006176737375, epoch: 4800/5000
Loss: 0.22823820767075415, epoch: 4900/5000
Loss: 0.23210182617692832, epoch: 5000/5000
Loss: 0.3040764835934228, epoch: 100/5000
Loss: 0.3498716460576809, epoch: 200/5000
Loss: 0.24562521817846936, epoch: 300/5000
Loss: 0.3649133200870361, epoch: 400/5000
Loss: 0.30653652442449664, epoch: 500/5000
Loss: 0.2666951112951722, epoch: 600/5000
Loss: 0.26503374498722004, epoch: 700/5000
Loss: 0.2570038080135021, epoch: 800/5000
Loss: 0.24677238169347215, epoch: 900/5000
Loss:

Loss: 0.18803953199694803, epoch: 2600/5000
Loss: 0.18781005307811194, epoch: 2700/5000
Loss: 0.19309437896149015, epoch: 2800/5000
Loss: 0.20028611193524437, epoch: 2900/5000
Loss: 0.1677919351648688, epoch: 3000/5000
Loss: 0.18945528285870553, epoch: 3100/5000
Loss: 0.16257628635568577, epoch: 3200/5000
Loss: 0.1573989682412378, epoch: 3300/5000
Loss: 0.1935140337789969, epoch: 3400/5000
Loss: 0.17377585215554597, epoch: 3500/5000
Loss: 0.18158736052996094, epoch: 3600/5000
Loss: 0.17760360833065153, epoch: 3700/5000
Loss: 0.15340667096260077, epoch: 3800/5000
Loss: 0.17961057421687668, epoch: 3900/5000
Loss: 0.16825202624562735, epoch: 4000/5000
Loss: 0.16463769723128474, epoch: 4100/5000
Loss: 0.17986775936417948, epoch: 4200/5000
Loss: 0.176821512717326, epoch: 4300/5000
Loss: 0.17229207451161965, epoch: 4400/5000
Loss: 0.18398367264367002, epoch: 4500/5000
Loss: 0.1820799897090713, epoch: 4600/5000
Loss: 0.175249691380309, epoch: 4700/5000
Loss: 0.16899117498578775, epoch: 4800/5

Loss: 0.364698406346604, epoch: 1500/5000
Loss: 0.3323746334735903, epoch: 1600/5000
Loss: 0.3379805557557217, epoch: 1700/5000
Loss: 0.33923123184567333, epoch: 1800/5000
Loss: 0.38116028749796005, epoch: 1900/5000
Loss: 0.3153419663640441, epoch: 2000/5000
Loss: 0.23574711677502247, epoch: 2100/5000
Loss: 0.2644634736124182, epoch: 2200/5000
Loss: 0.2812657784808378, epoch: 2300/5000
Loss: 0.2121524396264753, epoch: 2400/5000
Loss: 0.3270878629745825, epoch: 2500/5000
Loss: 0.29589717517675485, epoch: 2600/5000
Loss: 0.3056208102450494, epoch: 2700/5000
Loss: 0.3530912232321559, epoch: 2800/5000
Loss: 0.2607257067968208, epoch: 2900/5000
Loss: 0.3566069695124587, epoch: 3000/5000
Loss: 0.29379459630761456, epoch: 3100/5000
Loss: 0.3080839609828926, epoch: 3200/5000
Loss: 0.30543009523588116, epoch: 3300/5000
Loss: 0.2647044531074199, epoch: 3400/5000
Loss: 0.24699709615325965, epoch: 3500/5000
Loss: 0.28884311178825534, epoch: 3600/5000
Loss: 0.2676887650445304, epoch: 3700/5000
Loss

Loss: 0.22780729283435525, epoch: 400/5000
Loss: 0.27591057126363683, epoch: 500/5000
Loss: 0.21589731555633662, epoch: 600/5000
Loss: 0.2153806717445314, epoch: 700/5000
Loss: 0.24164147061110616, epoch: 800/5000
Loss: 0.1796845054111676, epoch: 900/5000
Loss: 0.22310236135958506, epoch: 1000/5000
Loss: 0.22491722551897536, epoch: 1100/5000
Loss: 0.1748275208844821, epoch: 1200/5000
Loss: 0.23080400719128846, epoch: 1300/5000
Loss: 0.20304190374056147, epoch: 1400/5000
Loss: 0.20216937478078792, epoch: 1500/5000
Loss: 0.21720190900343095, epoch: 1600/5000
Loss: 0.22322258466158146, epoch: 1700/5000
Loss: 0.1789073312482679, epoch: 1800/5000
Loss: 0.1724983405894983, epoch: 1900/5000
Loss: 0.18269928248977013, epoch: 2000/5000
Loss: 0.19653475478272953, epoch: 2100/5000
Loss: 0.18582204140349384, epoch: 2200/5000
Loss: 0.18777399940495226, epoch: 2300/5000
Loss: 0.16620669796775048, epoch: 2400/5000
Loss: 0.22268785283179932, epoch: 2500/5000
Loss: 0.19986642137619542, epoch: 2600/5000

Loss: 0.19748392308666304, epoch: 4300/5000
Loss: 0.28035787520820704, epoch: 4400/5000
Loss: 0.2555882528031816, epoch: 4500/5000
Loss: 0.23524924494046706, epoch: 4600/5000
Loss: 0.2120512262932889, epoch: 4700/5000
Loss: 0.22889611238148308, epoch: 4800/5000
Loss: 0.2117695150055398, epoch: 4900/5000
Loss: 0.2045233818955246, epoch: 5000/5000
Loss: 0.4200226610916955, epoch: 100/5000
Loss: 0.3705003931955552, epoch: 200/5000
Loss: 0.37924313013766475, epoch: 300/5000
Loss: 0.4042339373473318, epoch: 400/5000
Loss: 0.4304920225108768, epoch: 500/5000
Loss: 0.4460664037106582, epoch: 600/5000
Loss: 0.3680780178972481, epoch: 700/5000
Loss: 0.35465648915475856, epoch: 800/5000
Loss: 0.3532807097565668, epoch: 900/5000
Loss: 0.4051729849245156, epoch: 1000/5000
Loss: 0.36049292266970556, epoch: 1100/5000
Loss: 0.3021511705861303, epoch: 1200/5000
Loss: 0.37214964811609663, epoch: 1300/5000
Loss: 0.3984703402633036, epoch: 1400/5000
Loss: 0.33116263491126763, epoch: 1500/5000
Loss: 0.277

Loss: 0.23542510848415826, epoch: 3300/5000
Loss: 0.22983439999666336, epoch: 3400/5000
Loss: 0.26735199763712997, epoch: 3500/5000
Loss: 0.2674256904191469, epoch: 3600/5000
Loss: 0.21182132321474456, epoch: 3700/5000
Loss: 0.29846339987655346, epoch: 3800/5000
Loss: 0.2339281218627769, epoch: 3900/5000
Loss: 0.20763882100864822, epoch: 4000/5000
Loss: 0.20511144466302147, epoch: 4100/5000
Loss: 0.20101073448476586, epoch: 4200/5000
Loss: 0.2240410243429988, epoch: 4300/5000
Loss: 0.24378311123309865, epoch: 4400/5000
Loss: 0.2518605045202425, epoch: 4500/5000
Loss: 0.25392729928798397, epoch: 4600/5000
Loss: 0.22768123076922628, epoch: 4700/5000
Loss: 0.20755295455149456, epoch: 4800/5000
Loss: 0.20532249684394888, epoch: 4900/5000
Loss: 0.2588329683774084, epoch: 5000/5000
Loss: 0.3643546997771098, epoch: 100/5000
Loss: 0.2918924433764215, epoch: 200/5000
Loss: 0.31715889160761607, epoch: 300/5000
Loss: 0.26508299906833827, epoch: 400/5000
Loss: 0.25437642232161745, epoch: 500/5000


Loss: 0.19154433510380217, epoch: 2200/5000
Loss: 0.1773867127397604, epoch: 2300/5000
Loss: 0.15863474807979142, epoch: 2400/5000
Loss: 0.19786553636314969, epoch: 2500/5000
Loss: 0.17986911709975187, epoch: 2600/5000
Loss: 0.1750574058926737, epoch: 2700/5000
Loss: 0.205883320888882, epoch: 2800/5000
Loss: 0.19349177700172582, epoch: 2900/5000
Loss: 0.1976569434099677, epoch: 3000/5000
Loss: 0.17464923712765232, epoch: 3100/5000
Loss: 0.21850012428995882, epoch: 3200/5000
Loss: 0.17567961610993105, epoch: 3300/5000
Loss: 0.15803901937541356, epoch: 3400/5000
Loss: 0.16820372386820445, epoch: 3500/5000
Loss: 0.1632158737844151, epoch: 3600/5000
Loss: 0.1974216314646817, epoch: 3700/5000
Loss: 0.18165666523426927, epoch: 3800/5000
Loss: 0.19735886810373077, epoch: 3900/5000
Loss: 0.2022211246040385, epoch: 4000/5000
Loss: 0.1899251078536259, epoch: 4100/5000
Loss: 0.21021781885303778, epoch: 4200/5000
Loss: 0.15543623809784587, epoch: 4300/5000
Loss: 0.1599594984130684, epoch: 4400/500

Loss: 0.2475877281628946, epoch: 1100/5000
Loss: 0.23936513171741652, epoch: 1200/5000
Loss: 0.22469279084521698, epoch: 1300/5000
Loss: 0.257542309405224, epoch: 1400/5000
Loss: 0.21385592354841246, epoch: 1500/5000
Loss: 0.2557330882166646, epoch: 1600/5000
Loss: 0.2771770411017232, epoch: 1700/5000
Loss: 0.25729760902533083, epoch: 1800/5000
Loss: 0.21217226298487674, epoch: 1900/5000
Loss: 0.289522089330592, epoch: 2000/5000
Loss: 0.1848352364822377, epoch: 2100/5000
Loss: 0.19416885546665627, epoch: 2200/5000
Loss: 0.17695479439824227, epoch: 2300/5000
Loss: 0.20834199855501861, epoch: 2400/5000
Loss: 0.2353390878706068, epoch: 2500/5000
Loss: 0.2230664699226313, epoch: 2600/5000
Loss: 0.2120695677007141, epoch: 2700/5000
Loss: 0.26418144157642176, epoch: 2800/5000
Loss: 0.23883375763651707, epoch: 2900/5000
Loss: 0.21179069009530016, epoch: 3000/5000
Loss: 0.21156055687539355, epoch: 3100/5000
Loss: 0.17214316044119032, epoch: 3200/5000
Loss: 0.20710604180373932, epoch: 3300/5000

In [45]:
plt.hist(losses_dist_gil, bins=50)
plt.title(f"Final Result Distribution for SingleSetModel (GillarSample)")
plt.xlabel("Loss")
plt.ylabel("Occurrences")

Text(0, 0.5, 'Occurrences')

In [28]:
exp = ExperimentAnalytes(k0 = alist.k0.values, S = alist.S.values, h=0.001, run_time=10.0)
pol = PolicySingle(len(delta_taus), sigma_max = sigma_max)
# Parameters
run_time_lim = 10.0
sigma_max = 0.3
delta_taus = [.25, .25, 10]
num_episodes = 5000
sample_size = 10
lr = .05
optim = lambda a, b: torch.optim.SGD(a, b)#, momentum=0.65)
lr_decay_factor = .5
lr_milestones = 1000
print_every= 100
baseline = 0.55
max_norm = 2.

losses, best_program, mus, sigmas = reinforce_one_set(
        exp, 
        pol, 
        delta_taus= delta_taus, 
        num_episodes=num_episodes, 
        sample_size=sample_size,
        optim=optim,
        lr=lr, 
        print_every=print_every,
        lr_decay_factor=lr_decay_factor,
        lr_milestones=lr_milestones,
        baseline=baseline,
        max_norm=max_norm,

    )

Loss: 0.9314450072745212, epoch: 100/5000
Loss: 0.7960921502013795, epoch: 200/5000
Loss: 0.7872799096277983, epoch: 300/5000
Loss: 0.8325011470639797, epoch: 400/5000
Loss: 0.7338813111961876, epoch: 500/5000
Loss: 0.7022732128134827, epoch: 600/5000
Loss: 0.7002333483001402, epoch: 700/5000
Loss: 0.7252364850520611, epoch: 800/5000
Loss: 0.6913454929450438, epoch: 900/5000
Loss: 0.6086716956008829, epoch: 1000/5000
Loss: 0.6655050577757822, epoch: 1100/5000
Loss: 0.6533033800482182, epoch: 1200/5000
Loss: 0.6424524889437833, epoch: 1300/5000
Loss: 0.6014987500228728, epoch: 1400/5000
Loss: 0.572126823418777, epoch: 1500/5000
Loss: 0.6277729726763039, epoch: 1600/5000
Loss: 0.5935932283541143, epoch: 1700/5000
Loss: 0.5566144634880233, epoch: 1800/5000
Loss: 0.5597704675729076, epoch: 1900/5000
Loss: 0.5437495740875014, epoch: 2000/5000
Loss: 0.5524055871679504, epoch: 2100/5000
Loss: 0.5532123540687122, epoch: 2200/5000
Loss: 0.5479945799359544, epoch: 2300/5000
Loss: 0.5556080010027

In [29]:
exp.reset()
exp.run_all(pol.mu.detach().numpy(), delta_taus)
exp.print_analytes(rc=(7, 5), angle=45,  title=f'Solvent Strength function\nLoss: {round(exp.loss(),4)}')
plt.show()
plt.legend()
#plt.savefig(f'results/{sample}_result_{run}.png')

<matplotlib.legend.Legend at 0x7f4a83916b10>

In [30]:
exp.reset()
exp.run_all(best_program, delta_taus)
exp.print_analytes(rc=(7, 5), angle=45,  title=f'Solvent Strength function\nLoss: {round(exp.loss(),4)}')
plt.legend()
#plt.savefig(f'results/{sample}_best_result_{run}.png')

<matplotlib.legend.Legend at 0x7f4a832aba90>

In [25]:
plt.plot(mus[:, 0], label='Mu: phi1')
plt.plot(mus[:, 1], label='Mu: phi2')
#plt.plot(mus[:, 2], label='Mu: phi3')
#plt.plot(mus[:, 3], label='Mu: phi4')
plt.ylim((0,1))
plt.legend()
#plt.savefig(f'results/{sample}_mus_{run}.png')

<matplotlib.legend.Legend at 0x7f4a83deb7d0>

In [26]:
plt.plot(sigmas[:, 0], label='Sigma: phi1')
plt.plot(sigmas[:, 1], label='Sigma: phi2')
#plt.plot(sigmas[:, 2], label='Sigma: phi3')
#plt.plot(sigmas[:, 3], label='Sigma: phi4')
plt.ylim((0,sigma_max))
plt.legend()
#plt.savefig(f'results/{sample}_sigmas_{run}.png')

<matplotlib.legend.Legend at 0x7f4a83957c50>

In [None]:
plt.plot(losses)
plt.xlabel("Episodes")
plt.ylabel("Loss")
plt.title("Loss during learning process")
#plt.savefig(f'results/{sample}_loss_{run}.png')

In [None]:
X, Y, Loss_field = loss_field(exp, [.25, 30000], N = 500)
plt.contourf(X, Y, Loss_field, levels=100)
plt.xlabel("Phi1")
plt.ylabel("Phi2")
index = np.unravel_index(np.argmin(Loss_field), Loss_field.shape)
plt.scatter(X[index], Y[index], c="r", s=5)
plt.title(f"Loss Field, change of phi at 0.5\nGlobal min ({round(np.min(Loss_field), 3)}) at {round(X[index], 4), round(Y[index], 4)}, red dot")
#plt.savefig(f'results/{sample}_loss_filed_at_5.png')

In [None]:
N = -1
mlab.surf(X.T[:N, :N], Y.T[:N, :N], Loss_field.T[:N, :N])

In [None]:

phis = np.linspace(0, 1, 1000)
losses = []
for phi1 in phis:
    exp.reset()
    exp.run_all([phi1], [10])
    losses.append(exp.loss())
plt.plot(losses[:])
plt.ylim((0, 2))

In [None]:
plt.contourf(X, Y, Loss_field, levels=100)

In [None]:
exp_time = ExperimentAnalytes(k0 = alist.k0.values, S = alist.S.values, h=0.001, grad='iso', run_time=10.0)
pol_time = PolicyTime(2, sigma_max = .2)
losses_time, best_program_time, mus_time, sigmas_time, n_par_time = reinforce_delta_tau(
        exp_time, 
        pol_time,
        num_episodes=5000, 
        batch_size=10,
        optim=lambda a, b: torch.optim.SGD(a, b, momentum=0.65),
        lr=.1, 
        print_every=100,
        lr_decay=lambda a, b, c: step_decay(a, b, c, steps=5, decay_factor=0.5),
        baseline=0.65,
        max_norm=2.
    )

In [None]:
exp_time.reset()
exp_time.run_all(pol_time.mu.detach().numpy()[:2], [0.25, 3000000])
exp_time.print_analytes(rc=(7, 5))
exp_time.loss()

In [None]:
exp_time.reset()
exp_time.run_all(best_program_time[0], best_program_time[1])
exp_time.print_analytes(rc=(7, 5))
exp_time.loss()

In [None]:
plt.plot(losses_time)

In [None]:
plt.plot(mus_time[:, 0], label='Mu: 0')
plt.plot(mus_time[:, 1], label='Mu: 1')
plt.plot(mus_time[:, 2], label='Mu: 2')
plt.legend()
plt.show()

In [None]:
plt.plot(sigmas_time[:, 0], label='Sigma: 0')
plt.plot(sigmas_time[:, 1], label='Sigma: 1')
plt.plot(sigmas_time[:, 2], label='Sigma: 2')
plt.plot(sigmas_time[:, 3], label='Sigma: 3')
plt.plot(sigmas_time[:, 4], label='Sigma: 4')
plt.plot(sigmas_time[:, 5], label='Sigma: 5')
plt.plot(sigmas_time[:, 6], label='Sigma: 6')

plt.legend()
plt.show()