In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import os
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# GPU Configuration
def setup_device():
    """Setup and display GPU configuration"""
    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()
        print(f"Number of GPUs available: {num_gpus}")
        
        for i in range(num_gpus):
            gpu_name = torch.cuda.get_device_name(i)
            gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1e9
            print(f"GPU {i}: {gpu_name}, Memory: {gpu_memory:.2f} GB")
        
        if num_gpus > 1:
            print(f"\n✓ Using DataParallel with {num_gpus} GPUs for faster training!")
            device = torch.device("cuda")
        else:
            print("\n✓ Using single GPU")
            device = torch.device("cuda")
    else:
        print("No GPU available, using CPU")
        device = torch.device("cpu")
    
    return device, num_gpus if torch.cuda.is_available() else 0

device, num_gpus = setup_device()

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)  # Set seed for all GPUs
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Configuration
class Config:
    # Paths - Update this path based on your Kaggle dataset location
    data_dir = '/kaggle/input/labeled-mri-brain-tumor-dataset/Brain Tumor labeled dataset'
    
    # Hyperparameters (adjusted for multi-GPU)
    base_batch_size = 32  # Batch size per GPU
    batch_size = base_batch_size * max(1, num_gpus)  # Total batch size
    num_epochs = 25
    base_learning_rate = 0.001
    learning_rate = base_learning_rate * max(1, num_gpus)  # Linear scaling rule
    num_workers = 4 * max(1, num_gpus)  # More workers for multi-GPU
    
    # Model
    model_name = 'resnet101'  # Can be 'resnet18', 'resnet34', 'resnet50', 'resnet101'
    pretrained = True
    freeze_backbone = False  # Set to True to freeze early layers
    
    # Data split
    train_split = 0.8
    val_split = 0.1
    test_split = 0.1
    
    # Image settings
    img_size = 224
    
    # Training settings
    gradient_accumulation_steps = 1
    mixed_precision = True  # Use mixed precision training for faster computation
    
config = Config()

print(f"\nTraining Configuration:")
print(f"Total Batch Size: {config.batch_size}")
print(f"Learning Rate: {config.learning_rate}")
print(f"Number of Workers: {config.num_workers}")
print(f"Mixed Precision Training: {config.mixed_precision}")

