In [1]:
import numpy as np
import pandas as pd
from chromatography import *
from torch import optim, tensor
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from mayavi import mlab
%matplotlib qt

In [16]:
isinstance(optim.SGD, optim.optimi)

AttributeError: module 'torch.optim' has no attribute 'optimizer'

In [2]:
alists = []
alists.append(pd.read_csv(f'../data/GilarSample.csv'))
alists.append(pd.read_csv(f'../data/Peterpeptides.csv'))
alists.append(pd.read_csv(f'../data/Roca.csv'))
alists.append(pd.read_csv(f'../data/Peter32.csv'))
alists.append(pd.read_csv(f'../data/Eosin.csv'))
alists.append(pd.read_csv(f'../data/Alizarin.csv'))
alists.append(pd.read_csv(f'../data/Controlmix2.csv'))


In [3]:
def step_decay(lr, iteration, num_episodes, steps=10, decay_factor=0.8):
    if iteration % (num_episodes // steps) == 0:
        return lr * decay_factor
    
    return lr

def loss_field(exp, taus, N = 200):
    phis = np.linspace(0, 1, N)
    losses = np.zeros((N, N))
    j = 0
    for phi1 in phis:
        i = 0
        for phi2 in phis:
            exp.reset()
            exp.run_all([phi1, phi2], taus)
            losses[i, j] = exp.loss()
            i += 1
        j += 1
    X, Y = np.meshgrid(phis, phis)
    
    return X, Y, losses

In [4]:
class Rho(nn.Module):
    def __init__(self, n_par, width, in_dim = 2, sigma_max = .3, sigma_min = .1):
        super().__init__()
        
        self.n_par = n_par
        self.width = width
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max

        self.sig = nn.Sigmoid()
        self.fc_mu_1 = nn.Linear(in_dim, width)
        self.fc_mu_2 = nn.Linear(width, n_par)
        self.fc_sig_1 = nn.Linear(in_dim, width)
        self.fc_sig_2 = nn.Linear(width, n_par)
        
        
        
    def forward(self, x):
        
        x = torch.mean(x, dim=0, keepdim=True)
        mu = F.relu(self.fc_mu_1(x))
        sigma = F.relu(self.fc_sig_1(x))
        
        mu = self.sig(self.fc_mu_2(mu)).squeeze()
        # limit sigma to be in range (sigma_min; sigma_max)
        sigma = self.sig(self.fc_sig_2(sigma)).squeeze() * (self.sigma_max - self.sigma_min) + self.sigma_min
        return mu, sigma
    
    
class Perm_max(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Perm_max, self).__init__()
        self.Gamma = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        xm, _ = x.max(0, keepdim=True)
        x = self.Gamma(x-xm)
        return x

class Perm_max2(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Perm_max2, self).__init__()
        self.Gamma = nn.Linear(in_dim, out_dim)
        self.Lambda = nn.Linear(in_dim, out_dim, bias=False)

    def forward(self, x):
        xm, _ = x.max(0, keepdim=True)
        xm = self.Lambda(xm) 
        x = self.Gamma(x)
        x = x - xm
        return x

In [5]:
class PolicyGeneral(nn.Module):
    def __init__(self, n_par, width, sigma_min = .0, sigma_max = .1):
        super().__init__()
        
        self.n_par = n_par
        self.width = width
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max

        self.phi = nn.Sequential(
            Perm_max(2, self.width),
            nn.ELU(inplace=True),
            Perm_max(self.width, self.width),
            nn.ELU(inplace=True),
            Perm_max(self.width, self.width),
            nn.ELU(inplace=True),
        )
        
        self.rho = Rho(n_par, width, width, sigma_max, sigma_min)
        
    def forward(self, x):
        phi_output = self.phi(x)
        sum_output = phi_output.mean(0, keepdim=True)
        mu, sigma = self.rho(sum_output)
        return mu, sigma

In [46]:
pol = PolicyGeneral(2, 4, sigma_max=.2, sigma_min=0.01)
losses = reinforce_gen(
    alists = alists, 
    policy = pol, 
    delta_taus = [.25, 10], 
    num_episodes = 100000, 
    batch_size = 10, 
    lr = .01, 
    optim = lambda a, b: torch.optim.SGD(a, b),
    print_every = 100,
    lr_decay = lambda a, b, c: step_decay(a, b, c, steps=5, decay_factor=0.5),
    weights = [1., 1.],
    baseline = .55,
    max_norm = None,
    beta = .0,
    max_rand_analytes = 30,
    min_rand_analytes = 10,
    rand_prob = 0.7
)

Loss: 0.9517968919291919, epoch: 100/100000
Loss: 0.8286087333226642, epoch: 200/100000
Loss: 0.7962744894328113, epoch: 300/100000
Loss: 0.7794659340417865, epoch: 400/100000
Loss: 0.8034790205576013, epoch: 500/100000
Loss: 0.7884183389961075, epoch: 600/100000
Loss: 0.8183371856131324, epoch: 700/100000
Loss: 0.815788352712508, epoch: 800/100000
Loss: 0.7654202813186363, epoch: 900/100000
Loss: 0.796601441774525, epoch: 1000/100000
Loss: 0.7846272875250858, epoch: 1100/100000
Loss: 0.7806360058094388, epoch: 1200/100000
Loss: 0.7864323380499451, epoch: 1300/100000
Loss: 0.8235284284464865, epoch: 1400/100000
Loss: 0.7949469918834671, epoch: 1500/100000
Loss: 0.7719088515149468, epoch: 1600/100000
Loss: 0.7885155902522067, epoch: 1700/100000
Loss: 0.7779781399957828, epoch: 1800/100000
Loss: 0.8126499979691035, epoch: 1900/100000
Loss: 0.8043515313807184, epoch: 2000/100000
Loss: 0.7880529595701143, epoch: 2100/100000
Loss: 0.7682002394598152, epoch: 2200/100000
Loss: 0.7935503627342

Loss: 1.824058130537591, epoch: 18300/100000
Loss: 1.8346887178449514, epoch: 18400/100000
Loss: 1.851647949192961, epoch: 18500/100000
Loss: 1.8439050272369297, epoch: 18600/100000
Loss: 1.8427559333809904, epoch: 18700/100000
Loss: 1.8275822877117667, epoch: 18800/100000
Loss: 1.8563109710963959, epoch: 18900/100000
Loss: 1.8313199446215296, epoch: 19000/100000
Loss: 1.8278554958245936, epoch: 19100/100000
Loss: 1.847866485006396, epoch: 19200/100000
Loss: 1.8248696647625036, epoch: 19300/100000
Loss: 1.8382038067031479, epoch: 19400/100000
Loss: 1.8484605454140821, epoch: 19500/100000
Loss: 1.830063433660245, epoch: 19600/100000
Loss: 1.8193169482948879, epoch: 19700/100000
Loss: 1.8512207368945588, epoch: 19800/100000
Loss: 1.8369087144298908, epoch: 19900/100000
Loss: 1.8390032738114142, epoch: 20000/100000
Loss: 1.8368189239778119, epoch: 20100/100000
Loss: 1.827690937168917, epoch: 20200/100000
Loss: 1.8450980858287211, epoch: 20300/100000
Loss: 1.8284427353716406, epoch: 20400/

Loss: 1.8393924387452159, epoch: 36300/100000
Loss: 1.8193829427470058, epoch: 36400/100000
Loss: 1.8214807060517602, epoch: 36500/100000
Loss: 1.8174374034984213, epoch: 36600/100000
Loss: 1.8529873321426948, epoch: 36700/100000
Loss: 1.8390445512781515, epoch: 36800/100000
Loss: 1.8402838508151869, epoch: 36900/100000
Loss: 1.8209437520633736, epoch: 37000/100000
Loss: 1.8304900836526519, epoch: 37100/100000
Loss: 1.8246442101461426, epoch: 37200/100000
Loss: 1.8369530468317448, epoch: 37300/100000
Loss: 1.8473084586551174, epoch: 37400/100000
Loss: 1.8186775465974627, epoch: 37500/100000
Loss: 1.8081709037810618, epoch: 37600/100000
Loss: 1.8331037272588149, epoch: 37700/100000
Loss: 1.8409601799964244, epoch: 37800/100000
Loss: 1.8262291341229175, epoch: 37900/100000
Loss: 1.8432004245097127, epoch: 38000/100000
Loss: 1.8417884230587822, epoch: 38100/100000
Loss: 1.8441378078268647, epoch: 38200/100000
Loss: 1.825837402814907, epoch: 38300/100000
Loss: 1.8419266523246345, epoch: 38

Loss: 1.8481836891807983, epoch: 54300/100000
Loss: 1.8529781060770973, epoch: 54400/100000
Loss: 1.8429221870176886, epoch: 54500/100000
Loss: 1.8390823503573939, epoch: 54600/100000
Loss: 1.8462997266564851, epoch: 54700/100000
Loss: 1.868120276357577, epoch: 54800/100000
Loss: 1.8213109199626336, epoch: 54900/100000
Loss: 1.8381295295342803, epoch: 55000/100000
Loss: 1.8260553726003192, epoch: 55100/100000
Loss: 1.818054586933688, epoch: 55200/100000
Loss: 1.841855507918359, epoch: 55300/100000
Loss: 1.828694719827766, epoch: 55400/100000
Loss: 1.8554658666536883, epoch: 55500/100000
Loss: 1.8369704985342814, epoch: 55600/100000
Loss: 1.8407709964710581, epoch: 55700/100000
Loss: 1.8336266302782982, epoch: 55800/100000
Loss: 1.813866064575018, epoch: 55900/100000
Loss: 1.829034286612139, epoch: 56000/100000
Loss: 1.8316903443650772, epoch: 56100/100000
Loss: 1.8452301212834668, epoch: 56200/100000
Loss: 1.8375925471178631, epoch: 56300/100000
Loss: 1.82882325338936, epoch: 56400/100

Loss: 1.8400752408366399, epoch: 72200/100000
Loss: 1.856341487382651, epoch: 72300/100000
Loss: 1.8528273199941128, epoch: 72400/100000
Loss: 1.8318295650805336, epoch: 72500/100000
Loss: 1.8265877268114927, epoch: 72600/100000
Loss: 1.8411159990550192, epoch: 72700/100000
Loss: 1.852376281044834, epoch: 72800/100000
Loss: 1.8381645589474926, epoch: 72900/100000
Loss: 1.8317964322535822, epoch: 73000/100000
Loss: 1.8388512101293164, epoch: 73100/100000
Loss: 1.8445310412592586, epoch: 73200/100000
Loss: 1.8146251928638975, epoch: 73300/100000
Loss: 1.839869785476546, epoch: 73400/100000
Loss: 1.8377842665148205, epoch: 73500/100000
Loss: 1.8225598825172953, epoch: 73600/100000
Loss: 1.8169196068507785, epoch: 73700/100000
Loss: 1.8320029011862964, epoch: 73800/100000
Loss: 1.8411005659008004, epoch: 73900/100000
Loss: 1.8366444722214497, epoch: 74000/100000
Loss: 1.8546235518337886, epoch: 74100/100000
Loss: 1.8388037974604958, epoch: 74200/100000
Loss: 1.8411873451105396, epoch: 7430

Loss: 1.8500338600050739, epoch: 90200/100000
Loss: 1.8112610361715098, epoch: 90300/100000
Loss: 1.8110238016183013, epoch: 90400/100000
Loss: 1.8350174014766363, epoch: 90500/100000
Loss: 1.8660333764936166, epoch: 90600/100000
Loss: 1.827662116938037, epoch: 90700/100000
Loss: 1.8461412530425327, epoch: 90800/100000
Loss: 1.8618883436580325, epoch: 90900/100000
Loss: 1.8271243530678263, epoch: 91000/100000
Loss: 1.8148145827630986, epoch: 91100/100000
Loss: 1.832729144695172, epoch: 91200/100000
Loss: 1.8427981224897647, epoch: 91300/100000
Loss: 1.8325107315963307, epoch: 91400/100000
Loss: 1.8592076928734125, epoch: 91500/100000
Loss: 1.8266366838215111, epoch: 91600/100000
Loss: 1.8378556040620484, epoch: 91700/100000
Loss: 1.8188866778194683, epoch: 91800/100000
Loss: 1.8372290415147847, epoch: 91900/100000
Loss: 1.8434397759971173, epoch: 92000/100000
Loss: 1.8404406738424526, epoch: 92100/100000
Loss: 1.8423103278832285, epoch: 92200/100000
Loss: 1.8230337702588406, epoch: 923

In [18]:
plt.plot(np.arange(0, 1000, 10), losses)

[<matplotlib.lines.Line2D at 0x7efe15956950>]

In [8]:
plt.plot(mus[:, 0], label='Mu: phi1')
plt.plot(mus[:, 1], label='Mu: phi2')
#plt.plot(mus[:, 2], label='Mu: phi3')
#plt.plot(mus[:, 3], label='Mu: phi4')
plt.ylim((-0.1,1.1))
plt.legend()

NameError: name 'mus' is not defined

In [None]:
plt.plot(sigmas[:, 0], label='Sigma: phi1')
plt.plot(sigmas[:, 1], label='Sigma: phi2')
#plt.plot(sigmas[:, 2], label='Sigma: phi3')
#plt.plot(sigmas[:, 3], label='Sigma: phi4')
plt.ylim((-0.01,0.3))
plt.legend()

In [44]:
i = 8
exp = ExperimentAnalytes(k0 = alists[i].k0.values, S = alists[i].S.values, h=0.001,run_time=10.0)

In [45]:
import torch
mu, sig = pol(torch.Tensor(alists[i][['S', 'lnk0']].values))
exp.run_all(mu.detach().numpy(), [.25, 10])

exp.print_analytes(title=f"Solvent Strength Program\nLoss:{round(exp.loss(), 4)}", rc=(10,10), angle=40)

In [43]:
alists.append( pd.concat(alists, sort=True).sample(30))

In [None]:
L = loss_field(exp, [.25, 10], 300)

In [None]:
grads_mu_2

In [None]:
grads_sig_2

In [None]:
grads_3

In [None]:
grads_1

In [None]:
t = torch.randn((2, 3, 5))
a, b = t.max(0, keepdim=True)
t, a, b.shape