In [1]:
!pip install zetascale
!pip install swarms

Collecting zetascale
  Downloading zetascale-2.8.6-py3-none-any.whl.metadata (23 kB)
Collecting argparse<2.0.0,>=1.4.0 (from zetascale)
  Downloading argparse-1.4.0-py2.py3-none-any.whl.metadata (2.8 kB)
Collecting bitsandbytes (from zetascale)
  Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting colt5-attention (from zetascale)
  Downloading CoLT5_attention-0.11.1-py3-none-any.whl.metadata (737 bytes)
Collecting einops-exts==0.0.4 (from zetascale)
  Downloading einops_exts-0.0.4-py3-none-any.whl.metadata (621 bytes)
Collecting joblib<1.4.0,>=1.3.0 (from zetascale)
  Downloading joblib-1.3.2-py3-none-any.whl.metadata (5.4 kB)
Collecting local-attention (from zetascale)
  Downloading local_attention-1.11.1-py3-none-any.whl.metadata (907 bytes)
Collecting loguru (from zetascale)
  Downloading loguru-0.7.3-py3-none-any.whl.metadata (22 kB)
Collecting scikit-learn<1.6.0,>=1.5.0 (from zetascale)
  Downloading scikit_learn-1.5.2-cp311-cp311-manylin

In [2]:
import torch
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import nn, Tensor
from zeta.nn import SSM
from einops.layers.torch import Reduce

E0000 00:00:1748813298.550892      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748813298.668791      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


# Modules

In [3]:
# Pair
def pair(t):
    return t if isinstance(t, tuple) else (t, t)


def output_head(dim: int, num_classes: int):
    """
    Creates a head for the output layer of a model.

    Args:
        dim (int): The input dimension of the head.
        num_classes (int): The number of output classes.

    Returns:
        nn.Sequential: The output head module.
    """
    return nn.Sequential(
        Reduce("b s d -> b d", "mean"),
        nn.LayerNorm(dim),
        nn.Linear(dim, num_classes),
    )


class VisionEncoderMambaBlock(nn.Module):
    """
    VisionMambaBlock is a module that implements the Mamba block from the paper
    Vision Mamba: Efficient Visual Representation Learning with Bidirectional
    State Space Model

    Args:
        dim (int): The input dimension of the input tensor.
        dt_rank (int): The rank of the state space model.
        dim_inner (int): The dimension of the inner layer of the
            multi-head attention.
        d_state (int): The dimension of the state space model.


    Example:
    >>> block = VisionMambaBlock(dim=256, heads=8, dt_rank=32,
            dim_inner=512, d_state=256)
    >>> x = torch.randn(1, 32, 256)
    >>> out = block(x)
    >>> out.shape
    torch.Size([1, 32, 256])
    """

    def __init__(
        self,
        dim: int,
        dt_rank: int,
        dim_inner: int,
        d_state: int,
    ):
        super().__init__()
        self.dim = dim
        self.dt_rank = dt_rank
        self.dim_inner = dim_inner
        self.d_state = d_state

        self.forward_conv1d = nn.Conv1d(
            in_channels=dim, out_channels=dim, kernel_size=1
        )
        self.backward_conv1d = nn.Conv1d(
            in_channels=dim, out_channels=dim, kernel_size=1
        )
        self.norm = nn.LayerNorm(dim)
        self.silu = nn.SiLU()
        self.ssm = SSM(dim, dt_rank, dim_inner, d_state)

        # Linear layer for z and x
        self.proj = nn.Linear(dim, dim)

        # Softplus
        self.softplus = nn.Softplus()

    def forward(self, x: torch.Tensor):
        b, s, d = x.shape

        # Skip connection
        skip = x

        # Normalization
        x = self.norm(x)

        # Split x into x1 and x2 with linears
        z1 = self.proj(x)
        x = self.proj(x)

        # forward con1d
        x1 = self.process_direction(
            x,
            self.forward_conv1d,
            self.ssm,
        )

        # backward conv1d
        x2 = self.process_direction(
            x,
            self.backward_conv1d,
            self.ssm,
        )

        # Activation
        z = self.silu(z1)

        # Matmul
        x1 *= z
        x2 *= z

        # Residual connection
        return x1 + x2 + skip

    def process_direction(
        self,
        x: Tensor,
        conv1d: nn.Conv1d,
        ssm: SSM,
    ):
        x = rearrange(x, "b s d -> b d s")
        x = self.softplus(conv1d(x))
        # print(f"Conv1d: {x}")
        x = rearrange(x, "b d s -> b s d")
        x = ssm(x)
        return x

