# MobileViT Training with Optuna Optimization

This notebook trains a MobileViT model on the Kvasir-V2 dataset with Optuna TPE hyperparameter optimization.

## Features:
- Custom MobileViT architecture implementation
- Optuna TPE hyperparameter optimization
- Comprehensive metrics tracking
- ROC curve analysis

## Setup Instructions:
1. **Enable GPU**: Runtime ‚Üí Change runtime type ‚Üí GPU (T4 or better)
2. **Run all cells sequentially**

## 1. Check GPU Availability

In [None]:
# Check GPU
!nvidia-smi

import torch

print("\nPyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA version:", torch.version.cuda)
    print("GPU:", torch.cuda.get_device_name(0))
    print("GPU count:", torch.cuda.device_count())

## 2. Install Required Packages

In [None]:
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q optuna
!pip install -q scikit-learn matplotlib seaborn pandas numpy tqdm

print("‚úÖ All packages installed successfully!")

## 3. Download Kvasir-V2 Dataset

In [None]:
import os
import zipfile
import urllib.request

# Download Kvasir-V2 dataset
dataset_url = "https://datasets.simula.no/downloads/kvasir/kvasir-dataset-v2.zip"
dataset_zip = "kvasir-dataset-v2.zip"

if not os.path.exists("kvasir-dataset-v2"):
    print("Downloading Kvasir dataset...")
    urllib.request.urlretrieve(dataset_url, dataset_zip)
    
    print("Extracting dataset...")
    with zipfile.ZipFile(dataset_zip, 'r') as zip_ref:
        zip_ref.extractall(".")
    
    os.remove(dataset_zip)
    print("‚úÖ Dataset ready!")
else:
    print("‚úÖ Dataset already exists!")

# Verify dataset
data_dir = "kvasir-dataset-v2/kvasir-dataset-v2"
classes = sorted(os.listdir(data_dir))
print(f"\nFound {len(classes)} classes: {classes}")

## 4. Define MobileViT Architecture

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Helper functions
def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.Hardswish()
    )

def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
    return nn.Sequential(
        nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.Hardswish()
    )

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    
    def forward(self, x, *args, **kwargs):
        return self.fn(self.norm(x), *args, **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.Hardswish(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        self.inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(self.inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.reshape(t.shape[0], t.shape[1], self.heads, t.shape[2] // self.heads).transpose(1, 2), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).reshape(out.shape[0], out.shape[1], self.inner_dim)
        return self.to_out(out)

class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.attn = PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout))
        self.ff = PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))

    def forward(self, x):
        x = x + self.attn(x)
        x = x + self.ff(x)
        return x

class MobileNetV2Block(nn.Module):
    def __init__(self, inp, oup, stride, expansion):
        super().__init__()
        self.stride = stride
        hidden_dim = int(round(inp * expansion))
        self.use_res_connect = self.stride == 1 and inp == oup

        if expansion == 1:
            self.conv = nn.Sequential(
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.Hardswish(),
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.Hardswish(),
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.Hardswish(),
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)

class MobileViTBlock(nn.Module):
    def __init__(self, dim, depth, channel, kernal_size, patch_size, mlp_dim, dropout=0.):
        super().__init__()
        self.ph, self.pw = patch_size

        self.conv1 = conv_nxn_bn(channel, channel, kernal_size)
        self.conv2 = conv_1x1_bn(channel, dim)
        self.transformer = nn.Sequential(*[TransformerBlock(dim, 4, 8, mlp_dim, dropout) for _ in range(depth)])
        self.conv3 = conv_1x1_bn(dim, channel)
        self.conv4 = conv_nxn_bn(2 * channel, channel, kernal_size)
    
    def forward(self, x):
        y = x.clone()
        x = self.conv1(x)
        x = self.conv2(x)
        
        shape = x.shape
        B, D, H, W = shape

        pad_h = (self.ph - (H % self.ph)) % self.ph
        pad_w = (self.pw - (W % self.pw)) % self.pw

        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, (0, pad_w, 0, pad_h), mode='constant', value=0)
            H = H + pad_h
            W = W + pad_w

        h_split, w_split = H // self.ph, W // self.pw

        x = x.reshape(B, D, h_split, self.ph, w_split, self.pw)
        x = x.permute(0, 2, 4, 3, 5, 1).contiguous()
        x = x.view(B * h_split * w_split, self.ph * self.pw, D)

        x = self.transformer(x)

        x = x.view(B, h_split, w_split, self.ph, self.pw, D)
        x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
        x = x.view(B, D, H, W)

        if pad_h > 0 or pad_w > 0:
            x = x[:, :, :shape[2], :shape[3]]

        x = self.conv3(x)
        x = torch.cat((x, y), 1)
        x = self.conv4(x)
        return x

