
\begin{align*}
fraction \, tumor \, cells &=1- \frac{ \nu_0 \pi_0 w_0}{\sum_k \nu_k \pi_k w_k}\\
&= 1- \frac{   w_0}{\sum_k \frac{\nu_+}{\nu_0} \frac{\pi_k}{\pi_0} w_k}
\end{align*}

\begin{align*}
\frac{1}{fraction \, healthy \, cells} -1 &=\sum_{k\neq 0} \frac{\nu_+}{\nu_0} \frac{\pi_k}{\pi_0} \frac{w_k}{w_0}
\end{align*}

We model

$$\log(\frac{\pi_k}{\pi_0}) = b_d + \textbf q_d^{\top} \textbf W \textbf x_k $$
and 
$$ \frac{\nu_+}{\nu_0}$$ is obtained by taking the average over control wells using the RNA as ground truth.

In [None]:
Bayesian = False
use_scatrex = True

import numpy as np
import pandas as pd
import math
import os
import matplotlib as mlt
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import pickle

import torch
import pyro
from torch.distributions import constraints
from pyro import poutine
from pyro.contrib.autoguide import AutoDelta
from pyro.optim import Adam
import pyro.distributions as dist
import torch.distributions as tdist
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete, Predictive
%matplotlib inline
pyro.enable_validation(True)
import sys
sys.path.append('/data/users/quentin/Paper_nu_per_patient/')

import generative_all_drugs
from generative_all_drugs.data_preprocessing.get_real_data import *
from generative_all_drugs.visualization.vis_real_data import *

def plot_CredInt(x, mean, bottom, top, color='#2187bb', horizontal_line_width=0.25, alpha=1, linestyle='-'):
    left = x - horizontal_line_width / 2
    right = x + horizontal_line_width / 2
    plt.plot([x, x], [top, bottom], color=color, alpha=alpha, linestyle=linestyle)
    plt.plot([left, right], [top, top], color=color, alpha=alpha, linestyle=linestyle)
    plt.plot([left, right], [bottom, bottom], color=color, alpha=alpha, linestyle=linestyle)

In [None]:
data_train, data_test, selected_drugs, sample_names, sample2labels_subclones = get_real_data(total_genes=600, seed=13, use_pathways=True, test_size=0.2, use_scatrex_clusters=use_scatrex, compute_data=False, use_median_of_means=False)

In [None]:
Kmax, N, dim_pathways = data_train['X'].shape
D = data_train['D']
scores = torch.exp(data_train['log_scores'])

In [None]:
if True:
    ### 1) Using RNA proportions
    proportions = (data_train['n_rna'] / torch.sum(data_train['n_rna'], dim=0).unsqueeze(0)).T
else:
    ### 2) Correcting RNA proportions using control wells
    proportions_rna = (data_train['n_rna'] / torch.sum(data_train['n_rna'], dim=0).unsqueeze(0)).T
    proportions = torch.zeros((N,Kmax))
    for i in range(N):
        lsk =  [k for k,el in enumerate(data_train['masks']['C'][:,i]) if el]
        pi0fd = 1.-torch.mean(data_train['frac_c'][lsk,i])
        proportions[i,0] = pi0fd
        proportions[i,1:] = (1-pi0fd) * proportions_rna[i,1:] / torch.sum(proportions_rna[i,1:])

In [None]:
# Vectorize
vec_proportions = []
subclone_features = []
indexes_subclones = []
for i in range(N):
    for k in range(Kmax):
        vec_proportions.append(proportions[i,k])
        subclone_features.append(list(data_train['X'][k,i,:]))
        indexes_subclones.append(i)
subclone_features = torch.tensor(subclone_features).T
vec_proportions = torch.tensor(vec_proportions)
indexes_subclones = torch.tensor(indexes_subclones)

In [None]:
import torch.nn as nn
import torch.optim as optim

class FM(nn.Module):
    def __init__(self, D=None, dim_drug_feat=100, dim_pathways=39):
        super().__init__()
        self.bias = nn.Parameter(torch.ones(D))
        self.drug_feat = nn.Parameter(torch.ones(D,dim_drug_feat))
        self.W = nn.Parameter(torch.ones(dim_drug_feat, dim_pathways))
        self.normalize = dim_drug_feat * dim_pathways
 
    def forward(self, subclone_features):
        # return the ratio pi_k/pi_0 of shape:  D x (total subclones)
        # subclone_features: dim feat x (Kmax x Nsamples)
        return torch.exp(self.bias.unsqueeze(1) + self.drug_feat @ self.W @ subclone_features / self.normalize)


model = FM(D, dim_pathways=dim_pathways)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.1, betas=(0.90, 0.999))

for step in range(100):
    # compute loss
    sub_scores = model(subclone_features) * vec_proportions
    hat_scores = torch.zeros((D,N))
    input = torch.zeros(N)
    for d in range(D):
        hat_scores[d,:] = input.scatter_reduce(0, indexes_subclones, sub_scores[d,:], reduce="sum")
    # hat_scores = torch.Tensor.scatter_reduce_(1, indexes_subclones, sub_scores, reduce="sum")
    loss = loss_fn(hat_scores, scores)        
    loss.backward()
    # take a step and zero the parameter gradients
    optimizer.step()
    optimizer.zero_grad()

    if step % 10 == 0:
        print('[iter {}]  loss: {:.4f}'.format(step, loss))

