In [1]:
import pandas as pd
import numpy as np

## mRNA data

In [2]:
mrna_raw=pd.read_csv('TCGA_inter_SNUH_clinical_standardized_combat_quantile_data.csv')

In [3]:
mrna_raw=mrna_raw.iloc[:-70,:]

## Non-coding RNA

In [4]:
lnc_raw=pd.read_csv("TCGA_inter_SNUH_clinical_lnc_standardized_combat_quantile_data.csv")

In [5]:
lnc_raw=lnc_raw.iloc[:-70,:]

## Survival patients

In [6]:
survival_raw=pd.read_csv("TCGA_OV_survival_reference.csv")

## Vital status

In [7]:
vital_raw=pd.read_csv('balanced_index_230.csv')

## Handle N/A

In [8]:
mrna_raw=mrna_raw.dropna()
lnc_raw=lnc_raw.dropna()
survival_raw=survival_raw.dropna()

In [9]:
mrna_raw=mrna_raw.drop_duplicates()
lnc_raw=lnc_raw.drop_duplicates()

In [10]:
mrna_patient_idx=mrna_raw.patient.values
lnc_patient_idx=lnc_raw.patient.values
print(len(mrna_patient_idx))
print(len(lnc_patient_idx))

426
426


In [11]:
print(mrna_raw.shape)
print(lnc_raw.shape)
print(survival_raw.shape)

(426, 22205)
(426, 1886)
(485, 2)


## Check an existance of lnc data in mrna

In [12]:
mrna_raw.loc[mrna_raw.patient.isin(lnc_patient_idx)].shape

(426, 22205)

## Extract existed idx from survival raw data

In [13]:
existed_survival=survival_raw.loc[survival_raw.patient.isin(lnc_patient_idx),:]

In [14]:
len(existed_survival)

408

In [15]:
mrna_col=mrna_raw.columns.values
lnc_col=lnc_raw.columns.values


In [16]:
#Remove 'Patients' columns in lnc_col
lnc_col=lnc_col[1:]

In [17]:
intersection_col=[]
for i in mrna_col:
    if i in lnc_col:
        intersection_col.append(i)

In [18]:
#Delete overlapped columns in mrna data
del_inter_mrna_raw=mrna_raw.drop(columns=intersection_col)

In [19]:
test_mrna=del_inter_mrna_raw.loc[del_inter_mrna_raw.patient.isin(existed_survival.patient.values),:]
test_lnc=lnc_raw.loc[lnc_raw.patient.isin(existed_survival.patient.values),:]

In [20]:
print(len(del_inter_mrna_raw))
print(len(lnc_raw))
print(len(test_mrna))
print(len(test_lnc))

426
426
413
413


In [21]:
## Delete duplicated patients id

## Delete duplicated patients id

In [22]:
del_inter_mrna_raw=del_inter_mrna_raw.drop_duplicates(subset='patient')
lnc_raw=lnc_raw.drop_duplicates(subset='patient')
test_lnc=test_lnc.drop_duplicates(subset='patient')
test_mrna=test_mrna.drop_duplicates(subset='patient')

In [23]:
lnc_data=pd.merge(test_lnc,existed_survival,on=['patient'])

In [24]:
mrna_data=pd.merge(test_mrna,existed_survival,on=['patient'])

# Model Class

In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim

from itertools import *
from tqdm import tqdm

## Highway net

In [26]:
class Highway(nn.Module):
    def __init__(self,size,num_layers,f):
        super().__init__()
        self.num_layers=num_layers
        self.nonlinear=nn.ModuleList([nn.Linear(size,size) for _ in range(num_layers)])
        self.linear=nn.ModuleList([nn.Linear(size,size) for _ in range(num_layers)])
        self.gate=nn.ModuleList([nn.Linear(size,size) for _ in range(num_layers)])
        self.f=f
    def forward(self,x):
        for layer in range(self.num_layers):
            gate=F.sigmoid(self.gate[layer](x))
            nonlinear=self.f(self.nonlinear[layer](x))
            linear=self.linear[layer](x)
            x=gate*nonlinear+(1-gate)*linear
        return x
                       

## Similarity loss

