# Fracture Detection Model Training with PyTorch

This notebook trains three different deep learning models (ResNet50, DenseNet121, EfficientNetB0) for fracture detection using X-ray images. The models are trained on the Kaggle fracture multi-region X-ray dataset.

**Dataset:** [Fracture Multi-Region X-ray Data](https://www.kaggle.com/datasets/bmadushanirodrigo/fracture-multi-region-x-ray-data/data)

**Models:**
- ResNet50
- DenseNet121  
- EfficientNetB0

**Environment:** Google Colab with free GPU

## 1. Environment Setup and Dependencies

In [None]:
# Check if running on Colab
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Running on Google Colab")
    # Install required packages
    !pip install kaggle
    !pip install efficientnet-pytorch
    !pip install albumentations
else:
    print("Not running on Colab")

In [None]:
# Import required libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms, models
import torch.nn.functional as F

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import os
import cv2
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import albumentations as A
from albumentations.pytorch import ToTensorV2

# For EfficientNet
try:
    from efficientnet_pytorch import EfficientNet
except ImportError:
    print("EfficientNet not installed, installing now...")
    !pip install efficientnet-pytorch
    from efficientnet_pytorch import EfficientNet

import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    print("Using CPU")

## 2. Kaggle API Setup and Dataset Download

In [None]:
# Mount Google Drive and setup Kaggle API
if IN_COLAB:
    # Mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')

    # Create necessary directories in Drive
    import os
    drive_base = '/content/drive/MyDrive/Capstone'
    models_dir = f'{drive_base}/Models'
    dataset_dir = f'{drive_base}/Dataset'

    os.makedirs(drive_base, exist_ok=True)
    os.makedirs(models_dir, exist_ok=True)
    os.makedirs(dataset_dir, exist_ok=True)

    print(f"Created directories:")
    print(f"- Base: {drive_base}")
    print(f"- Models: {models_dir}")
    print(f"- Dataset: {dataset_dir}")

    # Setup Kaggle API
    from google.colab import files
    print("\nPlease upload your kaggle.json file:")
    uploaded = files.upload()

    # Setup Kaggle API
    !mkdir -p ~/.kaggle
    !mv kaggle.json ~/.kaggle/
    !chmod 600 ~/.kaggle/kaggle.json

    print("Kaggle API setup complete!")

In [None]:
# Download the fracture dataset to Google Drive
if IN_COLAB:
    # Check if dataset already exists in Drive
    dataset_zip_path = f'{dataset_dir}/fracture-multi-region-x-ray-data.zip'

    if os.path.exists(dataset_zip_path):
        print("Dataset already exists in Google Drive. Skipping download.")
    else:
        print("Downloading dataset to Google Drive...")
        # Change to dataset directory
        os.chdir(dataset_dir)

        !kaggle datasets download -d bmadushanirodrigo/fracture-multi-region-x-ray-data
        print("Dataset downloaded to Google Drive!")

    # Extract if not already extracted
    extracted_check = os.path.join(dataset_dir, 'FractureDataset')
    if not os.path.exists(extracted_check):
        print("Extracting dataset...")
        os.chdir(dataset_dir)
        !unzip -q fracture-multi-region-x-ray-data.zip
        print("Dataset extracted!")
    else:
        print("Dataset already extracted in Google Drive.")

    # Change back to content directory
    os.chdir('/content')

    # List the contents to understand the structure
    print("\nDataset structure in Google Drive:")
    for root, dirs, files in os.walk(dataset_dir):
        level = root.replace(dataset_dir, '').count(os.sep)
        indent = ' ' * 2 * level
        print(f"{indent}{os.path.basename(root)}/")
        subindent = ' ' * 2 * (level + 1)
        for file in files[:5]:  # Show only first 5 files
            print(f"{subindent}{file}")
        if len(files) > 5:
            print(f"{subindent}... and {len(files)-5} more files")
        if level > 3:  # Limit depth
            break

## 3. Data Exploration and Preprocessing

In [None]:
# Define dataset paths (now using Google Drive)
if IN_COLAB:
    # Use Google Drive paths - check for all three folders
    base_dataset_path = f'{dataset_dir}/Bone_Fracture_Binary_Classification/Bone_Fracture_Binary_Classification'
    train_dir = f'{base_dataset_path}/train'
    val_dir = f'{base_dataset_path}/val'  # Use existing validation folder
    test_dir = f'{base_dataset_path}/test'

    # If the structure is different, explore and update
    if not os.path.exists(train_dir):
        # Find the actual dataset directory in Google Drive
        print("Exploring dataset structure in Google Drive...")
        for root, dirs, files in os.walk(dataset_dir):
            if 'train' in dirs or 'Fractured' in dirs or 'Non-Fractured' in dirs:
                print(f"Found potential dataset directory: {root}")
                print(f"Subdirectories: {dirs}")
                # Update paths based on actual structure
                if 'train' in dirs:
                    base_dataset_path = root
                    train_dir = os.path.join(root, 'train')
                if 'val' in dirs:
                    val_dir = os.path.join(root, 'val')
                if 'test' in dirs:
                    test_dir = os.path.join(root, 'test')
                break
else:
    train_dir = './dataset/train'
    val_dir = './dataset/val'
    test_dir = './dataset/test'

print(f"Train directory: {train_dir}")
print(f"Validation directory: {val_dir}")
print(f"Test directory: {test_dir}")
print(f"Models will be saved to: {models_dir if IN_COLAB else './models'}")

# Check which folders actually exist
existing_folders = []
if os.path.exists(train_dir):
    existing_folders.append('train')
if os.path.exists(val_dir):
    existing_folders.append('val')
if os.path.exists(test_dir):
    existing_folders.append('test')

print(f"Available folders: {existing_folders}")

In [None]:
# Create a function to explore dataset structure and create file lists
def explore_dataset_structure(base_path):
    """Explore and understand the dataset structure"""
    if not os.path.exists(base_path):
        print(f"Path {base_path} does not exist. Exploring available paths...")

        # Look for common fracture dataset patterns
        search_path = dataset_dir if IN_COLAB else '.'
        for root, dirs, files in os.walk(search_path):
            if any(keyword in root.lower() for keyword in ['fracture', 'train', 'test']):
                print(f"Found: {root}")
                if dirs:
                    print(f"  Subdirs: {dirs}")
                if files[:3]:  # Show first 3 files
                    print(f"  Files: {files[:3]}...")
        return None, None

    # If path exists, explore structure
    class_dirs = [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]
    print(f"Classes found in {base_path}: {class_dirs}")

    file_list = []
    labels = []

    # FIXED: Assign labels based on folder content, not alphabetical order
    # Label 0 = Non-Fractured, Label 1 = Fractured (matching Flask app expectations)
    for class_name in class_dirs:
        class_path = os.path.join(base_path, class_name)
        images = [f for f in os.listdir(class_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

        print(f"Class '{class_name}': {len(images)} images")

        class_name_lower = class_name.lower()
        
        # Determine correct label based on folder name
        # Check for NON-fractured keywords FIRST (to handle "not fractured" correctly)
        if any(keyword in class_name_lower for keyword in ['not fractured', 'non-fractured', 'normal', 'negative', 'no', '0']):
            # This is the non-fractured class  
            label = 0
            print(f"  -> Assigned as NON-FRACTURED (label=0)")
        elif any(keyword in class_name_lower for keyword in ['fractured', 'fracture', 'positive', 'yes', '1']):
            # This is the fractured class
            label = 1
            print(f"  -> Assigned as FRACTURED (label=1)")
        else:
            # Fallback: ask user to verify or make educated guess
            print(f"  -> WARNING: Could not determine class type from name '{class_name}'")
            print(f"     Please verify: Is '{class_name}' fractured (1) or non-fractured (0)?")
            # For now, assume alphabetical assignment but warn user
            class_idx = class_dirs.index(class_name)
            label = class_idx
            print(f"  -> Using alphabetical assignment: label={label}")

        for img in images:
            file_list.append(os.path.join(class_path, img))
            labels.append(label)

    # Print summary
    unique_labels = set(labels)
    print(f"\nLabel assignment summary:")
    for lbl in sorted(unique_labels):
        count = labels.count(lbl)
        label_name = "NON-FRACTURED" if lbl == 0 else "FRACTURED"
        print(f"  Label {lbl} ({label_name}): {count} images")

    return file_list, labels

# Explore the dataset from Google Drive
search_location = train_dir if os.path.exists(train_dir) else (dataset_dir if IN_COLAB else '/content')
train_files, train_labels = explore_dataset_structure(search_location)

In [None]:
# Load data from all available folders (train, val, test)
def load_dataset_from_folders():
    """Load dataset from existing train/val/test folders"""
    all_files = []
    all_labels = []
    dataset_info = {}
    
    folders_to_check = [
        ('train', train_dir),
        ('val', val_dir), 
        ('test', test_dir)
    ]
    
    for folder_name, folder_path in folders_to_check:
        if os.path.exists(folder_path):
            print(f"\n--- Loading {folder_name.upper()} data ---")
            files, labels = explore_dataset_structure(folder_path)
            
            if files:
                dataset_info[folder_name] = {
                    'files': files,
                    'labels': labels,
                    'count': len(files)
                }
                print(f"Loaded {len(files)} images from {folder_name} folder")
                
                # Add to combined dataset for overall statistics
                all_files.extend(files)
                all_labels.extend(labels)
            else:
                print(f"No files found in {folder_name} folder")
        else:
            print(f"{folder_name.upper()} folder not found: {folder_path}")
    
    return dataset_info, all_files, all_labels

# Load from existing folder structure
dataset_info, combined_files, combined_labels = load_dataset_from_folders()

if combined_files:
    print(f"\n=== DATASET SUMMARY ===")
    print(f"Total images across all folders: {len(combined_files)}")
    total_label_counts = pd.Series(combined_labels).value_counts().sort_index()
    print(f"Overall label distribution:")
    for label, count in total_label_counts.items():
        label_name = "NON-FRACTURED" if label == 0 else "FRACTURED"  
        print(f"  {label_name} (label={label}): {count} images")
    
    # Print breakdown by folder
    print(f"\nBreakdown by folder:")
    for folder_name, info in dataset_info.items():
        folder_label_counts = pd.Series(info['labels']).value_counts().sort_index()
        print(f"  {folder_name.upper()}: {info['count']} images")
        for label, count in folder_label_counts.items():
            label_name = "NON-FRACTURED" if label == 0 else "FRACTURED"
            print(f"    {label_name}: {count}")
else:
    print("No images found in any folder. Please check the dataset structure.")

In [None]:
# Visualize sample images with CORRECTED labels
if train_files:
    # Sample some images to display
    fig, axes = plt.subplots(2, 4, figsize=(15, 8))
    fig.suptitle('Dataset Label Verification\n(Top row: Non-Fractured, Bottom row: Fractured)', fontsize=14)

    # Show 4 fractured and 4 non-fractured images
    fractured_indices = [i for i, label in enumerate(train_labels) if label == 1]
    non_fractured_indices = [i for i, label in enumerate(train_labels) if label == 0]

    print(f"Found {len(non_fractured_indices)} non-fractured images (label=0)")
    print(f"Found {len(fractured_indices)} fractured images (label=1)")

    # Display NON-fractured images (label=0) in TOP row
    for i in range(4):
        if i < len(non_fractured_indices):
            img_path = train_files[non_fractured_indices[i]]
            img = Image.open(img_path)
            axes[0, i].imshow(img, cmap='gray')
            axes[0, i].set_title(f'NON-FRACTURED\n(Label=0)', color='green')
            axes[0, i].axis('off')
        else:
            axes[0, i].set_title('No Image')
            axes[0, i].axis('off')

    # Display FRACTURED images (label=1) in BOTTOM row  
    for i in range(4):
        if i < len(fractured_indices):
            img_path = train_files[fractured_indices[i]]
            img = Image.open(img_path)
            axes[1, i].imshow(img, cmap='gray')
            axes[1, i].set_title(f'FRACTURED\n(Label=1)', color='red')
            axes[1, i].axis('off')
        else:
            axes[1, i].set_title('No Image')
            axes[1, i].axis('off')

    plt.tight_layout()
    plt.show()

    # Check image sizes
    sample_sizes = []
    for i in range(min(10, len(train_files))):
        img = Image.open(train_files[i])
        sample_sizes.append(img.size)

    print(f"\nSample image sizes: {sample_sizes[:5]}")
    print(f"Unique sizes: {list(set(sample_sizes))}")
    
    # VERIFICATION: Check if labels match folder names
    print(f"\nLabel verification:")
    print(f"Total images: {len(train_files)}")
    label_counts = pd.Series(train_labels).value_counts().sort_index()
    print(f"Label distribution:")
    for label, count in label_counts.items():
        label_name = "NON-FRACTURED" if label == 0 else "FRACTURED"  
        print(f"  {label_name} (label={label}): {count} images")
    
    # Show sample file paths to verify folder assignment
    print(f"\nSample file paths with labels:")
    for i in range(min(5, len(train_files))):
        label_name = "NON-FRACTURED" if train_labels[i] == 0 else "FRACTURED"
        print(f"  {os.path.basename(os.path.dirname(train_files[i]))} -> {label_name} (label={train_labels[i]})")
else:
    print("No images found for visualization")

### Important: Label Assignment Verification

If you've already trained models with potentially swapped labels, you have two options:

1. **Retrain models** with corrected labels (recommended for accuracy)
2. **Adjust prediction interpretation** in your Flask app (quick fix)

The corrected labeling is:
- **Label 0**: Non-Fractured  
- **Label 1**: Fractured

This matches your Flask app expectations: `fracture_classes = ['NON_FRACTURED', 'FRACTURED']`

## 4. Custom Dataset Class and Data Loaders

In [None]:
class FractureDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        try:
            # Load image
            image_path = self.image_paths[idx]
            image = cv2.imread(image_path)

            if image is None:
                # Try with PIL if cv2 fails
                image = Image.open(image_path)
                image = np.array(image)

            # Convert BGR to RGB if needed
            if len(image.shape) == 3 and image.shape[2] == 3:
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            elif len(image.shape) == 2:
                # Convert grayscale to RGB
                image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)

            # Apply transformations
            if self.transform:
                if isinstance(self.transform, A.Compose):
                    augmented = self.transform(image=image)
                    image = augmented['image']
                else:
                    # Convert to PIL for torchvision transforms
                    image = Image.fromarray(image)
                    image = self.transform(image)

            label = torch.tensor(self.labels[idx], dtype=torch.long)

            return image, label

        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            # Return a dummy image and label
            dummy_image = torch.zeros(3, 224, 224)
            dummy_label = torch.tensor(0, dtype=torch.long)
            return dummy_image, dummy_label

In [None]:
# Define data augmentation and preprocessing
IMG_SIZE = 224
BATCH_SIZE = 32

# Training transforms with augmentation
train_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=10, p=0.3),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
    A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

# Validation transforms (no augmentation)
val_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

print("Data transforms defined successfully!")

In [None]:
# Create datasets and data loaders using existing train/val/test splits
if combined_files and 'train' in dataset_info:
    print("\n=== CREATING DATA LOADERS ===")
    
    # Use existing train folder
    if 'train' in dataset_info:
        train_files = dataset_info['train']['files']
        train_labels = dataset_info['train']['labels']
        print(f"Training data: {len(train_files)} images")
    
    # Use existing val folder if available, otherwise create validation from train
    if 'val' in dataset_info:
        val_files = dataset_info['val']['files']
        val_labels = dataset_info['val']['labels']
        print(f"Validation data: {len(val_files)} images (from existing val folder)")
    else:
        # Fallback: split train data if no val folder exists
        print("No validation folder found. Splitting training data...")
        train_files, val_files, train_labels, val_labels = train_test_split(
            train_files, train_labels, test_size=0.2, random_state=42, stratify=train_labels
        )
        print(f"Training data (after split): {len(train_files)} images")
        print(f"Validation data (from train split): {len(val_files)} images")
    
    # Test data (optional - for final evaluation)
    if 'test' in dataset_info:
        test_files = dataset_info['test']['files']
        test_labels = dataset_info['test']['labels']
        print(f"Test data: {len(test_files)} images (for final evaluation)")
    
    # Print label distributions for each split
    print(f"\nLabel distributions:")
    train_dist = pd.Series(train_labels).value_counts().sort_index()
    val_dist = pd.Series(val_labels).value_counts().sort_index()
    
    print(f"Training:")
    for label, count in train_dist.items():
        label_name = "NON-FRACTURED" if label == 0 else "FRACTURED"
        print(f"  {label_name}: {count} ({count/len(train_labels)*100:.1f}%)")
    
    print(f"Validation:")
    for label, count in val_dist.items():
        label_name = "NON-FRACTURED" if label == 0 else "FRACTURED"
        print(f"  {label_name}: {count} ({count/len(val_labels)*100:.1f}%)")
    
    if 'test' in dataset_info:
        test_dist = pd.Series(test_labels).value_counts().sort_index()
        print(f"Test:")
        for label, count in test_dist.items():
            label_name = "NON-FRACTURED" if label == 0 else "FRACTURED"
            print(f"  {label_name}: {count} ({count/len(test_labels)*100:.1f}%)")

    # Create datasets
    train_dataset = FractureDataset(train_files, train_labels, transform=train_transform)
    val_dataset = FractureDataset(val_files, val_labels, transform=val_transform)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    print(f"\nData loaders created:")
    print(f"  Train batches: {len(train_loader)} (batch size: {BATCH_SIZE})")
    print(f"  Validation batches: {len(val_loader)} (batch size: {BATCH_SIZE})")
    
    # Optional: Create test loader if test data exists
    if 'test' in dataset_info:
        test_dataset = FractureDataset(test_files, test_labels, transform=val_transform)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
        print(f"  Test batches: {len(test_loader)} (batch size: {BATCH_SIZE})")
    
    # Test loading a batch
    try:
        sample_batch = next(iter(train_loader))
        print(f"\nSample batch verification:")
        print(f"  Batch shape: {sample_batch[0].shape}")
        print(f"  Labels shape: {sample_batch[1].shape}")
        print(f"  Sample labels: {sample_batch[1][:8].tolist()}")  # Show first 8 labels
        print("✅ Data loaders created successfully!")
    except Exception as e:
        print(f"❌ Error in data loading: {e}")
else:
    print("❌ No training files found. Please check the dataset paths.")
    train_loader = None
    val_loader = None

## 5. Model Definitions

In [None]:
class FractureResNet50(nn.Module):
    def __init__(self, num_classes=2):
        super(FractureResNet50, self).__init__()
        self.backbone = models.resnet50(pretrained=True)

        # Freeze early layers
        for param in list(self.backbone.parameters())[:-20]:
            param.requires_grad = False

        # Replace the final layer
        self.backbone.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(self.backbone.fc.in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.backbone(x)

class FractureDenseNet121(nn.Module):
    def __init__(self, num_classes=2):
        super(FractureDenseNet121, self).__init__()
        self.backbone = models.densenet121(pretrained=True)

        # Freeze early layers
        for param in list(self.backbone.parameters())[:-20]:
            param.requires_grad = False

        # Replace the final layer
        self.backbone.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(self.backbone.classifier.in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.backbone(x)

class FractureEfficientNetB0(nn.Module):
    def __init__(self, num_classes=2):
        super(FractureEfficientNetB0, self).__init__()
        self.backbone = EfficientNet.from_pretrained('efficientnet-b0')

        # Freeze early layers
        for param in list(self.backbone.parameters())[:-20]:
            param.requires_grad = False

        # Replace the final layer
        self.backbone._fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(self.backbone._fc.in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.backbone(x)

print("Model classes defined successfully!")

## 6. Training Functions

In [None]:
def train_model(model, train_loader, val_loader, num_epochs=10, learning_rate=0.001):
    """
    Train a model and return training history
    """
    model = model.to(device)

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)

    # Training history
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }

    best_val_acc = 0.0
    best_model_state = None

    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 50)

        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()

            if batch_idx % 10 == 0:
                print(f'Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

        # Calculate metrics
        train_loss /= len(train_loader)
        train_acc = 100. * train_correct / train_total
        val_loss /= len(val_loader)
        val_acc = 100. * val_correct / val_total

        # Update history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()

        # Update learning rate
        scheduler.step(val_loss)

        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        print(f'Best Val Acc: {best_val_acc:.2f}%')

    # Load best model
    model.load_state_dict(best_model_state)

    return model, history

def evaluate_model(model, val_loader):
    """
    Evaluate model and return predictions for detailed analysis
    """
    model.eval()
    all_predictions = []
    all_labels = []
    all_probabilities = []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            probabilities = F.softmax(outputs, dim=1)
            _, predicted = outputs.max(1)

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())

    return all_predictions, all_labels, all_probabilities

print("Training functions defined successfully!")

## 7. Model Training

In [None]:
# Training parameters
NUM_EPOCHS = 15
LEARNING_RATE = 0.001

# Initialize models
models_dict = {
    'ResNet50': FractureResNet50(),
    'DenseNet121': FractureDenseNet121(),
    'EfficientNetB0': FractureEfficientNetB0()
}

trained_models = {}
training_histories = {}

print("Starting model training...")
print(f"Training on device: {device}")

In [None]:
# Train ResNet50
if train_files:
    print("\n" + "="*60)
    print("TRAINING RESNET50")
    print("="*60)

    resnet_model, resnet_history = train_model(
        models_dict['ResNet50'],
        train_loader,
        val_loader,
        num_epochs=NUM_EPOCHS,
        learning_rate=LEARNING_RATE
    )

    trained_models['ResNet50'] = resnet_model
    training_histories['ResNet50'] = resnet_history

    # Save model to Google Drive
    if IN_COLAB:
        model_path = f'{models_dir}/fracture_resnet50.pth'
    else:
        os.makedirs('models', exist_ok=True)
        model_path = 'models/fracture_resnet50.pth'

    torch.save(resnet_model.state_dict(), model_path)
    print(f"ResNet50 model saved to: {model_path}")
else:
    print("Skipping training - no data available")

In [None]:
# Train DenseNet121
if train_files:
    print("\n" + "="*60)
    print("TRAINING DENSENET121")
    print("="*60)

    densenet_model, densenet_history = train_model(
        models_dict['DenseNet121'],
        train_loader,
        val_loader,
        num_epochs=NUM_EPOCHS,
        learning_rate=LEARNING_RATE
    )

    trained_models['DenseNet121'] = densenet_model
    training_histories['DenseNet121'] = densenet_history

    # Save model to Google Drive
    if IN_COLAB:
        model_path = f'{models_dir}/fracture_densenet121.pth'
    else:
        os.makedirs('models', exist_ok=True)
        model_path = 'models/fracture_densenet121.pth'

    torch.save(densenet_model.state_dict(), model_path)
    print(f"DenseNet121 model saved to: {model_path}")
else:
    print("Skipping training - no data available")

In [None]:
# Train EfficientNetB0
if train_files:
    print("\n" + "="*60)
    print("TRAINING EFFICIENTNETB0")
    print("="*60)

    efficientnet_model, efficientnet_history = train_model(
        models_dict['EfficientNetB0'],
        train_loader,
        val_loader,
        num_epochs=NUM_EPOCHS,
        learning_rate=LEARNING_RATE
    )

    trained_models['EfficientNetB0'] = efficientnet_model
    training_histories['EfficientNetB0'] = efficientnet_history

    # Save model to Google Drive
    if IN_COLAB:
        model_path = f'{models_dir}/fracture_efficientnetb0.pth'
    else:
        os.makedirs('models', exist_ok=True)
        model_path = 'models/fracture_efficientnetb0.pth'

    torch.save(efficientnet_model.state_dict(), model_path)
    print(f"EfficientNetB0 model saved to: {model_path}")
else:
    print("Skipping training - no data available")

## 8. Results Visualization and Analysis

In [None]:
# Plot training histories
if training_histories:
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Training Loss
    for model_name, history in training_histories.items():
        axes[0, 0].plot(history['train_loss'], label=f'{model_name} Train')
    axes[0, 0].set_title('Training Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)

    # Validation Loss
    for model_name, history in training_histories.items():
        axes[0, 1].plot(history['val_loss'], label=f'{model_name} Val')
    axes[0, 1].set_title('Validation Loss')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].legend()
    axes[0, 1].grid(True)

    # Training Accuracy
    for model_name, history in training_histories.items():
        axes[1, 0].plot(history['train_acc'], label=f'{model_name} Train')
    axes[1, 0].set_title('Training Accuracy')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Accuracy (%)')
    axes[1, 0].legend()
    axes[1, 0].grid(True)

    # Validation Accuracy
    for model_name, history in training_histories.items():
        axes[1, 1].plot(history['val_acc'], label=f'{model_name} Val')
    axes[1, 1].set_title('Validation Accuracy')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Accuracy (%)')
    axes[1, 1].legend()
    axes[1, 1].grid(True)

    plt.tight_layout()
    plt.show()
else:
    print("No training histories to plot")

In [None]:
# Evaluate all models and compare performance
if trained_models and val_loader:
    model_results = {}

    for model_name, model in trained_models.items():
        print(f"\nEvaluating {model_name}...")
        predictions, true_labels, probabilities = evaluate_model(model, val_loader)

        # Calculate metrics
        accuracy = accuracy_score(true_labels, predictions)

        model_results[model_name] = {
            'accuracy': accuracy,
            'predictions': predictions,
            'true_labels': true_labels,
            'probabilities': probabilities
        }

        print(f"{model_name} - Accuracy: {accuracy:.4f}")
        print(f"{model_name} - Classification Report:")
        print(classification_report(true_labels, predictions,
                                  target_names=['Non-Fractured', 'Fractured']))

    # Summary comparison
    print("\n" + "="*50)
    print("MODEL COMPARISON SUMMARY")
    print("="*50)

    for model_name, results in model_results.items():
        print(f"{model_name}: {results['accuracy']:.4f}")

    # Find best model
    best_model_name = max(model_results.keys(), key=lambda x: model_results[x]['accuracy'])
    best_accuracy = model_results[best_model_name]['accuracy']

    print(f"\nBest Model: {best_model_name} (Accuracy: {best_accuracy:.4f})")
else:
    print("No trained models to evaluate")

In [None]:
# Plot confusion matrices
if 'model_results' in locals():
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    for idx, (model_name, results) in enumerate(model_results.items()):
        cm = confusion_matrix(results['true_labels'], results['predictions'])

        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=['Non-Fractured', 'Fractured'],
                   yticklabels=['Non-Fractured', 'Fractured'],
                   ax=axes[idx])

        axes[idx].set_title(f'{model_name}\nAccuracy: {results["accuracy"]:.3f}')
        axes[idx].set_xlabel('Predicted')
        axes[idx].set_ylabel('Actual')

    plt.tight_layout()
    plt.show()

