In [4]:
import os
import gc
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import efficientnet_b4, EfficientNet_B4_Weights
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, classification_report, confusion_matrix
from PIL import Image
from typing import Tuple, List, Dict
import albumentations as A
from albumentations.pytorch import ToTensorV2

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [5]:
# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = (380, 380)
BATCH_SIZE = 32
EPOCHS_PHASE1 = 15
EPOCHS_PHASE2 = 10
CLASS_NAMES = ['DR', 'MH', 'ODC', 'TSLN', 'DN', 'MYA', 'ARMD', 'BRVO', 'ODP', 'ODE', 'LS', 'RS', 'CSR', 'CRS']

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

def load_and_preprocess_data():
    train_labels = pd.read_csv("/kaggle/input/retinal-disease-classification/Training_Set/Training_Set/RFMiD_Training_Labels.csv")
    val_labels = pd.read_csv("/kaggle/input/retinal-disease-classification/Evaluation_Set/Evaluation_Set/RFMiD_Validation_Labels.csv")
    test_labels = pd.read_csv("/kaggle/input/retinal-disease-classification/Test_Set/Test_Set/RFMiD_Testing_Labels.csv")
    
    def process_df(df):
        selected_diseases = list(set(CLASS_NAMES) & set(df.columns))
        df = df[['ID', 'Disease_Risk'] + selected_diseases].copy()
        df['Disease_Risk'] = (df[selected_diseases].sum(axis=1) > 0).astype(int)
        return df
    
    return process_df(train_labels), process_df(val_labels), process_df(test_labels)

