In [3]:
import pickle
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
from sklearn.metrics import roc_auc_score, accuracy_score
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt
from pathlib import Path
import json
from datetime import datetime
import logging
import cv2
import random
from tensorflow.keras.preprocessing.image import ImageDataGenerator

2024-12-06 07:05:42.186529: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1733468742.244076    1414 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1733468742.257987    1414 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-06 07:05:42.382812: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('training.log'),
        logging.StreamHandler()
    ]
)

# Create directories for saving results
Path("models").mkdir(exist_ok=True)
Path("results").mkdir(exist_ok=True)
Path("plots").mkdir(exist_ok=True)

def load_data():
    """Load the pickle files containing the dataset."""
    
    with open('train_data.pkl', 'rb') as f:
        train_data = pickle.load(f)
    
    
    with open('test_data.pkl', 'rb') as f:
        test_data = pickle.load(f)
    
    X_train = np.array(train_data['images'])
    y_train = np.array(train_data['labels'])
    X_test = np.array(test_data['images'])

    class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
    class_weights = torch.tensor(class_weights, dtype=torch.float)
    
    logging.info(f"Training data shape: {X_train.shape}")
    logging.info(f"Test data shape: {X_test.shape}")
    logging.info(f"Label distribution: {np.bincount(y_train)}")
    
    return X_train, y_train, X_test, class_weights


In [8]:
# Early stopping implementation
class EarlyStopping:
    def __init__(self, patience=3, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        
    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

X_train, y_train, X_test, class_weights = load_data()

# Temperature Scaled Cross Entropy Loss
class TemperatureScaledCrossEntropyLoss(nn.Module):
    def __init__(self, temperature=2.0):
        super(TemperatureScaledCrossEntropyLoss, self).__init__()
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss(weight=class_weights)

    def forward(self, logits, targets):
        scaled_logits = logits / self.temperature
        return self.criterion(scaled_logits, targets)

class Trainer:
    def __init__(self, model, device, config):
        self.model = model.to(device)
        self.device = device
        self.config = config
        self.class_weights = class_weights
        self.optimizer = optim.Adam(
            model.parameters(),
            lr=config['learning_rate'],
            weight_decay=1e-5
        )

        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='max',
            factor=0.5,
            patience=2,
        )

        
        self.criterion = TemperatureScaledCrossEntropyLoss(temperature=2.0)
        
        self.train_losses = []
        self.val_losses = []
        self.train_accs = []
        self.val_accs = []

        # Initialization of early stopping
        self.early_stopping = EarlyStopping(patience=3)
    
    def train_epoch(self, train_loader):
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.to(self.device)
            target = target.to(self.device)
            
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
        return total_loss / len(train_loader), 100. * correct / total

    def validate(self, val_loader):
        self.model.eval()
        val_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                val_loss += self.criterion(output, target).item()
                
                _, predicted = output.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()
                
                all_preds.extend(F.softmax(output, dim=1).cpu().numpy())
                all_targets.extend(target.cpu().numpy())
        
        val_loss /= len(val_loader)
        accuracy = 100. * correct / total
        
        # Calculate AUC for each class
        all_preds = np.array(all_preds)
        all_targets = np.array(all_targets)
        aucs = []
        for i in range(4):
            y_true = (all_targets == i).astype(int)
            y_pred = all_preds[:, i]
            try:
                auc = roc_auc_score(y_true, y_pred)
                aucs.append(auc)
            except:
                aucs.append(0.0)
        
        return val_loss, accuracy, aucs

