In [None]:
import torch
import torch.nn as nn
from torch import fft
import torch
import numpy as np
from scipy.signal import cont2discrete

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class LMUFFTCell(nn.Module):
    def __init__(self, input_size, hidden_size, memory_size, seq_len, theta):
        super(LMUFFTCell, self).__init__()

        self.hidden_size = hidden_size
        self.memory_size = memory_size
        self.seq_len = seq_len
        self.theta = theta

        self.W_u = nn.Linear(in_features = input_size, out_features = 1)
        self.f_u = nn.ReLU()
        self.W_h = nn.Linear(in_features = memory_size + input_size, out_features = hidden_size)
        self.f_h = nn.ReLU()

        A, B = self.stateSpaceMatrices()
        self.register_buffer("A", A) # [memory_size, memory_size]
        self.register_buffer("B", B) # [memory_size, 1]

        H, fft_H = self.impulse()
        self.register_buffer("H", H) # [memory_size, seq_len]
        self.register_buffer("fft_H", fft_H) # [memory_size, seq_len + 1]

    def stateSpaceMatrices(self):
        """ Returns the discretized state space matrices A and B """

        Q = np.arange(self.memory_size, dtype = np.float64).reshape(-1, 1)
        R = (2*Q + 1) / self.theta
        i, j = np.meshgrid(Q, Q, indexing = "ij")

        # Continuous
        A = R * np.where(i < j, -1, (-1.0)**(i - j + 1))
        B = R * ((-1.0)**Q)
        C = np.ones((1, self.memory_size))
        D = np.zeros((1,))

        # Convert to discrete
        A, B, C, D, dt = cont2discrete(
            system = (A, B, C, D), 
            dt = 1.0, 
            method = "zoh"
        )

        # To torch.tensor
        A = torch.from_numpy(A).float() # [memory_size, memory_size]
        B = torch.from_numpy(B).float() # [memory_size, 1]
        
        return A, B

    def impulse(self):
        """ Returns the matrices H and the 1D Fourier transform of H (Equations 23, 26 of the paper) """

        H = []
        A_i = torch.eye(self.memory_size).to(self.A.device) 
        for t in range(self.seq_len):
            H.append(A_i @ self.B)
            A_i = self.A @ A_i

        H = torch.cat(H, dim = -1) # [memory_size, seq_len]
        fft_H = fft.rfft(H, n = 2*self.seq_len, dim = -1) # [memory_size, seq_len + 1]

        return H, fft_H

    def forward(self, x):
        """
        Parameters:
            x (torch.tensor): 
                Input of size [batch_size, seq_len, input_size]
        """
        batch_size, seq_len, input_size = x.shape
        # print("batch_size, seq_len, input_size", batch_size, seq_len, input_size)

        # Equation 18 of the paper
        u = self.f_u(self.W_u(x)) # [batch_size, seq_len, 1]

        # Equation 26 of the paper
        fft_input = u.permute(0, 2, 1) # [batch_size, 1, seq_len]
        fft_u = fft.rfft(fft_input, n = 2*seq_len, dim = -1) # [batch_size, seq_len, seq_len+1]

        # Element-wise multiplication (uses broadcasting)
        # [batch_size, 1, seq_len+1] * [1, memory_size, seq_len+1]
        temp = fft_u * self.fft_H.unsqueeze(0) # [batch_size, memory_size, seq_len+1]

        m = fft.irfft(temp, n = 2*seq_len, dim = -1) # [batch_size, memory_size, seq_len+1]
        m = m[:, :, :seq_len] # [batch_size, memory_size, seq_len]
        m = m.permute(0, 2, 1) # [batch_size, seq_len, memory_size]

        # Equation 20 of the paper (W_m@m + W_x@x  W@[m;x])
        input_h = torch.cat((m, x), dim = -1) # [batch_size, seq_len, memory_size + input_size]
        h = self.f_h(self.W_h(input_h)) # [batch_size, seq_len, hidden_size]

        h_n = h[:, -1, :] # [batch_size*T, hidden_size]

        return h, h_n
    
    def forward_recurrent(self, x, m_last):
        u = self.f_u(self.W_u(x)) # [batch_size, seq_len, 1]
        # A: torch.Size([512, 512]), m_last: torch.Size([256, 512]), B: torch.Size([512, 1]), u: torch.Size([256, 1])
        m = m_last @ self.A.T + u @ self.B.T  # [batch_size, memory_size]
        input_h = torch.cat((m, x), dim = -1) # [batch_size, seq_len, memory_size + input_size]
        h = self.f_h(self.W_h(input_h)) # [batch_size, seq_len, hidden_size]

        return h, m

