In [1]:
!pip install easydict
!pip install wandb



In [6]:
import secrets

import easydict
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.cuda.amp import autocast, GradScaler
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm
import wandb

In [7]:
# # Get Wandb API key from Kaggle Secrets
# user_secrets = UserSecretsClient()
# wandb_api_key = user_secrets.get_secret("wandb_api_key")


In [None]:
# Initialize wandb
wandb.login()  # This will now use the API key from the environment variable
wandb.init(project="convlstmvae-moving-mnist", entity="ryukijano")

In [None]:
# Initialize wandb
wandb.login()  # This will now use the API key from the environment variable
wandb.init(project="convlstmvae-moving-mnist", entity="ryukijano")

In [8]:
import torch
from torchvision import models

#loading a pre-trained ResNet and modifying it
resnet = models.resnet18(pretrained=True)
resnet.fc = nn.Identity() #removing the final fully connected layer.



In [None]:

import torch
from torch import nn
from torch.nn import functional as F


class Encoder(nn.Module):
    def __init__(self, input_size=4096, hidden_size=1024, num_layers=2):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(
            input_size,
            hidden_size,
            num_layers,
            batch_first=True,
            bidirectional=False,
        )

    def forward(self, x):
        # x: tensor of shape (batch_size, seq_length, hidden_size)
        outputs, (hidden, cell) = self.lstm(x)
        return (hidden, cell)


class Decoder(nn.Module):
    def __init__(
        self, input_size=4096, hidden_size=1024, output_size=4096, num_layers=2
    ):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(
            input_size,
            hidden_size,
            num_layers,
            batch_first=True,
            bidirectional=False,
        )
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        # x: tensor of shape (batch_size, seq_length, hidden_size)
        output, (hidden, cell) = self.lstm(x, hidden)
        prediction = self.fc(output)
        return prediction, (hidden, cell)


class LSTMVAE(nn.Module):
    """LSTM-based Variational Auto Encoder"""

    def __init__(
        self, input_size, hidden_size, latent_size, device=torch.device("cuda")
    ):
        """
        input_size: int, batch_size x sequence_length x input_dim
        hidden_size: int, output size of LSTM AE
        latent_size: int, latent z-layer size
        num_lstm_layer: int, number of layers in LSTM
        """
        super(LSTMVAE, self).__init__()
        self.device = device

        # dimensions
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.num_layers = 1

        # lstm ae
        self.lstm_enc = Encoder(
            input_size=input_size, hidden_size=hidden_size, num_layers=self.num_layers
        )
        self.lstm_dec = Decoder(
            input_size=latent_size,
            output_size=input_size,
            hidden_size=hidden_size,
            num_layers=self.num_layers,
        )

        self.fc21 = nn.Linear(self.hidden_size, self.latent_size)
        self.fc22 = nn.Linear(self.hidden_size, self.latent_size)
        self.fc3 = nn.Linear(self.latent_size, self.hidden_size)

    def reparametize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        noise = torch.randn_like(std).to(self.device)

        z = mu + noise * std
        return z

    def forward(self, x):
        batch_size, seq_len, feature_dim = x.shape

        # encode input space to hidden space
        enc_hidden = self.lstm_enc(x)
        enc_h = enc_hidden[0].view(self.num_layers, batch_size, self.hidden_size).to(self.device)
        enc_c = enc_hidden[1].view(self.num_layers, batch_size, self.hidden_size).to(self.device)

        # extract latent variable z(hidden space to latent space)
        mean = self.fc21(enc_h[-1])
        logvar = self.fc22(enc_h[-1])
        z = self.reparametize(mean, logvar)  # batch_size x latent_size

        # initialize hidden state as inputs
        h_ = self.fc3(z).view(self.num_layers, batch_size, self.hidden_size)
        c_ = torch.zeros_like(h_)
        
        # decode latent space to input space
        z = z.unsqueeze(1).repeat(1, seq_len, 1)
        z = z.view(batch_size, seq_len, self.latent_size).to(self.device)

        # initialize hidden state
        hidden = (h_.contiguous(), c_.contiguous())
        reconstruct_output, hidden = self.lstm_dec(z, hidden)

        x_hat = reconstruct_output

        # calculate vae loss
        losses = self.loss_function(x_hat, x, mean, logvar)
        m_loss, recon_loss, kld_loss = losses["loss"], losses["Reconstruction_Loss"], losses["KLD"]

        return m_loss, x_hat, (recon_loss, kld_loss)

    def loss_function(self, *args, **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = 0.00025  # Account for the minibatch samples from the dataset
        recons_loss = F.mse_loss(recons, input)

        kld_loss = torch.mean(
            -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), dim=0
        )

        loss = recons_loss + kld_weight * kld_loss
        return {
            "loss": loss,
            "Reconstruction_Loss": recons_loss.detach(),
            "KLD": -kld_loss.detach(),
        }


