In [1]:
from typing import List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# if not torch.cuda.is_available():
#     import torch_xla.core.xla_model as xm  # For use of a TPU (CUDATimer class can't be used on TPUs)
#     import time

In [2]:
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(torch.cuda.current_device()))

True
NVIDIA H100 80GB HBM3


In [3]:
class CUDATimer:

    def __init__(self):
        self._starter = torch.cuda.Event(enable_timing=True)
        self._ender = torch.cuda.Event(enable_timing=True)

    def reset(self):
        self._starter.record()

    def time(self):
        self._ender.record()
        torch.cuda.synchronize()
        forward_time = self._starter.elapsed_time(self._ender)
        return forward_time

In [4]:
class ConvLSTMCell(nn.Module):
    """
    A ConvLSTM implementation using Conv2d operations.
    """

    def __init__(
        self,
        batch_size,
        input_size,
        hidden_size,
        height,
        width,
        device,
        bias=True,
        padding_mode: str = "zeros",
        **kwargs
    ):
        super(ConvLSTMCell, self).__init__()

        # Parameters
        self.batch_size = batch_size
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.height = height
        self.width = width
        self.bias = bias
        self.device = device


        # Hidden (h) and cell (c) states
        self.h = torch.zeros(size=(batch_size, hidden_size, height, width), device=device)
        self.c = torch.zeros(size=(batch_size, hidden_size, height, width), device=device)

        # Convolution weights
        conv = []
        conv.append(nn.Conv2d(
            in_channels=input_size + hidden_size,
            out_channels=hidden_size*4,
            kernel_size=3,
            stride=1,
            padding=1,
            padding_mode=padding_mode,
            bias=bias
        ))
        self.conv = nn.Sequential(*conv)

    def reset_states(self, batch_size, height: int = 64, width: int = 64):
        if self.batch_size == batch_size and self.height == height and self.width == width:
            self.h = torch.zeros_like(self.h)
            self.c = torch.zeros_like(self.c)
        else:
            self.batch_size = batch_size
            self.height = height
            self.width = width
            self.h = torch.zeros(size=(batch_size, self.hidden_size, height, width), device=self.device)
            self.c = torch.zeros(size=(batch_size, self.hidden_size, height, width), device=self.device)

    def reset_parameters(self):
        # Uniform distribution initialization of lstm weights with respect to
        # the number of lstm cells in the layer
        std = 1.0/np.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def forward(
        self,
        x: torch.Tensor,
        h_prev: torch.Tensor = None,
        c_prev: torch.Tensor = None
    ):
        """
        ...
        """

        # Set the previous hidden and cell states if not provided
        h_prev = self.h if h_prev is None else h_prev
        c_prev = self.c if c_prev is None else c_prev

        # Perform input and recurrent convolutions
        conv_res = self.conv(torch.cat((x, h_prev), dim=1))

        # Split result into input and gate activations
        netin, igate, fgate, ogate = torch.split(conv_res, self.hidden_size, dim=1)

        # Compute input and gate activations
        act_input = torch.tanh(netin)
        act_igate = torch.sigmoid(igate)
        act_fgate = torch.sigmoid(fgate)
        act_ogate = torch.sigmoid(ogate)

        # Compute the new cell and hidden states
        c_curr = act_fgate*c_prev + act_igate*act_input
        h_curr = act_ogate*torch.tanh(c_curr)

        # Update the hidden and cells states
        self.h = h_curr
        self.c = c_curr

        return h_curr, c_curr