class RetinalDataset(Dataset):
    def __init__(self, img_dir: str, df: pd.DataFrame, transform=None, augment=False, mixup_alpha=0.4):
        self.img_dir = img_dir
        self.df = df
        self.transform = transform
        self.augment = augment
        self.mixup_alpha = mixup_alpha
        self.image_paths = [os.path.join(img_dir, f"{row['ID']}.png") for _, row in df.iterrows()]

    def __len__(self):
        return len(self.df)
        
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        # Use PIL's resize during loading to reduce memory usage
        image = Image.open(img_path).convert('RGB').resize((320, 320))  # Reduced size
        
        if self.transform:
            image = self.transform(image=np.array(image))['image']
        
        labels = torch.tensor(self.df.iloc[idx][CLASS_NAMES].values, dtype=torch.float32)
        
        return image, labels

    def mixup_data(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.mixup_alpha > 0:
            lam = np.clip(np.random.beta(self.mixup_alpha, self.mixup_alpha), 0.2, 0.8)
            batch_size = x.size()[0]
            index = torch.randperm(batch_size).to(x.device)
            
            mixed_x = lam * x + (1 - lam) * x[index]
            mixed_y = lam * y + (1 - lam) * y[index]
            return mixed_x, mixed_y
        return x, y

class RetinalModel(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        weights = EfficientNet_B4_Weights.DEFAULT
        self.backbone = efficientnet_b4(weights=weights)
        num_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()
        
        self.classifier = nn.Sequential(
            nn.BatchNorm1d(num_features),
            nn.Dropout(0.5),
            nn.Linear(num_features, 512),
            nn.SiLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.SiLU(),
            nn.Linear(256, num_classes),
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x)
        return self.classifier(x)

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        
    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        # bce_loss = nn.BCELoss(reduction='none')(inputs, targets)
        bce_loss = nn.BCEWithLogitsLoss(reduction='none')(inputs, targets)
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * bce_loss
        return focal_loss.mean()

def train_epoch(model: nn.Module, dataloader: DataLoader, criterion, optimizer, scaler, device: torch.device) -> Tuple[float, float]:
    model.train()
    total_loss = 0
    total_auc = 0
    
    for batch_idx, (images, targets) in enumerate(dataloader):
        # Clear cache periodically
        if batch_idx % 10 == 0:
            torch.cuda.empty_cache()
            gc.collect()
            
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        
        if dataloader.dataset.augment:
            images, targets = dataloader.dataset.mixup_data(images, targets)
        
        optimizer.zero_grad(set_to_none=True)  # More efficient than zero_grad()
        
        with torch.amp.autocast('cuda'):  # Use mixed precision
            outputs = model(images)
            loss = criterion(outputs, targets)
       
        # Use scaler for mixed precision training
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        # loss.backward()
        # optimizer.step()

         
        # For AUC calculation, apply sigmoid to get probabilities
        with torch.no_grad():
            probs = torch.sigmoid(outputs)
        
        try:
            auc = roc_auc_score(targets.cpu().numpy(), probs.cpu().numpy(), average='macro')
            total_auc += auc
        except:
            pass


        total_loss += loss.item()
        del images, targets, outputs, probs
        torch.cuda.empty_cache()
        
        # # Explicitly clear variables
        # del images, targets, outputs
        # torch.cuda.empty_cache()
        
        # total_loss += loss.item()
        # try:
        #     auc = roc_auc_score(targets.cpu().numpy(), outputs.detach().cpu().numpy(), average='macro')
        #     total_auc += auc
        # except:
        #     pass
            
    return total_loss / len(dataloader), total_auc / len(dataloader)


def validate(model: nn.Module, dataloader: DataLoader, criterion, device: torch.device) -> Tuple[float, float]:
    model.eval()
    total_loss = 0
    total_auc = 0
    
    with torch.no_grad():
        for images, targets in dataloader:
            images = images.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            
            outputs = model(images)
            loss = criterion(outputs, targets)

            # Apply sigmoid for predictions
            probs = torch.sigmoid(outputs)
            
            total_loss += loss.item()
            try:
                auc = roc_auc_score(targets.cpu().numpy(), probs.cpu().numpy(), average='macro')
                total_auc += auc
            except:
                pass
            
    return total_loss / len(dataloader), total_auc / len(dataloader)

In [6]:
def main():

    # Configure for memory efficiency
    torch.backends.cudnn.benchmark = True

    # Load data
    train_df, val_df, test_df = load_and_preprocess_data()
    
    # Define transforms
    train_transform = A.Compose([
        A.Resize(IMG_SIZE[0], IMG_SIZE[1]),
        A.RandomRotate90(),
        A.HorizontalFlip(),
        A.VerticalFlip(),
        # A.OneOf([
        #     A.RandomBrightness(),
        #     A.RandomContrast(),
        # ], p=0.3),
        A.Normalize(),
        ToTensorV2()
    ])
    
    val_transform = A.Compose([
        A.Resize(IMG_SIZE[0], IMG_SIZE[1]),
        A.Normalize(),
        ToTensorV2()
    ])
    
    # Create datasets and dataloaders
    train_dataset = RetinalDataset(
        "/kaggle/input/retinal-disease-classification/Training_Set/Training_Set/Training",
        train_df,
        transform=train_transform,
        augment=True,
        mixup_alpha=0.4
    )
    
    val_dataset = RetinalDataset(
        "/kaggle/input/retinal-disease-classification/Evaluation_Set/Evaluation_Set/Validation",
        val_df,
        transform=val_transform
    )
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, persistent_workers=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=True)

    # Initialize mixed precision training
    scaler = torch.amp.GradScaler(device='cuda')
    
    # Initialize model, criterion, and optimizer
    model = RetinalModel(len(CLASS_NAMES)).to(DEVICE)
    criterion = FocalLoss(alpha=0.25, gamma=2.0)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)
    scaler = torch.amp.GradScaler(device='cuda')
    
    # Training Phase 1
    best_val_auc = 0
    for epoch in range(EPOCHS_PHASE1):
        # Clear memory before each epoch
        torch.cuda.empty_cache()
        gc.collect()
        
        train_loss, train_auc = train_epoch(model, train_loader, criterion, optimizer, scaler, DEVICE)
        
        with torch.no_grad():  # Ensure validation doesn't accumulate gradients
            val_loss, val_auc = validate(model, val_loader, criterion, DEVICE)
        
        scheduler.step(val_auc)
        
        if val_auc > best_val_auc or epoch == 0:
            best_val_auc = val_auc
            print("Our model: \n\n", model, '\n')
            print("The state dict keys: \n\n", model.state_dict().keys())
            try:
                torch.save(model.state_dict(), '/kaggle/working/phase1_best.pth')
                print("Checkpoint saved successfully!")
            except Exception as e:
                print(f"Error while saving model: {e}")

        print(f'Epoch {epoch+1}/{EPOCHS_PHASE1}:')
        print(f'Train Loss: {train_loss:.4f}, Train AUC: {train_auc:.4f}')
        print(f'Val Loss: {val_loss:.4f}, Val AUC: {val_auc:.4f}')
        
        # Clear memory after each epoch
        torch.cuda.empty_cache()
        gc.collect()
    
    # Phase 2: Fine-tuning
    model.load_state_dict(torch.load('/kaggle/working/phase1_best.keras'))
    for param in model.backbone.parameters():
        param.requires_grad = True
        
    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=1)
    
    for epoch in range(EPOCHS_PHASE2):
        train_loss, train_auc = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
        val_loss, val_auc = validate(model, val_loader, criterion, DEVICE)
        
        scheduler.step(val_auc)
        
        if val_auc > best_val_auc or epoch == 0:
            best_val_auc = val_auc
            print("Our model: \n\n", model, '\n')
            print("The state dict keys: \n\n", model.state_dict().keys())
            try:
                torch.save(model.state_dict(), '/kaggle/working/final_model.keras')
                print("2nd Checkpoint saved successfully!")
            except Exception as e:
                print(f"Error while saving model: {e}")
            
        print(f'Fine-tuning Epoch {epoch+1}/{EPOCHS_PHASE2}:')
        print(f'Train Loss: {train_loss:.4f}, Train AUC: {train_auc:.4f}')
        print(f'Val Loss: {val_loss:.4f}, Val AUC: {val_auc:.4f}')
        
    
    # Compute final AUC across full validation set
    all_labels, all_probs = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            probs = torch.sigmoid(outputs)  # Convert logits to probabilities
            
            all_labels.append(labels.cpu().numpy())
            all_probs.append(probs.cpu().numpy())
    
    all_labels = np.concatenate(all_labels, axis=0)
    all_probs = np.concatenate(all_probs, axis=0)
    
    if len(np.unique(all_labels)) > 1:  # Ensure at least two classes exist
        auc_score = roc_auc_score(all_labels, all_probs)
        print(f'Final Validation AUC Score: {auc_score:.4f}')
    else:
        print("Warning: Only one class present in validation set. AUC cannot be computed.")


