### NOTE: must set up GEARS via this repo to run this notebook: https://github.com/yhr91/GEARS_misc

In [None]:
import argparse
import sys
sys.path.append('../../GEARS_misc/') # add GEARS_misc to path

import pandas as pd
import matplotlib.pyplot as plt 
import seaborn as sns
from scipy.stats import pearsonr

from gears import PertData, GEARS

results_folder = './results/'
dataset = 'jurkat'
data_path = './'
model = 'gears'
device = 'cuda'
seed = 10

%load_ext autoreload
%autoreload 2

def load_data(seed):
    pert_data = PertData(data_path, gi_go=True) # specific saved folder
    # pert_data = PertData(data_path, gene_path='../../data/gears_reproduce/essential_norman.pkl')
    pert_data.load(data_path= data_path+dataset) # load the processed data, the path is saved folder + dataset_name
    # pert_data.load(data_name = 'norman') # load the processed data, the path is saved folder + dataset_name
    pert_data.prepare_split(split = 'simulation', seed = seed)
    pert_data.get_dataloader(batch_size = 256, test_batch_size = 256)
    
    adata_df = pert_data.adata.to_df()
    adata_df['condition'] = pert_data.adata.obs['condition']
    mean_df = adata_df.groupby('condition').mean()
    ctrl_mean = mean_df.loc['ctrl']
    
    return pert_data, ctrl_mean, mean_df

def load_model(pert_data):
    gears_model = GEARS(pert_data, device = 'cuda', 
                weight_bias_track = False, 
                proj_name = 'gears', 
                exp_name = 'gears',
                gi_predict = True)
    gears_model.model_initialize(hidden_size = 256)
    
    return gears_model

    

In [2]:
pert_data, _, _ = load_data(seed)
gears_model = load_model(pert_data)

Found local copy...


