In [14]:
!pip install -q transformers datasets accelerate evaluate scikit-learn

!pip install --upgrade sympy



In [15]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torchvision.transforms as transforms
from datasets import load_dataset
from transformers import (
    ViTImageProcessor,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    confusion_matrix,
    classification_report
)
import evaluate
import json
import time
import warnings
import pandas as pd
import os
import random
import gc
warnings.filterwarnings('ignore')

In [16]:
# Create output directory
CSV_OUTPUT_DIR = './csv_results'
os.makedirs(CSV_OUTPUT_DIR, exist_ok=True)

# Set random seeds
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

set_seed(42)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nDevice: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    !nvidia-smi


def clear_memory():
    """Clear GPU/CPU memory cache"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()


Device: cuda
GPU: Tesla T4
Sun Dec 14 15:26:39 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   52C    P0             29W /   70W |    8480MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                    

In [17]:
print("Loading CIFAR-10 Dataset")

# Load CIFAR-10
dataset = load_dataset('cifar10')
labels = dataset['train'].features['label'].names
num_labels = len(labels)

print(f"\nDataset: CIFAR-10")
print(f"Classes: {labels}")
print(f"Training samples: {len(dataset['train'])}")
print(f"Test samples: {len(dataset['test'])}")

Loading CIFAR-10 Dataset

Dataset: CIFAR-10
Classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
Training samples: 50000
Test samples: 10000


In [18]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

# Define augmentation transforms
train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
])

def preprocess_images_train(examples):
    """Training preprocessing with augmentation"""
    images = []
    for img in examples['img']:
        img = img.convert('RGB')
        img = train_transforms(img)
        images.append(img)

    inputs = processor(images, return_tensors='pt')
    inputs['labels'] = examples['label']
    return inputs

def preprocess_images_val(examples):
    """Validation preprocessing without augmentation"""
    images = [img.convert('RGB') for img in examples['img']]
    inputs = processor(images, return_tensors='pt')
    inputs['labels'] = examples['label']
    return inputs

print("Preprocessing with augmentation")
train_ds = dataset['train'].with_transform(preprocess_images_train)
val_ds = dataset['test'].with_transform(preprocess_images_val)
print("Done")

Preprocessing with augmentation
Done


In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTConfig, ViTModel
import math

class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma

class SoftTargetCrossEntropy(nn.Module):
    def __init__(self):
        super(SoftTargetCrossEntropy, self).__init__()

    def forward(self, x, target):
        loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
        return loss.mean()

class FastWindowAttention(nn.Module):
    """
    Window based local attention
    Computes attention only within local windows for efficiency and locality bias.
    """
    def __init__(self, hidden_size, num_heads=8, window_size=7, shift_size=0):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.window_size = window_size

        self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True)
        self.proj = nn.Linear(hidden_size, hidden_size, bias=True)

    def window_partition(self, x, window_size):
        B, H, W, C = x.shape
        pad_h = (window_size - H % window_size) % window_size
        pad_w = (window_size - W % window_size) % window_size
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))

        Hp, Wp = x.shape[1], x.shape[2]
        x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
        windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
        return windows, Hp, Wp

    def window_reverse(self, windows, window_size, H, W, Hp, Wp):
        B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
        x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)

        if Hp > H or Wp > W:
            x = x[:, :H, :W, :]
        return x

    def forward(self, x, H, W):
        B, L, C = x.shape
        x_2d = x.view(B, H, W, C)

        windows, Hp, Wp = self.window_partition(x_2d, self.window_size)
        windows = windows.view(-1, self.window_size * self.window_size, C)

        qkv = self.qkv(windows).reshape(windows.shape[0], windows.shape[1], 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        attn_out = F.scaled_dot_product_attention(q, k, v)

        attn_out = attn_out.transpose(1, 2).reshape(windows.shape[0], windows.shape[1], C)
        x_out = self.window_reverse(attn_out, self.window_size, H, W, Hp, Wp)

        x_out = x_out.reshape(B, L, C)

        return self.proj(x_out)

class HybridLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, window_size=7, dropout=0.1):
        super().__init__()
        self.local_attn = FastWindowAttention(hidden_size, num_heads, window_size=window_size)
        self.global_attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=dropout, batch_first=True)

        self.gate = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.Sigmoid()
        )

        self.norm1 = nn.LayerNorm(hidden_size)
        self.norm2 = nn.LayerNorm(hidden_size)
        self.ls1 = LayerScale(hidden_size)
        self.ls2 = LayerScale(hidden_size)

        self.ffn = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size * 4, hidden_size),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        B, N, C = x.shape
        cls_token = x[:, 0:1]
        patches = x[:, 1:]

        H = W = int(math.sqrt(patches.shape[1]))

        x_norm = self.norm1(x)
        patches_norm = x_norm[:, 1:]

        # local window attention
        local_out_patches = self.local_attn(patches_norm, H, W)
        local_out = torch.cat([torch.zeros_like(cls_token), local_out_patches], dim=1)

        # global attention
        global_out, _ = self.global_attn(x_norm, x_norm, x_norm, need_weights=False)

        # gating mechanism
        combined = torch.cat([local_out, global_out], dim=-1)
        gate_score = self.gate(combined)
        attn_out = gate_score * local_out + (1 - gate_score) * global_out

        x = x + self.ls1(attn_out)
        x = x + self.ls2(self.ffn(self.norm2(x)))

        return x

class ViTWithHybridAttention(nn.Module):
    def __init__(self, base_model_name='google/vit-base-patch16-224', num_labels=10, num_hybrid_layers=2):
        super().__init__()
        self.vit = ViTModel.from_pretrained(base_model_name)
        config = self.vit.config

        self.hybrid_layers = nn.ModuleList([
            HybridLayer(config.hidden_size, config.num_attention_heads, window_size=3, dropout=0.1)
            for _ in range(num_hybrid_layers)
        ])

        self.classifier = nn.Linear(config.hidden_size, num_labels)

    def forward(self, pixel_values, labels=None):
        out = self.vit(pixel_values)
        x = out.last_hidden_state

        for layer in self.hybrid_layers:
            x = layer(x)

        # global average pooling
        patch_embeddings = x[:, 1:, :]
        global_pool = patch_embeddings.mean(dim=1)

        logits = self.classifier(global_pool)

        loss = None
        if labels is not None:
            if labels.dim() > 1:
                loss_fct = SoftTargetCrossEntropy()
            else:
                loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)

        return {'loss': loss, 'logits': logits}

def create_model(num_labels=10, num_hybrid_layers=2):
    """Create ViT with Hybrid Attention model"""
    return ViTWithHybridAttention(num_labels=num_labels, num_hybrid_layers=num_hybrid_layers)

def count_parameters(model):
    """Count trainable and total parameters"""
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total

## 4. CutMix & MixUp Augmentation

In [29]:
class AugmentationConfig:
    """Config for augmentation settings"""
    def __init__(self, use_cutmix=True, use_mixup=True,
                 cutmix_prob=0.4, mixup_prob=0.2,
                 cutmix_alpha=0.8, mixup_alpha=0.6):
        self.use_cutmix = use_cutmix
        self.use_mixup = use_mixup
        self.cutmix_prob = cutmix_prob
        self.mixup_prob = mixup_prob
        self.cutmix_alpha = cutmix_alpha
        self.mixup_alpha = mixup_alpha

def rand_bbox(size, lam):
    """Generate random bounding box for CutMix"""
    W, H = size[2], size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def cutmix_data(images, labels, alpha=1.0):
    """Apply CutMix augmentation"""
    lam = np.random.beta(alpha, alpha)
    batch_size = images.size(0)
    index = torch.randperm(batch_size).to(images.device)

    bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
    images[:, :, bbx1:bbx2, bby1:bby2] = images[index, :, bbx1:bbx2, bby1:bby2]

    # Adjust lambda based on actual box size
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size()[-1] * images.size()[-2]))

    return images, labels, labels[index], lam

def mixup_data(images, labels, alpha=1.0):
    """Apply MixUp augmentation"""
    lam = np.random.beta(alpha, alpha)
    batch_size = images.size(0)
    index = torch.randperm(batch_size).to(images.device)

    mixed_images = lam * images + (1 - lam) * images[index]

    return mixed_images, labels, labels[index], lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Loss function for MixUp/CutMix"""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