if __name__ == "__main__":
    main()

Epoch 1/15:
Train Loss: 0.0176, Train AUC: 0.0000
Val Loss: 0.0120, Val AUC: 0.0000
Epoch 2/15:
Train Loss: 0.0114, Train AUC: 0.0000
Val Loss: 0.0118, Val AUC: 0.0000
Epoch 3/15:
Train Loss: 0.0108, Train AUC: 0.0000
Val Loss: 0.0104, Val AUC: 0.0000
Epoch 4/15:
Train Loss: 0.0102, Train AUC: 0.0000
Val Loss: 0.0098, Val AUC: 0.0000
Epoch 5/15:
Train Loss: 0.0092, Train AUC: 0.0000
Val Loss: 0.0085, Val AUC: 0.0000
Epoch 6/15:
Train Loss: 0.0092, Train AUC: 0.0000
Val Loss: 0.0087, Val AUC: 0.0000
Epoch 7/15:
Train Loss: 0.0087, Train AUC: 0.0000
Val Loss: 0.0082, Val AUC: 0.0000
Epoch 8/15:
Train Loss: 0.0083, Train AUC: 0.0000
Val Loss: 0.0078, Val AUC: 0.0000
Epoch 9/15:
Train Loss: 0.0082, Train AUC: 0.0000
Val Loss: 0.0077, Val AUC: 0.0000
Epoch 10/15:
Train Loss: 0.0080, Train AUC: 0.0000
Val Loss: 0.0075, Val AUC: 0.0000
Epoch 11/15:
Train Loss: 0.0079, Train AUC: 0.0000
Val Loss: 0.0076, Val AUC: 0.0000
Epoch 12/15:
Train Loss: 0.0077, Train AUC: 0.0000
Val Loss: 0.0076, Val A

  model.load_state_dict(torch.load('phase1_best.pth'))


