In [None]:
# Install required packages
!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118
!pip install opencv-python scikit-learn pandas matplotlib numpy imbalanced-learn
!pip install albumentations tqdm timm

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
from sklearn.preprocessing import StandardScaler
from imblearn.over_sampling import RandomOverSampler
from google.colab import drive
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from torch.cuda.amp import GradScaler, autocast
import timm
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

# Mount Google Drive
drive.mount('/content/drive')

# Enhanced Configuration
class Config:
    BASE_PATH = "/content/drive/MyDrive/osteoporosis/Fixed image for training CNN models"
    CLASSES = ['Normal', 'osteophenia', 'osteoporosis']
    IMAGE_SIZE = (384, 384)  # Balanced size for performance and accuracy
    BATCH_SIZE = 16  # Optimal batch size
    NUM_EPOCHS = 100
    LEARNING_RATE = 0.001
    NUM_CLASSES = 3
    EARLY_STOPPING_PATIENCE = 15
    TTA_NUM = 10
    DROPOUT_RATE = 0.3
    WEIGHT_DECAY = 1e-5
    NUM_WORKERS = 2

config = Config()

# Advanced Data Augmentation
train_transform = A.Compose([
    A.Resize(*config.IMAGE_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=45, p=0.7),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=30, p=0.7),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.GaussianBlur(blur_limit=(3, 7), p=0.3),
    A.GaussNoise(var_limit=(10.0, 30.0), p=0.3),
    A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.3),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(*config.IMAGE_SIZE),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

# Dataset Class with Error Handling
class BoneDensityDataset(Dataset):
    def __init__(self, root_dir, classes, transform=None, tta=False):
        self.root_dir = root_dir
        self.classes = classes
        self.transform = transform
        self.tta = tta
        self.image_paths = []
        self.labels = []

        for class_idx, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            if not os.path.exists(class_dir):
                print(f"Warning: Directory not found - {class_dir}")
                continue

            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    img_path = os.path.join(class_dir, img_name)
                    self.image_paths.append(img_path)
                    self.labels.append(class_idx)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]

        try:
            image = cv2.imread(img_path)
            if image is None:
                raise ValueError(f"Could not read image {img_path}")

            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            if self.transform:
                if self.tta:
                    images = [self.transform(image=image)['image'] for _ in range(config.TTA_NUM)]
                    return images, torch.tensor(label, dtype=torch.long)  # Ensure label is long type
                else:
                    image = self.transform(image=image)['image']
                    return image, torch.tensor(label, dtype=torch.long)  # Ensure label is long type
            return image, torch.tensor(label, dtype=torch.long)  # Ensure label is long type
        except Exception as e:
            print(f"Error loading {img_path}: {str(e)}")
            dummy = torch.zeros((3, *config.IMAGE_SIZE), dtype=torch.float32)
            return (dummy, torch.tensor(0, dtype=torch.long)) if not self.tta else ([dummy]*config.TTA_NUM, torch.tensor(0, dtype=torch.long))

# Initialize datasets
print("Loading datasets...")
full_dataset = BoneDensityDataset(
    root_dir=config.BASE_PATH,
    classes=config.CLASSES,
    transform=train_transform
)

# Print class distribution
print("\nClass Distribution:")
class_counts = np.unique(full_dataset.labels, return_counts=True)
for class_idx, count in zip(class_counts[0], class_counts[1]):
    print(f"{config.CLASSES[class_idx]}: {count} images")

# Handle class imbalance with oversampling
print("\nBalancing dataset...")
ros = RandomOverSampler(random_state=42)
indices = np.arange(len(full_dataset)).reshape(-1, 1)
resampled_indices, _ = ros.fit_resample(indices, full_dataset.labels)

# Create balanced dataset
balanced_dataset = torch.utils.data.Subset(full_dataset, resampled_indices.squeeze())

# Stratified split
from sklearn.model_selection import StratifiedShuffleSplit
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_indices, val_indices = next(sss.split(resampled_indices, [full_dataset.labels[i] for i in resampled_indices.squeeze()]))

train_dataset = torch.utils.data.Subset(balanced_dataset, train_indices)
val_dataset = torch.utils.data.Subset(
    BoneDensityDataset(
        root_dir=config.BASE_PATH,
        classes=config.CLASSES,
        transform=val_transform
    ),
    val_indices
)