class AugmentedTrainer(Trainer):
    """Custom trainer with CutMix and MixUp support"""

    def __init__(self, *args, aug_config=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.aug_config = aug_config or AugmentationConfig()

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """Custom loss with MixUp/CutMix - memory optimized"""
        inputs = dict(inputs)

        labels = inputs.pop("labels")


        if self.model.training:
            r = np.random.rand()

            if self.aug_config.use_cutmix and r < self.aug_config.cutmix_prob:
                # CutMix
                inputs['pixel_values'], labels_a, labels_b, lam = cutmix_data(
                    inputs['pixel_values'], labels, self.aug_config.cutmix_alpha
                )
                outputs = model(**inputs)

                logits = outputs['logits'] if isinstance(outputs, dict) else outputs.logits

                loss = lam * nn.functional.cross_entropy(logits, labels_a) + \
                       (1 - lam) * nn.functional.cross_entropy(logits, labels_b)

                # free memory
                del labels_a, labels_b

            elif self.aug_config.use_mixup and r < (self.aug_config.cutmix_prob + self.aug_config.mixup_prob):
                # MixUp
                inputs['pixel_values'], labels_a, labels_b, lam = mixup_data(
                    inputs['pixel_values'], labels, self.aug_config.mixup_alpha
                )
                outputs = model(**inputs)

                logits = outputs['logits'] if isinstance(outputs, dict) else outputs.logits

                loss = lam * nn.functional.cross_entropy(logits, labels_a) + \
                       (1 - lam) * nn.functional.cross_entropy(logits, labels_b)

                # Free memory
                del labels_a, labels_b

            else:
                # Normal training
                outputs = model(**inputs, labels=labels)

                loss = outputs['loss'] if isinstance(outputs, dict) else outputs.loss
        else:
            # Validation
            outputs = model(**inputs, labels=labels)
            loss = outputs['loss'] if isinstance(outputs, dict) else outputs.loss

        return (loss, outputs) if return_outputs else loss

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
          inputs = self._prepare_inputs(inputs)

          labels = None
          if self.label_names:
              labels = tuple(inputs.get(name).detach() for name in self.label_names if inputs.get(name) is not None)
              if len(labels) == 1:
                  labels = labels[0]
              elif len(labels) == 0:
                  labels = None

          has_labels = labels is not None

          with torch.no_grad():
              if has_labels:
                  loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
                  loss = loss.detach()
                  if isinstance(outputs, dict):
                      logits = outputs.get('logits')
                      if logits is not None:
                          logits = logits.detach()
                  else:
                      logits = outputs.logits.detach()
              else:
                  loss = None
                  outputs = model(**inputs)
                  if isinstance(outputs, dict):
                      logits = outputs.get('logits')
                      if logits is not None:
                          logits = logits.detach()
                  else:
                      logits = outputs.logits.detach()

          if prediction_loss_only:
              return (loss, None, None)

          return (loss, logits, labels)

## 5. Experiment Configurations

In [21]:
EXPERIMENT_CONFIG = {
    'name': 'hybrid_attention',
    'num_hybrid_layers': 3,
    'description': 'ViT with Local + Global Attention (3 layers, window=3)',
    'learning_rate': 5e-5,
    'batch_size': 64,
    'num_epochs': 50,
    'weight_decay': 0.05,
    'warmup_ratio': 0.1,
    'gradient_accumulation_steps': 1,
    'aug_config': AugmentationConfig(
        use_cutmix=True,
        use_mixup=True,
        cutmix_prob=0.4,
        mixup_prob=0.2,
        cutmix_alpha=0.8,
        mixup_alpha=0.6
    )
}

print(f"Configured experiment: {EXPERIMENT_CONFIG['name']}")
print(f"Description: {EXPERIMENT_CONFIG['description']}")

Configured experiment: hybrid_attention
Description: ViT with Local + Global Attention (3 layers, window=3)


## 6. Metrics & Evaluation Functions

In [22]:
accuracy_metric = evaluate.load('accuracy')
f1_metric = evaluate.load('f1')

def compute_metrics(eval_pred):
    """Compute metrics during training/evaluation - memory optimized"""
    predictions, labels = eval_pred
    if isinstance(predictions, torch.Tensor):
        predictions = predictions.detach().cpu().numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.detach().cpu().numpy()


    predictions = np.argmax(predictions, axis=1)

    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
    f1_macro = f1_metric.compute(predictions=predictions, references=labels, average='macro')
    f1_weighted = f1_metric.compute(predictions=predictions, references=labels, average='weighted')

    precision_macro, recall_macro, _, _ = precision_recall_fscore_support(
        labels, predictions, average='macro', zero_division=0
    )

    # Clear memory
    del predictions, labels

    return {
        'accuracy': accuracy['accuracy'],
        'f1_macro': f1_macro['f1'],
        'f1_weighted': f1_weighted['f1'],
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
    }

In [23]:
def evaluate_model_detailed(trainer, dataset, config_name):
    """Generate comprehensive evaluation metrics and visualizations"""
    print(f"Evaluation: {config_name.upper()}")

    predictions = trainer.predict(dataset)
    preds = np.argmax(predictions.predictions, axis=1)
    true_labels = predictions.label_ids

    # Classification report
    print("Classification report")
    report = classification_report(true_labels, preds, target_names=labels, digits=4, output_dict=True)
    print(classification_report(true_labels, preds, target_names=labels, digits=4))

    report_df = pd.DataFrame(report).transpose()
    report_csv_path = os.path.join(CSV_OUTPUT_DIR, f'{config_name}_classification_report.csv')
    report_df.to_csv(report_csv_path)
    print(f"Saved to: {report_csv_path}")

    # Confusion matrix
    cm = confusion_matrix(true_labels, preds)
    cm_df = pd.DataFrame(cm, index=labels, columns=labels)
    cm_csv_path = os.path.join(CSV_OUTPUT_DIR, f'{config_name}_confusion_matrix.csv')
    cm_df.to_csv(cm_csv_path)

    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=labels, yticklabels=labels, cbar_kws={'label': 'Count'})
    plt.title(f'Confusion Matrix - {config_name.replace("_", " ").title()}',
              fontsize=16, fontweight='bold')
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.tight_layout()
    plt.savefig(f'confusion_matrix_{config_name}.png', dpi=300, bbox_inches='tight')
    plt.show()

    # Per-class metrics
    precision, recall, f1, support = precision_recall_fscore_support(
        true_labels, preds, average=None, zero_division=0
    )

    per_class_df = pd.DataFrame({
        'Class': labels,
        'Precision': precision,
        'Recall': recall,
        'F1-Score': f1,
        'Support': support
    })
    per_class_csv_path = os.path.join(CSV_OUTPUT_DIR, f'{config_name}_per_class_metrics.csv')
    per_class_df.to_csv(per_class_csv_path, index=False)

    # Overall metrics
    accuracy = accuracy_score(true_labels, preds)
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        true_labels, preds, average='macro', zero_division=0
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        true_labels, preds, average='weighted', zero_division=0
    )

    print("All Metrics:")
    print(f"Accuracy:            {accuracy:.4f}")
    print(f"Precision (Macro):   {precision_macro:.4f}")
    print(f"Recall (Macro):      {recall_macro:.4f}")
    print(f"F1-Score (Macro):    {f1_macro:.4f}")
    print(f"F1-Score (Weighted): {f1_weighted:.4f}")

    return {
        'config_name': config_name,
        'accuracy': accuracy,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'f1_macro': f1_macro,
        'precision_weighted': precision_weighted,
        'recall_weighted': recall_weighted,
        'f1_weighted': f1_weighted,
    }