# Model

In [4]:
class Vim(nn.Module):
    """
    Vision Mamba (Vim) model implementation.

    Args:
        dim (int): Dimension of the model.
        dt_rank (int, optional): Rank of the dynamic tensor. Defaults to 32.
        dim_inner (int, optional): Inner dimension of the model. Defaults to None.
        d_state (int, optional): State dimension of the model. Defaults to None.
        num_classes (int, optional): Number of output classes. Defaults to None.
        image_size (int, optional): Size of the input image. Defaults to 224.
        patch_size (int, optional): Size of the image patch. Defaults to 16.
        channels (int, optional): Number of image channels. Defaults to 3.
        dropout (float, optional): Dropout rate. Defaults to 0.1.
        depth (int, optional): Number of encoder layers. Defaults to 12.

    Attributes:
        dim (int): Dimension of the model.
        dt_rank (int): Rank of the dynamic tensor.
        dim_inner (int): Inner dimension of the model.
        d_state (int): State dimension of the model.
        num_classes (int): Number of output classes.
        image_size (int): Size of the input image.
        patch_size (int): Size of the image patch.
        channels (int): Number of image channels.
        dropout (float): Dropout rate.
        depth (int): Number of encoder layers.
        to_patch_embedding (nn.Sequential): Sequential module for patch embedding.
        dropout (nn.Dropout): Dropout module.
        cls_token (nn.Parameter): Class token parameter.
        to_latent (nn.Identity): Identity module for latent representation.
        layers (nn.ModuleList): List of encoder layers.
        output_head (output_head): Output head module.

    """

    def __init__(
        self,
        dim: int,
        dt_rank: int = 32,
        dim_inner: int = None,
        d_state: int = None,
        num_classes: int = None,
        image_size: int = 224,
        patch_size: int = 16,
        channels: int = 3,
        dropout: float = 0.1,
        depth: int = 12,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.dt_rank = dt_rank
        self.dim_inner = dim_inner
        self.d_state = d_state
        self.num_classes = num_classes
        self.image_size = image_size
        self.patch_size = patch_size
        self.channels = channels
        self.dropout = dropout
        self.depth = depth

        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange(
                "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
                p1=patch_height,
                p2=patch_height,
            ),
            nn.Linear(patch_dim, dim),
        )

        # Dropout
        self.dropout = nn.Dropout(dropout)

        # class token
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        # Latent
        self.to_latent = nn.Identity()

        # encoder layers
        self.layers = nn.ModuleList()

        # Append the encoder layers
        for _ in range(depth):
            self.layers.append(
                VisionEncoderMambaBlock(
                    dim=dim,
                    dt_rank=dt_rank,
                    dim_inner=dim_inner,
                    d_state=d_state,
                    *args,
                    **kwargs,
                )
            )

        # Output head
        self.output_head = output_head(dim, num_classes)

    def forward(self, x: Tensor):
        # Patch embedding
        b, c, h, w = x.shape

        x = self.to_patch_embedding(x)
        # print(f"Patch embedding: {x.shape}")

        # Shape
        b, n, _ = x.shape

        # Cls tokens
        cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b)
        # print(f"Cls tokens: {cls_tokens.shape}")

        # Concatenate
        # x = torch.cat((cls_tokens, x), dim=1)

        # Dropout
        x = self.dropout(x)
        # print(x.shape)

        # Forward pass with the layers
        for layer in self.layers:
            x = layer(x)
            # print(f"Layer: {x.shape}")

        # Latent
        x = self.to_latent(x)

        # x = reduce(x, "b s d -> b d", "mean")

        # Output head with the cls tokens
        return self.output_head(x)

# Save & Load checkpoint

In [5]:
import torch
import os

# Đường dẫn lưu checkpoint
checkpoint_path = "/kaggle/working/model_checkpoint.pth"
checkpoint_path_best = "/kaggle/working/model_bestcheckpoint.pth"

# Hàm lưu checkpoint
def save_checkpoint(model, optimizer, epoch, loss, scheduler, checkpoint_path):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': loss
    }
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")
    
# Hàm load checkpoint
def load_checkpoint(model, optimizer, scheduler, checkpoint_path):
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, weights_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        print(f"Checkpoint loaded! Resuming from Epoch {epoch + 1}, Loss: {loss}")
        return model, optimizer, scheduler, epoch + 1, loss
    else:
        print("No checkpoint found, starting from scratch.")
        return model, optimizer, scheduler, 0, None

