In [1]:
from typing import Callable, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Load CIFAR-10

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 96406328.31it/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [3]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device=device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [4]:
trainloader = DeviceDataLoader(trainloader, device)
testloader = DeviceDataLoader(testloader, device)

In [5]:
@torch.no_grad()
def evaluate(model, data):
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for images, labels in data:
            # calculate outputs by running images through the network
            outputs = model(images.to(device))
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()

    return 100 * correct / total

# Toad convmodel

In [6]:
class SELayer(nn.Module):
    def __init__(self, n_channels: int, rescale_input: bool, reduction: int = 16):
        super(SELayer, self).__init__()
        self.rescale_input = rescale_input
        self.fc = nn.Sequential(
            nn.Linear(n_channels, n_channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(n_channels // reduction, n_channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, c, _, _ = x.shape
        # Average feature planes
        if self.rescale_input:
            y = torch.flatten(x, start_dim=-2, end_dim=-1).sum(dim=-1)
        else:
            y = torch.flatten(x, start_dim=-2, end_dim=-1).mean(dim=-1)
        y = self.fc(y.view(b, c)).view(b, c, 1, 1)
        return x * y.expand_as(x)


class ResidualBlock(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            height: int,
            width: int,
            kernel_size: int = 3,
            normalize: bool = False,
            activation: Callable = nn.ReLU,
            squeeze_excitation: bool = True,
            rescale_se_input: bool = True,
            **conv2d_kwargs
    ):
        super(ResidualBlock, self).__init__()

        # Calculate "same" padding
        # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
        # https://www.wolframalpha.com/input/?i=i%3D%28i%2B2x-k-%28k-1%29%28d-1%29%2Fs%29+%2B+1&assumption=%22i%22+-%3E+%22Variable%22
        assert "padding" not in conv2d_kwargs.keys()
        k = kernel_size
        d = conv2d_kwargs.get("dilation", 1)
        s = conv2d_kwargs.get("stride", 1)
        padding = (k - 1) * (d + s - 1) / (2 * s)
        assert padding == int(padding), f"padding should be an integer, was {padding:.2f}"
        padding = int(padding)

        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=(kernel_size, kernel_size),
            padding=(padding, padding),
            **conv2d_kwargs
        )
        # We use LayerNorm here since the size of the input "images" may vary based on the board size
        self.norm1 = nn.LayerNorm([in_channels, height, width]) if normalize else nn.Identity()
        self.act1 = activation()

        self.conv2 = nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=(kernel_size, kernel_size),
            padding=(padding, padding),
            **conv2d_kwargs
        )
        self.norm2 = nn.LayerNorm([in_channels, height, width]) if normalize else nn.Identity()
        self.final_act = activation()

        if in_channels != out_channels:
            self.change_n_channels = nn.Conv2d(in_channels, out_channels, (1, 1))
        else:
            self.change_n_channels = nn.Identity()

        if squeeze_excitation:
            self.squeeze_excitation = SELayer(out_channels, rescale_se_input)
        else:
            self.squeeze_excitation = nn.Identity()

    def forward(self, x):
        identity = x
        x = self.conv1(x)
        x = self.act1(self.norm1(x))
        x = self.conv2(x)
        x = self.squeeze_excitation(self.norm2(x))
        x = x + self.change_n_channels(identity)
        return self.final_act(x) 

In [8]:
# The orginal makes overfit
# toad = nn.Sequential(*[ResidualBlock(in_channels = 3, 
#                                     out_channels = 3, 
#                                     height = 32, 
#                                     width = 32,
#                                     kernel_size = 3,
#                                     normalize = False,
#                                     activation = nn.LeakyReLU) for _ in range(5)],
#                     nn.Flatten(),
#                     nn.Linear(3072,1000),
#                     nn.ReLU(),
#                     nn.Linear(1000,100),
#                     nn.ReLU(),
#                     nn.Linear(100,10)).to(device=device)

toad = nn.Sequential(ResidualBlock(in_channels = 3, 
                                    out_channels = 32, 
                                    height = 32, 
                                    width = 32,
                                    kernel_size = 5,
                                    normalize = False,
                                    activation = nn.LeakyReLU),
                     ResidualBlock(in_channels = 32, 
                                    out_channels = 64, 
                                    height = 32, 
                                    width = 32,
                                    kernel_size = 5,
                                    normalize = False,
                                    activation = nn.LeakyReLU),
                    nn.MaxPool2d(2, 2),
                    ResidualBlock(in_channels = 64, 
                                    out_channels = 128, 
                                    height = 16, 
                                    width = 16,
                                    kernel_size = 5,
                                    normalize = False,
                                    activation = nn.LeakyReLU),
                     ResidualBlock(in_channels = 128, 
                                    out_channels = 128, 
                                    height = 16, 
                                    width = 16,
                                    kernel_size = 5,
                                    normalize = False,
                                    activation = nn.LeakyReLU),
                    nn.MaxPool2d(2, 2), 
                    nn.Flatten(),
                    nn.Linear(8*8*128,128),
                    nn.ReLU(),
                    nn.Linear(128,32),
                    nn.ReLU(),
                    nn.Linear(32,10)).to(device=device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(toad.parameters(), lr=1e-3, eps=0.0003)

print(count_parameters(toad))

2684362


In [9]:
num_epochs = 21

for epoch in range(num_epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    for inputs, labels in trainloader:
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = toad(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    val = evaluate(toad, testloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Val_acc: {val:.2f}')

print('Finished Training')

Epoch [1/21], Loss: 1.5276, Val_acc: 48.08
Epoch [2/21], Loss: 1.1505, Val_acc: 54.64
Epoch [3/21], Loss: 1.0087, Val_acc: 65.65
Epoch [4/21], Loss: 0.9043, Val_acc: 70.13
Epoch [5/21], Loss: 0.6795, Val_acc: 73.12
Epoch [6/21], Loss: 0.4675, Val_acc: 73.91
Epoch [7/21], Loss: 0.5149, Val_acc: 75.03
Epoch [8/21], Loss: 0.3057, Val_acc: 73.92
Epoch [9/21], Loss: 0.2839, Val_acc: 75.25
Epoch [10/21], Loss: 0.1420, Val_acc: 73.83
Epoch [11/21], Loss: 0.1388, Val_acc: 73.42
Epoch [12/21], Loss: 0.1576, Val_acc: 73.61
Epoch [13/21], Loss: 0.1924, Val_acc: 74.39
Epoch [14/21], Loss: 0.0829, Val_acc: 73.42
Epoch [15/21], Loss: 0.1561, Val_acc: 74.42
Epoch [16/21], Loss: 0.0849, Val_acc: 74.62
Epoch [17/21], Loss: 0.1289, Val_acc: 73.66
Epoch [18/21], Loss: 0.0394, Val_acc: 73.89
Epoch [19/21], Loss: 0.0683, Val_acc: 74.26
Epoch [20/21], Loss: 0.0133, Val_acc: 74.65
Epoch [21/21], Loss: 0.0843, Val_acc: 73.88
Finished Training


In [10]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in trainloader:
        images, labels = data
        # calculate outputs by running images through the network
        outputs = toad(images.to(device))
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.to(device)).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f} %')

Accuracy of the network on the 10000 test images: 98.30 %


# Vim Original implementation

In [12]:
import os
import shutil
import sys

destination = shutil.copytree('/kaggle/input/vim-implementation/Vim-main/Vim-main', '/kaggle/working/vim')
#destination = shutil.copytree('/kaggle/input/vim-implementation/vim_compiled/vim', '/kaggle/working/vim')

In [13]:
!pip install -e vim/mamba-1p1p1

Obtaining file:///kaggle/working/vim/mamba-1p1p1
  Preparing metadata (setup.py) ... [?25ldone
Collecting triton (from mamba_ssm==1.1.1)
  Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)
Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.5/209.5 MB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: triton, mamba_ssm
  Running setup.py develop for mamba_ssm
Successfully installed mamba_ssm-1.1.1 triton-3.1.0


In [14]:
!pip install -e vim/causal-conv1d-1.1.0

Obtaining file:///kaggle/working/vim/causal-conv1d-1.1.0
  Preparing metadata (setup.py) ... [?25ldone
Collecting buildtools (from causal_conv1d==1.1.0)
  Downloading buildtools-1.0.6.tar.gz (446 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m446.5/446.5 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting argparse (from buildtools->causal_conv1d==1.1.0)
  Downloading argparse-1.4.0-py2.py3-none-any.whl.metadata (2.8 kB)
Collecting twisted (from buildtools->causal_conv1d==1.1.0)
  Downloading twisted-24.10.0-py3-none-any.whl.metadata (20 kB)
Collecting simplejson (from buildtools->causal_conv1d==1.1.0)
  Downloading simplejson-3.19.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.2 kB)
Collecting furl (from buildtools->causal_conv1d==1.1.0)
  Downloading furl-2.1.3-py2.py3-none-any.whl.metadata (1.2 kB)
Collecting redo (from buil

In [15]:
# Add path to environment
#!export PATH="$PATH:/kaggle/working/vim/mamba-1p1p1"
sys.path.insert(0, "/kaggle/working/vim/mamba-1p1p1")
sys.path.insert(0, "/kaggle/working/vim/causal-conv1d-1.1.0")

In [16]:
from vim.vim.models_mamba import VisionMamba

#vimamba = VisionMamba(img_size=32, patch_size=4, stride=2, depth = 5, emed_dim = 192, d_state=16, channels = 3, num_classes=10).to(device=device)
vimamba = VisionMamba(img_size = 32, patch_size = 4, stride = 2, depth = 5, emed_dim = 192, d_state=16, channels = 3, num_classes=10, if_bidirectional = False, if_bimamba = False,  drop_path_rate=0.0,)
#vimamba = VisionMamba(img_size = 32, patch_size=16, stride = 8, embed_dim=192, depth=5, d_state=16, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", if_cls_token=True)
vimamba = nn.DataParallel(vimamba)
vimamba.to('cuda')

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vimamba.parameters(), lr=1e-3, eps=0.0003)

print(count_parameters(vimamba))

  def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  def backward(ctx, dout):
  def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  def backward(ctx, dout):
  def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  def backward(ctx, dout):
  def forward(
  def backward(ctx, dout, *args):


1465354


In [17]:
num_epochs = 21
for epoch in range(num_epochs):  # loop over the dataset multiple times
    for inputs, labels in trainloader:
        # zero the parameter gradients
        optimizer.zero_grad()
        # print(inputs.dtype)
        # forward + backward + optimize
        outputs = vimamba(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    val = evaluate(vimamba, testloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Val_acc: {val:.2f}')

print('Finished Training')

  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):


Epoch [1/21], Loss: 1.4626, Val_acc: 48.47
Epoch [2/21], Loss: 1.1322, Val_acc: 58.30
Epoch [3/21], Loss: 1.0304, Val_acc: 64.59
Epoch [4/21], Loss: 0.7788, Val_acc: 67.81
Epoch [5/21], Loss: 0.7452, Val_acc: 70.26
Epoch [6/21], Loss: 0.4161, Val_acc: 70.19
Epoch [7/21], Loss: 0.4232, Val_acc: 72.43
Epoch [8/21], Loss: 0.4424, Val_acc: 71.55
Epoch [9/21], Loss: 0.3366, Val_acc: 72.12
Epoch [10/21], Loss: 0.3010, Val_acc: 72.03
Epoch [11/21], Loss: 0.3560, Val_acc: 72.38
Epoch [12/21], Loss: 0.2432, Val_acc: 72.30
Epoch [13/21], Loss: 0.2122, Val_acc: 72.39
Epoch [14/21], Loss: 0.2409, Val_acc: 72.12
Epoch [15/21], Loss: 0.2068, Val_acc: 72.10
Epoch [16/21], Loss: 0.0801, Val_acc: 72.58
Epoch [17/21], Loss: 0.1370, Val_acc: 72.99
Epoch [18/21], Loss: 0.0842, Val_acc: 71.86
Epoch [19/21], Loss: 0.1323, Val_acc: 72.91
Epoch [20/21], Loss: 0.0703, Val_acc: 72.67
Epoch [21/21], Loss: 0.1188, Val_acc: 73.09
Finished Training


In [18]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in trainloader:
        images, labels = data
        # calculate outputs by running images through the network
        outputs = vimamba(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.to(device)).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f} %')

Accuracy of the network on the 10000 test images: 97.99 %


# SiMBA

In [6]:
!pip install einops
!pip install mamba-ssm

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0
Collecting mamba-ssm
  Downloading mamba_ssm-2.2.2.tar.gz (85 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.4/85.4 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting triton (from mamba-ssm)
  Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)
Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.5/209.5 MB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hBuilding wheels for collected package

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import torch.fft

from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
from mamba_ssm import Mamba
import math
import numpy as np
from mamba_ssm import Mamba
from einops import rearrange, repeat, einsum

  def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  def backward(ctx, dout):
  def forward(
  def backward(ctx, dout, *args):
  def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
  def backward(ctx, grad_output):
  def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
  def backward(ctx, dout, *args):


In [3]:
class EinFFT(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.hidden_size = dim #768
        self.num_blocks = 4 
        self.block_size = self.hidden_size // self.num_blocks 
        assert self.hidden_size % self.num_blocks == 0
        self.sparsity_threshold = 0.01
        self.scale = 0.02

        self.complex_weight_1 = nn.Parameter(torch.randn(2, self.num_blocks, self.block_size, self.block_size, dtype=torch.float32) * self.scale)
        self.complex_weight_2 = nn.Parameter(torch.randn(2, self.num_blocks, self.block_size, self.block_size, dtype=torch.float32) * self.scale)
        self.complex_bias_1 = nn.Parameter(torch.randn(2, self.num_blocks, self.block_size,  dtype=torch.float32) * self.scale)
        self.complex_bias_2 = nn.Parameter(torch.randn(2, self.num_blocks, self.block_size,  dtype=torch.float32) * self.scale)

    def multiply(self, input, weights):
        return torch.einsum('...bd,bdk->...bk', input, weights)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.view(B, N, self.num_blocks, self.block_size )

        x = torch.fft.fft2(x, dim=(1,2), norm='ortho') # FFT on N dimension

        x_real_1 = F.relu(self.multiply(x.real, self.complex_weight_1[0]) - self.multiply(x.imag, self.complex_weight_1[1]) + self.complex_bias_1[0])
        x_imag_1 = F.relu(self.multiply(x.real, self.complex_weight_1[1]) + self.multiply(x.imag, self.complex_weight_1[0]) + self.complex_bias_1[1])
        x_real_2 = self.multiply(x_real_1, self.complex_weight_2[0]) - self.multiply(x_imag_1, self.complex_weight_2[1]) + self.complex_bias_2[0]
        x_imag_2 = self.multiply(x_real_1, self.complex_weight_2[1]) + self.multiply(x_imag_1, self.complex_weight_2[0]) + self.complex_bias_2[1]

        x = torch.stack([x_real_2, x_imag_2], dim=-1).float()
        x = F.softshrink(x, lambd=self.sparsity_threshold) if self.sparsity_threshold else x
        x = torch.view_as_complex(x)

        x = torch.fft.ifft2(x, dim=(1,2), norm="ortho")
        
        # RuntimeError: "fused_dropout" not implemented for 'ComplexFloat'
        x = x.to(torch.float32)
        x = x.reshape(B, N, C)
        return x

# For Fast Implementation use MambaLayer,# This implementation is slow, only for checking GFLOPS and other paramater,
# For more details please refer to https://github.com/johnma2006/mamba-minimal/blob/master/model.py     
class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state = 64, expand = 2, d_conv = 4, conv_bias = True,  bias = False ):
        """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
        super().__init__()
        self.d_model = d_model # Model dimension d_model
        self.d_state=d_state # SSM state expansion factor
        self.d_conv=d_conv  # Local convolution width
        self.expand=expand  # Block expansion factor
        self.conv_bias=conv_bias
        self.bias=bias
        self.d_inner = int(self.expand * self.d_model)
        self.dt_rank = math.ceil(self.d_model / 16)
        

        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=self.bias)

        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            bias=self.conv_bias,
            kernel_size=self.d_conv,
            groups=self.d_inner,
            padding=self.d_conv - 1,
        )

        # x_proj takes in `x` and outputs the input-specific Δ, B, C
        self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
        
        # dt_proj projects Δ from dt_rank to d_in
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)

        A = repeat(torch.arange(1, self.d_state + 1), 'n -> d n', d=self.d_inner)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(self.d_inner))
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=self.bias)
        

    def forward(self, x):
        """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
    
        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d)
        
        Official Implementation:
            class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        (b, l, d) = x.shape
        
        x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)
        (x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1)

        x = rearrange(x, 'b l d_in -> b d_in l')
        x = self.conv1d(x)[:, :, :l]
        x = rearrange(x, 'b d_in l -> b l d_in')
        
        x = F.silu(x)

        y = self.ssm(x)
        
        y = y * F.silu(res)
        
        output = self.out_proj(y)

        return output

    
    def ssm(self, x):
        """Runs the SSM. See:
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        Args:
            x: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d_in)

        Official Implementation:
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        (d_in, n) = self.A_log.shape

        # Compute ∆ A B C D, the state space parameters.
        #     A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
        #     ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
        #                                  and is why Mamba is called **selective** state spaces)
        
        A = -torch.exp(self.A_log.float())  # shape (d_in, n)
        D = self.D.float()

        x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)
        
        (delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1)  # delta: (b, l, dt_rank). B, C: (b, l, n)
        delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)
        
        y = self.selective_scan(x, delta, A, B, C, D)  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
        
        return y

    
    def selective_scan(self, u, delta, A, B, C, D):
        """Does selective scan algorithm. See:
            - Section 2 State Space Models in the Mamba paper [1]
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        This is the classic discrete state space formula:
            x(t + 1) = Ax(t) + Bu(t)
            y(t)     = Cx(t) + Du(t)
        except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
    
        Args:
            u: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
            delta: shape (b, l, d_in)
            A: shape (d_in, n)
            B: shape (b, l, n)
            C: shape (b, l, n)
            D: shape (d_in,)
    
        Returns:
            output: shape (b, l, d_in)
    
        Official Implementation:
            selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
            Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
            
        """
        (b, l, d_in) = u.shape
        n = A.shape[1]
        
        # Discretize continuous parameters (A, B)
        # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
        # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
        #   "A is the more important term and the performance doesn't change much with the simplification on B"
        deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
        deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
        
        # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
        # Note that the below is sequential, while the official implementation does a much faster parallel scan that
        # is additionally hardware-aware (like FlashAttention).
        x = torch.zeros((b, d_in, n), device=deltaA.device)
        ys = []    
        for i in range(l):
            x = deltaA[:, i] * x + deltaB_u[:, i]
            y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
            ys.append(y)
        y = torch.stack(ys, dim=1)  # shape (b, l, d_in)
        
        y = y + u * D
    
        return y


class MambaLayer(nn.Module):
    def __init__(self, dim, d_state=64, d_conv=4, expand=2):
        super().__init__()
        self.dim = dim
        self.norm = nn.LayerNorm(dim)
        self.mamba = Mamba(
            d_model=dim,  # Model dimension d_model
            d_state=d_state,  # SSM state expansion factor
            d_conv=d_conv,  # Local convolution width
            expand=expand  # Block expansion factor
        )
    def forward(self, x):
        # print('x',x.shape)
        B, L, C = x.shape
        x_norm = self.norm(x)
        x_mamba = self.mamba(x_norm)    
        return x_mamba

def rand_bbox(size, lam, scale=1):
    W = size[1] // scale
    H = size[2] // scale
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int_(W * cut_rat)
    cut_h = np.int_(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

class PVT2FFN(nn.Module):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.dwconv(x, H, W)
        x = self.act(x)
        x = self.fc2(x)
        return x

class FFN(nn.Module):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

class ClassBlock(nn.Module):
    def __init__(self, dim,  mlp_ratio, norm_layer=nn.LayerNorm, cm_type = 'mlp'):
        super().__init__()
        # self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.attn = MambaLayer(dim) #MambaBlock(d_model=dim)
        if cm_type == 'EinFFT':
            self.mlp = EinFFT(dim)
        else:
            self.mlp = FFN(dim, int(dim * mlp_ratio))  
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        cls_embed = x[:, :1]
        cls_embed = cls_embed + self.attn(x[:, :1])
        cls_embed = cls_embed + self.mlp(self.norm2(cls_embed), H, W)
        return torch.cat([cls_embed, x[:, 1:]], dim=1)



class Block_mamba(nn.Module):
    def __init__(self, 
        dim, 
        mlp_ratio,
        drop_path=0., 
        norm_layer=nn.LayerNorm, 
        sr_ratio=1, 
        cm_type = 'mlp'
    ):
        super().__init__()
        # self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.attn = MambaLayer(dim) #MambaBlock(d_model=dim)
        if cm_type == 'EinFFT':
            self.mlp = EinFFT(dim)
        else:
            self.mlp = PVT2FFN(in_features=dim, hidden_features=int(dim * mlp_ratio))       
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = x + self.drop_path(self.attn(x))
        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
        return x



class DownSamples(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
        self.norm = nn.LayerNorm(out_channels)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W

class Stem(nn.Module):
    def __init__(self, in_channels, stem_hidden_dim, out_channels):
        super().__init__()
        hidden_dim = stem_hidden_dim
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=7, stride=2,
                      padding=3, bias=False),  # 112x112
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,
                      padding=1, bias=False),  # 112x112
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,
                      padding=1, bias=False),  # 112x112
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
        )
        self.proj = nn.Conv2d(hidden_dim,
                              out_channels,
                              kernel_size=3,
                              stride=2,
                              padding=1)
        self.norm = nn.LayerNorm(out_channels)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        x = self.conv(x)
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W

class SiMBA(nn.Module):
    def __init__(self, 
        in_chans=3, 
        num_classes=1000, 
        stem_hidden_dim = 32,
        embed_dims=[64, 128, 320, 448],
        mlp_ratios=[8, 8, 4, 4], 
        drop_path_rate=0., 
        norm_layer=nn.LayerNorm,
        depths=[3, 4, 6, 3], 
        sr_ratios=[4, 2, 1, 1], 
        num_stages=4,
        token_label=True,
        cm_type='mlp',
        **kwargs
    ):
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths
        self.num_stages = num_stages

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        cur = 0
        alpha=5#
        for i in range(num_stages):
            if i == 0:
                patch_embed = Stem(in_chans, stem_hidden_dim, embed_dims[i])
            else:
                patch_embed = DownSamples(embed_dims[i - 1], embed_dims[i])

            block = nn.ModuleList([Block_mamba(
                dim = embed_dims[i], 
                mlp_ratio = mlp_ratios[i], 
                drop_path=dpr[cur + j], 
                norm_layer=norm_layer,
                sr_ratio = sr_ratios[i],
                cm_type=cm_type)   # Change this to run EinFFT based Channel Mixer, cm_type='EinFFT'
            for j in range(depths[i])])

            norm = norm_layer(embed_dims[i])
            cur += depths[i]

            setattr(self, f"patch_embed{i + 1}", patch_embed)
            setattr(self, f"block{i + 1}", block)
            setattr(self, f"norm{i + 1}", norm)

        post_layers = ['ca']
        self.post_network = nn.ModuleList([
            ClassBlock(
                dim = embed_dims[-1], 
                mlp_ratio = mlp_ratios[-1],
                norm_layer=norm_layer,
                cm_type=cm_type) # Change this to run EinFFT based Channel Mixer, cm_type='EinFFT'
            for _ in range(len(post_layers))
        ])

        # classification head
        self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
        ##################################### token_label #####################################
        self.return_dense = token_label
        self.mix_token = token_label
        self.beta = 1.0
        self.pooling_scale = 8
        if self.return_dense:
            self.aux_head = nn.Linear(
                embed_dims[-1],
                num_classes) if num_classes > 0 else nn.Identity()
        ##################################### token_label #####################################

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward_cls(self, x, H, W):
        B, N, C = x.shape
        cls_tokens = x.mean(dim=1, keepdim=True)
        x = torch.cat((cls_tokens, x), dim=1)
        for block in self.post_network:
            x = block(x, H, W)
        return x

    def forward_features(self, x):
        B = x.shape[0]
        for i in range(self.num_stages):
            patch_embed = getattr(self, f"patch_embed{i + 1}")
            block = getattr(self, f"block{i + 1}")
            x, H, W = patch_embed(x)
            for blk in block:
                x = blk(x, H, W)
            
            if i != self.num_stages - 1:
                norm = getattr(self, f"norm{i + 1}")
                x = norm(x)
                x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        x = self.forward_cls(x, H, W)[:, 0]
        norm = getattr(self, f"norm{self.num_stages}")
        x = norm(x)
        return x

    def forward(self, x):
        if not self.return_dense:
            x = self.forward_features(x)
            x = self.head(x)
            return x
        else:
            x, H, W = self.forward_embeddings(x)
            # mix token, see token labeling for details.
            if self.mix_token and self.training:
                lam = np.random.beta(self.beta, self.beta)
                patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[
                    2] // self.pooling_scale
                bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale)
                temp_x = x.clone()
                sbbx1,sbby1,sbbx2,sbby2=self.pooling_scale*bbx1,self.pooling_scale*bby1,\
                                        self.pooling_scale*bbx2,self.pooling_scale*bby2
                temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :]
                x = temp_x
            else:
                bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0
            x = self.forward_tokens(x, H, W)
            x_cls = self.head(x[:, 0])
            x_aux = self.aux_head(
                x[:, 1:]
            )  # generate classes in all feature tokens, see token labeling

            if not self.training:
                return x_cls + 0.5 * x_aux.max(1)[0]

            if self.mix_token and self.training:  # reverse "mix token", see token labeling for details.
                x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1])

                temp_x = x_aux.clone()
                temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :]
                x_aux = temp_x

                x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1])

            return x_cls, x_aux, (bbx1, bby1, bbx2, bby2)

    def forward_tokens(self, x, H, W):
        B = x.shape[0]
        x = x.view(B, -1, x.size(-1))

        for i in range(self.num_stages):
            if i != 0:
                patch_embed = getattr(self, f"patch_embed{i + 1}")
                x, H, W = patch_embed(x)
            block = getattr(self, f"block{i + 1}")
            for blk in block:
                x = blk(x, H, W)
            if i != self.num_stages - 1:
                norm = getattr(self, f"norm{i + 1}")
                x = norm(x)
                x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        x = self.forward_cls(x, H, W)
        norm = getattr(self, f"norm{self.num_stages}")
        x = norm(x)    
        return x

    def forward_embeddings(self, x):
        patch_embed = getattr(self, f"patch_embed{0 + 1}")
        x, H, W = patch_embed(x)
        x = x.view(x.size(0), H, W, -1)
        return x, H, W


class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)
        return x

@register_model
def simba_s(pretrained=False, **kwargs):
    model = SiMBA(
        stem_hidden_dim = 32,
        embed_dims = [64, 128, 320, 448], 
        mlp_ratios = [8, 8, 4, 4],
        norm_layer = partial(nn.LayerNorm, eps=1e-6), 
        depths = [3, 4, 6, 3], 
        sr_ratios = [4, 2, 1, 1], 
        **kwargs)
    model.default_cfg = _cfg()
    return model

@register_model
def simba_b(pretrained=False, **kwargs):
    model = SiMBA(
        stem_hidden_dim = 64,
        embed_dims = [64, 128, 320, 512], 
        mlp_ratios = [8, 8, 4, 4], 
        norm_layer = partial(nn.LayerNorm, eps=1e-6), 
        depths = [3, 4, 12, 3], 
        sr_ratios = [4, 2, 1, 1], 
        **kwargs)
    model.default_cfg = _cfg()
    return model

@register_model
def simba_l(pretrained=False, **kwargs):
    model = SiMBA(
        stem_hidden_dim = 64,
        embed_dims = [96, 192, 384, 512],
        mlp_ratios = [8, 8, 4, 4],
        norm_layer = partial(nn.LayerNorm, eps=1e-6), 
        depths = [3, 6, 18, 3], 
        sr_ratios = [4, 2, 1, 1], 
        **kwargs)
    model.default_cfg = _cfg()
    return model

In [58]:
simba = SiMBA(
        stem_hidden_dim = 32,
        num_classes=10, 
        embed_dims = [64, 128], 
        mlp_ratios = [4, 2],
        norm_layer = partial(nn.LayerNorm, eps=1e-6), 
        depths = [2, 3],
        num_stages = 2,
        sr_ratios = [4, 2],
        cm_type = 'mlp',
        token_label = False)
simba = nn.DataParallel(simba)
simba.to('cuda')


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(simba.parameters(), lr=1e-3, eps=0.0003)

print(count_parameters(simba))

1178474


In [52]:
num_epochs = 21
for epoch in range(num_epochs):  # loop over the dataset multiple times
    for inputs, labels in trainloader:
        # zero the parameter gradients
        optimizer.zero_grad()
        # print(inputs.dtype)
        # forward + backward + optimize
        outputs = simba(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    val = evaluate(simba, testloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Val_acc: {val:.2f}')

print('Finished Training')

Epoch [1/21], Loss: 1.4239, Val_acc: 52.56
Epoch [2/21], Loss: 1.3047, Val_acc: 62.18
Epoch [3/21], Loss: 1.0611, Val_acc: 67.18
Epoch [4/21], Loss: 0.9034, Val_acc: 68.87
Epoch [5/21], Loss: 0.7551, Val_acc: 71.70
Epoch [6/21], Loss: 0.8602, Val_acc: 71.66
Epoch [7/21], Loss: 0.3571, Val_acc: 73.08
Epoch [8/21], Loss: 0.5105, Val_acc: 74.92
Epoch [9/21], Loss: 0.3820, Val_acc: 73.97
Epoch [10/21], Loss: 0.4261, Val_acc: 73.82
Epoch [11/21], Loss: 0.2061, Val_acc: 73.37
Epoch [12/21], Loss: 0.2671, Val_acc: 73.38
Epoch [13/21], Loss: 0.2747, Val_acc: 75.66
Epoch [14/21], Loss: 0.0868, Val_acc: 75.05
Epoch [15/21], Loss: 0.1246, Val_acc: 74.21
Epoch [16/21], Loss: 0.0618, Val_acc: 74.08
Epoch [17/21], Loss: 0.1509, Val_acc: 74.62
Epoch [18/21], Loss: 0.2445, Val_acc: 73.80
Epoch [19/21], Loss: 0.1824, Val_acc: 74.03
Epoch [20/21], Loss: 0.2704, Val_acc: 74.98
Epoch [21/21], Loss: 0.0336, Val_acc: 74.97
Finished Training


In [53]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in trainloader:
        images, labels = data
        # calculate outputs by running images through the network
        outputs = simba(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.to(device)).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f} %')

Accuracy of the network on the 10000 test images: 97.97 %


In [12]:
simba = SiMBA(
        stem_hidden_dim = 32,
        num_classes=10, 
        embed_dims = [32, 256], 
        mlp_ratios = [2, 4],
        norm_layer = partial(nn.LayerNorm, eps=1e-6), 
        depths = [3, 2],
        num_stages = 2,
        sr_ratios = [4, 2],
        cm_type = 'EinFFT',
        token_label = False)
simba = nn.DataParallel(simba)
simba.to('cuda')


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(simba.parameters(), lr=1e-3, eps=0.0003)

print(count_parameters(simba))

1908746


In [13]:
num_epochs = 21
for epoch in range(num_epochs):  # loop over the dataset multiple times
    for inputs, labels in trainloader:
        # zero the parameter gradients
        optimizer.zero_grad()
        # print(inputs.dtype)
        # forward + backward + optimize
        outputs = simba(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    val = evaluate(simba, testloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Val_acc: {val:.2f}')

print('Finished Training')

  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
  x = x.to(torch.float32)


Epoch [1/21], Loss: 1.2882, Val_acc: 51.75
Epoch [2/21], Loss: 0.9905, Val_acc: 59.78
Epoch [3/21], Loss: 0.8882, Val_acc: 65.60
Epoch [4/21], Loss: 0.8957, Val_acc: 68.15
Epoch [5/21], Loss: 0.7609, Val_acc: 69.62
Epoch [6/21], Loss: 0.6948, Val_acc: 69.75
Epoch [7/21], Loss: 0.4352, Val_acc: 71.32
Epoch [8/21], Loss: 0.4932, Val_acc: 71.01
Epoch [9/21], Loss: 0.4304, Val_acc: 70.30
Epoch [10/21], Loss: 0.3568, Val_acc: 71.08
Epoch [11/21], Loss: 0.2797, Val_acc: 70.32
Epoch [12/21], Loss: 0.3022, Val_acc: 70.65
Epoch [13/21], Loss: 0.1784, Val_acc: 70.76
Epoch [14/21], Loss: 0.3074, Val_acc: 70.86
Epoch [15/21], Loss: 0.2911, Val_acc: 70.17
Epoch [16/21], Loss: 0.1818, Val_acc: 70.92
Epoch [17/21], Loss: 0.0872, Val_acc: 70.96
Epoch [18/21], Loss: 0.1443, Val_acc: 70.31
Epoch [19/21], Loss: 0.0263, Val_acc: 71.55
Epoch [20/21], Loss: 0.2221, Val_acc: 70.91
Epoch [21/21], Loss: 0.0784, Val_acc: 71.65
Finished Training


In [14]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in trainloader:
        images, labels = data
        # calculate outputs by running images through the network
        outputs = simba(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.to(device)).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f} %')

Accuracy of the network on the 10000 test images: 98.11 %


# Refined SiMBA

In [53]:
import math
import torch.nn as nn
from mamba_ssm import Mamba
from einops.layers.torch import Rearrange
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

class EinFFT(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.hidden_size = dim #768
        self.num_blocks = 4 
        self.block_size = self.hidden_size // self.num_blocks 
        assert self.hidden_size % self.num_blocks == 0
        self.sparsity_threshold = 0.01
        self.scale = 0.02

        self.complex_weight_1 = nn.Parameter(torch.randn(2, self.num_blocks, self.block_size, self.block_size, dtype=torch.float32) * self.scale)
        self.complex_weight_2 = nn.Parameter(torch.randn(2, self.num_blocks, self.block_size, self.block_size, dtype=torch.float32) * self.scale)
        self.complex_bias_1 = nn.Parameter(torch.randn(2, self.num_blocks, self.block_size,  dtype=torch.float32) * self.scale)
        self.complex_bias_2 = nn.Parameter(torch.randn(2, self.num_blocks, self.block_size,  dtype=torch.float32) * self.scale)

    def multiply(self, input, weights):
        return torch.einsum('...bd,bdk->...bk', input, weights)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.view(B, N, self.num_blocks, self.block_size )

        x = torch.fft.fft2(x, dim=(1,2), norm='ortho') # FFT on N dimension

        x_real_1 = F.relu(self.multiply(x.real, self.complex_weight_1[0]) - self.multiply(x.imag, self.complex_weight_1[1]) + self.complex_bias_1[0])
        x_imag_1 = F.relu(self.multiply(x.real, self.complex_weight_1[1]) + self.multiply(x.imag, self.complex_weight_1[0]) + self.complex_bias_1[1])
        x_real_2 = self.multiply(x_real_1, self.complex_weight_2[0]) - self.multiply(x_imag_1, self.complex_weight_2[1]) + self.complex_bias_2[0]
        x_imag_2 = self.multiply(x_real_1, self.complex_weight_2[1]) + self.multiply(x_imag_1, self.complex_weight_2[0]) + self.complex_bias_2[1]

        x = torch.stack([x_real_2, x_imag_2], dim=-1).float()
        x = F.softshrink(x, lambd=self.sparsity_threshold) if self.sparsity_threshold else x
        x = torch.view_as_complex(x)

        x = torch.fft.ifft2(x, dim=(1,2), norm="ortho")
        
        # RuntimeError: "fused_dropout" not implemented for 'ComplexFloat'
        x = x.to(torch.float32)
        x = x.reshape(B, N, C)
        return x


class PVT2FFN(nn.Module):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.dwconv(x, H, W)
        x = self.act(x)
        x = self.fc2(x)
        return x

class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H, W):
        #print(x.shape)
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)
        return x



class RSimbaBlock(nn.Module):
    def __init__(self, embed_dim, cm_type='mlp'):
        super().__init__()
        self.mamba = Mamba(d_model=embed_dim, d_state=16, d_conv=8, expand=2)
        if cm_type == 'mlp':
            self.mlp = PVT2FFN(in_features=embed_dim, hidden_features=int(embed_dim * 2))
        else:
            self.mlp = EinFFT(embed_dim)
        self.norm = nn.LayerNorm(embed_dim)
        

    def forward(self, x):
        #print(x.shape)
        x = self.mamba(x) + x + self.mlp(x,8,8)
        return self.norm(x)

class RSimbaBackBone(nn.Module):
    def __init__(self, embed_dim, n_layers, seq_len=None, global_pool=True, cm_type='mlp'):
        super().__init__()
        self.blocks = nn.Sequential(*[RSimbaBlock(embed_dim) for _ in range(n_layers)])
        self.global_pool = global_pool #for classification or other supervised learning.

    def forward(self, x):
        #for input (bs, n, d) it returns either (bs, n, d) or (bs, d) is global_pool
        out = self.blocks(x) if not self.global_pool else torch.mean(self.blocks(x),1)
        return out


class RSiMBA(nn.Module):
    def __init__(self, patch_size=4, img_size=32, n_channels=3, embed_dim=128, n_layers=6, cm_type='mlp'):
        super().__init__()

        patch_dim = n_channels*patch_size*patch_size
        self.rearrange = Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)',
                                   p1=patch_size, p2=patch_size)

        self.func = nn.Sequential(self.rearrange,
                                  nn.Linear(patch_dim, embed_dim),
                                  RSimbaBackBone(embed_dim,6),
                                  nn.Linear(embed_dim, 10))

    def forward(self, x):
        #print(x.shape)
        return self.func(x)

In [54]:
simba = RSiMBA()
simba = nn.DataParallel(simba)
simba.to('cuda')

    
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(simba.parameters(), lr=1e-3, eps=0.0003)

print(count_parameters(simba))

1125002


In [62]:
xxx = Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)',
                                   p1=4, p2=4)
test = torch.rand(128,3,32,32)
print(test.shape)
test = xxx(test)
xxx = Rearrange(' b (h w) (c p1 p2)  -> b c (h p1) (w p2)',
                                   p1=4, p2=4, h=8,w=8, c = 3)

print(test.shape)
test = xxx(test)
print(test.shape)

torch.Size([128, 3, 32, 32])
torch.Size([128, 64, 48])
torch.Size([128, 3, 32, 32])


In [55]:
num_epochs = 21
for epoch in range(num_epochs):  # loop over the dataset multiple times
    for inputs, labels in trainloader:
        # zero the parameter gradients
        optimizer.zero_grad()
        #print(inputs.shape)
        # forward + backward + optimize
        outputs = simba(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    val = evaluate(simba, testloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Val_acc: {val:.2f}')

print('Finished Training')

Epoch [1/21], Loss: 1.2045, Val_acc: 53.11
Epoch [2/21], Loss: 1.1145, Val_acc: 61.73
Epoch [3/21], Loss: 0.6917, Val_acc: 68.06
Epoch [4/21], Loss: 0.6500, Val_acc: 71.66
Epoch [5/21], Loss: 0.7271, Val_acc: 74.23
Epoch [6/21], Loss: 0.5736, Val_acc: 74.88
Epoch [7/21], Loss: 0.3483, Val_acc: 75.22
Epoch [8/21], Loss: 0.4633, Val_acc: 74.03
Epoch [9/21], Loss: 0.3855, Val_acc: 73.63
Epoch [10/21], Loss: 0.1300, Val_acc: 75.00
Epoch [11/21], Loss: 0.1753, Val_acc: 73.84
Epoch [12/21], Loss: 0.1425, Val_acc: 74.85
Epoch [13/21], Loss: 0.0887, Val_acc: 74.27
Epoch [14/21], Loss: 0.1194, Val_acc: 74.68
Epoch [15/21], Loss: 0.0904, Val_acc: 75.26
Epoch [16/21], Loss: 0.1376, Val_acc: 75.69
Epoch [17/21], Loss: 0.0581, Val_acc: 74.50
Epoch [18/21], Loss: 0.1201, Val_acc: 76.09
Epoch [19/21], Loss: 0.0202, Val_acc: 75.10
Epoch [20/21], Loss: 0.2076, Val_acc: 74.93
Epoch [21/21], Loss: 0.0319, Val_acc: 75.70
Finished Training


In [56]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in trainloader:
        images, labels = data
        # calculate outputs by running images through the network
        outputs = simba(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.to(device)).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f} %')

Accuracy of the network on the 10000 test images: 98.36 %