In [24]:
def plot_training_history(log_history, config_name):
    """Plot training curves"""
    train_loss, eval_loss, eval_accuracy = [], [], []
    steps, eval_steps, epochs = [], [], []

    for entry in log_history:
        if 'loss' in entry and 'step' in entry:
            train_loss.append(entry['loss'])
            steps.append(entry['step'])
        if 'eval_loss' in entry and 'step' in entry:
            eval_loss.append(entry['eval_loss'])
            eval_accuracy.append(entry.get('eval_accuracy', 0))
            eval_steps.append(entry['step'])
            epochs.append(entry.get('epoch', 0))

    if not train_loss:
        print("Warning: No training history to plot")
        return

    training_history_df = pd.DataFrame({'step': steps, 'train_loss': train_loss})
    eval_history_df = pd.DataFrame({
        'step': eval_steps, 'epoch': epochs,
        'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy
    })

    training_history_df.to_csv(os.path.join(CSV_OUTPUT_DIR, f'{config_name}_training_history.csv'), index=False)
    eval_history_df.to_csv(os.path.join(CSV_OUTPUT_DIR, f'{config_name}_eval_history.csv'), index=False)

    # Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5))

    ax1.plot(steps, train_loss, label='Training Loss', linewidth=2, color='#2E86DE')
    if eval_loss:
        ax1.plot(eval_steps, eval_loss, label='Validation Loss',
                linewidth=2.5, marker='o', markersize=6, color='#EE5A6F')
    ax1.set_xlabel('Training Steps', fontsize=11)
    ax1.set_ylabel('Loss', fontsize=11)
    ax1.set_title(f'Loss Curves - {config_name.replace("_", " ").title()}',
                  fontsize=13, fontweight='bold')
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3)

    if eval_accuracy:
        ax2.plot(eval_steps, eval_accuracy, label='Validation Accuracy',
                linewidth=2.5, marker='o', markersize=6, color='#26DE81')
        ax2.set_xlabel('Training Steps', fontsize=11)
        ax2.set_ylabel('Accuracy', fontsize=11)
        ax2.set_title(f'Validation Accuracy - {config_name.replace("_", " ").title()}',
                      fontsize=13, fontweight='bold')
        ax2.legend(fontsize=10)
        ax2.grid(True, alpha=0.3)
        ax2.set_ylim([0, 1.0])

    plt.tight_layout()
    plt.savefig(f'training_history_{config_name}.png', dpi=300, bbox_inches='tight')
    plt.show()