# TTA validation dataset
tta_val_dataset = torch.utils.data.Subset(
    BoneDensityDataset(
        root_dir=config.BASE_PATH,
        classes=config.CLASSES,
        transform=train_transform,
        tta=True
    ),
    val_indices
)

# Data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    num_workers=config.NUM_WORKERS,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=True
)

tta_val_loader = DataLoader(
    tta_val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=True
)

# Enhanced Model Architecture with EfficientNet-B3
class BoneDensityModel(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        # Using EfficientNet-B3 with proper feature extraction
        self.backbone = timm.create_model('efficientnet_b3', pretrained=True)

        # Freeze first 50% of the model
        num_layers = len(list(self.backbone.parameters()))
        for i, (name, param) in enumerate(self.backbone.named_parameters()):
            if i < num_layers * 0.5:
                param.requires_grad = False

        # Get number of features from the backbone
        num_features = self.backbone.classifier.in_features

        # Replace classifier with custom head
        self.backbone.classifier = nn.Sequential(
            nn.Dropout(p=config.DROPOUT_RATE),
            nn.Linear(num_features, 512),
            nn.SiLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(p=config.DROPOUT_RATE/2),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.backbone(x)

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BoneDensityModel(num_classes=config.NUM_CLASSES).to(device)

# Label Smoothing Cross Entropy Loss
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, x, target):
        log_probs = torch.nn.functional.log_softmax(x, dim=-1)
        nll_loss = -log_probs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -log_probs.mean(dim=-1)
        loss = (1 - self.smoothing) * nll_loss + self.smoothing * smooth_loss
        return loss.mean()

# Calculate class weights
class_weights = torch.tensor([
    1.0 / (class_counts[1][0] + 1e-6),
    1.0 / (class_counts[1][1] + 1e-6),
    1.0 / (class_counts[1][2] + 1e-6)
], device=device, dtype=torch.float32)  # Ensure weights are float32
class_weights = class_weights / class_weights.sum()

# Optimizer with weight decay
optimizer = optim.AdamW(
    model.parameters(),
    lr=config.LEARNING_RATE,
    weight_decay=config.WEIGHT_DECAY
)

# Learning rate scheduler
scheduler = CosineAnnealingWarmRestarts(
    optimizer,
    T_0=10,
    T_mult=1,
    eta_min=1e-6
)

# Loss functions
criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
focal_criterion = nn.CrossEntropyLoss(weight=class_weights)

# Training function
def train_epoch(model, loader, optimizer, criterion, focal_criterion, scaler):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in tqdm(loader, desc="Training"):
        images = images.to(device, dtype=torch.float32)  # Ensure input is float32
        labels = labels.to(device)

        optimizer.zero_grad()

        with autocast():
            outputs = model(images)
            loss = 0.7 * criterion(outputs, labels) + 0.3 * focal_criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    scheduler.step()
    return running_loss / len(loader), correct / total

# Validation function
def validate(model, loader, criterion, focal_criterion, tta_loader=None):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Validating"):
            images = images.to(device, dtype=torch.float32)  # Ensure input is float32
            labels = labels.to(device)

            with autocast():
                outputs = model(images)
                loss = 0.7 * criterion(outputs, labels) + 0.3 * focal_criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    val_loss = running_loss / len(loader)
    val_acc = correct / total
    f1 = f1_score(all_labels, all_preds, average='weighted')

    # TTA validation
    tta_acc = None
    tta_f1 = None
    if tta_loader:
        tta_correct = 0
        tta_total = 0
        tta_preds = []
        tta_labels = []

        with torch.no_grad():
            for images, label in tqdm(tta_loader, desc="TTA Validating"):
                label = label.to(device)
                outputs = []

                for img in images[0]:
                    img = img.unsqueeze(0).to(device, dtype=torch.float32)  # Ensure input is float32
                    with autocast():
                        output = model(img)
                    outputs.append(output)

                avg_output = torch.mean(torch.stack(outputs), dim=0)
                _, predicted = avg_output.max(1)

                tta_total += 1
                tta_correct += predicted.eq(label).sum().item()
                tta_preds.append(predicted.cpu().numpy()[0])
                tta_labels.append(label.cpu().numpy()[0])

        tta_acc = tta_correct / tta_total
        tta_f1 = f1_score(tta_labels, tta_preds, average='weighted')

    return val_loss, val_acc, f1, tta_acc, tta_f1

# Training loop
best_val_acc = 0.0
patience_counter = 0
scaler = GradScaler()

history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': [],
    'val_f1': [],
    'tta_acc': [],
    'tta_f1': []
}