# Data Transforms
def get_transforms():
    """Get data transformation pipelines"""
    
    train_transform = transforms.Compose([
        transforms.Resize((config.img_size, config.img_size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    val_test_transform = transforms.Compose([
        transforms.Resize((config.img_size, config.img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    return train_transform, val_test_transform

# Load Dataset
def load_data():
    """Load and split the dataset"""
    
    train_transform, val_test_transform = get_transforms()
    
    # Load the full dataset
    full_dataset = datasets.ImageFolder(root=config.data_dir, transform=train_transform)
    
    # Get class names
    class_names = full_dataset.classes
    num_classes = len(class_names)
    print(f"\nDataset Information:")
    print(f"Classes: {class_names}")
    print(f"Number of classes: {num_classes}")
    
    # Calculate split sizes
    total_size = len(full_dataset)
    train_size = int(config.train_split * total_size)
    val_size = int(config.val_split * total_size)
    test_size = total_size - train_size - val_size
    
    # Random split
    train_dataset, val_dataset, test_dataset = random_split(
        full_dataset, [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    # Update transforms for validation and test sets
    val_dataset.dataset.transform = val_test_transform
    test_dataset.dataset.transform = val_test_transform
    
    # Create data loaders with optimized settings for multi-GPU
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config.batch_size, 
        shuffle=True, 
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False,
        prefetch_factor=2
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=config.batch_size, 
        shuffle=False, 
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False,
        prefetch_factor=2
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=config.batch_size, 
        shuffle=False, 
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False,
        prefetch_factor=2
    )
    
    print(f"Dataset sizes - Train: {train_size}, Val: {val_size}, Test: {test_size}")
    
    return train_loader, val_loader, test_loader, class_names, num_classes

# Create Model
def create_model(num_classes):
    """Create and configure the ResNet model with multi-GPU support"""
    
    # Select ResNet architecture
    model_dict = {
        'resnet18': models.resnet18,
        'resnet34': models.resnet34,
        'resnet50': models.resnet50,
        'resnet101': models.resnet101,
    }
    
    # Load pretrained model
    model = model_dict[config.model_name](pretrained=config.pretrained)
    
    # Freeze backbone layers if specified
    if config.freeze_backbone:
        for param in model.parameters():
            param.requires_grad = False
    
    # Modify the final fully connected layer
    num_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(num_features, 512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, num_classes)
    )
    
    # Move model to device
    model = model.to(device)
    
    # Wrap model with DataParallel for multi-GPU training
    if num_gpus > 1:
        model = nn.DataParallel(model)
        print(f"\n✓ Model wrapped with DataParallel for {num_gpus} GPUs")
    
    return model

# Mixed Precision Training Setup
scaler = torch.cuda.amp.GradScaler() if config.mixed_precision and torch.cuda.is_available() else None

# Training function with multi-GPU optimization
def train_epoch(model, dataloader, criterion, optimizer):
    """Train the model for one epoch with multi-GPU support"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(dataloader, desc='Training')
    for batch_idx, (images, labels) in enumerate(progress_bar):
        images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        
        # Mixed precision training
        if config.mixed_precision and scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss = loss / config.gradient_accumulation_steps
            
            scaler.scale(loss).backward()
            
            if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
        else:
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss = loss / config.gradient_accumulation_steps
            loss.backward()
            
            if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
        
        # Statistics
        running_loss += loss.item() * config.gradient_accumulation_steps
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': loss.item() * config.gradient_accumulation_steps,
            'acc': 100. * correct / total,
            'gpu_mem': f'{torch.cuda.memory_allocated()/1e9:.2f}GB' if torch.cuda.is_available() else 'N/A'
        })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

# Validation function with multi-GPU support
def validate(model, dataloader, criterion):
    """Validate the model"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc='Validation'):
            images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            
            if config.mixed_precision and torch.cuda.is_available():
                with torch.cuda.amp.autocast():
                    outputs = model(images)
                    loss = criterion(outputs, labels)
            else:
                outputs = model(images)
                loss = criterion(outputs, labels)
            
            # Statistics
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc, all_predictions, all_labels

# Training loop with learning rate warmup for multi-GPU
def train_model(model, train_loader, val_loader, num_epochs):
    """Main training loop with multi-GPU optimizations"""
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    
    # Use different optimizer for better multi-GPU training
    optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=0.01)
    
    # Learning rate scheduler with warmup
    warmup_epochs = 3
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=config.learning_rate,
        epochs=num_epochs,
        steps_per_epoch=len(train_loader),
        pct_start=warmup_epochs/num_epochs,
        anneal_strategy='cos'
    )
    
    # Training history
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    best_val_acc = 0.0
    best_model_state = None
    
    print(f"\nStarting training on {num_gpus if num_gpus > 0 else 1} device(s)...")
    print("="*60)
    
    for epoch in range(num_epochs):
        print(f'\nEpoch [{epoch+1}/{num_epochs}]')
        print('-' * 60)
        
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer)
        
        # Step scheduler after each batch (already done in train_epoch for OneCycleLR)
        if isinstance(scheduler, optim.lr_scheduler.OneCycleLR):
            pass  # OneCycleLR steps are handled in train_epoch
        else:
            scheduler.step()
        
        # Validate
        val_loss, val_acc, _, _ = validate(model, val_loader, criterion)
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Print metrics
        current_lr = optimizer.param_groups[0]['lr']
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        print(f'Learning Rate: {current_lr:.6f}')
        
        if torch.cuda.is_available():
            for i in range(num_gpus):
                memory_allocated = torch.cuda.memory_allocated(i) / 1e9
                memory_reserved = torch.cuda.memory_reserved(i) / 1e9
                print(f'GPU {i} - Allocated: {memory_allocated:.2f}GB, Reserved: {memory_reserved:.2f}GB')
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            # Handle DataParallel when saving
            if isinstance(model, nn.DataParallel):
                best_model_state = model.module.state_dict().copy()
            else:
                best_model_state = model.state_dict().copy()
            print(f'✓ Best model saved with validation accuracy: {best_val_acc:.2f}%')
        
        # Clear cache to free memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # Load best model
    if best_model_state is not None:
        if isinstance(model, nn.DataParallel):
            model.module.load_state_dict(best_model_state)
        else:
            model.load_state_dict(best_model_state)
    
    return model, history

# Visualization functions
def plot_training_history(history):
    """Plot training history"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss plot
    ax1.plot(history['train_loss'], label='Train Loss', linewidth=2)
    ax1.plot(history['val_loss'], label='Validation Loss', linewidth=2)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title('Training and Validation Loss', fontsize=14)
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)
    
    # Accuracy plot
    ax2.plot(history['train_acc'], label='Train Accuracy', linewidth=2)
    ax2.plot(history['val_acc'], label='Validation Accuracy', linewidth=2)
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Accuracy (%)', fontsize=12)
    ax2.set_title('Training and Validation Accuracy', fontsize=14)
    ax2.legend(fontsize=11)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

def plot_confusion_matrix(y_true, y_pred, class_names):
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Count'})
    plt.title('Confusion Matrix', fontsize=16)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.ylabel('True Label', fontsize=12)
    plt.show()

# Test function
def test_model(model, test_loader, class_names):
    """Test the model and display results"""
    criterion = nn.CrossEntropyLoss()
    
    print("\nEvaluating on test set...")
    test_loss, test_acc, predictions, labels = validate(model, test_loader, criterion)
    
    print(f'\nTest Loss: {test_loss:.4f}')
    print(f'Test Accuracy: {test_acc:.2f}%')
    
    # Classification report
    print("\nClassification Report:")
    print("="*60)
    print(classification_report(labels, predictions, target_names=class_names))
    
    # Confusion matrix
    plot_confusion_matrix(labels, predictions, class_names)
    
    return test_acc

# Display sample predictions
def display_sample_predictions(model, test_loader, class_names, num_samples=8):
    """Display sample predictions from the test set"""
    model.eval()
    images, labels = next(iter(test_loader))
    images, labels = images[:num_samples].to(device), labels[:num_samples]
    
    with torch.no_grad():
        if config.mixed_precision and torch.cuda.is_available():
            with torch.cuda.amp.autocast():
                outputs = model(images)
        else:
            outputs = model(images)
        _, predicted = torch.max(outputs, 1)
    
    # Denormalize images for display
    def denormalize(tensor):
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        return tensor * std + mean
    
    fig, axes = plt.subplots(2, 4, figsize=(15, 8))
    axes = axes.ravel()
    
    for idx in range(num_samples):
        img = denormalize(images[idx].cpu())
        img = np.transpose(img.numpy(), (1, 2, 0))
        img = np.clip(img, 0, 1)
        
        axes[idx].imshow(img)
        axes[idx].set_title(f'True: {class_names[labels[idx]]}\nPred: {class_names[predicted[idx]]}',
                           color='green' if labels[idx] == predicted[idx] else 'red')
        axes[idx].axis('off')
    
    plt.suptitle('Sample Predictions', fontsize=16)
    plt.tight_layout()
    plt.show()

# Main execution
def main():
    """Main execution function"""
    
    # Load data
    train_loader, val_loader, test_loader, class_names, num_classes = load_data()
    
    # Create model
    model = create_model(num_classes)
    
    # Model information
    if isinstance(model, nn.DataParallel):
        total_params = sum(p.numel() for p in model.module.parameters())
        trainable_params = sum(p.numel() for p in model.module.parameters() if p.requires_grad)
    else:
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"\nModel: {config.model_name}")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Train model
    model, history = train_model(model, train_loader, val_loader, config.num_epochs)
    
    # Plot training history
    plot_training_history(history)
    
    # Test model
    test_accuracy = test_model(model, test_loader, class_names)
    
    # Display sample predictions
    display_sample_predictions(model, test_loader, class_names)
    
    # Save model
    model_path = f'{config.model_name}_brain_tumor_classifier_multigpu.pth'
    
    # Handle DataParallel when saving
    if isinstance(model, nn.DataParallel):
        model_to_save = model.module
    else:
        model_to_save = model
    
    torch.save({
        'model_state_dict': model_to_save.state_dict(),
        'class_names': class_names,
        'test_accuracy': test_accuracy,
        'config': config.__dict__,
        'num_gpus_trained': num_gpus
    }, model_path)
    
    print(f"\nModel saved to {model_path}")
    print(f"Training completed using {num_gpus} GPU(s)!")
    
    return model, history, test_accuracy

# Run the training
if __name__ == "__main__":
    model, history, test_accuracy = main()