In [None]:
from vrae.vrae import VRAE
from vrae.utils import *
import numpy as np
import torch
import pandas as pd
import plotly
from torch.utils.data import DataLoader, TensorDataset
plotly.offline.init_notebook_mode()
from tslearn.preprocessing import TimeSeriesScalerMeanVariance

dload = './model_dir_mm' #download directory

df = pd.read_pickle(fr'C:\Users\achfr\OneDrive - University of Edinburgh\Compiled dataset\df_compiled_mothers_labelled_cyclefate.csv')
df_tracks= df.pivot(values=[
    'GrowthRateSize','GrowthRateLength','GrowthRateFeretMax','GrowthRateFeretMaxSliding',
    'InterdivisionTimes','DivisionRate','DivisionRate_filtered','TrackLength','TrackLength_filtered',
    'Size','SizeAtBirthSize','FeretMax','SizeAtBirthFeretMax','MaxLength','SpineLength','SizeAtBirthLength','SpineWidth',
    'MeanIntensity_mch','MeanIntensity_gfp','Maxgfp',
    'BacteriaLineage','NextDivisionFrame','PreviousDivisionFrame',
    'TrackHeadIndices','Prev','Next','Idx','Frame','Indices','PositionIdx','cellcycle_fate'], 
    index=['Position','ParentTrackHeadIndices','Medium','Treatment','RepeatID','RepeatDate','fate','DeathSubtype'],
    columns='Time')

exp_name = 'glu_cip_1'

medium,treatment,replicate = exp_name.split('_')

skip_timepoints = 1
if medium == 'gly': skip_timepoints = 2  #in glycerol data is missing every other timepoint because data was collected every 10 minutes instead of 5 so we need to skip every other timepoint

frame = df_tracks.loc(axis=0)[:,:,medium,treatment,replicate]
size_array = np.array([list(frame['FeretMax'].T[k]) for k in frame['FeretMax'].T.keys()])
size_array = TimeSeriesScalerMeanVariance().fit_transform(size_array)
cyclefate_array = np.array([list(frame['cellcycle_fate'].T[k]) for k in frame['cellcycle_fate'].T.keys()])
fates = np.array([i for i in frame.reset_index('fate')['fate']])
death_subtypes = np.array([i for i in frame.reset_index('DeathSubtype')['DeathSubtype']])


liste=[]
for i in range(size_array.shape[0]):
    serie =  size_array[i]
    t_death = np.where(cyclefate_array[i,range(0,cyclefate_array.shape[1],skip_timepoints)] != 'alive')[0][0]
    usable_data = serie[24:min(t_death,168)]
    seq_len = 70
    # for i in range(len(usable_data)//seq_len):  #splitting the data into sequences of length seq_len
    #     tre = (i+1)*seq_len
    #     liste.append(usable_data[tre-seq_len:tre])
    for i in range(0,len(usable_data)-seq_len,1):
        liste.append(usable_data[i:i+seq_len])
        
Data = np.array(np.random.permutation(liste))[:,:,:]
Data.shape

X_train, X_test = Data[:int(0.8*Data.shape[0])], Data[int(0.8*Data.shape[0]):]

train_dataset = TensorDataset(torch.from_numpy(X_train))
test_dataset = TensorDataset(torch.from_numpy(X_test))

sequence_length = X_train.shape[1]
number_of_features = X_train.shape[2]

print(X_train.shape, X_test.shape)

hidden_size = 90
hidden_layer_depth = 1
latent_length = 20
batch_size = 30
learning_rate = 0.0005
n_epochs = 100
dropout_rate = 0.2
optimizer = 'Adam' # options: ADAM, SGD
cuda = True # options: True, False
print_every=30
clip = True # options: True, False
max_grad_norm=5
loss = 'MSELoss' # options: SmoothL1Loss, MSELoss
block = 'LSTM' # options: LSTM, GRU

vrae = VRAE(sequence_length=sequence_length,
            number_of_features = number_of_features,
            hidden_size = hidden_size, 
            hidden_layer_depth = hidden_layer_depth,
            latent_length = latent_length,
            batch_size = batch_size,
            learning_rate = learning_rate,
            n_epochs = n_epochs,
            dropout_rate = dropout_rate,
            optimizer = optimizer, 
            cuda = cuda,
            print_every=print_every, 
            clip=clip, 
            max_grad_norm=max_grad_norm,
            loss = loss,
            block = block,
            dload = dload)

vrae.fit(train_dataset, test_dataset)

vrae.eval()
testseq = test_dataset[30:batch_size+30][0].float()
print(testseq.shape)
# plt.plot(testseq[:,:,0].T)
testseq2 = testseq.permute(1, 0, 2).cuda()
print(testseq2.shape)
outp = vrae.forward(testseq2)
print(outp[0].shape)

k=np.random.randint(0,batch_size)
input_seq = testseq2[:,k,0].cpu().detach().numpy()
output_seq = outp[0][:,k,0].cpu().detach().numpy()
plt.figure()
plt.plot(input_seq.T)
plt.plot(output_seq.T)
import torch.nn as nn   
loss = nn.MSELoss(reduction='sum')
loss(testseq2, outp[0])

# 6min