class LMU(nn.Module):
    def __init__(self, dim, T, use_all_h=True):
        super().__init__()
        self.dim = dim
        self.hidden_size = dim
        self.memory_size = dim
        self.use_all_h = use_all_h
        self.lmu = LMUFFTCell(input_size=dim, hidden_size=self.hidden_size, memory_size=self.memory_size, seq_len=T, theta=T)
        # self.lmu = LMUFFTCell(input_size=dim, hidden_size=self.hidden_size, memory_size=self.memory_size, seq_len=64, theta=64)

        self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1)
        self.proj_bn = nn.BatchNorm1d(dim)

    def forward(self, x):
        x = x.transpose(-1,-2).contiguous() # B, C, N -> B, N, C
        h, _ = self.lmu(x) # B, N, C; B, C
        
        x = h.transpose(-1,-2).contiguous() #if self.use_all_h else h_n.unsqueeze(-1) # h or h_n

        x = self.proj_conv(x)
        x = self.proj_bn(x)

        return x

class LinearFFN(nn.Module):
    def __init__(self, in_features, pre_norm=False, hidden_features=None, out_features=None, drop=0., act_type='spike'):
        super().__init__()
        self.pre_norm = pre_norm
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1_linear  = nn.Linear(in_features, hidden_features)
        self.fc1_ln = nn.LayerNorm(hidden_features)
        self.fc1_lif = get_act(act_type if act_type == 'spike' else 'gelu', tau=2.0, detach_reset=True)

        self.fc2_linear = nn.Linear(hidden_features, out_features)
        self.fc2_ln = nn.LayerNorm(out_features)
        self.fc2_lif = get_act(act_type, tau=2.0, detach_reset=True)
 
        self.c_hidden = hidden_features
        self.c_output = out_features

    def forward(self, x):
        B,C,N = x.shape
        # 
        x = x.permute(0,2,1) # B, N, C
        # x = x.reshape(B*N, C)
        if self.pre_norm:
            x = self.fc1_ln(x)
            x = self.fc1_lif(x)
            x = self.fc1_linear(x)
            
            x = self.fc2_ln(x)
            x = self.fc2_lif(x)
            x = self.fc2_linear(x)

        else:
            x = self.fc1_linear(x)
            x = self.fc1_ln(x)
            x = self.fc1_lif(x)

            x = self.fc2_linear(x)
            x = self.fc2_ln(x)
            x = self.fc2_lif(x)

        # x = x.reshape(B, N, self.c_output)
        x = x.permute(0,2,1) # B, C, N
        return x
    
class Block(nn.Module):
    def __init__(self, dim, T, mlp_ratio=4., act_type='spike'):
        super().__init__()

        self.attn = LMU(dim=dim, T=T)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = LinearFFN(in_features=dim, hidden_features=mlp_hidden_dim, act_type=act_type)

    def forward(self, x):
        x = x + self.attn(x)
        x = x + self.mlp(x)
        return x

class perm(nn.Module):
    def __init__(self, a, b, c) -> None:
        super().__init__()
        self.a = a
        self.b = b
        self.c = c

    def forward(self, x):
        return x.permute(self.a,self.b,self.c).contiguous()

