In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import math

from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, classification_report, confusion_matrix)
import seaborn as sns
import timm
from tqdm import tqdm

In [None]:
class MVTecMultiClassDataset(Dataset):
    def __init__(self, root_dir, categories, split='train', transform=None):
        self.root_dir = root_dir 
        self.categories = categories
        self.split = split
        self.transform = transform
        
        self.category_to_idx = {cat: idx for idx, cat in enumerate(categories)}
        self.num_classes = len(categories)
        
        self.image_paths = []
        self.labels = []
        
        for category in categories:
            category_path = os.path.join(root_dir, category, split)
            
            if split == 'train':
                good_path = os.path.join(category_path, 'good')
                if os.path.exists(good_path):
                    for img_name in os.listdir(good_path):
                        if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                            self.image_paths.append(os.path.join(good_path, img_name))
                            self.labels.append(self.category_to_idx[category])
            else:
                if os.path.exists(category_path):
                    for defect_type in os.listdir(category_path):
                        defect_path = os.path.join(category_path, defect_type)
                        if os.path.isdir(defect_path):
                            for img_name in os.listdir(defect_path):
                                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                                    self.image_paths.append(os.path.join(defect_path, img_name))
                                    self.labels.append(self.category_to_idx[category])
        
        print(f"Loaded {len(self.image_paths)} images for {split} split")
        print(f"Categories: {self.categories}")
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        try:
            image = Image.open(self.image_paths[idx]).convert('RGB')
            label = self.labels[idx]
            
            if self.transform:
                image = self.transform(image)
                
            return image, label
        except Exception as e:
            print(f"Error loading image {self.image_paths[idx]}: {e}")
            dummy_image = Image.new('RGB', (224, 224), (0, 0, 0))
            if self.transform:
                dummy_image = self.transform(dummy_image)
            return dummy_image, self.labels[idx]

In [3]:
transform_train = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

In [4]:
categories = ['bottle', 'cable', 'capsule', 'carpet', 'grid', 
              'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
              'tile', 'toothbrush', 'transistor', 'wood', 'zipper']

In [5]:
train_dataset = MVTecMultiClassDataset(
    root_dir='mvtec_anomaly_detection',
    categories=categories,
    split='train',
    transform=transform_train
)

Loaded 3838 images for train split
Categories: ['bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper']


In [6]:
test_dataset = MVTecMultiClassDataset(
    root_dir='mvtec_anomaly_detection',
    categories=categories,
    split='test',
    transform=transform_test
)

Loaded 1725 images for test split
Categories: ['bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper']


In [7]:
train_loader = DataLoader(
    train_dataset, 
    batch_size=32, 
    shuffle=True,
    num_workers=0
)

In [8]:
test_loader = DataLoader(
    test_dataset, 
    batch_size=32, 
    shuffle=False,
    num_workers=0
)

In [9]:
print(f"Training samples: {len(train_dataset)}")
print(f"Testing samples: {len(test_dataset)}")
print(f"Number of classes: {len(categories)}")

Training samples: 3838
Testing samples: 1725
Number of classes: 15


In [10]:
def analyze_dataset(dataset, dataset_name):
    print(f"\n=== {dataset_name} Analysis ===")
    print(f"Total samples: {len(dataset)}")
    
    label_counts = {}
    for _, label in dataset:
        if label in label_counts:
            label_counts[label] += 1
        else:
            label_counts[label] = 1
    
    print("Samples per category:")
    for cat_idx, cat_name in enumerate(dataset.categories):
        count = label_counts.get(cat_idx, 0)
        print(f"  {cat_name} (class {cat_idx}): {count} samples")
    
    if len(dataset) > 0:
        sample_image, sample_label = dataset[0]
        print(f"Sample image shape: {sample_image.shape}")
        print(f"Sample label: {sample_label}")

