In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets.utils import download_url

import io
import imageio
from ipywidgets import widgets, HBox
import wandb

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset
import numpy as np
import io
import imageio
import matplotlib.pyplot as plt
from IPython.display import Image, display
from ipywidgets import widgets, Layout, HBox
from torch.utils.data import TensorDataset, DataLoader
from torchvision.datasets.utils import download_url
import random

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: ryukijano (hack-the-thong). Use `wandb login --relogin` to force relogin


True

In [3]:
class ConvLSTMCell(nn.Module):

    def __init__(self, in_channels, out_channels, 
    kernel_size, padding, activation, frame_size):

        super(ConvLSTMCell, self).__init__()  

        if activation == "tanh":
            self.activation = torch.tanh 
        elif activation == "relu":
            self.activation = torch.relu
        
        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        self.conv = nn.Conv2d(
            in_channels=in_channels + out_channels, 
            out_channels=4 * out_channels, 
            kernel_size=kernel_size, 
            padding=padding)           

        # Initialize weights for Hadamard Products
        self.W_ci = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_co = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_cf = nn.Parameter(torch.Tensor(out_channels, *frame_size))

    def forward(self, X, H_prev, C_prev):

        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        conv_output = self.conv(torch.cat([X, H_prev], dim=1))

        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        i_conv, f_conv, C_conv, o_conv = torch.chunk(conv_output, chunks=4, dim=1)

        input_gate = torch.sigmoid(i_conv + self.W_ci * C_prev )
        forget_gate = torch.sigmoid(f_conv + self.W_cf * C_prev )

        # Current Cell output
        C = forget_gate*C_prev + input_gate * self.activation(C_conv)

        output_gate = torch.sigmoid(o_conv + self.W_co * C )

        # Current Hidden State
        H = output_gate * self.activation(C)

        return H, C

In [4]:
class ConvLSTM(nn.Module):

    def __init__(self, in_channels, out_channels, 
    kernel_size, padding, activation, frame_size):

        super(ConvLSTM, self).__init__()

        self.out_channels = out_channels

        # We will unroll this over time steps
        self.convLSTMcell = ConvLSTMCell(in_channels, out_channels, 
        kernel_size, padding, activation, frame_size)

    def forward(self, X):

        # X is a frame sequence (batch_size, num_channels, seq_len, height, width)

        # Get the dimensions
        batch_size, _, seq_len, height, width = X.size()

        # Initialize output
        output = torch.zeros(batch_size, self.out_channels, seq_len, 
        height, width, device=device)
        
        # Initialize Hidden State
        H = torch.zeros(batch_size, self.out_channels, 
        height, width, device=device)

        # Initialize Cell Input
        C = torch.zeros(batch_size,self.out_channels, 
        height, width, device=device)

        # Unroll over time steps
        for time_step in range(seq_len):

            H, C = self.convLSTMcell(X[:,:,time_step], H, C)

            output[:,:,time_step] = H

        return output

In [10]:
class ConvLSTMVAE(nn.Module):
    def __init__(self, num_channels, num_kernels, kernel_size, padding, 
                 activation, frame_size, num_layers, latent_dim, seq_length):
        super(ConvLSTMVAE, self).__init__()
        
        self.latent_dim = latent_dim
        self.frame_size = frame_size
        self.seq_length = seq_length

        # Encoder
        self.encoder = nn.Sequential()
        self.encoder.add_module(
            "convlstm1", ConvLSTM(
                in_channels=num_channels, out_channels=num_kernels,
                kernel_size=kernel_size, padding=padding, 
                activation=activation, frame_size=frame_size)
        )
        self.encoder.add_module("batchnorm1", nn.BatchNorm3d(num_features=num_kernels))

        for l in range(2, num_layers+1):
            self.encoder.add_module(
                f"convlstm{l}", ConvLSTM(
                    in_channels=num_kernels, out_channels=num_kernels,
                    kernel_size=kernel_size, padding=padding, 
                    activation=activation, frame_size=frame_size)
            )
            self.encoder.add_module(f"batchnorm{l}", nn.BatchNorm3d(num_features=num_kernels))

        # Latent space
        self.fc_mu = nn.Linear(num_kernels * frame_size[0] * frame_size[1] * seq_length, latent_dim)
        self.fc_logvar = nn.Linear(num_kernels * frame_size[0] * frame_size[1] * seq_length, latent_dim)

        # Decoder
        self.decoder_input = nn.Linear(latent_dim, num_kernels * frame_size[0] * frame_size[1] * seq_length)
        
        self.decoder = nn.Sequential()
        for l in range(1, num_layers+1):
            self.decoder.add_module(
                f"convlstm{l}", ConvLSTM(
                    in_channels=num_kernels, out_channels=num_kernels,
                    kernel_size=kernel_size, padding=padding, 
                    activation=activation, frame_size=frame_size)
            )
            self.decoder.add_module(f"batchnorm{l}", nn.BatchNorm3d(num_features=num_kernels))

        self.final_conv = nn.Conv3d(
            in_channels=num_kernels, out_channels=num_channels,
            kernel_size=(1, kernel_size, kernel_size), padding=(0, padding, padding))

    def encode(self, x):
        h = self.encoder(x)
        h = h.contiguous().view(h.size(0), -1)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.decoder_input(z)
        h = h.view(h.size(0), -1, self.seq_length, self.frame_size[0], self.frame_size[1])
        h = self.decoder(h)
        return torch.sigmoid(self.final_conv(h))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar


