In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# This cell must be the VERY FIRST cell in your notebook
!pip install segmentation-models-pytorch albumentations timm --quiet

In [None]:
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2 # Albumentations can use this backend

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset, Subset
from PIL import Image
import segmentation_models_pytorch as smp

import albumentations as A
from albumentations.pytorch import ToTensorV2

# ===============================
# 1. Data Preparation at High Resolution
# ===============================

train_transform = A.Compose([
    A.Resize(512, 512),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05),
    A.GridDistortion(p=0.5),
    A.RandomBrightnessContrast(p=0.8),
    A.CLAHE(p=0.8),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(512, 512),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

class RetinaNumpyDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_paths = sorted(glob.glob(os.path.join(image_dir, "*")))
        self.mask_paths = sorted(glob.glob(os.path.join(mask_dir, "*")))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = np.array(Image.open(self.image_paths[idx]).convert("RGB"))
        mask = np.array(Image.open(self.mask_paths[idx]).convert("L"))
        return image, mask

class TransformedRetinaDataset(Dataset):
    def __init__(self, original_dataset, transform):
        self.original_dataset = original_dataset
        self.transform = transform

    def __len__(self):
        return len(self.original_dataset)

    def __getitem__(self, idx):
        image_np, mask_np = self.original_dataset[idx]
        transformed = self.transform(image=image_np, mask=mask_np)
        image = transformed['image']
        mask = transformed['mask']
        mask = (mask > 0).float().unsqueeze(0)
        return image, mask

datasets_np = [
    RetinaNumpyDataset("/kaggle/input/drive-digital-retinal-images-for-vessel-extraction/DRIVE/training/images", "/kaggle/input/drive-digital-retinal-images-for-vessel-extraction/DRIVE/training/1st_manual"),
    RetinaNumpyDataset("/kaggle/input/retina-blood-vessel/Data/train/image", "/kaggle/input/retina-blood-vessel/Data/train/mask"),
]

# --- THIS IS THE CORRECTED PART ---
full_np_dataset = ConcatDataset(datasets_np)

torch.manual_seed(42)
train_size = int(0.8 * len(full_np_dataset))
val_size = len(full_np_dataset) - train_size
train_indices, val_indices = random_split(range(len(full_np_dataset)), [train_size, val_size])

train_dataset = TransformedRetinaDataset(Subset(full_np_dataset, train_indices.indices), transform=train_transform)
val_dataset = TransformedRetinaDataset(Subset(full_np_dataset, val_indices.indices), transform=val_transform)

BATCH_SIZE = 1 
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

# ===============================
# 2. Model, Loss, and Optimizer
# ===============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = smp.UnetPlusPlus(
    encoder_name="efficientnet-b5",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1,
).to(device)

class SuperLoss(nn.Module):
    def __init__(self, focal_weight=0.5, tversky_weight=0.5):
        super().__init__()
        self.focal = smp.losses.FocalLoss(mode='binary')
        self.tversky = smp.losses.TverskyLoss(mode='binary', alpha=0.3, beta=0.7)

    def forward(self, pred, true):
        return 0.5 * self.focal(pred, true) + 0.5 * self.tversky(pred, true)

loss_fn = SuperLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

EPOCHS = 50 

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

# ===============================
# 3. Metrics
# ===============================
# Helper function to get TP, FP, FN, TN
def get_stats(y_pred, y_true):
    y_pred_sigmoid = torch.sigmoid(y_pred)
    # Using a threshold of 0.5 for metrics calculation
    y_pred_binary = (y_pred_sigmoid > 0.5).long()
    tp, fp, fn, tn = smp.metrics.get_stats(y_pred_binary, y_true.long(), mode='binary')
    return tp, fp, fn, tn

def dice_coef(y_pred, y_true):
    return smp.metrics.f1_score(*get_stats(y_pred, y_true), reduction='micro')

def iou_coef(y_pred, y_true):
    return smp.metrics.iou_score(*get_stats(y_pred, y_true), reduction='micro')

def pixel_accuracy(y_pred, y_true):
    return smp.metrics.accuracy(*get_stats(y_pred, y_true), reduction='micro')

def precision(y_pred, y_true):
    return smp.metrics.precision(*get_stats(y_pred, y_true), reduction='micro')

def recall(y_pred, y_true):
    return smp.metrics.recall(*get_stats(y_pred, y_true), reduction='micro')

# ===============================
# 4. Training Loop
# ===============================
best_val_iou = 0.0 
patience = 5 
epochs_no_improve = 0

for epoch in range(EPOCHS):
    model.train()
    train_loss, train_dice, train_iou, train_acc, train_prec, train_rec = 0, 0, 0, 0, 0, 0
    
    for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]"):
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        loss = loss_fn(outputs, masks)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        train_dice += dice_coef(outputs, masks).item()
        train_iou += iou_coef(outputs, masks).item()
        train_acc += pixel_accuracy(outputs, masks).item()
        train_prec += precision(outputs, masks).item()
        train_rec += recall(outputs, masks).item()

    model.eval()
    val_loss, val_dice, val_iou, val_acc, val_prec, val_rec = 0, 0, 0, 0, 0, 0

    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]"):
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, masks)

            val_loss += loss.item()
            val_dice += dice_coef(outputs, masks).item()
            val_iou += iou_coef(outputs, masks).item()
            val_acc += pixel_accuracy(outputs, masks).item()
            val_prec += precision(outputs, masks).item()
            val_rec += recall(outputs, masks).item()

    scheduler.step()
    
    avg_train_loss = train_loss/len(train_loader)
    avg_train_dice = train_dice/len(train_loader)
    avg_train_iou = train_iou/len(train_loader)
    avg_train_acc = train_acc/len(train_loader)
    avg_train_prec = train_prec/len(train_loader)
    avg_train_rec = train_rec/len(train_loader)
    
    avg_val_loss = val_loss/len(val_loader)
    avg_val_dice = val_dice/len(val_loader)
    avg_val_iou = val_iou/len(val_loader)
    avg_val_acc = val_acc/len(val_loader)
    avg_val_prec = val_prec/len(val_loader)
    avg_val_rec = val_rec/len(val_loader)

    print(f"--- Epoch {epoch+1}/{EPOCHS} ---")
    print(f"Train -> Loss: {avg_train_loss:.4f} | Dice: {avg_train_dice:.4f} | IoU: {avg_train_iou:.4f} | Acc: {avg_train_acc:.4f} | Prec: {avg_train_prec:.4f} | Rec: {avg_train_rec:.4f}")
    print(f"Valid -> Loss: {avg_val_loss:.4f} | Dice: {avg_val_dice:.4f} | IoU: {avg_val_iou:.4f} | Acc: {avg_val_acc:.4f} | Prec: {avg_val_prec:.4f} | Rec: {avg_val_rec:.4f}")

    if avg_val_iou > best_val_iou:
        best_val_iou = avg_val_iou
        torch.save(model.state_dict(), "/kaggle/working/unetplusplus_retina_final.pth")
        print(f"✅ Model Saved! New best Validation IoU: {best_val_iou:.4f}")
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        print(f"Validation IoU did not improve for {epochs_no_improve} epoch(s).")

    if epochs_no_improve >= patience:
        print(f"Early stopping triggered after {patience} epochs with no improvement.")
        break

