In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/pcamv1/camelyonpatch_level_2_split_train_y.h5
/kaggle/input/pcamv1/camelyonpatch_level_2_split_valid_y.h5
/kaggle/input/pcamv1/camelyonpatch_level_2_split_valid_meta.csv
/kaggle/input/pcamv1/camelyonpatch_level_2_split_valid_x.h5
/kaggle/input/pcamv1/camelyonpatch_level_2_split_train_mask.h5
/kaggle/input/pcamv1/camelyonpatch_level_2_split_train_meta.csv
/kaggle/input/pcamv1/camelyonpatch_level_2_split_test_y.h5
/kaggle/input/pcamv1/camelyonpatch_level_2_split_test_meta.csv
/kaggle/input/pcamv1/camelyonpatch_level_2_split_test_x.h5
/kaggle/input/pcamv1/camelyonpatch_level_2_split_train_x.h5-001/camelyonpatch_level_2_split_train_x.h5


In [2]:
# ===== 0) Imports & Setup =====
import os
import math
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, precision_recall_fscore_support
import matplotlib.pyplot as plt
import seaborn as sns
import random
from tqdm import tqdm
import torch.cuda.amp as amp  # Modified: Use cuda.amp for mixed precision

# Reproducibility
SEED = 1131
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Paths from Kaggle environment (updated for standard PCam dataset)
BASE_DIR = "/kaggle/input/pcamv1/"
TRAIN_X_PATH = os.path.join(BASE_DIR, "camelyonpatch_level_2_split_train_x.h5-001/camelyonpatch_level_2_split_train_x.h5")
TRAIN_Y_PATH = os.path.join(BASE_DIR, "camelyonpatch_level_2_split_train_y.h5")
VALID_X_PATH = os.path.join(BASE_DIR, "camelyonpatch_level_2_split_valid_x.h5")
VALID_Y_PATH = os.path.join(BASE_DIR, "camelyonpatch_level_2_split_valid_y.h5")
TEST_X_PATH = os.path.join(BASE_DIR, "camelyonpatch_level_2_split_test_x.h5")
TEST_Y_PATH = os.path.join(BASE_DIR, "camelyonpatch_level_2_split_test_y.h5")

# Verify paths
for path in [TRAIN_X_PATH, TRAIN_Y_PATH, VALID_X_PATH, VALID_Y_PATH, TEST_X_PATH, TEST_Y_PATH]:
    if not os.path.exists(path):
        print(f"File not found: {path}")
    else:
        print(f"File found: {path}")

# Training hyperparameters
INPUT_SIZE = 128  # Paper uses 128x128
BATCH_SIZE = 64  # Modified: Increased back to 128 assuming GPU allows
WARMUP_EPOCHS = 3
FINETUNE_EPOCHS = 10  # Modified: Reduced from 12 to avoid overfitting
LR_WARMUP = 1e-3
LR_FINETUNE = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Mixed precision
MIXED_PRECISION = True
print("Mixed precision enabled")

File found: /kaggle/input/pcamv1/camelyonpatch_level_2_split_train_x.h5-001/camelyonpatch_level_2_split_train_x.h5
File found: /kaggle/input/pcamv1/camelyonpatch_level_2_split_train_y.h5
File found: /kaggle/input/pcamv1/camelyonpatch_level_2_split_valid_x.h5
File found: /kaggle/input/pcamv1/camelyonpatch_level_2_split_valid_y.h5
File found: /kaggle/input/pcamv1/camelyonpatch_level_2_split_test_x.h5
File found: /kaggle/input/pcamv1/camelyonpatch_level_2_split_test_y.h5
Using device: cuda
Mixed precision enabled