# LOO

In [1]:
Bayesian = False
theta_rna_per_sample = True
use_scatrex = True

import numpy as np
import pandas as pd
import math
import os
import matplotlib as mlt
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import pickle

import torch
import pyro
from torch.distributions import constraints
from pyro import poutine
from pyro.contrib.autoguide import AutoDelta
from pyro.optim import Adam
import pyro.distributions as dist
import torch.distributions as tdist
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete, Predictive


import sys
sys.path.append("/data/users/quentin/Paper_nu_per_patient/save_LOO/")
sys.path.append('/data/users/quentin/Paper_nu_per_patient/')
import generative_all_drugs
from generative_all_drugs.naive_baseline.sampling import *
from generative_all_drugs.naive_baseline.sampling_test import *
from generative_all_drugs.naive_baseline.prior import *
from generative_all_drugs.data_preprocessing.get_real_data import *
from generative_all_drugs.svi.svi_optim import *
from generative_all_drugs.visualization.vis_real_data import *

def plot_CredInt(x, mean, bottom, top, color='#2187bb', horizontal_line_width=0.25, alpha=1, linestyle='-'):
    left = x - horizontal_line_width / 2
    right = x + horizontal_line_width / 2
    plt.plot([x, x], [top, bottom], color=color, alpha=alpha, linestyle=linestyle)
    plt.plot([left, right], [top, top], color=color, alpha=alpha, linestyle=linestyle)
    plt.plot([left, right], [bottom, bottom], color=color, alpha=alpha, linestyle=linestyle)

def split(a, n):
    k, m = divmod(len(a), n)
    return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))