In [11]:
# analyze_dataset(train_dataset, "Training Dataset")
# analyze_dataset(test_dataset, "Test Dataset")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=15):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(2, 2)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        self.fc1 = nn.Linear(64 * 56 * 56, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))  
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = x.view(x.size(0), -1)  # flatten
        x = F.dropout(F.relu(self.fc1(x)), p=0.3, training=self.training)  
        x = self.fc2(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = SimpleCNN(num_classes=15).to(device)

dummy = torch.randn(1, 3, 224, 224).to(device)
print("Output shape:", model(dummy).shape)  


Using device: cuda
Output shape: torch.Size([1, 15])


In [None]:
def train_model(model, train_loader, test_loader, device, epochs=10, lr=0.001):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        batch_count = len(train_loader)
        
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            if (batch_idx + 1) % 20 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx+1}/{batch_count}], "
                      f"Loss: {loss.item():.4f}")
        
        train_acc = 100 * correct / total
        avg_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f} | Train Acc: {train_acc:.2f}%")
        
        model.eval()
        test_correct = 0
        test_total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                test_total += labels.size(0)
                test_correct += (predicted == labels).sum().item()
        
        test_acc = 100 * test_correct / test_total
        print(f"Test Acc: {test_acc:.2f}%\n")


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import (accuracy_score, precision_score, recall_score, 
                           f1_score, roc_auc_score, classification_report, 
                           confusion_matrix)
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def calculate_all_metrics(y_true, y_pred, y_probs, num_classes):
    """
    Calculate all metrics mentioned in the abstract
    """
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
    
    try:
        if len(np.unique(y_true)) > 2: 
            auc_roc = roc_auc_score(y_true, y_probs, multi_class='ovr', average='weighted')
        else: 
            auc_roc = roc_auc_score(y_true, y_probs[:, 1])
    except:
        auc_roc = 0.0
    
    precision_per_class = precision_score(y_true, y_pred, average=None, zero_division=0)
    recall_per_class = recall_score(y_true, y_pred, average=None, zero_division=0)
    f1_per_class = f1_score(y_true, y_pred, average=None, zero_division=0)
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'auc_roc': auc_roc,
        'precision_per_class': precision_per_class,
        'recall_per_class': recall_per_class,
        'f1_per_class': f1_per_class
    }

def evaluate_model(model, data_loader, device, categories):
    """
    Comprehensive model evaluation with all metrics
    """
    model.eval()
    all_labels = []
    all_preds = []
    all_probs = []
    
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)
    
    metrics = calculate_all_metrics(all_labels, all_preds, all_probs, len(categories))
    
    return metrics, all_labels, all_preds, all_probs

def print_detailed_metrics(metrics, categories):
    """
    Print detailed metrics report
    """
    print("\n" + "="*50)
    print("COMPREHENSIVE EVALUATION METRICS")
    print("="*50)
    
    print(f"\n📊 OVERALL PERFORMANCE:")
    print(f"   Accuracy:  {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.2f}%)")
    print(f"   Precision: {metrics['precision']:.4f}")
    print(f"   Recall:    {metrics['recall']:.4f}")
    print(f"   F1-Score:  {metrics['f1_score']:.4f}")
    print(f"   AUC-ROC:   {metrics['auc_roc']:.4f}")
    
    print(f"\n📈 PER-CLASS PERFORMANCE:")
    print("-" * 60)
    print(f"{'Category':<15} {'Precision':<10} {'Recall':<10} {'F1-Score':<10}")
    print("-" * 60)
    
    for i, category in enumerate(categories):
        if i < len(metrics['precision_per_class']):
            print(f"{category:<15} {metrics['precision_per_class'][i]:.4f}     "
                  f"{metrics['recall_per_class'][i]:.4f}     "
                  f"{metrics['f1_per_class'][i]:.4f}")