# ===============================
# 5. Prediction Visualization
# ===============================
# We will only run this if the training finishes successfully and saves a model
if os.path.exists("/kaggle/working/unetplusplus_retina_final.pth"):
    print("\nLoading best model for prediction...")
    model.load_state_dict(torch.load("/kaggle/working/unetplusplus_retina_final.pth"))
    print("Best model loaded successfully.")

    def denormalize(tensor, mean, std):
        mean = torch.tensor(mean).view(3, 1, 1)
        std = torch.tensor(std).view(3, 1, 1)
        return tensor.cpu() * std + mean

    images, masks = next(iter(val_loader))
    images, masks = images.to(device), masks.to(device)

    with torch.no_grad():
        outputs = model(images)
        preds = (torch.sigmoid(outputs) > 0.5).float()

    i = 0
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.title("Original")
    img_denorm = denormalize(images[i], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    plt.imshow(img_denorm.permute(1, 2, 0).clamp(0, 1)) # Clamp values to [0,1] for proper display
    plt.subplot(1, 3, 2)
    plt.title("True Mask")
    plt.imshow(masks[i][0].cpu(), cmap="gray")
    plt.subplot(1, 3, 3)
    plt.title("Predicted Mask")
    plt.imshow(preds[i][0].cpu(), cmap="gray")
    plt.show()
else:
    print("\nNo model was saved, skipping prediction visualization.")

In [None]:
!pip install segmentation-models-pytorch timm

In [None]:
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
import torch
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
import random

# ==========================================================
# 1. Configuration 
# ==========================================================
MODEL_PATH = "/kaggle/input/unetplusplus-retina-final/unetplusplus_retina_final.pth"

BASE_INPUT_DIR = "/kaggle/input/aptos2019/"
BASE_OUTPUT_DIR = "/kaggle/working/"

TRAIN_INPUT_FOLDER = os.path.join(BASE_INPUT_DIR, "train_images/train_images")
VAL_INPUT_FOLDER = os.path.join(BASE_INPUT_DIR, "val_images/val_images")
TEST_INPUT_FOLDER = os.path.join(BASE_INPUT_DIR, "test_images/test_images")

TRAIN_OUTPUT_FOLDER = os.path.join(BASE_OUTPUT_DIR, "train_images_predicted")
VAL_OUTPUT_FOLDER = os.path.join(BASE_OUTPUT_DIR, "val_images_predicted")
TEST_OUTPUT_FOLDER = os.path.join(BASE_OUTPUT_DIR, "test_images_predicted")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 8

# ==========================================================
# 2. Definitions 
# ==========================================================
predict_transform = A.Compose([
    A.Resize(512, 512),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

class PredictionDataset(Dataset):
    def __init__(self, image_paths, transform):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        try:
            image = np.array(Image.open(image_path).convert("RGB"))
            transformed = self.transform(image=image)
            image_tensor = transformed['image']
            return image_tensor, image_path
        except Exception as e:
            print(f"Could not read image {image_path}, skipping. Error: {e}")
            return None, None

# ==========================================================
# 3. Model Loading 
# ==========================================================
if not os.path.exists(MODEL_PATH):
    print(f"ERROR: Model file not found at {MODEL_PATH}")
    model = None
else:
    print("Loading the champion model...")
    model = smp.UnetPlusPlus(
        encoder_name="efficientnet-b5",
        encoder_weights=None,
        in_channels=3,
        classes=1,
    ).to(DEVICE)
    model.load_state_dict(torch.load(MODEL_PATH))
    model.eval()
    print("Model loaded successfully!")

# ==========================================================
# 4. Prediction Function 
# ==========================================================
def predict_on_folder(input_folder, output_folder, model, transform):
    print(f"\n--- Starting prediction for folder: {input_folder} ---")
    os.makedirs(output_folder, exist_ok=True)
    
    print("Step 1 of 2: Searching for image files...")
    image_paths = []
    for ext in ["*.png", "*.jpeg", "*.jpg", "*.tif"]:
        image_paths.extend(glob.glob(os.path.join(input_folder, ext)))
    
    print(f"Found {len(image_paths)} images.")
    
    if not image_paths:
        print("No images found in this folder, skipping.")
        return

    dataset = PredictionDataset(image_paths, transform=transform)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    print("Step 2 of 2: Starting the prediction loop...")
    with torch.no_grad():
        for batch in tqdm(loader, desc=f"Predicting"):
            if batch is None or batch[0] is None:
                continue
            image_tensors, image_paths_batch = batch
            
            image_tensors = image_tensors.to(DEVICE)
            outputs = model(image_tensors)
            preds = (torch.sigmoid(outputs) > 0.5).float()

            for i, pred_tensor in enumerate(preds):
                pred_mask_np = pred_tensor.squeeze().cpu().numpy()
                pred_mask_img = Image.fromarray((pred_mask_np * 255).astype(np.uint8))
                
                original_filename = os.path.basename(image_paths_batch[i])
                new_filename = os.path.splitext(original_filename)[0] + ".png"
                save_path = os.path.join(output_folder, new_filename)
                pred_mask_img.save(save_path)
    
    print(f"--- Predictions for {input_folder} are complete and saved to {output_folder} ---")

# ==========================================================
# 5. Main Execution
# ==========================================================
if model is not None:
    predict_on_folder(TRAIN_INPUT_FOLDER, TRAIN_OUTPUT_FOLDER, model, predict_transform)
    predict_on_folder(VAL_INPUT_FOLDER, VAL_OUTPUT_FOLDER, model, predict_transform)
    predict_on_folder(TEST_INPUT_FOLDER, TEST_OUTPUT_FOLDER, model, predict_transform)
    print("\nAll predictions are complete!")
else:
    print("\nSkipping prediction because the model file was not loaded.")

# ==========================================================
# 6. Visualization 
# ==========================================================
ORIGINAL_DIR_TO_SHOW = TEST_INPUT_FOLDER
PREDICTED_DIR_TO_SHOW = TEST_OUTPUT_FOLDER
NUM_EXAMPLES = 5

print(f"\nDisplaying {NUM_EXAMPLES} random examples from the test set predictions...")
predicted_mask_paths = glob.glob(os.path.join(PREDICTED_DIR_TO_SHOW, "*.png"))

if not predicted_mask_paths:
    print("No predicted masks found to display!")
else:
    random_samples = random.sample(predicted_mask_paths, min(NUM_EXAMPLES, len(predicted_mask_paths)))
    
    plt.figure(figsize=(10, 5 * NUM_EXAMPLES))
    
    for i, mask_path in enumerate(random_samples):
        predicted_mask = Image.open(mask_path)
        base_filename = os.path.splitext(os.path.basename(mask_path))[0]
        
        original_path = None
        for ext in ['.png', '.jpeg', '.jpg', '.tif']:
            path_to_check = os.path.join(ORIGINAL_DIR_TO_SHOW, base_filename + ext)
            if os.path.exists(path_to_check):
                original_path = path_to_check
                break
        
        if original_path:
            original_image = Image.open(original_path)
            
            plt.subplot(NUM_EXAMPLES, 2, 2*i + 1)
            plt.title(f"Original: {base_filename}")
            plt.imshow(original_image)
            plt.axis('off')
            
            plt.subplot(NUM_EXAMPLES, 2, 2*i + 2)
            plt.title("Predicted Mask")
            plt.imshow(predicted_mask, cmap='gray')
            plt.axis('off')
        else:
            print(f"Could not find original image for mask: {mask_path}")

    plt.tight_layout()
    plt.show()

In [None]:
import os
import gc
import cv2
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import albumentations as A
from albumentations.pytorch import ToTensorV2

from sklearn.metrics import accuracy_score, cohen_kappa_score
from sklearn.utils.class_weight import compute_class_weight

import timm

# =============================================================================
# CONFIGURATION FOR 4-CHANNEL (IMAGE + VESSEL MASK) INPUT
# =============================================================================
class CFG:
    # --- MODEL & IMAGE SIZE (Same as your baseline EffNet-B3) ---
    MODEL_NAME = 'efficientnet_b3'
    IMG_SIZE = 384
    BATCH_SIZE = 8

    # --- DATA PATHS ---
    BASE_PATH = "/kaggle/input/aptos2019"
    TRAIN_CSV = os.path.join(BASE_PATH, "train_1.csv")
    VAL_CSV   = os.path.join(BASE_PATH, "valid.csv")
    TRAIN_DIR = os.path.join(BASE_PATH, "train_images", "train_images")
    VAL_DIR   = os.path.join(BASE_PATH, "val_images", "val_images")
    
    # --- PATHS TO YOUR NEW SEGMENTED MASKS ---
    SEG_BASE_PATH = "/kaggle/working"
    SEG_TRAIN_DIR = os.path.join(SEG_BASE_PATH, "train_images_predicted")
    SEG_VAL_DIR   = os.path.join(SEG_BASE_PATH, "val_images_predicted")

    # --- TRAINING PIPELINE (Identical to your successful run for fair comparison) ---
    S1_EPOCHS = 15; S1_LR = 1e-4; S1_USE_MIXUP = True
    S2_EPOCHS = 15; S2_LR = 3e-5; S2_USE_MIXUP = False
    
    # --- GENERAL & SAVING ---
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    NUM_WORKERS = 2
    PATIENCE = 5
    SEED = 42
    LABEL_SMOOTHING = 0.05
    # New save paths for this experiment
    SAVE_PATH_S1 = "best_model_effnet_b3_seg_stage1.pth"
    SAVE_PATH_FINAL = "best_model_effnet_b3_seg_final.pth"

def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = True
seed_everything(CFG.SEED)

# =============================================================================
# PREPROCESSING & AUGMENTATIONS (CORRECTED LOGIC)
# =============================================================================
def preprocess_ben_graham(image, output_size):
    # This function only preprocesses the 3-channel image
    try:
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        if gray.mean() < 15: 
            image = cv2.resize(image, (output_size, output_size), interpolation=cv2.INTER_AREA)
        else:
            _, thresh = cv2.threshold(gray, 15, 255, cv2.THRESH_BINARY)
            contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            if contours:
                largest_contour = max(contours, key=cv2.contourArea)
                x, y, w, h = cv2.boundingRect(largest_contour)
                image = image[y:y+h, x:x+w]
            image = cv2.resize(image, (output_size, output_size), interpolation=cv2.INTER_AREA)
    except Exception: 
        image = cv2.resize(image, (output_size, output_size), interpolation=cv2.INTER_AREA)
    
    b, g, r = cv2.split(image)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    g = clahe.apply(g)
    
    return cv2.merge((b, g, r))

def get_transforms(is_train=True):
    # This pipeline now only contains augmentations. Preprocessing happens before.
    if is_train:
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=15, p=0.7),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.7),
            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
        ])
    else:
        # No augmentations for validation/test
        return None