FileNotFoundError: [Errno 2] No such file or directory: 'phase1_best.pth'

In [None]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.metrics import AUC
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D, BatchNormalization
from tensorflow.keras.applications import EfficientNetB4
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras import regularizers
from tensorflow.keras.losses import BinaryFocalCrossentropy
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras.utils import to_categorical
from concurrent.futures import ThreadPoolExecutor

In [None]:
#### print("Available GPUs:", tf.config.list_physical_devices('GPU'))

In [None]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
            print("GPU Done!!")
    except RuntimeError as e:
        print(e)

In [None]:
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)

In [None]:
!nvidia-smi

In [None]:
# Set random seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

# Configuration
IMG_SIZE = (380, 380)  # EfficientNetB4 requires 380x380
BATCH_SIZE = 64
EPOCHS_PHASE1 = 15
EPOCHS_PHASE2 = 10
CLASS_NAMES = ['DR', 'MH', 'ODC', 'TSLN', 'DN', 'MYA', 'ARMD', 'BRVO', 'ODP', 'ODE', 'LS', 'RS', 'CSR', 'CRS']

# Paths
train_labels_path = "/kaggle/input/retinal-disease-classification/Training_Set/Training_Set/RFMiD_Training_Labels.csv"
val_labels_path = "/kaggle/input/retinal-disease-classification/Evaluation_Set/Evaluation_Set/RFMiD_Validation_Labels.csv"
test_labels_path = "/kaggle/input/retinal-disease-classification/Test_Set/Test_Set/RFMiD_Testing_Labels.csv"

In [None]:
# Load and preprocess data
def load_and_preprocess_data():
    train_labels = pd.read_csv(train_labels_path)
    val_labels = pd.read_csv(val_labels_path)
    test_labels = pd.read_csv(test_labels_path)

    selected_diseases = CLASS_NAMES
    
    def process_df(df):
        selected_diseases = list(set(CLASS_NAMES) & set(df.columns))
        df = df[['ID', 'Disease_Risk'] + selected_diseases].copy()
        df['Disease_Risk'] = (df[selected_diseases].sum(axis=1) > 0).astype(int)
        return df

    return (
        process_df(train_labels),
        process_df(val_labels),
        process_df(test_labels)
    )

train_df, val_df, test_df = load_and_preprocess_data()