def get_act(act_type = 'spike', **act_params):
    '''
    act_type :- spike, gelu, relu, identity

    output :- class <act_type>
    '''
    act_type = act_type.lower()
    # if act_type == 'spike':
    #     return MultiStepLIFNode(**act_params, backend='cupy')
    #     # act_params['init_tau'] = act_params.pop('tau')
    #     # return MultiStepParametricLIFNode(**act_params, backend="cupy")
    if act_type == 'relu':
        return nn.ReLU()
    elif act_type == 'gelu':
        return nn.GELU()
    elif act_type == 'identity':
        return nn.Identity()
    
def get_conv_block(T, dim, act_type, kernel_size=3, padding=1, groups=1):
    return [
        perm(0,2,1),
        nn.Conv1d(dim, dim, kernel_size=kernel_size, stride=1, padding=padding, groups=groups, bias=False),
        nn.BatchNorm1d(dim),
        perm(1,2,0),
        get_act(act_type, tau=2.0, detach_reset=True),
        perm(2,1,0)
]

class Conv1d4EB(nn.Module):
    def __init__(self, T=128, vw_dim=256, act_type='spike'):
        super().__init__()

        kernel_size = 3
        padding = 1
        groups = 1
        self.proj_conv = nn.ModuleList(
            [perm(0,2,1)]+\
            get_conv_block(T, vw_dim, act_type)+\
            get_conv_block(T, vw_dim, act_type, kernel_size=kernel_size, padding=padding, groups=groups)+\
            get_conv_block(T, vw_dim, act_type, kernel_size=kernel_size, padding=padding, groups=groups)+\
            get_conv_block(T, vw_dim, act_type, kernel_size=kernel_size, padding=padding, groups=groups)+\
            [perm(0,2,1)]
        )
        self.rpe_conv = nn.ModuleList(
            [perm(0,2,1)]+\
            get_conv_block(T, vw_dim, act_type, kernel_size=kernel_size, padding=padding, groups=groups)+\
            [perm(0,2,1)]
        )
        self.act_loss = 0.0
        
    def forward(self, x):

        for ele in self.proj_conv:
            x = ele(x)

        x_rpe = x.clone()
        for ele in self.rpe_conv:
            x_rpe = ele(x_rpe)

        x = x + x_rpe
        
        return x 

class LMU_RNN(nn.Module):
    def __init__(self, input_size, num_layers, hidden_size, act_type='relu', T=784, test_mode='all_seq',with_head_lif=False):
        super().__init__()
        self.with_head_lif = with_head_lif
        self.test_mode = test_mode

        self.in_layer = nn.Linear(input_size, hidden_size)

        self.patch_embed = Conv1d4EB(T=T, vw_dim=hidden_size, act_type=act_type)

        self.block = nn.ModuleList([
            Block(dim=hidden_size, T=T, act_type=act_type)
            for j in range(num_layers)
        ])

        # classification head
        if self.with_head_lif:
            self.head_bn = nn.BatchNorm1d(hidden_size)
            self.head_lif = get_act(act_type, tau=2.0, detach_reset=True)

    def forward_features(self, x):
        x = self.patch_embed(x)
        for blk in self.block:
            x = blk(x)
        return x

    def forward(self, x):
        self.act_loss = 0.0
        x = self.in_layer(x)
        x = x.permute(0, 2, 1).contiguous()
        x = self.forward_features(x)    # b, d, t -> b, d, t

        if self.with_head_lif:
            x = self.head_bn(x)         # b, d, t 
            x = self.head_lif(x)        # b, d, t

        x = x.permute(0, 2, 1).contiguous()
        
        return x


