In [8]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, 
    roc_auc_score, 
    roc_curve, 
    confusion_matrix,
    classification_report
)
from sklearn.preprocessing import label_binarize
import timm

In [9]:
class DeepLensConfig:
    """Centralized configuration management"""
    def __init__(self):
        # Model Hyperparameters
        self.lr = 1e-4
        self.batch_size = 64
        self.num_classes = 3
        self.epochs = 10
        self.weight_decay = 1e-2
        
        # Model Architecture
        self.model_name = "resnet18"
        self.pretrained = True
        
        # Training Settings
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.num_workers = 4
        self.seed = 42


In [10]:
class LensDataset(Dataset):
    """Custom Dataset for Lens Classification"""
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        image_path = self.dataframe['data_path'].iloc[idx]
        image = np.load(image_path).astype(np.float32)
        
        # Ensure the image is 2D (single channel)
        if image.ndim == 3:
            image = image.squeeze()
        
        label = self.dataframe['target'].iloc[idx]
        if self.transform:
            image = self.transform(image=image)['image']
        
        # Ensure the tensor is 3D: [1, height, width]
        return torch.tensor(image).unsqueeze(0), torch.tensor(label)


In [11]:
class LensClassificationModel(nn.Module):
    """Lens Classification Model using Pretrained Architecture"""
    def __init__(self, config):
        super().__init__()
        self.model = timm.create_model(
            config.model_name, 
            pretrained=config.pretrained, 
            in_chans=1
        )
        
        # Modify classifier
        classifier_name = self.model.default_cfg['classifier']
        n_features = self.model._modules[classifier_name].in_features
        self.model._modules[classifier_name] = nn.Identity()
        
        self.classifier = nn.Sequential(
            nn.Linear(n_features, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, config.num_classes)
        )

    def forward(self, x):
        features = self.model(x)
        return self.classifier(features)

In [12]:
class ModelTrainer:
    """Comprehensive Training and Evaluation Pipeline"""
    def __init__(self, config):
        self.config = config
        self.device = config.device
        
    def prepare_data(self, data_dir):
        """Prepare training and validation datasets"""
        def _get_image_paths(base_path, classes):
            paths = []
            labels = []
            for label, cls in enumerate(classes):
                class_path = os.path.join(base_path, cls)
                for img_file in os.listdir(class_path):
                    paths.append(os.path.join(class_path, img_file))
                    labels.append(label)
            return paths, labels

        # Assuming directory structure: data_dir/train/(no/sphere/vort)
        train_path = os.path.join(data_dir, 'train')
        classes = ['no', 'sphere', 'vort']
        
        paths, labels = _get_image_paths(train_path, classes)
        
        df = pd.DataFrame({
            'data_path': paths,
            'target': labels
        })
        
        # Split into train and validation
        train_df, val_df = train_test_split(
            df,
            # 90 : 10 split
            test_size=0.1, 
            stratify=df['target'], 
            random_state=self.config.seed
        )
        
        return train_df, val_df

    def create_dataloaders(self, train_df, val_df):
        """Create PyTorch DataLoaders"""
        train_dataset = LensDataset(train_df)
        val_dataset = LensDataset(val_df)
        
        train_loader = DataLoader(
            train_dataset, 
            batch_size=self.config.batch_size, 
            shuffle=True, 
            num_workers=self.config.num_workers
        )
        
        val_loader = DataLoader(
            val_dataset, 
            batch_size=self.config.batch_size, 
            shuffle=False, 
            num_workers=self.config.num_workers
        )
        
        return train_loader, val_loader

    def train(self, train_loader, val_loader):
        """Main training loop with comprehensive logging"""
        model = LensClassificationModel(self.config).to(self.device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(
            model.parameters(), 
            lr=self.config.lr, 
            weight_decay=self.config.weight_decay
        )
        
        best_val_auc = 0
        training_history = {
            'train_loss': [],
            'val_loss': [],
            'val_auc': []
        }
        
        for epoch in range(self.config.epochs):
            # Training Phase
            model.train()
            train_losses = []
            
            for images, labels in train_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                train_losses.append(loss.item())
            
            # Validation Phase
            model.eval()
            val_losses = []
            all_preds = []
            all_labels = []
            
            with torch.no_grad():
                for images, labels in val_loader:
                    images, labels = images.to(self.device), labels.to(self.device)
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    
                    val_losses.append(loss.item())
                    preds = torch.softmax(outputs, dim=1)
                    all_preds.append(preds.cpu().numpy())
                    all_labels.append(labels.cpu().numpy())
            
            # Aggregate predictions and labels
            all_preds = np.concatenate(all_preds)
            all_labels = np.concatenate(all_labels)
            
            # Calculate metrics
            val_auc = roc_auc_score(all_labels, all_preds, multi_class='ovr')
            
            # Log metrics
            avg_train_loss = np.mean(train_losses)
            avg_val_loss = np.mean(val_losses)
            
            training_history['train_loss'].append(avg_train_loss)
            training_history['val_loss'].append(avg_val_loss)
            training_history['val_auc'].append(val_auc)
            
            # Print epoch summary
            print(f"Epoch {epoch+1}/{self.config.epochs}")
            print(f"Train Loss: {avg_train_loss:.4f}")
            print(f"Val Loss: {avg_val_loss:.4f}")
            print(f"Val AUC: {val_auc:.4f}")
            
            # Model checkpoint
            if val_auc > best_val_auc:
                best_val_auc = val_auc
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_val_auc': best_val_auc
                }, 'best_model.pth')
        
        return model, all_preds, all_labels, training_history

    