In [27]:
def Similarity_loss(modalities):
    mode_rna=modalities['rna'].detach()
    #print(f'mode_rna:{mode_rna.shape}')
    #print(mode_rna[0,:])
    mode_lnc=modalities['lnc'].detach()
    #print(f'mode_lnc:{mode_lnc.shape}')
    cos=nn.CosineSimilarity(dim=1,eps=1e-6)
    M=0.1
    N=mode_rna.shape[0]
    loss=[]
    #sim(x,x) output=distance of num of samples
    sim_same=cos(mode_rna,mode_lnc)
    avg_sim_same=torch.sum(sim_same)/N
    tmp=torch.clone(mode_lnc[0,:])
    mode_lnc[0:-1,:]=mode_lnc[1:,:]
    mode_lnc[-1,:]=tmp
    #print(mode_rna.shape[0]-1)
    for i in range(mode_rna.shape[0]-1):
        #sim(x,y) output=distance of num of samples
        sim_diff=cos(mode_rna,mode_lnc)
        #print(f'sim diff :{sim_diff}')
        avg_sim_diff=torch.sum(sim_diff)/N
        #L_theta(x,y)=max(M-sim(x,y)+sim(x,x))
        L_theta_x_y=max(0,M-avg_sim_diff+avg_sim_same)
        loss.append(L_theta_x_y)
        #print(L_theta_x_y)
        #shift y data
        tmp=torch.clone(mode_lnc[0,:])
        mode_lnc[0:-1,:]=mode_lnc[1:,:]
        mode_lnc[-1,:]=tmp
    total_loss=sum(loss)/len(loss)
    return total_loss
        
        
    

## Train Network

