In [None]:
# import libraries
import os
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.nn import functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from multiprocessing import Process



import argparse

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
    parser.add_argument('--beta_1', type=float, default=0.9, help='decay rate 1')
    parser.add_argument('--beta_2', type=float, default=0.98, help='decay rate 2')
    parser.add_argument('--batch_size', default=16, type=int, help='batch size')
    parser.add_argument('--epochs', type=int, default=300, help='number of epochs to train for')
    parser.add_argument('--use_amp', default=True, type=bool, help='mixed-precision training')
    parser.add_argument('--n_gpus', type=int, default=2, help='number of GPUs')
    parser.add_argument('--n_hidden_dim', type=int, default=64, help='number of hidden dim for ConvLSTM layers')
    
    args = parser.parse_args([])  # Parse empty list to use defaults
    
    # Override with your desired values
    args.n_gpus = 1
    args.use_amp = True
    args.batch_size = 128
    
    return args

opt = parse_args()

import torch.nn as nn
import torch

import socket
import numpy as np
from torchvision import datasets, transforms

# from: https://github.com/edenton/svg/blob/master/data/moving_mnist.py

class MovingMNIST(object):
    """Data Handler that creates Bouncing MNIST dataset on the fly."""

    def __init__(self, train, data_root, seq_len=20, num_digits=2, image_size=64, deterministic=True):
        path = data_root
        self.seq_len = seq_len
        self.num_digits = num_digits
        self.image_size = image_size
        self.step_length = 0.1
        self.digit_size = 32
        self.deterministic = deterministic
        self.seed_is_set = False  # multi threaded loading
        self.channels = 1

        self.data = datasets.MNIST(
            path,
            train=train,
            download=True,
            transform=transforms.Compose(
                [transforms.Resize(self.digit_size),
                 transforms.ToTensor()]))

        self.N = len(self.data)

    def set_seed(self, seed):
        if not self.seed_is_set:
            self.seed_is_set = True
            np.random.seed(seed)

    def __len__(self):
        return self.N

    def __getitem__(self, index):
        self.set_seed(index)
        image_size = self.image_size
        digit_size = self.digit_size
        x = np.zeros((self.seq_len,
                      image_size,
                      image_size,
                      self.channels),
                     dtype=np.float32)
        for n in range(self.num_digits):
            idx = np.random.randint(self.N)
            digit, _ = self.data[idx]

            sx = np.random.randint(image_size - digit_size)
            sy = np.random.randint(image_size - digit_size)
            dx = np.random.randint(-4, 5)
            dy = np.random.randint(-4, 5)
            for t in range(self.seq_len):
                if sy < 0:
                    sy = 0
                    if self.deterministic:
                        dy = -dy
                    else:
                        dy = np.random.randint(1, 5)
                        dx = np.random.randint(-4, 5)
                elif sy >= image_size - 32:
                    sy = image_size - 32 - 1
                    if self.deterministic:
                        dy = -dy
                    else:
                        dy = np.random.randint(-4, 0)
                        dx = np.random.randint(-4, 5)

                if sx < 0:
                    sx = 0
                    if self.deterministic:
                        dx = -dx
                    else:
                        dx = np.random.randint(1, 5)
                        dy = np.random.randint(-4, 5)
                elif sx >= image_size - 32:
                    sx = image_size - 32 - 1
                    if self.deterministic:
                        dx = -dx
                    else:
                        dx = np.random.randint(-4, 0)
                        dy = np.random.randint(-4, 5)

                x[t, sy:sy + 32, sx:sx + 32, 0] += digit.numpy().squeeze()
                sy += dy
                sx += dx

        x[x > 1] = 1.
        return x



class ConvLSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize ConvLSTM cell.

        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis

        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
    

import torch
import torch.nn as nn