In [13]:
class ModelVisualizer:
    """Visualization utilities for model performance"""
    @staticmethod
    def plot_training_history(history, save_path='training_history.png'):
        """Plot training and validation loss/AUC"""
        plt.figure(figsize=(12, 4))
        
        # Loss Plot
        plt.subplot(1, 2, 1)
        plt.plot(history['train_loss'], label='Train Loss')
        plt.plot(history['val_loss'], label='Validation Loss')
        plt.title('Training and Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        # AUC Plot
        plt.subplot(1, 2, 2)
        plt.plot(history['val_auc'], label='Validation AUC', color='green')
        plt.title('Validation AUC')
        plt.xlabel('Epoch')
        plt.ylabel('AUC Score')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()

    @staticmethod
    def plot_confusion_matrix(labels, predictions, class_names, save_path='confusion_matrix.png'):
        """Create and save confusion matrix visualization"""
        # Get predicted class labels
        pred_labels = np.argmax(predictions, axis=1)
        
        # Compute confusion matrix
        cm = confusion_matrix(labels, pred_labels)
        
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                    xticklabels=class_names, 
                    yticklabels=class_names)
        plt.title('Confusion Matrix')
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()

    @staticmethod
    def plot_roc_curve(labels, predictions, class_names, save_path='roc_curve.png'):
        """Plot ROC curves for multiclass classification"""
        plt.figure(figsize=(10, 8))
        
        # One-vs-Rest ROC Curves
        for i in range(len(class_names)):
            # Create binary labels for current class
            binary_labels = (labels == i).astype(int)
            class_preds = predictions[:, i]
            
            # Compute ROC curve
            fpr, tpr, _ = roc_curve(binary_labels, class_preds)
            
            # Calculate AUC
            roc_auc = roc_auc_score(binary_labels, class_preds)
            
            # Plot ROC curve
            plt.plot(fpr, tpr, 
                     label=f'ROC curve (class: {class_names[i]}, AUC = {roc_auc:.2f})')
        
        plt.plot([0, 1], [0, 1], 'k--')  # Diagonal line
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic (ROC) Curve')
        plt.legend(loc="lower right")
        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()

In [15]:
def main():
    # Configuration
    config_class = DeepLensConfig
    data_dir = '/kaggle/input/commontest/dataset'
    
    # Define classes explicitly
    CLASS_NAMES = ['no', 'sphere', 'vort']

    

    # MANUAL PARAMETER INPUT

    BEST_PARAMS = {
        'lr': 0.0007166,          # Learning rate
        'batch_size': 32,          # Batch size
        'weight_decay': 0.0002875, # Weight decay
        'model_name': 'efficientnet_b0', # Model architecture
        'dropout_rate': 0.286      # Dropout rate
    }

    # Create configuration with manual parameters
    config = config_class()
    config.lr = BEST_PARAMS['lr']
    config.batch_size = BEST_PARAMS['batch_size']
    config.weight_decay = BEST_PARAMS['weight_decay']
    config.model_name = BEST_PARAMS['model_name']
    config.dropout_rate = BEST_PARAMS['dropout_rate']

    # Prepare data
    trainer = ModelTrainer(config)
    train_df, val_df = trainer.prepare_data(data_dir)
    
    # Create data loaders
    train_loader, val_loader = trainer.create_dataloaders(train_df, val_df)
    
    # Train model with best configuration
    model, predictions, labels, history = trainer.train(train_loader, val_loader)
    
    # Visualization
    visualizer = ModelVisualizer()
    
    # 1. Training History Plot
    visualizer.plot_training_history(history)
    
    # 2. Confusion Matrix (pass class names explicitly)
    visualizer.plot_confusion_matrix(labels, predictions, CLASS_NAMES)
    
    # 3. ROC Curves (pass class names explicitly)
    visualizer.plot_roc_curve(labels, predictions, CLASS_NAMES)
    
    # 4. Save Best Model with Additional Metadata
    model_save_path = 'best_model_full.pth'
    torch.save({
        'model_state_dict': model.state_dict(),
        'best_config': BEST_PARAMS,
        'class_names': CLASS_NAMES,
        'training_history': history
    }, model_save_path)
    
    

if __name__ == "__main__":
    main()

Epoch 1/10
Train Loss: 0.9165
Val Loss: 0.7633
Val AUC: 0.8667
Epoch 2/10
Train Loss: 0.5824
Val Loss: 0.5417
Val AUC: 0.9323
Epoch 3/10
Train Loss: 0.4247
Val Loss: 0.4203
Val AUC: 0.9517
Epoch 4/10
Train Loss: 0.3409
Val Loss: 0.3521
Val AUC: 0.9646
Epoch 5/10
Train Loss: 0.2768
Val Loss: 0.2951
Val AUC: 0.9742
Epoch 6/10
Train Loss: 0.2369
Val Loss: 0.3264
Val AUC: 0.9748
Epoch 7/10
Train Loss: 0.2068
Val Loss: 0.3187
Val AUC: 0.9757
Epoch 8/10
Train Loss: 0.1837
Val Loss: 0.2549
Val AUC: 0.9807
Epoch 9/10
Train Loss: 0.1591
Val Loss: 0.2748
Val AUC: 0.9799
Epoch 10/10
Train Loss: 0.1387
Val Loss: 0.2762
Val AUC: 0.9825
