## Effieciecny and accuracy improvement
This code is run on GPU. This code is build to make the testing process faster

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd drive/MyDrive/ece_project2
!pwd

In [3]:
# #!/bin/bash
# !kaggle datasets download abhinavnayak/catsvdogs-transformed -p "/content/drive/MyDrive/ece_project2"

In [4]:
# import zipfile
# with zipfile.ZipFile('/content/drive/MyDrive/ece_project2/catsvdogs-transformed.zip', 'r') as zip_ref:
#     zip_ref.extractall('/content/drive/MyDrive/ece_project2/cats-dogs')

### Data Preprocessing

In [5]:
import os
import shutil

base_dir = "cats-dogs/train_transformed"
output_dir = "dataset"

cats_dir = os.path.join(output_dir, "cats")
dogs_dir = os.path.join(output_dir, "dogs")

os.makedirs(cats_dir, exist_ok=True)
os.makedirs(dogs_dir, exist_ok=True)


for filename in os.listdir(base_dir):
    if filename.lower().startswith("cat"):
        shutil.move(os.path.join(base_dir, filename), os.path.join(cats_dir, filename))
    elif filename.lower().startswith("dog"):
        shutil.move(os.path.join(base_dir, filename), os.path.join(dogs_dir, filename))

print("Files have been successfully sorted into 'cats' and 'dogs' folders.")

In [6]:
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder

dest_dataset_dir = "dataset"

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

dataset = ImageFolder(root=dest_dataset_dir, transform=transform)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

### Get Vision Mamba

In [7]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange

class SimplifiedMamba(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2, device=None, dtype=None):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(expand * d_model)

        # Input projection
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, **factory_kwargs)

        # Convolution for local mixing
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding=d_conv - 1,
            **factory_kwargs,
        )
        self.activation = nn.SiLU()

        # Output projection
        self.out_proj = nn.Linear(self.d_inner, d_model, **factory_kwargs)

    def forward(self, hidden_states):
        """
        hidden_states: (batch, seq_len, d_model)
        """
        batch, seqlen, dim = hidden_states.shape

        xz = self.in_proj(hidden_states)
        x, z = xz.chunk(2, dim=-1)

        x = self.conv1d(x.transpose(1, 2)).transpose(1, 2)
        x = self.activation(x)

        if x.shape[1] != z.shape[1]:
            x = x[:, :z.shape[1], :]

        out = self.out_proj(x * z)
        return out


In [None]:
import torch
import torch.nn as nn
from functools import partial
from timm.models.layers import DropPath, trunc_normal_


class SimplifiedVisionMamba(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        stride=16,
        depth=12,
        embed_dim=192,
        num_classes=1000,
        d_state=16,
        drop_rate=0.1,
        drop_path_rate=0.1,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim

        # Patch Embedding
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            stride=stride,
            in_chans=3,
            embed_dim=embed_dim,
        )
        num_patches = self.patch_embed.num_patches

        # Positional Embeddings
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        # Transformer Blocks
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.layers = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    mixer_cls=partial(SimplifiedMamba, d_state=d_state),
                    norm_cls=nn.LayerNorm,
                    drop_path=dpr[i],
                )
                for i in range(depth)
            ]
        )


        # Final Classifier Head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialize Weights
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.ones_(m.weight)

    def forward(self, x):
        # Patch Embedding
        x = self.patch_embed(x)

        # Add Positional Embeddings
        x = x + self.pos_embed
        x = self.pos_drop(x)

        # Transformer Layers
        for layer in self.layers:
            x, _ = layer(x)

        # Classification Head
        x = self.norm(x[:, 0])  # CLS token output
        x = self.head(x)
        return x


class PatchEmbed(nn.Module):
    """Simple 2D Patch Embedding"""
    def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (
            (img_size - patch_size) // stride + 1,
            (img_size - patch_size) // stride + 1,
        )
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=stride
        )

    def forward(self, x):
        x = self.proj(x)  # Convert to patches
        x = x.flatten(2).transpose(1, 2)  # Flatten into patch tokens
        return x


class Block(nn.Module):
    """Transformer Block with Mamba Mixer"""
    def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm, drop_path=0.):
        super().__init__()
        self.mixer = mixer_cls(dim)
        self.norm1 = norm_cls(dim)
        self.norm2 = norm_cls(dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x, residual=None):
        # Self-attention with residual
        residual = x
        x = self.norm1(x)
        x = self.mixer(x)
        x = residual + self.drop_path(x)  # Add & Norm

        # MLP with residual
        residual = x
        x = self.norm2(x)
        return x, residual


In [9]:
# !pip install fvcore