def run_model(data_train, data_test, mode_prop, nsteps_svi=2000):
    Kmax, N, dim_pathways = data_train['X'].shape
    D = data_train['D']
    scores = torch.exp(data_train['log_scores'])
    if mode_prop=='RNA':
        ### 1) Using RNA proportions
        proportions = (data_train['n_rna'] / torch.sum(data_train['n_rna'], dim=0).unsqueeze(0)).T
    else:
        ### 2) Correcting RNA proportions using control wells
        proportions_rna = (data_train['n_rna'] / torch.sum(data_train['n_rna'], dim=0).unsqueeze(0)).T
        proportions = torch.zeros((N,Kmax))
        for i in range(N):
            lsk =  [k for k,el in enumerate(data_train['masks']['C'][:,i]) if el]
            pi0fd = 1.-torch.mean(data_train['frac_c'][lsk,i])
            proportions[i,0] = pi0fd
            proportions[i,1:] = (1-pi0fd) * proportions_rna[i,1:] / torch.sum(proportions_rna[i,1:])
    
    # Vectorize
    vec_proportions = []
    subclone_features = []
    indexes_subclones = []
    for i in range(N):
        for k in range(Kmax):
            vec_proportions.append(proportions[i,k])
            subclone_features.append(list(data_train['X'][k,i,:]))
            indexes_subclones.append(i)
    subclone_features = torch.tensor(subclone_features).T
    vec_proportions = torch.tensor(vec_proportions)
    indexes_subclones = torch.tensor(indexes_subclones)

    import torch.nn as nn
    import torch.optim as optim

    class FM(nn.Module):
        def __init__(self, D=None, dim_drug_feat=100, dim_pathways=39):
            super().__init__()
            self.bias = nn.Parameter(torch.ones(D))
            self.drug_feat = nn.Parameter(torch.ones(D,dim_drug_feat))
            self.W = nn.Parameter(torch.ones(dim_drug_feat, dim_pathways))
            self.normalize = dim_drug_feat * dim_pathways

        def forward(self, subclone_features):
            # return the ratio pi_k/pi_0 of shape:  D x (total subclones)
            # subclone_features: dim feat x (Kmax x Nsamples)
            return torch.exp(self.bias.unsqueeze(1) + self.drug_feat @ self.W @ subclone_features / self.normalize)


    model = FM(D, dim_pathways=dim_pathways)
    loss_fn = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.1, betas=(0.90, 0.999))

    for step in range(400):
        # compute loss
        sub_scores = model(subclone_features) * vec_proportions
        hat_scores = torch.zeros((D,N))
        input = torch.zeros(N)
        for d in range(D):
            hat_scores[d,:] = input.scatter_reduce(0, indexes_subclones, sub_scores[d,:], reduce="sum")
        # hat_scores = torch.Tensor.scatter_reduce_(1, indexes_subclones, sub_scores, reduce="sum")
        loss = loss_fn(hat_scores, scores)        
        loss.backward()
        # take a step and zero the parameter gradients
        optimizer.step()
        optimizer.zero_grad()

        if step % 10 == 0:
            print('[iter {}]  loss: {:.4f}'.format(step, loss))

    # Saving the learned parameters
    params_svi = {}
    params_svi['bias'] = model.bias
    params_svi['drug_feat'] = model.drug_feat
    params_svi['W'] = model.W


    with open('/data/users/quentin/Paper_nu_per_patient/save_LOO/data_LOO_FM/LOO_prop_{0}_train_sample_{1}.pkl'.format(mode_prop, idsample), 'wb') as f:
        pickle.dump(params_svi, f)    

    sub_scores = model(subclone_features) * vec_proportions
    hat_scores_train = torch.zeros((D,N))
    input = torch.zeros(N)
    for d in range(D):
        hat_scores_train[d,:] = input.scatter_reduce(0, indexes_subclones, sub_scores[d,:], reduce="sum")
    log_hat_scores_train = torch.log(hat_scores_train)

    train_log_scores = data_train['log_scores'].numpy().reshape(-1)
    thresholds = np.linspace(0,np.max(np.sort(train_log_scores)[:-1]),100)
    lsy_train = []
    for th in thresholds:
        idxs = np.where(np.abs(train_log_scores)>=th)[0]
        lsy_train.append(np.sum( (log_hat_scores_train.detach().numpy().reshape(-1)*data_train['log_scores'].numpy().reshape(-1)>=0)[idxs] ) / len(idxs) )


        
        
        
    ## TEST
    Kmax, N, dim_pathways = data_test['X'].shape
    D = data_test['D']
    scores = torch.exp(data_test['log_scores'])
    
    if mode_prop=='RNA':
        ### 1) Using RNA proportions
        proportions = (data_test['n_rna'] / torch.sum(data_test['n_rna'], dim=0).unsqueeze(0)).T
    else:
        ### 2) Correcting RNA proportions using control wells
        proportions_rna = (data_test['n_rna'] / torch.sum(data_test['n_rna'], dim=0).unsqueeze(0)).T
        proportions = torch.zeros((N,Kmax))
        for i in range(N):
            lsk =  [k for k,el in enumerate(data_test['masks']['C'][:,i]) if el]
            pi0fd = 1.-torch.mean(data_test['frac_c'][lsk,i])
            proportions[i,0] = pi0fd
            proportions[i,1:] = (1-pi0fd) * proportions_rna[i,1:] / torch.sum(proportions_rna[i,1:])
    
    # Vectorize
    vec_proportions = []
    subclone_features = []
    indexes_subclones = []
    for i in range(N):
        for k in range(Kmax):
            vec_proportions.append(proportions[i,k])
            subclone_features.append(list(data_test['X'][k,i,:]))
            indexes_subclones.append(i)
    subclone_features = torch.tensor(subclone_features).T
    vec_proportions = torch.tensor(vec_proportions)
    indexes_subclones = torch.tensor(indexes_subclones)


    
    sub_scores = model(subclone_features) * vec_proportions
    hat_scores_test = torch.zeros((D,N))
    input = torch.zeros(N)
    for d in range(D):
        hat_scores_test[d,:] = input.scatter_reduce(0, indexes_subclones, sub_scores[d,:], reduce="sum")

    log_hat_scores_test = torch.log(hat_scores_test)
    

    lsy = []
    lenidxs = []
    for th in thresholds:
        idxs = np.where(np.abs(data_test['log_scores'].numpy().reshape(-1))>=th)[0]
        lenidxs.append(len(idxs))
        lsy.append(np.sum( (log_hat_scores_test.detach().numpy().reshape(-1)*data_test['log_scores'].numpy().reshape(-1)>=0)[idxs] ) / len(idxs) )

        
        
    dic = {}
    dic['thresholds'] = thresholds
    dic['sign_pred_train'] = lsy_train
    dic['sign_pred_test'] = lsy
    dic['score_pred_train'] = log_hat_scores_train.detach().numpy()
    dic['score_pred_test'] = log_hat_scores_test.detach().numpy()
    with open('/data/users/quentin/Paper_nu_per_patient/save_LOO/data_LOO_FM/LOO_scores_prop_{0}_train_sample_{1}.pkl'.format(mode_prop, idsample), 'wb') as f:
        pickle.dump(dic, f)    

        
data, sample2nb_cells_per_well, sample2nb_cells_per_well_control, S, Kmax, Nsample, R, C, D, sampleID2K, sample_names, selected_drugs, S2KS = get_real_data(total_genes=600, seed=13, use_pathways=True, use_scatrex_clusters=use_scatrex, compute_data=False, use_median_of_means=False, get_split=False)




for mode_prop in ['RNA','corr_RNA']:
    
    print('MODE  :', mode_prop)
    
    for idsample in range(Nsample):
        idxs_train = [i for i in range(Nsample) if i!=idsample]
        idxs_test = [idsample]
            
        data_train, data_test = get_real_data_split(data, sample2nb_cells_per_well, sample2nb_cells_per_well_control, S, Kmax, Nsample, R, C, D, idxs_train, idxs_test, sampleID2K, sample_names, selected_drugs, S2KS)
        
        run_model(data_train, data_test, mode_prop)

