# testing mnist

In [None]:
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to the input size expected by Vim
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Download and load the MNIST dataset
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
import torch

# Assuming the model is compatible with MNIST and has been adjusted accordingly
model_path = '/path/to/vim_mnist_model.pth'
model = torch.load(model_path)
model.eval()

In [None]:
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        # Here you can compare predictions with the labels, etc.

# Vim 

installation: <br>
pip install vision-mamba \
conda install python=3.10

## simple Vim example with random array

In [1]:

import torch
from vision_mamba.model import Vim

# Create a random tensor
x = torch.randn(1, 3, 224, 224)

# Create an instance of the Vim model
model = Vim(
    dim=256,  # Dimension of the model
    heads=8,  # Number of attention heads
    dt_rank=32,  # Rank of the dynamic routing tensor
    dim_inner=256,  # Inner dimension of the model
    d_state=256,  # State dimension of the model
    num_classes=1000,  # Number of output classes
    image_size=224,  # Size of the input image
    patch_size=16,  # Size of the image patch
    channels=3,  # Number of input channels
    dropout=0.1,  # Dropout rate
    depth=12,  # Depth of the model
)

# Perform a forward pass through the model
out = model(x)

# Print the shape and output of the forward pass
print(out.shape)
print(out)


  from .autonotebook import tqdm as notebook_tqdm


Patch embedding: torch.Size([1, 196, 256])
Cls tokens: torch.Size([1, 1, 256])
torch.Size([1, 196, 256])
Conv1d: tensor([[[1.3771, 0.6676, 0.6763,  ..., 0.6272, 0.8216, 0.3146],
         [0.9814, 0.9791, 0.2337,  ..., 0.7276, 0.5884, 0.2602],
         [0.6092, 0.3642, 0.3866,  ..., 0.3664, 0.2654, 0.5497],
         ...,
         [1.1254, 0.3642, 0.7701,  ..., 0.7868, 1.0325, 0.5821],
         [0.3234, 0.9146, 0.6319,  ..., 0.6030, 0.5755, 0.6765],
         [0.5759, 0.9608, 0.5400,  ..., 0.4798, 1.0296, 0.4448]]],
       grad_fn=<SoftplusBackward0>)