In [None]:
# Enhanced Data Generator with Mixup
class AdvancedDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, img_dir, df, batch_size=32, img_size=IMG_SIZE, 
                 augment=False, shuffle=True, mixup_alpha=0.4, **kwargs):
        super().__init__(**kwargs) 
         
        self.img_dir = img_dir
        self.df = df
        self.batch_size = batch_size
        self.img_size = img_size
        self.augment = augment
        self.shuffle = shuffle
        self.mixup_alpha = mixup_alpha
        self.indices = np.arange(len(df))
        
        # Pre-load all image paths for faster access
        self.image_paths = [os.path.join(img_dir, f"{row['ID']}.png") for _, row in df.iterrows()]

        # Augmentation configurations
        self.augmenter = ImageDataGenerator(
            rotation_range=25,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True,
            vertical_flip=True,
            brightness_range=[0.8, 1.2]
        )

        # Set memory optimization flags
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.df) / self.batch_size))

    def __getitem__(self, index):
        batch_indices = self.indices[index*self.batch_size : (index+1)*self.batch_size]
        batch_df = self.df.iloc[batch_indices]

        # Use parallel processing for data loading
        with ThreadPoolExecutor(max_workers=4) as executor:
            results = list(executor.map(self._load_single_item, batch_indices))

        X = np.stack([r[0] for r in results])
        y = np.stack([r[1] for r in results])

        # Apply Mixup using vectorized operations
        if self.augment and self.mixup_alpha > 0:
            X, y = self._apply_mixup(X, y)
            
        return X, y        
        # # Load images and labels
        # X, y = self._load_data(batch_df)
        
        # # Apply Mixup
        # if self.augment and self.mixup_alpha > 0:
        #     X, y = self._apply_mixup(X, y)
            
        # return X, y

    def _load_single_item(self, idx):
        row = self.df.iloc[idx]
        img = load_img(self.image_paths[idx], target_size=self.img_size)
        img_array = img_to_array(img)
        
        if self.augment and row['Disease_Risk'] == 1:
            img_array = self.augmenter.random_transform(img_array)
            
        return preprocess_input(img_array), row[CLASS_NAMES].values.astype(np.float32)
    
    def _apply_mixup(self, X, y):
        lam = np.clip(np.random.beta(self.mixup_alpha, self.mixup_alpha), 0.2, 0.8)
        rand_index = np.random.permutation(len(X))
        
        # Vectorized operations
        mixed_X = lam * X + (1 - lam) * X[rand_index]
        mixed_y = lam * y + (1 - lam) * y[rand_index]
        return mixed_X, mixed_y

    # def on_epoch_end(self):
    #     if self.shuffle:
    #         np.random.shuffle(self.indices)

    # def _load_data(self, batch_df):
    #     X = np.empty((len(batch_df), *self.img_size, 3))
    #     y = np.empty((len(batch_df), len(CLASS_NAMES)))   # Disease_Risk + diseases
        
    #     for i, (_, row) in enumerate(batch_df.iterrows()):
    #         img_path = os.path.join(self.img_dir, f"{row['ID']}.png")
    #         img = load_img(img_path, target_size=self.img_size)
    #         img_array = img_to_array(img)
            
    #         # Apply augmentation only to diseased samples
    #         if self.augment and row['Disease_Risk'] == 1:
    #             img_array = self.augmenter.random_transform(img_array)
                
    #         X[i] = preprocess_input(img_array)  # EfficientNet preprocessing
    #         y[i] = row[CLASS_NAMES].values
            
    #     return X, y.astype(np.float32)

    # def _apply_mixup(self, X, y):
    #     lam = np.clip(np.random.beta(self.mixup_alpha, self.mixup_alpha), 0.2, 0.8)
    #     rand_index = np.random.permutation(len(X))
        
    #     mixed_X = lam * X + (1 - lam) * X[rand_index]
    #     mixed_y = lam * y + (1 - lam) * y[rand_index]
    #     return mixed_X, mixed_y

In [None]:
# Create data generators
train_gen = AdvancedDataGenerator(
    "/kaggle/input/retinal-disease-classification/Training_Set/Training_Set/Training",
    train_df,
    batch_size=BATCH_SIZE,
    augment=True,
    mixup_alpha=0.4
)

val_gen = AdvancedDataGenerator(
    "/kaggle/input/retinal-disease-classification/Evaluation_Set/Evaluation_Set/Validation",
    val_df,
    batch_size=BATCH_SIZE
)

test_gen = AdvancedDataGenerator(
    "/kaggle/input/retinal-disease-classification/Test_Set/Test_Set/Test",
    test_df,
    batch_size=BATCH_SIZE
)

def generator_wrapper(generator):
    for X, y in generator:
        yield X, y