In [677]:
class MultiNet(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.fcm=nn.Linear(20319,256)
        self.fcl=nn.Linear(1885,256)
        self.highway=Highway(256,10,f=F.relu)
        self.fc2=nn.Linear(512,2)
        self.fcd=nn.Linear(512,1)
        self.bn1=nn.BatchNorm1d(256)
        self.bn2=nn.BatchNorm1d(256)
        self.bn3=nn.BatchNorm1d(1)
    
    def forward(self,data):
        rna=data['mRNA']
        x=rna.cuda()
        #print(x)
        x=x.view(x.shape[0],-1)
        x=F.dropout(x,0.4)
        x=F.tanh(self.fcm(x))
        x=self.bn1(x)
        x=F.dropout(x,0.5,training=self.training)
        x=self.highway(x)
        x=self.bn2(x)
        
        lnc=data['lnc']
        y=lnc.cuda()
        #print(y)
        y=y.view(y.shape[0],-1)        
        y=F.dropout(y,0.4)
        y=F.tanh(self.fcl(y))
        y=self.bn1(y)
        y=F.dropout(y,0.5,training=self.training)
        y=self.highway(y)
        y=self.bn2(y)
        
        modal_x=x.clone()
        modal_y=y.clone()
        modal={'rna':modal_x,'lnc':modal_y}
        #similarity loss
        #sim_loss=Similarity_loss(modal)
        #print(sim_loss.requires_grad)
        #concatenates x and y
        concat_x_y=torch.cat((x,y),1)
        
        #vital status
        #score=F.log_softmax(self.fc2(concat_x_y),dim=1)
        #survival
        hazard=self.fcd(concat_x_y)
        
        #print(f'hazard:{hazard}type:{hazard.requires_grad}')
        return {'hazard':hazard,'modal':modal}
    def loss(self,pred,target):
        modal=pred['modal']
        loss1=Similarity_loss(modal)
        
        days_to_death=target.cuda()
        hazard=pred['hazard'].squeeze()
        
        _,idx=torch.sort(days_to_death)
        hazard_probs=F.softmax(hazard[idx].squeeze())
        hazard_cum=torch.stack([torch.tensor(0.0).cuda()]+list(accumulate(hazard_probs)))
        N=hazard_probs.shape[0]
        weights_cum=torch.range(1,N)
        p,q=hazard_cum[1:],1-hazard_cum[:-1]        
        w1,w2=weights_cum,N-weights_cum
        probs=torch.stack([p,q],dim=1)
        
        logits=torch.log(probs)
        
        #print(zeros)
        #print(logits)
        w1=w1.cuda()
        w2=w2.cuda()
        w1.requires_grad=False
        w2.requires_grad=False
        
        #print(w1)
        #print(N)
        #ll1 = (F.nll_loss(logits, torch.zeros(N,dtype=torch.long,device=device), reduce=False) * w1)/N
        ll1 = F.nll_loss(logits, torch.zeros(N,dtype=torch.long,device=device), reduce=False)
        ll1=(ll1*w1)/N
        ll2 = F.nll_loss(logits, torch.ones(N,dtype=torch.long,device=device), reduce=False)
        ll2=(ll2*w2)/N
        #ll2 = (F.nll_loss(logits, torch.ones(N).long(), reduce=False) * w2)/N
        loss2 = torch.mean(ll1 + ll2)
        
        
        #loss1=pred['sim_loss']
        #print(f'sim loss:{loss1} cox loss: {loss2} tot loss:{loss1+loss2}')
        return loss1+loss2

        
        

## Train/Test split 98/2 ratio

In [666]:
from sklearn.model_selection import train_test_split

In [667]:
mrna_train,mrna_test=train_test_split(mrna_data,test_size=0.02,random_state=777)
lnc_train,lnc_test=train_test_split(lnc_data,test_size=0.02,random_state=777)

In [668]:
print(type(mrna_train))
print(mrna_train.shape)
print(lnc_test.shape)

<class 'pandas.core.frame.DataFrame'>
(399, 20321)
(9, 1887)


In [669]:
whole_data={'mRNA':mrna_train.iloc[:,1:].to_numpy(),'lnc':lnc_train.iloc[:,1:].to_numpy()}

## Data Generate

In [670]:
from torch.autograd import Variable
from torch.utils.data import Dataset,DataLoader

In [671]:
class GenerateData(Dataset):
    def __init__(self,dataset):
        mrna_data=dataset['mRNA']
        lnc_data=dataset['lnc']
        self.len=mrna_data.shape[0]
        self.rna_x=torch.from_numpy(mrna_data[:,0:-1]).float()
        self.rna_y=torch.from_numpy(mrna_data[:,-1]).float()
        self.lnc_x=torch.from_numpy(lnc_data[:,0:-1]).float()
        self.lnc_y=torch.from_numpy(lnc_data[:,-1]).float()
        
    def __getitem__(self,index):
        data={'mRNA':self.rna_x[index],'lnc':self.lnc_x[index],'mRNA_y':self.rna_y[index],'lnc_y':self.lnc_y[index]}
        return data
    
    def __len__(self):
        return self.len
    

In [672]:
dataset=GenerateData(whole_data)
train_loader=DataLoader(dataset=dataset,batch_size=16,shuffle=False,num_workers=2)

## Main

In [678]:
from tqdm import tqdm

In [673]:
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

In [674]:
use_cuda
device

device(type='cuda')

In [681]:
learning_rate=1e-5
optimizer=optim.Adam(model.parameters(),lr=learning_rate)

In [682]:
#model=MultiNet()
#model=model.to(device)
model=MultiNet()
model=model.cuda()    
            
for epoch in tqdm(range(100)):    
    for i,data in enumerate(train_loader):        
        with torch.autograd.set_detect_anomaly(True):
            optimizer.zero_grad()
            output=model(data)
            loss=model.loss(output,data['mRNA_y'])
            loss.backward()
            optimizer.step()
    print(f'# of epoch :{epoch} tot loss:{loss}')







  1%|          | 1/100 [00:05<09:30,  5.76s/it][A[A[A

# of epoch :0 tot loss:0.5536885857582092





  2%|▏         | 2/100 [00:11<09:26,  5.78s/it][A[A[A

# of epoch :1 tot loss:0.4668825566768646





  3%|▎         | 3/100 [00:17<09:17,  5.74s/it][A[A[A

# of epoch :2 tot loss:0.47857627272605896





  4%|▍         | 4/100 [00:22<09:10,  5.73s/it][A[A[A

# of epoch :3 tot loss:0.5851966738700867





  5%|▌         | 5/100 [00:28<09:02,  5.71s/it][A[A[A

# of epoch :4 tot loss:0.5900180339813232





  6%|▌         | 6/100 [00:34<08:56,  5.70s/it][A[A[A

# of epoch :5 tot loss:0.5164112448692322





  7%|▋         | 7/100 [00:39<08:48,  5.68s/it][A[A[A

# of epoch :6 tot loss:0.5235821604728699





  8%|▊         | 8/100 [00:45<08:43,  5.70s/it][A[A[A

# of epoch :7 tot loss:0.573197066783905





  9%|▉         | 9/100 [00:51<08:36,  5.68s/it][A[A[A

# of epoch :8 tot loss:0.6178929805755615





 10%|█         | 10/100 [00:57<08:32,  5.70s/it][A[A[A

# of epoch :9 tot loss:0.6058146953582764





 11%|█         | 11/100 [01:02<08:27,  5.70s/it][A[A[A

# of epoch :10 tot loss:0.6170494556427002





 12%|█▏        | 12/100 [01:08<08:19,  5.68s/it][A[A[A

# of epoch :11 tot loss:0.6494596600532532





 13%|█▎        | 13/100 [01:14<08:19,  5.74s/it][A[A[A

# of epoch :12 tot loss:0.5179411768913269





 14%|█▍        | 14/100 [01:20<08:19,  5.80s/it][A[A[A

# of epoch :13 tot loss:0.5945671796798706





 15%|█▌        | 15/100 [01:26<08:15,  5.83s/it][A[A[A

# of epoch :14 tot loss:0.5620136260986328





 16%|█▌        | 16/100 [01:32<08:12,  5.86s/it][A[A[A

# of epoch :15 tot loss:0.5874331593513489





 17%|█▋        | 17/100 [01:37<08:08,  5.89s/it][A[A[A

# of epoch :16 tot loss:0.5308243632316589





 18%|█▊        | 18/100 [01:43<08:03,  5.90s/it][A[A[A

# of epoch :17 tot loss:0.6510264873504639





 19%|█▉        | 19/100 [01:49<07:56,  5.88s/it][A[A[A

# of epoch :18 tot loss:0.6349881291389465





 20%|██        | 20/100 [01:55<07:44,  5.81s/it][A[A[A

# of epoch :19 tot loss:0.5162748098373413





 21%|██        | 21/100 [02:01<07:38,  5.81s/it][A[A[A

# of epoch :20 tot loss:0.49405187368392944





 22%|██▏       | 22/100 [02:06<07:32,  5.80s/it][A[A[A

# of epoch :21 tot loss:0.6489897966384888





 23%|██▎       | 23/100 [02:12<07:25,  5.78s/it][A[A[A

# of epoch :22 tot loss:0.5417945981025696





 24%|██▍       | 24/100 [02:18<07:17,  5.76s/it][A[A[A

# of epoch :23 tot loss:0.5763205885887146





 25%|██▌       | 25/100 [02:24<07:12,  5.77s/it][A[A[A

# of epoch :24 tot loss:0.50163733959198





 26%|██▌       | 26/100 [02:29<07:02,  5.71s/it][A[A[A

# of epoch :25 tot loss:0.6051726937294006





 27%|██▋       | 27/100 [02:35<06:57,  5.72s/it][A[A[A

# of epoch :26 tot loss:0.570557713508606





 28%|██▊       | 28/100 [02:41<06:52,  5.72s/it][A[A[A

# of epoch :27 tot loss:0.5881659388542175





 29%|██▉       | 29/100 [02:46<06:45,  5.72s/it][A[A[A

# of epoch :28 tot loss:0.5646957159042358





 30%|███       | 30/100 [02:52<06:41,  5.74s/it][A[A[A

# of epoch :29 tot loss:0.45984911918640137





 31%|███       | 31/100 [02:58<06:36,  5.74s/it][A[A[A

# of epoch :30 tot loss:0.5468692779541016





 32%|███▏      | 32/100 [03:04<06:29,  5.73s/it][A[A[A

# of epoch :31 tot loss:0.5301008820533752





 33%|███▎      | 33/100 [03:09<06:22,  5.71s/it][A[A[A

# of epoch :32 tot loss:0.5083196759223938





 34%|███▍      | 34/100 [03:15<06:18,  5.74s/it][A[A[A

# of epoch :33 tot loss:0.6315391063690186





 35%|███▌      | 35/100 [03:21<06:15,  5.77s/it][A[A[A

# of epoch :34 tot loss:0.5553745031356812





 36%|███▌      | 36/100 [03:27<06:08,  5.76s/it][A[A[A

# of epoch :35 tot loss:0.6491348743438721





 37%|███▋      | 37/100 [03:32<06:00,  5.72s/it][A[A[A

# of epoch :36 tot loss:0.622761607170105





 38%|███▊      | 38/100 [03:38<05:53,  5.71s/it][A[A[A

# of epoch :37 tot loss:0.4823102653026581





 39%|███▉      | 39/100 [03:44<05:47,  5.70s/it][A[A[A

# of epoch :38 tot loss:0.5347152948379517





 40%|████      | 40/100 [03:50<05:45,  5.76s/it][A[A[A

# of epoch :39 tot loss:0.580610454082489





 41%|████      | 41/100 [03:56<05:41,  5.79s/it][A[A[A

# of epoch :40 tot loss:0.598051905632019





 42%|████▏     | 42/100 [04:01<05:38,  5.83s/it][A[A[A

# of epoch :41 tot loss:0.5448808073997498





 43%|████▎     | 43/100 [04:07<05:32,  5.84s/it][A[A[A

# of epoch :42 tot loss:0.5388686060905457





 44%|████▍     | 44/100 [04:13<05:26,  5.83s/it][A[A[A

# of epoch :43 tot loss:0.508815586566925





 45%|████▌     | 45/100 [04:19<05:23,  5.88s/it][A[A[A

# of epoch :44 tot loss:0.5954064130783081





 46%|████▌     | 46/100 [04:25<05:18,  5.89s/it][A[A[A

# of epoch :45 tot loss:0.5975371599197388





 47%|████▋     | 47/100 [04:31<05:11,  5.88s/it][A[A[A

# of epoch :46 tot loss:0.5711615085601807





 48%|████▊     | 48/100 [04:37<05:05,  5.88s/it][A[A[A

# of epoch :47 tot loss:0.6123687624931335





 49%|████▉     | 49/100 [04:43<05:01,  5.92s/it][A[A[A

# of epoch :48 tot loss:0.5323269367218018





 50%|█████     | 50/100 [04:49<04:55,  5.91s/it][A[A[A

# of epoch :49 tot loss:0.5359553694725037





 51%|█████     | 51/100 [04:55<04:49,  5.91s/it][A[A[A

# of epoch :50 tot loss:0.4751361906528473





 52%|█████▏    | 52/100 [05:00<04:42,  5.88s/it][A[A[A

# of epoch :51 tot loss:0.5684080123901367





 53%|█████▎    | 53/100 [05:06<04:36,  5.88s/it][A[A[A

# of epoch :52 tot loss:0.6107295751571655





 54%|█████▍    | 54/100 [05:12<04:31,  5.90s/it][A[A[A

# of epoch :53 tot loss:0.5865159630775452





 55%|█████▌    | 55/100 [05:18<04:25,  5.89s/it][A[A[A

# of epoch :54 tot loss:0.587846040725708





 56%|█████▌    | 56/100 [05:24<04:18,  5.89s/it][A[A[A

# of epoch :55 tot loss:0.5313772559165955





 57%|█████▋    | 57/100 [05:30<04:10,  5.83s/it][A[A[A

# of epoch :56 tot loss:0.5038008093833923





 58%|█████▊    | 58/100 [05:35<04:04,  5.81s/it][A[A[A

# of epoch :57 tot loss:0.48450782895088196





 59%|█████▉    | 59/100 [05:41<03:55,  5.76s/it][A[A[A

# of epoch :58 tot loss:0.5780981183052063





 60%|██████    | 60/100 [05:47<03:49,  5.74s/it][A[A[A

# of epoch :59 tot loss:0.5219416618347168





 61%|██████    | 61/100 [05:53<03:45,  5.77s/it][A[A[A

# of epoch :60 tot loss:0.5829092860221863





 62%|██████▏   | 62/100 [05:58<03:37,  5.73s/it][A[A[A

# of epoch :61 tot loss:0.6335541605949402





 63%|██████▎   | 63/100 [06:04<03:33,  5.77s/it][A[A[A

# of epoch :62 tot loss:0.5025615096092224





 64%|██████▍   | 64/100 [06:10<03:27,  5.76s/it][A[A[A

# of epoch :63 tot loss:0.5553693175315857





 65%|██████▌   | 65/100 [06:15<03:20,  5.74s/it][A[A[A

# of epoch :64 tot loss:0.5337725877761841





 66%|██████▌   | 66/100 [06:21<03:14,  5.73s/it][A[A[A

# of epoch :65 tot loss:0.5457082986831665





 67%|██████▋   | 67/100 [06:27<03:08,  5.72s/it][A[A[A

# of epoch :66 tot loss:0.6117599010467529





 68%|██████▊   | 68/100 [06:33<03:03,  5.72s/it][A[A[A

# of epoch :67 tot loss:0.4937077462673187





 69%|██████▉   | 69/100 [06:38<02:57,  5.73s/it][A[A[A

# of epoch :68 tot loss:0.5805814862251282





 70%|███████   | 70/100 [06:44<02:51,  5.73s/it][A[A[A

# of epoch :69 tot loss:0.581520676612854





 71%|███████   | 71/100 [06:50<02:47,  5.78s/it][A[A[A

# of epoch :70 tot loss:0.6106309294700623





 72%|███████▏  | 72/100 [06:56<02:41,  5.77s/it][A[A[A

# of epoch :71 tot loss:0.4508413076400757





 73%|███████▎  | 73/100 [07:01<02:35,  5.74s/it][A[A[A

# of epoch :72 tot loss:0.5553771257400513





 74%|███████▍  | 74/100 [07:07<02:28,  5.73s/it][A[A[A

# of epoch :73 tot loss:0.4602898955345154





 75%|███████▌  | 75/100 [07:13<02:23,  5.76s/it][A[A[A

# of epoch :74 tot loss:0.5904955863952637





 76%|███████▌  | 76/100 [07:19<02:17,  5.74s/it][A[A[A

# of epoch :75 tot loss:0.48172053694725037





 77%|███████▋  | 77/100 [07:24<02:11,  5.74s/it][A[A[A

# of epoch :76 tot loss:0.5500175952911377





 78%|███████▊  | 78/100 [07:30<02:06,  5.73s/it][A[A[A

# of epoch :77 tot loss:0.5170906186103821





 79%|███████▉  | 79/100 [07:36<02:01,  5.77s/it][A[A[A

# of epoch :78 tot loss:0.542073130607605





 80%|████████  | 80/100 [07:42<01:54,  5.73s/it][A[A[A

# of epoch :79 tot loss:0.5766952037811279





 81%|████████  | 81/100 [07:47<01:48,  5.73s/it][A[A[A

# of epoch :80 tot loss:0.552909255027771





 82%|████████▏ | 82/100 [07:53<01:43,  5.73s/it][A[A[A

# of epoch :81 tot loss:0.5085840225219727





 83%|████████▎ | 83/100 [07:59<01:37,  5.73s/it][A[A[A

# of epoch :82 tot loss:0.5569682717323303





 84%|████████▍ | 84/100 [08:05<01:32,  5.77s/it][A[A[A

# of epoch :83 tot loss:0.6466910243034363





 85%|████████▌ | 85/100 [08:10<01:26,  5.76s/it][A[A[A

# of epoch :84 tot loss:0.549347996711731





 86%|████████▌ | 86/100 [08:16<01:20,  5.73s/it][A[A[A

# of epoch :85 tot loss:0.5483113527297974





 87%|████████▋ | 87/100 [08:22<01:14,  5.72s/it][A[A[A

# of epoch :86 tot loss:0.48076942563056946





 88%|████████▊ | 88/100 [08:28<01:08,  5.74s/it][A[A[A

# of epoch :87 tot loss:0.6208505034446716





 89%|████████▉ | 89/100 [08:33<01:03,  5.73s/it][A[A[A

# of epoch :88 tot loss:0.5201811790466309





 90%|█████████ | 90/100 [08:39<00:56,  5.70s/it][A[A[A

# of epoch :89 tot loss:0.5223301649093628





 91%|█████████ | 91/100 [08:45<00:51,  5.71s/it][A[A[A

# of epoch :90 tot loss:0.5512722134590149





 92%|█████████▏| 92/100 [08:50<00:45,  5.74s/it][A[A[A

# of epoch :91 tot loss:0.5704483985900879





 93%|█████████▎| 93/100 [08:56<00:40,  5.76s/it][A[A[A

# of epoch :92 tot loss:0.5493901968002319





 94%|█████████▍| 94/100 [09:02<00:34,  5.78s/it][A[A[A

# of epoch :93 tot loss:0.6084209680557251





 95%|█████████▌| 95/100 [09:08<00:28,  5.73s/it][A[A[A

# of epoch :94 tot loss:0.49885281920433044





 96%|█████████▌| 96/100 [09:13<00:22,  5.71s/it][A[A[A

# of epoch :95 tot loss:0.5320302248001099





 97%|█████████▋| 97/100 [09:19<00:17,  5.71s/it][A[A[A

# of epoch :96 tot loss:0.5643812417984009





 98%|█████████▊| 98/100 [09:25<00:11,  5.73s/it][A[A[A

# of epoch :97 tot loss:0.564497709274292





 99%|█████████▉| 99/100 [09:31<00:05,  5.72s/it][A[A[A

# of epoch :98 tot loss:0.47092753648757935





100%|██████████| 100/100 [09:36<00:00,  5.77s/it][A[A[A

# of epoch :99 tot loss:0.5940861701965332





In [None]:
#Train,Test split 99/1
#Generate Instance
#Train
#Optim_zero_grad
#Loss
#Backward
#Optimizer.step()