### Training and evaluation

#### Full patches train and test

In [10]:
import torch
import torch.optim as optim
import torch.nn as nn
from fvcore.nn import FlopCountAnalysis

# Training function
def train(model, train_loader, criterion, optimizer, device, num_epochs=20):
    total_flops = 0
    for epoch in range(num_epochs):
        model.train()  # Set the model to training mode
        train_loss = 0
        train_correct = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            if total_flops == 0:
                flop_analyzer = FlopCountAnalysis(model, images)
                total_flops = flop_analyzer.total()
                print(f"FLOPs per forward pass: {total_flops} FLOPs")

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            train_correct += (preds == labels).sum().item()

        train_accuracy = 100 * train_correct / len(train_loader.dataset)
        avg_train_loss = train_loss / len(train_loader)
        print(f"Epoch {epoch + 1}/{num_epochs}: Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%")

    print("Training complete!")

def test(model, test_loader, criterion, device, num_iterations=10):
    total_test_loss = 0
    total_test_accuracy = 0

    for iteration in range(num_iterations):
        model.eval()
        test_loss = 0
        test_correct = 0

        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)

                test_loss += loss.item()
                _, preds = torch.max(outputs, 1)
                test_correct += (preds == labels).sum().item()

        test_accuracy = 100 * test_correct / len(test_loader.dataset)
        avg_test_loss = test_loss / len(test_loader)

        print(f"Iteration {iteration + 1}: Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
        total_test_loss += avg_test_loss
        total_test_accuracy += test_accuracy

    avg_loss = total_test_loss / num_iterations
    avg_accuracy = total_test_accuracy / num_iterations
    print(f"\nAverage Test Loss over {num_iterations} runs: {avg_loss:.4f}")
    print(f"Average Test Accuracy over {num_iterations} runs: {avg_accuracy:.2f}%")


In [11]:
model = SimplifiedVisionMamba(num_classes=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)


train(model, train_loader, criterion, optimizer, device, num_epochs=35)

test(model, test_loader, criterion, device)




FLOPs per forward pass: 17833961472 FLOPs
Epoch 1/35: Train Loss: 0.6925, Train Accuracy: 53.14%
Epoch 2/35: Train Loss: 0.6923, Train Accuracy: 54.23%
Epoch 3/35: Train Loss: 0.6756, Train Accuracy: 55.92%
Epoch 4/35: Train Loss: 0.6738, Train Accuracy: 57.74%
Epoch 5/35: Train Loss: 0.6666, Train Accuracy: 57.54%
Epoch 6/35: Train Loss: 0.6625, Train Accuracy: 59.97%
Epoch 7/35: Train Loss: 0.6616, Train Accuracy: 57.20%
Epoch 8/35: Train Loss: 0.6547, Train Accuracy: 58.55%
Epoch 9/35: Train Loss: 0.6680, Train Accuracy: 55.92%
Epoch 10/35: Train Loss: 0.6433, Train Accuracy: 61.46%
Epoch 11/35: Train Loss: 0.6407, Train Accuracy: 60.45%
Epoch 12/35: Train Loss: 0.6364, Train Accuracy: 60.11%
Epoch 13/35: Train Loss: 0.6282, Train Accuracy: 62.34%
Epoch 14/35: Train Loss: 0.6347, Train Accuracy: 61.60%
Epoch 15/35: Train Loss: 0.6227, Train Accuracy: 63.08%
Epoch 16/35: Train Loss: 0.6121, Train Accuracy: 64.77%
Epoch 17/35: Train Loss: 0.5998, Train Accuracy: 63.35%
Epoch 18/35: Tr

#### Drop 50% patches train and test

In [12]:
import torch
import torch.nn as nn
from functools import partial
from timm.models.layers import DropPath, trunc_normal_


class SimplifiedVisionMambaWithDrop(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        stride=16,
        depth=12,
        embed_dim=192,
        num_classes=1000,
        d_state=16,
        drop_rate=0.1,
        drop_path_rate=0.1,
        drop_patch_prob=0.5,  # Probability to drop patches
    ):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim

        # Patch Embedding with patch dropping
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            stride=stride,
            in_chans=3,
            embed_dim=embed_dim,
            drop_patch_prob=drop_patch_prob,
        )
        num_patches = self.patch_embed.num_patches

        # Positional Embeddings
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        # Transformer Blocks
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.layers = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    mixer_cls=partial(SimplifiedMamba, d_state=d_state),
                    norm_cls=nn.LayerNorm,
                    drop_path=dpr[i],
                )
                for i in range(depth)
            ]
        )

        # Final Classifier Head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialize Weights
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.ones_(m.weight)

    def forward(self, x):
        # Patch Embedding
        x = self.patch_embed(x)

        # Add Positional Embeddings
        x = x + self.pos_embed[:, :x.size(1), :]  # Adjust for dropped patches
        x = self.pos_drop(x)

        # Transformer Layers
        for layer in self.layers:
            x, _ = layer(x)

        # Classification Head
        x = self.norm(x[:, 0])  # CLS token output
        x = self.head(x)
        return x