In [6]:
# ===== 1) HDF5 Dataset =====
class PCamH5Dataset(Dataset):
    def __init__(self, x_path, y_path, transform=None):
        self.x_file = h5py.File(x_path, "r")
        self.y_file = h5py.File(y_path, "r")
        self.X = self.x_file["x"]  # (N, 96, 96, 3) uint8
        self.Y = self.y_file["y"]  # (N, 1, 1, 1)
        self.transform = transform

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        img = self.X[idx].astype(np.float32) / 255.0  # Normalize to [0,1]
        label = float(self.Y[idx].reshape(-1)[0])
        img = torch.from_numpy(img).permute(2, 0, 1)  # (3, 96, 96)
        label = torch.tensor(label, dtype=torch.float32)
        if self.transform:
            img = self.transform(img)
        return img, label

    def close(self):
        self.x_file.close()
        self.y_file.close()

# Data augmentation (Modified: Added ColorJitter and Normalize)
train_transform = transforms.Compose([
    transforms.Resize((INPUT_SIZE, INPUT_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15, hue=0.05),
    transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), shear=15),
    transforms.RandomResizedCrop(INPUT_SIZE, scale=(0.85, 1.0)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

valid_transform = transforms.Compose([
    transforms.Resize((INPUT_SIZE, INPUT_SIZE)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Create datasets
train_dataset = PCamH5Dataset(TRAIN_X_PATH, TRAIN_Y_PATH, transform=train_transform)
valid_dataset = PCamH5Dataset(VALID_X_PATH, VALID_Y_PATH, transform=valid_transform)
test_dataset = PCamH5Dataset(TEST_X_PATH, TEST_Y_PATH, transform=valid_transform)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

# Compute class weights (Modified: Slight emphasis on positive class)
def count_class_distribution(y_path, chunk=65536):
    f = h5py.File(y_path, "r")
    Y = f["y"]
    n = Y.shape[0]
    ones = 0
    for start in range(0, n, chunk):
        end = min(start + chunk, n)
        ones += Y[start:end].reshape(-1).sum()
    zeros = n - int(ones)
    f.close()
    return zeros, int(ones)

neg, pos = count_class_distribution(TRAIN_Y_PATH)
print(f"Train label counts → 0: {neg}, 1: {pos}")
classes = np.array([0, 1])
weights = [1.0, 1.2]  # Modified: Emphasize positive slightly
class_weight_dict = {0: weights[0], 1: weights[1]}
print("Class weights:", class_weight_dict)

class_weights_tensor = torch.tensor([class_weight_dict[0], class_weight_dict[1]], dtype=torch.float32).to(DEVICE)

Train label counts → 0: 131072, 1: 131072
Class weights: {0: 1.0, 1: 1.2}


In [7]:
# ===== 2) Model Definition (Modified: Enhanced classifier with extra layer) =====
class DNBCD(nn.Module):
    def __init__(self):
        super(DNBCD, self).__init__()
        self.backbone = models.densenet121(pretrained=True)
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])  # Remove classifier
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.dropout1 = nn.Dropout(0.4)  # Modified: Increased dropout
        self.fc1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.relu1 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu2 = nn.ReLU()
        self.dropout3 = nn.Dropout(0.3)
        self.fc3 = nn.Linear(256, 1)  # Output raw logits

    def forward(self, x):
        x = self.backbone(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout1(x)
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.dropout3(x)
        x = self.fc3(x)
        return x

# Initialize model
model = DNBCD().to(DEVICE)

# Freeze backbone for warmup
for param in model.backbone.parameters():
    param.requires_grad = False

Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth
100%|██████████| 30.8M/30.8M [00:00<00:00, 156MB/s] 


In [8]:
# ===== 3) Training Setup (Modified: FocalLoss, weight_decay, Cosine scheduler) =====
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)  # prevents nans when probability 0
        F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        if self.reduction == 'mean':
            return torch.mean(F_loss)
        elif self.reduction == 'sum':
            return torch.sum(F_loss)
        else:
            return F_loss

criterion = FocalLoss(alpha=0.75, gamma=2.0)  # Modified: Use FocalLoss

optimizer = optim.AdamW(model.parameters(), lr=LR_WARMUP, weight_decay=1e-4)  # Modified: Added weight_decay
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=FINETUNE_EPOCHS)  # Modified: Cosine scheduler for finetune
scaler = amp.GradScaler()  # For mixed precision
best_auc = 0.0
ckpt_path = "/kaggle/working/dnbcd_pcam_best.pth"

# Early stopping params (Modified: Added early stopping)
patience = 3
early_stop_counter = 0

def train_epoch(loader, model, criterion, optimizer, scaler):
    model.train()
    total_loss, total_correct, total_samples = 0, 0, 0
    all_probs, all_labels = [], []
    
    progress_bar = tqdm(loader, desc="Training")
    
    for images, labels in progress_bar:
        images, labels = images.to(DEVICE), labels.to(DEVICE).float()
        
        optimizer.zero_grad()
        
        with amp.autocast():
            outputs = model(images).squeeze()
            loss = criterion(outputs, labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        predicted = (torch.sigmoid(outputs) > 0.5).float()
        total_correct += (predicted == labels).sum().item()
        total_samples += labels.size(0)
        
        all_probs.extend(torch.sigmoid(outputs).cpu().detach().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    accuracy = total_correct / total_samples
    auc = roc_auc_score(all_labels, all_probs)
    avg_loss = total_loss / len(loader)
    
    return avg_loss, accuracy, auc

def validate_epoch(loader, model, criterion):
    model.eval()
    total_loss, total_correct, total_samples = 0, 0, 0
    all_probs, all_labels, all_preds = [], [], []
    
    with torch.no_grad():
        progress_bar = tqdm(loader, desc="Validation")
        
        for images, labels in progress_bar:
            images, labels = images.to(DEVICE), labels.to(DEVICE).float()
            
            with amp.autocast():
                outputs = model(images).squeeze()
                loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            predicted = (torch.sigmoid(outputs) > 0.5).float()  # Default threshold
            total_correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)
            
            all_probs.extend(torch.sigmoid(outputs).cpu().detach().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())
            
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    accuracy = total_correct / total_samples
    auc = roc_auc_score(all_labels, all_probs)
    avg_loss = total_loss / len(loader)
    
    # Modified: Add per-class metrics
    prec, rec, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average=None)
    print(f"Metastasis Precision: {prec[1]:.4f}, Recall: {rec[1]:.4f}, F1: {f1[1]:.4f}")
    
    return avg_loss, accuracy, auc

  scaler = amp.GradScaler()  # For mixed precision


In [10]:
# ===== 4) Training Loop (Modified: Added early stopping, scheduler step) =====
print("Starting warm-up training phase...")
history = {'train_loss': [], 'train_acc': [], 'train_auc': [], 'val_loss': [], 'val_acc': [], 'val_auc': []}

for epoch in range(WARMUP_EPOCHS):
    train_loss, train_acc, train_auc = train_epoch(train_loader, model, criterion, optimizer, scaler)
    val_loss, val_acc, val_auc = validate_epoch(valid_loader, model, criterion)
    
    print(f"Warmup Epoch {epoch+1}/{WARMUP_EPOCHS}")
    print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, AUC: {train_auc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, AUC: {val_auc:.4f}")
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['train_auc'].append(train_auc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_auc'].append(val_auc)

# Unfreeze backbone for finetuning
for param in model.backbone.parameters():
    param.requires_grad = True

# Reset optimizer for finetuning with new LR and higher weight_decay
optimizer = optim.AdamW(model.parameters(), lr=LR_FINETUNE, weight_decay=1e-3)  # Modified: Higher weight_decay

print("Starting fine-tuning phase...")
for epoch in range(FINETUNE_EPOCHS):
    train_loss, train_acc, train_auc = train_epoch(train_loader, model, criterion, optimizer, scaler)
    val_loss, val_acc, val_auc = validate_epoch(valid_loader, model, criterion)
    
    scheduler.step()  # Modified: Step the cosine scheduler
    
    print(f"Finetune Epoch {epoch+1}/{FINETUNE_EPOCHS}")
    print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, AUC: {train_auc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, AUC: {val_auc:.4f}")
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['train_auc'].append(train_auc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_auc'].append(val_auc)
    
    # Save best model
    if val_auc > best_auc:
        best_auc = val_auc
        torch.save(model.state_dict(), ckpt_path)
        print("Saved best model")
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print("Early stopping triggered")
            break

Starting warm-up training phase...


  with amp.autocast():
Training: 100%|██████████| 4096/4096 [34:35<00:00,  1.97it/s, loss=0.0986]
  with amp.autocast():
Validation: 100%|██████████| 512/512 [00:41<00:00, 12.41it/s, loss=0.0988]


Metastasis Precision: 0.8810, Recall: 0.7157, F1: 0.7898
Warmup Epoch 1/3
Train Loss: 0.0821, Acc: 0.8067, AUC: 0.8877
Val Loss: 0.0778, Acc: 0.8097, AUC: 0.9074


  with amp.autocast():
Training: 100%|██████████| 4096/4096 [33:03<00:00,  2.07it/s, loss=0.1147]
  with amp.autocast():
Validation: 100%|██████████| 512/512 [00:35<00:00, 14.60it/s, loss=0.0909]


Metastasis Precision: 0.8754, Recall: 0.7251, F1: 0.7932
Warmup Epoch 2/3
Train Loss: 0.0765, Acc: 0.8223, AUC: 0.9039
Val Loss: 0.0765, Acc: 0.8111, AUC: 0.9094


  with amp.autocast():
Training: 100%|██████████| 4096/4096 [30:53<00:00,  2.21it/s, loss=0.0790]
  with amp.autocast():
Validation: 100%|██████████| 512/512 [00:35<00:00, 14.63it/s, loss=0.0841]


Metastasis Precision: 0.8726, Recall: 0.7394, F1: 0.8005
Warmup Epoch 3/3
Train Loss: 0.0747, Acc: 0.8273, AUC: 0.9087
Val Loss: 0.0744, Acc: 0.8159, AUC: 0.9133
Starting fine-tuning phase...


  with amp.autocast():
Training: 100%|██████████| 4096/4096 [34:53<00:00,  1.96it/s, loss=0.0515]
  with amp.autocast():
Validation: 100%|██████████| 512/512 [00:34<00:00, 14.66it/s, loss=0.1074]


Metastasis Precision: 0.9584, Recall: 0.7705, F1: 0.8542
Finetune Epoch 1/10
Train Loss: 0.0438, Acc: 0.9135, AUC: 0.9702
Val Loss: 0.0680, Acc: 0.8687, AUC: 0.9592
Saved best model


  with amp.autocast():
Training: 100%|██████████| 4096/4096 [34:15<00:00,  1.99it/s, loss=0.0251]
  with amp.autocast():
Validation: 100%|██████████| 512/512 [00:36<00:00, 14.00it/s, loss=0.1145]


Metastasis Precision: 0.9632, Recall: 0.7572, F1: 0.8479
Finetune Epoch 2/10
Train Loss: 0.0334, Acc: 0.9382, AUC: 0.9824
Val Loss: 0.0848, Acc: 0.8643, AUC: 0.9534


  with amp.autocast():
Training: 100%|██████████| 4096/4096 [33:38<00:00,  2.03it/s, loss=0.0209]
  with amp.autocast():
Validation: 100%|██████████| 512/512 [00:35<00:00, 14.58it/s, loss=0.0856]


Metastasis Precision: 0.9616, Recall: 0.8302, F1: 0.8911
Finetune Epoch 3/10
Train Loss: 0.0295, Acc: 0.9468, AUC: 0.9862
Val Loss: 0.0639, Acc: 0.8986, AUC: 0.9630
Saved best model


  with amp.autocast():
Training: 100%|██████████| 4096/4096 [34:51<00:00,  1.96it/s, loss=0.0168]
  with amp.autocast():
Validation: 100%|██████████| 512/512 [00:35<00:00, 14.39it/s, loss=0.0968]


Metastasis Precision: 0.9568, Recall: 0.8303, F1: 0.8891
Finetune Epoch 4/10
Train Loss: 0.0269, Acc: 0.9522, AUC: 0.9884
Val Loss: 0.0668, Acc: 0.8965, AUC: 0.9613


  with amp.autocast():
Training: 100%|██████████| 4096/4096 [36:40<00:00,  1.86it/s, loss=0.0272]
  with amp.autocast():
Validation: 100%|██████████| 512/512 [00:36<00:00, 13.85it/s, loss=0.1107]


Metastasis Precision: 0.9530, Recall: 0.8184, F1: 0.8806
Finetune Epoch 5/10
Train Loss: 0.0250, Acc: 0.9562, AUC: 0.9899
Val Loss: 0.0672, Acc: 0.8891, AUC: 0.9587


  with amp.autocast():
Training: 100%|██████████| 4096/4096 [36:32<00:00,  1.87it/s, loss=0.0339]
  with amp.autocast():
Validation: 100%|██████████| 512/512 [00:36<00:00, 13.94it/s, loss=0.1478]

Metastasis Precision: 0.9544, Recall: 0.8027, F1: 0.8720
Finetune Epoch 6/10
Train Loss: 0.0238, Acc: 0.9580, AUC: 0.9909
Val Loss: 0.0729, Acc: 0.8823, AUC: 0.9577
Early stopping triggered





In [None]:
# ===== 5) Evaluation (Updated to match provided code) =====
print("Loading best model for evaluation...")
model.load_state_dict(torch.load(ckpt_path))
model.eval()

# Test evaluation
test_loss, test_acc, test_auc = validate_epoch(test_loader, model, criterion)
print(f"Test Loss: {test_loss:.4f}, Acc: {test_acc:.4f}, AUC: {test_auc:.4f}")

# Detailed metrics
print("Generating detailed predictions...")
preds, trues = [], []
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing"):
        images = images.to(DEVICE)
        
        if MIXED_PRECISION:
            with torch.cuda.amp.autocast():
                outputs = model(images).squeeze()
        else:
            outputs = model(images).squeeze()
        
        preds.extend(torch.sigmoid(outputs).cpu().numpy())  # Apply sigmoid for metrics
        trues.extend(labels.cpu().numpy())

# Classification Report
print("\nClassification Report:")
print(classification_report(
    trues, 
    (np.array(preds) > 0.5).astype(int), 
    target_names=['No Metastasis', 'Metastasis']
))

# Confusion Matrix
plt.figure(figsize=(8, 6))
cm = confusion_matrix(trues, (np.array(preds) > 0.5).astype(int))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['No Metastasis', 'Metastasis'], 
            yticklabels=['No Metastasis', 'Metastasis'])
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

# Plot training history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history['train_acc'], label='Train Acc', color='blue')
plt.plot(history['val_acc'], label='Val Acc', color='orange')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(history['train_loss'], label='Train Loss', color='blue')
plt.plot(history['val_loss'], label='Val Loss', color='orange')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Additional metrics
print(f"\nAdditional Metrics:")
print(f"AUC-ROC Score: {roc_auc_score(trues, preds):.4f}")

# Threshold analysis
thresholds = [0.3, 0.4, 0.5, 0.6, 0.7]
print(f"\nThreshold Analysis:")
for thresh in thresholds:
    pred_binary = (np.array(preds) > thresh).astype(int)
    acc = (pred_binary == np.array(trues)).mean()
    prec, rec, f1, _ = precision_recall_fscore_support(trues, pred_binary, average=None)
    print(f"Threshold {thresh}: Accuracy = {acc:.4f}, Metastasis Recall = {rec[1]:.4f}")

# Clean up datasets
print("\nCleaning up...")
if hasattr(train_dataset, 'close'):
    train_dataset.close()
if hasattr(valid_dataset, 'close'):
    valid_dataset.close()
if hasattr(test_dataset, 'close'):
    test_dataset.close()

print("Training and evaluation completed!")

# Optional: Save final results
results_summary = {
    'best_validation_auc': best_auc,
    'test_loss': test_loss,
    'test_accuracy': test_acc,
    'test_auc': test_auc,
    'training_history': history
}

# Save results to file
with open('/kaggle/working/training_results.json', 'w') as f:
    json.dump(results_summary, f, indent=2)

print("Results saved to training_results.json")