## 7. Training Function

In [30]:
def train_hybrid_attention(config):
    """Train ViT with Hybrid Attention"""
    config_name = config['name']

    print(f"Experiments: {config_name.upper()}")
    print(f"\nDescription: {config['description']}")
    print(f"Learning Rate: {config['learning_rate']}")
    print(f"Batch Size: {config['batch_size']}")
    print(f"Gradient Accumulation: {config['gradient_accumulation_steps']}")
    print(f"Effective Batch: {config['batch_size'] * config['gradient_accumulation_steps']}")
    print(f"Epochs: {config['num_epochs']}")
    print(f"Weight Decay: {config['weight_decay']}")
    print(f"Warmup Ratio: {config['warmup_ratio']}")

    aug_cfg = config['aug_config']
    print(f"\nAugmentation Settings:")
    print(f"\tBasic Augmentation: Enabled (flip, rotation, color jitter)")
    print(f"\tCutMix: {aug_cfg.use_cutmix} (prob={aug_cfg.cutmix_prob}, alpha={aug_cfg.cutmix_alpha})")
    print(f"\tMixUp: {aug_cfg.use_mixup} (prob={aug_cfg.mixup_prob}, alpha={aug_cfg.mixup_alpha})")

    # Create model
    print("\nCreating model...")
    model = create_model(
        num_labels=num_labels,
        num_hybrid_layers=config['num_hybrid_layers']
    )
    model = model.cuda()

    # Compile for optimization
    model = torch.compile(model)

    # Enable gradient checkpointing for memory efficiency
    supports_gradient_checkpointing = hasattr(model, 'gradient_checkpointing_enable')
    if supports_gradient_checkpointing:
        model.gradient_checkpointing_enable()
        print("Gradient checkpointing enabled on model")
    else:
        print("Model does not support gradient_checkpointing_enable")

    trainable_params, total_params = count_parameters(model)
    print(f"Trainable Parameters: {trainable_params:,}")
    print(f"Total Parameters: {total_params:,}")

    output_dir = f'./vit-modified-{config_name}-augmented'

    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=config['num_epochs'],
        per_device_train_batch_size=config['batch_size'],
        per_device_eval_batch_size=config['batch_size'],
        gradient_accumulation_steps=config['gradient_accumulation_steps'],
        learning_rate=config['learning_rate'],
        weight_decay=config['weight_decay'],
        warmup_ratio=config['warmup_ratio'],
        eval_strategy='epoch',
        save_strategy='epoch',
        logging_strategy='steps',
        logging_steps=50,
        load_best_model_at_end=True,
        metric_for_best_model='eval_loss',
        greater_is_better=False,
        save_total_limit=2,
        remove_unused_columns=False,
        label_names=["labels"],
        push_to_hub=False,
        report_to='none',
        seed=42,
        fp16=torch.cuda.is_available(),
        dataloader_num_workers=2,
        dataloader_pin_memory=True,
        max_grad_norm=1.0,
        gradient_checkpointing=supports_gradient_checkpointing,
        eval_accumulation_steps=1,
        optim="adamw_torch",
        lr_scheduler_type="cosine",
    )

    # Initialize trainer
    trainer = AugmentedTrainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=7)],
        aug_config=config['aug_config']
    )

    # Train
    print(f"\nStarting training...")
    total_steps = len(train_ds) // config['batch_size'] * config['num_epochs']
    print(f"Total training steps: {total_steps}")

    start_time = time.time()
    train_result = trainer.train()
    training_time = time.time() - start_time

    print(f"\nTraining completed!")
    print(f"\tTraining time: {training_time:.2f} seconds ({training_time/60:.2f} minutes)")
    print(f"\tTraining loss: {train_result.metrics.get('train_loss', 0):.4f}")

    clear_memory()

    # Evaluate
    print(f"\nEvaluating...")
    eval_metrics = trainer.evaluate()
    print(f"Evaluation completed!")
    print(f"\tTest Accuracy: {eval_metrics.get('eval_accuracy', 0):.4f}")
    print(f"\tTest Loss: {eval_metrics.get('eval_loss', 0):.4f}")

    clear_memory()

    # Detailed evaluation
    overall_metrics = evaluate_model_detailed(trainer, val_ds, config_name)
    overall_metrics['train_time'] = training_time
    overall_metrics['train_loss'] = train_result.metrics.get('train_loss', 0)
    overall_metrics['eval_loss'] = eval_metrics.get('eval_loss', 0)
    overall_metrics['trainable_params'] = trainable_params
    overall_metrics['total_params'] = total_params
    overall_metrics['description'] = config['description']

    # Plot history
    plot_training_history(trainer.state.log_history, config_name)

    # Save metrics
    overall_metrics_df = pd.DataFrame([overall_metrics])
    overall_csv_path = os.path.join(CSV_OUTPUT_DIR, f'{config_name}_overall_metrics.csv')
    overall_metrics_df.to_csv(overall_csv_path, index=False)
    print(f"Overall metrics saved to: {overall_csv_path}")

    # Save model
    trainer.save_model(output_dir)
    processor.save_pretrained(output_dir)
    print(f"\nModel saved to {output_dir}")

    return overall_metrics

