In [1]:
!wget https://huggingface.co/MapleF/eva_x/resolve/main/eva_x_base_patch16_merged520k_mim.pt

--2025-10-04 10:39:44--  https://huggingface.co/MapleF/eva_x/resolve/main/eva_x_base_patch16_merged520k_mim.pt
Resolving huggingface.co (huggingface.co)... 3.169.137.5, 3.169.137.119, 3.169.137.19, ...
Connecting to huggingface.co (huggingface.co)|3.169.137.5|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cas-bridge.xethub.hf.co/xet-bridge-us/659e6a763232931c9a49ee20/f6e7bd72aa9ec12b574f8bbc89a4f2006ffa2cd5bc8dcdd4866ad4f0aa395a2a?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=cas%2F20251004%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20251004T103945Z&X-Amz-Expires=3600&X-Amz-Signature=4479b3aac7b7bbe9aba24754a151c79f02933fc5dfe7bdf865f50cd240f75e06&X-Amz-SignedHeaders=host&X-Xet-Cas-Uid=public&response-content-disposition=inline%3B+filename*%3DUTF-8%27%27eva_x_base_patch16_merged520k_mim.pt%3B+filename%3D%22eva_x_base_patch16_merged520k_mim.pt%22%3B&x-id=GetObject&Expires=1759577985&Policy=eyJTdGF0ZW1l

In [2]:
import torch
from timm.models.eva import Eva
from timm.layers import resample_abs_pos_embed, resample_patch_embed
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image
import os
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from sklearn.metrics import roc_auc_score 
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import torchvision.io as io
from sklearn.model_selection import StratifiedShuffleSplit, KFold
from torch.amp import autocast, GradScaler
from torchvision.transforms import autoaugment, transforms
from torchvision.io import read_image
from torchvision import transforms as T
# Use the recommended v2 transforms
from torchvision.transforms import v2
from transformers import get_cosine_schedule_with_warmup
import copy

In [3]:
torch.manual_seed(22)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [4]:
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader

cv2.setNumThreads(0)

class ChestXrayDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        self.label_cols = [
            "Atelectasis","Cardiomegaly","Consolidation","Edema",
            "Enlarged Cardiomediastinum","Fracture","Lung Lesion",
            "Lung Opacity","No Finding","Pleural Effusion","Pleural Other",
            "Pneumonia","Pneumothorax","Support Devices",
        ]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row["Image_name"])

        # Fast read with OpenCV. It may return None if file missing -> handle.
        img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
        if img is None:
            raise FileNotFoundError(f"Image not found: {img_path}")

        # If grayscale (H, W), convert to 3-channel
        if img.ndim == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        elif img.shape[2] == 4:
            # if RGBA, convert to RGB
            img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)

        # OpenCV loads BGR -> convert to RGB
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        if self.transform:
            augmented = self.transform(image=img)
            img = augmented["image"]

        labels = torch.tensor(row[self.label_cols].values.astype("float32"))
        return img, labels

# choose img_size once
img_size = 448
size_tuple = (img_size, img_size)   # IMPORTANT: albumentations expects a tuple

