In [1]:
import secrets

import easydict
import matplotlib.pyplot as plt
import torch
from torch.cuda.amp import autocast, GradScaler
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm

2024-06-23 15:26:45.093793: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-23 15:26:45.093903: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-23 15:26:45.210478: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:


import torch
from torch import nn
from torch.nn import functional as F


class Encoder(nn.Module):
    def __init__(self, input_size=4096, hidden_size=1024, num_layers=2):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(
            input_size,
            hidden_size,
            num_layers,
            batch_first=True,
            bidirectional=False,
        )

    def forward(self, x):
        # x: tensor of shape (batch_size, seq_length, hidden_size)
        outputs, (hidden, cell) = self.lstm(x)
        return (hidden, cell)


class Decoder(nn.Module):
    def __init__(
        self, input_size=4096, hidden_size=1024, output_size=4096, num_layers=2
    ):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(
            input_size,
            hidden_size,
            num_layers,
            batch_first=True,
            bidirectional=False,
        )
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        # x: tensor of shape (batch_size, seq_length, hidden_size)
        output, (hidden, cell) = self.lstm(x, hidden)
        prediction = self.fc(output)
        return prediction, (hidden, cell)


class LSTMVAE(nn.Module):
    """LSTM-based Variational Auto Encoder"""

    def __init__(
        self, input_size, hidden_size, latent_size, device=torch.device("cuda")
    ):
        """
        input_size: int, batch_size x sequence_length x input_dim
        hidden_size: int, output size of LSTM AE
        latent_size: int, latent z-layer size
        num_lstm_layer: int, number of layers in LSTM
        """
        super(LSTMVAE, self).__init__()
        self.device = device

        # dimensions
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.num_layers = 1

        # lstm ae
        self.lstm_enc = Encoder(
            input_size=input_size, hidden_size=hidden_size, num_layers=self.num_layers
        )
        self.lstm_dec = Decoder(
            input_size=latent_size,
            output_size=input_size,
            hidden_size=hidden_size,
            num_layers=self.num_layers,
        )

        self.fc21 = nn.Linear(self.hidden_size, self.latent_size)
        self.fc22 = nn.Linear(self.hidden_size, self.latent_size)
        self.fc3 = nn.Linear(self.latent_size, self.hidden_size)

    def reparametize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        noise = torch.randn_like(std).to(self.device)

        z = mu + noise * std
        return z

    def forward(self, x):
        batch_size, seq_len, feature_dim = x.shape

        # encode input space to hidden space
        enc_hidden = self.lstm_enc(x)
        enc_h = enc_hidden[0].view(self.num_layers, batch_size, self.hidden_size).to(self.device)
        enc_c = enc_hidden[1].view(self.num_layers, batch_size, self.hidden_size).to(self.device)

        # extract latent variable z(hidden space to latent space)
        mean = self.fc21(enc_h[-1])
        logvar = self.fc22(enc_h[-1])
        z = self.reparametize(mean, logvar)  # batch_size x latent_size

        # initialize hidden state as inputs
        h_ = self.fc3(z).view(self.num_layers, batch_size, self.hidden_size)
        c_ = torch.zeros_like(h_)
        
        # decode latent space to input space
        z = z.unsqueeze(1).repeat(1, seq_len, 1)
        z = z.view(batch_size, seq_len, self.latent_size).to(self.device)

        # initialize hidden state
        hidden = (h_.contiguous(), c_.contiguous())
        reconstruct_output, hidden = self.lstm_dec(z, hidden)

        x_hat = reconstruct_output

        # calculate vae loss
        losses = self.loss_function(x_hat, x, mean, logvar)
        m_loss, recon_loss, kld_loss = losses["loss"], losses["Reconstruction_Loss"], losses["KLD"]

        return m_loss, x_hat, (recon_loss, kld_loss)

    def loss_function(self, *args, **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = 0.00025  # Account for the minibatch samples from the dataset
        recons_loss = F.mse_loss(recons, input)

        kld_loss = torch.mean(
            -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), dim=0
        )

        loss = recons_loss + kld_weight * kld_loss
        return {
            "loss": loss,
            "Reconstruction_Loss": recons_loss.detach(),
            "KLD": -kld_loss.detach(),
        }