class PatchEmbed(nn.Module):
    """2D Patch Embedding with Patch Dropping"""
    def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3, embed_dim=768, drop_patch_prob=0.5):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (
            (img_size - patch_size) // stride + 1,
            (img_size - patch_size) // stride + 1,
        )
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.drop_patch_prob = drop_patch_prob

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=stride
        )

    def forward(self, x):
        x = self.proj(x)  # Convert to patches
        x = x.flatten(2).transpose(1, 2)  # Flatten into patch tokens [B, num_patches, embed_dim]

        # Drop patches probabilistically
        if self.drop_patch_prob > 0.0 and self.training:
            batch_size, num_patches, _ = x.size()
            patch_mask = torch.rand(num_patches, device=x.device) > self.drop_patch_prob
            x = x[:, patch_mask, :]

        return x


class Block(nn.Module):
    """Transformer Block with Mamba Mixer"""
    def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm, drop_path=0.):
        super().__init__()
        self.mixer = mixer_cls(dim)
        self.norm1 = norm_cls(dim)
        self.norm2 = norm_cls(dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x, residual=None):
        # Self-attention with residual
        residual = x
        x = self.norm1(x)
        x = self.mixer(x)
        x = residual + self.drop_path(x)  # Add & Norm

        # MLP with residual
        residual = x
        x = self.norm2(x)
        return x, residual


class SimplifiedMamba(nn.Module):
    """A simplified version of Mamba Mixer."""
    def __init__(self, dim, d_state=16):
        super().__init__()
        self.fc1 = nn.Linear(dim, d_state)
        self.act = nn.ReLU()
        self.fc2 = nn.Linear(d_state, dim)

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))


In [13]:
model = SimplifiedVisionMambaWithDrop(num_classes=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

train(model, train_loader, criterion, optimizer, device, num_epochs=35)

test(model, test_loader, criterion, device)



FLOPs per forward pass: 1215965184 FLOPs
Epoch 1/35: Train Loss: 0.6999, Train Accuracy: 51.66%
Epoch 2/35: Train Loss: 0.6867, Train Accuracy: 54.16%
Epoch 3/35: Train Loss: 0.6882, Train Accuracy: 55.10%
Epoch 4/35: Train Loss: 0.6810, Train Accuracy: 54.16%
Epoch 5/35: Train Loss: 0.6850, Train Accuracy: 54.83%
Epoch 6/35: Train Loss: 0.6815, Train Accuracy: 55.58%
Epoch 7/35: Train Loss: 0.6799, Train Accuracy: 54.90%
Epoch 8/35: Train Loss: 0.6767, Train Accuracy: 55.17%
Epoch 9/35: Train Loss: 0.6766, Train Accuracy: 54.50%
Epoch 10/35: Train Loss: 0.6756, Train Accuracy: 56.19%
Epoch 11/35: Train Loss: 0.6762, Train Accuracy: 57.94%
Epoch 12/35: Train Loss: 0.6740, Train Accuracy: 56.05%
Epoch 13/35: Train Loss: 0.6671, Train Accuracy: 56.32%
Epoch 14/35: Train Loss: 0.6696, Train Accuracy: 57.40%
Epoch 15/35: Train Loss: 0.6737, Train Accuracy: 56.39%
Epoch 16/35: Train Loss: 0.6656, Train Accuracy: 58.08%
Epoch 17/35: Train Loss: 0.6723, Train Accuracy: 56.86%
Epoch 18/35: Tra

#### Drop 80% patches train and test

In [14]:
import torch
import torch.nn as nn
from functools import partial
from timm.models.layers import DropPath, trunc_normal_


class SimplifiedVisionMambaWith80Drop(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        stride=16,
        depth=12,
        embed_dim=192,
        num_classes=1000,
        d_state=16,
        drop_rate=0.1,
        drop_path_rate=0.1,
        drop_patch_prob=0.8,  # Probability to drop patches
    ):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim

        # Patch Embedding with patch dropping
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            stride=stride,
            in_chans=3,
            embed_dim=embed_dim,
            drop_patch_prob=drop_patch_prob,
        )
        num_patches = self.patch_embed.num_patches

        # Positional Embeddings
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        # Transformer Blocks
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.layers = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    mixer_cls=partial(SimplifiedMamba, d_state=d_state),
                    norm_cls=nn.LayerNorm,
                    drop_path=dpr[i],
                )
                for i in range(depth)
            ]
        )

        # Final Classifier Head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialize Weights
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.ones_(m.weight)

    def forward(self, x):
        # Patch Embedding
        x = self.patch_embed(x)

        # Add Positional Embeddings
        x = x + self.pos_embed[:, :x.size(1), :]  # Adjust for dropped patches
        x = self.pos_drop(x)

        # Transformer Layers
        for layer in self.layers:
            x, _ = layer(x)

        # Classification Head
        x = self.norm(x[:, 0])  # CLS token output
        x = self.head(x)
        return x