class ConvLSTM(nn.Module):
    """
    A ConvLSTM implementation using Conv1d instead of Conv2d operations.
    """

    def __init__(
        self,
        batch_size: int = 8,
        input_size: int = 1,
        hidden_sizes: List = [4, 4],
        height: int = 16,
        width: int = 16,
        device: torch.device = torch.device("cpu"),
        bias: bool = True,
        padding_mode: str = "zeros",
        tanh_encoder: bool = False,
        norm = nn.LayerNorm,  # Must be overridden with a OmegaConfig
        **kwargs
    ):
        super(ConvLSTM, self).__init__()

        self.device = device

        self.encoder = []
        self.encoder.append(torch.nn.Conv2d(
            in_channels=input_size,
            out_channels=hidden_sizes[0],
            kernel_size=1  # When using padding, pass the padding_mode here and use CylinderPad with geopotential
        ))
        if tanh_encoder: self.encoder.append(torch.nn.Tanh())
        self.encoder = torch.nn.Sequential(*self.encoder)

        self.clstm = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()
        for h in hidden_sizes:
            self.clstm.append(ConvLSTMCell(
                batch_size=batch_size,
                input_size=h,
                hidden_size=h,
                height=height,
                width=width,
                device=device,
                bias=bias,
                padding_mode=padding_mode
            ))
            self.norms.append(norm(normalized_shape=(height, width)))
        self.clstm = torch.nn.ModuleList(self.clstm)

        self.decoder = torch.nn.Conv2d(
            in_channels=hidden_sizes[-1],
            out_channels=input_size,
            kernel_size=1
        )

    def forward(
        self,
        x: torch.Tensor,
        tf_steps: int = 10,  # teacher forcing steps
    ) -> torch.Tensor:
        """
        ...
        """

        # Initialize hidden and cell states of LSTM to zero
        b, t, c, h, w = x.shape
        self.reset(batch_size=b, height=h, width=w)
        outs = []

        # Iterate over sequence
        for t in range(x.shape[1]):
            # During teacher forcing, take the ground truth as input. In closed loop, take the last model output.
            x_t = x[:, t] if t < tf_steps else x_t
            # Forward the current time step's input through the model
            x_t = self.encoder(x_t)  # [b, hidden, h, w]
            for clstm_cell, norm in zip(self.clstm, self.norms):
                z_t = x_t
                z_t, _ = clstm_cell(z_t)  # [b, hidden, h, w]
                x_t = x_t + z_t  # residual connection
                x_t = norm(x_t)
            x_t = self.decoder(x_t)
            outs.append(x_t)

        return torch.stack(outs, dim=1)

    def reset(self, batch_size: int = 8, height: int = 64, width: int = 64):
        for clstm_cell in self.clstm:
            clstm_cell.reset_states(batch_size=batch_size, height=height, width=width)
        self.zeros = torch.zeros(size=(batch_size, 1, height, width), device=self.device)

In [5]:
class minConvLSTMCell(nn.Module):
    """
    A ConvLSTM implementation using Conv2d operations.
    """

    def __init__(
        self,
        batch_size: int = 1,
        input_size: int = 1,
        hidden_size: int = 16,
        height: int = 16,
        width: int = 16,
        device: torch.device = torch.device("cpu"),
        bias: bool = True,
        padding_mode: str = "zeros",
        exponentiate: bool = False,
        **kwargs
    ):
        super(minConvLSTMCell, self).__init__()

        # Parameters
        self.batch_size = batch_size
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.height = height
        self.width = width
        self.bias = bias
        self.device = device
        self.exponentiate = exponentiate

        self.h = torch.zeros(size=(batch_size, hidden_size, height, width), device=device)

        # Convolution weights
        conv = []
        conv.append(nn.Conv2d(
            in_channels=hidden_size,
            out_channels=hidden_size*3,
            kernel_size=3,
            stride=1,
            padding=1,
            padding_mode=padding_mode,
            bias=bias
        ))
        self.conv = nn.Sequential(*conv)

    def reset_states(self, batch_size, height: int = 16, width: int = 16):
        if self.batch_size == batch_size and self.height == height and self.width == width:
            self.h = torch.zeros_like(self.h)
        else:
            self.batch_size = batch_size
            self.height = height
            self.width = width
            self.h = torch.zeros(size=(batch_size, self.hidden_size, height, width), device=self.device)

    @staticmethod
    def g(x):
        return torch.where(x >= 0, x + 0.5, torch.sigmoid(x))

    @staticmethod
    def log_g(x):
        return torch.where(x >= 0, (F.relu(x) + 0.5).log(), -F.softplus(-x))

    @staticmethod
    def parallel_scan_log(log_coeffs: torch.Tensor, log_values: torch.Tensor) -> torch.Tensor:
        # log_coeffs: [B, T, hid, h, w]
        # log_values: [B, T+1, hid, h, w]

        a_star = F.pad(torch.cumsum(log_coeffs, dim=1), (0, 0, 0, 0, 0, 0, 1, 0))
        log_h0_plus_b_star = torch.logcumsumexp(log_values - a_star, dim=1)
        log_h = a_star + log_h0_plus_b_star

        return torch.exp(log_h)[:, 1:].contiguous()  # [B, T, hid, h, w]

    def forward(
        self,
        x: torch.Tensor,
    ):
        """
        ...
        """

        bt, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
        b = self.batch_size
        t = int(bt/b)

        h0 = self.h.unsqueeze(1)

        f_gate, i_gate, h_tilde = torch.chunk(
            self.conv(x).view(b, t, self.hidden_size*3, h, w).contiguous(),
            chunks=3,
            dim=2
        )

        diff = i_gate - f_gate if self.exponentiate else F.softplus(-f_gate) - F.softplus(-i_gate)
        log_f = -F.softplus(diff)
        log_i = -F.softplus(-diff)
        log_h0 = self.log_g(h0)
        log_h_tilde = self.log_g(h_tilde)
        out = self.parallel_scan_log(log_f, torch.cat([log_h0, log_i + log_h_tilde], dim=1))
        self.h = out[:, -1]
        out = out.view(bt, self.hidden_size, h, w)

        return out

    def step(self, x_t: torch.Tensor, h_prev: torch.Tensor = None) -> torch.Tensor:
        # sequential mode of minLSTM trained in log-space
        # x_t:
        # h_prev:

        h_prev = self.h if h_prev is None else h_prev

        # I get nan exactly in this block until h_curr
        f_t, i_t, h_tilde_t = torch.chunk(self.conv(x_t), chunks=3, dim=1)

        if self.exponentiate:
            f_t, i_t = torch.exp(f_t), torch.exp(i_t)
        else:
            f_t, i_t = torch.sigmoid(f_t), torch.sigmoid(i_t)
        h_tilde_t = self.g(h_tilde_t)
        f_prime_t = f_t / (f_t + i_t)
        i_prime_t = i_t / (f_t + i_t)
        h_curr = f_prime_t * h_prev + i_prime_t * h_tilde_t

        self.h = h_curr

        return h_curr