for epoch in range(config.NUM_EPOCHS):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, focal_criterion, scaler)
    val_loss, val_acc, val_f1, tta_acc, tta_f1 = validate(model, val_loader, criterion, focal_criterion, tta_val_loader)

    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_f1'].append(val_f1)
    if tta_acc:
        history['tta_acc'].append(tta_acc)
        history['tta_f1'].append(tta_f1)

    print(f"\nEpoch {epoch+1}/{config.NUM_EPOCHS}:")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2%}")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2%} | Val F1: {val_f1:.4f}")
    if tta_acc:
        print(f"TTA Val Acc: {tta_acc:.2%} | TTA F1: {tta_f1:.4f}")

    current_acc = tta_acc if tta_acc else val_acc
    if current_acc > best_val_acc:
        best_val_acc = current_acc
        patience_counter = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': val_loss,
            'accuracy': best_val_acc
        }, '/content/best_model.pth')
        print("Saved new best model")
    else:
        patience_counter += 1
        if patience_counter >= config.EARLY_STOPPING_PATIENCE:
            print(f"Early stopping at epoch {epoch+1}")
            break

# Plot training history
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.title('Loss over Epochs')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train Acc')
plt.plot(history['val_acc'], label='Val Acc')
if 'tta_acc' in history:
    plt.plot(history['tta_acc'], label='TTA Val Acc')
plt.title('Accuracy over Epochs')
plt.legend()
plt.show()

# Load best model
checkpoint = torch.load('/content/best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Evaluation function
def evaluate_model(model, loader, tta_loader=None, set_name="Validation"):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in tqdm(loader, desc=f"Evaluating {set_name}"):
            images = images.to(device, dtype=torch.float32)  # Ensure input is float32
            with autocast():
                outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')

    print(f"\nStandard {set_name} Set Evaluation:")
    print(f"Accuracy: {acc:.2%}")
    print(f"F1 Score: {f1:.4f}")
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=config.CLASSES, zero_division=0, digits=4))
    print("\nConfusion Matrix:")
    print(confusion_matrix(all_labels, all_preds))

    if tta_loader:
        tta_preds = []
        tta_labels = []

        with torch.no_grad():
            for images, label in tqdm(tta_loader, desc=f"TTA {set_name}"):
                label = label.to(device)
                outputs = []

                for img in images[0]:
                    img = img.unsqueeze(0).to(device, dtype=torch.float32)  # Ensure input is float32
                    with autocast():
                        output = model(img)
                    outputs.append(output)

                avg_output = torch.mean(torch.stack(outputs), dim=0)
                _, predicted = avg_output.max(1)
                tta_preds.append(predicted.cpu().numpy()[0])
                tta_labels.append(label.cpu().numpy()[0])

        tta_acc = accuracy_score(tta_labels, tta_preds)
        tta_f1 = f1_score(tta_labels, tta_preds, average='weighted')

        print(f"\nTTA {set_name} Set Evaluation:")
        print(f"Accuracy: {tta_acc:.2%}")
        print(f"F1 Score: {tta_f1:.4f}")
        print("\nClassification Report:")
        print(classification_report(tta_labels, tta_preds, target_names=config.CLASSES, zero_division=0, digits=4))
        print("\nConfusion Matrix:")
        print(confusion_matrix(tta_labels, tta_preds))

        return tta_acc
    return acc

print("\n=== Final Evaluation ===")
train_acc = evaluate_model(model, train_loader, set_name="Training")
val_acc = evaluate_model(model, val_loader, tta_val_loader, "Validation")

print("\n=== PROCESS COMPLETED SUCCESSFULLY ===")
print(f"Final Validation Accuracy: {val_acc:.2%}")