def augment_minority_classes(X_train, y_train, transform):
    """Augment minority classes with transformations"""
    augmented_images = []
    augmented_labels = []
    
    augmentation_ratios = {
        0: 3,  # Double class 0 samples
        1: 9,   # sixtuple class 1 samples
        2: 12,   # octuple class 2 samples
        3: 2 # slightly increase class 3 samples
    }
    
    for class_label, ratio in augmentation_ratios.items():
        mask = y_train == class_label
        class_images = X_train[mask]
        
        for img in class_images:
            for _ in range(ratio):
                img_tensor = np.expand_dims(img, -1)  # Add channel dimension
                img_tensor = np.expand_dims(img_tensor, 0) # Add batch dimension
                augmented = next(transform.flow(img_tensor, batch_size=1))[0]
                augmented_images.append(augmented)
                augmented_labels.append(class_label)
    
    augmented_images = np.array(augmented_images)
    augmented_labels = np.array(augmented_labels)

    num_augmented_images = len(augmented_images)
    print(f"Number of augmented images: {num_augmented_images}")
    logging.info(f"Number of augmented images: {num_augmented_images}")
    
    X_train_aug = np.concatenate((np.expand_dims(X_train, -1), augmented_images), axis=0)
    y_train_aug = np.concatenate((y_train, augmented_labels), axis=0)
    X_train_aug = X_train_aug.squeeze()
    
    return X_train_aug, y_train_aug        

2024-12-06 07:14:12,840 - INFO - Training data shape: (97477, 28, 28)
2024-12-06 07:14:12,841 - INFO - Test data shape: (1000, 28, 28)
2024-12-06 07:14:12,842 - INFO - Label distribution: [33484 10213  7754 46026]


In [9]:
class VerticalAttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.vertical_conv = nn.Conv2d(
            in_channels, in_channels, 
            kernel_size=(3, 1),  # Vertical kernel
            padding=(1, 0)
        )
        self.bn = nn.BatchNorm2d(in_channels)
        
    def forward(self, x):
        vertical_features = self.vertical_conv(x)
        return x + self.bn(vertical_features)

class ModifiedResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        # 1x1 conv for dimension reduction
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, 
                              stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        # 3x3 conv with vertical attention
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                              stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.vertical_attn = VerticalAttentionBlock(out_channels)
        
        # 1x1 conv for dimension expansion
        self.conv3 = nn.Conv2d(out_channels, out_channels * 4, kernel_size=1,
                              stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels * 4)
        
        self.relu = nn.ReLU(inplace=True)
        
        # Adjust shortcut for dimension matching
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels * 4:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * 4, kernel_size=1,
                         stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * 4)
            )
            
    def forward(self, x):
        identity = x
        
        # First bottleneck transformation
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        # Second transformation with vertical attention
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.vertical_attn(out)
        out = self.relu(out)
        
        # Final expansion
        out = self.conv3(out)
        out = self.bn3(out)
        
        # Add identity
        out += self.shortcut(identity)
        out = self.relu(out)
        
        return out
        
