In [None]:
import os, math, random, warnings, cv2
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import timm
from tqdm import tqdm
import torch.nn.functional as F

warnings.filterwarnings("ignore")
os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = str(2**64) # Handle large images if needed

# -------------------------
# Config
# -------------------------
SEED = 42
IMG_SIZE = 512
BATCH_SIZE = 8
EPOCHS = 4
# --- New Training Config ---
EARLY_STOPPING_PATIENCE = 5 # Stop after this many epochs without AUC improvement
# ---
WARMUP_EPOCHS = 1
BASE_LR = 2e-5
HEAD_LR = 8e-5
WEIGHT_DECAY = 1e-2
GRAD_CLIP_NORM = 1.0
EMA_DECAY = 0.999
FOCAL_GAMMA = 2.0
NUM_WORKERS = 4
TRAIN_CSV = "/kaggle/input/grand-xray-slam-division-a/train1.csv"
TRAIN_DIR = "/kaggle/input/grand-xray-slam-division-a/train1"
TEST_DIR  = "/kaggle/input/grand-xray-slam-division-a/test1"
LABEL_COLS = [
    'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum',
    'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion',
    'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']
SAVE_PATH = "convnextv2_best_auc_checkpoint.pth" # Updated name to reflect saving best AUC

# -------------------------
# Repro
# -------------------------
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
set_seed()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# -------------------------
# Dataset without CLAHE (Used for Train and Validation)
# -------------------------
class XRayDataset(Dataset):
    def __init__(self, df, image_dir, img_size=IMG_SIZE, is_train=True):
        self.df = df.reset_index(drop=True)
        self.image_dir = image_dir
        self.img_size = img_size
        if is_train:
            self.tf = T.Compose([
                T.ToTensor(),
                T.RandomHorizontalFlip(p=0.5),
                T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
            ])
        else:
            # No data augmentation for validation
            self.tf = T.Compose([
                T.ToTensor(),
                T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
            ])
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = os.path.join(self.image_dir, row['Image_name'])
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            img = np.zeros((self.img_size, self.img_size), dtype=np.uint8)
        img = cv2.resize(img, (self.img_size, self.img_size), interpolation=cv2.INTER_CUBIC)
        img = cv2.merge([img,img,img])
        img = self.tf(img)
        y = torch.tensor(row[LABEL_COLS].values.astype(np.float32), dtype=torch.float32)
        return img, y

# -------------------------
# Focal Loss
# -------------------------
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction="mean"):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    def forward(self, logits, targets):
        bce = nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction="none")
        pt = torch.exp(-bce)
        loss = (1 - pt)**self.gamma * bce
        if self.alpha is not None:
            loss = loss * self.alpha.to(logits.device)
        if self.reduction == "mean":
            return loss.mean()
        return loss.sum()

# -------------------------
# EMA
# -------------------------
class EMA:
    def __init__(self, model, decay=EMA_DECAY):
        self.decay = decay
        # Use module.named_parameters() if model is DataParallel
        model_to_save = model.module if isinstance(model, nn.DataParallel) else model
        self.shadow = {n: p.detach().clone() for n, p in model_to_save.named_parameters() if p.requires_grad}
    @torch.no_grad()
    def update(self, model):
        model_to_update = model.module if isinstance(model, nn.DataParallel) else model
        for n, p in model_to_update.named_parameters():
            if n in self.shadow:
                self.shadow[n].mul_(self.decay).add_(p.detach(), alpha=1-self.decay)
    @torch.no_grad()
    def apply_to(self, model):
        model_to_apply = model.module if isinstance(model, nn.DataParallel) else model
        for n, p in model_to_apply.named_parameters():
            if n in self.shadow:
                p.copy_(self.shadow[n])