class LSTMAE(nn.Module):
    """LSTM-based Auto Encoder"""

    def __init__(self, input_size, hidden_size, latent_size, device=torch.device("cuda")):
        """
        input_size: int, batch_size x sequence_length x input_dim
        hidden_size: int, output size of LSTM AE
        latent_size: int, latent z-layer size
        num_lstm_layer: int, number of layers in LSTM
        """
        super(LSTMAE, self).__init__()
        self.device = device

        # dimensions
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.latent_size = latent_size

        # lstm ae
        self.lstm_enc = Encoder(
            input_size=input_size,
            hidden_size=hidden_size,
        )
        self.lstm_dec = Decoder(
            input_size=input_size,
            output_size=input_size,
            hidden_size=hidden_size,
        )

        self.criterion = nn.MSELoss()

    def forward(self, x):
        batch_size, seq_len, feature_dim = x.shape

        enc_hidden = self.lstm_enc(x)

        temp_input = torch.zeros((batch_size, seq_len, feature_dim), dtype=torch.float).to(
            self.device
        )
        hidden = enc_hidden
        reconstruct_output, hidden = self.lstm_dec(temp_input, hidden)
        reconstruct_loss = self.criterion(reconstruct_output, x)

        return reconstruct_loss, reconstruct_output, (0, 0)


In [None]:
# import torch
# from torch import nn
# from torch.nn import functional as F

# class ConvLSTMCell(nn.Module):
#     def __init__(self, input_channels, hidden_channels, kernel_size):
#         super(ConvLSTMCell, self).__init__()
        
#         self.input_channels = input_channels
#         self.hidden_channels = hidden_channels
#         self.kernel_size = kernel_size
#         self.padding = kernel_size // 2
        
#         self.conv = nn.Conv2d(
#             in_channels=self.input_channels + self.hidden_channels,
#             out_channels=4 * self.hidden_channels,
#             kernel_size=self.kernel_size,
#             padding=self.padding,
#             bias=True
#         )

#     def forward(self, x, h, c):
#         combined = torch.cat([x, h], dim=1)
#         combined_conv = self.conv(combined)
#         cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_channels, dim=1)
#         i = torch.sigmoid(cc_i)
#         f = torch.sigmoid(cc_f)
#         o = torch.sigmoid(cc_o)
#         g = torch.tanh(cc_g)
        
#         c_next = f * c + i * g
#         h_next = o * torch.tanh(c_next)
        
#         return h_next, c_next

# class ConvLSTMEncoder(nn.Module):
#     def __init__(self, input_channels, hidden_channels, kernel_size):
#         super(ConvLSTMEncoder, self).__init__()
#         self.convlstm = ConvLSTMCell(input_channels, hidden_channels, kernel_size)
    
#     def forward(self, x):
#         batch_size, seq_len, channels, height, width = x.size()
#         h = torch.zeros(batch_size, self.convlstm.hidden_channels, height, width).to(x.device)
#         c = torch.zeros(batch_size, self.convlstm.hidden_channels, height, width).to(x.device)
        
#         for t in range(seq_len):
#             h, c = self.convlstm(x[:, t, :, :, :], h, c)
        
#         return h, c

# class ConvLSTMDecoder(nn.Module):
#     def __init__(self, input_channels, hidden_channels, output_channels, kernel_size):
#         super(ConvLSTMDecoder, self).__init__()
#         self.convlstm = ConvLSTMCell(input_channels, hidden_channels, kernel_size)
#         self.conv_out = nn.Conv2d(hidden_channels, output_channels, kernel_size=3, padding=1)
    
#     def forward(self, x, h, c, seq_len):
#         outputs = []
        
#         for _ in range(seq_len):
#             h, c = self.convlstm(x, h, c)
#             output = self.conv_out(h)
#             outputs.append(output)
#             x = output
        
#         return torch.stack(outputs, dim=1)

# class CONVLSTMVAE(nn.Module):
#     def __init__(self, input_channels, hidden_channels, latent_size, kernel_size=3):
#         super(CONVLSTMVAE, self).__init__()
        