In [None]:
class MnistLMU(nn.Module):
    r"""
    Very Simple 2 layer LSTM with an fc layer on last steps hidden dimension
    """
    def __init__(self, input_size, hidden_size, codebook_size):
        super(MnistLMU, self).__init__()
        self.rnn = LMU_RNN(input_size=2, hidden_size=128, num_layers=2, T=32)

        self.fc = nn.Sequential(nn.Linear(hidden_size, hidden_size//4),
                                nn.ReLU(),
                                nn.Linear(hidden_size//4, codebook_size))
        # Add pad and start token to embedding size
        self.word_embedding = nn.Embedding(codebook_size+2, input_size)
    
    def forward(self, x):
        x = self.word_embedding(x)
        output = self.rnn(x)
        output = output[:, -1, :]
        return self.fc(output)

class MnistLSTM(nn.Module):
    r"""
    Very Simple 2 layer LSTM with an fc layer on last steps hidden dimension
    """
    def __init__(self, input_size, hidden_size, codebook_size):
        super(MnistLSTM, self).__init__()
        self.rnn = nn.LSTM(input_size=2, hidden_size=128, num_layers=2, batch_first=True)
        self.fc = nn.Sequential(nn.Linear(hidden_size, hidden_size // 4),
                                nn.ReLU(),
                                nn.Linear(hidden_size // 4, codebook_size))
        # Add pad and start token to embedding size
        self.word_embedding = nn.Embedding(codebook_size+2, input_size)
    
    def forward(self, x):
        x = self.word_embedding(x)
        output, _ = self.rnn(x)
        output = output[:, -1, :]
        return self.fc(output)

In [2]:
from dataset.mnist_dataset import MnistDataset
from utils.inception_score import inceptionScore
import torch
import torch.nn as nn
from torch import fft
import torch
import numpy as np
from scipy.signal import cont2discrete

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

mnist = MnistDataset('train', 'data/train/images', im_channels=3)

FID = inceptionScore(mnist, device, 500)

100%|██████████| 10/10 [00:00<00:00, 136.76it/s]


Found 60000 images for split train
getting statistics of data...


100%|██████████| 2/2 [00:29<00:00, 14.68s/it]

complete !





In [None]:
import os
import torch
import pickle
from tqdm import tqdm
from model.vqvae import get_model
from tools.train_lstm import MnistLSTM, MnistLMU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



def generate(config):
    r"""
    Method for generating images after training vqvae and lstm
    1. Create config
    2. Create and load vqvae model
    3. Create and load LSTM model
    4. Generate 100 encoder outputs from trained LSTM
    5. Pass them to the trained vqvae decoder
    6. Save the generated image
    :param args:
    :return:
    """
    
    ########## Load VQVAE Model ##############
    vqvae_model = get_model(config).to(device)
    vqvae_model.to(device)
    assert os.path.exists(os.path.join(config['train_params']['task_name'],
                                                  config['train_params']['ckpt_name'])), "Train the vqvae model first"
    vqvae_model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'],
                                                  config['train_params']['ckpt_name']), map_location=device))
        
    vqvae_model.eval()
    #########################################
    
    ################ Generate Samples #############
    generated_quantized_indices = []
    mnist_encodings = pickle.load(open(os.path.join(config['train_params']['task_name'],
                                                    config['train_params']['output_train_dir'],
                                                    'mnist_encodings.pkl'), 'rb'))
    mnist_encodings_length = mnist_encodings.reshape(mnist_encodings.size(0), -1).shape[-1]
    #########################################

    ########## Load LSTM ##############
    default_lstm_config = {
        'input_size': 2,
        'hidden_size': 128,
        'codebook_size': config['model_params']['codebook_size']
    }
    
    if config['model_params']['rnn_type']=='lstm':
        model = MnistLSTM(input_size=default_lstm_config['input_size'],
                        hidden_size=default_lstm_config['hidden_size'],
                        codebook_size=default_lstm_config['codebook_size']).to(device)
    elif config['model_params']['rnn_type']=='lmu':
        model = MnistLMU(input_size=default_lstm_config['input_size'],
                        hidden_size=default_lstm_config['hidden_size'],
                        codebook_size=default_lstm_config['codebook_size'],
                        T=64).to(device)
    model.to(device)
    assert os.path.exists(os.path.join(config['train_params']['task_name'],
                                                    'best_mnist_lstm.pth')), "Train the lstm first"
    model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'],
                                                    f'best_mnist_{config['model_params']['rnn_type']}.pth'), map_location=device))
    model.eval()
    
    # Assume fixed contex size
    context_size = 64
    num_samples = 1000
    print('Generating Samples', mnist_encodings_length)
    for _ in tqdm(range(num_samples)):
        # Initialize with start token
        ctx = torch.ones((1)).to(device) * (config['model_params']['codebook_size'])
        
        for i in range(mnist_encodings_length):
            padded_ctx = ctx
            if len(ctx) < context_size:
                # Pad context with pad token
                padded_ctx = torch.nn.functional.pad(padded_ctx, (0, context_size - len(ctx)), "constant",
                                                  config['model_params']['codebook_size']+1)
            # print(padded_ctx[None, :].shape)
            out = model(padded_ctx[None, :][:,-64:].long().to(device))
            probs = torch.nn.functional.softmax(out, dim=-1)
            pred = torch.multinomial(probs[0], num_samples=1)
            # Update the context with the new prediction
            ctx = torch.cat([ctx, pred])
            # print(padded_ctx.shape, pred.shape, ctx.shape)
        generated_quantized_indices.append(ctx[1:][None, :])
    
    ######## Decode the Generated Indices ##########
    generated_quantized_indices = torch.cat(generated_quantized_indices, dim=0)
    h = int(generated_quantized_indices[0].size(-1)**0.5)
    quantized_indices = generated_quantized_indices.reshape((generated_quantized_indices.size(0), h, h)).long()
    quantized_indices = torch.nn.functional.one_hot(quantized_indices, config['model_params']['codebook_size'])
    quantized_indices = quantized_indices.permute((0, 3, 1, 2))
    output = vqvae_model.decode_from_codebook_indices(quantized_indices.float())
    
    # Transform from -1, 1 range to 0,1
    output = (output + 1) / 2
    
    if config['model_params']['in_channels'] == 3:
        # Just because we took input as cv2.imread which is BGR so make it RGB
        output = output[:, [2, 1, 0], :, :]

    return output