class PatchEmbed(nn.Module):
    """2D Patch Embedding with Patch Dropping"""
    def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3, embed_dim=768, drop_patch_prob=0.8):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (
            (img_size - patch_size) // stride + 1,
            (img_size - patch_size) // stride + 1,
        )
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.drop_patch_prob = drop_patch_prob

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=stride
        )

    def forward(self, x):
        x = self.proj(x)  # Convert to patches
        x = x.flatten(2).transpose(1, 2)  # Flatten into patch tokens [B, num_patches, embed_dim]

        # Drop patches probabilistically
        if self.drop_patch_prob > 0.0 and self.training:
            batch_size, num_patches, _ = x.size()
            patch_mask = torch.rand(num_patches, device=x.device) > self.drop_patch_prob
            x = x[:, patch_mask, :]

        return x


class Block(nn.Module):
    """Transformer Block with Mamba Mixer"""
    def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm, drop_path=0.):
        super().__init__()
        self.mixer = mixer_cls(dim)
        self.norm1 = norm_cls(dim)
        self.norm2 = norm_cls(dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x, residual=None):
        # Self-attention with residual
        residual = x
        x = self.norm1(x)
        x = self.mixer(x)
        x = residual + self.drop_path(x)  # Add & Norm

        # MLP with residual
        residual = x
        x = self.norm2(x)
        return x, residual


class SimplifiedMamba(nn.Module):
    """A simplified version of Mamba Mixer."""
    def __init__(self, dim, d_state=16):
        super().__init__()
        self.fc1 = nn.Linear(dim, d_state)
        self.act = nn.ReLU()
        self.fc2 = nn.Linear(d_state, dim)

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

In [15]:
model = SimplifiedVisionMambaWith80Drop(num_classes=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

train(model, train_loader, criterion, optimizer, device, num_epochs=35)

test(model, test_loader, criterion, device)



FLOPs per forward pass: 1042556928 FLOPs
Epoch 1/35: Train Loss: 0.7019, Train Accuracy: 49.63%
Epoch 2/35: Train Loss: 0.6894, Train Accuracy: 53.21%
Epoch 3/35: Train Loss: 0.6939, Train Accuracy: 51.52%
Epoch 4/35: Train Loss: 0.6914, Train Accuracy: 53.21%
Epoch 5/35: Train Loss: 0.6878, Train Accuracy: 52.60%
Epoch 6/35: Train Loss: 0.6784, Train Accuracy: 54.63%
Epoch 7/35: Train Loss: 0.6849, Train Accuracy: 53.89%
Epoch 8/35: Train Loss: 0.6897, Train Accuracy: 53.62%
Epoch 9/35: Train Loss: 0.6819, Train Accuracy: 53.21%
Epoch 10/35: Train Loss: 0.6868, Train Accuracy: 51.05%
Epoch 11/35: Train Loss: 0.6838, Train Accuracy: 54.02%
Epoch 12/35: Train Loss: 0.6888, Train Accuracy: 53.96%
Epoch 13/35: Train Loss: 0.6833, Train Accuracy: 52.20%
Epoch 14/35: Train Loss: 0.6814, Train Accuracy: 54.50%
Epoch 15/35: Train Loss: 0.6788, Train Accuracy: 56.32%
Epoch 16/35: Train Loss: 0.6835, Train Accuracy: 53.35%
Epoch 17/35: Train Loss: 0.6825, Train Accuracy: 54.36%
Epoch 18/35: Tra

#### Drop 20% patches train and test

In [16]:
import torch
import torch.nn as nn
from functools import partial
from timm.models.layers import DropPath, trunc_normal_


class SimplifiedVisionMambaWith20Drop(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        stride=16,
        depth=12,
        embed_dim=192,
        num_classes=1000,
        d_state=16,
        drop_rate=0.1,
        drop_path_rate=0.1,
        drop_patch_prob=0.2,  # Probability to drop patches
    ):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim

        # Patch Embedding with patch dropping
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            stride=stride,
            in_chans=3,
            embed_dim=embed_dim,
            drop_patch_prob=drop_patch_prob,
        )
        num_patches = self.patch_embed.num_patches

        # Positional Embeddings
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        # Transformer Blocks
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.layers = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    mixer_cls=partial(SimplifiedMamba, d_state=d_state),
                    norm_cls=nn.LayerNorm,
                    drop_path=dpr[i],
                )
                for i in range(depth)
            ]
        )

        # Final Classifier Head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialize Weights
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.ones_(m.weight)

    def forward(self, x):
        # Patch Embedding
        x = self.patch_embed(x)

        # Add Positional Embeddings
        x = x + self.pos_embed[:, :x.size(1), :]  # Adjust for dropped patches
        x = self.pos_drop(x)

        # Transformer Layers
        for layer in self.layers:
            x, _ = layer(x)

        # Classification Head
        x = self.norm(x[:, 0])  # CLS token output
        x = self.head(x)
        return x


