In [None]:
from transformers import AutoModelForImageClassification, ViTImageProcessor, ViTForImageClassification
from timm.data.transforms_factory import create_transform
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.optim as optim

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# print cuda device information
print(torch.cuda.get_device_name(0))


In [None]:
model = AutoModelForImageClassification.from_pretrained(
    "nvidia/MambaVision-B-1K",
    num_labels=100,
    id2label={i: str(i) for i in range(100)},
    label2id={str(i): i for i in range(100)},
    ignore_mismatched_sizes=True,
    trust_remote_code=True)
# Check the model's architecture to find the classification head
model.model.head = nn.Linear(in_features=model.model.head.in_features, out_features=100, bias=True)
model.to(device)

In [None]:
input_resolution = (3, 224, 224)  # MambaVision supports any input resolutions
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=model.config.mean, std=model.config.std),
])
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=model.config.mean, std=model.config.std),
])

In [None]:
# Load CIFAR-100 Dataset
train_dataset = datasets.CIFAR100(root="./data", train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR100(root="./data", train=False, download=True, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [None]:
def train_and_evaluate_model(model, train_loader, test_loader, optimizer, criterion, device, num_epochs):
    train_acc = []
    test_acc = []
    for epoch in range(num_epochs):
        # Initialize a new progress bar for each epoch
        progress_bar = tqdm(total=len(train_loader) + len(test_loader), desc=f"Epoch {epoch + 1}/{num_epochs} - Train Acc: 0.000 - Val Acc: 0.000", unit="batch")
        
        # Initialize variables for training phase
        train_correct = 0
        train_total = 0
        
        # Training Phase
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs['logits'], labels)
            loss.backward()
            optimizer.step()
            
            _, predicted = outputs['logits'].max(1)
            train_correct += predicted.eq(labels).sum().item()
            train_total += labels.size(0)
            
            # Update training accuracy
            train_accuracy = train_correct / train_total
            progress_bar.set_description(f"Epoch {epoch + 1}/{num_epochs} - Train Acc: {train_accuracy:.4f}")
            progress_bar.update(1)  # Update progress bar for each training batch
        
        train_acc.append(train_accuracy)

        # Initialize variables for validation phase
        val_correct = 0
        val_total = 0
        
        # Validation Phase
        model.eval()
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs['logits'], labels)
                
                _, predicted = outputs['logits'].max(1)
                val_correct += predicted.eq(labels).sum().item()
                val_total += labels.size(0)
                
                # Update validation accuracy
                val_accuracy = val_correct / val_total
                progress_bar.set_description(f"Epoch {epoch + 1}/{num_epochs} - Train Acc: {train_accuracy:.4f} - Val Acc: {val_accuracy:.4f}")
                progress_bar.update(1)  # Update progress bar for each validation batch

        test_acc.append(val_accuracy)

        # Reset the progress bar description for the next epoch
        progress_bar.set_description(f"Epoch {epoch + 1}/{num_epochs} - Train Acc: {train_accuracy:.4f} - Val Acc: {val_accuracy:.4f}")
        progress_bar.close()
        
    return train_acc, test_acc

# Call the function
num_epochs = 5
# Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01)
train_acc, test_acc = train_and_evaluate_model(model, train_loader, test_loader, optimizer, criterion, device, num_epochs)

# Save the Fine-Tuned Model
model.save_pretrained("./mambavision-finetuned-cifar100")