In [14]:
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')




In [15]:
# ConvLSTMCell as defined earlier
class ConvLSTMCell(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding, activation, frame_size):
        super(ConvLSTMCell, self).__init__()
        if activation == "tanh":
            self.activation = torch.tanh 
        elif activation == "relu":
            self.activation = torch.relu
        self.conv = nn.Conv2d(in_channels=in_channels + out_channels, out_channels=4 * out_channels, kernel_size=kernel_size, padding=padding)           
        self.W_ci = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_co = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_cf = nn.Parameter(torch.Tensor(out_channels, *frame_size))

    def forward(self, X, H_prev, C_prev):
        conv_output = self.conv(torch.cat([X, H_prev], dim=1))
        i_conv, f_conv, C_conv, o_conv = torch.chunk(conv_output, chunks=4, dim=1)
        input_gate = torch.sigmoid(i_conv + self.W_ci * C_prev )
        forget_gate = torch.sigmoid(f_conv + self.W_cf * C_prev )
        C = forget_gate * C_prev + input_gate * self.activation(C_conv)
        output_gate = torch.sigmoid(o_conv + self.W_co * C )
        H = output_gate * self.activation(C)
        return H, C

In [16]:
class ConvLSTMCell(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding, activation, frame_size):
        super(ConvLSTMCell, self).__init__()
        if activation == "tanh":
            self.activation = torch.tanh 
        elif activation == "relu":
            self.activation = torch.relu
        self.conv = nn.Conv2d(in_channels=in_channels + out_channels, out_channels=4 * out_channels, kernel_size=kernel_size, padding=padding)           
        self.W_ci = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_co = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_cf = nn.Parameter(torch.Tensor(out_channels, *frame_size))

    def forward(self, X, H_prev, C_prev):
        conv_output = self.conv(torch.cat([X, H_prev], dim=1))
        i_conv, f_conv, C_conv, o_conv = torch.chunk(conv_output, chunks=4, dim=1)
        input_gate = torch.sigmoid(i_conv + self.W_ci * C_prev )
        forget_gate = torch.sigmoid(f_conv + self.W_cf * C_prev )
        C = forget_gate * C_prev + input_gate * self.activation(C_conv)
        output_gate = torch.sigmoid(o_conv + self.W_co * C )
        H = output_gate * self.activation(C)
        return H, C

class ConvLSTM(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding, activation, frame_size):
        super(ConvLSTM, self).__init__()
        self.out_channels = out_channels
        self.convLSTMcell = ConvLSTMCell(in_channels, out_channels, kernel_size, padding, activation, frame_size)

    def forward(self, X):
        batch_size, _, seq_len, height, width = X.size()
        output = torch.zeros(batch_size, self.out_channels, seq_len, height, width, device=device)
        H = torch.zeros(batch_size, self.out_channels, height, width, device=device)
        C = torch.zeros(batch_size, self.out_channels, height, width, device=device)
        for time_step in range(seq_len):
            H, C = self.convLSTMcell(X[:, :, time_step], H, C)
            output[:, :, time_step] = H
        return output

In [17]:
class Sampling(nn.Module):
    def forward(self, z_mean, z_log_var):
        std = torch.exp(0.5 * z_log_var)
        eps = torch.randn_like(std)
        return z_mean + eps * std

class Encoder(nn.Module):
    def __init__(self, input_shape, latent_dim):
        super(Encoder, self).__init__()
        self.convLSTM = ConvLSTM(input_shape[1], 64, (3, 3), (1, 1), 'relu', (input_shape[2], input_shape[3]))
        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(64 * input_shape[2] * input_shape[3], latent_dim)
        self.fc_logvar = nn.Linear(64 * input_shape[2] * input_shape[3], latent_dim)

    def forward(self, x):
        h = self.convLSTM(x)[:, :, -1]
        h = self.flatten(h)
        z_mean = self.fc_mu(h)
        z_log_var = self.fc_logvar(h)
        return z_mean, z_log_var

class Decoder(nn.Module):
    def __init__(self, latent_dim, output_shape):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 64 * output_shape[2] * output_shape[3])
        self.unflatten = nn.Unflatten(1, (64, output_shape[2], output_shape[3]))
        self.convLSTM = ConvLSTM(64, output_shape[1], (3, 3), (1, 1), 'relu', (output_shape[2], output_shape[3]))

    def forward(self, z, seq_len):
        h = self.fc(z)
        h = self.unflatten(h)
        h = h.unsqueeze(2).expand(-1, -1, seq_len, -1, -1)
        recon_x = self.convLSTM(h)
        return recon_x

class VAE(nn.Module):
    def __init__(self, input_shape, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_shape, latent_dim)
        self.decoder = Decoder(latent_dim, input_shape)

    def forward(self, x):
        z_mean, z_log_var = self.encoder(x)
        z = Sampling()(z_mean, z_log_var)
        recon_x = self.decoder(z, x.size(2))
        return z_mean, z_log_var, recon_x


In [20]:
class MovingMNISTDataset(Dataset):
    def __init__(self, data_file):
        self.data = np.load(data_file)
        self.data = self.data.transpose(1, 0, 2, 3)  # (num_samples, seq_len, height, width)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx] / 255.0
        sample = np.expand_dims(sample, axis=1)  # Add channel dimension
        return torch.tensor(sample, dtype=torch.float32)

data_file = 'mnist_test_seq.npy'  # Update this path
train_data = MovingMNISTDataset(data_file)
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)


In [21]:
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

def vae_loss(x, recon_x, z_mean, z_log_var):
    recon_loss = nn.functional.mse_loss(recon_x, x, reduction='sum')
    kl_div = -0.5 * torch.sum(1 + z_log_var - z_mean**2 - torch.exp(z_log_var))
    return recon_loss + kl_div

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_shape = (10, 1, 64, 64)  # (seq_len, num_channels, height, width)
latent_dim = 20
model = VAE(input_shape, latent_dim).to(device)
optimizer = Adam(model.parameters(), lr=1e-4)
scaler = GradScaler()

num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch in tqdm(train_loader):
        batch = batch.to(device)
        optimizer.zero_grad()
        with autocast():
            z_mean, z_log_var, recon_x = model(batch)
            loss = vae_loss(batch, recon_x, z_mean, z_log_var)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item()
    train_loss /= len(train_loader.dataset)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}')


  0%|          | 0/625 [00:00<?, ?it/s]


RuntimeError: Given groups=1, weight of size [256, 65, 3, 3], expected input[16, 84, 64, 64] to have 65 channels, but got 84 channels instead

In [None]:
import matplotlib.pyplot as plt

model.eval()
with torch.no_grad():
    for batch in train_loader:
        batch = batch.to(device)
        z_mean, z_log_var, recon_batch = model(batch)
        break

batch = batch.cpu().numpy()
recon_batch = recon_batch.cpu().numpy()

fig, axes = plt.subplots(2, 10, figsize=(20, 4))
for i in range(10):
    axes[0, i].imshow(batch[0, 0, i], cmap='gray')
    axes[0, i].axis('off')
    axes[1, i].imshow(recon_batch[0, 0, i], cmap='gray')
    axes[1, i].axis('off')
plt.show()