Number of drugs 29
Number of samples:  59
MODE  : RNA
[iter 0]  loss: 3.1330
[iter 10]  loss: 0.0408
[iter 20]  loss: 0.0579
[iter 30]  loss: 0.0777
[iter 40]  loss: 0.0478
[iter 50]  loss: 0.0209
[iter 60]  loss: 0.0174
[iter 70]  loss: 0.0180
[iter 80]  loss: 0.0167
[iter 90]  loss: 0.0166
[iter 100]  loss: 0.0166
[iter 110]  loss: 0.0165
[iter 120]  loss: 0.0165
[iter 130]  loss: 0.0165
[iter 140]  loss: 0.0164
[iter 150]  loss: 0.0164
[iter 160]  loss: 0.0164
[iter 170]  loss: 0.0164
[iter 180]  loss: 0.0164
[iter 190]  loss: 0.0163
[iter 200]  loss: 0.0163
[iter 210]  loss: 0.0163
[iter 220]  loss: 0.0163
[iter 230]  loss: 0.0163
[iter 240]  loss: 0.0162
[iter 250]  loss: 0.0162
[iter 260]  loss: 0.0162
[iter 270]  loss: 0.0162
[iter 280]  loss: 0.0161
[iter 290]  loss: 0.0161
[iter 300]  loss: 0.0161
[iter 310]  loss: 0.0160
[iter 320]  loss: 0.0160
[iter 330]  loss: 0.0160
[iter 340]  loss: 0.0160
[iter 350]  loss: 0.0159
[iter 360]  loss: 0.0159
[iter 370]  loss: 0.0159
[iter 3

  lsy.append(np.sum( (log_hat_scores_test.detach().numpy().reshape(-1)*data_test['log_scores'].numpy().reshape(-1)>=0)[idxs] ) / len(idxs) )


[iter 0]  loss: 3.1339
[iter 10]  loss: 0.0408
[iter 20]  loss: 0.0578
[iter 30]  loss: 0.0777
[iter 40]  loss: 0.0478
[iter 50]  loss: 0.0209
[iter 60]  loss: 0.0174
[iter 70]  loss: 0.0180
[iter 80]  loss: 0.0168
[iter 90]  loss: 0.0166
[iter 100]  loss: 0.0166
[iter 110]  loss: 0.0165
[iter 120]  loss: 0.0165
[iter 130]  loss: 0.0165
[iter 140]  loss: 0.0165
[iter 150]  loss: 0.0165
[iter 160]  loss: 0.0164
[iter 170]  loss: 0.0164
[iter 180]  loss: 0.0164
[iter 190]  loss: 0.0164
[iter 200]  loss: 0.0164
[iter 210]  loss: 0.0164
[iter 220]  loss: 0.0163
[iter 230]  loss: 0.0163
[iter 240]  loss: 0.0163
[iter 250]  loss: 0.0163
[iter 260]  loss: 0.0163
[iter 270]  loss: 0.0162
[iter 280]  loss: 0.0162
[iter 290]  loss: 0.0162
[iter 300]  loss: 0.0162
[iter 310]  loss: 0.0161
[iter 320]  loss: 0.0161
[iter 330]  loss: 0.0161
[iter 340]  loss: 0.0160
[iter 350]  loss: 0.0160
[iter 360]  loss: 0.0160
[iter 370]  loss: 0.0159
[iter 380]  loss: 0.0159
[iter 390]  loss: 0.0159
[iter 0]  l

[iter 120]  loss: 0.0165
[iter 130]  loss: 0.0164
[iter 140]  loss: 0.0164
[iter 150]  loss: 0.0164
[iter 160]  loss: 0.0164
[iter 170]  loss: 0.0164
[iter 180]  loss: 0.0164
[iter 190]  loss: 0.0164
[iter 200]  loss: 0.0163
[iter 210]  loss: 0.0163
[iter 220]  loss: 0.0163
[iter 230]  loss: 0.0163
[iter 240]  loss: 0.0163
[iter 250]  loss: 0.0163
[iter 260]  loss: 0.0162
[iter 270]  loss: 0.0162
[iter 280]  loss: 0.0162
[iter 290]  loss: 0.0162
[iter 300]  loss: 0.0161
[iter 310]  loss: 0.0161
[iter 320]  loss: 0.0161
[iter 330]  loss: 0.0160
[iter 340]  loss: 0.0160
[iter 350]  loss: 0.0160
[iter 360]  loss: 0.0159
[iter 370]  loss: 0.0159
[iter 380]  loss: 0.0159
[iter 390]  loss: 0.0158
[iter 0]  loss: 3.1344
[iter 10]  loss: 0.0408
[iter 20]  loss: 0.0578
[iter 30]  loss: 0.0777
[iter 40]  loss: 0.0478
[iter 50]  loss: 0.0209
[iter 60]  loss: 0.0174
[iter 70]  loss: 0.0180
[iter 80]  loss: 0.0167
[iter 90]  loss: 0.0166
[iter 100]  loss: 0.0166
[iter 110]  loss: 0.0165
[iter 120] 

[iter 250]  loss: 0.0163
[iter 260]  loss: 0.0163
[iter 270]  loss: 0.0163
[iter 280]  loss: 0.0162
[iter 290]  loss: 0.0162
[iter 300]  loss: 0.0162
[iter 310]  loss: 0.0162
[iter 320]  loss: 0.0161
[iter 330]  loss: 0.0161
[iter 340]  loss: 0.0161
[iter 350]  loss: 0.0160
[iter 360]  loss: 0.0160
[iter 370]  loss: 0.0160
[iter 380]  loss: 0.0159
[iter 390]  loss: 0.0159
[iter 0]  loss: 3.1331
[iter 10]  loss: 0.0408
[iter 20]  loss: 0.0579
[iter 30]  loss: 0.0777
[iter 40]  loss: 0.0478
[iter 50]  loss: 0.0209
[iter 60]  loss: 0.0174
[iter 70]  loss: 0.0180
[iter 80]  loss: 0.0167
[iter 90]  loss: 0.0166
[iter 100]  loss: 0.0166
[iter 110]  loss: 0.0165
[iter 120]  loss: 0.0165
[iter 130]  loss: 0.0165
[iter 140]  loss: 0.0164
[iter 150]  loss: 0.0164
[iter 160]  loss: 0.0164
[iter 170]  loss: 0.0164
[iter 180]  loss: 0.0164
[iter 190]  loss: 0.0164
[iter 200]  loss: 0.0163
[iter 210]  loss: 0.0163
[iter 220]  loss: 0.0163
[iter 230]  loss: 0.0163
[iter 240]  loss: 0.0162
[iter 250] 

[iter 0]  loss: 3.1379
[iter 10]  loss: 0.0410
[iter 20]  loss: 0.0576
[iter 30]  loss: 0.0775
[iter 40]  loss: 0.0478
[iter 50]  loss: 0.0209
[iter 60]  loss: 0.0173
[iter 70]  loss: 0.0180
[iter 80]  loss: 0.0167
[iter 90]  loss: 0.0166
[iter 100]  loss: 0.0165
[iter 110]  loss: 0.0165
[iter 120]  loss: 0.0164
[iter 130]  loss: 0.0164
[iter 140]  loss: 0.0164
[iter 150]  loss: 0.0164
[iter 160]  loss: 0.0164
[iter 170]  loss: 0.0164
[iter 180]  loss: 0.0163
[iter 190]  loss: 0.0163
[iter 200]  loss: 0.0163
[iter 210]  loss: 0.0163
[iter 220]  loss: 0.0163
[iter 230]  loss: 0.0163
[iter 240]  loss: 0.0162
[iter 250]  loss: 0.0162
[iter 260]  loss: 0.0162
[iter 270]  loss: 0.0162
[iter 280]  loss: 0.0161
[iter 290]  loss: 0.0161
[iter 300]  loss: 0.0161
[iter 310]  loss: 0.0161
[iter 320]  loss: 0.0160
[iter 330]  loss: 0.0160
[iter 340]  loss: 0.0160
[iter 350]  loss: 0.0159
[iter 360]  loss: 0.0159
[iter 370]  loss: 0.0159
[iter 380]  loss: 0.0158
[iter 390]  loss: 0.0158
[iter 0]  l

[iter 120]  loss: 0.0164
[iter 130]  loss: 0.0164
[iter 140]  loss: 0.0164
[iter 150]  loss: 0.0164
[iter 160]  loss: 0.0164
[iter 170]  loss: 0.0163
[iter 180]  loss: 0.0163
[iter 190]  loss: 0.0163
[iter 200]  loss: 0.0163
[iter 210]  loss: 0.0163
[iter 220]  loss: 0.0163
[iter 230]  loss: 0.0162
[iter 240]  loss: 0.0162
[iter 250]  loss: 0.0162
[iter 260]  loss: 0.0162
[iter 270]  loss: 0.0161
[iter 280]  loss: 0.0161
[iter 290]  loss: 0.0161
[iter 300]  loss: 0.0160
[iter 310]  loss: 0.0160
[iter 320]  loss: 0.0160
[iter 330]  loss: 0.0159
[iter 340]  loss: 0.0159
[iter 350]  loss: 0.0159
[iter 360]  loss: 0.0158
[iter 370]  loss: 0.0158
[iter 380]  loss: 0.0158
[iter 390]  loss: 0.0158
[iter 0]  loss: 3.1328
[iter 10]  loss: 0.0407
[iter 20]  loss: 0.0579
[iter 30]  loss: 0.0777
[iter 40]  loss: 0.0478
[iter 50]  loss: 0.0209
[iter 60]  loss: 0.0174
[iter 70]  loss: 0.0180
[iter 80]  loss: 0.0167
[iter 90]  loss: 0.0166
[iter 100]  loss: 0.0166
[iter 110]  loss: 0.0165
[iter 120] 

[iter 240]  loss: 0.0163
[iter 250]  loss: 0.0162
[iter 260]  loss: 0.0162
[iter 270]  loss: 0.0162
[iter 280]  loss: 0.0162
[iter 290]  loss: 0.0161
[iter 300]  loss: 0.0161
[iter 310]  loss: 0.0161
[iter 320]  loss: 0.0160
[iter 330]  loss: 0.0160
[iter 340]  loss: 0.0160
[iter 350]  loss: 0.0159
[iter 360]  loss: 0.0159
[iter 370]  loss: 0.0159
[iter 380]  loss: 0.0159
[iter 390]  loss: 0.0158
[iter 0]  loss: 3.1339
[iter 10]  loss: 0.0408
[iter 20]  loss: 0.0578
[iter 30]  loss: 0.0777
[iter 40]  loss: 0.0479
[iter 50]  loss: 0.0209
[iter 60]  loss: 0.0174
[iter 70]  loss: 0.0180
[iter 80]  loss: 0.0168
[iter 90]  loss: 0.0166
[iter 100]  loss: 0.0166
[iter 110]  loss: 0.0165
[iter 120]  loss: 0.0165
[iter 130]  loss: 0.0165
[iter 140]  loss: 0.0165
[iter 150]  loss: 0.0165
[iter 160]  loss: 0.0164
[iter 170]  loss: 0.0164
[iter 180]  loss: 0.0164
[iter 190]  loss: 0.0164
[iter 200]  loss: 0.0164
[iter 210]  loss: 0.0164
[iter 220]  loss: 0.0163
[iter 230]  loss: 0.0163
[iter 240] 

[iter 360]  loss: 0.0157
[iter 370]  loss: 0.0157
[iter 380]  loss: 0.0157
[iter 390]  loss: 0.0157
[iter 0]  loss: 3.1384
[iter 10]  loss: 0.0403
[iter 20]  loss: 0.0567
[iter 30]  loss: 0.0769
[iter 40]  loss: 0.0472
[iter 50]  loss: 0.0202
[iter 60]  loss: 0.0166
[iter 70]  loss: 0.0173
[iter 80]  loss: 0.0160
[iter 90]  loss: 0.0159
[iter 100]  loss: 0.0159
[iter 110]  loss: 0.0158
[iter 120]  loss: 0.0158
[iter 130]  loss: 0.0158
[iter 140]  loss: 0.0158
[iter 150]  loss: 0.0157
[iter 160]  loss: 0.0157
[iter 170]  loss: 0.0157
[iter 180]  loss: 0.0157
[iter 190]  loss: 0.0157
[iter 200]  loss: 0.0157
[iter 210]  loss: 0.0157
[iter 220]  loss: 0.0156
[iter 230]  loss: 0.0156
[iter 240]  loss: 0.0156
[iter 250]  loss: 0.0156
[iter 260]  loss: 0.0156
[iter 270]  loss: 0.0155
[iter 280]  loss: 0.0155
[iter 290]  loss: 0.0155
[iter 300]  loss: 0.0155
[iter 310]  loss: 0.0154
[iter 320]  loss: 0.0154
[iter 330]  loss: 0.0154
[iter 340]  loss: 0.0154
[iter 350]  loss: 0.0153
[iter 360] 

[iter 90]  loss: 0.0166
[iter 100]  loss: 0.0166
[iter 110]  loss: 0.0165
[iter 120]  loss: 0.0165
[iter 130]  loss: 0.0164
[iter 140]  loss: 0.0164
[iter 150]  loss: 0.0164
[iter 160]  loss: 0.0164
[iter 170]  loss: 0.0163
[iter 180]  loss: 0.0163
[iter 190]  loss: 0.0163
[iter 200]  loss: 0.0162
[iter 210]  loss: 0.0162
[iter 220]  loss: 0.0162
[iter 230]  loss: 0.0161
[iter 240]  loss: 0.0160
[iter 250]  loss: 0.0160
[iter 260]  loss: 0.0159
[iter 270]  loss: 0.0158
[iter 280]  loss: 0.0158
[iter 290]  loss: 0.0157
[iter 300]  loss: 0.0156
[iter 310]  loss: 0.0155
[iter 320]  loss: 0.0155
[iter 330]  loss: 0.0154
[iter 340]  loss: 0.0154
[iter 350]  loss: 0.0153
[iter 360]  loss: 0.0153
[iter 370]  loss: 0.0152
[iter 380]  loss: 0.0152
[iter 390]  loss: 0.0151
[iter 0]  loss: 3.1198
[iter 10]  loss: 0.0407
[iter 20]  loss: 0.0580
[iter 30]  loss: 0.0779
[iter 40]  loss: 0.0480
[iter 50]  loss: 0.0210
[iter 60]  loss: 0.0173
[iter 70]  loss: 0.0180
[iter 80]  loss: 0.0167
[iter 90]  

[iter 240]  loss: 0.0161
[iter 250]  loss: 0.0161
[iter 260]  loss: 0.0160
[iter 270]  loss: 0.0159
[iter 280]  loss: 0.0159
[iter 290]  loss: 0.0158
[iter 300]  loss: 0.0157
[iter 310]  loss: 0.0157
[iter 320]  loss: 0.0156
[iter 330]  loss: 0.0155
[iter 340]  loss: 0.0155
[iter 350]  loss: 0.0154
[iter 360]  loss: 0.0154
[iter 370]  loss: 0.0153
[iter 380]  loss: 0.0153
[iter 390]  loss: 0.0152
[iter 0]  loss: 3.1216
[iter 10]  loss: 0.0408
[iter 20]  loss: 0.0578
[iter 30]  loss: 0.0778
[iter 40]  loss: 0.0480
[iter 50]  loss: 0.0209
[iter 60]  loss: 0.0172
[iter 70]  loss: 0.0179
[iter 80]  loss: 0.0166
[iter 90]  loss: 0.0165
[iter 100]  loss: 0.0165
[iter 110]  loss: 0.0164
[iter 120]  loss: 0.0164
[iter 130]  loss: 0.0164
[iter 140]  loss: 0.0163
[iter 150]  loss: 0.0163
[iter 160]  loss: 0.0163
[iter 170]  loss: 0.0162
[iter 180]  loss: 0.0162
[iter 190]  loss: 0.0162
[iter 200]  loss: 0.0161
[iter 210]  loss: 0.0161
[iter 220]  loss: 0.0160
[iter 230]  loss: 0.0159
[iter 240] 

[iter 360]  loss: 0.0153
[iter 370]  loss: 0.0153
[iter 380]  loss: 0.0152
[iter 390]  loss: 0.0152
[iter 0]  loss: 3.1231
[iter 10]  loss: 0.0409
[iter 20]  loss: 0.0578
[iter 30]  loss: 0.0778
[iter 40]  loss: 0.0480
[iter 50]  loss: 0.0210
[iter 60]  loss: 0.0173
[iter 70]  loss: 0.0180
[iter 80]  loss: 0.0167
[iter 90]  loss: 0.0166
[iter 100]  loss: 0.0166
[iter 110]  loss: 0.0165
[iter 120]  loss: 0.0165
[iter 130]  loss: 0.0164
[iter 140]  loss: 0.0164
[iter 150]  loss: 0.0164
[iter 160]  loss: 0.0164
[iter 170]  loss: 0.0164
[iter 180]  loss: 0.0163
[iter 190]  loss: 0.0163
[iter 200]  loss: 0.0163
[iter 210]  loss: 0.0162
[iter 220]  loss: 0.0162
[iter 230]  loss: 0.0162
[iter 240]  loss: 0.0161
[iter 250]  loss: 0.0160
[iter 260]  loss: 0.0160
[iter 270]  loss: 0.0159
[iter 280]  loss: 0.0158
[iter 290]  loss: 0.0158
[iter 300]  loss: 0.0157
[iter 310]  loss: 0.0156
[iter 320]  loss: 0.0156
[iter 330]  loss: 0.0155
[iter 340]  loss: 0.0155
[iter 350]  loss: 0.0154
[iter 360] 

[iter 100]  loss: 0.0152
[iter 110]  loss: 0.0151
[iter 120]  loss: 0.0151
[iter 130]  loss: 0.0151
[iter 140]  loss: 0.0151
[iter 150]  loss: 0.0150
[iter 160]  loss: 0.0150
[iter 170]  loss: 0.0150
[iter 180]  loss: 0.0150
[iter 190]  loss: 0.0149
[iter 200]  loss: 0.0149
[iter 210]  loss: 0.0148
[iter 220]  loss: 0.0148
[iter 230]  loss: 0.0147
[iter 240]  loss: 0.0147
[iter 250]  loss: 0.0146
[iter 260]  loss: 0.0145
[iter 270]  loss: 0.0145
[iter 280]  loss: 0.0144
[iter 290]  loss: 0.0143
[iter 300]  loss: 0.0142
[iter 310]  loss: 0.0142
[iter 320]  loss: 0.0141
[iter 330]  loss: 0.0141
[iter 340]  loss: 0.0140
[iter 350]  loss: 0.0140
[iter 360]  loss: 0.0139
[iter 370]  loss: 0.0139
[iter 380]  loss: 0.0139
[iter 390]  loss: 0.0138
[iter 0]  loss: 3.1245
[iter 10]  loss: 0.0409
[iter 20]  loss: 0.0577
[iter 30]  loss: 0.0777
[iter 40]  loss: 0.0479
[iter 50]  loss: 0.0209
[iter 60]  loss: 0.0172
[iter 70]  loss: 0.0179
[iter 80]  loss: 0.0166
[iter 90]  loss: 0.0165
[iter 100] 

[iter 230]  loss: 0.0161
[iter 240]  loss: 0.0160
[iter 250]  loss: 0.0160
[iter 260]  loss: 0.0159
[iter 270]  loss: 0.0158
[iter 280]  loss: 0.0158
[iter 290]  loss: 0.0157
[iter 300]  loss: 0.0156
[iter 310]  loss: 0.0156
[iter 320]  loss: 0.0155
[iter 330]  loss: 0.0154
[iter 340]  loss: 0.0154
[iter 350]  loss: 0.0153
[iter 360]  loss: 0.0153
[iter 370]  loss: 0.0153
[iter 380]  loss: 0.0152
[iter 390]  loss: 0.0152
[iter 0]  loss: 3.1236
[iter 10]  loss: 0.0409
[iter 20]  loss: 0.0577
[iter 30]  loss: 0.0777
[iter 40]  loss: 0.0480
[iter 50]  loss: 0.0210
[iter 60]  loss: 0.0172
[iter 70]  loss: 0.0179
[iter 80]  loss: 0.0166
[iter 90]  loss: 0.0165
[iter 100]  loss: 0.0165
[iter 110]  loss: 0.0164
[iter 120]  loss: 0.0164
[iter 130]  loss: 0.0164
[iter 140]  loss: 0.0163
[iter 150]  loss: 0.0163
[iter 160]  loss: 0.0163
[iter 170]  loss: 0.0163
[iter 180]  loss: 0.0163
[iter 190]  loss: 0.0162
[iter 200]  loss: 0.0162
[iter 210]  loss: 0.0162
[iter 220]  loss: 0.0161
[iter 230] 

[iter 350]  loss: 0.0153
[iter 360]  loss: 0.0153
[iter 370]  loss: 0.0152
[iter 380]  loss: 0.0152
[iter 390]  loss: 0.0151
[iter 0]  loss: 3.1214
[iter 10]  loss: 0.0408
[iter 20]  loss: 0.0580
[iter 30]  loss: 0.0779
[iter 40]  loss: 0.0480
[iter 50]  loss: 0.0210
[iter 60]  loss: 0.0173
[iter 70]  loss: 0.0180
[iter 80]  loss: 0.0167
[iter 90]  loss: 0.0166
[iter 100]  loss: 0.0166
[iter 110]  loss: 0.0165
[iter 120]  loss: 0.0165
[iter 130]  loss: 0.0165
[iter 140]  loss: 0.0164
[iter 150]  loss: 0.0164
[iter 160]  loss: 0.0164
[iter 170]  loss: 0.0164
[iter 180]  loss: 0.0164
[iter 190]  loss: 0.0163
[iter 200]  loss: 0.0163
[iter 210]  loss: 0.0163
[iter 220]  loss: 0.0162
[iter 230]  loss: 0.0162
[iter 240]  loss: 0.0161
[iter 250]  loss: 0.0161
[iter 260]  loss: 0.0160
[iter 270]  loss: 0.0160
[iter 280]  loss: 0.0159
[iter 290]  loss: 0.0158
[iter 300]  loss: 0.0158
[iter 310]  loss: 0.0157
[iter 320]  loss: 0.0156
[iter 330]  loss: 0.0156
[iter 340]  loss: 0.0155
[iter 350] 

[iter 80]  loss: 0.0165
[iter 90]  loss: 0.0164
[iter 100]  loss: 0.0163
[iter 110]  loss: 0.0162
[iter 120]  loss: 0.0162
[iter 130]  loss: 0.0162
[iter 140]  loss: 0.0162
[iter 150]  loss: 0.0162
[iter 160]  loss: 0.0161
[iter 170]  loss: 0.0161
[iter 180]  loss: 0.0161
[iter 190]  loss: 0.0161
[iter 200]  loss: 0.0160
[iter 210]  loss: 0.0160
[iter 220]  loss: 0.0159
[iter 230]  loss: 0.0159
[iter 240]  loss: 0.0158
[iter 250]  loss: 0.0158
[iter 260]  loss: 0.0157
[iter 270]  loss: 0.0156
[iter 280]  loss: 0.0156
[iter 290]  loss: 0.0155
[iter 300]  loss: 0.0154
[iter 310]  loss: 0.0154
[iter 320]  loss: 0.0153
[iter 330]  loss: 0.0153
[iter 340]  loss: 0.0152
[iter 350]  loss: 0.0152
[iter 360]  loss: 0.0151
[iter 370]  loss: 0.0151
[iter 380]  loss: 0.0150
[iter 390]  loss: 0.0150
[iter 0]  loss: 3.1247
[iter 10]  loss: 0.0403
[iter 20]  loss: 0.0569
[iter 30]  loss: 0.0771
[iter 40]  loss: 0.0474
[iter 50]  loss: 0.0203
[iter 60]  loss: 0.0166
[iter 70]  loss: 0.0173
[iter 80]  

[iter 200]  loss: 0.0163
[iter 210]  loss: 0.0162
[iter 220]  loss: 0.0162
[iter 230]  loss: 0.0162
[iter 240]  loss: 0.0161
[iter 250]  loss: 0.0160
[iter 260]  loss: 0.0160
[iter 270]  loss: 0.0159
[iter 280]  loss: 0.0158
[iter 290]  loss: 0.0158
[iter 300]  loss: 0.0157
[iter 310]  loss: 0.0156
[iter 320]  loss: 0.0156
[iter 330]  loss: 0.0155
[iter 340]  loss: 0.0154
[iter 350]  loss: 0.0154
[iter 360]  loss: 0.0153
[iter 370]  loss: 0.0153
[iter 380]  loss: 0.0152
[iter 390]  loss: 0.0152