## 9. Model Export and Download (For Colab)

In [None]:
# Save all trained models and create a summary in Google Drive
if trained_models:
    # Create a summary of results
    results_summary = {
        'training_parameters': {
            'epochs': NUM_EPOCHS,
            'learning_rate': LEARNING_RATE,
            'batch_size': BATCH_SIZE,
            'image_size': IMG_SIZE
        },
        'model_performance': {},
        'storage_info': {
            'dataset_location': dataset_dir if IN_COLAB else './dataset',
            'models_location': models_dir if IN_COLAB else './models'
        }
    }

    if 'model_results' in locals():
        for model_name, results in model_results.items():
            results_summary['model_performance'][model_name] = {
                'accuracy': float(results['accuracy']),
                'best_val_accuracy': max(training_histories[model_name]['val_acc'])
            }

    # Save summary to Google Drive
    import json
    if IN_COLAB:
        summary_path = f'{models_dir}/training_summary.json'
    else:
        os.makedirs('models', exist_ok=True)
        summary_path = 'models/training_summary.json'

    with open(summary_path, 'w') as f:
        json.dump(results_summary, f, indent=2)

    print(f"Training summary saved to: {summary_path}")

    # List all saved files
    print("\nSaved files in Google Drive:" if IN_COLAB else "\nSaved files:")

    if IN_COLAB:
        saved_files = [
            f'{models_dir}/fracture_resnet50.pth',
            f'{models_dir}/fracture_densenet121.pth',
            f'{models_dir}/fracture_efficientnetb0.pth',
            f'{models_dir}/training_summary.json'
        ]
    else:
        saved_files = [
            'models/fracture_resnet50.pth',
            'models/fracture_densenet121.pth',
            'models/fracture_efficientnetb0.pth',
            'models/training_summary.json'
        ]

    for file in saved_files:
        if os.path.exists(file):
            size = os.path.getsize(file) / (1024*1024)  # Size in MB
            print(f"- {os.path.basename(file)} ({size:.1f} MB)")
            print(f"  Full path: {file}")

    # Optional: Also create local copies for download if on Colab
    if IN_COLAB:
        print("\nCreating local copies for download...")
        from google.colab import files

        local_files = []
        for file in saved_files:
            if os.path.exists(file):
                local_name = os.path.basename(file)
                # Copy to local content directory
                import shutil
                shutil.copy2(file, f'/content/{local_name}')
                local_files.append(local_name)
                print(f"Local copy created: /content/{local_name}")

        # Ask user if they want to download
        download_choice = input("\nDo you want to download model files locally? (y/n): ").lower()
        if download_choice == 'y':
            for local_file in local_files:
                try:
                    files.download(local_file)
                    print(f"Downloaded: {local_file}")
                except Exception as e:
                    print(f"Error downloading {local_file}: {e}")

    print(f"\n✅ All models and dataset are now saved in Google Drive!")
    if IN_COLAB:
        print(f"📁 Dataset: {dataset_dir}")
        print(f"🤖 Models: {models_dir}")
        print(f"\nYour data will persist across Colab sessions!")