class LSTMAE(nn.Module):
    """LSTM-based Auto Encoder"""

    def __init__(self, input_size, hidden_size, latent_size, device=torch.device("cuda")):
        """
        input_size: int, batch_size x sequence_length x input_dim
        hidden_size: int, output size of LSTM AE
        latent_size: int, latent z-layer size
        num_lstm_layer: int, number of layers in LSTM
        """
        super(LSTMAE, self).__init__()
        self.device = device

        # dimensions
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.latent_size = latent_size

        # lstm ae
        self.lstm_enc = Encoder(
            input_size=input_size,
            hidden_size=hidden_size,
        )
        self.lstm_dec = Decoder(
            input_size=input_size,
            output_size=input_size,
            hidden_size=hidden_size,
        )

        self.criterion = nn.MSELoss()

    def forward(self, x):
        batch_size, seq_len, feature_dim = x.shape

        enc_hidden = self.lstm_enc(x)

        temp_input = torch.zeros((batch_size, seq_len, feature_dim), dtype=torch.float).to(
            self.device
        )
        hidden = enc_hidden
        reconstruct_output, hidden = self.lstm_dec(temp_input, hidden)
        reconstruct_loss = self.criterion(reconstruct_output, x)

        return reconstruct_loss, reconstruct_output, (0, 0)


In [None]:
import torch
from torch import nn
from torch.nn import functional as F

class ConvLSTMCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvLSTMCell, self).__init__()
        
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        
        self.conv = nn.Conv2d(
            in_channels=self.input_channels + self.hidden_channels,
            out_channels=4 * self.hidden_channels,
            kernel_size=self.kernel_size,
            padding=self.padding,
            bias=True
        )

    def forward(self, x, h, c):
        combined = torch.cat([x, h], dim=1)
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_channels, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)
        
        c_next = f * c + i * g
        h_next = o * torch.tanh(c_next)
        
        return h_next, c_next

class ConvLSTMEncoder(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvLSTMEncoder, self).__init__()
        self.convlstm = ConvLSTMCell(input_channels, hidden_channels, kernel_size)
    
    def forward(self, x):
        batch_size, seq_len, channels, height, width = x.size()
        h = torch.zeros(batch_size, self.convlstm.hidden_channels, height, width).to(x.device)
        c = torch.zeros(batch_size, self.convlstm.hidden_channe7ls, height, width).to(x.device)
        7
        for t in range(seq_len):
            h, c = self.convlstm(x[:, t, :, :, :], h, c)
        
        return h, c

class ConvLSTMDecoder(nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels, kernel_size):
        super(ConvLSTMDecoder, self).__init__()
        self.convlstm = ConvLSTMCell(input_channels, hidden_channels, kernel_size)
        self.conv_out = nn.Conv2d(hidden_channels, output_channels, kernel_size=3, padding=1)
    
    def forward(self, x, h, c, seq_len):
        outputs = []
        
        for _ in range(seq_len):
            h, c = self.convlstm(x, h, c)
            output = self.conv_out(h)
            outputs.append(output)
            x = output
        
        return torch.stack(outputs, dim=1)