# =============================================================================
# UPGRADED DATASET (CORRECTED LOGIC)
# =============================================================================
class Dataset4Channel(Dataset):
    def __init__(self, df, img_dir, seg_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.seg_dir = seg_dir
        self.transform = transform
        # The final normalization/tensor conversion is always applied
        self.post_transform = A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406, 0.5], std=[0.229, 0.224, 0.225, 0.5]),
            ToTensorV2()
        ])

    def __len__(self): 
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['id_code'] + '.png')
        seg_path = os.path.join(self.seg_dir, row['id_code'] + '.png')
        
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        mask = cv2.imread(seg_path, cv2.IMREAD_GRAYSCALE)
        
        # Step 1: Apply preprocessing to the 3-channel RGB image first
        img = preprocess_ben_graham(img, CFG.IMG_SIZE)
        
        # Step 2: Resize the mask to the exact same size to ensure alignment
        mask = cv2.resize(mask, (CFG.IMG_SIZE, CFG.IMG_SIZE), interpolation=cv2.INTER_NEAREST)
        
        # Step 3: Apply geometric and color augmentations to the ALIGNED pair
        if self.transform:
            augmented = self.transform(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']
            
        # Step 4: Add the mask as the 4th channel
        img_4_channel = np.dstack((img, mask))
        
        # Step 5: Apply final normalization and convert to tensor
        img_4_channel = self.post_transform(image=img_4_channel)['image']
            
        label = torch.tensor(row['diagnosis'], dtype=torch.long)
        return img_4_channel, label

# =============================================================================
# UPGRADED MODEL TO ACCEPT 4 CHANNELS (Unchanged, was already correct)
# =============================================================================
class EfficientNet4ChannelOrdinal(nn.Module):
    def __init__(self, model_name, num_classes=5, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool='avg')
        
        original_conv = self.backbone.conv_stem
        original_weights = original_conv.weight.clone()

        new_conv = nn.Conv2d(4, original_conv.out_channels, 
                             kernel_size=original_conv.kernel_size, stride=original_conv.stride, 
                             padding=original_conv.padding, bias=(original_conv.bias is not None))
        with torch.no_grad():
            new_conv.weight[:, :3] = original_weights
            new_conv.weight[:, 3] = original_weights.mean(dim=1)
        self.backbone.conv_stem = new_conv
        
        feature_dim = self.backbone.num_features
        self.classifier = nn.Sequential(nn.Dropout(0.5), nn.Linear(feature_dim, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes - 1))

    def forward(self, x): 
        return self.classifier(self.backbone(x))

# --- Loss functions, training loops, and other utilities are unchanged ---
class WeightedOrdinalFocalLoss(nn.Module):
    def __init__(self, num_classes=5, gamma=2.0, class_weights=None, label_smoothing=0.0):
        super().__init__()
        self.num_classes, self.gamma, self.class_weights, self.label_smoothing = num_classes, gamma, class_weights, label_smoothing
        self.bce = nn.BCEWithLogitsLoss(reduction='none')
    def forward(self, outputs, targets):
        ordinal_targets = torch.zeros_like(outputs)
        for i, t in enumerate(targets):
            if t > 0: ordinal_targets[i, :t] = 1.0
        if self.label_smoothing > 0.0: ordinal_targets = ordinal_targets * (1.0 - self.label_smoothing) + 0.5 * self.label_smoothing
        bce = self.bce(outputs, ordinal_targets)
        if self.class_weights is not None:
            weights = self.class_weights[targets].view(-1, 1).expand(-1, outputs.shape[1])
            bce = bce * weights
        pt = torch.exp(-bce)
        focal = (1 - pt) ** self.gamma * bce
        return focal.mean()

class SmoothKappaLoss(nn.Module):
    def __init__(self, num_classes=5, eps=1e-7):
        super().__init__()
        self.num_classes, self.eps = num_classes, eps
        W = torch.zeros(num_classes, num_classes)
        for i in range(num_classes):
            for j in range(num_classes): W[i,j] = ((i - j)**2) / ((num_classes - 1)**2)
        self.register_buffer("W", W)
    def forward(self, outputs, targets):
        device = outputs.device; B = outputs.size(0); probs = torch.sigmoid(outputs)
        class_probs = torch.zeros(B, self.num_classes, device=device)
        class_probs[:, 0] = 1 - probs[:, 0]
        for k in range(1, self.num_classes-1): class_probs[:, k] = probs[:, k-1] - probs[:, k]
        class_probs[:, -1] = probs[:, -1]
        class_probs = torch.clamp(class_probs, min=self.eps, max=1.0)
        one_hot = F.one_hot(targets, num_classes=self.num_classes).float().to(device)
        conf_mat = torch.matmul(one_hot.T, class_probs)
        hist_true = one_hot.sum(dim=0); hist_pred = class_probs.sum(dim=0)
        expected = torch.outer(hist_true, hist_pred)
        W = self.W.to(device); obs = torch.sum(W * conf_mat); exp = torch.sum(W * expected)
        kappa = 1.0 - (B * obs) / (exp + self.eps)
        return 1.0 - kappa

def mixup_data(x, y, alpha=0.4):
    if alpha > 0: lam = np.random.beta(alpha, alpha)
    else: lam = 1
    batch_size = x.size()[0]; index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def ordinal_to_class(outputs): 
    return torch.sum(torch.sigmoid(outputs) > 0.5, dim=1).long()

def calculate_metrics(outputs, targets):
    preds = ordinal_to_class(outputs).cpu().numpy()
    targets_np = targets.cpu().numpy()
    return accuracy_score(targets_np, preds), cohen_kappa_score(targets_np, preds, weights='quadratic')

def clear_memory(): 
    gc.collect()
    torch.cuda.empty_cache()

def train_epoch(model, loader, optimizer, criterion, scaler, device, use_mixup):
    model.train(); running_loss = 0.0; all_out, all_t = [], []
    pbar = tqdm(loader, desc="Training", leave=False)
    for images, targets in pbar:
        images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        if use_mixup: images, targets_a, targets_b, lam = mixup_data(images, targets)
        with torch.cuda.amp.autocast():
            outputs = model(images)
            if use_mixup: loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
            else: loss = criterion(outputs, targets)
        scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        running_loss += loss.item(); all_out.append(outputs.detach()); all_t.append(targets.detach())
        pbar.set_postfix(loss=loss.item())
    all_out, all_t = torch.cat(all_out), torch.cat(all_t)
    return running_loss / len(loader), *calculate_metrics(all_out, all_t)

def validate_epoch(model, loader, criterion, device):
    model.eval(); running_loss = 0.0; all_out, all_t = [], []
    with torch.no_grad():
        pbar = tqdm(loader, desc="Validating", leave=False)
        for images, targets in pbar:
            images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, targets)
            running_loss += loss.item()
            all_out.append(outputs)
            all_t.append(targets)
    all_out, all_t = torch.cat(all_out), torch.cat(all_t)
    return running_loss / len(loader), *calculate_metrics(all_out, all_t)

