In [1]:
from transformers import AutoModelForImageClassification, ViTImageProcessor, ViTForImageClassification
from timm.data.transforms_factory import create_transform
from torchvision import datasets
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import torch

In [2]:
# Load MambaVision Model
mamba_model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-B-1K", trust_remote_code=True)
mamba_model.cuda().eval()

  def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  def backward(ctx, dout):
  def forward(
  def backward(ctx, dout, *args):
  def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
  def backward(ctx, grad_output):
  def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
  def backward(ctx, dout, *args):


MambaVisionModelForImageClassification(
  (model): MambaVision(
    (patch_embed): PatchEmbed(
      (proj): Identity()
      (conv_down): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (4): BatchNorm2d(128, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU()
      )
    )
    (levels): ModuleList(
      (0): MambaVisionLayer(
        (blocks): ModuleList(
          (0): ConvBlock(
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (act1): GELU(approximate='tanh')
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padd

In [3]:
# prepare image for the model
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
url = 'https://i.ibb.co/JQLxsgX/aircraft-jet-landing-cloud-46148.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 34, 34)  # MambaVision supports any input resolutions

# display image
plt.imshow(image)
plt.axis('off')  # Hide axes
plt.show()

transform = create_transform(input_size=input_resolution,
                             is_training=False,
                             mean=mamba_model.config.mean,
                             std=mamba_model.config.std,
                             crop_mode=mamba_model.config.crop_mode,
                             crop_pct=mamba_model.config.crop_pct)

# output flops
input = torch.randn(1, 3, 224, 224).cuda()
flops, params = profile(mamba_model, inputs=(input, ))
print(flops)


inputs = transform(image).unsqueeze(0).cuda()
# model inference
outputs = mamba_model(inputs)
logits = outputs['logits']
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", mamba_model.config.id2label[predicted_class_idx])

#  Feature Extraction
#out_avg_pool, features = model_feature_extraction(inputs)
#print("Size of the averaged pool features:", out_avg_pool.size())  # torch.Size([1, 640])
#print("Number of stages in extracted features:", len(features)) # 4 stages
#print("Size of extracted features in stage 1:", features[0].size()) # torch.Size([1, 80, 56, 56])
#print("Size of extracted features in stage 4:", features[3].size()) # torch.Size([1, 640, 7, 7])

NameError: name 'Image' is not defined

In [None]:
# Load ViT Model and Processor
vit_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
vit_model.cuda().eval()

In [None]:
def get_flops(model, input_shape):
    input = torch.randn(1, *input_shape).cuda()
    return torch.cuda.get_device_properties(0).multi_processor_count * torch.cuda.get_device_properties(0).core_count * model(input).shape.numel()

mamba_flops = get_flops(mamba_model, (3, 224, 224))
vit_flops = get_flops(vit_model, (3, 224, 224))

print(f"MambaVision-B-1K: {mamba_flops / 1e9:.2f} GFLOPs")
print(f"ViT-Base: {vit_flops / 1e9:.2f} GFLOPs")

In [None]:
# Define CIFAR-100 Dataset Loader
batch_size = 32

# MambaVision Transform
input_resolution = (3, 224, 224)
mamba_transform = create_transform(input_size=input_resolution,
                                   is_training=False,
                                   mean=mamba_model.config.mean,
                                   std=mamba_model.config.std,
                                   crop_mode=mamba_model.config.crop_mode,
                                   crop_pct=mamba_model.config.crop_pct)

# CIFAR-100 Dataset for MambaVision
cifar100_test_mamba = datasets.CIFAR100(root='./data', train=False, download=True, transform=mamba_transform)
test_loader_mamba = DataLoader(cifar100_test_mamba, batch_size=batch_size, shuffle=False)

# CIFAR-100 Dataset for ViT
vit_transform = vit_processor  # Use ViT's processor directly
cifar100_test_vit = datasets.CIFAR100(root='./data', train=False, download=True, transform=lambda img: vit_transform(images=img, return_tensors="pt")["pixel_values"].squeeze())
test_loader_vit = DataLoader(cifar100_test_vit, batch_size=batch_size, shuffle=False)



In [None]:
# Evaluate MambaVision
def evaluate_mamba(model, data_loader):
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc="Evaluating MambaVision"):
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            logits = outputs['logits']
            predicted_class_idxs = logits.argmax(dim=-1)
            correct += (predicted_class_idxs == labels).sum().item()
            total += labels.size(0)
    return correct / total


In [None]:
# Evaluate ViT
def evaluate_vit(model, data_loader):
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc="Evaluating ViT"):
            images, labels = images.cuda(), labels.cuda()
            outputs = model(pixel_values=images)
            logits = outputs.logits
            predicted_class_idxs = logits.argmax(dim=-1)
            correct += (predicted_class_idxs == labels).sum().item()
            total += labels.size(0)
    return correct / total

In [None]:
# Run Evaluation
accuracy_mamba = evaluate_mamba(mamba_model, test_loader_mamba)
print(f"MambaVision Accuracy on CIFAR-100: {accuracy_mamba:.4f}")


In [None]:
accuracy_vit = evaluate_vit(vit_model, test_loader_vit)
print(f"ViT Accuracy on CIFAR-100: {accuracy_vit:.4f}")

In [None]:
from transformers import ViTForImageClassification, ViTImageProcessor
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Pre-Trained ViT Model
vit_model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=100,  # Adjust for CIFAR-100
    id2label={i: str(i) for i in range(100)},
    label2id={str(i): i for i in range(100)},
    ignore_mismatched_sizes=True  # Ignore size mismatch for the classification head
)
vit_model.to(device)

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

# Data Preprocessing
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
])

# Load CIFAR-100 Dataset
train_dataset = datasets.CIFAR100(root="./data", train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR100(root="./data", train=False, download=True, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(vit_model.parameters(), lr=5e-5, weight_decay=0.01)

# Training Function
def train_model(model, data_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for images, labels in tqdm(data_loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(pixel_values=images)
        loss = criterion(outputs.logits, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = outputs.logits.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)
    
    accuracy = correct / total
    return total_loss / len(data_loader), accuracy

# Evaluation Function
def evaluate_model(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(pixel_values=images)
            loss = criterion(outputs.logits, labels)
            
            total_loss += loss.item()
            _, predicted = outputs.logits.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
    
    accuracy = correct / total
    return total_loss / len(data_loader), accuracy

# Training Loop
num_epochs = 5
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    
    train_loss, train_acc = train_model(vit_model, train_loader, optimizer, criterion, device)
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")
    
    val_loss, val_acc = evaluate_model(vit_model, test_loader, criterion, device)
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")

# Save the Fine-Tuned Model
vit_model.save_pretrained("./vit-finetuned-cifar100")