# Data Setup

In [6]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split, DataLoader

# Data transforms for CIFAR10.
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    transforms.RandomErasing(p=0.2),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# # Transforms cho CIFAR-100
# transform_train = transforms.Compose([
#     transforms.RandomCrop(32, padding=4),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
#                          std=[0.2675, 0.2565, 0.2761]),
#     transforms.RandomErasing(p=0.2),
# ])
# transform_test = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
#                          std=[0.2675, 0.2565, 0.2761])
# ])

# Load CIFAR-10
train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)

test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)

shuffle_gen = torch.Generator()
shuffle_gen.manual_seed(42)

# Dataloaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                          generator=shuffle_gen, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

100%|██████████| 170M/170M [00:02<00:00, 81.6MB/s] 


# Train Model

In [7]:
from tqdm import tqdm

def train_model(model, train_loader, test_loader, optimizer, scheduler, 
                criterion, start_epoch, num_epochs, device, model_name):
    best_acc = 0.0
    
    for epoch in range(start_epoch, num_epochs):
        # Training phase.
        model.train()
        running_loss = 0.0
        train_correct = 0
        train_total = 0
        train_loader_tqdm = tqdm(train_loader, desc=f"{model_name} Epoch {epoch+1}/{num_epochs} - Training")
        for inputs, labels in train_loader_tqdm:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()
            train_loader_tqdm.set_postfix(loss=running_loss/train_total, acc=100.*train_correct/train_total)

        
        epoch_train_loss = running_loss / len(train_loader.dataset)
        epoch_train_acc = 100. * train_correct / train_total
        print(f"\n{model_name} Epoch [{epoch+1}/{num_epochs}] Training Loss: {epoch_train_loss:.4f} | Accuracy: {epoch_train_acc:.2f}%")

        current_lr = optimizer.param_groups[0]['lr']

        
        # Validation phase.
        model.eval()
        test_loss = 0.0
        correct = 0
        total = 0
        test_loader_tqdm = tqdm(test_loader, desc=f"{model_name} Epoch {epoch+1}/{num_epochs} - Validation")
        with torch.no_grad():
            for inputs, labels in test_loader_tqdm:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                test_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                test_loader_tqdm.set_postfix(loss=test_loss/total, acc=100.*correct/total)
        
        val_loss = test_loss / len(test_loader.dataset)
        val_acc = 100. * correct / total
        print(f"{model_name} Epoch [{epoch+1}/{num_epochs}] Validation Loss: {val_loss:.4f} | Accuracy: {val_acc:.2f}%\n")
        
        # wandb.log({
        #     "epoch": epoch,
        #     "train_loss": epoch_train_loss,
        #     "train_acc": epoch_train_acc,
        #     "val_loss": val_loss,
        #     "val_acc": val_acc,
        #     "learning_rate": current_lr
        # })
                
        scheduler.step()

        save_checkpoint(model, optimizer, epoch, loss, scheduler, checkpoint_path)
        print(f"Checkpoint saved at epoch {epoch+1}, accuracy: {val_acc:.2f}%")
        
        if val_acc > best_acc:
            best_acc = val_acc
            save_checkpoint(model, optimizer, epoch, loss, scheduler, checkpoint_path_best)
            print(f"✅ Best checkpoint saved at epoch {epoch+1}, accuracy: {val_acc:.2f}%")

In [8]:
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available else "cpu")
    
VisionMamba_model = Vim(
                        image_size=32, 
                        patch_size=4, 
                        depth=6, 
                        dim=192, 
                        channels=3, 
                        num_classes=10,
                        dropout=0.1,
                        dt_rank=8,
                        d_state=256,
                        dim_inner=192
                    ).to(device)

print("Available GPUs:", torch.cuda.device_count())
if torch.cuda.device_count() > 1:
    VisionMamba_model = nn.DataParallel(VisionMamba_model)

# Number of Trainable parameters
trainable_params = sum(p.numel() for p in VisionMamba_model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params}")

Available GPUs: 2
Trainable parameters: 1586698


In [9]:
lr = 7e-4
weight_decay=0.01
num_epochs = 200
start_epoch = 0

criterion = nn.CrossEntropyLoss(label_smoothing=0.2)
optimizer = torch.optim.AdamW(VisionMamba_model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=5e-6)

