# Setup

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
setup_already_done = False

In [None]:
import sys
import os

In [None]:
IN_COLAB = 'google.colab' in sys.modules
IN_COLAB

In [None]:
if not setup_already_done:
    if(IN_COLAB):
        !git clone https://github.com/Silver0x10/VideoPrediction_MovingMNIST.git
        %cd VideoPrediction_MovingMNIST
    else:
        %cd ..
    if(not os.path.exists("data/mnist_test_seq.npy")):
        if(not os.path.exists("data")):
            %mkdir data
        %cd data
        !wget -q https://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy
        %cd ..

    %pip install -qr requirements.txt  --quiet
    setup_already_done = True

In [None]:
!nvidia-smi

In [None]:
!pwd
# %cd VideoPrediction_MovingMNIST/

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader,random_split

import lightning.pytorch as pl
# from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger

import numpy as np

# import pylab as plt
import matplotlib.pyplot as plt

import wandb

In [None]:
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device.type

In [None]:
wandb.login()

# Dataset

In [None]:
from src.MovingMNIST import MovingMNIST
from src.parameters import shared_params

In [None]:
dataset = MovingMNIST(data_path="data/mnist_test_seq.npy")

In [None]:
train_set, validation_set, test_set = random_split(dataset,[8000, 1000, 1000],
                                                   generator=torch.Generator().manual_seed(42))

## Reference Sample

In [None]:
dataset.visualize_as_gif(42)
sequence_test = test_set[42]

In [None]:
print("Input:")
input_frames = sequence_test['frames']
dataset.visualize_given_frames_as_gif(input_frames)

In [None]:
print("Ground Truth:")
gt_frames = sequence_test['y']
dataset.visualize_given_frames_as_gif(gt_frames)

# 1) SimpleLSTM

In [None]:
from src.simpleLSTM import SimpleLSTM
from src.parameters import ParamsSimpleLSTM

params_simpleLSTM = ParamsSimpleLSTM()
model_simpleLSTM = SimpleLSTM(params_simpleLSTM)

## Training + Testing

In [133]:
training_dataloader = DataLoader(train_set, batch_size = params_simpleLSTM.batch_size)
validation_dataloader = DataLoader(validation_set, batch_size = params_simpleLSTM.batch_size)
test_dataloader = DataLoader(test_set, batch_size = params_simpleLSTM.batch_size)

In [None]:
wandb_logger = WandbLogger(project='DeepLearning', name='SimpleLSTM', log_model=True)

In [None]:
wandb_logger.experiment.config['layers'] = str(model_simpleLSTM)

for p in vars(params_simpleLSTM).keys():
    wandb_logger.experiment.config[p] = vars(params_simpleLSTM)[p]

In [None]:
trainer_simpleLSTM = pl.Trainer(max_epochs=params_convTAU.training_epochs, 
                                accelerator=device.type, 
                                # logger=wandb_logger, 
                                # deterministic=True,
                                fast_dev_run=1,
                                # overfit_batches=1, # to test if the model can learn
                                detect_anomaly=True)

In [None]:
trainer_simpleLSTM.fit(model=model_simpleLSTM, train_dataloaders=training_dataloader, val_dataloaders=validation_dataloader)

In [None]:
results_simpleLSTM = trainer_simpleLSTM.test(model_simpleLSTM, dataloaders=test_dataloader)

## Visualization Example

In [None]:
# checkpoint_path = "lightning_logs/vo3td60s/checkpoints/epoch=9-step=80000.ckpt"
# # checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
# checkpoint = torch.load(checkpoint_path)
# model.load_state_dict(checkpoint['state_dict'])
# model.eval()

In [None]:
print("Prediction:")
pred_simpleLSTM, _ = model_simpleLSTM(input_frames.unsqueeze(0))
pred_simpleLSTM = pred_simpleLSTM.detach().squeeze(0)
dataset.visualize_given_frames_as_gif(pred_simpleLSTM)

In [None]:
fig, ax = plt.subplots(10, 3, figsize=(14, 28))
ax[0,0].set_title('Ground Truth')
ax[0,1].set_title('Prediction')
ax[0,2].set_title('Difference')

for i in range(gt_frames.shape[0]):
    gt_i = gt_frames[i, :, :]
    pred_i = pred_simpleLSTM[i, :, :]
    ax[i,0].imshow(gt_i, cmap='gray')
    ax[i,1].imshow(pred_i, cmap='gray')
    ax[i,2].imshow(gt_i - pred_i, cmap='gray')

fig.savefig('out/simpeLSTM_comparison.png')
wandb.log({"example": wandb.Image("out/convTAU_comparison.png")})

fig.show()

## SimpleLSTM Finish

In [None]:
wandb.finish()

In [None]:
# !zip -r ../logs lightning_logs/ # Remember to save weights (if needed)

# 2) ConvLSTM

In [None]:
from src.convLSTM import *
from src.parameters import ParamsConvLSTM

params_convLSTM = ParamsConvLSTM()
model_convLSTM = PlEncoderDecoder(k_s =5, Batch_size = params_convLSTM.batch_size)

## Training + Testing

In [None]:
training_dataloader = DataLoader(train_set, batch_size = params_convLSTM.batch_size)
validation_dataloader = DataLoader(validation_set, batch_size = params_convLSTM.batch_size)
test_dataloader = DataLoader(test_set, batch_size = params_convLSTM.batch_size)