#         self.encoder = ConvLSTMEncoder(input_channels, hidden_channels, kernel_size)
#         self.decoder = ConvLSTMDecoder(input_channels, hidden_channels, input_channels, kernel_size)
        
#         self.fc_mu = nn.Linear(hidden_channels * 64 * 64, latent_size)  # Assuming 64x64 spatial dimensions
#         self.fc_logvar = nn.Linear(hidden_channels * 64 * 64, latent_size)
#         self.fc_decode = nn.Linear(latent_size, hidden_channels * 64 * 64)
        
#     def reparameterize(self, mu, logvar):
#         std = torch.exp(0.5 * logvar)
#         eps = torch.randn_like(std)
#         return mu + eps * std
    
#     def forward(self, x):
#         batch_size, seq_len, channels, height, width = x.size()
        
#         # Encode
#         h, c = self.encoder(x)
#         h_flat = h.view(batch_size, -1)
        
#         # VAE bottleneck
#         mu = self.fc_mu(h_flat)
#         logvar = self.fc_logvar(h_flat)
#         z = self.reparameterize(mu, logvar)
        
#         # Decode
#         h_decoded = self.fc_decode(z).view(batch_size, -1, height, width)
#         c_decoded = torch.zeros_like(h_decoded)
#         x_decoded = torch.zeros(batch_size, channels, height, width).to(x.device)
        
#         output = self.decoder(x_decoded, h_decoded, c_decoded, seq_len)
        
#         return output, mu, logvar
    
#     def loss_function(self, recon_x, x, mu, logvar):
#         BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
#         KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
#         return BCE + KLD

# # Example usage
# input_channels = 3  # For RGB videos
# hidden_channels = 64
# latent_size = 128
# model = CONVLSTMVAE(input_channels, hidden_channels, latent_size)

# # Assuming input shape: (batch_size, sequence_length, channels, height, width)
# sample_input = torch.randn(16, 10, 3, 64, 64)
# output, mu, logvar = model(sample_input)

# print(f"Input shape: {sample_input.shape}")
# print(f"Output shape: {output.shape}")
# print(f"Mu shape: {mu.shape}")
# print(f"Logvar shape: {logvar.shape}")

In [None]:
from __future__ import print_function

import codecs
import errno
import os
import os.path

import numpy as np
import torch
import torch.utils.data as data
from PIL import Image