def plot_confusion_matrix(y_true, y_pred, categories, save_path=None):
    """
    Plot confusion matrix
    """
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=categories, yticklabels=categories)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def train_model_with_metrics(model, train_loader, test_loader, device, categories, 
                           epochs=15, lr=0.001):
    """
    Enhanced training function with comprehensive metrics tracking
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    train_metrics_history = []
    test_metrics_history = []
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        train_labels = []
        train_preds = []
        train_probs = []
        
        print(f"\n🚀 Epoch [{epoch+1}/{epochs}]")
        print("-" * 40)
        
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            
            train_labels.extend(labels.cpu().numpy())
            train_preds.extend(preds.cpu().numpy())
            train_probs.extend(probs.cpu().detach().numpy())
            
            if (batch_idx + 1) % 20 == 0:
                print(f"   Batch [{batch_idx+1}/{len(train_loader)}] - Loss: {loss.item():.4f}")
        
        train_metrics = calculate_all_metrics(
            np.array(train_labels), np.array(train_preds), 
            np.array(train_probs), len(categories)
        )
        train_metrics_history.append(train_metrics)
        
        test_metrics, test_labels, test_preds, test_probs = evaluate_model(
            model, test_loader, device, categories
        )
        test_metrics_history.append(test_metrics)
        
        print(f"\n📈 EPOCH {epoch+1} RESULTS:")
        print(f"   Train - Acc: {train_metrics['accuracy']*100:.2f}% | "
              f"F1: {train_metrics['f1_score']:.4f} | "
              f"AUC: {train_metrics['auc_roc']:.4f}")
        print(f"   Test  - Acc: {test_metrics['accuracy']*100:.2f}% | "
              f"F1: {test_metrics['f1_score']:.4f} | "
              f"AUC: {test_metrics['auc_roc']:.4f}")
        
        if epoch > 0 and test_metrics['f1_score'] < test_metrics_history[-2]['f1_score']:
            print("   ⚠️  F1-score decreased, consider early stopping")
    
    print("\n" + "="*60)
    print("FINAL MODEL EVALUATION")
    print("="*60)
    
    final_metrics, final_labels, final_preds, final_probs = evaluate_model(
        model, test_loader, device, categories
    )
    
    print_detailed_metrics(final_metrics, categories)
    
    print("\n📋 DETAILED CLASSIFICATION REPORT:")
    print("-" * 50)
    print(classification_report(final_labels, final_preds, target_names=categories))
    
    plot_confusion_matrix(final_labels, final_preds, categories)
    
    return {
        'train_history': train_metrics_history,
        'test_history': test_metrics_history,
        'final_metrics': final_metrics,
        'final_predictions': (final_labels, final_preds, final_probs)
    }


In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = SimpleCNN(num_classes=15).to(device)
train_model(model, train_loader, test_loader, device, epochs=5, lr=0.001)


Using device: cuda
Epoch [1/5], Batch [20/120], Loss: 8.5908
Epoch [1/5], Batch [40/120], Loss: 3.5644
Epoch [1/5], Batch [60/120], Loss: 0.0678
Epoch [1/5], Batch [80/120], Loss: 0.3678
Epoch [1/5], Batch [100/120], Loss: 0.0107
Epoch [1/5], Batch [120/120], Loss: 0.0000
Epoch [1/5] - Loss: 4.9258 | Train Acc: 87.00%
Test Acc: 99.07%

Epoch [2/5], Batch [20/120], Loss: 0.0549
Epoch [2/5], Batch [40/120], Loss: 0.0025
Epoch [2/5], Batch [60/120], Loss: 0.0066
Epoch [2/5], Batch [80/120], Loss: 0.0000
Epoch [2/5], Batch [100/120], Loss: 0.0387
Epoch [2/5], Batch [120/120], Loss: 0.0000
Epoch [2/5] - Loss: 0.1339 | Train Acc: 97.71%
Test Acc: 99.88%



KeyboardInterrupt: 

In [None]:
results = train_model_with_metrics(
    model, train_loader, test_loader, device, categories, 
    epochs=3, lr=0.001
)

In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import timm
import math
import numpy as np
from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, classification_report, confusion_matrix)

In [None]:
class CNNFeatureExtractor(nn.Module):

    def __init__(self, output_dim=768):
        super(CNNFeatureExtractor, self).__init__()
        
        self.conv_blocks = nn.Sequential(
            # Block 1
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            
            # Block 2
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # Block 3
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # Block 4
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
        self.adaptive_pool = nn.AdaptiveAvgPool2d((14, 14))  # 14x14 = 196 patches
        
        self.feature_projection = nn.Linear(512, output_dim)
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        features = self.conv_blocks(x)  # [B, 512, H, W]
        
        features = self.adaptive_pool(features)  # [B, 512, 14, 14]
        
        B, C, H, W = features.shape
        features = features.view(B, C, H * W).transpose(1, 2)  # [B, 196, 512]
        
        features = self.feature_projection(features)  # [B, 196, 768]
        features = self.dropout(features)
        
        return features

In [None]:
class CNNDeiTModel(nn.Module):

    def __init__(self, num_classes=15, embed_dim=768, num_heads=12, num_layers=12):
        super(CNNDeiTModel, self).__init__()
        
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        
        self.cnn_extractor = CNNFeatureExtractor(output_dim=embed_dim)
        
        self.deit_model = timm.create_model(
            'deit_base_patch16_224',
            pretrained=True,
            num_classes=0,  
            img_size=224
        )
        
        self.transformer_blocks = self.deit_model.blocks
        self.norm = self.deit_model.norm
        
        self.pos_embed = nn.Parameter(torch.zeros(1, 197, embed_dim))
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(embed_dim // 2, num_classes)
        )
        
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        self._init_weights()
        
    def _init_weights(self):
        """Initialize model weights"""
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.dist_token, std=0.02)
        
    def forward(self, x):
        B = x.shape[0]
        
        cnn_features = self.cnn_extractor(x)  # [B, 196, 768]
        
        cls_tokens = self.cls_token.expand(B, -1, -1)  # [B, 1, 768]
        dist_tokens = self.dist_token.expand(B, -1, -1)  # [B, 1, 768]
        
        x = torch.cat([cls_tokens, dist_tokens, cnn_features], dim=1)  # [B, 198, 768]
        
        pos_embed = torch.cat([
            self.pos_embed[:, :1, :],  # CLS token position
            self.pos_embed[:, :1, :],  # Distillation token position (reuse CLS)
            self.pos_embed[:, 1:, :]   # Patch positions
        ], dim=1)
        
        x = x + pos_embed[:, :x.size(1), :]
        
        for block in self.transformer_blocks:
            x = block(x)
        
        x = self.norm(x)
        
        cls_output = x[:, 0]  # [B, 768]
        
        output = self.classifier(cls_output)
        
        return output

In [28]:
def create_cnn_deit_model(num_classes=15, pretrained=True):

    model = CNNDeiTModel(num_classes=num_classes)
    
    if pretrained:
        print("Using pre-trained DeiT weights")
    else:
        print("Training from scratch")
        
    return model

In [None]:
def train_cnn_deit_model(model, train_loader, test_loader, device, categories, 
                        epochs=20, lr=1e-4, weight_decay=1e-4):

    optimizer = optim.AdamW(model.parameters(), 
                           lr=lr, 
                           weight_decay=weight_decay,
                           betas=(0.9, 0.999))
    
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    
    train_metrics_history = []
    test_metrics_history = []
    best_f1 = 0.0
    
    print(f"Starting CNN+DeiT Training")
    print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        train_labels = []
        train_preds = []
        train_probs = []
        
        print(f"\nEpoch [{epoch+1}/{epochs}] - LR: {scheduler.get_last_lr()[0]:.6f}")
        
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            running_loss += loss.item()
            
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            
            train_labels.extend(labels.cpu().numpy())
            train_preds.extend(preds.cpu().numpy())
            train_probs.extend(probs.cpu().detach().numpy())
            
            if (batch_idx + 1) % 20 == 0:
                print(f"   Batch [{batch_idx+1}/{len(train_loader)}] - Loss: {loss.item():.4f}")
        
        scheduler.step()
        
        train_metrics = calculate_all_metrics(
            np.array(train_labels), np.array(train_preds), 
            np.array(train_probs), len(categories)
        )
        train_metrics_history.append(train_metrics)
        
        test_metrics, test_labels, test_preds, test_probs = evaluate_model(
            model, test_loader, device, categories
        )
        test_metrics_history.append(test_metrics)
        
        avg_loss = running_loss / len(train_loader)
        print(f"\n EPOCH {epoch+1} RESULTS:")
        print(f"   Loss: {avg_loss:.4f}")
        print(f"   Train - Acc: {train_metrics['accuracy']*100:.2f}% | "
              f"Precision: {train_metrics['precision']:.4f} | "
              f"Recall: {train_metrics['recall']:.4f} | "
              f"F1: {train_metrics['f1_score']:.4f} | "
              f"AUC: {train_metrics['auc_roc']:.4f}")
        print(f"   Test  - Acc: {test_metrics['accuracy']*100:.2f}% | "
              f"Precision: {test_metrics['precision']:.4f} | "
              f"Recall: {test_metrics['recall']:.4f} | "
              f"F1: {test_metrics['f1_score']:.4f} | "
              f"AUC: {test_metrics['auc_roc']:.4f}")
        
        if test_metrics['f1_score'] > best_f1:
            best_f1 = test_metrics['f1_score']
            torch.save(model.state_dict(), 'best_cnn_deit_model.pth')
            print(f"   New best model saved! F1-Score: {best_f1:.4f}")
    
    model.load_state_dict(torch.load('best_cnn_deit_model.pth'))
    
    print("\n" + "="*70)
    print("FINAL CNN+DeiT MODEL EVALUATION")
    print("="*70)
    
    final_metrics, final_labels, final_preds, final_probs = evaluate_model(
        model, test_loader, device, categories
    )
    
    print_detailed_metrics(final_metrics, categories)
    
    print("\n DETAILED CLASSIFICATION REPORT:")
    print("-" * 60)
    print(classification_report(final_labels, final_preds, target_names=categories))
    
    plot_confusion_matrix(final_labels, final_preds, categories, 
                         save_path='cnn_deit_confusion_matrix.png')
    
    return {
        'train_history': train_metrics_history,
        'test_history': test_metrics_history,
        'final_metrics': final_metrics,
        'final_predictions': (final_labels, final_preds, final_probs),
        'best_f1_score': best_f1
    }

In [33]:
def compare_models(cnn_results, cnn_deit_results):

    print("\n" + "="*80)
    print(" MODEL COMPARISON: CNN vs CNN+DeiT")
    print("="*80)
    
    cnn_metrics = cnn_results['final_metrics']
    deit_metrics = cnn_deit_results['final_metrics']
    
    metrics_names = ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc']
    
    print(f"{'Metric':<15} {'CNN':<12} {'CNN+DeiT':<12} {'Improvement':<12}")
    print("-" * 60)
    
    for metric in metrics_names:
        cnn_val = cnn_metrics[metric]
        deit_val = deit_metrics[metric]
        improvement = ((deit_val - cnn_val) / cnn_val) * 100 if cnn_val > 0 else 0
        
        print(f"{metric.upper():<15} {cnn_val:.4f}      {deit_val:.4f}      "
              f"{improvement:+.2f}%")
    
    print("\n KEY IMPROVEMENTS:")
    for metric in metrics_names:
        cnn_val = cnn_metrics[metric]
        deit_val = deit_metrics[metric]
        if deit_val > cnn_val:
            improvement = ((deit_val - cnn_val) / cnn_val) * 100
            print(f"   {metric.upper()}: +{improvement:.2f}% improvement")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

cnn_deit_model = create_cnn_deit_model(num_classes=15, pretrained=True)
cnn_deit_model = cnn_deit_model.to(device)

total_params = sum(p.numel() for p in cnn_deit_model.parameters())
trainable_params = sum(p.numel() for p in cnn_deit_model.parameters() if p.requires_grad)

print(f"\n CNN+DeiT MODEL ARCHITECTURE:")
print(f"   Total Parameters: {total_params:,}")
print(f"   Trainable Parameters: {trainable_params:,}")
print(f"   Model Size: ~{total_params * 4 / (1024**2):.1f} MB")

print("\n Starting CNN+DeiT Training...")
cnn_deit_results = train_cnn_deit_model(
    cnn_deit_model, 
    train_loader, 
    test_loader, 
    device, 
    categories,
    epochs=20, 
    lr=1e-4,
    weight_decay=1e-4
)

print(f"\n CNN+DeiT FINAL PERFORMANCE:")
final_metrics = cnn_deit_results['final_metrics']
print(f"   Accuracy:  {final_metrics['accuracy']*100:.2f}%")
print(f"   Precision: {final_metrics['precision']:.4f}")
print(f"   Recall:    {final_metrics['recall']:.4f}")
print(f"   F1-Score:  {final_metrics['f1_score']:.4f}")
print(f"   AUC-ROC:   {final_metrics['auc_roc']:.4f}")


Using device: cuda
Using pre-trained DeiT weights

 CNN+DeiT MODEL ARCHITECTURE:
   Total Parameters: 91,306,383
   Trainable Parameters: 91,306,383
   Model Size: ~348.3 MB

 Starting CNN+DeiT Training...
Starting CNN+DeiT Training
Model Parameters: 91,306,383
Trainable Parameters: 91,306,383

Epoch [1/20] - LR: 0.000100


In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import math
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import timm


class LoRALayer(nn.Module):
    def __init__(self, in_features, out_features, rank=4, alpha=16, dropout=0.1):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        
        self.lora_A = nn.Parameter(torch.randn(in_features, rank) / math.sqrt(rank))
        self.lora_B = nn.Parameter(torch.zeros(rank, out_features))
        self.dropout = nn.Dropout(dropout)
        
        assert not torch.isnan(self.lora_A).any(), "LoRA_A contains NaN"
        assert not torch.isnan(self.lora_B).any(), "LoRA_B contains NaN"


class LoRALinear(nn.Module):
    def __init__(self, original_layer, rank=4, alpha=16, dropout=0.1):
        super().__init__()
        self.original_layer = original_layer
        self.lora = LoRALayer(
            original_layer.in_features,
            original_layer.out_features,
            rank=rank, alpha=alpha, dropout=dropout
        )
        for param in self.original_layer.parameters():
            param.requires_grad = False

    def forward(self, x):
        return self.original_layer(x) + self.lora(x)

class LoRAConv2d(nn.Module):
    def __init__(self, original_conv, rank=4, alpha=16, dropout=0.1):
        super().__init__()
        self.original_conv = original_conv
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank

        in_channels = original_conv.in_channels
        out_channels = original_conv.out_channels
        kernel_size = original_conv.kernel_size

        # LoRA decomposition for conv: two 1x1 convs (A then B)
        self.lora_A = nn.Conv2d(in_channels, rank, kernel_size=1, bias=False)
        self.lora_B = nn.Conv2d(rank, out_channels, kernel_size=kernel_size,
                                stride=original_conv.stride,
                                padding=original_conv.padding,
                                bias=False)

        nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B.weight)

        self.dropout = nn.Dropout2d(dropout)

        for param in self.original_conv.parameters():
            param.requires_grad = False

    def forward(self, x):
        original_out = self.original_conv(x)
        lora_out = self.lora_B(self.dropout(self.lora_A(x))) * self.scaling
        return original_out + lora_out

def validate_and_fix_parameters(model):
    """Validate and fix any parameter issues"""
    print("Validating and fixing model parameters...")
    
    fixed_count = 0
    for name, param in model.named_parameters():
        if param.requires_grad:  
            if torch.isnan(param).any() or torch.isinf(param).any():
                print(f"Fixing invalid values in {name}")
                if 'weight' in name:
                    if len(param.shape) >= 2:
                        nn.init.xavier_uniform_(param)
                    else:
                        nn.init.uniform_(param, -0.1, 0.1)
                elif 'bias' in name:
                    nn.init.zeros_(param)
                else:
                    nn.init.normal_(param, 0, 0.01)
                fixed_count += 1
    
    print(f"Fixed {fixed_count} parameters")
    return model


class CNNDeiTLoRAModel(nn.Module):
    def __init__(self, num_classes, cnn_backbone='resnet50', deit_model='deit_small_patch16_224',
                 lora_rank=4, lora_alpha=16, lora_dropout=0.1, freeze_cnn=True):
        super().__init__()

        self.cnn_backbone = timm.create_model(cnn_backbone, pretrained=True, num_classes=0)
        cnn_features = self.cnn_backbone.num_features

        if not freeze_cnn:
            self._apply_lora_to_cnn(lora_rank, lora_alpha, lora_dropout)
        else:
            for param in self.cnn_backbone.parameters():
                param.requires_grad = False

        self.deit = timm.create_model(deit_model, pretrained=True, num_classes=0)
        deit_features = self.deit.num_features

        self._apply_lora_to_deit(lora_rank, lora_alpha, lora_dropout)

        self.fusion_layer = nn.Sequential(
            nn.Linear(cnn_features + deit_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        self.classifier = nn.Linear(256, num_classes)

    def _apply_lora_to_cnn(self, rank, alpha, dropout):
        for name, module in list(self.cnn_backbone.named_modules()):
            if isinstance(module, nn.Conv2d) and 'downsample' not in name:
                parent_name = '.'.join(name.split('.')[:-1])
                child_name = name.split('.')[-1]
                if parent_name:
                    parent = self.cnn_backbone.get_submodule(parent_name)
                    setattr(parent, child_name, LoRAConv2d(module, rank, alpha, dropout))

    def _apply_lora_to_deit(self, rank, alpha, dropout):
        for name, module in list(self.deit.named_modules()):
            if isinstance(module, nn.Linear) and ('qkv' in name or 'proj' in name or 'fc' in name):
                parent_name = '.'.join(name.split('.')[:-1])
                child_name = name.split('.')[-1]
                if parent_name:
                    parent = self.deit.get_submodule(parent_name)
                    setattr(parent, child_name, LoRALinear(module, rank, alpha, dropout))

    def forward(self, x):
        cnn_features = self.cnn_backbone(x)  
        deit_features = self.deit(x)          
        combined_features = torch.cat([cnn_features, deit_features], dim=1)
        fused_features = self.fusion_layer(combined_features)
        output = self.classifier(fused_features)
        return output

    def count_parameters(self):
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        return total_params, trainable_params


def safe_move_model_to_device(model, device):
    """Enhanced model device transfer with parameter validation"""
    print(f"Attempting to move model to {device}...")
    
    print("Validating parameters before GPU transfer...")
    invalid_params = []
    
    for name, param in model.named_parameters():
        try:
            if torch.isnan(param).any():
                invalid_params.append(f"{name}: contains NaN")
            if torch.isinf(param).any():
                invalid_params.append(f"{name}: contains Inf")
            
            param_abs_max = param.abs().max().item()
            if param_abs_max > 1e6:
                invalid_params.append(f"{name}: unusually large values (max: {param_abs_max})")
                
        except Exception as e:
            invalid_params.append(f"{name}: validation error - {e}")
    
    if invalid_params:
        print("INVALID PARAMETERS DETECTED:")
        for issue in invalid_params[:10]: 
            print(f"  - {issue}")
        print("Fix these issues before moving to GPU!")
        return model
    
    try:
        print("Moving CNN backbone...")
        model.cnn_backbone.to(device)
        
        print("Moving DeiT transformer...")
        model.deit.to(device)
        
        print("Moving fusion layers...")
        model.fusion_layer.to(device)
        
        print("Moving classifier...")
        model.classifier.to(device)
        
        print(f"✓ Model successfully moved to {device}")
        return model
        
    except Exception as e:
        print(f"ERROR during model.to(device): {e}")
        
        print("Resetting model to CPU...")
        model.cpu()
        
        print("\nCUDA Memory Info:")
        if torch.cuda.is_available():
            print(f"  Allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
            print(f"  Reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB")
            print(f"  Max allocated: {torch.cuda.max_memory_allocated()/1e9:.2f} GB")
        
        raise


from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score

def calculate_all_metrics(y_true, y_pred, y_probs, num_classes):
    accuracy = accuracy_score(y_true, y_pred)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
    if num_classes > 2:
        auc = roc_auc_score(y_true, y_probs, multi_class='ovr', average='weighted')
    else:
        auc = roc_auc_score(y_true, y_probs[:, 1])
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1_score': f1, 'auc_roc': auc}

def evaluate_model(model, test_loader, device, categories):
    model.eval()
    test_labels, test_preds, test_probs = [], [], []
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            test_labels.extend(labels.cpu().numpy())
            test_preds.extend(preds.cpu().numpy())
            test_probs.extend(probs.cpu().numpy())
    test_labels = np.array(test_labels)
    test_preds = np.array(test_preds)
    test_probs = np.array(test_probs)
    metrics = calculate_all_metrics(test_labels, test_preds, test_probs, len(categories))
    return metrics, test_labels, test_preds, test_probs


def train_cnn_deit_lora_model(
    model, train_loader, test_loader, device, categories,
    epochs=20, lr=1e-4, weight_decay=1e-4
):
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    best_f1 = 0.0
    train_metrics_history, test_metrics_history = [], []

    num_classes = len(categories)
    print("Before training, classifier shape:", getattr(model, "classifier").weight.shape, "num_classes:", num_classes)

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        train_labels, train_preds, train_probs = [], [], []

        print(f"\nEpoch [{epoch+1}/{epochs}] - LR: {scheduler.get_last_lr()[0]:.6f}")

        for batch_idx, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            if labels.dtype != torch.long:
                labels = labels.long()

            try:
                optimizer.zero_grad()
                outputs = model(images)  # [B, C]

                if batch_idx < 3:
                    print(f"Batch {batch_idx} -> outputs.shape={outputs.shape}, labels.shape={labels.shape}")
                    print(f"   labels min={labels.min().item()} max={labels.max().item()} (num_classes={num_classes})")

                loss = criterion(outputs, labels)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                running_loss += loss.item()
                probs = torch.softmax(outputs, dim=1)
                _, preds = torch.max(outputs, 1)
                train_labels.extend(labels.cpu().numpy())
                train_preds.extend(preds.cpu().numpy())
                train_probs.extend(probs.cpu().detach().numpy())

                if (batch_idx + 1) % 20 == 0:
                    print(f"   Batch [{batch_idx+1}/{len(train_loader)}] - Loss: {loss.item():.4f}")

            except Exception as e:
                print("\nError in batch", batch_idx, "-> Exception:", repr(e))
                if 'outputs' in locals():
                    print("outputs.shape:", outputs.shape)
                print("labels dtype/device/min/max/unique:")
                try:
                    print(labels.dtype, labels.device, labels.min().item(), labels.max().item(), torch.unique(labels))
                except Exception:
                    print("Could not print labels details.")
                raise

        scheduler.step()
        train_metrics = calculate_all_metrics(np.array(train_labels), np.array(train_preds), np.array(train_probs), num_classes)
        train_metrics_history.append(train_metrics)

        test_metrics, _, _, _ = evaluate_model(model, test_loader, device, categories)
        test_metrics_history.append(test_metrics)

        avg_loss = running_loss / len(train_loader)
        print(f"\n EPOCH {epoch+1} RESULTS:")
        print(f"   Loss: {avg_loss:.4f}")
        print(f"   Train - Acc: {train_metrics['accuracy']*100:.2f}% | F1: {train_metrics['f1_score']:.4f} | AUC: {train_metrics['auc_roc']:.4f}")
        print(f"   Test  - Acc: {test_metrics['accuracy']*100:.2f}% | F1: {test_metrics['f1_score']:.4f} | AUC: {test_metrics['auc_roc']:.4f}")

        if test_metrics['f1_score'] > best_f1:
            best_f1 = test_metrics['f1_score']
            torch.save(model.state_dict(), 'best_cnn_deit_lora_model_debug.pth')
            print("Saved new best model:", best_f1)

    return {
        'train_history': train_metrics_history,
        'test_history': test_metrics_history,
        'best_f1_score': best_f1
    }


def main():
    print("Creating model...")
    num_classes = 10
    categories = [f'class_{i}' for i in range(num_classes)]
    
    model = CNNDeiTLoRAModel(
        num_classes=num_classes,
        cnn_backbone='resnet50',
        deit_model='deit_small_patch16_224',
        lora_rank=8,
        lora_alpha=32,
        lora_dropout=0.1,
        freeze_cnn=True
    )
    
    model = validate_and_fix_parameters(model)
    
    total_params, trainable_params = model.count_parameters()
    print(f"Model created - Total: {total_params:,}, Trainable: {trainable_params:,}")
    print(f"Parameter efficiency: {(trainable_params/total_params)*100:.2f}%")
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("Cleared CUDA cache")
    
    print("\n=== Testing on CPU ===")
    cpu_device = torch.device('cpu')
    model = safe_move_model_to_device(model, cpu_device)
    
    try:
        dummy_input = torch.randn(2, 3, 224, 224)
        with torch.no_grad():
            output = model(dummy_input)
        print(f"✓ CPU forward pass successful: {output.shape}")
    except Exception as e:
        print(f"✗ CPU forward pass failed: {e}")
        return
    
    if torch.cuda.is_available():
        print("\n=== Moving to GPU ===")
        gpu_device = torch.device('cuda')
        model = safe_move_model_to_device(model, gpu_device)
        
        try:
            dummy_input = torch.randn(2, 3, 224, 224).to(gpu_device)
            with torch.no_grad():
                output = model(dummy_input)
            print(f"✓ GPU forward pass successful: {output.shape}")
        except Exception as e:
            print(f"✗ GPU forward pass failed: {e}")
            return
    
    print("Model validation complete!")


if __name__ == "__main__":
    main()