class PatchEmbed(nn.Module):
    """2D Patch Embedding with Patch Dropping"""
    def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3, embed_dim=768, drop_patch_prob=0.2):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (
            (img_size - patch_size) // stride + 1,
            (img_size - patch_size) // stride + 1,
        )
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.drop_patch_prob = drop_patch_prob

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=stride
        )

    def forward(self, x):
        x = self.proj(x)  # Convert to patches
        x = x.flatten(2).transpose(1, 2)  # Flatten into patch tokens [B, num_patches, embed_dim]

        # Drop patches probabilistically
        if self.drop_patch_prob > 0.0 and self.training:
            batch_size, num_patches, _ = x.size()
            patch_mask = torch.rand(num_patches, device=x.device) > self.drop_patch_prob
            x = x[:, patch_mask, :]

        return x


class Block(nn.Module):
    """Transformer Block with Mamba Mixer"""
    def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm, drop_path=0.):
        super().__init__()
        self.mixer = mixer_cls(dim)
        self.norm1 = norm_cls(dim)
        self.norm2 = norm_cls(dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x, residual=None):
        # Self-attention with residual
        residual = x
        x = self.norm1(x)
        x = self.mixer(x)
        x = residual + self.drop_path(x)  # Add & Norm

        # MLP with residual
        residual = x
        x = self.norm2(x)
        return x, residual


class SimplifiedMamba(nn.Module):
    """A simplified version of Mamba Mixer."""
    def __init__(self, dim, d_state=16):
        super().__init__()
        self.fc1 = nn.Linear(dim, d_state)
        self.act = nn.ReLU()
        self.fc2 = nn.Linear(d_state, dim)

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

In [18]:
model = SimplifiedVisionMambaWith20Drop(num_classes=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

train(model, train_loader, criterion, optimizer, device, num_epochs=35)

test(model, test_loader, criterion, device)



FLOPs per forward pass: 1442015232 FLOPs
Epoch 1/35: Train Loss: 0.7010, Train Accuracy: 51.59%
Epoch 2/35: Train Loss: 0.6862, Train Accuracy: 53.14%
Epoch 3/35: Train Loss: 0.6800, Train Accuracy: 55.71%
Epoch 4/35: Train Loss: 0.6847, Train Accuracy: 55.44%
Epoch 5/35: Train Loss: 0.6798, Train Accuracy: 55.31%
Epoch 6/35: Train Loss: 0.6709, Train Accuracy: 57.34%
Epoch 7/35: Train Loss: 0.6722, Train Accuracy: 56.80%
Epoch 8/35: Train Loss: 0.6747, Train Accuracy: 56.52%
Epoch 9/35: Train Loss: 0.6735, Train Accuracy: 56.46%
Epoch 10/35: Train Loss: 0.6629, Train Accuracy: 58.69%
Epoch 11/35: Train Loss: 0.6625, Train Accuracy: 58.82%
Epoch 12/35: Train Loss: 0.6639, Train Accuracy: 56.19%
Epoch 13/35: Train Loss: 0.6770, Train Accuracy: 54.16%
Epoch 14/35: Train Loss: 0.6683, Train Accuracy: 58.01%
Epoch 15/35: Train Loss: 0.6644, Train Accuracy: 57.40%
Epoch 16/35: Train Loss: 0.6637, Train Accuracy: 55.78%
Epoch 17/35: Train Loss: 0.6561, Train Accuracy: 58.49%
Epoch 18/35: Tra