Conv1d: tensor([[[0.5625, 1.5989, 0.7373,  ..., 0.7363, 0.4514, 0.4862],
         [1.2964, 0.6825, 0.4543,  ..., 0.4795, 0.7936, 0.8328],
         [0.9315, 0.2457, 0.8427,  ..., 0.3826, 0.8089, 0.7580],
         ...,
         [0.8614, 0.7139, 0.5613,  ..., 0.7951, 1.3101, 0.7791],
         [0.7097, 0.4163, 1.1317,  ..., 1.0280, 1.2699, 1.1448],
         [0.9224, 0.8868, 1.0144,  ..., 0.4625, 0.7845, 0.7140]]],
       grad_fn=<SoftplusBackwar

In [2]:
x

tensor([[[[ 0.5783,  0.2006, -0.0710,  ..., -1.1893,  0.6128,  0.2046],
          [-1.6061,  0.9653,  0.3932,  ...,  0.6520,  0.1387,  0.3216],
          [-0.5303,  0.3957,  0.2776,  ..., -0.5863, -0.4728, -0.2662],
          ...,
          [-1.7907, -0.7359,  0.6876,  ...,  0.0656, -0.8618, -2.0615],
          [-0.7561, -0.8682,  2.4468,  ..., -0.3109,  1.7211, -1.1898],
          [-0.4262,  0.4547, -0.1449,  ..., -0.2588, -0.0426, -0.6860]],

         [[ 0.7730,  1.6578, -0.7547,  ...,  0.8086, -0.2623, -0.1340],
          [-1.2641, -1.5959, -0.9542,  ...,  0.2687, -0.5831, -0.0722],
          [ 1.5232,  1.3753, -1.2419,  ...,  0.2460,  0.1829,  0.3695],
          ...,
          [ 0.5844,  0.2033, -0.3266,  ..., -0.7598, -0.6962,  0.7340],
          [ 1.2142,  0.6437,  0.0954,  ..., -0.6485, -0.9494, -1.1728],
          [-0.8300, -1.9720, -1.1100,  ..., -0.3776, -0.2361, -1.4140]],

         [[ 0.0588, -0.5722, -0.7176,  ...,  0.8371, -0.1833,  0.7830],
          [ 0.2611,  0.0945,  

## Vim example with loss and optimizer

In [4]:
import torch
import torch.optim as optim
from torch.nn import MSELoss
from torch.utils.data import DataLoader, TensorDataset
from vision_mamba.model import Vim

# Create a random tensor for inputs and targets
inputs = torch.randn(1, 3, 224, 224)  # 100 images
targets = torch.randn(1, 1)  # 100 target values

# Create a TensorDataset and DataLoader
dataset = TensorDataset(inputs, targets)
train_loader = DataLoader(dataset, batch_size=10, shuffle=True)

# Initialize the Vim model
model = Vim(
    dim=256,
    heads=8,
    dt_rank=32,
    dim_inner=256,
    d_state=256,
    num_classes=1,  # For regression, typically the output is a single value per instance
    image_size=224,
    patch_size=16,
    channels=3,
    dropout=0.1,
    depth=12,
)

# Using Mean Squared Error Loss for a regression task
criterion = MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
model.train()  # Set the model to training mode
num_epochs = 5  # Define the number of epochs
# for epoch in range(num_epochs):
#     total_loss = 0.0
#     num_batches = 0

#     for batch_inputs, batch_targets in train_loader:
#         optimizer.zero_grad()  # Zero the parameter gradients

#         # Forward pass
#         outputs = model(batch_inputs)
#         loss = criterion(outputs, batch_targets)

#         # Backward pass and optimize
#         loss.backward()
#         optimizer.step()

#         # Accumulate loss
#         total_loss += loss.item()
#         num_batches += 1

#     # Calculate average loss for the epoch
#     average_loss = total_loss / num_batches
#     print(f'Epoch {epoch + 1}: Average Loss {average_loss:.4f}')

for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (batch_inputs, batch_targets) in enumerate(train_loader):
        optimizer.zero_grad()  # Zero the parameter gradients

        # Forward pass
        outputs = model(batch_inputs)
        loss = criterion(outputs, batch_targets)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
        if i % 1 == 0:  # Print every 10 mini-batches
            print(f'Epoch {epoch + 1}, Batch {i + 1}: Loss {running_loss / 10:.4f}')
            running_loss = 0.0

Patch embedding: torch.Size([1, 196, 256])
Cls tokens: torch.Size([1, 1, 256])
torch.Size([1, 196, 256])
Conv1d: tensor([[[0.9293, 1.0125, 0.6928,  ..., 0.7531, 0.6932, 0.5721],
         [0.7163, 0.6889, 0.8921,  ..., 0.5416, 0.4246, 1.3118],
         [0.4444, 0.7595, 0.7152,  ..., 0.4556, 0.4305, 0.9408],
         ...,
         [1.1944, 0.8200, 0.3902,  ..., 0.4207, 0.8082, 0.9928],
         [0.3697, 0.8940, 0.5094,  ..., 0.5713, 0.2747, 0.8667],
         [0.7288, 0.9279, 0.7717,  ..., 0.7507, 0.7188, 0.7698]]],
       grad_fn=<SoftplusBackward0>)
Conv1d: tensor([[[1.3588, 0.1901, 0.8427,  ..., 0.5604, 0.5004, 1.1977],
         [0.7636, 0.4937, 0.6161,  ..., 0.6588, 0.7040, 0.3382],
         [0.9844, 0.6580, 0.5420,  ..., 0.6902, 0.2148, 0.8575],
         ...,
         [0.8270, 0.4530, 0.6703,  ..., 0.5966, 0.2155, 0.5822],
         [0.7245, 1.0426, 1.2601,  ..., 0.8590, 1.2649, 0.3840],
         [1.0408, 1.2897, 0.9722,  ..., 0.7549, 1.0958, 0.7068]]],
       grad_fn=<SoftplusBackwar

## Vim example with correlation

In [None]:
import torch
import torch.optim as optim
from torch.nn import MSELoss
from torch.utils.data import DataLoader, TensorDataset
from vision_mamba.model import Vim
import numpy as np

# Create a random tensor for inputs and targets
inputs = torch.randn(100, 3, 224, 224)  # 100 images
targets = torch.randn(100, 1)  # 100 target values

# Create a TensorDataset and DataLoader
dataset = TensorDataset(inputs, targets)
train_loader = DataLoader(dataset, batch_size=10, shuffle=True)

# Initialize the Vim model
model = Vim(
    dim=256,
    heads=8,
    dt_rank=32,
    dim_inner=256,
    d_state=256,
    num_classes=1,  # For regression, typically the output is a single value per instance
    image_size=224,
    patch_size=16,
    channels=3,
    dropout=0.1,
    depth=12,
)

# Using Mean Squared Error Loss for a regression task
criterion = MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
model.train()  # Set the model to training mode
num_epochs = 5  # Define the number of epochs
verbose = True  # Set verbose to True to print correlation
for epoch in range(num_epochs):
    total_loss = 0.0
    num_batches = 0
    outputs_all = []
    targets_all = []

    for batch_inputs, batch_targets in train_loader:
        optimizer.zero_grad()  # Zero the parameter gradients

        # Forward pass
        outputs = model(batch_inputs)
        loss = criterion(outputs, batch_targets)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Accumulate loss
        total_loss += loss.item()
        num_batches += 1

        # Debugging shapes
        print("Output shape:", outputs.shape)
        print("Target shape:", batch_targets.shape)

        # Collect outputs and targets for correlation, ensure they are flattened
        outputs_all.append(outputs.view(-1).detach().numpy())
        targets_all.append(batch_targets.view(-1).detach().numpy())

    # Calculate average loss for the epoch
    average_loss = total_loss / num_batches
    print(f'Epoch {epoch + 1}: Average Loss {average_loss:.4f}')

    # Compute correlation
    outputs_flat = np.concatenate(outputs_all)
    targets_flat = np.concatenate(targets_all)
    corr = np.corrcoef(outputs_flat, targets_flat)[0, 1]
    if verbose:
        print('Epoch {}: Correlation: {:.4f}'.format(epoch + 1, corr))

  from .autonotebook import tqdm as notebook_tqdm


Patch embedding: torch.Size([10, 196, 256])
Cls tokens: torch.Size([10, 1, 256])
torch.Size([10, 196, 256])
Conv1d: tensor([[[0.3320, 0.6612, 0.2751,  ..., 0.5106, 0.3773, 1.1315],
         [0.6955, 0.4852, 0.7019,  ..., 0.4834, 0.3049, 0.7811],
         [0.6267, 0.8924, 0.6238,  ..., 1.0904, 1.0939, 0.7068],
         ...,
         [1.1881, 1.1885, 0.8275,  ..., 1.4902, 0.6884, 0.8910],
         [1.3025, 0.5386, 0.5016,  ..., 0.9290, 0.9307, 0.5725],
         [0.4529, 0.5871, 0.6947,  ..., 0.8271, 0.8326, 0.8179]],

        [[1.0223, 0.4483, 0.7652,  ..., 0.6915, 0.5049, 0.8500],
         [0.4183, 0.7102, 0.4091,  ..., 0.6213, 0.8058, 1.0482],
         [0.8836, 0.9647, 0.7765,  ..., 0.3293, 0.8521, 0.9572],
         ...,
         [0.2148, 0.5371, 1.1194,  ..., 1.0633, 0.5155, 0.7701],
         [1.0844, 0.3533, 0.8568,  ..., 0.6971, 1.0897, 0.7682],
         [0.7583, 0.6769, 1.2603,  ..., 0.7924, 0.9959, 0.9841]],

        [[0.8421, 0.7553, 0.4263,  ..., 0.2580, 0.4061, 0.9946],
       