# train_dataset = tf.data.Dataset.from_generator(
#     lambda: generator_wrapper(train_gen),
#     output_signature=(
#         tf.TensorSpec(shape=(None, *IMG_SIZE, 3), dtype=tf.float32),
#         tf.TensorSpec(shape=(None, len(CLASS_NAMES)), dtype=tf.float32)
#     )
# ).prefetch(tf.data.AUTOTUNE)

train_dataset = tf.data.Dataset.from_generator(
    lambda: generator_wrapper(train_gen),
    output_signature=(
        tf.TensorSpec(shape=(None, *IMG_SIZE, 3), dtype=tf.float32),
        tf.TensorSpec(shape=(None, len(CLASS_NAMES)), dtype=tf.float32)
    )
).prefetch(tf.data.AUTOTUNE).cache().shuffle(buffer_size=1000)


In [None]:
X_batch, y_batch = train_gen[0]
print("Shape of X_batch:", X_batch.shape)  # Should be (batch_size, 380, 380, 3)
print("Shape of y_batch:", y_batch.shape)  # Should be (batch_size, 14)

In [None]:
# Calculate class weights
def calculate_class_weights(df):
    weights = {}
    for idx, disease in enumerate(['Disease_Risk'] + CLASS_NAMES):
        cls_weights = compute_class_weight(
            'balanced',
            classes=np.array([0, 1]),
            y=df[disease]
        )
        weights[idx] = {0: cls_weights[0], 1: cls_weights[1]}
    return weights

class_weights = calculate_class_weights(train_df)
class_weights = {idx: weights[1] for idx, weights in class_weights.items()}

In [None]:
# Build Model
def build_model():
    base_model = EfficientNetB4(
        weights='imagenet',
        include_top=False,
        input_shape=(*IMG_SIZE, 3)
    )
    base_model.trainable = False  # Freeze initially

    model = Sequential([
        base_model,
        GlobalAveragePooling2D(),
        BatchNormalization(),
        Dropout(0.5),
        Dense(512, activation='swish', kernel_regularizer=regularizers.l2(1e-4)),
        BatchNormalization(),
        Dropout(0.5),
        Dense(256, activation='swish'),
        Dense(len(CLASS_NAMES), activation='sigmoid')  # Disease_Risk + diseases
    ])
    
    return model

In [None]:
def focal_loss(alpha=0.25, gamma=2.0):
    def loss_fn(y_true, y_pred):
        # Cast to float32
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        
        # Compute element-wise binary crossentropy using Keras backend.
        ce = tf.keras.backend.binary_crossentropy(y_true, y_pred)  
        # ce now has shape (batch_size, num_classes)
        
        # Compute p_t element-wise
        p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
        
        # Compute alpha factor element-wise
        alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha)
        
        # Compute modulating factor element-wise
        modulating_factor = tf.pow(1.0 - p_t, gamma)
        
        # Final loss element-wise
        loss = alpha_factor * modulating_factor * ce
        return loss  # Optionally, you can reduce_mean over the batch or last axis.
    
    return loss_fn


In [None]:
sample_y_true = tf.convert_to_tensor(y_batch[:5], dtype=tf.float32)  # (5, 14)
sample_y_pred = tf.random.uniform(sample_y_true.shape, 0, 1)           # (5, 14)
loss_fn = focal_loss(alpha=0.25, gamma=2.0)
print(loss_fn(sample_y_true, sample_y_pred))


In [None]:
print("y_true shape:", sample_y_true.shape)
print("y_pred shape:", sample_y_pred.shape)

In [None]:
# Two-Phase Training
model = build_model()

# Phase 1: Train head
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss=focal_loss(gamma=2.0, alpha=0.25),
    metrics=[AUC(name='auc', multi_label=True)]
)

phase1_callbacks = [
    EarlyStopping(patience=3, monitor='val_auc', mode='max', verbose=1),
    ModelCheckpoint('phase1_best.keras', save_best_only=True),
    ReduceLROnPlateau(factor=0.5, patience=2)
]