class minConvLSTM(nn.Module):

    def __init__(
        self,
        batch_size: int = 8,
        input_size: int = 1,
        hidden_sizes: List = [4, 4],
        height: int = 16,
        width: int = 16,
        device: torch.device = torch.device("cpu"),
        bias: bool = True,
        padding_mode: str = "zeros",
        tanh_encoder: bool = False,
        exponentiate: bool = False,
        norm = nn.LayerNorm,
        **kwargs
    ):
        super(minConvLSTM, self).__init__()

        self.device = device

        self.encoder = []
        self.encoder.append(torch.nn.Conv2d(
            in_channels=input_size,
            out_channels=hidden_sizes[0],
            kernel_size=1,  # When using padding, pass the padding_mode here and use CylinderPad with geopotential
        ))
        if tanh_encoder: self.encoder.append(torch.nn.Tanh())
        self.encoder = torch.nn.Sequential(*self.encoder)

        self.clstm = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()
        for h in hidden_sizes:
            self.clstm.append(minConvLSTMCell(
                batch_size=batch_size,
                input_size=h,
                hidden_size=h,
                height=height,
                width=width,
                device=device,
                bias=bias,
                padding_mode=padding_mode,
                exponentiate=exponentiate
            ))
            self.norms.append(norm(normalized_shape=(height, width)))

        self.decoder = torch.nn.Conv2d(
            in_channels=hidden_sizes[-1],
            out_channels=input_size,
            kernel_size=1
        )

    def forward(self, x: torch.Tensor, tf_steps: int = 500, test: bool = True) -> torch.Tensor:
        b, t, c, h, w = x.shape
        tf_steps = min(tf_steps, t)
        self.reset(batch_size=b, height=h, width=w)

        x_tf = x[:, :tf_steps].reshape(b*tf_steps, c, h, w)
        outs = []

        #
        # Parallel mode for teacher forcing
        x_tf = self.encoder(x_tf)
        for clstm_cell, norm in zip(self.clstm, self.norms):
            # apply skip connection + post layer group norm
            z = x_tf
            z = clstm_cell(z)
            x_tf = z + x_tf
            x_tf = norm(x_tf)
        out = self.decoder(x_tf).view(b, tf_steps, c, h, w)
        outs.append(out)

        #
        # Sequential mode for closed loop prediction
        x_t = out[:, -1] # no need for .clone() because changing x_t does not change out -- tested in notebook
        # Iterate over sequence
        for t in range(t-tf_steps):
            # Forward the current time step's input through the model
            x_t = self.encoder(x_t)  # [B, C, H, W] C -> num_channels (hidden state)
            for clstm_cell, norm in zip(self.clstm, self.norms):
                z_t = x_t
                z_t = clstm_cell.step(z_t)  # [B, C, H, W]
                x_t = x_t + z_t  # residual connection
                x_t = norm(x_t)
            x_t = self.decoder(x_t) # [B, 1, I, H, W]
            outs.append(x_t.unsqueeze(1))
        outs = torch.cat(outs, dim=1)

        return outs

    def reset(self, batch_size: int = 8, height: int = 16, width: int = 16):
        for clstm_cell in self.clstm:
            clstm_cell.reset_states(batch_size=batch_size, height=height, width=width)

In [6]:
def sequential(x, n_reps, model):

    if torch.cuda.is_available():
        timer = CUDATimer()
        timer.reset()
    else:
        a = time.time()

    with torch.no_grad():
        for i in range(n_reps):
            model(x)
            if not torch.cuda.is_available(): xm.mark_step()  # Crucial for synchronizing TPU execution

    if torch.cuda.is_available():
        print("Sequential:", timer.time(), "ms")
    else:
        print("Sequential:", time.time()-a, "seconds")