**LMU model**

In [7]:
config = {'model_params': {'in_channels': 3, 
                           'convbn_blocks': 4, 
                           'conv_kernel_size': [3, 3, 3, 2], 
                           'conv_kernel_strides': [2, 2, 1, 1], 
                           'convbn_channels': [3, 16, 32, 8, 8], 
                           'conv_activation_fn': 'leaky', 
                           'transpose_bn_blocks': 4, 
                           'transposebn_channels': [8, 8, 32, 16, 3], 
                           'transpose_kernel_size': [3, 4, 4, 4], 
                           'transpose_kernel_strides': [1, 2, 1, 1], 
                           'transpose_activation_fn': 'leaky', 
                           'latent_dim': 8, 'codebook_size': 20, 
                           'rnn_type': 'lmu'}, 
          'train_params': {'task_name': 'vqvae_latent_8_colored_codebook_20', 
                           'batch_size': 64, 'epochs': 20, 'lr': 0.005, 
                           'crit': 'l2', 'reconstruction_loss_weight': 5, 
                           'codebook_loss_weight': 1, 
                           'commitment_loss_weight': 0.2, 
                           'ckpt_name': 'best_vqvae_latent_8_colored_codebook_20.pth', 
                           'seed': 111, 'save_training_image': True, 
                           'train_path': 'data/train/images', 
                           'test_path': 'data/test/images', 'output_train_dir': 'output'}}

out = generate(config)

{'model_params': {'in_channels': 3, 'convbn_blocks': 4, 'conv_kernel_size': [3, 3, 3, 2], 'conv_kernel_strides': [2, 2, 1, 1], 'convbn_channels': [3, 16, 32, 8, 8], 'conv_activation_fn': 'leaky', 'transpose_bn_blocks': 4, 'transposebn_channels': [8, 8, 32, 16, 3], 'transpose_kernel_size': [3, 4, 4, 4], 'transpose_kernel_strides': [1, 2, 1, 1], 'transpose_activation_fn': 'leaky', 'latent_dim': 8, 'codebook_size': 20, 'rnn_type': 'lmu'}, 'train_params': {'task_name': 'vqvae_latent_8_colored_codebook_20', 'batch_size': 64, 'epochs': 20, 'lr': 0.005, 'crit': 'l2', 'reconstruction_loss_weight': 5, 'codebook_loss_weight': 1, 'commitment_loss_weight': 0.2, 'ckpt_name': 'best_vqvae_latent_8_colored_codebook_20.pth', 'seed': 111, 'save_training_image': True, 'train_path': 'data/train/images', 'test_path': 'data/test/images', 'output_train_dir': 'output'}}


  vqvae_model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'],
  model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'],