class CustomResNet(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()
        self.in_channels = 64
        # Modified initial conv to be more sensitive to vertical features
        self.conv1 = nn.Conv2d(1, 64, kernel_size=(5,3), stride=1, padding=(2,1), bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.vertical_attn1 = VerticalAttentionBlock(64)
        self.relu = nn.ReLU(inplace=True)
        
        # Progressive channel expansion matching ResNet28x28
        self.layer1 = self.make_layer(64, 2, stride=1)    # 64 -> 256
        self.layer2 = self.make_layer(128, 2, stride=2)   # 256 -> 512
        self.layer3 = self.make_layer(256, 2, stride=2)   # 512 -> 1024
        self.layer4 = self.make_layer(512, 2, stride=1)   # 1024 -> 2048
        
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(2048, num_classes)  # 512 * 4 = 2048

    def make_layer(self, out_channels, num_blocks, stride):
        layers = []
        # First block handles stride and channel change
        layers.append(ModifiedResBlock(self.in_channels, out_channels, stride))
        
        # Update input channels for subsequent blocks
        self.in_channels = out_channels * 4  # Account for bottleneck expansion
        
        # Remaining blocks maintain channel dimensions
        for _ in range(1, num_blocks):
            layers.append(ModifiedResBlock(self.in_channels, out_channels))
            
        return nn.Sequential(*layers)

    def forward(self, x):
        # Initial convolution with vertical attention
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.vertical_attn1(x)
        x = self.relu(x)
        
        # ResNet blocks
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # Global pooling and classification
        x = self.avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        return x

# Modified preprocessing
def preprocess_image(image):
    # Print shape and type information
    #print("\nInput image shape: ", image.shape)
    #print("\nInput image type: ", image.dtype,)
    #print("\nInput image range:", torch.min(image).item(), torch.max(image).item())
    
    # Convert to numpy and ensure correct format
    image_np = image.numpy()
    
    # Ensure single channel
    if len(image_np.shape) > 2:
        image_np = image_np.squeeze()
        
    # Scale and convert to uint8
    image_np = np.clip(image_np * 255.0, 0, 255).astype(np.uint8)
    
    #print("\nBefore CLAHE shape:", image_np.shape)
    #print("\nBefore CLAHE type:", image_np.dtype)
    #print("\nBefore CLAHE range:", np.min(image_np), np.max(image_np))
    
    # Apply CLAHE
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(4,4))
    enhanced = clahe.apply(image_np)
    
    # Convert back to float32
    enhanced = enhanced.astype(np.float32) / 255.0
    return enhanced

class RetinalDataset(Dataset):
    def __init__(self, images, labels=None, transform=None, is_training=False):
        self.images = torch.FloatTensor(images).unsqueeze(1)
        self.labels = torch.LongTensor(labels) if labels is not None else None
        self.transform = transform
        self.is_training = is_training
        
        # Calculate augmentation probabilities based on class distribution
        if labels is not None:
            label_counts = np.bincount(labels)
            max_count = np.max(label_counts)
            self.aug_probs = {
                i: min(3.0, max_count / count) if count > 0 else 1.0 
                for i, count in enumerate(label_counts)
            }
    
    def __getitem__(self, idx):
        image = self.images[idx]
        
        if self.is_training and self.labels is not None:
            label = self.labels[idx]
            # Apply augmentation with probability based on class
            if random.random() < self.aug_probs[label.item()]:
                if self.transform:
                    image = preprocess_image(image)
                    image = torch.FloatTensor(image).unsqueeze(1)
                    image = self.transform(image)
        
        if self.labels is not None:
            return image, self.labels[idx]
        return image

    def __len__(self):
        return len(self.images)

In [12]:
def main():
    # Configuration
    config = {
        'seed': 42,
        'batch_size': 64,
        'learning_rate': 2e-4,
        'n_epochs': 50,
        'n_folds': 5
    }
    
    # Set random seeds
    torch.manual_seed(config['seed'])
    np.random.seed(config['seed'])
    
    # Load data
    X_train, y_train, X_test, weights = load_data()
    
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f"Using device: {device}")
    
    # Initialize K-fold cross validation
    skf = StratifiedKFold(n_splits=config['n_folds'], shuffle=True, random_state=config['seed'])
    
    # Transform for augmentation
    #transform = transforms.Compose([
    #    transforms.RandomRotation(degrees=30, fill=0),
    #transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), scale=None, shear=15, fill=0),
    #transforms.RandomHorizontalFlip(),
    #transforms.RandomResizedCrop(size=28, scale=(0.8, 1.2), ratio=(0.8, 1.2)),