class MobileViT(nn.Module):
    def __init__(self, image_size, num_classes, dims, channels, depths, patch_size=2, mlp_dim_ratio=2):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        
        ph, pw = patch_size, patch_size

        self.conv1 = conv_nxn_bn(3, channels[0], kernal_size=3, stride=2)

        self.stem = nn.ModuleList([])
        in_channel = channels[0]

        self.stem.append(MobileNetV2Block(in_channel, channels[1], stride=1, expansion=4))
        in_channel = channels[1]
        self.stem.append(MobileNetV2Block(in_channel, channels[2], stride=2, expansion=4))
        in_channel = channels[2]
        self.stem.append(MobileNetV2Block(in_channel, channels[3], stride=1, expansion=4))
        in_channel = channels[3]
        self.stem.append(MobileNetV2Block(in_channel, channels[4], stride=2, expansion=4))
        in_channel = channels[4]

        self.mobilevit_blocks = nn.ModuleList([])
        for i in range(len(dims)):
            mlp_dim = dims[i] * mlp_dim_ratio
            self.mobilevit_blocks.append(MobileViTBlock(dims[i], depths[i], in_channel, 3, (ph, pw), mlp_dim))
            
            if i < len(dims) - 1:
                next_channel_idx = i + 5
                self.mobilevit_blocks.append(MobileNetV2Block(in_channel, channels[next_channel_idx], stride=2, expansion=4))
                in_channel = channels[next_channel_idx]

        self.conv2 = conv_1x1_bn(in_channel, dims[-1])
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(dims[-1], num_classes)

    def forward(self, x):
        x = self.conv1(x)
        for block in self.stem:
            x = block(x)
        
        for block in self.mobilevit_blocks:
            x = block(x)

        x = self.conv2(x)
        x = self.pool(x).flatten(1)
        x = self.fc(x)
        return x

print("‚úÖ MobileViT architecture defined!")

## 5. Define Dataset Class

In [None]:
from PIL import Image
from torch.utils.data import Dataset

class KvasirDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.image_paths = []
        self.labels = []

        for class_name in self.classes:
            class_dir = os.path.join(root_dir, class_name)
            if not os.path.isdir(class_dir):
                continue
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
                    self.image_paths.append(os.path.join(class_dir, img_name))
                    self.labels.append(self.class_to_idx[class_name])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

print("‚úÖ Dataset class defined!")

## 6. Define Training and Validation Functions

In [None]:
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import torch.optim as optim

def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_labels = []
    all_predictions = []

    for images, labels in tqdm(dataloader, desc="Training"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())
    
    avg_loss = running_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_predictions)
    return avg_loss, accuracy

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_labels = []
    all_predictions = []
    all_probabilities = []

    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validation"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            
            probabilities = F.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs.data, 1)

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())
    
    avg_loss = running_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_predictions)
    return avg_loss, accuracy, all_labels, all_probabilities

print("‚úÖ Training functions defined!")

## 7. Define Optuna Objective Function

In [None]:
import optuna
import numpy as np
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, roc_curve
import time

def objective(trial):
    # Hyperparameters to tune
    image_size = 224
    num_classes = 8
    dataset_root = 'kvasir-dataset-v2/kvasir-dataset-v2'
    
    batch_size = trial.suggest_categorical('batch_size', [16, 32, 64])
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
    num_epochs = 5  # Fixed for faster optimization

    # MobileViT architectural hyperparameters
    depths = [
        trial.suggest_int('depth_0', 1, 3),
        trial.suggest_int('depth_1', 2, 5),
        trial.suggest_int('depth_2', 2, 4)
    ]
    dims = [96, 192, 384]
    channels = [16, 32, 64, 128, 160, 192, 256, 320, 384, 512]

    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device} for trial {trial.number}")
    
    # Data transformations
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Dataset and DataLoaders
    full_dataset = KvasirDataset(root_dir=dataset_root, transform=transform)
    
    # Split dataset: 70% train, 15% val, 15% test
    train_size = int(0.7 * len(full_dataset))
    val_size = int(0.15 * len(full_dataset))
    test_size = len(full_dataset) - train_size - val_size
    train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

    pin_memory = True if torch.cuda.is_available() else False
    num_workers = 2 if torch.cuda.is_available() else 0
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                             num_workers=num_workers, pin_memory=pin_memory)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, 
                           num_workers=num_workers, pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, 
                            num_workers=num_workers, pin_memory=pin_memory)

    # Model initialization
    model = MobileViT(image_size=image_size, num_classes=num_classes, 
                     dims=dims, channels=channels, depths=depths).to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    best_val_accuracy = 0.0
    model_save_path = f"mobilevit_kvasir_v2_trial_{trial.number}.pth"

    # Training loop
    for epoch in range(num_epochs):
        train_loss, train_accuracy = train(model, train_loader, criterion, optimizer, device)
        val_loss, val_accuracy, _, _ = validate(model, val_loader, criterion, device)

        print(f"  Trial {trial.number}, Epoch {epoch+1}/{num_epochs}: "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}")

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), model_save_path)

        trial.report(val_accuracy, epoch)

        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    # Load best model for final evaluation
    model.load_state_dict(torch.load(model_save_path))
    
    # Test evaluation
    test_loss, test_accuracy, test_labels, test_probs = validate(model, test_loader, criterion, device)
    
    test_preds = np.argmax(test_probs, axis=1)
    test_precision = precision_score(test_labels, test_preds, average='weighted')
    test_recall = recall_score(test_labels, test_preds, average='weighted')
    test_f1 = f1_score(test_labels, test_preds, average='weighted')

    print(f"\n--- Trial {trial.number} Results ---")
    print(f"  Test Accuracy: {test_accuracy:.4f}")
    print(f"  Test Precision: {test_precision:.4f}")
    print(f"  Test Recall: {test_recall:.4f}")
    print(f"  Test F1-score: {test_f1:.4f}")
    print("------------------------------------\n")

    return best_val_accuracy