#Defining CONVLSTMVAE with a Pre-trained Encoder
class CONVLSTMVAE(nn.Module):
    def __init__(self, input_channels, hidden_channels, latent_size, kernel_size=3):
        super(CONVLSTMVAE, self).__init__()

        self.encoder = encoder
        self.encoder = ConvLSTMEncoder(input_channels, hidden_channels, kernel_size)
        self.decoder = ConvLSTMDecoder(input_channels, hidden_channels, input_channels, kernel_size)
        
        self.fc_mu = nn.Linear(hidden_channels * 64 * 64, latent_size)  # Assuming 64x64 spatial dimensions
        self.fc_logvar = nn.Linear(hidden_channels * 64 * 64, latent_size)
        self.fc_decode = nn.Linear(latent_size, hidden_channels * 64 * 64)
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        batch_size, seq_len, channels, height, width = x.size()
        
        # Encode
        cnn_features = []
        for t in range(seq_len):
            cnn_output = self.encoder(x[:, t, :, :, :])
            cnn_features.append(cnn_output)
        cnn_features = torch.stack(cnn_features, dim=1)

        h,c = self.convlstm_encoder(cnn_features)
        h_flat = h.reshape(batch_size, -1)
        
        # VAE bottleneck
        mu = self.fc_mu(h_flat)
        logvar = self.fc_logvar(h_flat)
        z = self.reparameterize(mu, logvar)
        
        # Decode
        h_decoded = self.fc_decode(z).view(batch_size, -1, height, width)
        c_decoded = torch.zeros_like(h_decoded)
        x_decoded = torch.zeros(batch_size, channels, height, width).to(x.device)
        
        output = self.decoder(x_decoded, h_decoded, c_decoded, seq_len)
        
        return output, mu, logvar
    
    def loss_function(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD

# Visualization function
def visualize_reconstructions(model, val_loader, device, epoch):
    model.eval()
    with torch.no_grad():
        for batch in val_loader:
            x, _ = batch
            x = x.to(device)
            recon_x, _, _ = model(x)
            break
    x = x.cpu().numpy()
    recon_x = recon_x.cpu().numpy()

    fig, axes = plt.subplots(2, 10, figsize=(20, 4))
    for i in range(10):
        axes[0, i].imshow(x[i, 0, 0], cmap='gray')
        axes[0, i].axis('off')
        axes[1, i].imshow(recon_x[i, 0, 0], cmap='gray')
        axes[1, i].axis('off')
    plt.suptitle(f"Epoch {epoch}: Original (top) vs Reconstructed (bottom)")
    plt.savefig(f"reconstruction_epoch_{epoch}.png")
    plt.close()

# Training function
def train(model, train_loader, val_loader, optimizer, epochs, device):
    scaler = GradScaler()
    model.to(device)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in train_loader:
            x, _ = batch
            x = x.to(device)
            optimizer.zero_grad()
            with autocast():
                recon_x, mu, logvar = model(x)
                loss = model.loss_function(recon_x, x, mu, logvar)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader.dataset)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")
        
        if (epoch + 1) % 10 == 0:
            visualize_reconstructions(model, val_loader, device, epoch)

# Main function
def main():
    # Data preparation
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
    ])
    train_dataset = MovingMNIST(root='./data', train=True, transform=transform, download=True)
    val_dataset = MovingMNIST(root='./data', train=False, transform=transform, download=True)
    
    def collate_fn(batch):
        seqs, targets = zip(*batch)  # Separating sequences and targets
        seqs = torch.stack(seqs).unsqueeze(2).permute(0, 1, 2, 4, 3)  # Adjust dimensions
        targets = torch.stack(targets).unsqueeze(2).permute(0, 1, 2, 4, 3)  # Adjust dimensions
        return seqs, targets

# Example usage
input_channels = 1  # 3 For RGB videos
hidden_channels = 64
latent_size = 128
model = CONVLSTMVAE(input_channels, hidden_channels, latent_size)

# # Assuming input shape: (batch_size, sequence_length, channels, height, width)
# sample_input = torch.randn(16, 10, 3, 64, 64)
# output, mu, logvar = model(sample_input)

# print(f"Input shape: {sample_input.shape}")
# print(f"Output shape: {output.shape}")
# print(f"Mu shape: {mu.shape}")
# print(f"Logvar shape: {logvar.shape}")