In [None]:
wandb_logger = WandbLogger(project='DeepLearning', name='ConvLSTM', log_model=True)

In [None]:
wandb_logger.experiment.config['layers'] = str(model_convLSTM)

for p in vars(params_convLSTM).keys():
    wandb_logger.experiment.config[p] = vars(params_convLSTM)[p]

In [None]:
trainer_convLSTM= pl.Trainer(max_epochs=params_convLSTM.training_epochs, 
                             accelerator=device.type, 
                            #  logger=wandb_logger, 
                            #  deterministic=True,
                            #  fast_dev_run=1,
                            overfit_batches=1, # to test if the model can learn
                            detect_anomaly=True)

In [None]:
trainer_convLSTM.fit(model=model_convLSTM, train_dataloaders=training_dataloader, val_dataloaders=validation_dataloader)

In [None]:
results_convLSTM = trainer_convLSTM.test(model_convLSTM, dataloaders=test_dataloader)

## Reference Prediction:

In [None]:
# Load weights if needed

# checkpoint_path = "lightning_logs/vo3td60s/checkpoints/epoch=9-step=80000.ckpt"
# # checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
# checkpoint = torch.load(checkpoint_path)
# model.load_state_dict(checkpoint['state_dict'])
# model.eval()

In [None]:
print("Prediction:")
pred_convLSTM = model_convLSTM(input_frames)#.detach()#.squeeze(1)
dataset.visualize_given_frames_as_gif(pred_convLSTM)

In [None]:
fig, ax = plt.subplots(10, 3, figsize=(14, 28))
ax[0,0].set_title('Ground Truth')
ax[0,1].set_title('Prediction')
ax[0,2].set_title('Difference')

for i in range(gt_frames.shape[0]):
    gt_i = gt_frames[i, :, :]
    pred_i = pred_convLSTM[i, :, :]
    ax[i,0].imshow(gt_i, cmap='gray')
    ax[i,1].imshow(pred_i, cmap='gray')
    ax[i,2].imshow(gt_i - pred_i, cmap='gray')

fig.savefig('out/convTAU_comparison.png')
wandb.log({"example": wandb.Image("out/convTAU_comparison.png")})

fig.show()

## ConvLSTM Finish 

In [None]:
wandb.finish()

In [None]:
# !zip -r ../logs lightning_logs/ # Remember to save weights (if needed)

# 3) ConvTAU

In [None]:
from src.ConvTAU import *
from src.parameters import ParamsConvTAU

params_convTAU = ParamsConvTAU()
model_convTAU = ConvTAU(params_convTAU)

## Training + Testing

In [None]:
training_dataloader = DataLoader(train_set, batch_size = params_convTAU.batch_size)
validation_dataloader = DataLoader(validation_set, batch_size = params_convTAU.batch_size)
test_dataloader = DataLoader(test_set, batch_size = params_convTAU.batch_size)

In [None]:
wandb_logger = WandbLogger(project='DeepLearning', name='ConvTAU', log_model=True)

In [None]:
wandb_logger.experiment.config['layers'] = str(model_convTAU)

for p in vars(params_convTAU).keys():
    wandb_logger.experiment.config[p] = vars(params_convTAU)[p]

In [None]:
trainer_convTAU = pl.Trainer(max_epochs=params_convTAU.training_epochs, 
                             accelerator=device.type, 
                            #  logger=wandb_logger, 
                            #  deterministic=True,
                             fast_dev_run=1,
                            # overfit_batches=1, # to test if the model can learn
                             detect_anomaly=True)

In [None]:
trainer_convTAU.fit(model=model_convTAU, train_dataloaders=training_dataloader, val_dataloaders=validation_dataloader)

In [None]:
results_convTAU = trainer_convTAU.test(model_convTAU, dataloaders=test_dataloader)

## Visualization Example

In [None]:
# checkpoint_path = "lightning_logs/vo3td60s/checkpoints/epoch=9-step=80000.ckpt"
# # checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
# checkpoint = torch.load(checkpoint_path)
# model_3.load_state_dict(checkpoint['state_dict'])
# model_3.eval()

In [None]:
print("Prediction:")
pred_convTAU = model_convTAU.single_prediction(input_frames).detach().squeeze(1)
dataset.visualize_given_frames_as_gif(pred_convTAU)

In [None]:
fig, ax = plt.subplots(10, 3, figsize=(14, 28))
ax[0,0].set_title('Ground Truth')
ax[0,1].set_title('Prediction')
ax[0,2].set_title('Difference')

for i in range(gt_frames.shape[0]):
    gt_i = gt_frames[i, :, :]
    pred_i = pred_convTAU[i, :, :]
    ax[i,0].imshow(gt_i, cmap='gray')
    ax[i,1].imshow(pred_i, cmap='gray')
    ax[i,2].imshow(gt_i - pred_i, cmap='gray')

fig.savefig('out/convTAU_comparison.png')
# wandb.log({"example": wandb.Image("out/convTAU_comparison.png")})

fig.show()

## ConvTAU Finish

In [None]:
wandb.finish()

In [None]:
# !zip -r ../logs lightning_logs/ # Remember to save weights (if needed)

# Extra / Trash

In [None]:
# import gc
# print(torch.cuda.list_gpu_processes())
# gc.collect()
# torch.cuda.empty_cache()