def main():
    print(f"Device: {CFG.DEVICE}, Model: {CFG.MODEL_NAME} (4-Channel), Image Size: {CFG.IMG_SIZE}")
    train_df = pd.read_csv(CFG.TRAIN_CSV)
    val_df = pd.read_csv(CFG.VAL_CSV)
    
    train_tf = get_transforms(is_train=True)
    val_tf = get_transforms(is_train=False)

    train_ds = Dataset4Channel(train_df, CFG.TRAIN_DIR, CFG.SEG_TRAIN_DIR, transform=train_tf)
    val_ds   = Dataset4Channel(val_df, CFG.VAL_DIR, CFG.SEG_VAL_DIR, transform=val_tf)

    class_weights_sampler = compute_class_weight('balanced', classes=np.unique(train_df['diagnosis']), y=train_df['diagnosis'])
    sample_weights = np.array([class_weights_sampler[int(l)] for l in train_df['diagnosis']])
    sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
    train_loader = DataLoader(train_ds, batch_size=CFG.BATCH_SIZE, sampler=sampler, num_workers=CFG.NUM_WORKERS, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=CFG.BATCH_SIZE*2, shuffle=False, num_workers=CFG.NUM_WORKERS, pin_memory=True)
    
    model = EfficientNet4ChannelOrdinal(CFG.MODEL_NAME).to(CFG.DEVICE)
    class_weights_loss = torch.tensor(class_weights_sampler, dtype=torch.float).to(CFG.DEVICE)
    focal_loss = WeightedOrdinalFocalLoss(num_classes=5, gamma=2.0, class_weights=class_weights_loss, label_smoothing=CFG.LABEL_SMOOTHING)
    kappa_loss = SmoothKappaLoss(num_classes=5)
    
    def hybrid_loss(outputs, targets): 
        return 0.7 * kappa_loss(outputs, targets) + 0.3 * focal_loss(outputs, targets)
    
    scaler = torch.cuda.amp.GradScaler()

    # --- STAGE 1 ---
    print("\n" + "="*50 + "\n     STARTING STAGE 1 (4-Channel)\n" + "="*50)
    opt = optim.AdamW(model.parameters(), lr=CFG.S1_LR, weight_decay=1e-4)
    sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=CFG.S1_EPOCHS)
    best_val_qwk, patience_counter = -1, 0

    for epoch in range(CFG.S1_EPOCHS):
        clear_memory()
        print(f"\nEpoch {epoch+1}/{CFG.S1_EPOCHS}")
        train_loss, train_acc, train_qwk = train_epoch(model, train_loader, opt, focal_loss, scaler, CFG.DEVICE, CFG.S1_USE_MIXUP)
        val_loss, val_acc, val_qwk = validate_epoch(model, val_loader, focal_loss, CFG.DEVICE)
        sched.step()
        print(f"Train -> Loss:{train_loss:.4f} Acc:{train_acc:.4f} QWK:{train_qwk:.4f}")
        print(f"Valid -> Loss:{val_loss:.4f} Acc:{val_acc:.4f} QWK:{val_qwk:.4f}")
        if val_qwk > best_val_qwk:
            print(f"Val QWK improved from {best_val_qwk:.4f} to {val_qwk:.4f}. Saving model...")
            best_val_qwk, patience_counter = val_qwk, 0
            torch.save(model.state_dict(), CFG.SAVE_PATH_S1)
        else:
            patience_counter += 1
            if patience_counter >= CFG.PATIENCE: 
                print("Early stopping in Stage 1.")
                break
    
    # --- STAGE 2 ---
    print("\n" + "="*50 + "\n     STARTING STAGE 2 (4-Channel)\n" + "="*50)
    if os.path.exists(CFG.SAVE_PATH_S1):
        model.load_state_dict(torch.load(CFG.SAVE_PATH_S1))
    else:
        print("No Stage 1 model was saved. Continuing with the current model.")

    opt = optim.AdamW(model.parameters(), lr=CFG.S2_LR, weight_decay=1e-5)
    sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=CFG.S2_EPOCHS)
    best_val_qwk_stage2, patience_counter = best_val_qwk, 0

    for epoch in range(CFG.S2_EPOCHS):
        clear_memory()
        print(f"\nEpoch {epoch+1}/{CFG.S2_EPOCHS}")
        train_loss, train_acc, train_qwk = train_epoch(model, train_loader, opt, hybrid_loss, scaler, CFG.DEVICE, CFG.S2_USE_MIXUP)
        val_loss, val_acc, val_qwk = validate_epoch(model, val_loader, hybrid_loss, CFG.DEVICE)
        sched.step()
        print(f"Train -> Loss:{train_loss:.4f} Acc:{train_acc:.4f} QWK:{train_qwk:.4f}")
        print(f"Valid -> Loss:{val_loss:.4f} Acc:{val_acc:.4f} QWK:{val_qwk:.4f}")
        if val_qwk > best_val_qwk_stage2:
            print(f"Val QWK improved from {best_val_qwk_stage2:.4f} to {val_qwk:.4f}. Saving final model...")
            best_val_qwk_stage2, patience_counter = val_qwk, 0
            torch.save(model.state_dict(), CFG.SAVE_PATH_FINAL)
        else:
            patience_counter += 1
            if patience_counter >= CFG.PATIENCE: 
                print("Early stopping in Stage 2.")
                break

    print(f"\nTraining Finished!\nFinal Best QWK: {best_val_qwk_stage2:.4f}")

if __name__ == "__main__":
    main()



In [None]:
import os
import gc
import cv2
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import albumentations as A
from albumentations.pytorch import ToTensorV2

from sklearn.metrics import accuracy_score, cohen_kappa_score
from sklearn.utils.class_weight import compute_class_weight

import timm

# =============================================================================
# CONFIGURATION FOR 4-CHANNEL (IMAGE + VESSEL MASK) INPUT
# =============================================================================
class CFG:
    # --- ENHANCEMENT: Upgraded model for more power ---
    MODEL_NAME = 'efficientnet_b4'
    IMG_SIZE = 384
    BATCH_SIZE = 8

    # --- DATA PATHS (Unchanged as requested) ---
    BASE_PATH = "/kaggle/input/aptos2019"
    TRAIN_CSV = os.path.join(BASE_PATH, "train_1.csv")
    VAL_CSV   = os.path.join(BASE_PATH, "valid.csv")
    TRAIN_DIR = os.path.join(BASE_PATH, "train_images", "train_images")
    VAL_DIR   = os.path.join(BASE_PATH, "val_images", "val_images")
    
    # --- PATHS TO YOUR NEW SEGMENTED MASKS (Unchanged as requested) ---
    SEG_BASE_PATH = "/kaggle/working"
    SEG_TRAIN_DIR = os.path.join(SEG_BASE_PATH, "train_images_predicted")
    SEG_VAL_DIR   = os.path.join(SEG_BASE_PATH, "val_images_predicted")

    # --- TRAINING PIPELINE ---
    S1_EPOCHS = 15; S1_LR = 1e-4; S1_USE_MIXUP = True
    S2_EPOCHS = 15; S2_LR = 3e-5; S2_USE_MIXUP = False
    
    # --- GENERAL & SAVING ---
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    NUM_WORKERS = 2
    PATIENCE = 5
    SEED = 42
    LABEL_SMOOTHING = 0.05
    # Updated save paths for the new model
    SAVE_PATH_S1 = "best_model_effnet_b4_seg_stage1.pth"
    SAVE_PATH_FINAL = "best_model_effnet_b4_seg_final.pth"

def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = True
seed_everything(CFG.SEED)

# =============================================================================
# PREPROCESSING & AUGMENTATIONS
# =============================================================================
def preprocess_ben_graham(image, output_size):
    try:
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        if gray.mean() < 15: image = cv2.resize(image, (output_size, output_size), interpolation=cv2.INTER_AREA)
        else:
            _, thresh = cv2.threshold(gray, 15, 255, cv2.THRESH_BINARY)
            contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            if contours:
                largest_contour = max(contours, key=cv2.contourArea)
                x, y, w, h = cv2.boundingRect(largest_contour)
                image = image[y:y+h, x:x+w]
            image = cv2.resize(image, (output_size, output_size), interpolation=cv2.INTER_AREA)
    except Exception: image = cv2.resize(image, (output_size, output_size), interpolation=cv2.INTER_AREA)
    b, g, r = cv2.split(image)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    g = clahe.apply(g)
    return cv2.merge((b, g, r))

def get_transforms(is_train=True):
    if is_train:
        # --- ENHANCEMENT: Added stronger augmentations ---
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=15, p=0.7),
            A.GridDistortion(p=0.3), # Added
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.7),
            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
        ])
    else:
        return None

