In [1]:
!pip install easydict
!pip install wandb



In [13]:
import secrets

import easydict
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.cuda.amp import autocast, GradScaler
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm
import wandb

In [None]:
# import os
# from kaggle_secrets import UserSecretsClient

In [7]:
# # Get Wandb API key from Kaggle Secrets
# user_secrets = UserSecretsClient()
# wandb_api_key = user_secrets.get_secret("wandb_api_key")


In [None]:
# Initialize wandb
wandb.login()  # This will now use the API key from the environment variable
wandb.init(project="convlstmvae-moving-mnist", entity="ryukijano")

In [14]:
import torch
from torchvision import models

#loading a pre-trained ResNet and modifying it
resnet = models.resnet18(pretrained=True)
resnet.fc = nn.Identity() #removing the final fully connected layer.



In [15]:
from __future__ import print_function

import codecs
import errno
import os
import os.path

import numpy as np
import torch
import torch.utils.data as data
from PIL import Image


class MovingMNIST(data.Dataset):
    """`MovingMNIST <http://www.cs.toronto.edu/~nitish/unsupervised_video/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        split (int, optional): Train/test split size. Number defines how many samples
            belong to test set.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in an PIL
            image and returns a transformed version. E.g, ``transforms.RandomCrop``
    """

    urls = ["https://github.com/tychovdo/MovingMNIST/raw/master/mnist_test_seq.npy.gz"]
    raw_folder = "raw"
    processed_folder = "processed"
    training_file = "moving_mnist_train.pt"
    test_file = "moving_mnist_test.pt"

    def __init__(
        self,
        root,
        train=True,
        split=1000,
        transform=None,
        target_transform=None,
        download=False,
    ):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.split = split
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError(
                "Dataset not found." + " You can use download=True to download it"
            )

        if self.train:
            self.train_data = torch.load(
                os.path.join(self.root, self.processed_folder, self.training_file)
            )
        else:
            self.test_data = torch.load(
                os.path.join(self.root, self.processed_folder, self.test_file)
            )

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (seq, target) where sampled sequences are splitted into a seq
                    and target part
        """

        # need to iterate over time
        def _transform_time(data):
            new_data = None
            for i in range(data.size(0)):
                img = Image.fromarray(data[i].numpy(), mode="L")
                new_data = (
                    self.transform(img)
                    if new_data is None
                    else torch.cat([self.transform(img), new_data], dim=0)
                )
            return new_data

        if self.train:
            seq, target = self.train_data[index, :10], self.train_data[index, 10:]
        else:
            seq, target = self.test_data[index, :10], self.test_data[index, 10:]

        if self.transform is not None:
            seq = _transform_time(seq)
        if self.target_transform is not None:
            target = _transform_time(target)

        return seq, target

    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)

    def _check_exists(self):
        return os.path.exists(
            os.path.join(self.root, self.processed_folder, self.training_file)
        ) and os.path.exists(
            os.path.join(self.root, self.processed_folder, self.test_file)
        )

    def download(self):
        """Download the Moving MNIST data if it doesn't exist in processed_folder already."""
        import gzip

        from six.moves import urllib

        if self._check_exists():
            return

        # download files
        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
            os.makedirs(os.path.join(self.root, self.processed_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print("Downloading " + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition("/")[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, "wb") as f:
                f.write(data.read())
            with open(file_path.replace(".gz", ""), "wb") as out_f, gzip.GzipFile(
                file_path
            ) as zip_f:
                out_f.write(zip_f.read())
            os.unlink(file_path)

        # process and save as torch files
        print("Processing...")

        training_set = torch.from_numpy(
            np.load(
                os.path.join(self.root, self.raw_folder, "mnist_test_seq.npy")
            ).swapaxes(0, 1)[: -self.split]
        )
        test_set = torch.from_numpy(
            np.load(
                os.path.join(self.root, self.raw_folder, "mnist_test_seq.npy")
            ).swapaxes(0, 1)[-self.split :]
        )

        with open(
            os.path.join(self.root, self.processed_folder, self.training_file), "wb"
        ) as f:
            torch.save(training_set, f)
        with open(
            os.path.join(self.root, self.processed_folder, self.test_file), "wb"
        ) as f:
            torch.save(test_set, f)

        print("Done!")

    def __repr__(self):
        fmt_str = "Dataset " + self.__class__.__name__ + "\n"
        fmt_str += "    Number of datapoints: {}\n".format(self.__len__())
        tmp = "train" if self.train is True else "test"
        fmt_str += "    Train/test: {}\n".format(tmp)
        fmt_str += "    Root Location: {}\n".format(self.root)
        tmp = "    Transforms (if any): "
        fmt_str += "{0}{1}\n".format(
            tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp))
        )
        tmp = "    Target Transforms (if any): "
        fmt_str += "{0}{1}".format(
            tmp, self.target_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
        )
        return fmt_str


In [16]:
import torch
from torch import nn
from torch.nn import functional as F

class ConvLSTMCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvLSTMCell, self).__init__()
        
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        
        self.conv = nn.Conv2d(
            in_channels=self.input_channels + self.hidden_channels,
            out_channels=4 * self.hidden_channels,
            kernel_size=self.kernel_size,
            padding=self.padding,
            bias=True
        )

    def forward(self, x, h, c):
        combined = torch.cat([x, h], dim=1)
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_channels, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)
        
        c_next = f * c + i * g
        h_next = o * torch.tanh(c_next)
        
        return h_next, c_next