In [7]:
def parallel(x, n_reps, model):

    if torch.cuda.is_available():
        timer = CUDATimer()
        timer.reset()
    else:
        a = time.time()

    with torch.no_grad():
        #b, t, c, h, w = x.shape
        #data = x.view(b*t, c, h, w)
        #for i in range(warmup):
        #    model(data)
        #
        #timer.reset()
        #
        for i in range(n_reps):
            model(x)  # convolution on full input
            if not torch.cuda.is_available(): xm.mark_step()  # Crucial for synchronizing TPU execution

    if torch.cuda.is_available():
        print("Parallel:  ", timer.time(), "ms")
    else:
        print("Parallel:  ", time.time()-a, "seconds")

In [15]:
def benchmark(shape, n_reps):
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = xm.xla_device()
    #device = torch.device("cpu")
    B, T, C, H, W = shape
    print("\nBenchmarking shape", shape, "on device", device)

    x = torch.randn(B, T, C, H, W).to(device=device)
    clstm = ConvLSTM(
        batch_size=B,
        input_size=C,
        hidden_sizes=[12, 12, 12, 12, 12, 12, 12, 12, 12, 12],
        height=H,
        width=W,
        device=device
    ).to(device=device)

    mclstm = minConvLSTM(
        batch_size=B,
        input_size=C,
        hidden_sizes=[20, 20, 20, 20, 20, 20, 20, 20, 20, 20],
        height=H,
        width=W,
        device=device
    ).to(device=device)
    
    sequential(x=x, n_reps=n_reps, model=clstm)
    parallel(x=x, n_reps=n_reps, model=mclstm)


In [16]:
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True

n_reps = 10

seqlen = 10

B, T, C, H, W = 4, seqlen, 1, 4, 4
benchmark(shape=(B, T, C, H, W), n_reps=n_reps)

B, T, C, H, W = 4, seqlen, 1, 4, 4
benchmark(shape=(B, T, C, H, W), n_reps=n_reps)

#B, T, C, H, W = 4, seqlen, 1, 8, 32
#benchmark(shape=(B, T, C, H, W), n_reps=n_reps)

B, T, C, H, W = 4, seqlen, 1, 8, 8
benchmark(shape=(B, T, C, H, W), n_reps=n_reps)

B, T, C, H, W = 4, seqlen, 1, 16, 16
benchmark(shape=(B, T, C, H, W), n_reps=n_reps)

B, T, C, H, W = 4, seqlen, 1, 32, 32
benchmark(shape=(B, T, C, H, W), n_reps=n_reps)

B, T, C, H, W = 4, seqlen, 1, 64, 64
benchmark(shape=(B, T, C, H, W), n_reps=n_reps)

B, T, C, H, W = 4, seqlen, 1, 128, 128
benchmark(shape=(B, T, C, H, W), n_reps=n_reps)

B, T, C, H, W = 4, seqlen, 1, 256, 256
benchmark(shape=(B, T, C, H, W), n_reps=n_reps)

B, T, C, H, W = 4, seqlen, 1, 512, 512
benchmark(shape=(B, T, C, H, W), n_reps=n_reps)

B, T, C, H, W = 4, seqlen, 1, 4, 4
benchmark(shape=(B, T, C, H, W), n_reps=n_reps)

B, T, C, H, W = 4, seqlen, 1, 32, 32
benchmark(shape=(B, T, C, H, W), n_reps=n_reps)




Benchmarking shape (4, 10, 1, 4, 4) on device cuda
Sequential: 103.60486602783203 ms
Parallel:   28.965824127197266 ms

Benchmarking shape (4, 10, 1, 4, 4) on device cuda
Sequential: 103.73737335205078 ms
Parallel:   28.57004737854004 ms

Benchmarking shape (4, 10, 1, 8, 8) on device cuda
Sequential: 104.56626892089844 ms
Parallel:   96.80924987792969 ms

Benchmarking shape (4, 10, 1, 16, 16) on device cuda
Sequential: 102.74240112304688 ms
Parallel:   35.618560791015625 ms

Benchmarking shape (4, 10, 1, 32, 32) on device cuda
Sequential: 103.731201171875 ms
Parallel:   28.887008666992188 ms

Benchmarking shape (4, 10, 1, 64, 64) on device cuda
Sequential: 138.51504516601562 ms
Parallel:   52.65385437011719 ms

Benchmarking shape (4, 10, 1, 128, 128) on device cuda
Sequential: 136.5806427001953 ms


OutOfMemoryError: CUDA out of memory. Tried to allocate 56.00 MiB. GPU 0 has a total capacity of 79.10 GiB of which 19.88 MiB is free. Including non-PyTorch memory, this process has 79.07 GiB memory in use. Of the allocated memory 78.15 GiB is allocated by PyTorch, and 320.19 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)