# -------------------------
# Early Stopper Class
# -------------------------
class EarlyStopper:
    def __init__(self, patience=5, min_delta=0, mode='max'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = -np.inf if mode == 'max' else np.inf
        self.early_stop = False

    def __call__(self, score):
        if self.mode == 'max':
            condition = score > self.best_score + self.min_delta
        else: # 'min'
            condition = score < self.best_score - self.min_delta

        if condition:
            self.best_score = score
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        return self.early_stop

# -------------------------
# Validation Function (for AUC)
# -------------------------
@torch.no_grad()
def validate_model(model, loader, target_labels):
    model.eval()
    all_preds = []
    
    pbar = tqdm(loader, desc=f"Validating")
    for imgs, _ in pbar:
        imgs = imgs.to(device)
        with torch.cuda.amp.autocast(enabled=device.type == "cuda"):
            logits = model(imgs)
        
        probs = torch.sigmoid(logits).cpu().numpy()
        all_preds.append(probs)

    all_preds = np.concatenate(all_preds, axis=0)
    
    # Calculate Mean AUC for all 14 labels
    try:
        # Check for columns where all true labels are the same (which breaks roc_auc_score)
        valid_cols = (target_labels.min(axis=0) != target_labels.max(axis=0))
        if not np.any(valid_cols):
             auc_score = 0.5 # Default if no valid columns exist
        else:
             auc_score = roc_auc_score(target_labels[:, valid_cols], all_preds[:, valid_cols], average='macro')
    except ValueError as e:
        print(f"AUC Error: {e}. Returning 0.5.")
        auc_score = 0.5 

    return auc_score

# -------------------------
# Data prep (Split Train/Val)
# -------------------------
df = pd.read_csv(TRAIN_CSV)
df[LABEL_COLS] = df[LABEL_COLS].apply(pd.to_numeric, errors="coerce").fillna(0)
if 'No Finding' in LABEL_COLS:
    others = [c for c in LABEL_COLS if c != 'No Finding']
    df['No Finding'] = (df[others].sum(axis=1) == 0).astype(int)

# --- SPLIT DATA ---
# Use stratify on 'No Finding' for a decent representation of clean/pathology
train_df, val_df = train_test_split(
    df, test_size=0.2, random_state=SEED, shuffle=True, stratify=df['No Finding']
)

# alpha weights for Focal Loss (calculated only on the TRAINING set)
pos_counts = train_df[LABEL_COLS].sum()
neg_counts = len(train_df) - pos_counts
alpha = torch.tensor((neg_counts/(pos_counts+1e-6)).values, dtype=torch.float32)

# --- Dataloaders ---
train_ds = XRayDataset(train_df, TRAIN_DIR, is_train=True)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)

val_ds = XRayDataset(val_df, TRAIN_DIR, is_train=False)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE*2, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True)
val_labels = val_df[LABEL_COLS].values # True labels for AUC calculation

# -------------------------
# GeM Pooling + Heavy Attention Head
# -------------------------
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps
    def forward(self, x):
        return F.adaptive_avg_pool2d(x.clamp(min=self.eps).pow(self.p), (1,1)).pow(1./self.p)

class HeavyAttentionHead(nn.Module):
    def __init__(self, in_ch, num_classes, num_heads=8, ff_mult=4):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=in_ch, num_heads=num_heads, batch_first=True)
        self.ff = nn.Sequential(
            nn.LayerNorm(in_ch),
            nn.Linear(in_ch, in_ch*ff_mult),
            nn.GELU(),
            nn.Linear(in_ch*ff_mult, in_ch),
        )
        self.ln = nn.LayerNorm(in_ch)
        self.fc = nn.Linear(in_ch, num_classes)
    def forward(self, x):
        x = x.unsqueeze(1)
        attn_out,_ = self.attn(x,x,x)
        x = x + attn_out
        x = x + self.ff(x)
        x = self.ln(x)
        x = x.squeeze(1)
        return self.fc(x)

# -------------------------
# Full Model Module
# -------------------------
class ConvNeXtV2_GeM_HeavyAttn(nn.Module):
    def __init__(self, num_classes=len(LABEL_COLS)):
        super().__init__()
        self.backbone = timm.create_model(
            "convnextv2_base.fcmae_ft_in22k_in1k",
            pretrained=True,
            num_classes=0
        )
        in_ch = self.backbone.num_features
        self.gem = GeM()
        self.dropout = nn.Dropout(0.3)
        self.head = HeavyAttentionHead(in_ch, num_classes, num_heads=8, ff_mult=4)
    def forward(self, x):
        x = self.backbone.forward_features(x)
        x = self.gem(x)
        x = torch.flatten(x,1)
        x = self.dropout(x)
        x = self.head(x)
        return x

def make_model():
    return ConvNeXtV2_GeM_HeavyAttn(num_classes=len(LABEL_COLS))

def param_groups(m):
    module = m.module if isinstance(m, nn.DataParallel) else m
    base_params = [p for p in module.backbone.parameters()]
    head_params = list(module.gem.parameters()) + list(module.dropout.parameters()) + list(module.head.parameters())
    return [
        {"params": base_params, "lr": BASE_LR, "weight_decay": WEIGHT_DECAY},
        {"params": head_params, "lr": HEAD_LR, "weight_decay": WEIGHT_DECAY},
    ]

criterion = FocalLoss(alpha=alpha, gamma=FOCAL_GAMMA)

# -------------------------
# Train Loop with Validation and Early Stopping
# -------------------------
model = make_model().to(device)
if torch.cuda.device_count() > 1:
    print(f"✅ Using {torch.cuda.device_count()} GPUs for DataParallel")
    model = nn.DataParallel(model)

optimizer = optim.AdamW(param_groups(model))
scheduler = optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=lambda e: (e+1)/WARMUP_EPOCHS if e<WARMUP_EPOCHS
    else 0.5*(1+math.cos(math.pi*(e-WARMUP_EPOCHS)/max(1,EPOCHS-WARMUP_EPOCHS))))

scaler = torch.cuda.amp.GradScaler(enabled=device.type=="cuda")
ema = EMA(model)

# Initialize Early Stopping
early_stopper = EarlyStopper(patience=EARLY_STOPPING_PATIENCE, mode='max')
best_val_auc = 0.0
best_model_state = None