## 8. Run Experiment

In [None]:
print("Starting Experiment")


clear_memory()

try:
    results = train_hybrid_attention(EXPERIMENT_CONFIG)
    print("\nExperiment completed successfully!")
except Exception as e:
    print(f"\nError during training: {str(e)}")
    import traceback
    traceback.print_exc()
    clear_memory()

Starting Experiment
Experiments: HYBRID_ATTENTION

Description: ViT with Local + Global Attention (3 layers, window=3)
Learning Rate: 5e-05
Batch Size: 64
Gradient Accumulation: 1
Effective Batch: 64
Epochs: 50
Weight Decay: 0.05
Warmup Ratio: 0.1

Augmentation Settings:
	Basic Augmentation: Enabled (flip, rotation, color jitter)
	CutMix: True (prob=0.4, alpha=0.8)
	MixUp: True (prob=0.2, alpha=0.6)

Creating model...


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model does not support gradient_checkpointing_enable
Trainable Parameters: 118,293,514
Total Parameters: 118,293,514

Starting training...
Total training steps: 39050


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,F1 Weighted,Precision Macro,Recall Macro
1,0.8281,0.254338,0.9729,0.972859,0.972859,0.97354,0.9729


## 9. Final Results

In [None]:
if 'results' in dir() and results is not None:
    print("Final results")


    print(f"\n{'Metric':<25} {'Value':<15}")
    print("-"*40)
    print(f"{'Accuracy':<25} {results['accuracy']:.4f}")
    print(f"{'F1-Score (Macro)':<25} {results['f1_macro']:.4f}")
    print(f"{'F1-Score (Weighted)':<25} {results['f1_weighted']:.4f}")
    print(f"{'Precision (Macro)':<25} {results['precision_macro']:.4f}")
    print(f"{'Recall (Macro)':<25} {results['recall_macro']:.4f}")
    print(f"{'Training Loss':<25} {results['train_loss']:.4f}")
    print(f"{'Eval Loss':<25} {results['eval_loss']:.4f}")
    print(f"{'Training Time (min)':<25} {results['train_time']/60:.2f}")
    print(f"{'Trainable Params (M)':<25} {results['trainable_params']/1e6:.2f}")

    # Save results
    with open('hybrid_attention_results.json', 'w') as f:
        json.dump(results, f, indent=2)
    print("\nResults saved to hybrid_attention_results.json")
else:
    print("\nNo results available!")

## 10. Download Results

In [None]:
# Download all results as a zip file
!zip -r results.zip *.png *.json csv_results/ vit-modified-*/

from google.colab import files
files.download('results.zip')

print("\nResults downloaded!")