In [None]:
import matplotlib.pyplot as plt
import numpy as np

"""
ViT Fine-tuning data
Epoch	Train Acc	Test Acc
1	0.8359	0.8797
2	0.9333	0.8843
3	0.9571	0.8836
4	0.9665	0.8799
5	0.9733	0.8790

MambaVision Fine-tuning data
Epoch	Train Acc	Test Acc
1	0.7408	0.8444
2	0.8754	0.8682
3	0.9134	0.8655
4	0.9323	0.8681
5	0.9464	0.8618
"""
vit_train_acc = [0.8359, 0.9333, 0.9571, 0.9665, 0.9733]
vit_test_acc = [0.8797, 0.8843, 0.8836, 0.8799, 0.8790]
mv_train_acc = [0.7408, 0.8754, 0.9134, 0.9323, 0.9464]
mv_test_acc = [0.8444, 0.8682, 0.8655, 0.8681, 0.8618]


In [None]:
# plot 6x4 figure
fig = plt.figure(figsize=(6, 4))
plt.plot(np.arange(1, 6), vit_train_acc, label='ViT Train Acc', color='lightsalmon', linestyle='--', marker='o')
plt.plot(np.arange(1, 6), vit_test_acc, label='ViT Test Acc', color='lightsalmon', linestyle='-', marker='o')
plt.plot(np.arange(1, 6), mv_train_acc, label='MambaVision Train Acc', color='royalblue', linestyle='--', marker='o')
plt.plot(np.arange(1, 6), mv_test_acc, label='MambaVision Test Acc', color='royalblue', linestyle='-', marker='o')
plt.xticks(np.arange(1, 6))
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('ViT Fine-tuning')
plt.legend()
plt.savefig('../report/figures/vit_fine_tuning.png')

In [None]:
# inference records

"""
ViT Inference data
Accuracy: 0.8790
Total inference time: 215.16 seconds
Average time per image: 0.0215 seconds


MambaVision Inference data
Accuracy: 0.8617
Total inference time: 146.37 seconds
Average time per image: 0.0146 seconds

"""

vit_acc = 0.8790
vit_total_time = 215.16
vit_avg_time = 0.0215
vit_throughput = 1/vit_avg_time


mv_acc = 0.8617
mv_total_time = 146.37
mv_avg_time = 0.0146
mv_throughput = 1/mv_avg_time

# print table for inference records
print('Inference Records')
print('Model\tAccuracy\tTotal Time\tAverage Time\tThroughput')
print(f'ViT\t{vit_acc}\t{vit_total_time}\t{vit_avg_time}\t{vit_throughput}')
print(f'MambaVision\t{mv_acc}\t{mv_total_time}\t{mv_avg_time}\t{mv_throughput}')


In [None]:
# accuracy vs throughput (img/s) plot, accuracy y axis
# x axis should be throughput
fig = plt.figure(figsize=(6, 4))
plt.plot(vit_throughput, vit_acc*100, label='ViT', color='lightsalmon', marker='o', markersize=10)
plt.plot(mv_throughput, mv_acc*100, label='MambaVision', color='royalblue', marker='o', markersize=10)
plt.xticks(np.arange(20, 100, 10))
plt.yticks(np.arange(80, 100, 5))
plt.xlabel('Throughput (img/s)')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.title('Accuracy vs Throughput')
plt.savefig('../report/figures/accuracy_vs_throughput.png')


In [None]:
# plot CIFAR100 sample datasets
# 10 classes, 10 images per class
# 10x10 figure
import torchvision.datasets as datasets

cifar100 = datasets.CIFAR100(root='./data', train=True, download=True)


In [None]:
# Load CIFAR-100 dataset
cifar100 = datasets.CIFAR100(root='./data', train=True, download=True)
classes = cifar100.classes

# Group images by class
class_to_images = {cls: [] for cls in range(100)}
for img, label in cifar100:
    if len(class_to_images[label]) < 10:  # Only collect 10 samples per class
        class_to_images[label].append(img)

# Plot 10 classes with 10 images per class
fig, axes = plt.subplots(10, 10, figsize=(10, 10))

for row, class_id in enumerate(range(10)):  # Select the first 10 classes
    for col in range(10):  # 10 images per class
        axes[row, col].imshow(class_to_images[class_id][col])
        axes[row, col].axis('off')

# Display the figure
plt.tight_layout()
plt.savefig('../report/figures/cifar100_sample.png')