Generating Samples 64


100%|██████████| 1000/1000 [02:39<00:00,  6.25it/s]


torch.Size([1000, 20, 8, 8])
torch.Size([1000, 8, 8, 8])


In [9]:
fid = FID.calculate_fid_for_generatorSamples(out)

calculating statistics of generated data...


100%|██████████| 2/2 [00:29<00:00, 14.79s/it]


In [13]:
fid.item()

96.29911003394864

**LSTM model**

In [11]:
config = {'model_params': {'in_channels': 3, 
                           'convbn_blocks': 4, 
                           'conv_kernel_size': [3, 3, 3, 2], 
                           'conv_kernel_strides': [2, 2, 1, 1], 
                           'convbn_channels': [3, 16, 32, 8, 8], 
                           'conv_activation_fn': 'leaky', 
                           'transpose_bn_blocks': 4, 
                           'transposebn_channels': [8, 8, 32, 16, 3], 
                           'transpose_kernel_size': [3, 4, 4, 4], 
                           'transpose_kernel_strides': [1, 2, 1, 1], 
                           'transpose_activation_fn': 'leaky', 
                           'latent_dim': 8, 'codebook_size': 20, 
                           'rnn_type': 'lstm'}, 
          'train_params': {'task_name': 'vqvae_latent_8_colored_codebook_20', 
                           'batch_size': 64, 'epochs': 20, 'lr': 0.005, 
                           'crit': 'l2', 'reconstruction_loss_weight': 5, 
                           'codebook_loss_weight': 1, 
                           'commitment_loss_weight': 0.2, 
                           'ckpt_name': 'best_vqvae_latent_8_colored_codebook_20.pth', 
                           'seed': 111, 'save_training_image': True, 
                           'train_path': 'data/train/images', 
                           'test_path': 'data/test/images', 'output_train_dir': 'output'}}

out = generate(config)

  vqvae_model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'],
  model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'],


{'model_params': {'in_channels': 3, 'convbn_blocks': 4, 'conv_kernel_size': [3, 3, 3, 2], 'conv_kernel_strides': [2, 2, 1, 1], 'convbn_channels': [3, 16, 32, 8, 8], 'conv_activation_fn': 'leaky', 'transpose_bn_blocks': 4, 'transposebn_channels': [8, 8, 32, 16, 3], 'transpose_kernel_size': [3, 4, 4, 4], 'transpose_kernel_strides': [1, 2, 1, 1], 'transpose_activation_fn': 'leaky', 'latent_dim': 8, 'codebook_size': 20, 'rnn_type': 'lstm'}, 'train_params': {'task_name': 'vqvae_latent_8_colored_codebook_20', 'batch_size': 64, 'epochs': 20, 'lr': 0.005, 'crit': 'l2', 'reconstruction_loss_weight': 5, 'codebook_loss_weight': 1, 'commitment_loss_weight': 0.2, 'ckpt_name': 'best_vqvae_latent_8_colored_codebook_20.pth', 'seed': 111, 'save_training_image': True, 'train_path': 'data/train/images', 'test_path': 'data/test/images', 'output_train_dir': 'output'}}
Generating Samples 64


100%|██████████| 1000/1000 [00:31<00:00, 31.89it/s]

torch.Size([1000, 20, 8, 8])
torch.Size([1000, 8, 8, 8])





In [14]:
fid = FID.calculate_fid_for_generatorSamples(out)

calculating statistics of generated data...


100%|██████████| 2/2 [00:30<00:00, 15.40s/it]


In [15]:
fid.item()

155.01308510968738