# =============================================================================
# DATASET - ENHANCED FOR ROBUSTNESS
# =============================================================================
class Dataset4Channel(Dataset):
    def __init__(self, df, img_dir, seg_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.seg_dir = seg_dir
        self.transform = transform
        self.post_transform = A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406, 0.5], std=[0.229, 0.224, 0.225, 0.5]),
            ToTensorV2()
        ])

    def __len__(self): 
        return len(self.df)

    # --- ENHANCEMENT: This version safely handles missing images or masks ---
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # Dynamically find extension for original image
        base_img_path = os.path.join(self.img_dir, row['id_code'])
        img_path = f"{base_img_path}.png" if os.path.exists(f"{base_img_path}.png") else f"{base_img_path}.jpeg"
        
        seg_path = os.path.join(self.seg_dir, row['id_code'] + '.png')
        
        # Safely read original image
        if os.path.exists(img_path):
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        else:
            img = np.zeros((CFG.IMG_SIZE, CFG.IMG_SIZE, 3), dtype=np.uint8)
            print(f"Warning: Image not found at {img_path}, using black image.")

        # Safely read segmentation mask
        if os.path.exists(seg_path):
            mask = cv2.imread(seg_path, cv2.IMREAD_GRAYSCALE)
        else:
            mask = np.zeros((CFG.IMG_SIZE, CFG.IMG_SIZE), dtype=np.uint8)
            print(f"Warning: Mask not found at {seg_path}, using black mask.")
        
        img = preprocess_ben_graham(img, CFG.IMG_SIZE)
        mask = cv2.resize(mask, (CFG.IMG_SIZE, CFG.IMG_SIZE), interpolation=cv2.INTER_NEAREST)
        
        if self.transform:
            augmented = self.transform(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']
            
        img_4_channel = np.dstack((img, mask))
        img_4_channel = self.post_transform(image=img_4_channel)['image']
        label = torch.tensor(row['diagnosis'], dtype=torch.long)
        return img_4_channel, label

# =============================================================================
# MODEL DEFINITION
# =============================================================================
class EfficientNet4ChannelOrdinal(nn.Module):
    def __init__(self, model_name, num_classes=5, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool='avg')
        original_conv = self.backbone.conv_stem
        original_weights = original_conv.weight.clone()
        new_conv = nn.Conv2d(4, original_conv.out_channels, 
                             kernel_size=original_conv.kernel_size, stride=original_conv.stride, 
                             padding=original_conv.padding, bias=(original_conv.bias is not None))
        with torch.no_grad():
            new_conv.weight[:, :3] = original_weights
            new_conv.weight[:, 3] = original_weights.mean(dim=1)
        self.backbone.conv_stem = new_conv
        feature_dim = self.backbone.num_features
        self.classifier = nn.Sequential(nn.Dropout(0.5), nn.Linear(feature_dim, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes - 1))
    def forward(self, x): 
        return self.classifier(self.backbone(x))

# =============================================================================
# UTILITY FUNCTIONS (Losses, Metrics, etc.)
# =============================================================================
# (These advanced functions are already well-optimized, no changes needed)
class WeightedOrdinalFocalLoss(nn.Module):
    def __init__(self, num_classes=5, gamma=2.0, class_weights=None, label_smoothing=0.0):
        super().__init__(); self.num_classes, self.gamma, self.class_weights, self.label_smoothing = num_classes, gamma, class_weights, label_smoothing
        self.bce = nn.BCEWithLogitsLoss(reduction='none')
    def forward(self, outputs, targets):
        ordinal_targets = torch.zeros_like(outputs)
        for i, t in enumerate(targets):
            if t > 0: ordinal_targets[i, :t] = 1.0
        if self.label_smoothing > 0.0: ordinal_targets = ordinal_targets * (1.0 - self.label_smoothing) + 0.5 * self.label_smoothing
        bce = self.bce(outputs, ordinal_targets)
        if self.class_weights is not None:
            weights = self.class_weights[targets].view(-1, 1).expand(-1, outputs.shape[1])
            bce = bce * weights
        pt = torch.exp(-bce); focal = (1 - pt) ** self.gamma * bce
        return focal.mean()

class SmoothKappaLoss(nn.Module):
    def __init__(self, num_classes=5, eps=1e-7):
        super().__init__(); self.num_classes, self.eps = num_classes, eps
        W = torch.zeros(num_classes, num_classes)
        for i in range(num_classes):
            for j in range(num_classes): W[i,j] = ((i - j)**2) / ((num_classes - 1)**2)
        self.register_buffer("W", W)
    def forward(self, outputs, targets):
        device = outputs.device; B = outputs.size(0); probs = torch.sigmoid(outputs)
        class_probs = torch.zeros(B, self.num_classes, device=device)
        class_probs[:, 0] = 1 - probs[:, 0]
        for k in range(1, self.num_classes-1): class_probs[:, k] = probs[:, k-1] - probs[:, k]
        class_probs[:, -1] = probs[:, -1]
        class_probs = torch.clamp(class_probs, min=self.eps, max=1.0)
        one_hot = F.one_hot(targets, num_classes=self.num_classes).float().to(device)
        conf_mat = torch.matmul(one_hot.T, class_probs); hist_true = one_hot.sum(dim=0); hist_pred = class_probs.sum(dim=0)
        expected = torch.outer(hist_true, hist_pred)
        W = self.W.to(device); obs = torch.sum(W * conf_mat); exp = torch.sum(W * expected)
        kappa = 1.0 - (B * obs) / (exp + self.eps)
        return 1.0 - kappa

def mixup_data(x, y, alpha=0.4):
    if alpha > 0: lam = np.random.beta(alpha, alpha)
    else: lam = 1
    batch_size = x.size()[0]; index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]; y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def ordinal_to_class(outputs): return torch.sum(torch.sigmoid(outputs) > 0.5, dim=1).long()

def calculate_metrics(outputs, targets):
    preds, targets_np = ordinal_to_class(outputs).cpu().numpy(), targets.cpu().numpy()
    return accuracy_score(targets_np, preds), cohen_kappa_score(targets_np, preds, weights='quadratic')

def clear_memory(): gc.collect(); torch.cuda.empty_cache()

# =============================================================================
# TRAINING & VALIDATION FUNCTIONS
# =============================================================================
def train_epoch(model, loader, optimizer, criterion, scaler, device, use_mixup):
    model.train(); running_loss = 0.0; all_out, all_t = [], []
    pbar = tqdm(loader, desc="Training", leave=False)
    for images, targets in pbar:
        images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        if use_mixup: images, targets_a, targets_b, lam = mixup_data(images, targets)
        with torch.cuda.amp.autocast():
            outputs = model(images)
            if use_mixup: loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
            else: loss = criterion(outputs, targets)
        scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        running_loss += loss.item(); all_out.append(outputs.detach()); all_t.append(targets.detach())
        pbar.set_postfix(loss=loss.item())
    all_out, all_t = torch.cat(all_out), torch.cat(all_t)
    return running_loss / len(loader), *calculate_metrics(all_out, all_t)

def validate_epoch(model, loader, criterion, device):
    model.eval(); running_loss = 0.0; all_out, all_t = [], []
    with torch.no_grad():
        pbar = tqdm(loader, desc="Validating", leave=False)
        for images, targets in pbar:
            images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
            with torch.cuda.amp.autocast():
                outputs = model(images); loss = criterion(outputs, targets)
            running_loss += loss.item(); all_out.append(outputs); all_t.append(targets)
    all_out, all_t = torch.cat(all_out), torch.cat(all_t)
    return running_loss / len(loader), *calculate_metrics(all_out, all_t)

# =============================================================================
# MAIN TRAINING PIPELINE
# =============================================================================
def main():
    print(f"Device: {CFG.DEVICE}, Model: {CFG.MODEL_NAME} (4-Channel), Image Size: {CFG.IMG_SIZE}")
    train_df, val_df = pd.read_csv(CFG.TRAIN_CSV), pd.read_csv(CFG.VAL_CSV)
    
    train_ds = Dataset4Channel(train_df, CFG.TRAIN_DIR, CFG.SEG_TRAIN_DIR, transform=get_transforms(is_train=True))
    val_ds   = Dataset4Channel(val_df, CFG.VAL_DIR, CFG.SEG_VAL_DIR, transform=get_transforms(is_train=False))

    class_weights = compute_class_weight('balanced', classes=np.unique(train_df['diagnosis']), y=train_df['diagnosis'])
    sample_weights = np.array([class_weights[int(l)] for l in train_df['diagnosis']])
    sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
    train_loader = DataLoader(train_ds, batch_size=CFG.BATCH_SIZE, sampler=sampler, num_workers=CFG.NUM_WORKERS, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=CFG.BATCH_SIZE*2, shuffle=False, num_workers=CFG.NUM_WORKERS, pin_memory=True)
    
    model = EfficientNet4ChannelOrdinal(CFG.MODEL_NAME).to(CFG.DEVICE)
    class_weights_loss = torch.tensor(class_weights, dtype=torch.float).to(CFG.DEVICE)
    focal_loss = WeightedOrdinalFocalLoss(num_classes=5, gamma=2.0, class_weights=class_weights_loss, label_smoothing=CFG.LABEL_SMOOTHING)
    kappa_loss = SmoothKappaLoss(num_classes=5)
    def hybrid_loss(outputs, targets): return 0.7 * kappa_loss(outputs, targets) + 0.3 * focal_loss(outputs, targets)
    
    scaler = torch.cuda.amp.GradScaler()

    # --- STAGE 1 ---
    print("\n" + "="*50 + "\n     STARTING STAGE 1 (4-Channel)\n" + "="*50)
    opt = optim.AdamW(model.parameters(), lr=CFG.S1_LR, weight_decay=1e-4)
    # --- ENHANCEMENT: Using improved scheduler ---
    sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=10, T_mult=1, eta_min=1e-6)
    best_val_qwk, patience_counter = -1, 0

    for epoch in range(CFG.S1_EPOCHS):
        clear_memory(); print(f"\nEpoch {epoch+1}/{CFG.S1_EPOCHS}")
        train_loss, train_acc, train_qwk = train_epoch(model, train_loader, opt, focal_loss, scaler, CFG.DEVICE, CFG.S1_USE_MIXUP)
        val_loss, val_acc, val_qwk = validate_epoch(model, val_loader, focal_loss, CFG.DEVICE)
        sched.step()
        print(f"Train -> Loss:{train_loss:.4f} Acc:{train_acc:.4f} QWK:{train_qwk:.4f}")
        print(f"Valid -> Loss:{val_loss:.4f} Acc:{val_acc:.4f} QWK:{val_qwk:.4f}")
        if val_qwk > best_val_qwk:
            print(f"Val QWK improved from {best_val_qwk:.4f} to {val_qwk:.4f}. Saving model...")
            best_val_qwk, patience_counter = val_qwk, 0
            torch.save(model.state_dict(), CFG.SAVE_PATH_S1)
        else:
            patience_counter += 1
            if patience_counter >= CFG.PATIENCE: print("Early stopping in Stage 1."); break
    
    # --- STAGE 2 ---
    print("\n" + "="*50 + "\n     STARTING STAGE 2 (4-Channel)\n" + "="*50)
    if os.path.exists(CFG.SAVE_PATH_S1): model.load_state_dict(torch.load(CFG.SAVE_PATH_S1))
    else: print("No Stage 1 model was saved. Continuing with the current model.")

    opt = optim.AdamW(model.parameters(), lr=CFG.S2_LR, weight_decay=1e-5)
    # --- ENHANCEMENT: Using improved scheduler ---
    sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=10, T_mult=1, eta_min=3e-6)
    best_val_qwk_stage2, patience_counter = best_val_qwk, 0

    for epoch in range(CFG.S2_EPOCHS):
        clear_memory(); print(f"\nEpoch {epoch+1}/{CFG.S2_EPOCHS}")
        train_loss, train_acc, train_qwk = train_epoch(model, train_loader, opt, hybrid_loss, scaler, CFG.DEVICE, CFG.S2_USE_MIXUP)
        val_loss, val_acc, val_qwk = validate_epoch(model, val_loader, hybrid_loss, CFG.DEVICE)
        sched.step()
        print(f"Train -> Loss:{train_loss:.4f} Acc:{train_acc:.4f} QWK:{train_qwk:.4f}")
        print(f"Valid -> Loss:{val_loss:.4f} Acc:{val_acc:.4f} QWK:{val_qwk:.4f}")
        if val_qwk > best_val_qwk_stage2:
            print(f"Val QWK improved from {best_val_qwk_stage2:.4f} to {val_qwk:.4f}. Saving final model...")
            best_val_qwk_stage2, patience_counter = val_qwk, 0
            torch.save(model.state_dict(), CFG.SAVE_PATH_FINAL)
        else:
            patience_counter += 1
            if patience_counter >= CFG.PATIENCE: print("Early stopping in Stage 2."); break

    print(f"\nTraining Finished!\nFinal Best QWK: {best_val_qwk_stage2:.4f}")