class MovingMNIST(data.Dataset):
    """`MovingMNIST <http://www.cs.toronto.edu/~nitish/unsupervised_video/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        split (int, optional): Train/test split size. Number defines how many samples
            belong to test set.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in an PIL
            image and returns a transformed version. E.g, ``transforms.RandomCrop``
    """

    urls = ["https://github.com/tychovdo/MovingMNIST/raw/master/mnist_test_seq.npy.gz"]
    raw_folder = "raw"
    processed_folder = "processed"
    training_file = "moving_mnist_train.pt"
    test_file = "moving_mnist_test.pt"

    def __init__(
        self,
        root,
        train=True,
        split=1000,
        transform=None,
        target_transform=None,
        download=False,
    ):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.split = split
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError(
                "Dataset not found." + " You can use download=True to download it"
            )

        if self.train:
            self.train_data = torch.load(
                os.path.join(self.root, self.processed_folder, self.training_file)
            )
        else:
            self.test_data = torch.load(
                os.path.join(self.root, self.processed_folder, self.test_file)
            )

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (seq, target) where sampled sequences are splitted into a seq
                    and target part
        """

        # need to iterate over time
        def _transform_time(data):
            new_data = None
            for i in range(data.size(0)):
                img = Image.fromarray(data[i].numpy(), mode="L")
                new_data = (
                    self.transform(img)
                    if new_data is None
                    else torch.cat([self.transform(img), new_data], dim=0)
                )
            return new_data

        if self.train:
            seq, target = self.train_data[index, :10], self.train_data[index, 10:]
        else:
            seq, target = self.test_data[index, :10], self.test_data[index, 10:]

        if self.transform is not None:
            seq = _transform_time(seq)
        if self.target_transform is not None:
            target = _transform_time(target)

        return seq, target

    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)

    def _check_exists(self):
        return os.path.exists(
            os.path.join(self.root, self.processed_folder, self.training_file)
        ) and os.path.exists(
            os.path.join(self.root, self.processed_folder, self.test_file)
        )

    def download(self):
        """Download the Moving MNIST data if it doesn't exist in processed_folder already."""
        import gzip

        from six.moves import urllib

        if self._check_exists():
            return

        # download files
        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
            os.makedirs(os.path.join(self.root, self.processed_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print("Downloading " + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition("/")[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, "wb") as f:
                f.write(data.read())
            with open(file_path.replace(".gz", ""), "wb") as out_f, gzip.GzipFile(
                file_path
            ) as zip_f:
                out_f.write(zip_f.read())
            os.unlink(file_path)

        # process and save as torch files
        print("Processing...")

        training_set = torch.from_numpy(
            np.load(
                os.path.join(self.root, self.raw_folder, "mnist_test_seq.npy")
            ).swapaxes(0, 1)[: -self.split]
        )
        test_set = torch.from_numpy(
            np.load(
                os.path.join(self.root, self.raw_folder, "mnist_test_seq.npy")
            ).swapaxes(0, 1)[-self.split :]
        )

        with open(
            os.path.join(self.root, self.processed_folder, self.training_file), "wb"
        ) as f:
            torch.save(training_set, f)
        with open(
            os.path.join(self.root, self.processed_folder, self.test_file), "wb"
        ) as f:
            torch.save(test_set, f)

        print("Done!")

    def __repr__(self):
        fmt_str = "Dataset " + self.__class__.__name__ + "\n"
        fmt_str += "    Number of datapoints: {}\n".format(self.__len__())
        tmp = "train" if self.train is True else "test"
        fmt_str += "    Train/test: {}\n".format(tmp)
        fmt_str += "    Root Location: {}\n".format(self.root)
        tmp = "    Transforms (if any): "
        fmt_str += "{0}{1}\n".format(
            tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp))
        )
        tmp = "    Target Transforms (if any): "
        fmt_str += "{0}{1}".format(
            tmp, self.target_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
        )
        return fmt_str


In [None]:


import secrets

import easydict
import matplotlib.pyplot as plt
import torch
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm

from models_moving_mnist import LSTMAE, LSTMVAE

writer = SummaryWriter()

## visualization
def imshow(past_data, title="MovingMNIST"):
    num_img = len(past_data)
    fig = plt.figure(figsize=(4 * num_img, 4))

    for idx in range(1, num_img + 1):
        ax = fig.add_subplot(1, num_img + 1, idx)
        ax.imshow(past_data[idx - 1])
    plt.suptitle(title, fontsize=30)
    plt.savefig(f"{title}")
    plt.close()

def visualize_reconstructions(model, test_loader, device, epoch):
    model.eval()
    with torch.no_grad():
        for i, batch_data in enumerate(test_loader):
            future_data, past_data = batch_data
            batch_size = past_data.size(0)
            example_size = past_data.size(1)
            image_size = past_data.size(2), past_data.size(3)
            past_data = past_data.view(batch_size, example_size, -1).float().to(device)
            
            _, recon_x, _ = model(past_data)
            
            if i == 0:
                n_examples = min(10, batch_size)
                examples = past_data[:n_examples].cpu().view(n_examples, example_size, image_size[0], -1)
                recon_examples = recon_x[:n_examples].cpu().view(n_examples, example_size, image_size[0], -1)

                fig, axes = plt.subplots(2, n_examples, figsize=(20, 4))
                for j in range(n_examples):
                    axes[0, j].imshow(examples[j, 0], cmap='gray')
                    axes[0, j].axis('off')
                    axes[1, j].imshow(recon_examples[j, 0], cmap='gray')
                    axes[1, j].axis('off')
                plt.suptitle(f"Epoch {epoch}: Original (top) vs Reconstructed (bottom)")
                plt.savefig(f"reconstruction_epoch_{epoch}.png")
                plt.close()
                break

def train(args, model, train_loader, test_loader):
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    ## interation setup
    epochs = tqdm(range(args.max_iter // len(train_loader) + 1))

    ## training
    count = 0
    for epoch in epochs:
        model.train()
        optimizer.zero_grad()
        train_iterator = tqdm(
            enumerate(train_loader), total=len(train_loader), desc="training"
        )

        for i, batch_data in train_iterator:

            if count > args.max_iter:
                return model
            count += 1

            future_data, past_data = batch_data

            ## reshape
            batch_size = past_data.size(0)
            example_size = past_data.size(1)
            image_size = past_data.size(2), past_data.size(3)
            past_data = (
                past_data.view(batch_size, example_size, -1).float().to(args.device)
            )

            mloss, recon_x, info = model(past_data)

            # Backward and optimize
            optimizer.zero_grad()
            mloss.mean().backward()
            optimizer.step()

            train_iterator.set_postfix({"train_loss": float(mloss.mean())})
        writer.add_scalar("train_loss", float(mloss.mean()), epoch)

        model.eval()
        eval_loss = 0
        test_iterator = tqdm(
            enumerate(test_loader), total=len(test_loader), desc="testing"
        )

        with torch.no_grad():
            for i, batch_data in test_iterator:
                future_data, past_data = batch_data

                ## reshape
                batch_size = past_data.size(0)
                example_size = past_data.size(1)
                past_data = (
                    past_data.view(batch_size, example_size, -1).float().to(args.device)
                )

                mloss, recon_x, info = model(past_data)

                eval_loss += mloss.mean().item()

                test_iterator.set_postfix({"eval_loss": float(mloss.mean())})

        eval_loss = eval_loss / len(test_loader)
        writer.add_scalar("eval_loss", float(eval_loss), epoch)
        print("Evaluation Score : [{}]".format(eval_loss))

        # Visualize reconstructions every 10 epochs
        if epoch % 10 == 0:
            visualize_reconstructions(model, test_loader, args.device, epoch)

    return model


if __name__ == "__main__":

    # training dataset
    train_set = MovingMNIST(
        root=".data/mnist",
        train=True,
        download=True,
        transform=transforms.ToTensor(),
        target_transform=transforms.ToTensor(),
    )

    # test dataset
    test_set = MovingMNIST(
        root=".data/mnist",
        train=False,
        download=True,
        transform=transforms.ToTensor(),
        target_transform=transforms.ToTensor(),
    )

    args = easydict.EasyDict(
        {
            "batch_size": 512,
            "device": torch.device("cuda")
            if torch.cuda.is_available()
            else torch.device("cpu"),
            "input_size": 4096,
            "hidden_size": 2048,
            "latent_size": 1024,
            "learning_rate": 0.001,
            "max_iter": 1000,
        }
    )

    batch_size = args.batch_size
    input_size = args.input_size
    hidden_size = args.hidden_size
    latent_size = args.latent_size

    # define LSTM-based VAE model
    model = LSTMVAE(input_size, hidden_size, latent_size, device=args.device)
    model.to(args.device)

    # convert to format of data loader
    train_loader = torch.utils.data.DataLoader(
        dataset=train_set, batch_size=args.batch_size, shuffle=True
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set, batch_size=args.batch_size, shuffle=False
    )

    # training
    trained_model = train(args, model, train_loader, test_loader)

    # save model
    id_ = secrets.token_hex(nbytes=4)
    torch.save(trained_model.state_dict(), f"lstmvae{id_}.model")

    # load model
    model_to_load = LSTMVAE(input_size, hidden_size, latent_size, device=args.device)
    model_to_load.to(args.device)
    model_to_load.load_state_dict(torch.load(f"lstmvae{id_}.model"))
    model_to_load.eval()

    # show results
    ## past_data, future_data -> shape: (10,10)
    future_data, past_data = train_set[0]

    ## reshape
    example_size = past_data.size(0)
    image_size = past_data.size(1), past_data.size(2)
    past_data = past_data.view(example_size, -1).float().to(args.device)
    _, recon_data, info = model_to_load(past_data.unsqueeze(0))

    nhw_orig = past_data.view(example_size, image_size[0], -1).cpu()
    nhw_recon = (
        recon_data.squeeze(0)
        .view(example_size, image_size[0], -1)
        .detach()
        .cpu()
        .numpy()
    )

    imshow(nhw_orig, title=f"final_input{id_}")
    imshow(nhw_recon, title=f"final_output{id_}")
    plt.show()