prev_checkpoint = "/kaggle/input/visionmambademo13/pytorch/default/1/model_checkpoint.pth"
VisionMamba_model, optimizer, scheduler, start_epoch, last_loss = load_checkpoint(VisionMamba_model, optimizer, scheduler, prev_checkpoint)

Checkpoint loaded! Resuming from Epoch 117, Loss: 1.417494535446167


In [10]:
print(VisionMamba_model)

DataParallel(
  (module): Vim(
    (to_patch_embedding): Sequential(
      (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=4, p2=4)
      (1): Linear(in_features=48, out_features=192, bias=True)
    )
    (dropout): Dropout(p=0.1, inplace=False)
    (to_latent): Identity()
    (layers): ModuleList(
      (0-5): 6 x VisionEncoderMambaBlock(
        (forward_conv1d): Conv1d(192, 192, kernel_size=(1,), stride=(1,))
        (backward_conv1d): Conv1d(192, 192, kernel_size=(1,), stride=(1,))
        (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
        (silu): SiLU()
        (ssm): SSM(
          (deltaBC_layer): Linear(in_features=192, out_features=520, bias=False)
          (dt_proj_layer): Linear(in_features=8, out_features=192, bias=True)
        )
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (softplus): Softplus(beta=1.0, threshold=20.0)
      )
    )
    (output_head): Sequential(
      (0): Reduce('b s d -> b d', 'mean')
    

In [None]:
# Train the model.
train_model(VisionMamba_model, train_loader, test_loader, optimizer, scheduler, criterion,
        start_epoch, num_epochs, device, model_name="VisionMamba")

VisionMamba Epoch 118/200 - Training: 100%|██████████| 1563/1563 [22:44<00:00,  1.15it/s, acc=92.4, loss=1.01]



VisionMamba Epoch [118/200] Training Loss: 1.0148 | Accuracy: 92.42%


VisionMamba Epoch 118/200 - Validation: 100%|██████████| 313/313 [01:33<00:00,  3.35it/s, acc=81.5, loss=1.22]


VisionMamba Epoch [118/200] Validation Loss: 1.2161 | Accuracy: 81.45%

Checkpoint saved at /kaggle/working/model_checkpoint.pth
Checkpoint saved at epoch 118, accuracy: 81.45%
Checkpoint saved at /kaggle/working/model_bestcheckpoint.pth
✅ Best checkpoint saved at epoch 118, accuracy: 81.45%


VisionMamba Epoch 119/200 - Training: 100%|██████████| 1563/1563 [22:43<00:00,  1.15it/s, acc=92.5, loss=1.01]



VisionMamba Epoch [119/200] Training Loss: 1.0123 | Accuracy: 92.50%


VisionMamba Epoch 119/200 - Validation: 100%|██████████| 313/313 [01:32<00:00,  3.37it/s, acc=82.8, loss=1.19]


VisionMamba Epoch [119/200] Validation Loss: 1.1948 | Accuracy: 82.85%

Checkpoint saved at /kaggle/working/model_checkpoint.pth
Checkpoint saved at epoch 119, accuracy: 82.85%
Checkpoint saved at /kaggle/working/model_bestcheckpoint.pth
✅ Best checkpoint saved at epoch 119, accuracy: 82.85%


VisionMamba Epoch 120/200 - Training: 100%|██████████| 1563/1563 [22:42<00:00,  1.15it/s, acc=92.3, loss=1.01]



VisionMamba Epoch [120/200] Training Loss: 1.0125 | Accuracy: 92.32%


VisionMamba Epoch 120/200 - Validation: 100%|██████████| 313/313 [01:32<00:00,  3.37it/s, acc=82.7, loss=1.2] 


VisionMamba Epoch [120/200] Validation Loss: 1.1972 | Accuracy: 82.72%

Checkpoint saved at /kaggle/working/model_checkpoint.pth
Checkpoint saved at epoch 120, accuracy: 82.72%


VisionMamba Epoch 121/200 - Training: 100%|██████████| 1563/1563 [22:42<00:00,  1.15it/s, acc=92.8, loss=1.01]



VisionMamba Epoch [121/200] Training Loss: 1.0068 | Accuracy: 92.76%


VisionMamba Epoch 121/200 - Validation: 100%|██████████| 313/313 [01:32<00:00,  3.37it/s, acc=82.6, loss=1.21]


VisionMamba Epoch [121/200] Validation Loss: 1.2085 | Accuracy: 82.58%

Checkpoint saved at /kaggle/working/model_checkpoint.pth
Checkpoint saved at epoch 121, accuracy: 82.58%


VisionMamba Epoch 122/200 - Training: 100%|██████████| 1563/1563 [22:42<00:00,  1.15it/s, acc=92.7, loss=1.01]



VisionMamba Epoch [122/200] Training Loss: 1.0075 | Accuracy: 92.66%


VisionMamba Epoch 122/200 - Validation: 100%|██████████| 313/313 [01:32<00:00,  3.37it/s, acc=82.4, loss=1.21]


VisionMamba Epoch [122/200] Validation Loss: 1.2078 | Accuracy: 82.40%

Checkpoint saved at /kaggle/working/model_checkpoint.pth
Checkpoint saved at epoch 122, accuracy: 82.40%


VisionMamba Epoch 123/200 - Training: 100%|██████████| 1563/1563 [22:42<00:00,  1.15it/s, acc=92.5, loss=1.01]



VisionMamba Epoch [123/200] Training Loss: 1.0095 | Accuracy: 92.48%


VisionMamba Epoch 123/200 - Validation: 100%|██████████| 313/313 [01:32<00:00,  3.37it/s, acc=82.2, loss=1.21]


VisionMamba Epoch [123/200] Validation Loss: 1.2099 | Accuracy: 82.15%

Checkpoint saved at /kaggle/working/model_checkpoint.pth
Checkpoint saved at epoch 123, accuracy: 82.15%


VisionMamba Epoch 124/200 - Training:  16%|█▌        | 249/1563 [03:37<19:04,  1.15it/s, acc=92.8, loss=1]   

In [None]:
def evaluate_model(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0

    test_loader_tqdm = tqdm(test_loader, desc="Evaluating")
    with torch.no_grad():
        for inputs, labels in test_loader_tqdm:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            test_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            test_loader_tqdm.set_postfix(loss=test_loss / total, acc=100. * correct / total)

    avg_loss = test_loss / len(test_loader.dataset)
    accuracy = 100. * correct / total
    print(f"\n Test Loss: {avg_loss:.4f} | Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [None]:
evaluate_model(VisionMamba_model, test_loader, criterion, device)

In [None]:
import matplotlib.pyplot as plt
import random

# Hàm denormalize ảnh đã chuẩn hóa theo mean/std
def denormalize(img_tensor, mean, std):
    mean = torch.tensor(mean).view(3, 1, 1)
    std = torch.tensor(std).view(3, 1, 1)
    return img_tensor * std + mean

def visualize_prediction(model, train_loader, device, class_names=None):
    model.eval()  # chế độ đánh giá

    # Lấy 1 batch bất kỳ
    data_iter = iter(train_loader)
    images, labels = next(data_iter)

    # Chọn ngẫu nhiên 1 ảnh trong batch
    idx = random.randint(0, len(images) - 1)
    image = images[idx].unsqueeze(0).to(device)
    label = labels[idx].item()

    # Dự đoán
    with torch.no_grad():
        outputs = model(image)
        _, predicted = outputs.max(1)
        predicted = predicted.item()

    # Denormalize để hiển thị đúng
    img_denorm = denormalize(images[idx], mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
    img_np = img_denorm.permute(1, 2, 0).cpu().numpy()  # CHW → HWC

    # Clamp lại để tránh giá trị vượt [0,1] do cộng trừ float
    img_np = img_np.clip(0, 1)

    # Hiển thị
    plt.imshow(img_np)
    plt.axis('off')

    true_label = class_names[label] if class_names else str(label)
    pred_label = class_names[predicted] if class_names else str(predicted)
    plt.title(f"Predicted: {pred_label} | True: {true_label}")
    plt.show()


In [None]:
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

visualize_prediction(VisionMamba_model, test_loader, device, class_names)

In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import confusion_matrix
import os

def show_confusion_matrix(model, dataloader, class_names, device):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Generating Confusion Matrix"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    cm = confusion_matrix(all_labels, all_preds)

    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.show()  # 🔥 Hiển thị trực tiếp thay vì lưu

In [None]:
show_confusion_matrix(VisionMamba_model, test_loader, class_names, device)

In [None]:
from sklearn.metrics import classification_report

def print_classification_report(model, dataloader, class_names, device):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating for Classification Report"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    print()
    print(classification_report(all_labels, all_preds, target_names=class_names, digits=2))

In [None]:
print_classification_report(VisionMamba_model, test_loader, class_names, device)