if __name__ == "__main__":
    main()

In [None]:
import os
import gc
import cv2
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import albumentations as A
from albumentations.pytorch import ToTensorV2

from sklearn.metrics import accuracy_score, cohen_kappa_score
from sklearn.utils.class_weight import compute_class_weight

import timm

# =============================================================================
# CONFIGURATION FOR 4-CHANNEL (IMAGE + VESSEL MASK) INPUT
# =============================================================================
class CFG:
    # --- ENHANCEMENT: Upgraded model for significantly more power ---
    MODEL_NAME = 'efficientnet_b5'
    IMG_SIZE = 384
    BATCH_SIZE = 8 # Note: You may need to lower this to 4 if you get memory errors with B5

    # --- DATA PATHS (Unchanged as requested) ---
    BASE_PATH = "/kaggle/input/aptos2019"
    TRAIN_CSV = os.path.join(BASE_PATH, "train_1.csv")
    VAL_CSV   = os.path.join(BASE_PATH, "valid.csv")
    TRAIN_DIR = os.path.join(BASE_PATH, "train_images", "train_images")
    VAL_DIR   = os.path.join(BASE_PATH, "val_images", "val_images")
    
    # --- PATHS TO YOUR NEW SEGMENTED MASKS (Unchanged as requested) ---
    SEG_BASE_PATH = "/kaggle/working"
    SEG_TRAIN_DIR = os.path.join(SEG_BASE_PATH, "train_images_predicted")
    SEG_VAL_DIR   = os.path.join(SEG_BASE_PATH, "val_images_predicted")

    # --- ENHANCEMENT: Increased training duration for the larger model ---
    S1_EPOCHS = 20; S1_LR = 1e-4; S1_USE_MIXUP = True
    S2_EPOCHS = 20; S2_LR = 3e-5; S2_USE_MIXUP = False
    
    # --- GENERAL & SAVING ---
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    NUM_WORKERS = 2
    PATIENCE = 5
    SEED = 42
    LABEL_SMOOTHING = 0.05
    # Updated save paths for the new model
    SAVE_PATH_S1 = "best_model_effnet_b5_seg_stage1.pth"
    SAVE_PATH_FINAL = "best_model_effnet_b5_seg_final.pth"

def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = True
seed_everything(CFG.SEED)

# =============================================================================
# PREPROCESSING & AUGMENTATIONS
# =============================================================================
def preprocess_ben_graham(image, output_size):
    try:
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        if gray.mean() < 15: image = cv2.resize(image, (output_size, output_size), interpolation=cv2.INTER_AREA)
        else:
            _, thresh = cv2.threshold(gray, 15, 255, cv2.THRESH_BINARY)
            contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            if contours:
                largest_contour = max(contours, key=cv2.contourArea)
                x, y, w, h = cv2.boundingRect(largest_contour)
                image = image[y:y+h, x:x+w]
            image = cv2.resize(image, (output_size, output_size), interpolation=cv2.INTER_AREA)
    except Exception: image = cv2.resize(image, (output_size, output_size), interpolation=cv2.INTER_AREA)
    b, g, r = cv2.split(image)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    g = clahe.apply(g)
    return cv2.merge((b, g, r))

def get_transforms(is_train=True):
    if is_train:
        # --- ENHANCEMENT: Added stronger augmentations ---
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=15, p=0.7),
            A.GridDistortion(p=0.3),
            A.ElasticTransform(p=0.3, alpha=120, sigma=120 * 0.05),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.7),
            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
        ])
    else:
        return None

# =============================================================================
# DATASET - ENHANCED FOR ROBUSTNESS
# =============================================================================
class Dataset4Channel(Dataset):
    def __init__(self, df, img_dir, seg_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.seg_dir = seg_dir
        self.transform = transform
        self.post_transform = A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406, 0.5], std=[0.229, 0.224, 0.225, 0.5]),
            ToTensorV2()
        ])

    def __len__(self): 
        return len(self.df)

    # --- ENHANCEMENT: This version safely handles missing images or masks ---
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['id_code'] + '.png')
        seg_path = os.path.join(self.seg_dir, row['id_code'] + '.png')
        
        if os.path.exists(img_path):
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        else:
            img = np.zeros((CFG.IMG_SIZE, CFG.IMG_SIZE, 3), dtype=np.uint8)
            print(f"Warning: Image not found at {img_path}, using black image.")

        if os.path.exists(seg_path):
            mask = cv2.imread(seg_path, cv2.IMREAD_GRAYSCALE)
        else:
            mask = np.zeros((CFG.IMG_SIZE, CFG.IMG_SIZE), dtype=np.uint8)
            print(f"Warning: Mask not found at {seg_path}, using black mask.")
        
        img = preprocess_ben_graham(img, CFG.IMG_SIZE)
        mask = cv2.resize(mask, (CFG.IMG_SIZE, CFG.IMG_SIZE), interpolation=cv2.INTER_NEAREST)
        
        if self.transform:
            augmented = self.transform(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']
            
        img_4_channel = np.dstack((img, mask))
        img_4_channel = self.post_transform(image=img_4_channel)['image']
        label = torch.tensor(row['diagnosis'], dtype=torch.long)
        return img_4_channel, label

# =============================================================================
# MODEL DEFINITION
# =============================================================================
class EfficientNet4ChannelOrdinal(nn.Module):
    def __init__(self, model_name, num_classes=5, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool='avg')
        original_conv = self.backbone.conv_stem
        original_weights = original_conv.weight.clone()
        new_conv = nn.Conv2d(4, original_conv.out_channels, 
                             kernel_size=original_conv.kernel_size, stride=original_conv.stride, 
                             padding=original_conv.padding, bias=(original_conv.bias is not None))
        with torch.no_grad():
            new_conv.weight[:, :3] = original_weights
            new_conv.weight[:, 3] = original_weights.mean(dim=1)
        self.backbone.conv_stem = new_conv
        feature_dim = self.backbone.num_features
        self.classifier = nn.Sequential(nn.Dropout(0.5), nn.Linear(feature_dim, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes - 1))
    def forward(self, x): 
        return self.classifier(self.backbone(x))

# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================
class WeightedOrdinalFocalLoss(nn.Module):
    def __init__(self, num_classes=5, gamma=2.0, class_weights=None, label_smoothing=0.0):
        super().__init__(); self.num_classes, self.gamma, self.class_weights, self.label_smoothing = num_classes, gamma, class_weights, label_smoothing
        self.bce = nn.BCEWithLogitsLoss(reduction='none')
    def forward(self, outputs, targets):
        ordinal_targets = torch.zeros_like(outputs)
        for i, t in enumerate(targets):
            if t > 0: ordinal_targets[i, :t] = 1.0
        if self.label_smoothing > 0.0: ordinal_targets = ordinal_targets * (1.0 - self.label_smoothing) + 0.5 * self.label_smoothing
        bce = self.bce(outputs, ordinal_targets)
        if self.class_weights is not None:
            weights = self.class_weights[targets].view(-1, 1).expand(-1, outputs.shape[1])
            bce = bce * weights
        pt = torch.exp(-bce); focal = (1 - pt) ** self.gamma * bce
        return focal.mean()

class SmoothKappaLoss(nn.Module):
    def __init__(self, num_classes=5, eps=1e-7):
        super().__init__(); self.num_classes, self.eps = num_classes, eps
        W = torch.zeros(num_classes, num_classes)
        for i in range(num_classes):
            for j in range(num_classes): W[i,j] = ((i - j)**2) / ((num_classes - 1)**2)
        self.register_buffer("W", W)
    def forward(self, outputs, targets):
        device=outputs.device; B=outputs.size(0); probs=torch.sigmoid(outputs)
        class_probs=torch.zeros(B,self.num_classes,device=device)
        class_probs[:,0]=1-probs[:,0]
        for k in range(1,self.num_classes-1): class_probs[:,k]=probs[:,k-1]-probs[:,k]
        class_probs[:,-1]=probs[:,-1]
        class_probs=torch.clamp(class_probs,min=self.eps,max=1.0)
        one_hot=F.one_hot(targets,num_classes=self.num_classes).float().to(device)
        conf_mat=torch.matmul(one_hot.T,class_probs); hist_true=one_hot.sum(dim=0); hist_pred=class_probs.sum(dim=0)
        expected=torch.outer(hist_true,hist_pred)
        W=self.W.to(device); obs=torch.sum(W*conf_mat); exp=torch.sum(W*expected)
        kappa=1.0-(B*obs)/(exp+self.eps)
        return 1.0 - kappa

def mixup_data(x, y, alpha=0.4):
    if alpha > 0: lam = np.random.beta(alpha, alpha)
    else: lam = 1
    batch_size = x.size()[0]; index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]; y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def ordinal_to_class(outputs): return torch.sum(torch.sigmoid(outputs) > 0.5, dim=1).long()