train_tfms = A.Compose([
    A.RandomResizedCrop(size=size_tuple, scale=(0.85, 1.0), ratio=(0.9, 1.1)),
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=12, p=0.4),
    A.CLAHE(p=0.3),
    A.RandomBrightnessContrast(p=0.3),
    A.CoarseDropout(max_holes=1, max_height=20, max_width=20, p=0.2),
    A.ShiftScaleRotate(shift_limit=0.03, scale_limit=0.05, rotate_limit=0, p=0.25),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

val_tfms = A.Compose([
    A.Resize(height=img_size, width=img_size),   # pass height & width
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

  A.CoarseDropout(max_holes=1, max_height=20, max_width=20, p=0.2),
  original_init(self, **validated_kwargs)


In [5]:
def checkpoint_filter_fn(
        state_dict,
        model,
        interpolation='bicubic',
        antialias=True,
):
    """ convert patch embedding weight from manual patchify + linear proj to conv"""
    out_dict = {}
    state_dict = state_dict.get('model_ema', state_dict)
    state_dict = state_dict.get('model', state_dict)
    state_dict = state_dict.get('module', state_dict)
    state_dict = state_dict.get('state_dict', state_dict)
    # prefix for loading OpenCLIP compatible weights
    if 'visual.trunk.pos_embed' in state_dict:
        prefix = 'visual.trunk.'
    elif 'visual.pos_embed' in state_dict:
        prefix = 'visual.'
    else:
        prefix = ''
    mim_weights = prefix + 'mask_token' in state_dict
    no_qkv = prefix + 'blocks.0.attn.q_proj.weight' in state_dict

    len_prefix = len(prefix)
    for k, v in state_dict.items():
        if prefix:
            if k.startswith(prefix):
                k = k[len_prefix:]
            else:
                continue

        if 'rope' in k:
            # fixed embedding no need to load buffer from checkpoint
            continue

        if 'patch_embed.proj.weight' in k:
            _, _, H, W = model.patch_embed.proj.weight.shape
            if v.shape[-1] != W or v.shape[-2] != H:
                v = resample_patch_embed(
                    v,
                    (H, W),
                    interpolation=interpolation,
                    antialias=antialias,
                    verbose=True,
                )
        elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
            # To resize pos embedding when using model at different size from pretrained weights
            num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
            v = resample_abs_pos_embed(
                v,
                new_size=model.patch_embed.grid_size,
                num_prefix_tokens=num_prefix_tokens,
                interpolation=interpolation,
                antialias=antialias,
                verbose=True,
            )

        k = k.replace('mlp.ffn_ln', 'mlp.norm')
        k = k.replace('attn.inner_attn_ln', 'attn.norm')
        k = k.replace('mlp.w12', 'mlp.fc1')
        k = k.replace('mlp.w1', 'mlp.fc1_g')
        k = k.replace('mlp.w2', 'mlp.fc1_x')
        k = k.replace('mlp.w3', 'mlp.fc2')
        if no_qkv:
            k = k.replace('q_bias', 'q_proj.bias')
            k = k.replace('v_bias', 'v_proj.bias')

        if mim_weights and k in ('mask_token', 'lm_head.weight', 'lm_head.bias', 'norm.weight', 'norm.bias'):
            if k == 'norm.weight' or k == 'norm.bias':
                # try moving norm -> fc norm on fine-tune, probably a better starting point than new init
                k = k.replace('norm', 'fc_norm')
            else:
                # skip pretrain mask token & head weights
                continue

        out_dict[k] = v

    return out_dict

class EVA_X(Eva):
    def __init__(self, **kwargs):
        super(EVA_X, self).__init__(**kwargs)

    def forward_features(self, x):
        x = self.patch_embed(x)
        x, rot_pos_embed = self._pos_embed(x)
        for blk in self.blocks:
            x = blk(x, rope=rot_pos_embed)
        x = self.norm(x)
        return x

    def forward_head(self, x, pre_logits: bool = False):
        if self.global_pool:
            x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
        x = self.fc_norm(x)
        x = self.head_drop(x)
        return x if pre_logits else self.head(x)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.forward_head(x)
        return x



def eva_x_base_patch16(pretrained=False):
    model = EVA_X(
        img_size=img_size,
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        qkv_fused=False,
        mlp_ratio=4 * 2 / 3,
        swiglu_mlp=True,
        scale_mlp=True,
        use_rot_pos_emb=True,
        ref_feat_shape=(14, 14),  # 224/16
    )
    eva_ckpt = checkpoint_filter_fn(torch.load(pretrained, map_location='cpu', weights_only=False), 
                        model)
    msg = model.load_state_dict(eva_ckpt, strict=False)
    print(msg)
    return model

In [6]:
#evax base
class EVAX_Model(nn.Module):
    def __init__(self, unfreeze_last_n_blocks=4 ):
        super().__init__()
        # Load Phikon backbone
        eva_x_b_pt = '/kaggle/working/eva_x_base_patch16_merged520k_mim.pt'  # Replace with your checkpoint path
        self.backbone = eva_x_base_patch16(pretrained=eva_x_b_pt)
        
        self.backbone.head=nn.Linear(in_features=768, out_features=14, bias=True)

        # # Also unfreeze head and normalisation layers around it
        for p in self.backbone.head.parameters():
            p.requires_grad = True

    
    
    def forward(self, x):
        backbone_output = self.backbone(x)
    
        #return self.head(backbone_output)
        return backbone_output

In [7]:
from torch.amp import autocast, GradScaler
from sklearn.metrics import roc_auc_score

scaler = GradScaler(device="cuda")

# -----------------------------
# Train One Epoch (with AMP)
# -----------------------------
# def train_one_epoch(model, loader, optimizer, criterion, ema=None, grad_clip=1.0):
#     model.train()
#     running_loss = 0.0
    
#     progress_bar = tqdm(loader, desc="[Train]")
#     for imgs, labels in progress_bar:
#         imgs, labels = imgs.to(DEVICE), labels.to(DEVICE).float()

#         optimizer.zero_grad()
#         with autocast(device_type="cuda" ):  # mixed precision forward
#             outputs = model(imgs)
#             loss = criterion(outputs, labels)

#         scaler.scale(loss).backward()
#         # scaler.unscale_(optimizer)
#         # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#         scaler.step(optimizer)
#         scaler.update()
        
#         if ema is not None:
#             ema.update(model)

#         running_loss += loss.item() * imgs.size(0)
#     return running_loss / len(loader.dataset)

def train_one_epoch(model, loader, optimizer, criterion, ema=None, grad_clip=1.0, accum_steps=2, scheduler=None):
    """
    accum_steps: number of mini-batches to accumulate before an optimizer step.
    scheduler: optional LR scheduler. If provided, step it once per optimizer.step().
    """
    model.train()
    running_loss = 0.0
    optimizer.zero_grad()

    progress_bar = tqdm(enumerate(loader), total=len(loader), desc="[Train]")
    for batch_idx, (imgs, labels) in progress_bar:
        imgs = imgs.to(DEVICE, non_blocking=True)
        labels = labels.to(DEVICE, non_blocking=True).float()

        with autocast(device_type="cuda"):
            outputs = model(imgs)
            loss = criterion(outputs, labels)

        # scale the loss down by accum_steps so gradients match larger batch
        loss = loss / accum_steps
        scaler.scale(loss).backward()

        # Do optimizer step once every accum_steps
        if (batch_idx + 1) % accum_steps == 0:
            # # IMPORTANT: unscale before clipping
            # scaler.unscale_(optimizer)
            # # clip grads on underlying model (unwrap possible DataParallel)
            # torch.nn.utils.clip_grad_norm_(get_model(model).parameters(), max_norm=grad_clip)

            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            # scheduler should step per optimizer step (if you want)
            if scheduler is not None:
                scheduler.step()

            # EMA update should be done after the actual optimizer step
            if ema is not None:
                ema.update(model)

        # accumulate running loss (multiply back by accum_steps to report per-sample loss)
        running_loss += loss.item() * imgs.size(0) * accum_steps

    # If total batches is not divisible by accum_steps, we should have stepped for the remainder.
    # The above code zeroes grads after each step, but if final batch didn't trigger a step, we need to step now.
    remainder = len(loader) % accum_steps
    if remainder != 0:
        # # final step for remaining gradients
        # scaler.unscale_(optimizer)
        # torch.nn.utils.clip_grad_norm_(get_model(model).parameters(), max_norm=grad_clip)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        if scheduler is not None:
            scheduler.step()
        if ema is not None:
            ema.update(model)

    avg_loss = running_loss / (len(loader.dataset))
    return avg_loss



# -----------------------------
# Validation (macro + per-label AUROC)
# -----------------------------
def validate(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    all_labels, all_outputs = [], []

    progress_bar = tqdm(loader, desc="[Val]")
    with torch.no_grad():
        for imgs, labels in progress_bar:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE).float()
            with autocast(device_type="cuda" ):
                outputs = model(imgs)
                loss = criterion(outputs, labels)

            running_loss += loss.item() * imgs.size(0)
            all_labels.append(labels.cpu().numpy())
            all_outputs.append(outputs.sigmoid().cpu().numpy())

    all_labels = np.vstack(all_labels)
    all_outputs = np.vstack(all_outputs)

    per_label_auc = {}
    for i, label in enumerate(loader.dataset.label_cols):
        try:
            per_label_auc[label] = roc_auc_score(all_labels[:, i], all_outputs[:, i])
        except ValueError:
            per_label_auc[label] = np.nan

    try:
        #macro_auc = roc_auc_score(all_labels, all_outputs, average="macro", multi_class="ovo")
        macro_auc = roc_auc_score(all_labels, all_outputs, average="macro")
    except ValueError:
        macro_auc = np.nan

    return running_loss / len(loader.dataset), macro_auc, per_label_auc



In [8]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)  # pt = p if target=1, 1-p otherwise
        focal_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class WeightedFocalBCELoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, pos_weights=None):
        super().__init__()
        self.focal = FocalLoss(alpha, gamma)
        self.bce = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
        self.alpha = 0.7  # Weight between focal and BCE
        
    def forward(self, inputs, targets):
        focal_loss = self.focal(inputs, targets)
        bce_loss = self.bce(inputs, targets)
        return self.alpha * focal_loss + (1 - self.alpha) * bce_loss

# Calculate better class weights
def calculate_effective_weights(df, label_cols, beta=0.999):
    """Calculate more robust class weights using effective number of samples"""
    n = len(df)
    weights = []
    for col in label_cols:
        pos_count = df[col].sum()
        neg_count = n - pos_count
        
        # Effective number of samples (from Class-Balanced Loss paper)
        eff_pos = (1 - beta**pos_count) / (1 - beta) if pos_count > 0 else 0
        eff_neg = (1 - beta**neg_count) / (1 - beta) if neg_count > 0 else 0
        
        weight = eff_neg / (eff_pos + 1e-8)  # Avoid division by zero
        weights.append(min(weight, 10.0))  # Cap extreme weights
    
    return torch.tensor(weights).float()

In [9]:
class ModelEMA:
    def __init__(self, model, decay=0.9997, device=None):
        # copy *unwrapped* module (so EMA keys match saved/loaded model keys)
        self.decay = decay
        self.device = device
        self.ema = copy.deepcopy(get_model(model)).eval()
        if device is not None:
            self.ema.to(device)
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def update(self, model):
        src_state = get_model(model).state_dict()
        with torch.no_grad():
            for k, v in self.ema.state_dict().items():
                model_v = src_state[k].detach().to(v.device)
                v.copy_(v * self.decay + (1.0 - self.decay) * model_v)

    def state_dict(self):
        return self.ema.state_dict()


def get_model(module):
    """Return the underlying model (unwrap DataParallel / DDP)."""
    return module.module if hasattr(module, "module") else module

In [10]:
EPOCHS=6

# -----------------------------
# Run Training
# -----------------------------



fold_results = []

df_all = pd.read_csv('/kaggle/input/grand-xray-slam-division-b/train2.csv')
df_all = df_all[~df_all['Image_name'].isin([
    '00043046_001_001.jpg',
    '00052495_001_001.jpg',
    '00056890_001_001.jpg'
])]

label_cols = [
            "Atelectasis",
            "Cardiomegaly",
            "Consolidation",
            "Edema",
            "Enlarged Cardiomediastinum",
            "Fracture",
            "Lung Lesion",
            "Lung Opacity",
            "No Finding",
            "Pleural Effusion",
            "Pleural Other",
            "Pneumonia",
            "Pneumothorax",
            "Support Devices",
        ]


    

# Create datasets
train_df = df_all.copy()

# -----------------------------
# Model: EVAX base
# -----------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EVAX_Model(unfreeze_last_n_blocks=6)
#model = model.to(DEVICE)

# -----------------------------
# Loss & Optimizer
# -----------------------------
#criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
#criterion = FocalLoss(alpha=0.25, gamma=2.0)
pos_weights = calculate_effective_weights(train_df, label_cols, beta=0.9999)
criterion = WeightedFocalBCELoss(alpha=0.25, gamma=2.0, pos_weights=pos_weights.to(DEVICE))


optimizer = torch.optim.AdamW(
    [
        {'params': model.backbone.patch_embed.parameters(), 'lr': 1e-6, 'weight_decay': 1e-6},  
        {'params': model.backbone.rope.parameters(), 'lr': 1e-6, 'weight_decay': 1e-6},  
        {'params': model.backbone.blocks[0:4].parameters(), 'lr': 1e-6, 'weight_decay': 1e-4},  
        {'params': model.backbone.blocks[4:6].parameters(), 'lr': 5e-6, 'weight_decay': 1e-4},  
        {'params': model.backbone.blocks[6:8].parameters(), 'lr': 1e-5, 'weight_decay': 1e-4},  
        {'params': model.backbone.blocks[8:10].parameters(), 'lr': 3e-5, 'weight_decay': 1e-4},  
        {'params': model.backbone.blocks[10:].parameters(), 'lr': 5e-5, 'weight_decay': 1e-4},  
        {'params': model.backbone.fc_norm.parameters(), 'lr': 5e-5, 'weight_decay': 1e-4},  
        {'params': model.backbone.head.parameters(), 'lr': 8e-5, 'weight_decay': 0},      
    ]
)

if torch.cuda.device_count() > 1:
    print(f"⚡ Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)  # wrap for multi-GPU

model = model.to(DEVICE)

#EMA
ema = ModelEMA(model, decay=0.9997)

# Dataset
train_dataset = ChestXrayDataset(
    df=train_df,
    img_dir="/kaggle/input/600p-div-b-data/train2_resized",
    transform=train_tfms
)



# DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,   
    #num_workers=4,   
    pin_memory=True
)


for epoch in range(EPOCHS):
    
    print(f"Epoch {epoch+1}/{EPOCHS}")
    
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, ema)
    #val_loss, val_auc, per_label_auc = validate(ema.ema, val_loader, criterion)

    print(f"[Epoch {epoch+1}] TrainLoss={train_loss:.4f} ")
    # print(f"[Epoch {epoch+1}] TrainLoss={train_loss:.4f} | EMA ValLoss={val_loss:.4f} | AUC={val_auc:.4f}")
    # print("Per-label AUROC:", {k: f"{v:.3f}" for k, v in per_label_auc.items()})

    if epoch > 3:
        #best_auc = val_auc
        save_path = f"best_model_f{epoch+1}.pth"
        torch.save({
            "epoch": epoch+1,
            "model_state_dict": get_model(model).state_dict(),   # unwrap
            "optimizer_state_dict": optimizer.state_dict(),
        }, save_path)
    
        # also save EMA
        if ema is not None:
            torch.save({
                "epoch": epoch+1,
                "ema_state_dict": ema.state_dict(),
            }, f"ema_model_f{epoch+1}.pth")
    
        print(f"✅ Saved new best model")

_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=[])
⚡ Using 2 GPUs
Epoch 1/6


[Train]: 100%|██████████| 3391/3391 [57:22<00:00,  1.02s/it]

[Epoch 1] TrainLoss=0.1086 
Epoch 2/6



[Train]: 100%|██████████| 3391/3391 [57:39<00:00,  1.02s/it]

[Epoch 2] TrainLoss=0.0982 
Epoch 3/6



[Train]: 100%|██████████| 3391/3391 [57:44<00:00,  1.02s/it]

[Epoch 3] TrainLoss=0.0947 
Epoch 4/6



[Train]: 100%|██████████| 3391/3391 [57:48<00:00,  1.02s/it]

[Epoch 4] TrainLoss=0.0915 
Epoch 5/6



[Train]: 100%|██████████| 3391/3391 [57:53<00:00,  1.02s/it]


[Epoch 5] TrainLoss=0.0879 
✅ Saved new best model
Epoch 6/6


[Train]: 100%|██████████| 3391/3391 [57:54<00:00,  1.02s/it]


[Epoch 6] TrainLoss=0.0831 
✅ Saved new best model