class VAEConvLSTM(nn.Module):
    def __init__(self, nf, in_chan):
        super(EncoderDecoderConvLSTM, self).__init__()

        """ ARCHITECTURE 

        # Encoder (ConvLSTM)
        # Encoder Vector (final hidden state of encoder)
        # Decoder (ConvLSTM) - takes Encoder Vector as input
        # Decoder (3D CNN) - produces regression predictions for our model

        """
        self.encoder_1_convlstm = ConvLSTMCell(input_dim=in_chan,
                                               hidden_dim=nf,
                                               kernel_size=(3, 3),
                                               bias=True)

        self.encoder_2_convlstm = ConvLSTMCell(input_dim=nf,
                                               hidden_dim=nf,
                                               kernel_size=(3, 3),
                                               bias=True)
        #VAE components
        self.fc_mu = nn.Conv2d(nf, latent_dim, kernel_size=1)
        self.fc_logvar = nn.Conv2d(nf, latent_dim, kernel_size = 1)

        self.decoder_1_convlstm = ConvLSTMCell(input_dim=nf,  # nf + 1
                                               hidden_dim=nf,
                                               kernel_size=(3, 3),
                                               bias=True)

        self.decoder_2_convlstm = ConvLSTMCell(input_dim=nf,
                                               hidden_dim=nf,
                                               kernel_size=(3, 3),
                                               bias=True)

        self.decoder_CNN = nn.Conv3d(in_channels=nf,
                                     out_channels=1,
                                     kernel_size=(1, 3, 3),
                                     padding=(0, 1, 1))


    def autoencoder(self, x, seq_len, future_step, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4):

        outputs = []

        # encoder
        for t in range(seq_len):
            h_t, c_t = self.encoder_1_convlstm(input_tensor=x[:, t, :, :],
                                               cur_state=[h_t, c_t])  # we could concat to provide skip conn here
            h_t2, c_t2 = self.encoder_2_convlstm(input_tensor=h_t,
                                                 cur_state=[h_t2, c_t2])  # we could concat to provide skip conn here

        # encoder_vector
        encoder_vector = h_t2

        # decoder
        for t in range(future_step):
            h_t3, c_t3 = self.decoder_1_convlstm(input_tensor=encoder_vector,
                                                 cur_state=[h_t3, c_t3])  # we could concat to provide skip conn here
            h_t4, c_t4 = self.decoder_2_convlstm(input_tensor=h_t3,
                                                 cur_state=[h_t4, c_t4])  # we could concat to provide skip conn here
            encoder_vector = h_t4
            outputs += [h_t4]  # predictions

        outputs = torch.stack(outputs, 1)
        outputs = outputs.permute(0, 2, 1, 3, 4)
        outputs = self.decoder_CNN(outputs)
        outputs = torch.nn.Sigmoid()(outputs)

        return outputs

    def forward(self, x, future_seq=0, hidden_state=None):

        """
        Parameters
        ----------
        input_tensor:
            5-D Tensor of shape (b, t, c, h, w)        #   batch, time, channel, height, width
        """

        # find size of different input dimensions
        b, seq_len, _, h, w = x.size()

        # initialize hidden states
        h_t, c_t = self.encoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w))
        h_t2, c_t2 = self.encoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w))
        h_t3, c_t3 = self.decoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w))
        h_t4, c_t4 = self.decoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w))

        # autoencoder forward
        outputs = self.autoencoder(x, seq_len, future_seq, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4)

        return outputs
    



##########################
######### MODEL ##########
##########################