def calculate_metrics(outputs, targets):
    preds, targets_np = ordinal_to_class(outputs).cpu().numpy(), targets.cpu().numpy()
    return accuracy_score(targets_np, preds), cohen_kappa_score(targets_np, preds, weights='quadratic')

def clear_memory(): gc.collect(); torch.cuda.empty_cache()

# =============================================================================
# TRAINING & VALIDATION FUNCTIONS
# =============================================================================
def train_epoch(model, loader, optimizer, criterion, scaler, device, use_mixup):
    model.train(); running_loss = 0.0; all_out, all_t = [], []
    pbar = tqdm(loader, desc="Training", leave=False)
    for images, targets in pbar:
        images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        if use_mixup: images, targets_a, targets_b, lam = mixup_data(images, targets)
        with torch.cuda.amp.autocast():
            outputs = model(images)
            if use_mixup: loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
            else: loss = criterion(outputs, targets)
        scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        running_loss += loss.item(); all_out.append(outputs.detach()); all_t.append(targets.detach())
        pbar.set_postfix(loss=loss.item())
    all_out, all_t = torch.cat(all_out), torch.cat(all_t)
    return running_loss / len(loader), *calculate_metrics(all_out, all_t)

def validate_epoch(model, loader, criterion, device):
    model.eval(); running_loss = 0.0; all_out, all_t = [], []
    with torch.no_grad():
        pbar = tqdm(loader, desc="Validating", leave=False)
        for images, targets in pbar:
            images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
            with torch.cuda.amp.autocast():
                outputs = model(images); loss = criterion(outputs, targets)
            running_loss += loss.item(); all_out.append(outputs); all_t.append(targets)
    all_out, all_t = torch.cat(all_out), torch.cat(all_t)
    return running_loss / len(loader), *calculate_metrics(all_out, all_t)

# =============================================================================
# MAIN TRAINING PIPELINE
# =============================================================================
def main():
    print(f"Device: {CFG.DEVICE}, Model: {CFG.MODEL_NAME} (4-Channel), Image Size: {CFG.IMG_SIZE}")
    try:
        train_df, val_df = pd.read_csv(CFG.TRAIN_CSV), pd.read_csv(CFG.VAL_CSV)
    except FileNotFoundError:
        print(f"CRITICAL ERROR: Could not find {CFG.TRAIN_CSV} or {CFG.VAL_CSV}.")
        print("Please ensure you have run the data splitting cell first, or that the paths are correct.")
        return

    train_ds = Dataset4Channel(train_df, CFG.TRAIN_DIR, CFG.SEG_TRAIN_DIR, transform=get_transforms(is_train=True))
    val_ds   = Dataset4Channel(val_df, CFG.VAL_DIR, CFG.SEG_VAL_DIR, transform=get_transforms(is_train=False))

    class_weights = compute_class_weight('balanced', classes=np.unique(train_df['diagnosis']), y=train_df['diagnosis'])
    sample_weights = np.array([class_weights[int(l)] for l in train_df['diagnosis']])
    sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
    train_loader = DataLoader(train_ds, batch_size=CFG.BATCH_SIZE, sampler=sampler, num_workers=CFG.NUM_WORKERS, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=CFG.BATCH_SIZE*2, shuffle=False, num_workers=CFG.NUM_WORKERS, pin_memory=True)
    
    model = EfficientNet4ChannelOrdinal(CFG.MODEL_NAME).to(CFG.DEVICE)
    class_weights_loss = torch.tensor(class_weights, dtype=torch.float).to(CFG.DEVICE)
    focal_loss = WeightedOrdinalFocalLoss(num_classes=5, gamma=2.0, class_weights=class_weights_loss, label_smoothing=CFG.LABEL_SMOOTHING)
    kappa_loss = SmoothKappaLoss(num_classes=5)
    def hybrid_loss(outputs, targets): return 0.7 * kappa_loss(outputs, targets) + 0.3 * focal_loss(outputs, targets)
    
    scaler = torch.cuda.amp.GradScaler()

    # --- STAGE 1 ---
    print("\n" + "="*50 + "\n     STARTING STAGE 1 (4-Channel)\n" + "="*50)
    opt = optim.AdamW(model.parameters(), lr=CFG.S1_LR, weight_decay=1e-4)
    # --- ENHANCEMENT: Using improved scheduler ---
    sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=10, T_mult=1, eta_min=1e-6)
    best_val_qwk, patience_counter = -1, 0

    for epoch in range(CFG.S1_EPOCHS):
        clear_memory(); print(f"\nEpoch {epoch+1}/{CFG.S1_EPOCHS}")
        train_loss, train_acc, train_qwk = train_epoch(model, train_loader, opt, focal_loss, scaler, CFG.DEVICE, CFG.S1_USE_MIXUP)
        val_loss, val_acc, val_qwk = validate_epoch(model, val_loader, focal_loss, CFG.DEVICE)
        sched.step()
        print(f"Train -> Loss:{train_loss:.4f} Acc:{train_acc:.4f} QWK:{train_qwk:.4f}")
        print(f"Valid -> Loss:{val_loss:.4f} Acc:{val_acc:.4f} QWK:{val_qwk:.4f}")
        if val_qwk > best_val_qwk:
            print(f"Val QWK improved from {best_val_qwk:.4f} to {val_qwk:.4f}. Saving model...")
            best_val_qwk, patience_counter = val_qwk, 0
            torch.save(model.state_dict(), CFG.SAVE_PATH_S1)
        else:
            patience_counter += 1
            if patience_counter >= CFG.PATIENCE: print("Early stopping in Stage 1."); break
    
    # --- STAGE 2 ---
    print("\n" + "="*50 + "\n     STARTING STAGE 2 (4-Channel)\n" + "="*50)
    if os.path.exists(CFG.SAVE_PATH_S1): model.load_state_dict(torch.load(CFG.SAVE_PATH_S1))
    else: print("No Stage 1 model was saved. Continuing with the current model.")

    opt = optim.AdamW(model.parameters(), lr=CFG.S2_LR, weight_decay=1e-5)
    # --- ENHANCEMENT: Using improved scheduler ---
    sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=10, T_mult=1, eta_min=3e-6)
    best_val_qwk_stage2, patience_counter = best_val_qwk, 0

    for epoch in range(CFG.S2_EPOCHS):
        clear_memory(); print(f"\nEpoch {epoch+1}/{CFG.S2_EPOCHS}")
        train_loss, train_acc, train_qwk = train_epoch(model, train_loader, opt, hybrid_loss, scaler, CFG.DEVICE, CFG.S2_USE_MIXUP)
        val_loss, val_acc, val_qwk = validate_epoch(model, val_loader, hybrid_loss, CFG.DEVICE)
        sched.step()
        print(f"Train -> Loss:{train_loss:.4f} Acc:{train_acc:.4f} QWK:{train_qwk:.4f}")
        print(f"Valid -> Loss:{val_loss:.4f} Acc:{val_acc:.4f} QWK:{val_qwk:.4f}")
        if val_qwk > best_val_qwk_stage2:
            print(f"Val QWK improved from {best_val_qwk_stage2:.4f} to {val_qwk:.4f}. Saving final model...")
            best_val_qwk_stage2, patience_counter = val_qwk, 0
            torch.save(model.state_dict(), CFG.SAVE_PATH_FINAL)
        else:
            patience_counter += 1
            if patience_counter >= CFG.PATIENCE: print("Early stopping in Stage 2."); break

    print(f"\nTraining Finished!\nFinal Best QWK: {best_val_qwk_stage2:.4f}")

if __name__ == "__main__":
    main()

In [None]:
import torch
import numpy as np
import pandas as pd
import cv2
import os
import timm
from tqdm.auto import tqdm
from sklearn.metrics import cohen_kappa_score, accuracy_score, classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from scipy.optimize import minimize