else:
    print("No trained models to save")

## 10. Model Loading Template (For Future Use)

In [None]:
# Template code for loading trained models from Google Drive in your Flask app
template_code = """
# Template code to load the trained fracture detection models
# Copy this to your Flask application

import torch
import torch.nn as nn
from torchvision import models
from efficientnet_pytorch import EfficientNet

# Model definitions (copy from above)
class FractureResNet50(nn.Module):
    def __init__(self, num_classes=2):
        super(FractureResNet50, self).__init__()
        self.backbone = models.resnet50(pretrained=False)
        self.backbone.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.backbone(x)

# Similar class definitions for DenseNet121 and EfficientNetB0...

# Loading the models from Google Drive (when running in Colab)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# For Google Colab - mount drive first:
# from google.colab import drive
# drive.mount('/content/drive')

# Model paths (update these paths based on your Google Drive structure)
MODEL_BASE_PATH = '/content/drive/MyDrive/fracture_detection/models'
# For local use: MODEL_BASE_PATH = './models'

# Load ResNet50
resnet_model = FractureResNet50()
resnet_path = f'{MODEL_BASE_PATH}/fracture_resnet50.pth'
resnet_model.load_state_dict(torch.load(resnet_path, map_location=device))
resnet_model.eval()

# Load DenseNet121
densenet_model = FractureDenseNet121()
densenet_path = f'{MODEL_BASE_PATH}/fracture_densenet121.pth'
densenet_model.load_state_dict(torch.load(densenet_path, map_location=device))
densenet_model.eval()

# Load EfficientNetB0
efficientnet_model = FractureEfficientNetB0()
efficientnet_path = f'{MODEL_BASE_PATH}/fracture_efficientnetb0.pth'
efficientnet_model.load_state_dict(torch.load(efficientnet_path, map_location=device))
efficientnet_model.eval()

# Prediction function
def predict_fracture(image, model):
    '''
    Predict fracture from preprocessed image tensor
    Returns: (predicted_class, confidence_score)
    - predicted_class: 0 = Non-Fractured, 1 = Fractured
    - confidence_score: float between 0-1
    '''
    model.eval()
    with torch.no_grad():
        outputs = model(image)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        predicted = torch.argmax(outputs, 1)
        confidence = torch.max(probabilities, 1)[0]

    return predicted.item(), confidence.item()

# Example usage in Flask app:
# prediction, confidence = predict_fracture(preprocessed_image, resnet_model)
# result = "Fractured" if prediction == 1 else "Non-Fractured"
"""