DATALOADERS


In [11]:
# Load Data as Numpy Array
url = "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy"
fpath = "moving_mnist.npy"
download_url(url, root=".", filename=fpath)
MovingMNIST = np.load(fpath).transpose(1, 0, 2, 3)

# Shuffle Data
np.random.shuffle(MovingMNIST)

# Train, Test, Validation splits
train_data = MovingMNIST[:8000]         
val_data = MovingMNIST[8000:9000]       
test_data = MovingMNIST[9000:10000]     

def collate(batch):

    # Add channel dim, scale pixels between 0 and 1, send to GPU
    batch = torch.tensor(batch).unsqueeze(1)     
    batch = batch / 255.0                        
    batch = batch.to(device)                     

    # Randomly pick 10 frames as input, 11th frame is target
    rand = np.random.randint(10,20)                     
    return batch[:,:,rand-10:rand], batch[:,:,rand]     


# Training Data Loader
train_loader = DataLoader(train_data, shuffle=True, 
                        batch_size=8, collate_fn=collate)

# Validation Data Loader
val_loader = DataLoader(val_data, shuffle=True, 
                        batch_size=8, collate_fn=collate)

Using downloaded and verified file: .\moving_mnist.npy


In [12]:
# Fetch a single batch from the train_loader
for input_batch, target_batch in train_loader:
    print("Input shape:", input_batch.shape)
    print("Target shape:", target_batch.shape)
    break  # Only process the first batch

Input shape: torch.Size([8, 1, 10, 64, 64])
Target shape: torch.Size([8, 1, 64, 64])


Instantiating the model, optimizer and loss function

In [13]:

# The input video frames are grayscale, thus single channel
model = ConvLSTMVAE(
    num_channels=1, 
    num_kernels=64, 
    kernel_size=3, 
    padding=1, 
    activation='relu', 
    frame_size=(64, 64),  # Adjust to match your data
    num_layers=3, 
    latent_dim=128,
    seq_length=10  # Number of frames in each sequence
).to(device)

optim = Adam(model.parameters(), lr=1e-4)

# Binary Cross Entropy, target pixel values either 0 or 1
criterion = nn.BCEWithLogitsLoss(reduction='sum')

In [14]:
import torch
from tqdm import tqdm
import wandb
from torchvision.utils import make_grid

num_epochs = 50
scaler = torch.cuda.amp.GradScaler()
'''change loss function later'''
def vae_loss(recon_x, x, mu, logvar):
    BCE = torch.nn.functional.binary_cross_entropy_with_logits(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Initialize wandb
wandb.init(project="conv_vae_lstm_mnist", name="conv_vae_lstm_mnist_default")

# Log model architecture
wandb.watch(model)

for epoch in range(1, num_epochs+1):
    train_loss = 0
    model.train()
    for batch_num, (input, _) in enumerate(tqdm(train_loader), 1):
        input = input.to(device)
        optim.zero_grad()
        with torch.cuda.amp.autocast():
            recon_batch, mu, logvar = model(input)
            loss = vae_loss(recon_batch, input, mu, logvar)
        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()
        train_loss += loss.item()
    train_loss /= len(train_loader.dataset)

    val_loss = 0
    model.eval()
    with torch.no_grad():
        for input, _ in tqdm(val_loader):
            input = input.to(device)
            recon_batch, mu, logvar = model(input)
            loss = vae_loss(recon_batch, input, mu, logvar)
            val_loss += loss.item()
    val_loss /= len(val_loader.dataset)

    print(f"Epoch:{epoch} Training Loss:{train_loss:.2f} Validation Loss:{val_loss:.2f}\n")

    # Log losses to wandb
    wandb.log({
        "epoch": epoch,
        "train_loss": train_loss,
        "val_loss": val_loss
    })

    # Generate and log sample images
    with torch.no_grad():
        # Generate samples
        sample = torch.randn(64, model.latent_dim).to(device)
        sample = model.decode(sample).cpu()
        
        # Reconstruct images from validation set
        val_inputs, _ = next(iter(val_loader))
        val_inputs = val_inputs.to(device)
        recon_batch, _, _ = model(val_inputs)
        
        # Create grids
        sample_grid = make_grid(sample, nrow=8, normalize=True)
        recon_grid = make_grid(torch.cat([val_inputs.cpu(), recon_batch.cpu()]), nrow=8, normalize=True)
        
        # Log images to wandb
        wandb.log({
            "generated_samples": wandb.Image(sample_grid, caption="Generated Samples"),
            "reconstructions": wandb.Image(recon_grid, caption="Reconstructions (Top: Input, Bottom: Reconstructed)")
        })

# Close wandb run
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

  1%|          | 10/1000 [02:05<3:27:52, 12.60s/it]


KeyboardInterrupt: 