history_phase1 = model.fit(
    train_dataset,
    validation_data=val_gen,
    epochs=EPOCHS_PHASE1,
    callbacks=phase1_callbacks,
    # workers=4,  # Increase as needed
    # use_multiprocessing=True
)

In [None]:
# Phase 2: Fine-tune
model = tf.keras.models.load_model('phase1_best.keras', custom_objects={'loss_fn': focal_loss()}) # Load best weights from phase 1
model.layers[0].trainable = True  # Unfreeze base model

# Set last 150 layers trainable
for layer in model.layers[0].layers[-150:]:
    layer.trainable = True

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
    loss=focal_loss(gamma=2.0, alpha=0.25),
    metrics=[AUC(name='auc', multi_label=True)]
)

phase2_callbacks = [
    EarlyStopping(patience=2, monitor='val_auc', mode='max', verbose=1),
    ModelCheckpoint('final_model.keras', save_best_only=True),
    ReduceLROnPlateau(factor=0.2, patience=1)
]

history_phase2 = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=EPOCHS_PHASE2,
    callbacks=phase2_callbacks
)

In [None]:
# Evaluation with Test-Time Augmentation
def evaluate_with_tta(model, generator, n_tta=5):
    y_true, y_pred = [], []
    # tta_preds = []
    
    for i in range(len(generator)):
        X, y = generator[i]
        batch_preds = np.zeros_like(y)
        
        for _ in range(n_tta):
            # Create augmented versions
            aug_X = np.array([generator.augmenter.random_transform(img) for img in X])
            batch_preds += model.predict(aug_X)
            
        # Average predictions
        batch_preds /= n_tta  # Average predictions
        y_true.append(y)
        y_pred.append(batch_preds)
        # avg_pred = np.mean(batch_preds, axis=0)
        # tta_preds.append(avg_pred)
        # y_true.append(y)
        
    return np.vstack(y_true), np.vstack(y_pred)

model = tf.keras.models.load_model('final_model.keras', custom_objects={'loss_fn': focal_loss()})
y_true, y_pred = evaluate_with_tta(model, test_gen)

In [None]:
# Generate reports
print("Classification Report:")
print(classification_report(
    y_true[:, 1:],  # Skip Disease_Risk
    (y_pred[:, 1:] > 0.5).astype(int),
    target_names=CLASS_NAMES
))

print("\nConfusion Matrices:")
for idx, disease in enumerate(CLASS_NAMES):
    cm = confusion_matrix(y_true[:, idx+1], (y_pred[:, idx+1] > 0.5).astype(int))
    plt.figure()
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f"{disease} Confusion Matrix")
    plt.show()

In [None]:
# Plot training history
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history_phase1.history['auc'], label='Phase1 Train')
plt.plot(history_phase1.history['val_auc'], label='Phase1 Val')
plt.plot(np.arange(EPOCHS_PHASE1, EPOCHS_PHASE1+len(history_phase2.history['auc'])), 
         history_phase2.history['auc'], label='Phase2 Train')
plt.plot(np.arange(EPOCHS_PHASE1, EPOCHS_PHASE1+len(history_phase2.history['val_auc'])), 
         history_phase2.history['val_auc'], label='Phase2 Val')
plt.title('AUC History')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history_phase1.history['loss'], label='Phase1 Train')
plt.plot(history_phase1.history['val_loss'], label='Phase1 Val')
plt.plot(np.arange(EPOCHS_PHASE1, EPOCHS_PHASE1+len(history_phase2.history['loss'])), 
         history_phase2.history['loss'], label='Phase2 Train')
plt.plot(np.arange(EPOCHS_PHASE1, EPOCHS_PHASE1+len(history_phase2.history['val_loss'])), 
         history_phase2.history['val_loss'], label='Phase2 Val')
plt.title('Loss History')
plt.legend()
plt.show()