class ConvLSTMEncoder(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvLSTMEncoder, self).__init__()
        self.convlstm = ConvLSTMCell(input_channels, hidden_channels, kernel_size)
    
    def forward(self, x):
        batch_size, seq_len, channels, height, width = x.size()
        h = torch.zeros(batch_size, self.convlstm.hidden_channels, height, width).to(x.device)
        c = torch.zeros(batch_size, self.convlstm.hidden_channe7ls, height, width).to(x.device)
        7
        for t in range(seq_len):
            h, c = self.convlstm(x[:, t, :, :, :], h, c)
        
        return h, c

class ConvLSTMDecoder(nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels, kernel_size):
        super(ConvLSTMDecoder, self).__init__()
        self.convlstm = ConvLSTMCell(input_channels, hidden_channels, kernel_size)
        self.conv_out = nn.Conv2d(hidden_channels, output_channels, kernel_size=3, padding=1)
    
    def forward(self, x, h, c, seq_len):
        outputs = []
        
        for _ in range(seq_len):
            h, c = self.convlstm(x, h, c)
            output = self.conv_out(h)
            outputs.append(output)
            x = output
        
        return torch.stack(outputs, dim=1)

#Defining CONVLSTMVAE with a Pre-trained Encoder
class CONVLSTMVAE(nn.Module):
    def __init__(self, encoder, input_channels, hidden_channels, latent_size, kernel_size=3):
        super(CONVLSTMVAE, self).__init__()

        self.encoder = encoder
        self.encoder = ConvLSTMEncoder(input_channels, hidden_channels, kernel_size)
        self.decoder = ConvLSTMDecoder(input_channels, hidden_channels, 1, kernel_size)
        
        self.fc_mu = nn.Linear(hidden_channels * 64 * 64, latent_size)  # Assuming 64x64 spatial dimensions
        self.fc_logvar = nn.Linear(hidden_channels * 64 * 64, latent_size)
        self.fc_decode = nn.Linear(latent_size, hidden_channels * 64 * 64)
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        batch_size, seq_len, channels, height, width = x.size()
        
        # Encode
        cnn_features = []
        for t in range(seq_len):
            cnn_output = self.encoder(x[:, t, :, :, :])
            cnn_features.append(cnn_output)
        cnn_features = torch.stack(cnn_features, dim=1)

        h,c = self.convlstm_encoder(cnn_features)
        h_flat = h.reshape(batch_size, -1)
        
        # VAE bottleneck
        mu = self.fc_mu(h_flat)
        logvar = self.fc_logvar(h_flat)
        z = self.reparameterize(mu, logvar)
        
        # Decode
        h_decoded = self.fc_decode(z).view(batch_size, -1, height, width)
        c_decoded = torch.zeros_like(h_decoded)
        x_decoded = torch.zeros(batch_size, channels, height, width).to(x.device)
        
        output = self.decoder(x_decoded, h_decoded, c_decoded, seq_len)
        
        return output, mu, logvar
    
    def loss_function(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD

# # Visualization function
# def visualize_reconstructions(model, val_loader, device, epoch):
#     model.eval()
#     with torch.no_grad():
#         for batch in val_loader:
#             x, _ = batch
#             x = x.to(device)
#             recon_x, _, _ = model(x)
#             break
#     x = x.cpu().numpy()
#     recon_x = recon_x.cpu().numpy()

#     fig, axes = plt.subplots(2, 10, figsize=(20, 4))
#     for i in range(10):
#         axes[0, i].imshow(x[i, 0, 0], cmap='gray')
#         axes[0, i].axis('off')
#         axes[1, i].imshow(recon_x[i, 0, 0], cmap='gray')
#         axes[1, i].axis('off')
#     plt.suptitle(f"Epoch {epoch}: Original (top) vs Reconstructed (bottom)")
#     plt.savefig(f"reconstruction_epoch_{epoch}.png")
#     plt.close()

# # Training function
# def train(model, train_loader, val_loader, optimizer, epochs, device):
#     scaler = GradScaler()
#     model.to(device)
    
#     for epoch in range(epochs):
#         model.train()
#         total_loss = 0
#         for batch in train_loader:
#             x, _ = batch
#             x = x.to(device)
#             optimizer.zero_grad()
#             with autocast():
#                 recon_x, mu, logvar = model(x)
#                 loss = model.loss_function(recon_x, x, mu, logvar)
#             scaler.scale(loss).backward()
#             scaler.step(optimizer)
#             scaler.update()
#             total_loss += loss.item()
        
#         avg_loss = total_loss / len(train_loader.dataset)
#         print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")
        
#         if (epoch + 1) % 10 == 0:
#             visualize_reconstructions(model, val_loader, device, epoch)

# # Main function
# def main():
#     # Data preparation
#     transform = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
#     ])
#     train_dataset = MovingMNIST(root='./data', train=True, transform=transform, download=True)
#     val_dataset = MovingMNIST(root='./data', train=False, transform=transform, download=True)
    
#     def collate_fn(batch):
#         seqs, targets = zip(*batch)  # Separating sequences and targets
#         seqs = torch.stack(seqs).unsqueeze(2).permute(0, 1, 2, 4, 3)  # Adjust dimensions
#         targets = torch.stack(targets).unsqueeze(2).permute(0, 1, 2, 4, 3)  # Adjust dimensions
#         return seqs, targets
    
#     train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)
#     val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, collate_fn=collate_fn)

# Example usage
input_channels = 1  # 3 For RGB videos
hidden_channels = 64
latent_size = 128
kernel_size = 3
model = CONVLSTMVAE(resnet, input_channels, hidden_channels, latent_size)



# # Assuming input shape: (batch_size, sequence_length, channels, height, width)
# sample_input = torch.randn(16, 10, 3, 64, 64)
# output, mu, logvar = model(sample_input)

# print(f"Input shape: {sample_input.shape}")
# print(f"Output shape: {output.shape}")
# print(f"Mu shape: {mu.shape}")
# print(f"Logvar shape: {logvar.shape}")

In [17]:


import secrets

import easydict
import matplotlib.pyplot as plt
import torch
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm

writer = SummaryWriter()

## visualization
def imshow(past_data, title="MovingMNIST"):
    num_img = len(past_data)
    fig = plt.figure(figsize=(4 * num_img, 4))

    for idx in range(1, num_img + 1):
        ax = fig.add_subplot(1, num_img + 1, idx)
        ax.imshow(past_data[idx - 1])
    plt.suptitle(title, fontsize=30)
    plt.savefig(f"{title}")
    plt.close()

def visualize_reconstructions(model, test_loader, device, epoch):
    model.eval()
    with torch.no_grad():
        for i, batch_data in enumerate(test_loader):
            future_data, past_data = batch_data
            batch_size = past_data.size(0)
            example_size = past_data.size(1)
            image_size = past_data.size(2), past_data.size(3)
            #past_data = past_data.view(batch_size, example_size, -1).float().to(device)
            #reshape for convvaelstm
            past_data = past_data.reshape(batch_size, example_size, -1).float().to(device)

            
            _, recon_x, _ = model(past_data)
            
            if i == 0:
                n_examples = min(10, batch_size)
                #examples = past_data[:n_examples].cpu().view(n_examples, example_size, image_size[0], -1)
                #reshape for convvaelstm
                examples = past_data[:n_examples].cpu().reshape(n_examples, example_size, image_size[0], -1)
                #recon_examples = recon_x[:n_examples].cpu().view(n_examples, example_size, image_size[0], -1)
                recon_examples = recon_x[:n_examples].cpu().reshape(n_examples, example_size, image_size[0], -1)

                fig, axes = plt.subplots(2, n_examples, figsize=(20, 4))
                for j in range(n_examples):
                    axes[0, j].imshow(examples[j, 0], cmap='gray')
                    axes[0, j].axis('off')
                    axes[1, j].imshow(recon_examples[j, 0], cmap='gray')
                    axes[1, j].axis('off')
                plt.suptitle(f"Epoch {epoch}: Original (top) vs Reconstructed (bottom)")
                plt.savefig(f"reconstruction_epoch_{epoch}.png")
                wandb.log({"reconstructions": wandb.Image(plt)})
                plt.close()
                break

def train(args, model, train_loader, test_loader):
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    
    #creating a GradScaler for automatic mixed precision
    scaler = GradScaler()

    ## interation setup
    epochs = tqdm(range(args.max_iter // len(train_loader) + 1))

    ## training
    count = 0
    for epoch in epochs:
        model.train()
        optimizer.zero_grad()
        train_iterator = tqdm(
            enumerate(train_loader), total=len(train_loader), desc="training"
        )

        for i, batch_data in train_iterator:

            if count > args.max_iter:
                return model
            count += 1

            future_data, past_data = batch_data
            
            '''
            NO RESHAPE REQUIRED FOR INTO 2D for the CONVLSTM
            ## reshape
            batch_size = past_data.size(0)
            example_size = past_data.size(1)
            image_size = past_data.size(2), past_data.size(3)
            past_data = (
                past_data.view(batch_size, example_size, -1).float().to(args.device)
            )
            '''
            
            past_data = past_data.float().to(args.device).half()

            #casting operations to mixed precision
            with autocast():
                mloss, recon_x, info = model(past_data)
            
            #scaling the loss for better gradient flow
            scaler.scale(mloss.mean()).backward()
            
            #unscales the gradients before optimization
            scaler.step(optimizer)
            
            #Update the GradScaler state
            scaler.update()


            train_iterator.set_postfix({"train_loss": float(mloss.mean())})
        writer.add_scalar("train_loss", float(mloss.mean()), epoch)
        wandb.log({"train_loss": float(mloss.mean())})
        
        model.eval()
        eval_loss = 0
        test_iterator = tqdm(
            enumerate(test_loader), total=len(test_loader), desc="testing"
        )

        with torch.no_grad():
            for i, batch_data in test_iterator:
                future_data, past_data = batch_data
                '''No RESHAPE REQUIRED FOR CONVLSTMVAE
                ## reshape
                batch_size = past_data.size(0)
                example_size = past_data.size(1)
                past_data = (
                    past_data.view(batch_size, example_size, -1).float().to(args.device)
                )
                '''
                past_data = past_data.float().to(args.device).half()

                #using autocast for validation
                with autocast():
                    mloss, recon_x, info = model(past_data)

                eval_loss += mloss.mean().item()

                test_iterator.set_postfix({"eval_loss": float(mloss.mean())})

        eval_loss = eval_loss / len(test_loader)
        writer.add_scalar("eval_loss", float(eval_loss), epoch)
        wandb.log({"eval_loss": float(eval_loss)})
        print("Evaluation Score : [{}]".format(eval_loss))

        # Visualize reconstructions every 10 epochs
        if epoch % 1 == 0:
            visualize_reconstructions(model, test_loader, args.device, epoch)

    return model


if __name__ == "__main__":

    # training dataset
    train_set = MovingMNIST(
        root=".data/mnist",
        train=True,
        download=True,
        transform=transforms.ToTensor(),
        target_transform=transforms.ToTensor(),
    )

    # test dataset
    test_set = MovingMNIST(
        root=".data/mnist",
        train=False,
        download=True,
        transform=transforms.ToTensor(),
        target_transform=transforms.ToTensor(),
    )

    args = easydict.EasyDict(
        {
            "batch_size": 128,
            "device": torch.device("cuda")
            if torch.cuda.is_available()
            else torch.device("cpu"),
            "input_size": 4096,
            "hidden_size": 2048,
            "latent_size": 1024,
            "learning_rate": 0.001,
            "max_iter": 1000,
        }
    )

    batch_size = args.batch_size
    input_size = args.input_size
    hidden_size = args.hidden_size
    latent_size = args.latent_size

    # define LSTM-based VAE model
    #model = LSTMVAE(input_size, hidden_size, latent_size, device=args.device)
    #model.to(args.device)
    input_channels = 1 #since grayscale
    hidden_channels = 64
    latent_size = 128
    kernel_size = 3
    model = CONVLSTMVAE(input_channels, hidden_channels, latent_size, kernel_size)
    model.to(args.device)
    

    '''Change required for CONVLSTMVAE'''
    
    def collate_fn(batch):
        seqs, targets = zip(*batch) #Separating sequences and targets
        seqs = torch.stack(seqs).unsqueeze(2).permute(0,1,2,4,3) #Adjust dimensions for sequences. Unsqueeze(2) for CONVLSTM to add an extra dimension for the channel before reordering the dimensions with the permute function.
        targets = torch.stack(targets).unsqueeze(2).permute(0,1,2,4,3) #Adjust dimensions for sequences. Unsqueeze(2) for CONVLSTM to add an extra dimension for the channel before reordering the dimensions with the permute function.
        return seqs, targets
    
    #convert to format of data loader
    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )
    
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )
    
#     # convert to format of data loader
#     train_loader = torch.utils.data.DataLoader(
#         dataset=train_set, batch_size=args.batch_size, shuffle=True
#     )
#     test_loader = torch.utils.data.DataLoader(
#         dataset=test_set, batch_size=args.batch_size, shuffle=False
#     )

    # training
    trained_model = train(args, model, train_loader, test_loader)

    # save model
    id_ = secrets.token_hex(nbytes=4)
    model_path = f"lstmvae{id_}.model"
    torch.save(trained_model.state_dict(), f"lstmvae{id_}.model")
    wandb.save(model_path)

    # load model
    model_to_load = LSTMVAE(input_size, hidden_size, latent_size, device=args.device)
    model_to_load.to(args.device)
    model_to_load.load_state_dict(torch.load(f"lstmvae{id_}.model"))
    model_to_load.eval()

    # show results
    ## past_data, future_data -> shape: (10,10)
    future_data, past_data = train_set[0]

    ## reshape
    example_size = past_data.size(0)
    image_size = past_data.size(1), past_data.size(2)
    past_data = past_data.view(example_size, -1).float().to(args.device)
    _, recon_data, info = model_to_load(past_data.unsqueeze(0))

    nhw_orig = past_data.view(example_size, image_size[0], -1).cpu()
    nhw_recon = (
        recon_data.squeeze(0)
        .view(example_size, image_size[0], -1)
        .detach()
        .cpu()
        .numpy()
    )

    imshow(nhw_orig, title=f"final_input{id_}")
    imshow(nhw_recon, title=f"final_output{id_}")
    plt.show()
    wandb.finish()

Downloading https://github.com/tychovdo/MovingMNIST/raw/master/mnist_test_seq.npy.gz
Processing...
Done!


training:   0%|          | 0/71 [00:00<?, ?it/s]
  0%|          | 0/15 [00:00<?, ?it/s]


ValueError: not enough values to unpack (expected 5, got 4)