print("Model loading template for Google Drive:")
print(template_code)

# Also save this template to Google Drive
if IN_COLAB:
    template_path = f'{models_dir}/flask_integration_template.py'
else:
    os.makedirs('models', exist_ok=True)
    template_path = 'models/flask_integration_template.py'

with open(template_path, 'w') as f:
    f.write(template_code)

print(f"\n📄 Template saved to: {template_path}")

## Summary

This notebook has successfully:

1. **Setup Environment**: Configured Google Colab with required dependencies and Google Drive integration
2. **Data Management**: Downloaded and stored the Kaggle fracture dataset in Google Drive for persistence
3. **Model Architecture**: Implemented three state-of-the-art models:
   - ResNet50 with custom classifier
   - DenseNet121 with custom classifier  
   - EfficientNetB0 with custom classifier
4. **Training**: Trained all models with proper augmentation and validation
5. **Evaluation**: Compared model performance with detailed metrics
6. **Persistent Storage**: Saved trained models to Google Drive `/models` folder for future use

## 📁 Google Drive Structure Created:

```
/content/drive/MyDrive/fracture_detection/
├── dataset/                          # Kaggle fracture dataset
│   ├── fracture-multi-region-x-ray-data.zip
│   └── [extracted dataset files]
└── models/                           # Trained models and outputs
    ├── fracture_resnet50.pth         # ResNet50 model weights
    ├── fracture_densenet121.pth      # DenseNet121 model weights  
    ├── fracture_efficientnetb0.pth   # EfficientNetB0 model weights
    ├── training_summary.json         # Training metrics and info
    └── flask_integration_template.py # Code template for Flask app
```

## 🚀 Key Benefits:

- **Persistent Storage**: Dataset and models survive Colab session restarts
- **Easy Access**: Models accessible from any Colab notebook via Drive mount
- **Organized Structure**: Clean folder organization for easy management
- **Integration Ready**: Template code provided for Flask app integration

## 📋 Next Steps:

1. **Access Models**: Mount Google Drive in any future Colab session to access trained models
2. **Flask Integration**: Use the provided template to load models in your web application  
3. **Model Deployment**: Download models from Drive for production deployment
4. **Continuous Training**: Easily retrain models using the persistent dataset

The trained models can now be integrated into your medical imaging application for fracture detection. Each model provides binary classification (Fractured vs Non-Fractured) with confidence scores, and all data is safely stored in your Google Drive!