class MovingMNISTVAE(pl.LightningModule):

    def __init__(self, hparams=None, model=None):
        super(MovingMNISTVAE, self).__init__()
        self.vae_convlstm= VAEConvLSTM(nf, in_chan, latent_dim)

        # default config
        self.path = os.getcwd() + '/data'
        self.model = model

        '''CHECK IF YOU ACTUALLY NEED THIS
        self.lr = lr
        '''

        # logging config
        self.log_images = True

        # Training config
        self.criterion = torch.nn.MSELoss()
        self.batch_size = opt.batch_size
        self.n_steps_past = 10
        self.n_steps_ahead = 10  # 4

    def create_video(self, x, y_hat, y):
        # predictions with input for illustration purposes
        preds = torch.cat([x.cpu(), y_hat.unsqueeze(2).cpu()], dim=1)[0]

        # entire input and ground truth
        y_plot = torch.cat([x.cpu(), y.unsqueeze(2).cpu()], dim=1)[0]

        # error (l2 norm) plot between pred and ground truth
        difference = (torch.pow(y_hat[0] - y[0], 2)).detach().cpu()
        zeros = torch.zeros(difference.shape)
        difference_plot = torch.cat([zeros.cpu().unsqueeze(0), difference.unsqueeze(0).cpu()], dim=1)[
            0].unsqueeze(1)

        # concat all images
        final_image = torch.cat([preds, y_plot, difference_plot], dim=0)

        # make them into a single grid image file
        grid = torchvision.utils.make_grid(final_image, nrow=self.n_steps_past + self.n_steps_ahead)

        return grid
    

    # OPERATIONS IN FUNCTIONS ARE MOVED AROUND A TAD BIT, HOPEFULLY SHOULDN'T AFFECT ANYTHING
    def forward(self, x):
        x = x.to(device='cuda')
        #permutted earlier because of the VAE architecture also, need to consider the probabilistic latent space now.
        x = x.permute(0, 1, 4, 2, 3)
        output, mu, logvar = self.vae_convlstm(x, self.n_steps_ahead)
        return output, mu, logvar
    
    #NEW LOSS FUNCTION INTRODUCED FOR VAE
    def loss_function(self, recon_x, x, mu, logvar):
        #Binary CROSS ENTROPY ONLY FOR THE MOVING MNIST DATASET BECAUSE OF 1  CHANNEL AND TWO PIXEL VALUES BLACK AND WHITE(EASY TO CLASSIFY)
        #NEED TO IMPLEMENT MSE(MEAN SQUARED ERROR) or MAE(MEAN AVERAGE ERROR) FOR 3 CHANNEL RGB IMAGES.
        BCE = F.binary_cross_entropy_with_logits(recon_x, x, reduction='sum')
        #NEED THESE LOSS FUNCTIONS WHEN USING 3 CHANNEL RGB IMAGES
        #MSE = F.mse_loss(recon_x, x, reduction='sum')
        #MAE = F.l1_loss(recon_x, x, reduction='sum')
        KLD = -0.5 * torch.sum(1+ logvar - mu.pow(2) - logvar.exp())
        #the return would be changed when using MSE or MAE
        return BCE + KLD

    def training_step(self, batch, batch_idx):
        x, y = batch[:, 0:self.n_steps_past, :, :, :], batch[:, self.n_steps_past:, :, :, :]
        #MOVING around things a bit 
        #x = x.permute(0, 1, 4, 2, 3)
    
        y = y.squeeze()
        
        '''
        The active selection of code is calling the `loss_function` method within a class. This method is used to calculate the loss for a variational autoencoder (VAE) model. The `loss_function` takes four parameters: `recon_x`, `x`, `mu`, and `logvar`. 

        In the provided implementation of the `loss_function`, the loss is calculated using a combination of binary cross entropy (BCE) and the Kullback-Leibler Divergence (KLD). 

        First, the BCE is calculated using the `F.binary_cross_entropy_with_logits` function. This function computes the binary cross entropy loss between the reconstructed output (`recon_x`) and the target input (`x`). The `reduction='sum'` argument specifies that the loss should be summed over all elements in the input tensors.

        Next, the KLD is computed using the formula `-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())`. The KLD measures the difference between the learned distribution of the latent space variables (`mu` and `logvar`) and a prior distribution (usually a standard normal distribution). It penalizes the model for deviating from the prior distribution. The KLD term encourages the VAE to learn a compact and smooth latent space representation.

        Finally, the BCE and KLD terms are added together to obtain the total loss. The return statement of the `loss_function` method returns the sum of the BCE and KLD terms.

        It's worth noting that the implementation of the `loss_function` in the provided code is specifically designed for the Moving MNIST dataset, which has a single channel and two pixel values (black and white). For RGB images with three channels, different loss functions such as mean squared error (MSE) or mean absolute error (MAE) would be more appropriate.
        '''
        y_hat, mu, logvar = self(x)
        loss = self.loss_function(y_hat, y, mu, logvar)
        return loss

        #few things changed around here since the loss function is different now and its vae 
        #y_hat = self.forward(x).squeeze()  # is squeeze neccessary?

        #loss = self.criterion(y_hat, y)

        # save learning_rate
        lr_saved = self.trainer.optimizers[0].param_groups[-1]['lr']
        lr_saved = torch.scalar_tensor(lr_saved).cuda()

        # save predicted images every 250 global_step
        if self.log_images:
            if self.global_step % 250 == 0:
                final_image = self.create_video(x, y_hat, y)

                self.logger.experiment.add_image(
                    'epoch_' + str(self.current_epoch) + '_step' + str(self.global_step) + '_generated_images',
                    final_image, 0)
                plt.close()

        tensorboard_logs = {'train_mse_loss': loss,
                            'learning_rate': lr_saved}

        return {'loss': loss, 'log': tensorboard_logs}


    def test_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'test_loss': self.criterion(y_hat, y)}


    def test_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        tensorboard_logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}


    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=opt.lr, betas=(opt.beta_1, opt.beta_2))
        #THE ONE BELOW IS SUGGESTED BY PERPLEXITY.
        #return torch.optim.Adam(self.parameters(), lr=1e-3)

    #@pl.data_loader
    def train_dataloader(self):
        train_data = MovingMNIST(
            train=True,
            data_root=self.path,
            seq_len=self.n_steps_past + self.n_steps_ahead,
            image_size=64,
            deterministic=True,
            num_digits=2)

        train_loader = torch.utils.data.DataLoader(
            dataset=train_data,
            batch_size=self.batch_size,
            shuffle=True)

        return train_loader

    #@pl.data_loader
    def test_dataloader(self):
        test_data = MovingMNIST(
            train=False,
            data_root=self.path,
            seq_len=self.n_steps_past + self.n_steps_ahead,
            image_size=64,
            deterministic=True,
            num_digits=2)

        test_loader = torch.utils.data.DataLoader(
            dataset=test_data,
            batch_size=self.batch_size,
            shuffle=True)

        return test_loader



def run_trainer():
    conv_lstm_model = EncoderDecoderConvLSTM(nf=opt.n_hidden_dim, in_chan=1)

    model = MovingMNISTLightning(model=conv_lstm_model)

    trainer = Trainer(max_epochs=opt.epochs,
                      devices=opt.n_gpus,
                      accelerator="gpu",
                      #strategy='ddp',
                      enable_checkpointing=False,
                      precision=16 if opt.use_amp else 32
                      )

    trainer.fit(model)


# if __name__ == '__main__':
#     p1 = Process(target=run_trainer)                    # start trainer
#     p1.start()
#     p2 = Process(target=run_tensorboard(new_run=True))  # start tensorboard
#     p2.start()
#     p1.join()
#     p2.join()


run_trainer()