These perturbations are not in the GO graph and is thus not able to make prediction for...
['AARS2+ADPRM' 'AARS2+ASNA1' 'AARS2+ATP5F1' ... 'USMG5+ZNRD1'
 'USMG5+ZWINT' 'USMG5+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:5460
combo_seen1:32240
combo_seen2:12052
unseen_single:104
Done!
Creating dataloaders....
Done!


In [3]:
pert_data.adata.obs

Unnamed: 0,condition,dose_val,control,condition_name
0,AARS2+AARS2,1+1,1,Jurkat_AARS2+AARS2_1+1
1,AARS2+AATF,1+1,1,Jurkat_AARS2+AATF_1+1
2,AARS2+ABCB7,1+1,1,Jurkat_AARS2+ABCB7_1+1
3,AARS2+ACTL6A,1+1,1,Jurkat_AARS2+ACTL6A_1+1
4,AARS2+ACTR10,1+1,1,Jurkat_AARS2+ACTR10_1+1
...,...,...,...,...
96236,ctrl,1+1,1,Jurkat_ctrl_1+1
96237,ctrl,1+1,1,Jurkat_ctrl_1+1
96238,ctrl,1+1,1,Jurkat_ctrl_1+1
96239,ctrl,1+1,1,Jurkat_ctrl_1+1


In [4]:
import pandas as pd
import pickle
import numpy as np

In [5]:
pert_data, ctrl_mean, mean_df = load_data(seed)
gears_model = load_model(pert_data)

Found local copy...


These perturbations are not in the GO graph and is thus not able to make prediction for...
['AARS2+ADPRM' 'AARS2+ASNA1' 'AARS2+ATP5F1' ... 'USMG5+ZNRD1'
 'USMG5+ZWINT' 'USMG5+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:5460
combo_seen1:32240
combo_seen2:12052
unseen_single:104
Done!
Creating dataloaders....
Done!


In [6]:
### turn the model on 
model = gears_model.model
model.cell_fitness_pred = True

In [7]:
adata = gears_model.adata
ctrl_adata = adata[adata.obs['condition'] == 'ctrl']

In [8]:
len(gears_model.pert_list)

9853

In [9]:
all_pert_list_options = gears_model.pert_list

In [10]:
import pickle
with open(f'../../data/horlbeck/{dataset}/GI_data_horlbeck_{dataset}.pkl', 'rb') as f:
    fitness_data = pickle.load(f)

In [11]:
fitness_data.keys()

dict_keys(['good_phen', 'fitness_mapper_gene', 'good_genes'])

In [12]:
unique_pert_genes = fitness_data['good_phen'].index.values

In [13]:
from itertools import combinations
pert_list = [i[0] + '+' + i[1] for i in list(combinations(fitness_data['good_phen'].index.values, 2))]

In [14]:
def convert(i):
    if i.split('_')[0] == i.split('_')[1]:
        return i.split('_')[0]
    #elif 'neg' in i:
    #    print(i)
    else:
        return i

In [15]:
import pandas as pd
import numpy as np
import pickle

from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import make_pipeline

class GI_expt():
    # TODO get rid of hold out genes
    def __init__(self, raw=False, pick_genes = None, set_train = None):
        # sample: matrix sampling percentage
        # seed: sampling seed,
        # sampled_data: sampled indices from upper traingle matrix
        # test: matrix values to be predicted
        # train: observed matrix values used for training, 
        # delta: deviation from expectation as inferred using 
        #        a regression model trained on sampled_data (corresponds to sampled_data)
        # delta_tot: delta values computed for all matrix values
        # y_tot: Full true matrix (raw fitness values)
        # raws: Sampled raw fitness values (corresponds to sampled_data)
        # transformer: regression model used to determine expected fitness values
        

        # with open('/dfs/project/perturb-gnn/datasets/Horlbeck2018/GI_data_horlbeck_jurkat.pkl', 'rb') as f:
        # with open('../../data/horlbeck/jurkat/GI_data_horlbeck_jurkat.pkl', 'rb') as f:
        with open(f'../../data/horlbeck/{dataset}/GI_data_horlbeck_{dataset}.pkl', 'rb') as f:
            self.data = pickle.load(f)
            
        self.data['fitness_mapper_gene'] = {convert(i): j for i,j in self.data['fitness_mapper_gene'].items()}
        self.raw = raw
        self.pick_genes = pick_genes
        self.set_train = set_train

    def pre_process(self, sample, seed):
        itr_data = {}
        itr_data['sample'] = sample
        itr_data['seed'] = seed

        # Randomly sample from the fitness matrix
        if self.set_train is not None:
            itr_data['sampled_data'] = self.set_train
        elif self.pick_genes is None:
            itr_data['sampled_data'] = \
               upper_triangle(self.data['good_phen'], k=0).sample(
               frac=float(sample) / 100., replace=False, random_state=seed).index
        else:
            itr_data['sampled_data'] = \
                upper_triangle(self.data['good_phen'].iloc[self.pick_genes,
                                                    self.pick_genes]).index

        # Set up masked delta matrix for performing matrix completion
        itr_data['delta'], itr_data['test'], \
        itr_data['train'], itr_data['delta_tot'], \
        itr_data['X_tot'], itr_data['y_tot'], itr_data['transformer'], \
        itr_data['raws'] = \
            get_masked_delta_matrix(self.data['good_phen'],
                                    itr_data['sampled_data'],
                                    self.data['fitness_mapper_gene'],
                                    self.data['good_genes'])

        return itr_data


def get_masked_delta_matrix(data, sampling, phen_mapper, good_genes):
    masked_data, mask = get_masked_data(data, sampling)

    upper_masked_data = upper_triangle(masked_data)
    predicted = upper_masked_data[upper_masked_data == 0].index
    given = upper_masked_data[upper_masked_data != 0].index

    delta_tot, X, X_tot, y, y_tot, transformer, raws = get_deltas(data, masked_data, phen_mapper, good_genes)
    delta, _ = get_masked_data(delta_tot, sampling)
    raws, _ = get_masked_data(raws, sampling)

    return delta, predicted, given, delta_tot, X_tot, y_tot, transformer, raws

def get_deltas(source_data, sampled_data, phen_mapper, good_genes):
    # this is the routine that fits a quadratic model to observed fitness measurements and returns
    # the deviations from the expectation given by that model (the "deltas")
    y = sampled_data.stack()
    y = y[y != 0]

    X = np.concatenate([y.index.get_level_values(0).map(lambda x: phen_mapper[x]).values[:, np.newaxis],
                        y.index.get_level_values(1).map(lambda x: phen_mapper[x]).values[:, np.newaxis]], axis=1)

    y_tot = source_data.loc[good_genes, good_genes].stack()
    X_tot = np.concatenate([y_tot.index.get_level_values(0).map(lambda x: phen_mapper[x]).values[:, np.newaxis],
                            y_tot.index.get_level_values(1).map(lambda x: phen_mapper[x]).values[:, np.newaxis]],
                           axis=1)

    transformer = PolynomialFeatures()
    model = make_pipeline(PolynomialFeatures(), LinearRegression())
    model.fit(X, y)

    delta = pd.Series(y_tot - model.predict(X_tot), index=y_tot.index)
    raws = pd.Series(y_tot, index=y_tot.index).unstack()
    delta = delta.unstack()

    return delta, X, X_tot, y, y_tot, model, raws

def upper_triangle(M, k=1):
    """ Copyright (C) 2019  Thomas Norman
    Return the upper triangular part of a matrix in stacked format (i.e. as a vector)
    """
    keep = np.triu(np.ones(M.shape), k=k).astype('bool').reshape(M.size)
    return M.stack(dropna=False).loc[keep]

def get_masked_data(df, ind, mean_normalize=False):
    masked_data = df.copy().values
    mask = pd.DataFrame(0, index=df.index, columns=df.columns)

    for gene1, gene2 in ind:
        mask.loc[gene1, gene2] = 1

        # assume DataFrame is symmetric
    mask = mask + mask.T
    mask = (mask != 0).values
    masked_data[~mask] = 0
    masked_data_df = pd.DataFrame(masked_data, index=df.index, columns=df.columns)

    # whether to center the observed entries such that the overall mean is 0
    if mean_normalize:
        masked_data_df = masked_data_df.stack()
        offset = masked_data_df[masked_data_df != 0]
        offset = offset.mean()
        masked_data_df[masked_data_df != 0] = masked_data_df[masked_data_df != 0] - offset
        print(offset)
        masked_data_df = masked_data_df.unstack()

    return masked_data_df, mask

In [17]:
# Run prediction
from torch_geometric.data import DataLoader

import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
from copy import deepcopy

def evaluate(model, loader):
    model.eval()
    y_all = []
    loss_all = []
    pert_all = []
    y_true_all = []
    for step, batch in enumerate(tqdm(loader)):
        batch = batch.to(device)
        out, y_pred = model(batch)
        loss = loss_fct(batch.y.reshape(-1), y_pred.reshape(-1))
        loss_all.append(loss.item())
        y_all.append(y_pred.detach().cpu().numpy())
        pert_all.append(batch.pert)
        y_true_all.append(batch.y)
    return y_all, loss_all, pert_all, y_true_all

def get_dataloader(set2cond, batch_size, test_batch_size = None):
    if test_batch_size is None:
        test_batch_size = batch_size

    cell_graphs = {}
    

    splits = ['train','val','test']
    for i in splits:
        cell_graphs[i] = []
        for p in set2cond[i]:
            if p in all_cell_graphs:
                cell_graphs[i].extend(all_cell_graphs[p])

    # Set up dataloaders
    train_loader = DataLoader(cell_graphs['train'],
                        batch_size=batch_size, shuffle=True, drop_last = True)
    val_loader = DataLoader(cell_graphs['val'],
                        batch_size=batch_size, shuffle=True)

    test_loader = DataLoader(cell_graphs['test'],
                    batch_size=batch_size, shuffle=False)
    return {'train_loader': train_loader,
                        'val_loader': val_loader,
                        'test_loader': test_loader}

# for seed in [123, 847, 618, 748, 808]:
n_train = 96141
n_samples = [100, 500, 1000, 2000, 3000, 4000, 5000, 6000]
for seed in [123, 847, 618]:
    for n_sample in n_samples:
        for label in ['raw_y']:
            sample_frac = float(n_sample) / float(n_train) * 100

            GI = GI_expt(pick_genes=None)
            itr_data = GI.pre_process(sample=sample_frac, seed=seed)

            itr_data['raw_y'] = fitness_data['good_phen']
            y = itr_data[label].values

            gene2idx = dict(zip(itr_data[label].index.values, range(len(itr_data[label].index.values))))

            pert_list = [i[0] + '+' + i[1] for i in list(combinations(itr_data[label].index.values, 2))]
            pert_list += [i + '+ctrl' for i in unique_pert_genes]

            y_list = []
            for i in pert_list:
                if i.split('+')[1]!= 'ctrl':
                    y_list.append(y[gene2idx[i.split('+')[0]], gene2idx[i.split('+')[1]]])
                else:
                    y_list.append(y[gene2idx[i.split('+')[0]], gene2idx[i.split('+')[0]]])

            pert_list2y_list = dict(zip(pert_list, y_list))

            train = [i[0] + '+' + i[1] for i in itr_data['train']]
            test = [i[0] + '+' + i[1] for i in itr_data['test']]

            from sklearn.model_selection import train_test_split
            train, val, _, _ = train_test_split(train, [-1] * len(train), test_size=0.1, random_state=seed)

            ## add single back
            train = train + [i + '+ctrl' for i in unique_pert_genes]

            set2cond = {
                'train': train,
                'val': val,
                'test': test
            }

            from torch_geometric.data import Data
            import torch
            from tqdm import tqdm
            num_samples = 20
            all_cell_graphs = {}
            pert_na = []
            for pert in tqdm(pert_list):
                pert_name = pert
                pert = pert.split('+')
                pert = [i for i in pert if i!='ctrl']

                # Get the indices (and signs) of applied perturbation
                try:
                    pert_idx = [np.where(p == np.array(all_pert_list_options))[0][0] for p in pert]
                    Xs = ctrl_adata[np.random.randint(0, len(ctrl_adata), num_samples), :].X.toarray()
                    # Create cell graphs
                    all_cell_graphs[pert_name] = [Data(x=torch.Tensor(X).T, pert_idx = pert_idx, pert=pert, y=pert_list2y_list[pert_name]) for X in Xs]
                except:
                    pert_na.append(pert_name)

            print(len(pert_na))
            loaders = get_dataloader(set2cond, batch_size = 256)
            train_loader = loaders['train_loader']
            val_loader = loaders['val_loader']
            test_loader = loaders['test_loader']

            optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay = 5e-4)
            scheduler = StepLR(optimizer, step_size=1, gamma=0.5)

            loss_fct = F.mse_loss
            earlystop_validation_metric = 'mse'
            binary_output = False

            earlystop_direction = 'descend'
            min_val = np.inf

            best_model = deepcopy(model).to(device)

            loss_history = {
                'loss': []
            }

            print('Start Training...')
            for epoch in range(1):
                model.train()

                for step, batch in enumerate(tqdm(train_loader)):
                    optimizer.zero_grad()
                    batch = batch.to(device)
                    out, y_pred = model(batch)
                    loss = loss_fct(batch.y.float().reshape(-1), y_pred.reshape(-1))
                    loss.backward()
                    optimizer.step()
                    loss_history['loss'].append(loss.item())

                    if (step % 200 == 0) and (step >= 200):
                        log = "Epoch {} Step {} Train Loss: {:.4f}" 
                        print(log.format(epoch + 1, step + 1, loss.item()))
                #_, val_loss
                # val_y_all, val_loss_all, val_pert_all, val_y_true_all = evaluate(model, val_loader)
                # if np.mean(val_loss_all) < min_val:
                #     best_model = deepcopy(model)
                #     min_val = np.mean(val_loss_all)
                best_model = deepcopy(model)
            test_y_all, test_loss_all, test_pert_all, test_y_true_all = evaluate(best_model, test_loader)
            print(np.mean(test_loss_all))


            pert_test_list = [i[0] + '+' +  i[1] if len(i) == 2 else i[0] +'+ctrl' for i in [j for i in test_pert_all for j in i]]
            test_y_list = np.concatenate([i for i in [j for i in test_y_all for j in i]])
            test_y_true_list = np.concatenate([i.detach().cpu().numpy() for i in test_y_true_all])
            df_pred = pd.DataFrame((pert_test_list, test_y_true_list, test_y_list)).T
            pert2y_true = dict(df_pred.groupby([0])[1].agg(np.mean))
            pert2y_pred = dict(df_pred.groupby([0])[2].agg(np.mean))

            pred = pd.DataFrame(pert2y_true, index = [0]).T.reset_index().rename(columns = {0: 'truth', 'index': 'test_pert'})
            pred['pred'] = pred.test_pert.apply(lambda x: pert2y_pred[x])
            pred.to_csv(results_folder + label + '_pred_no_pretrain_seed' + str(seed) + '_nsample' + str(n_sample) + dataset + '_gi.csv', index = False)

100%|██████████| 96141/96141 [07:54<00:00, 202.64it/s]


10236
Start Training...


100%|██████████| 38/38 [00:07<00:00,  5.10it/s]
100%|██████████| 6673/6673 [04:21<00:00, 25.56it/s]


0.08886251469980647


100%|██████████| 96141/96141 [07:46<00:00, 206.30it/s] 


10236
Start Training...


100%|██████████| 63/63 [00:11<00:00,  5.30it/s]
100%|██████████| 6645/6645 [04:29<00:00, 24.65it/s]


0.17974118719429372


100%|██████████| 96141/96141 [08:00<00:00, 199.96it/s] 


10236
Start Training...


100%|██████████| 94/94 [00:18<00:00,  5.19it/s]
100%|██████████| 6610/6610 [04:33<00:00, 24.19it/s]


0.04854963337285911


100%|██████████| 96141/96141 [07:48<00:00, 205.42it/s] 


10236
Start Training...


100%|██████████| 157/157 [00:30<00:00,  5.09it/s]
100%|██████████| 6540/6540 [04:36<00:00, 23.64it/s]


0.15253699653033126


100%|██████████| 96141/96141 [08:04<00:00, 198.51it/s] 


10236
Start Training...


 92%|█████████▏| 202/220 [00:40<00:03,  5.05it/s]

Epoch 1 Step 201 Train Loss: 0.1325


100%|██████████| 220/220 [00:43<00:00,  5.03it/s]
100%|██████████| 6470/6470 [04:31<00:00, 23.85it/s]


0.0422433259099981


100%|██████████| 96141/96141 [07:51<00:00, 203.75it/s] 


10236
Start Training...


 71%|███████   | 202/284 [00:40<00:16,  5.06it/s]

Epoch 1 Step 201 Train Loss: 0.0244


100%|██████████| 284/284 [00:56<00:00,  5.03it/s]
100%|██████████| 6400/6400 [04:29<00:00, 23.77it/s]


0.10960124078263292


100%|██████████| 96141/96141 [08:17<00:00, 193.10it/s] 


10236
Start Training...


 59%|█████▊    | 202/345 [00:40<00:28,  5.02it/s]

Epoch 1 Step 201 Train Loss: 0.0013


100%|██████████| 345/345 [01:09<00:00,  4.99it/s]
100%|██████████| 6333/6333 [04:23<00:00, 24.03it/s]


0.06464982658337151


100%|██████████| 96141/96141 [07:53<00:00, 202.94it/s] 


10236
Start Training...


 49%|████▉     | 202/409 [00:40<00:41,  5.01it/s]

Epoch 1 Step 201 Train Loss: 0.0014


 98%|█████████▊| 402/409 [01:20<00:01,  5.01it/s]

Epoch 1 Step 401 Train Loss: 0.0010


100%|██████████| 409/409 [01:21<00:00,  5.01it/s]
100%|██████████| 6261/6261 [04:15<00:00, 24.53it/s]


0.0856231907167668


100%|██████████| 96141/96141 [08:02<00:00, 199.38it/s] 


10236
Start Training...


100%|██████████| 38/38 [00:06<00:00,  5.45it/s]
100%|██████████| 6673/6673 [04:49<00:00, 23.07it/s]


0.0039924451936962264


100%|██████████| 96141/96141 [07:55<00:00, 202.01it/s] 


10236
Start Training...


100%|██████████| 63/63 [00:12<00:00,  5.19it/s]
100%|██████████| 6646/6646 [04:46<00:00, 23.18it/s] 


0.007531483741503366


100%|██████████| 96141/96141 [08:11<00:00, 195.70it/s] 


10236
Start Training...


100%|██████████| 94/94 [00:18<00:00,  5.17it/s]
100%|██████████| 6610/6610 [04:29<00:00, 24.51it/s]


0.002746798034427577


100%|██████████| 96141/96141 [08:09<00:00, 196.49it/s] 


10236
Start Training...


100%|██████████| 157/157 [00:31<00:00,  5.05it/s]
100%|██████████| 6540/6540 [04:27<00:00, 24.43it/s]


0.0032517762077622476


100%|██████████| 96141/96141 [08:20<00:00, 192.13it/s] 


10236
Start Training...


 91%|█████████▏| 202/221 [00:40<00:03,  5.01it/s]

Epoch 1 Step 201 Train Loss: 0.0013


100%|██████████| 221/221 [00:43<00:00,  5.02it/s]
100%|██████████| 6470/6470 [04:28<00:00, 24.07it/s]


0.0067860224141509965


100%|██████████| 96141/96141 [08:01<00:00, 199.72it/s] 


10236
Start Training...


 71%|███████   | 202/285 [00:40<00:16,  5.00it/s]

Epoch 1 Step 201 Train Loss: 0.0009


100%|██████████| 285/285 [00:57<00:00,  4.98it/s]
100%|██████████| 6399/6399 [04:31<00:00, 23.56it/s]


0.008035637553658881


100%|██████████| 96141/96141 [08:09<00:00, 196.34it/s] 


10236
Start Training...


 58%|█████▊    | 201/349 [00:40<00:29,  5.01it/s]

Epoch 1 Step 201 Train Loss: 0.0012


100%|██████████| 349/349 [01:12<00:00,  4.82it/s]
100%|██████████| 6328/6328 [04:39<00:00, 22.65it/s]


0.0031634432853144987


100%|██████████| 96141/96141 [08:01<00:00, 199.66it/s] 


10236
Start Training...


 49%|████▉     | 202/411 [00:40<00:41,  4.99it/s]

Epoch 1 Step 201 Train Loss: 0.0011


 98%|█████████▊| 402/411 [01:20<00:01,  5.00it/s]

Epoch 1 Step 401 Train Loss: 0.0012


100%|██████████| 411/411 [01:22<00:00,  4.99it/s]
100%|██████████| 6258/6258 [04:27<00:00, 23.37it/s]


0.0026965347888413425


100%|██████████| 96141/96141 [08:07<00:00, 197.41it/s] 


10236
Start Training...


100%|██████████| 38/38 [00:06<00:00,  5.46it/s]
100%|██████████| 6673/6673 [04:36<00:00, 24.15it/s]


0.0028402375630020675


100%|██████████| 96141/96141 [08:02<00:00, 199.07it/s] 


10236
Start Training...


100%|██████████| 63/63 [00:12<00:00,  5.24it/s]
100%|██████████| 6644/6644 [04:31<00:00, 24.43it/s]


0.003987722754647589


100%|██████████| 96141/96141 [08:09<00:00, 196.31it/s] 


10236
Start Training...


100%|██████████| 94/94 [00:18<00:00,  5.15it/s]
100%|██████████| 6610/6610 [04:32<00:00, 24.28it/s]


0.004202736588699826


100%|██████████| 96141/96141 [08:06<00:00, 197.61it/s] 


10236
Start Training...


100%|██████████| 157/157 [00:43<00:00,  3.57it/s]
100%|██████████| 6540/6540 [05:42<00:00, 19.12it/s]


0.020108390375955354


100%|██████████| 96141/96141 [08:27<00:00, 189.36it/s] 


10236
Start Training...


 92%|█████████▏| 201/219 [00:57<00:05,  3.45it/s]

Epoch 1 Step 201 Train Loss: 0.0011


100%|██████████| 219/219 [01:02<00:00,  3.52it/s]
100%|██████████| 6471/6471 [05:25<00:00, 19.85it/s]


0.0480812792906808


100%|██████████| 96141/96141 [08:08<00:00, 196.86it/s] 


10236
Start Training...


 72%|███████▏  | 201/280 [00:56<00:22,  3.51it/s]

Epoch 1 Step 201 Train Loss: 0.0012


100%|██████████| 280/280 [01:19<00:00,  3.54it/s]
100%|██████████| 6404/6404 [05:26<00:00, 19.64it/s]


0.011432818195660499


100%|██████████| 96141/96141 [08:18<00:00, 192.89it/s] 


10236
Start Training...


 59%|█████▉    | 201/341 [00:56<00:39,  3.52it/s]

Epoch 1 Step 201 Train Loss: 0.0012


100%|██████████| 341/341 [01:36<00:00,  3.54it/s]
100%|██████████| 6336/6336 [05:16<00:00, 19.99it/s]


0.0050094542671651155


100%|██████████| 96141/96141 [08:11<00:00, 195.81it/s] 


10236
Start Training...


 50%|████▉     | 201/403 [00:57<00:56,  3.59it/s]

Epoch 1 Step 201 Train Loss: 0.0008


100%|█████████▉| 401/403 [01:53<00:00,  3.61it/s]

Epoch 1 Step 401 Train Loss: 0.0008


100%|██████████| 403/403 [01:53<00:00,  3.55it/s]
100%|██████████| 6265/6265 [05:29<00:00, 19.00it/s]


0.015234327622951469