# =============================================================================
# CONFIGURATION FOR 4-CHANNEL EFFICIENTNET-B5
# =============================================================================
class CFG:
    # --- MODEL CONFIG ---
    MODEL_NAME = 'efficientnet_b5'
    IMG_SIZE = 384
    
    # --- PATHS ---
    BASE_PATH = "/kaggle/input/aptos2019"
    VAL_CSV = os.path.join(BASE_PATH, "valid.csv")
    VAL_DIR = os.path.join(BASE_PATH, "val_images", "val_images")
    TEST_CSV = os.path.join(BASE_PATH, "test.csv")
    TEST_DIR = os.path.join(BASE_PATH, "test_images", "test_images")
    
    # --- PATHS TO YOUR SEGMENTED MASKS ---
    SEG_BASE_PATH = "/kaggle/input/segmentaion-dataset"
    SEG_VAL_DIR   = os.path.join(SEG_BASE_PATH, "segmented_outputs_val/segmented_outputs_val/")
    SEG_TEST_DIR  = os.path.join(SEG_BASE_PATH, "segmented_outputs_test/segmented_outputs_test/") # Assuming this is your test mask path
    
    # Path to your saved 4-channel model
    MODEL_PATH = "/kaggle/working/best_model_effnet_b5_seg_final.pth"
    
    # --- INFERENCE CONFIG ---
    BATCH_SIZE = 8 
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    NUM_WORKERS = 2

# =============================================================================
# REUSED CLASSES & PREPROCESSING
# =============================================================================
class EfficientNet4ChannelOrdinal(nn.Module):
    def __init__(self, model_name, num_classes=5, pretrained=False):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool='avg')
        original_conv = self.backbone.conv_stem; original_weights = original_conv.weight.clone()
        new_conv = nn.Conv2d(4, original_conv.out_channels, 
                             kernel_size=original_conv.kernel_size, stride=original_conv.stride, 
                             padding=original_conv.padding, bias=(original_conv.bias is not None))
        with torch.no_grad():
            new_conv.weight[:, :3] = original_weights
            new_conv.weight[:, 3] = original_weights.mean(dim=1)
        self.backbone.conv_stem = new_conv
        feature_dim = self.backbone.num_features
        self.classifier = nn.Sequential(nn.Dropout(0.5), nn.Linear(feature_dim, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes - 1))
    def forward(self, x): return self.classifier(self.backbone(x))

def preprocess_ben_graham(image, output_size):
    try:
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        if gray.mean() < 15: image = cv2.resize(image, (output_size, output_size), interpolation=cv2.INTER_AREA)
        else:
            _, thresh = cv2.threshold(gray, 15, 255, cv2.THRESH_BINARY)
            contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            if contours:
                largest_contour = max(contours, key=cv2.contourArea); x, y, w, h = cv2.boundingRect(largest_contour)
                image = image[y:y+h, x:x+w]
            image = cv2.resize(image, (output_size, output_size), interpolation=cv2.INTER_AREA)
    except Exception: image = cv2.resize(image, (output_size, output_size), interpolation=cv2.INTER_AREA)
    b, g, r = cv2.split(image); clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)); g = clahe.apply(g)
    return cv2.merge((b, g, r))

class Dataset4Channel(Dataset):
    def __init__(self, df, img_dir, seg_dir, transform=None):
        self.df, self.img_dir, self.seg_dir, self.transform = df.reset_index(drop=True), img_dir, seg_dir, transform
        self.post_transform = A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406, 0.5], std=[0.229, 0.224, 0.225, 0.5]), ToTensorV2()])
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]; img_path = os.path.join(self.img_dir, row['id_code'] + '.png'); seg_path = os.path.join(self.seg_dir, row['id_code'] + '.png')
        img = cv2.imread(img_path); img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(seg_path, cv2.IMREAD_GRAYSCALE)
        img = preprocess_ben_graham(img, CFG.IMG_SIZE)
        mask = cv2.resize(mask, (CFG.IMG_SIZE, CFG.IMG_SIZE), interpolation=cv2.INTER_NEAREST)
        if self.transform:
            augmented = self.transform(image=img, mask=mask); img = augmented['image']; mask = augmented['mask']
        img_4_channel = np.dstack((img, mask))
        img_4_channel = self.post_transform(image=img_4_channel)['image']
        label = torch.tensor(row['diagnosis'], dtype=torch.long)
        return img_4_channel, label

# =============================================================================
# OPTIMIZATION FUNCTIONS
# =============================================================================
def ordinal_to_class_with_thresholds(outputs, thresholds):
    probs = torch.sigmoid(outputs).cpu().numpy()
    preds = np.sum(probs > thresholds, axis=1)
    return preds

def kappa_objective(thresholds, outputs, targets):
    preds = ordinal_to_class_with_thresholds(outputs, thresholds)
    return -cohen_kappa_score(targets, preds, weights="quadratic")

def find_best_thresholds(outputs, targets):
    print("Finding optimal thresholds...")
    outputs = outputs.detach(); targets = targets.cpu().numpy()
    init_thresh = np.array([0.5, 0.5, 0.5, 0.5]); bounds = [(0.1, 0.9)] * len(init_thresh)
    res = minimize(kappa_objective, init_thresh, args=(outputs, targets), method="Powell", bounds=bounds)
    best_thresholds = res.x
    print(f"Optimal thresholds found: {np.round(best_thresholds, 4)}")
    return best_thresholds

# =============================================================================
# MAIN SCRIPT
# =============================================================================
def run_optimization_and_test():
    # --- Step 0: Load Model ---
    model = EfficientNet4ChannelOrdinal(CFG.MODEL_NAME, pretrained=False).to(CFG.DEVICE)
    model.load_state_dict(torch.load(CFG.MODEL_PATH, map_location=CFG.DEVICE))
    model.eval()
    print(f"Model loaded successfully from {CFG.MODEL_PATH}")

    # --- Step 1: Get Raw Answers on the Validation Set ---
    print("\n--- Step 1: Evaluating on Validation Set to find thresholds ---")
    val_df = pd.read_csv(CFG.VAL_CSV)
    val_dataset = Dataset4Channel(val_df, CFG.VAL_DIR, CFG.SEG_VAL_DIR, transform=None) # No augs for val
    val_loader = DataLoader(val_dataset, batch_size=CFG.BATCH_SIZE, shuffle=False, num_workers=CFG.NUM_WORKERS)
    
    val_outputs_list, val_labels_list = [], []
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Getting Validation Outputs"):
            images = images.to(CFG.DEVICE); outputs = model(images)
            val_outputs_list.append(outputs.cpu()); val_labels_list.append(labels)
    val_outputs = torch.cat(val_outputs_list); val_labels = torch.cat(val_labels_list)
    
    # --- Step 2: Find the Perfect "Grading Scale" ---
    print("\n--- Step 2: Optimizing Thresholds ---")
    best_thresholds = find_best_thresholds(val_outputs, val_labels)

    # --- Step 3: Use the Grading Scale on the Test Set ---
    print("\n--- Step 3: Evaluating on Test Set with new thresholds ---")
    test_df = pd.read_csv(CFG.TEST_CSV)
    test_dataset = Dataset4Channel(test_df, CFG.TEST_DIR, CFG.SEG_TEST_DIR, transform=None) # No augs for test
    test_loader = DataLoader(test_dataset, batch_size=CFG.BATCH_SIZE, shuffle=False, num_workers=CFG.NUM_WORKERS)

    test_outputs_list, test_labels_list = [], []
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Getting Test Outputs"):
            images = images.to(CFG.DEVICE); outputs = model(images)
            test_outputs_list.append(outputs.cpu()); test_labels_list.append(labels)
    test_outputs = torch.cat(test_outputs_list); test_labels = torch.cat(test_labels_list).numpy()
    
    # --- FINAL RESULTS ---
    print("\n" + "="*50)
    print("      FINAL RESULTS COMPARISON (EffNet-B5 + Segmentation)")
    print("="*50)

    preds_old = torch.sum(torch.sigmoid(test_outputs) > 0.5, dim=1).numpy()
    qwk_old = cohen_kappa_score(test_labels, preds_old, weights='quadratic')
    acc_old = accuracy_score(test_labels, preds_old)
    print(f"\nOriginal Score (Threshold = 0.5):"); print(f"  QWK: {qwk_old:.4f}"); print(f"  Accuracy: {acc_old*100:.2f}%")

    preds_new = ordinal_to_class_with_thresholds(test_outputs, best_thresholds)
    qwk_new = cohen_kappa_score(test_labels, preds_new, weights='quadratic')
    acc_new = accuracy_score(test_labels, preds_new)
    print(f"\nPolished Score (Optimized Thresholds):"); print(f"  QWK: {qwk_new:.4f}"); print(f"  Accuracy: {acc_new*100:.2f}%")
    print("="*50)
    
    print("\n--- Polished Classification Report ---"); print(classification_report(test_labels, preds_new, target_names=[f"Class {i}" for i in range(5)]))
    
    print("\n--- Polished Confusion Matrix ---")
    cm = confusion_matrix(test_labels, preds_new)
    plt.figure(figsize=(8, 6)); sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=range(5), yticklabels=range(5))
    plt.xlabel("Predicted Label (Optimized)"); plt.ylabel("True Label"); plt.title("Confusion Matrix (Optimized Thresholds)")
    plt.show()
    
run_optimization_and_test()