#])
    train_transform = transforms.Compose([
        transforms.RandomRotation(30),
        transforms.RandomAffine(
            degrees=0, translate=(0.2, 0.2), 
            scale=(0.95, 1.05), shear=15),
        transforms.RandomResizedCrop(size=28, scale=(0.8, 1.2), ratio=(0.8, 1.2)),
        transforms.RandomHorizontalFlip()
])

    augment_gen = ImageDataGenerator(
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    zoom_range=0.2,
    shear_range=15,
    horizontal_flip=True,
    fill_mode="nearest"
)

    # Augment data for classes
    X_train, y_train = augment_minority_classes(X_train, y_train, augment_gen)
    
    X_train = torch.FloatTensor(X_train).unsqueeze(1)
    y_train = y_train.astype(int)
    print(X_train.shape)
    # Store fold results
    fold_results = []
    
    # Transform for training
    transform = transforms.Compose([
    #transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize
])

    # K-fold Cross Validation
    for fold, (train_idx, val_idx) in enumerate(skf.split(X_train, y_train)):
        logging.info(f"Training Fold {fold + 1}/{config['n_folds']}")
        
        # Split data
        X_train_fold = X_train[train_idx]
        y_train_fold = y_train[train_idx]
        X_val_fold = X_train[val_idx]
        y_val_fold = y_train[val_idx]
        
        # Create datasets
        train_dataset = RetinalDataset(X_train_fold, y_train_fold, transform=transform, is_training=True)
        val_dataset = RetinalDataset(X_val_fold, y_val_fold, transform=transform, is_training=True)
        
        # Create data loaders
        train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=4, pin_memory=True)
        val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], num_workers=4, pin_memory=True)
        
        # Initialize model and trainer
        model = CustomResNet()
        trainer = Trainer(model, device, config)
        
        # Training loop
        best_val_acc = 0
        for epoch in range(config['n_epochs']):
            train_loss, train_acc = trainer.train_epoch(train_loader)
            val_loss, val_acc, aucs = trainer.validate(val_loader)

            trainer.scheduler.step(val_acc)
            
            logging.info(f"Fold {fold + 1}, Epoch {epoch + 1}/{config['n_epochs']}")
            logging.info(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            logging.info(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
            logging.info(f"AUC per class: {[f'{auc:.4f}' for auc in aucs]}")
            logging.info(f"Learning Rate: {trainer.optimizer.param_groups[0]['lr']:.6f}")
            
            # Save best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), f"models/best_model_fold_{fold + 1}.pth")

            # Early stopping check
            trainer.early_stopping(val_loss)
            if trainer.early_stopping.early_stop:
                logging.info(f"Early stopping triggered at epoch {epoch + 1}")
                break
        
        fold_results.append({
            'fold': fold + 1,
            'best_val_acc': best_val_acc,
            'final_aucs': aucs,
            'final epoch': epoch + 1
        })
    
    # Save fold results
    with open('results/fold_results.json', 'w') as f:
        json.dump(fold_results, f, indent=4)
    
    # Generate predictions for test set
    test_dataset = RetinalDataset(X_test, transform=None)
    test_loader = DataLoader(test_dataset, batch_size=config['batch_size'])
    
    # Ensemble predictions from all folds
    all_predictions = []
    for fold in range(config['n_folds']):
        model = CustomResNet()
        model.load_state_dict(torch.load(f"models/best_model_fold_{fold + 1}.pth"))
        model.to(device)
        model.eval()
        
        fold_predictions = []
        with torch.no_grad():
            for data in test_loader:
                data = data.to(device)
                output = model(data)
                probabilities = F.softmax(output, dim=1)
                fold_predictions.extend(probabilities.cpu().numpy())
        
        all_predictions.append(fold_predictions)
    
    # Average predictions from all folds
    final_predictions = np.mean(all_predictions, axis=0)
    predicted_labels = np.argmax(final_predictions, axis=1)
    
    # Save predictions
    pd.DataFrame({
        'ID': range(1, len(predicted_labels) + 1),
        'Class': predicted_labels
    }).to_csv('results/test_predictions.csv', index=False)

if __name__ == "__main__":
    main()

2024-12-06 07:25:54,449 - INFO - Training data shape: (97477, 28, 28)
2024-12-06 07:25:54,450 - INFO - Test data shape: (1000, 28, 28)
2024-12-06 07:25:54,451 - INFO - Label distribution: [33484 10213  7754 46026]
2024-12-06 07:25:54,464 - INFO - Using device: cuda
2024-12-06 07:27:07,328 - INFO - Number of augmented images: 377469


Number of augmented images: 377469
torch.Size([474946, 1, 28, 28])


2024-12-06 07:27:07,544 - INFO - Training Fold 1/5


RuntimeError: Given groups=1, weight of size [64, 1, 5, 3], expected input[64, 28, 1, 28] to have 1 channels, but got 28 channels instead