print(f"Starting training with Validation Split. Patience: {EARLY_STOPPING_PATIENCE}")

for epoch in range(EPOCHS):
    # --- TRAINING STEP ---
    model.train()
    running=0
    pbar=tqdm(train_loader,desc=f"ConvNeXtV2+HeavyAttn Epoch {epoch+1}/{EPOCHS} (Train)")
    for imgs,y in pbar:
        imgs,y=imgs.to(device),y.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=device.type=="cuda"):
            out=model(imgs)
            loss=criterion(out,y)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(),GRAD_CLIP_NORM)
        scaler.step(optimizer); scaler.update()
        ema.update(model)
        running+=loss.item()
        pbar.set_postfix(loss=running/max(1,pbar.n))
    
    scheduler.step()
    
    # --- VALIDATION STEP ---
    val_auc = validate_model(model, val_loader, val_labels)
    print(f"\n[Epoch {epoch+1}/{EPOCHS}] Validation AUC: {val_auc:.4f}")
    
    # --- CHECKPOINTING AND EARLY STOPPING ---
    if val_auc > best_val_auc:
        print(f"🏆 AUC improved ({best_val_auc:.4f} -> {val_auc:.4f}). Saving best checkpoint.")
        best_val_auc = val_auc
        # Save the best model state to disk immediately
        model_to_save = model.module if isinstance(model, nn.DataParallel) else model
        best_model_state = copy.deepcopy(model_to_save.state_dict())
        torch.save(best_model_state, SAVE_PATH)
    
    if early_stopper(val_auc):
        print(f"\n🛑 Early stopping triggered after {early_stopper.counter} epochs without improvement.")
        break

# Apply EMA weights to the model after training finishes
ema.apply_to(model)

# Load the best checkpoint state for inference
if best_model_state is not None:
    print(f"Loading best model state (AUC: {best_val_auc:.4f}) for final inference.")
    model_to_load = model.module if isinstance(model, nn.DataParallel) else model
    model_to_load.load_state_dict(best_model_state)

print(f"✅ Training completed. Best model saved to {SAVE_PATH} (Best AUC: {best_val_auc:.4f})")

# -------------------------
# Inference Dataset without CLAHE
# -------------------------
class TestDataset(Dataset):
    def __init__(self, df, image_dir):
        self.df=df.reset_index(drop=True)
        self.image_dir=image_dir
        self.tf=T.Compose([
            T.ToTensor(),
            T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ])
    def __len__(self):return len(self.df)
    def __getitem__(self,idx):
        row=self.df.iloc[idx]
        path=os.path.join(self.image_dir,row['Image_name'])
        img=cv2.imread(path,cv2.IMREAD_GRAYSCALE)
        if img is None:
            img=np.zeros((IMG_SIZE,IMG_SIZE),dtype=np.uint8)
        img=cv2.resize(img,(IMG_SIZE,IMG_SIZE))
        img=cv2.merge([img,img,img])
        img=self.tf(img)
        return img,row['Image_name']

# Load the best weights saved during training
model_infer = make_model().to(device)
if torch.cuda.device_count() > 1:
    model_infer = nn.DataParallel(model_infer)

try:
    state_dict = torch.load(SAVE_PATH,map_location=device)
    if isinstance(model_infer, nn.DataParallel):
        model_infer.module.load_state_dict(state_dict)
    else:
        model_infer.load_state_dict(state_dict)
    print(f"Loaded best checkpoint from {SAVE_PATH} for inference.")
except FileNotFoundError:
    print(f"Checkpoint file {SAVE_PATH} not found. Using final trained weights.")
    # If the script stopped before saving the first checkpoint, use the current model state
    pass

model_infer.eval()
test_names=sorted(os.listdir(TEST_DIR))
test_df=pd.DataFrame({'Image_name':test_names})
test_ds=TestDataset(test_df,TEST_DIR)
test_loader=DataLoader(test_ds,batch_size=BATCH_SIZE,shuffle=False,num_workers=NUM_WORKERS,pin_memory=True)

@torch.no_grad()
def predict_tta(model,loader):
    preds_all=[];names_all=[]
    for imgs,names in tqdm(loader,desc="Predicting ConvNeXtV2+HeavyAttn"):
        imgs=imgs.to(device)
        
        # Original Image Prediction
        logits1=model(imgs)
        
        # Horizontal Flip TTA Prediction
        imgs_flipped=torch.flip(imgs,dims=[3])
        logits2=model(imgs_flipped)
        
        # Average logits
        logits=0.5*(logits1+logits2)
        probs=torch.sigmoid(logits).cpu().numpy()
        preds_all.append(probs)
        names_all.extend(names)
    return np.concatenate(preds_all,axis=0),names_all

preds,names=predict_tta(model_infer,test_loader)
sub=pd.DataFrame(preds,columns=LABEL_COLS)
sub.insert(0,"Image_name",names)
sub.to_csv("submission.csv",index=False)
print("✅ Created submission.csv")
print(sub.head())