print("‚úÖ Optuna objective function defined!")

## 8. Run Optuna Optimization

In [None]:
print("\nüöÄ Starting Optuna optimization...\n")

study = optuna.create_study(
    direction="maximize",
    pruner=optuna.pruners.MedianPruner(),
    study_name="mobilevit_optimization"
)

study.optimize(objective, n_trials=15)

print("\n" + "="*60)
print("OPTUNA OPTIMIZATION COMPLETE")
print("="*60)
print(f"Number of finished trials: {len(study.trials)}")
print(f"\nBest trial:")
print(f"  Value (Val Accuracy): {study.best_value:.4f}")
print(f"\n  Best Hyperparameters:")
for key, value in study.best_params.items():
    print(f"    {key}: {value}")
print("="*60)

## 9. Save Optimization Results

In [None]:
import pandas as pd

# Save all trials to CSV
trials_df = study.trials_dataframe()
trials_df.to_csv('optuna_study_results.csv', index=False)

print("‚úÖ Optimization results saved to 'optuna_study_results.csv'")

# Display top 5 trials
print("\nTop 5 Trials:")
print(trials_df.nlargest(5, 'value')[['number', 'value', 'params_batch_size', 
                                       'params_learning_rate', 'params_depth_0', 
                                       'params_depth_1', 'params_depth_2']])

## 10. Plot Optimization History

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Plot optimization history
axes[0].plot([t.number for t in study.trials], [t.value for t in study.trials], 'o-')
axes[0].axhline(y=study.best_value, color='r', linestyle='--', label=f'Best: {study.best_value:.4f}')
axes[0].set_xlabel('Trial Number')
axes[0].set_ylabel('Validation Accuracy')
axes[0].set_title('Optuna Optimization History')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot parameter importance (if available)
try:
    importance = optuna.importance.get_param_importances(study)
    params = list(importance.keys())
    values = list(importance.values())
    
    axes[1].barh(params, values)
    axes[1].set_xlabel('Importance')
    axes[1].set_title('Hyperparameter Importance')
    axes[1].grid(True, alpha=0.3, axis='x')
except:
    axes[1].text(0.5, 0.5, 'Parameter importance\nnot available', 
                ha='center', va='center', fontsize=12)
    axes[1].axis('off')

plt.tight_layout()
plt.savefig('optuna_history.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n‚úÖ Optimization plots saved!")

## 11. Save Best Model

In [None]:
# Save the best model
best_trial = study.best_trial
best_model_path = "mobilevit_kvasir_v2_best_optuna.pth"

# Copy the best trial's model
import shutil
source_path = f"mobilevit_kvasir_v2_trial_{best_trial.number}.pth"
if os.path.exists(source_path):
    shutil.copy(source_path, best_model_path)
    print(f"‚úÖ Best model saved to {best_model_path}")
    print(f"   (from trial {best_trial.number})")
else:
    print(f"‚ö†Ô∏è Best model file not found: {source_path}")

## 12. Download Results (Optional)

In [None]:
# Uncomment to download files to your local machine
# from google.colab import files

# files.download('mobilevit_kvasir_v2_best_optuna.pth')
# files.download('optuna_study_results.csv')
# files.download('optuna_history.png')

print("\n" + "="*60)
print("üéâ MOBILEVIT TRAINING COMPLETE!")
print("="*60)
print("\nGenerated files:")
print("  - mobilevit_kvasir_v2_best_optuna.pth (Best model)")
print("  - optuna_study_results.csv (All trials)")
print("  - optuna_history.png (Optimization plots)")
print("  - mobilevit_kvasir_v2_trial_*.pth (Individual trial models)")
print("="*60)