In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
        


In [2]:
import os, random, math, time
from pathlib import Path
from tqdm import tqdm
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models
from torchvision.transforms import RandAugment
import timm 
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from collections import Counter
from torch.autograd import Function

# --- 1. SETUP AND HYPERPARAMETERS (From Notebook Cell 4 & 5) ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device == "cuda": torch.cuda.manual_seed_all(SEED)

IMG_SIZE = 224
BATCH_SIZE = 8         
EPOCHS = 15          
NUM_WORKERS = 4         
LR = 3e-4              
LABEL_SMOOTH = 0.1
SAVE_PATH = "best_model.pth"
USE_SEGMENTATION = True 

# Loss weights 
ALPHA_DOM = 0.5
BETA_SUPCON = 0.3
ETA_CONS = 0.1
GAMMA_SEG = 0.5

# Mixup/CutMix probabilities and alphas
PROB_MIXUP = 0.5
PROB_CUTMIX = 0.5
MIXUP_ALPHA = 0.2
CUTMIX_ALPHA = 1.0

# Warmup epochs and accumulation steps
WARMUP_EPOCHS = 5
EARLY_STOPPING_PATIENCE = 7
FREEZE_EPOCHS = 5
ACCUMULATION_STEPS = 4

# --- 2. TRANSFORMS (From Notebook Cell 6) ---
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=0.05):
        self.mean = mean
        self.std = std
    def __call__(self, tensor):
        noise = torch.randn(tensor.size()) * self.std + self.mean
        noisy_tensor = tensor + noise
        return torch.clamp(noisy_tensor, 0., 1.)
    def __repr__(self):
        return self.__class__.__name__ + f'(mean={self.mean}, std={self.std})'

weak_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.02),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
strong_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ColorJitter(0.4,0.4,0.4,0.1),
    RandAugment(num_ops=2, magnitude=9),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.05),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
val_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

# --- 3. DATASET HELPER FUNCTIONS (From Notebook Cell 7) ---
def read_file_with_encoding(file_path, encodings=['utf-8', 'utf-8-sig', 'ISO-8859-1']):
    for encoding in encodings:
        try:
            with open(file_path, 'r', encoding=encoding) as f:
                return f.readlines()
        except UnicodeDecodeError:
            pass
        except Exception as e:
            raise 
    raise RuntimeError(f"Unable to read {file_path} with any of the provided encodings.")

def load_testing_dataset_info(info_file, image_dir):
    image_paths = []
    labels = []
    encodings = ['utf-8-sig', 'utf-8', 'ISO-8859-1', 'latin-1']
    lines = []
    
    for encoding in encodings:
        try:
            with open(info_file, 'r', encoding=encoding) as f:
                lines = f.readlines()
            break
        except UnicodeDecodeError:
            pass
        except Exception as e:
            raise 

    for line in lines:
        parts = line.strip().split()
        if len(parts) == 2:
            image_filename = parts[0]
            try:
                label = int(parts[1])
            except ValueError:
                continue
                
            label = 1 if label == 1 else 0
            image_full_path = os.path.join(image_dir, image_filename)
            image_paths.append(image_full_path)
            labels.append(label)
    
    return image_paths, labels

# --- 4. MULTIDATASET CLASS (From Notebook Cell 7) ---
class MultiDataset(Dataset):
    def __init__(self, root_dirs, txt_files, testing_image_paths=None, testing_labels=None, weak_transform=None, strong_transform=None, use_masks=True):
        self.root_dirs = root_dirs
        self.weak_transform = weak_transform
        self.strong_transform = strong_transform
        self.use_masks = use_masks
        self.samples = []

        if testing_image_paths is not None and testing_labels is not None:
            for img_path, label in zip(testing_image_paths, testing_labels):
                self.samples.append((img_path, label)) 
        
        if isinstance(txt_files, str):
            txt_files = [txt_files]

        all_lines = []
        for t in txt_files:
            if not os.path.exists(t):
                raise RuntimeError(f"TXT file not found: {t}. Please ensure all Kaggle input datasets are mounted.")

            lines = read_file_with_encoding(t)
            all_lines.extend([(line.strip(), t) for line in lines if line.strip()])

        for line, src_txt in all_lines:
            parts = line.split()
            if len(parts) == 0:
                continue

            fname = parts[0]
            if len(parts) >= 2:
                try:
                    lbl = int(parts[1])
                except:
                    lbl = 1 if "CAM" in fname or "cam" in fname else 0
            else:
                lbl = 1 if "CAM" in fname or "cam" in fname else 0
            
            lbl = 1 if lbl == 1 else 0
            base_fname = os.path.basename(fname)  

            found = False
            search_subs = [
                "", "Image", "Imgs", "images", "JPEGImages", "img", 
                "Images/Train", "Images/Test",
            ]
            
            for rdir in self.root_dirs:
                for sub in search_subs:
                    img_path = os.path.join(rdir, sub, base_fname)
                    if os.path.exists(img_path):
                        self.samples.append((img_path, lbl, rdir))
                        found = True
                        break
                if found:
                    break

        if len(self.samples) == 0:
            raise RuntimeError(f"No valid samples found from {txt_files}")

        print(f"✅ Loaded {len(self.samples)} samples from {len(self.root_dirs)} root directories.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        global IMG_SIZE 
        
        sample = self.samples[idx]
        img_path = sample[0]
        lbl = sample[1]
        
        if len(sample) == 3:
            rdir = sample[2]
        else:
            rdir = os.path.dirname(os.path.dirname(img_path))
            
        try:
            img = Image.open(img_path).convert("RGB")
        except:
            img = Image.fromarray(np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8))
            
        if self.weak_transform:
            weak = self.weak_transform(img)
        else:
            weak = transforms.ToTensor()(img)
            
        if self.strong_transform:
            strong = self.strong_transform(img)
        else:
            strong = weak.clone()

        mask = None
        if self.use_masks:
            mask_name = os.path.splitext(os.path.basename(img_path))[0] + ".png"
            found_mask = False
            for mask_dir in ["GT_Object", "GT", "masks", "Mask"]:
                mask_path = os.path.join(rdir, mask_dir, mask_name)
                
                if os.path.exists(mask_path):
                    m = Image.open(mask_path).convert("L").resize((IMG_SIZE, IMG_SIZE))
                    m = np.array(m).astype(np.float32) / 255.0
                    mask = torch.from_numpy((m > 0.5).astype(np.float32)).unsqueeze(0)
                    found_mask = True
                    break

            if mask is None:
                mask = torch.zeros((1, IMG_SIZE, IMG_SIZE), dtype=torch.float32)
                
        return weak, strong, lbl, mask

# --- 5. BACKBONE EXTRACTORS (From Notebook Cell 9) ---

class DenseNetExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.densenet201(weights='IMAGENET1K_V1' if pretrained else None).features
    def forward(self, x):
        feats = []
        for name, layer in self.features._modules.items():
            x = layer(x)
            if name in ["denseblock1","denseblock2","denseblock3","denseblock4"]:
                feats.append(x)
        return feats

class MobileNetExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.mobilenet_v3_large(weights='IMAGENET1K_V1' if pretrained else None).features
    def forward(self, x):
        feats = []
        out = x
        for i, layer in enumerate(self.features):
            out = layer(out)
            if i in (2,5,9,12):
                feats.append(out)
        if len(feats) < 4:
            feats.append(out)
        return feats

# --- 6. USER'S Keras-Style Fusion Model (Corrected) ---

DENSE_CHANNELS = 1920
# CORRECTION: The actual output features are 112, resulting in 2032 total features (1920+112). 
# We MUST use 112 here to resolve the mat1 and mat2 dimension mismatch error (2032 vs 3200).
MOBILE_CHANNELS = 112 
TOTAL_FEATURES = DENSE_CHANNELS + MOBILE_CHANNELS # 1920 + 112 = 2032
NUM_CLASSES = 2 

class KerasStyleFusion(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, dense_out_channels=DENSE_CHANNELS, mobile_out_channels=MOBILE_CHANNELS):
        super().__init__()
        
        self.densenet_base = DenseNetExtractor(pretrained=True)
        self.mobilenet_base = MobileNetExtractor(pretrained=True)
        
        for param in self.densenet_base.parameters():
            param.requires_grad = False
        for param in self.mobilenet_base.parameters():
            param.requires_grad = False
            
        self.global_pool = nn.AdaptiveAvgPool2d(1)

        in_features = dense_out_channels + mobile_out_channels
        
        # CORRECTED: The first linear layer now correctly expects 2032 input features.
        self.classifier = nn.Sequential(
            nn.Linear(in_features, 512), # nn.Linear(2032, 512)
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes) 
        )
        
        self.use_seg = False 
        self.domain_head = nn.Identity()
        
    def forward(self, x, grl_lambda=0.0):
        densenet_feats = self.densenet_base(x)[-1]
        x1 = self.global_pool(densenet_feats).view(x.size(0), -1)

        mobilenet_feats = self.mobilenet_base(x)[-1]
        x2 = self.global_pool(mobilenet_feats).view(x.size(0), -1)

        concatenated_features = torch.cat([x1, x2], dim=1)

        logits = self.classifier(concatenated_features)
        
        out = {"logits": logits, "feat": concatenated_features}
        out["domain_logits"] = torch.randn(x.size(0), 2).to(x.device) 
        out["seg"] = torch.randn(x.size(0), 1, IMG_SIZE, IMG_SIZE).to(x.device)
        
        return out

# --- 7. LOSS AND ADVERSARIAL HELPERS (From Notebook Cells 17-19) ---

class LabelSmoothingCE(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.s = smoothing
    def forward(self, logits, target):
        c = logits.size(-1)
        logp = F.log_softmax(logits, dim=-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(logp)
            true_dist.fill_(self.s / (c - 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.s)
        return (-true_dist * logp).sum(dim=-1).mean()

class FocalLoss(nn.Module):
    def __init__(self, gamma=1.5):
        super().__init__()
        self.gamma = gamma
    def forward(self, logits, target):
        prob = F.softmax(logits, dim=1)
        pt = prob.gather(1, target.unsqueeze(1)).squeeze(1)
        ce = F.cross_entropy(logits, target, reduction='none')
        loss = ((1 - pt) ** self.gamma) * ce
        return loss.mean()

def dice_loss(pred, target, smooth=1.0):
    pred = torch.sigmoid(pred)
    num = 2 * (pred * target).sum() + smooth
    den = pred.sum() + target.sum() + smooth
    return 1 - (num / den)

def seg_loss_fn(pred, mask):
    if pred.shape[-2:] != mask.shape[-2:]:
        pred = F.interpolate(pred, size=mask.shape[-2:], mode="bilinear", align_corners=False)
    return F.binary_cross_entropy_with_logits(pred, mask) + dice_loss(pred, mask)

class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.cos = nn.CosineSimilarity(dim=-1)
    def forward(self, features, labels):
        device = features.device
        f = F.normalize(features, dim=1)
        sim = torch.matmul(f, f.T) / self.temperature
        labels = labels.contiguous().view(-1,1)
        mask = torch.eq(labels, labels.T).float().to(device)
        logits_max, _ = torch.max(sim, dim=1, keepdim=True)
        logits = sim - logits_max.detach()
        exp_logits = torch.exp(logits) * (1 - torch.eye(len(features), device=device))
        denom = exp_logits.sum(1, keepdim=True)
        pos_mask = mask - torch.eye(len(features), device=device)
        pos_exp = (exp_logits * pos_mask).sum(1)
        loss = -torch.log((pos_exp + 1e-8) / (denom + 1e-8) + 1e-12)
        valid = (pos_mask.sum(1) > 0).float()
        loss = (loss * valid).sum() / (valid.sum() + 1e-8)
        return loss

clf_loss_ce = LabelSmoothingCE(LABEL_SMOOTH)
clf_loss_focal = FocalLoss(gamma=1.5)
supcon_loss_fn = SupConLoss(temperature=0.07)

class GradReverse(Function):
    @staticmethod
    def forward(ctx, x, l):
        ctx.l = l
        return x.view_as(x)
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.l, None

def grad_reverse(x, l=1.0):
    return GradReverse.apply(x, l)


# --- 8. DATA AUGMENTATION HELPERS (From Notebook Cell 20) ---

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def apply_mixup(x, y, alpha=MIXUP_ALPHA):
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0))
    mixed_x = lam * x + (1 - lam) * x[idx]
    y_a, y_b = y, y[idx]
    return mixed_x, y_a, y_b, lam

def apply_cutmix(x, y, alpha=CUTMIX_ALPHA):
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0))
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    new_x = x.clone()
    new_x[:, :, bby1:bby2, bbx1:bbx2] = x[idx, :, bby1:bby2, bbx1:bbx2]
    lam_adjusted = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size(-1) * x.size(-2)))
    return new_x, y, y[idx], lam_adjusted


# --- 9. DATA LOADING SETUP (From Notebook Cell 8) ---
info_dir = "/kaggle/input/cod10k/COD10K-v3/Info"
train_dir_cod = "/kaggle/input/cod10k/COD10K-v3/Train" 
test_dir_cod = "/kaggle/input/cod10k/COD10K-v3/Test"  
    
train_cam_txt = os.path.join(info_dir, "CAM_train.txt")
train_noncam_txt = os.path.join(info_dir, "NonCAM_train.txt")
test_cam_txt = os.path.join(info_dir, "CAM_test.txt")
test_noncam_txt = os.path.join(info_dir, "NonCAM_test.txt")

info_dir2 = "/kaggle/input/camo-coco/CAMO_COCO/Info"
train_cam_txt2 = os.path.join(info_dir2, "camo_train.txt")
train_noncam_txt2 = os.path.join(info_dir2, "non_camo_train.txt")
test_cam_txt2 = os.path.join(info_dir2, "camo_test.txt")
test_noncam_txt2 = os.path.join(info_dir2, "non_camo_test.txt")

train_dir_camo_cam = "/kaggle/input/camo-coco/CAMO_COCO/Camouflage"
train_dir_camo_noncam = "/kaggle/input/camo-coco/CAMO_COCO/Non_Camouflage"

testing_info_file = "/kaggle/input/testing-dataset/Info/image_labels.txt"
testing_images_dir = "/kaggle/input/testing-dataset/Images"

# --- MOCKING DATASET LOADING FOR RUNNABILITY ---
try:
# --- Scenario 1 Changes ---
# 1. 80/20 Split of testing-dataset remains the same
    testing_image_paths, testing_labels = load_testing_dataset_info(testing_info_file, testing_images_dir)
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        testing_image_paths, testing_labels, test_size=0.2, random_state=SEED
    )
    ALL_ROOT_DIRS = [
        train_dir_cod,       
        test_dir_cod,       
        train_dir_camo_cam,  
        train_dir_camo_noncam
    ]

# 2. Training includes basic sets + 80% of testing-dataset
    ALL_TRAIN_TXTS = [
        train_cam_txt2, train_noncam_txt2,
    ]

# 2. Validation uses CAMO-COCO test text files ONLY
    ALL_VAL_TXTS = [test_cam_txt2, test_noncam_txt2]

    train_ds = MultiDataset(
        root_dirs=ALL_ROOT_DIRS, txt_files=ALL_TRAIN_TXTS,
        testing_image_paths=train_paths, testing_labels=train_labels,
        weak_transform=weak_tf, strong_transform=strong_tf, use_masks=USE_SEGMENTATION
    )

    val_ds = MultiDataset(
        root_dirs=ALL_ROOT_DIRS,  
        txt_files=ALL_VAL_TXTS,                 
        testing_image_paths=val_paths,          
        testing_labels=val_labels,              
        weak_transform=val_tf, 
        strong_transform=None, 
        use_masks=USE_SEGMENTATION
    )

    def build_weighted_sampler(dataset):
        labels = [sample[1] for sample in dataset.samples]  
        counts = Counter(labels)
        total = len(labels)
        if len(counts) <= 1:
            weights = [1.0] * total
        else:
            class_weights = {c: total / (counts[c] * len(counts)) for c in counts}
            weights = [class_weights[lbl] for lbl in labels]
        return WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

    train_sampler = build_weighted_sampler(train_ds)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers=NUM_WORKERS)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

except RuntimeError as e:
    print(f"⚠️ Warning: Cannot access Kaggle paths ({e}). Using Mock DataLoaders for demonstration.")
    
    class MockDataset(Dataset):
        def __init__(self, num_samples, num_classes=2, img_size=IMG_SIZE):
            self.num_samples = num_samples
            self.data = torch.randn(num_samples, 3, img_size, img_size)
            self.labels = torch.randint(0, num_classes, (num_samples,))
            self.masks = torch.randint(0, 2, (num_samples, 1, img_size, img_size)).float()

        def __len__(self):
            return self.num_samples

        def __getitem__(self, idx):
            return self.data[idx], self.data[idx].clone(), self.labels[idx], self.masks[idx]

    train_ds = MockDataset(num_samples=14150) 
    val_ds = MockDataset(num_samples=6606)   
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    print(f"Mock DataLoaders initialized with {len(train_ds)} train and {len(val_ds)} val samples.")


# --- 10. MODEL INSTANTIATION AND OPTIMIZER SETUP (From Notebook Cell 20) ---

model = KerasStyleFusion().to(device)

backbone_params = []
head_params = []
for name, param in model.named_parameters():
    if any(k in name for k in ['densenet_base', 'mobilenet_base']):
        backbone_params.append(param)
    else:
        head_params.append(param)

opt = torch.optim.AdamW([
    {'params': backbone_params, 'lr': LR * 0.2}, 
    {'params': head_params, 'lr': LR}
], lr=LR, weight_decay=1e-4)

def get_cosine_with_warmup_scheduler(optimizer, warmup_epochs, total_epochs, last_epoch=-1):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return float(epoch) / float(max(1.0, warmup_epochs))
        t = (epoch - warmup_epochs) / float(max(1, total_epochs - warmup_epochs))
        return 0.5 * (1.0 + math.cos(math.pi * t))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)

scheduler = get_cosine_with_warmup_scheduler(opt, WARMUP_EPOCHS, EPOCHS)
scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))

print(f"\nModel instantiated. LR: {LR}, Epochs: {EPOCHS}.")
print(f"Backbones are initially frozen for {FREEZE_EPOCHS} epochs.")

# --- 11. TRAINING LOOP (Adapted from Notebook Cell 21) ---

def compute_combined_clf_loss(logits, targets, mix_info=None, use_focal=False):
    if mix_info is None:
        if use_focal:
            return clf_loss_focal(logits, targets)
        else:
            return clf_loss_ce(logits, targets)
    else:
        y_a, y_b, lam = mix_info
        if use_focal:
            loss = lam * F.cross_entropy(logits, y_a) + (1 - lam) * F.cross_entropy(logits, y_b)
        else:
            loss = lam * clf_loss_ce(logits, y_a) + (1 - lam) * clf_loss_ce(logits, y_b)
        return loss

best_vf1 = 0.0
best_epoch = 0
patience_count = 0

for epoch in range(1, EPOCHS+1):
    # --- Freeze/Unfreeze Logic ---
    if epoch <= FREEZE_EPOCHS:
        for name, p in model.named_parameters():
            if any(k in name for k in ['densenet_base', 'mobilenet_base']):
                p.requires_grad = False
    elif epoch == FREEZE_EPOCHS + 1:
        print(f"--- Unfreezing all backbone layers at epoch {epoch} ---")
        for p in model.parameters():
            p.requires_grad = True

    model.train()
    running_loss = 0.0
    y_true, y_pred = [], []
    n_batches = 0

    opt.zero_grad() 
    
    for i, (weak_imgs, strong_imgs, labels, masks) in enumerate(tqdm(train_loader, desc=f"Train {epoch}/{EPOCHS}")):
        weak_imgs = weak_imgs.to(device); strong_imgs = strong_imgs.to(device)
        labels = labels.to(device)
        if masks is not None:
            masks = masks.to(device)

        imgs = weak_imgs

        mix_info = None
        rand = random.random()
        if rand < PROB_MIXUP:
            imgs, y_a, y_b, lam = apply_mixup(imgs, labels)
            mix_info = (y_a.to(device), y_b.to(device), lam)
        elif rand < PROB_MIXUP + PROB_CUTMIX:
            imgs, y_a, y_b, lam = apply_cutmix(imgs, labels)
            mix_info = (y_a.to(device), y_b.to(device), lam)

        with torch.autocast(device_type=device, enabled=(device=="cuda")):
            out = model(imgs) 
            logits = out["logits"]
            
            clf_loss = compute_combined_clf_loss(logits, labels, mix_info=mix_info, use_focal=False)

            # Auxiliary losses are set to 0 as the user's simple fusion model excludes these complex heads.
            seg_loss = 0.0
            supcon_loss = 0.0 
            cons_loss = 0.0   
            dom_loss = 0.0

            total_loss = clf_loss + GAMMA_SEG * seg_loss + BETA_SUPCON * supcon_loss + ETA_CONS * cons_loss + ALPHA_DOM * dom_loss
            
            total_loss = total_loss / ACCUMULATION_STEPS 

        scaler.scale(total_loss).backward()

        if (i + 1) % ACCUMULATION_STEPS == 0:
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(opt)
            scaler.update()
            opt.zero_grad() 

        running_loss += total_loss.item() * ACCUMULATION_STEPS
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(logits.argmax(1).cpu().numpy())
        n_batches += 1

    if n_batches % ACCUMULATION_STEPS != 0:
        scaler.unscale_(opt)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(opt)
        scaler.update()
        opt.zero_grad()

    scheduler.step()

    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="macro", zero_division=0)
    print(f"[Epoch {epoch}] Train Loss: {running_loss/max(1,n_batches):.4f} Acc: {acc:.4f} Prec: {prec:.4f} Rec: {rec:.4f} F1: {f1:.4f}")

    # -------------------
    # VALIDATION
    # -------------------
    model.eval()
    val_y_true, val_y_pred = [], []
    val_loss = 0.0
    with torch.no_grad():
        for weak_imgs, _, labels, masks in val_loader:
            imgs = weak_imgs.to(device)
            labels = labels.to(device)
            if masks is not None:
                masks = masks.to(device)

            out = model(imgs)
            logits = out["logits"]
            
            loss = compute_combined_clf_loss(logits, labels, mix_info=None, use_focal=False)
            if USE_SEGMENTATION and (masks is not None):
                loss += GAMMA_SEG * seg_loss_fn(out["seg"], masks) 
            val_loss += loss.item()

            val_y_true.extend(labels.cpu().numpy())
            val_y_pred.extend(logits.argmax(1).cpu().numpy())

    vacc = accuracy_score(val_y_true, val_y_pred)
    vprec, vrec, vf1, _ = precision_recall_fscore_support(val_y_true, val_y_pred, average="macro", zero_division=0)
    print(f"[Epoch {epoch}] Val Loss: {val_loss/max(1,len(val_loader)):.4f} Acc: {vacc:.4f} Prec: {vprec:.4f} Rec: {vrec:.4f} F1: {vf1:.4f}")

    if vf1 > best_vf1:
        best_vf1 = vf1
        best_epoch = epoch
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "opt_state": opt.state_dict(),
            "best_vf1": best_vf1
        }, SAVE_PATH)
        patience_count = 0
        print(f"Saved best model at epoch {epoch} (F1 {best_vf1:.4f})")
    else:
        patience_count += 1
        if patience_count >= EARLY_STOPPING_PATIENCE:
            print("Early stopping triggered.")
            break

print("\nTraining finished. Best val F1:", best_vf1, "at epoch", best_epoch)



Device: cuda
✅ Loaded 8605 samples from 4 root directories.
✅ Loaded 2152 samples from 4 root directories.
Downloading: "https://download.pytorch.org/models/densenet201-c1103571.pth" to /root/.cache/torch/hub/checkpoints/densenet201-c1103571.pth


100%|██████████| 77.4M/77.4M [00:00<00:00, 174MB/s]


Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_large-8738ca79.pth


100%|██████████| 21.1M/21.1M [00:00<00:00, 154MB/s]
  scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))



Model instantiated. LR: 0.0003, Epochs: 15.
Backbones are initially frozen for 5 epochs.


Train 1/15: 100%|██████████| 1076/1076 [01:41<00:00, 10.62it/s]

[Epoch 1] Train Loss: 0.6930 Acc: 0.5069 Prec: 0.5094 Rec: 0.5041 F1: 0.4239





[Epoch 1] Val Loss: 1.5834 Acc: 0.3917 Prec: 0.1959 Rec: 0.5000 F1: 0.2815
Saved best model at epoch 1 (F1 0.2815)


Train 2/15: 100%|██████████| 1076/1076 [01:34<00:00, 11.40it/s]


[Epoch 2] Train Loss: 0.5399 Acc: 0.7626 Prec: 0.7659 Rec: 0.7616 F1: 0.7614
[Epoch 2] Val Loss: 1.3412 Acc: 0.9572 Prec: 0.9512 Rec: 0.9625 F1: 0.9558
Saved best model at epoch 2 (F1 0.9558)


Train 3/15: 100%|██████████| 1076/1076 [01:33<00:00, 11.52it/s]


[Epoch 3] Train Loss: 0.4726 Acc: 0.8099 Prec: 0.8099 Rec: 0.8099 F1: 0.8099
[Epoch 3] Val Loss: 1.3550 Acc: 0.9586 Prec: 0.9536 Rec: 0.9614 F1: 0.9570
Saved best model at epoch 3 (F1 0.9570)


Train 4/15: 100%|██████████| 1076/1076 [01:33<00:00, 11.54it/s]


[Epoch 4] Train Loss: 0.4661 Acc: 0.8092 Prec: 0.8092 Rec: 0.8091 F1: 0.8091
[Epoch 4] Val Loss: 1.3739 Acc: 0.9633 Prec: 0.9575 Rec: 0.9686 F1: 0.9620
Saved best model at epoch 4 (F1 0.9620)


Train 5/15: 100%|██████████| 1076/1076 [01:32<00:00, 11.60it/s]


[Epoch 5] Train Loss: 0.4670 Acc: 0.8123 Prec: 0.8123 Rec: 0.8123 F1: 0.8123
[Epoch 5] Val Loss: 1.3442 Acc: 0.9689 Prec: 0.9634 Rec: 0.9738 F1: 0.9677
Saved best model at epoch 5 (F1 0.9677)
--- Unfreezing all backbone layers at epoch 6 ---


Train 6/15: 100%|██████████| 1076/1076 [03:16<00:00,  5.47it/s]

[Epoch 6] Train Loss: 0.4449 Acc: 0.8212 Prec: 0.8211 Rec: 0.8211 F1: 0.8211





[Epoch 6] Val Loss: 1.3147 Acc: 0.9814 Prec: 0.9780 Rec: 0.9837 F1: 0.9806
Saved best model at epoch 6 (F1 0.9806)


Train 7/15: 100%|██████████| 1076/1076 [03:01<00:00,  5.94it/s]

[Epoch 7] Train Loss: 0.4329 Acc: 0.8306 Prec: 0.8305 Rec: 0.8305 F1: 0.8305





[Epoch 7] Val Loss: 1.2976 Acc: 0.9814 Prec: 0.9776 Rec: 0.9843 F1: 0.9806
Saved best model at epoch 7 (F1 0.9806)


Train 8/15: 100%|██████████| 1076/1076 [03:00<00:00,  5.97it/s]

[Epoch 8] Train Loss: 0.4275 Acc: 0.8460 Prec: 0.8460 Rec: 0.8460 F1: 0.8460





[Epoch 8] Val Loss: 1.2932 Acc: 0.9842 Prec: 0.9808 Rec: 0.9868 F1: 0.9835
Saved best model at epoch 8 (F1 0.9835)


Train 9/15: 100%|██████████| 1076/1076 [03:01<00:00,  5.92it/s]

[Epoch 9] Train Loss: 0.4242 Acc: 0.8492 Prec: 0.8491 Rec: 0.8491 F1: 0.8491





[Epoch 9] Val Loss: 1.3021 Acc: 0.9782 Prec: 0.9736 Rec: 0.9820 F1: 0.9773


Train 10/15: 100%|██████████| 1076/1076 [03:00<00:00,  5.98it/s]

[Epoch 10] Train Loss: 0.4250 Acc: 0.8429 Prec: 0.8429 Rec: 0.8429 F1: 0.8429





[Epoch 10] Val Loss: 1.3025 Acc: 0.9796 Prec: 0.9753 Rec: 0.9830 F1: 0.9787


Train 11/15: 100%|██████████| 1076/1076 [02:59<00:00,  5.99it/s]

[Epoch 11] Train Loss: 0.4181 Acc: 0.8515 Prec: 0.8514 Rec: 0.8514 F1: 0.8514





[Epoch 11] Val Loss: 1.2674 Acc: 0.9875 Prec: 0.9845 Rec: 0.9897 F1: 0.9869
Saved best model at epoch 11 (F1 0.9869)


Train 12/15: 100%|██████████| 1076/1076 [02:59<00:00,  6.01it/s]

[Epoch 12] Train Loss: 0.4233 Acc: 0.8482 Prec: 0.8482 Rec: 0.8482 F1: 0.8482





[Epoch 12] Val Loss: 1.2537 Acc: 0.9870 Prec: 0.9841 Rec: 0.9891 F1: 0.9864


Train 13/15: 100%|██████████| 1076/1076 [03:00<00:00,  5.95it/s]

[Epoch 13] Train Loss: 0.4148 Acc: 0.8558 Prec: 0.8558 Rec: 0.8558 F1: 0.8558





[Epoch 13] Val Loss: 1.2731 Acc: 0.9870 Prec: 0.9839 Rec: 0.9893 F1: 0.9864


Train 14/15: 100%|██████████| 1076/1076 [02:59<00:00,  5.98it/s]

[Epoch 14] Train Loss: 0.4195 Acc: 0.8628 Prec: 0.8628 Rec: 0.8628 F1: 0.8628





[Epoch 14] Val Loss: 1.2593 Acc: 0.9865 Prec: 0.9834 Rec: 0.9889 F1: 0.9859


Train 15/15: 100%|██████████| 1076/1076 [03:00<00:00,  5.96it/s]

[Epoch 15] Train Loss: 0.4246 Acc: 0.8454 Prec: 0.8454 Rec: 0.8454 F1: 0.8454





[Epoch 15] Val Loss: 1.2599 Acc: 0.9865 Prec: 0.9834 Rec: 0.9889 F1: 0.9859

Training finished. Best val F1: 0.9869087458202683 at epoch 11


In [3]:
import os, random, math, time
from pathlib import Path
from tqdm import tqdm
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models
from torchvision.transforms import RandAugment
import timm 
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from collections import Counter
from torch.autograd import Function

# --- 1. SETUP AND HYPERPARAMETERS (From Notebook Cell 4 & 5) ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device == "cuda": torch.cuda.manual_seed_all(SEED)

BATCH_SIZE = 8         
EPOCHS = 15          
NUM_WORKERS = 4         
LR = 3e-4              
LABEL_SMOOTH = 0.1
SAVE_PATH = "best_model.pth"
USE_SEGMENTATION = True 

# Loss weights 
ALPHA_DOM = 0.5
BETA_SUPCON = 0.3
ETA_CONS = 0.1
GAMMA_SEG = 0.5

# Mixup/CutMix probabilities and alphas
PROB_MIXUP = 0.5
PROB_CUTMIX = 0.5
MIXUP_ALPHA = 0.2
CUTMIX_ALPHA = 1.0

# Warmup epochs and accumulation steps
WARMUP_EPOCHS = 5
EARLY_STOPPING_PATIENCE = 7
FREEZE_EPOCHS = 5
ACCUMULATION_STEPS = 4

# --- 2. TRANSFORMS (From Notebook Cell 6) ---
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=0.05):
        self.mean = mean
        self.std = std
    def __call__(self, tensor):
        noise = torch.randn(tensor.size()) * self.std + self.mean
        noisy_tensor = tensor + noise
        return torch.clamp(noisy_tensor, 0., 1.)
    def __repr__(self):
        return self.__class__.__name__ + f'(mean={self.mean}, std={self.std})'

weak_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.02),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
strong_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ColorJitter(0.4,0.4,0.4,0.1),
    RandAugment(num_ops=2, magnitude=9),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.05),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
val_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

# --- 3. DATASET HELPER FUNCTIONS (From Notebook Cell 7) ---
def read_file_with_encoding(file_path, encodings=['utf-8', 'utf-8-sig', 'ISO-8859-1']):
    for encoding in encodings:
        try:
            with open(file_path, 'r', encoding=encoding) as f:
                return f.readlines()
        except UnicodeDecodeError:
            pass
        except Exception as e:
            raise 
    raise RuntimeError(f"Unable to read {file_path} with any of the provided encodings.")

def load_testing_dataset_info(info_file, image_dir):
    image_paths = []
    labels = []
    encodings = ['utf-8-sig', 'utf-8', 'ISO-8859-1', 'latin-1']
    lines = []
    
    for encoding in encodings:
        try:
            with open(info_file, 'r', encoding=encoding) as f:
                lines = f.readlines()
            break
        except UnicodeDecodeError:
            pass
        except Exception as e:
            raise 

    for line in lines:
        parts = line.strip().split()
        if len(parts) == 2:
            image_filename = parts[0]
            try:
                label = int(parts[1])
            except ValueError:
                continue
                
            label = 1 if label == 1 else 0
            image_full_path = os.path.join(image_dir, image_filename)
            image_paths.append(image_full_path)
            labels.append(label)
    
    return image_paths, labels

# --- 4. MULTIDATASET CLASS (From Notebook Cell 7) ---
class MultiDataset(Dataset):
    def __init__(self, root_dirs, txt_files, testing_image_paths=None, testing_labels=None, weak_transform=None, strong_transform=None, use_masks=True):
        self.root_dirs = root_dirs
        self.weak_transform = weak_transform
        self.strong_transform = strong_transform
        self.use_masks = use_masks
        self.samples = []

        if testing_image_paths is not None and testing_labels is not None:
            for img_path, label in zip(testing_image_paths, testing_labels):
                self.samples.append((img_path, label)) 
        
        if isinstance(txt_files, str):
            txt_files = [txt_files]

        all_lines = []
        for t in txt_files:
            if not os.path.exists(t):
                raise RuntimeError(f"TXT file not found: {t}. Please ensure all Kaggle input datasets are mounted.")

            lines = read_file_with_encoding(t)
            all_lines.extend([(line.strip(), t) for line in lines if line.strip()])

        for line, src_txt in all_lines:
            parts = line.split()
            if len(parts) == 0:
                continue

            fname = parts[0]
            if len(parts) >= 2:
                try:
                    lbl = int(parts[1])
                except:
                    lbl = 1 if "CAM" in fname or "cam" in fname else 0
            else:
                lbl = 1 if "CAM" in fname or "cam" in fname else 0
            
            lbl = 1 if lbl == 1 else 0
            base_fname = os.path.basename(fname)  

            found = False
            search_subs = [
                "", "Image", "Imgs", "images", "JPEGImages", "img", 
                "Images/Train", "Images/Test",
            ]
            
            for rdir in self.root_dirs:
                for sub in search_subs:
                    img_path = os.path.join(rdir, sub, base_fname)
                    if os.path.exists(img_path):
                        self.samples.append((img_path, lbl, rdir))
                        found = True
                        break
                if found:
                    break

        if len(self.samples) == 0:
            raise RuntimeError(f"No valid samples found from {txt_files}")

        print(f"✅ Loaded {len(self.samples)} samples from {len(self.root_dirs)} root directories.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        global IMG_SIZE 
        
        sample = self.samples[idx]
        img_path = sample[0]
        lbl = sample[1]
        
        if len(sample) == 3:
            rdir = sample[2]
        else:
            rdir = os.path.dirname(os.path.dirname(img_path))
            
        try:
            img = Image.open(img_path).convert("RGB")
        except:
            img = Image.fromarray(np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8))
            
        if self.weak_transform:
            weak = self.weak_transform(img)
        else:
            weak = transforms.ToTensor()(img)
            
        if self.strong_transform:
            strong = self.strong_transform(img)
        else:
            strong = weak.clone()

        mask = None
        if self.use_masks:
            mask_name = os.path.splitext(os.path.basename(img_path))[0] + ".png"
            found_mask = False
            for mask_dir in ["GT_Object", "GT", "masks", "Mask"]:
                mask_path = os.path.join(rdir, mask_dir, mask_name)
                
                if os.path.exists(mask_path):
                    m = Image.open(mask_path).convert("L").resize((IMG_SIZE, IMG_SIZE))
                    m = np.array(m).astype(np.float32) / 255.0
                    mask = torch.from_numpy((m > 0.5).astype(np.float32)).unsqueeze(0)
                    found_mask = True
                    break

            if mask is None:
                mask = torch.zeros((1, IMG_SIZE, IMG_SIZE), dtype=torch.float32)
                
        return weak, strong, lbl, mask

# --- 5. BACKBONE EXTRACTORS (From Notebook Cell 9) ---

class DenseNetExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.densenet201(weights='IMAGENET1K_V1' if pretrained else None).features
    def forward(self, x):
        feats = []
        for name, layer in self.features._modules.items():
            x = layer(x)
            if name in ["denseblock1","denseblock2","denseblock3","denseblock4"]:
                feats.append(x)
        return feats

class InceptionExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        # We load the full model but only use the features
        inception = models.inception_v3(weights='IMAGENET1K_V1' if pretrained else None)
        inception.transform_input = True # Normalizes internally for Inception
        
        # Extract features (everything except the final pooling and FC)
        self.features = nn.Sequential(
            inception.Conv2d_1a_3x3,
            inception.Conv2d_2a_3x3,
            inception.Conv2d_2b_3x3,
            nn.MaxPool2d(kernel_size=3, stride=2),
            inception.Conv2d_3b_1x1,
            inception.Conv2d_4a_3x3,
            nn.MaxPool2d(kernel_size=3, stride=2),
            inception.Mixed_5b, inception.Mixed_5c, inception.Mixed_5d,
            inception.Mixed_6a, inception.Mixed_6b, inception.Mixed_6c,
            inception.Mixed_6d, inception.Mixed_6e,
            inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c
        )

    def forward(self, x):
        # InceptionV3 expects 299x299. If input is different, we interpolate.
        if x.shape[-1] != 299:
            x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
        
        x = self.features(x)
        return [x] # Returning as list to maintain compatibility with your loop
# --- 6. USER'S Keras-Style Fusion Model (Corrected) ---

IMG_SIZE = 299  # Updated for InceptionV3
# ... other hyperparameters remain same ...
INCEPTION_CHANNELS = 2048 # InceptionV3 final feature count
DENSE_CHANNELS = 1920
TOTAL_FEATURES = DENSE_CHANNELS + INCEPTION_CHANNELS

NUM_CLASSES = 2 

class KerasStyleFusion(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        
        self.densenet_base = DenseNetExtractor(pretrained=True)
        self.inception_base = InceptionExtractor(pretrained=True)
        
        # Freeze backbones initially
        for param in self.densenet_base.parameters():
            param.requires_grad = False
        for param in self.inception_base.parameters():
            param.requires_grad = False
            
        self.global_pool = nn.AdaptiveAvgPool2d(1)

        # 1920 (DenseNet) + 2048 (Inception) = 3968
        in_features = DENSE_CHANNELS + INCEPTION_CHANNELS
        
        self.classifier = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes) 
        )
        
    def forward(self, x):
        # Get DenseNet features (expects 224, but works at 299)
        d_feats = self.densenet_base(x)[-1]
        x1 = self.global_pool(d_feats).view(x.size(0), -1)

        # Get Inception features (expects 299)
        i_feats = self.inception_base(x)[-1]
        x2 = self.global_pool(i_feats).view(x.size(0), -1)

        concatenated_features = torch.cat([x1, x2], dim=1)
        logits = self.classifier(concatenated_features)
        
        return {
            "logits": logits, 
            "feat": concatenated_features,
            "seg": torch.randn(x.size(0), 1, x.shape[-2], x.shape[-1]).to(x.device) # Placeholder
        }

# --- 7. LOSS AND ADVERSARIAL HELPERS (From Notebook Cells 17-19) ---

class LabelSmoothingCE(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.s = smoothing
    def forward(self, logits, target):
        c = logits.size(-1)
        logp = F.log_softmax(logits, dim=-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(logp)
            true_dist.fill_(self.s / (c - 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.s)
        return (-true_dist * logp).sum(dim=-1).mean()

class FocalLoss(nn.Module):
    def __init__(self, gamma=1.5):
        super().__init__()
        self.gamma = gamma
    def forward(self, logits, target):
        prob = F.softmax(logits, dim=1)
        pt = prob.gather(1, target.unsqueeze(1)).squeeze(1)
        ce = F.cross_entropy(logits, target, reduction='none')
        loss = ((1 - pt) ** self.gamma) * ce
        return loss.mean()

def dice_loss(pred, target, smooth=1.0):
    pred = torch.sigmoid(pred)
    num = 2 * (pred * target).sum() + smooth
    den = pred.sum() + target.sum() + smooth
    return 1 - (num / den)

def seg_loss_fn(pred, mask):
    if pred.shape[-2:] != mask.shape[-2:]:
        pred = F.interpolate(pred, size=mask.shape[-2:], mode="bilinear", align_corners=False)
    return F.binary_cross_entropy_with_logits(pred, mask) + dice_loss(pred, mask)

class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.cos = nn.CosineSimilarity(dim=-1)
    def forward(self, features, labels):
        device = features.device
        f = F.normalize(features, dim=1)
        sim = torch.matmul(f, f.T) / self.temperature
        labels = labels.contiguous().view(-1,1)
        mask = torch.eq(labels, labels.T).float().to(device)
        logits_max, _ = torch.max(sim, dim=1, keepdim=True)
        logits = sim - logits_max.detach()
        exp_logits = torch.exp(logits) * (1 - torch.eye(len(features), device=device))
        denom = exp_logits.sum(1, keepdim=True)
        pos_mask = mask - torch.eye(len(features), device=device)
        pos_exp = (exp_logits * pos_mask).sum(1)
        loss = -torch.log((pos_exp + 1e-8) / (denom + 1e-8) + 1e-12)
        valid = (pos_mask.sum(1) > 0).float()
        loss = (loss * valid).sum() / (valid.sum() + 1e-8)
        return loss

clf_loss_ce = LabelSmoothingCE(LABEL_SMOOTH)
clf_loss_focal = FocalLoss(gamma=1.5)
supcon_loss_fn = SupConLoss(temperature=0.07)

class GradReverse(Function):
    @staticmethod
    def forward(ctx, x, l):
        ctx.l = l
        return x.view_as(x)
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.l, None

def grad_reverse(x, l=1.0):
    return GradReverse.apply(x, l)


# --- 8. DATA AUGMENTATION HELPERS (From Notebook Cell 20) ---

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def apply_mixup(x, y, alpha=MIXUP_ALPHA):
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0))
    mixed_x = lam * x + (1 - lam) * x[idx]
    y_a, y_b = y, y[idx]
    return mixed_x, y_a, y_b, lam

def apply_cutmix(x, y, alpha=CUTMIX_ALPHA):
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0))
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    new_x = x.clone()
    new_x[:, :, bby1:bby2, bbx1:bbx2] = x[idx, :, bby1:bby2, bbx1:bbx2]
    lam_adjusted = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size(-1) * x.size(-2)))
    return new_x, y, y[idx], lam_adjusted


# --- 9. DATA LOADING SETUP (From Notebook Cell 8) ---
info_dir = "/kaggle/input/cod10k/COD10K-v3/Info"
train_dir_cod = "/kaggle/input/cod10k/COD10K-v3/Train" 
test_dir_cod = "/kaggle/input/cod10k/COD10K-v3/Test"  
    
train_cam_txt = os.path.join(info_dir, "CAM_train.txt")
train_noncam_txt = os.path.join(info_dir, "NonCAM_train.txt")
test_cam_txt = os.path.join(info_dir, "CAM_test.txt")
test_noncam_txt = os.path.join(info_dir, "NonCAM_test.txt")

info_dir2 = "/kaggle/input/camo-coco/CAMO_COCO/Info"
train_cam_txt2 = os.path.join(info_dir2, "camo_train.txt")
train_noncam_txt2 = os.path.join(info_dir2, "non_camo_train.txt")
test_cam_txt2 = os.path.join(info_dir2, "camo_test.txt")
test_noncam_txt2 = os.path.join(info_dir2, "non_camo_test.txt")

train_dir_camo_cam = "/kaggle/input/camo-coco/CAMO_COCO/Camouflage"
train_dir_camo_noncam = "/kaggle/input/camo-coco/CAMO_COCO/Non_Camouflage"

testing_info_file = "/kaggle/input/testing-dataset/Info/image_labels.txt"
testing_images_dir = "/kaggle/input/testing-dataset/Images"

# --- MOCKING DATASET LOADING FOR RUNNABILITY ---
try:
# --- Scenario 1 Changes ---
# 1. 80/20 Split of testing-dataset remains the same
    testing_image_paths, testing_labels = load_testing_dataset_info(testing_info_file, testing_images_dir)
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        testing_image_paths, testing_labels, test_size=0.2, random_state=SEED
    )
    ALL_ROOT_DIRS = [
        train_dir_cod,       
        test_dir_cod,       
        train_dir_camo_cam,  
        train_dir_camo_noncam
    ]

# 2. Training includes basic sets + 80% of testing-dataset
    ALL_TRAIN_TXTS = [
        train_cam_txt2, train_noncam_txt2,
    ]

# 2. Validation uses CAMO-COCO test text files ONLY
    ALL_VAL_TXTS = [test_cam_txt2, test_noncam_txt2]

    train_ds = MultiDataset(
        root_dirs=ALL_ROOT_DIRS, txt_files=ALL_TRAIN_TXTS,
        testing_image_paths=train_paths, testing_labels=train_labels,
        weak_transform=weak_tf, strong_transform=strong_tf, use_masks=USE_SEGMENTATION
    )

    val_ds = MultiDataset(
        root_dirs=ALL_ROOT_DIRS,  
        txt_files=ALL_VAL_TXTS,                 
        testing_image_paths=val_paths,          
        testing_labels=val_labels,              
        weak_transform=val_tf, 
        strong_transform=None, 
        use_masks=USE_SEGMENTATION
    )

    def build_weighted_sampler(dataset):
        labels = [sample[1] for sample in dataset.samples]  
        counts = Counter(labels)
        total = len(labels)
        if len(counts) <= 1:
            weights = [1.0] * total
        else:
            class_weights = {c: total / (counts[c] * len(counts)) for c in counts}
            weights = [class_weights[lbl] for lbl in labels]
        return WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

    train_sampler = build_weighted_sampler(train_ds)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers=NUM_WORKERS)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

except RuntimeError as e:
    print(f"⚠️ Warning: Cannot access Kaggle paths ({e}). Using Mock DataLoaders for demonstration.")
    
    class MockDataset(Dataset):
        def __init__(self, num_samples, num_classes=2, img_size=IMG_SIZE):
            self.num_samples = num_samples
            self.data = torch.randn(num_samples, 3, img_size, img_size)
            self.labels = torch.randint(0, num_classes, (num_samples,))
            self.masks = torch.randint(0, 2, (num_samples, 1, img_size, img_size)).float()

        def __len__(self):
            return self.num_samples

        def __getitem__(self, idx):
            return self.data[idx], self.data[idx].clone(), self.labels[idx], self.masks[idx]

    train_ds = MockDataset(num_samples=14150) 
    val_ds = MockDataset(num_samples=6606)   
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    print(f"Mock DataLoaders initialized with {len(train_ds)} train and {len(val_ds)} val samples.")


# --- 10. MODEL INSTANTIATION AND OPTIMIZER SETUP (From Notebook Cell 20) ---

model = KerasStyleFusion().to(device)

backbone_params = []
head_params = []
for name, param in model.named_parameters():
    # Updated key check for Inception
    if any(k in name for k in ['densenet_base', 'inception_base']):
        backbone_params.append(param)
    else:
        head_params.append(param)

opt = torch.optim.AdamW([
    {'params': backbone_params, 'lr': LR * 0.1}, # Lower LR for heavy backbones
    {'params': head_params, 'lr': LR}
], lr=LR, weight_decay=1e-4)

def get_cosine_with_warmup_scheduler(optimizer, warmup_epochs, total_epochs, last_epoch=-1):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return float(epoch) / float(max(1.0, warmup_epochs))
        t = (epoch - warmup_epochs) / float(max(1, total_epochs - warmup_epochs))
        return 0.5 * (1.0 + math.cos(math.pi * t))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)

scheduler = get_cosine_with_warmup_scheduler(opt, WARMUP_EPOCHS, EPOCHS)
scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))

print(f"\nModel instantiated. LR: {LR}, Epochs: {EPOCHS}.")
print(f"Backbones are initially frozen for {FREEZE_EPOCHS} epochs.")

# --- 11. TRAINING LOOP (Adapted from Notebook Cell 21) ---

def compute_combined_clf_loss(logits, targets, mix_info=None, use_focal=False):
    if mix_info is None:
        if use_focal:
            return clf_loss_focal(logits, targets)
        else:
            return clf_loss_ce(logits, targets)
    else:
        y_a, y_b, lam = mix_info
        if use_focal:
            loss = lam * F.cross_entropy(logits, y_a) + (1 - lam) * F.cross_entropy(logits, y_b)
        else:
            loss = lam * clf_loss_ce(logits, y_a) + (1 - lam) * clf_loss_ce(logits, y_b)
        return loss

best_vf1 = 0.0
best_epoch = 0
patience_count = 0

for epoch in range(1, EPOCHS+1):
    # --- Freeze/Unfreeze Logic ---
    if epoch <= FREEZE_EPOCHS:
        for name, p in model.named_parameters():
            if any(k in name for k in ['densenet_base', 'mobilenet_base']):
                p.requires_grad = False
    elif epoch == FREEZE_EPOCHS + 1:
        print(f"--- Unfreezing all backbone layers at epoch {epoch} ---")
        for p in model.parameters():
            p.requires_grad = True

    model.train()
    running_loss = 0.0
    y_true, y_pred = [], []
    n_batches = 0

    opt.zero_grad() 
    
    for i, (weak_imgs, strong_imgs, labels, masks) in enumerate(tqdm(train_loader, desc=f"Train {epoch}/{EPOCHS}")):
        weak_imgs = weak_imgs.to(device); strong_imgs = strong_imgs.to(device)
        labels = labels.to(device)
        if masks is not None:
            masks = masks.to(device)

        imgs = weak_imgs

        mix_info = None
        rand = random.random()
        if rand < PROB_MIXUP:
            imgs, y_a, y_b, lam = apply_mixup(imgs, labels)
            mix_info = (y_a.to(device), y_b.to(device), lam)
        elif rand < PROB_MIXUP + PROB_CUTMIX:
            imgs, y_a, y_b, lam = apply_cutmix(imgs, labels)
            mix_info = (y_a.to(device), y_b.to(device), lam)

        with torch.autocast(device_type=device, enabled=(device=="cuda")):
            out = model(imgs) 
            logits = out["logits"]
            
            clf_loss = compute_combined_clf_loss(logits, labels, mix_info=mix_info, use_focal=False)

            # Auxiliary losses are set to 0 as the user's simple fusion model excludes these complex heads.
            seg_loss = 0.0
            supcon_loss = 0.0 
            cons_loss = 0.0   
            dom_loss = 0.0

            total_loss = clf_loss + GAMMA_SEG * seg_loss + BETA_SUPCON * supcon_loss + ETA_CONS * cons_loss + ALPHA_DOM * dom_loss
            
            total_loss = total_loss / ACCUMULATION_STEPS 

        scaler.scale(total_loss).backward()

        if (i + 1) % ACCUMULATION_STEPS == 0:
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(opt)
            scaler.update()
            opt.zero_grad() 

        running_loss += total_loss.item() * ACCUMULATION_STEPS
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(logits.argmax(1).cpu().numpy())
        n_batches += 1

    if n_batches % ACCUMULATION_STEPS != 0:
        scaler.unscale_(opt)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(opt)
        scaler.update()
        opt.zero_grad()

    scheduler.step()

    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="macro", zero_division=0)
    print(f"[Epoch {epoch}] Train Loss: {running_loss/max(1,n_batches):.4f} Acc: {acc:.4f} Prec: {prec:.4f} Rec: {rec:.4f} F1: {f1:.4f}")

    # -------------------
    # VALIDATION
    # -------------------
    model.eval()
    val_y_true, val_y_pred = [], []
    val_loss = 0.0
    with torch.no_grad():
        for weak_imgs, _, labels, masks in val_loader:
            imgs = weak_imgs.to(device)
            labels = labels.to(device)
            if masks is not None:
                masks = masks.to(device)

            out = model(imgs)
            logits = out["logits"]
            
            loss = compute_combined_clf_loss(logits, labels, mix_info=None, use_focal=False)
            if USE_SEGMENTATION and (masks is not None):
                loss += GAMMA_SEG * seg_loss_fn(out["seg"], masks) 
            val_loss += loss.item()

            val_y_true.extend(labels.cpu().numpy())
            val_y_pred.extend(logits.argmax(1).cpu().numpy())

    vacc = accuracy_score(val_y_true, val_y_pred)
    vprec, vrec, vf1, _ = precision_recall_fscore_support(val_y_true, val_y_pred, average="macro", zero_division=0)
    print(f"[Epoch {epoch}] Val Loss: {val_loss/max(1,len(val_loader)):.4f} Acc: {vacc:.4f} Prec: {vprec:.4f} Rec: {vrec:.4f} F1: {vf1:.4f}")

    if vf1 > best_vf1:
        best_vf1 = vf1
        best_epoch = epoch
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "opt_state": opt.state_dict(),
            "best_vf1": best_vf1
        }, SAVE_PATH)
        patience_count = 0
        print(f"Saved best model at epoch {epoch} (F1 {best_vf1:.4f})")
    else:
        patience_count += 1
        if patience_count >= EARLY_STOPPING_PATIENCE:
            print("Early stopping triggered.")
            break

print("\nTraining finished. Best val F1:", best_vf1, "at epoch", best_epoch)

Device: cuda
✅ Loaded 8605 samples from 4 root directories.
✅ Loaded 2152 samples from 4 root directories.
Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth


100%|██████████| 104M/104M [00:00<00:00, 183MB/s] 
  scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))



Model instantiated. LR: 0.0003, Epochs: 15.
Backbones are initially frozen for 5 epochs.


Train 1/15: 100%|██████████| 1076/1076 [01:42<00:00, 10.54it/s]

[Epoch 1] Train Loss: 0.6917 Acc: 0.5117 Prec: 0.5417 Rec: 0.5134 F1: 0.4133





[Epoch 1] Val Loss: 1.5427 Acc: 0.6083 Prec: 0.3041 Rec: 0.5000 F1: 0.3782
Saved best model at epoch 1 (F1 0.3782)


Train 2/15: 100%|██████████| 1076/1076 [01:42<00:00, 10.52it/s]

[Epoch 2] Train Loss: 0.5270 Acc: 0.7568 Prec: 0.7568 Rec: 0.7567 F1: 0.7567





[Epoch 2] Val Loss: 1.2601 Acc: 0.9480 Prec: 0.9580 Rec: 0.9350 F1: 0.9441
Saved best model at epoch 2 (F1 0.9441)


Train 3/15: 100%|██████████| 1076/1076 [01:43<00:00, 10.37it/s]

[Epoch 3] Train Loss: 0.4999 Acc: 0.7890 Prec: 0.7890 Rec: 0.7890 F1: 0.7890





[Epoch 3] Val Loss: 1.2406 Acc: 0.9642 Prec: 0.9706 Rec: 0.9554 F1: 0.9619
Saved best model at epoch 3 (F1 0.9619)


Train 4/15: 100%|██████████| 1076/1076 [01:42<00:00, 10.46it/s]

[Epoch 4] Train Loss: 0.4886 Acc: 0.7873 Prec: 0.7872 Rec: 0.7872 F1: 0.7872





[Epoch 4] Val Loss: 1.2246 Acc: 0.9828 Prec: 0.9847 Rec: 0.9793 F1: 0.9819
Saved best model at epoch 4 (F1 0.9819)


Train 5/15: 100%|██████████| 1076/1076 [01:41<00:00, 10.57it/s]

[Epoch 5] Train Loss: 0.4877 Acc: 0.7872 Prec: 0.7872 Rec: 0.7872 F1: 0.7872





[Epoch 5] Val Loss: 1.2113 Acc: 0.9893 Prec: 0.9876 Rec: 0.9902 F1: 0.9888
Saved best model at epoch 5 (F1 0.9888)
--- Unfreezing all backbone layers at epoch 6 ---


Train 6/15: 100%|██████████| 1076/1076 [03:27<00:00,  5.19it/s]

[Epoch 6] Train Loss: 0.4555 Acc: 0.8189 Prec: 0.8189 Rec: 0.8189 F1: 0.8189





[Epoch 6] Val Loss: 1.2054 Acc: 0.9851 Prec: 0.9817 Rec: 0.9878 F1: 0.9845


Train 7/15: 100%|██████████| 1076/1076 [03:27<00:00,  5.18it/s]

[Epoch 7] Train Loss: 0.4372 Acc: 0.8307 Prec: 0.8307 Rec: 0.8307 F1: 0.8307





[Epoch 7] Val Loss: 1.1997 Acc: 0.9926 Prec: 0.9907 Rec: 0.9939 F1: 0.9922
Saved best model at epoch 7 (F1 0.9922)


Train 8/15: 100%|██████████| 1076/1076 [03:28<00:00,  5.17it/s]

[Epoch 8] Train Loss: 0.4343 Acc: 0.8450 Prec: 0.8450 Rec: 0.8449 F1: 0.8450





[Epoch 8] Val Loss: 1.1961 Acc: 0.9944 Prec: 0.9930 Rec: 0.9954 F1: 0.9942
Saved best model at epoch 8 (F1 0.9942)


Train 9/15: 100%|██████████| 1076/1076 [03:26<00:00,  5.21it/s]

[Epoch 9] Train Loss: 0.4272 Acc: 0.8402 Prec: 0.8402 Rec: 0.8402 F1: 0.8402





[Epoch 9] Val Loss: 1.1964 Acc: 0.9935 Prec: 0.9920 Rec: 0.9944 F1: 0.9932


Train 10/15: 100%|██████████| 1076/1076 [03:27<00:00,  5.18it/s]

[Epoch 10] Train Loss: 0.4277 Acc: 0.8418 Prec: 0.8418 Rec: 0.8418 F1: 0.8418





[Epoch 10] Val Loss: 1.1982 Acc: 0.9907 Prec: 0.9886 Rec: 0.9921 F1: 0.9903


Train 11/15: 100%|██████████| 1076/1076 [03:27<00:00,  5.19it/s]

[Epoch 11] Train Loss: 0.4243 Acc: 0.8404 Prec: 0.8404 Rec: 0.8404 F1: 0.8404





[Epoch 11] Val Loss: 1.1977 Acc: 0.9916 Prec: 0.9897 Rec: 0.9929 F1: 0.9913


Train 12/15: 100%|██████████| 1076/1076 [03:28<00:00,  5.16it/s]

[Epoch 12] Train Loss: 0.4245 Acc: 0.8537 Prec: 0.8536 Rec: 0.8536 F1: 0.8536





[Epoch 12] Val Loss: 1.1969 Acc: 0.9921 Prec: 0.9903 Rec: 0.9933 F1: 0.9917


Train 13/15: 100%|██████████| 1076/1076 [03:28<00:00,  5.15it/s]

[Epoch 13] Train Loss: 0.4211 Acc: 0.8526 Prec: 0.8527 Rec: 0.8526 F1: 0.8526





[Epoch 13] Val Loss: 1.1961 Acc: 0.9921 Prec: 0.9903 Rec: 0.9933 F1: 0.9917


Train 14/15: 100%|██████████| 1076/1076 [03:27<00:00,  5.19it/s]

[Epoch 14] Train Loss: 0.4245 Acc: 0.8501 Prec: 0.8501 Rec: 0.8501 F1: 0.8501





[Epoch 14] Val Loss: 1.1944 Acc: 0.9940 Prec: 0.9926 Rec: 0.9948 F1: 0.9937


Train 15/15: 100%|██████████| 1076/1076 [03:28<00:00,  5.15it/s]

[Epoch 15] Train Loss: 0.4273 Acc: 0.8436 Prec: 0.8436 Rec: 0.8436 F1: 0.8436





[Epoch 15] Val Loss: 1.1949 Acc: 0.9926 Prec: 0.9909 Rec: 0.9937 F1: 0.9922
Early stopping triggered.

Training finished. Best val F1: 0.9941640519703103 at epoch 8


In [4]:
import os, random, math, time
from pathlib import Path
from tqdm import tqdm
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models
from torchvision.transforms import RandAugment
import timm 
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from collections import Counter
from torch.autograd import Function

# --- 1. SETUP AND HYPERPARAMETERS (From Notebook Cell 4 & 5) ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device == "cuda": torch.cuda.manual_seed_all(SEED)

IMG_SIZE = 224
BATCH_SIZE = 8         
EPOCHS = 15          
NUM_WORKERS = 4         
LR = 3e-4              
LABEL_SMOOTH = 0.1
SAVE_PATH = "best_model.pth"
USE_SEGMENTATION = True 

# Loss weights 
ALPHA_DOM = 0.5
BETA_SUPCON = 0.3
ETA_CONS = 0.1
GAMMA_SEG = 0.5

# Mixup/CutMix probabilities and alphas
PROB_MIXUP = 0.5
PROB_CUTMIX = 0.5
MIXUP_ALPHA = 0.2
CUTMIX_ALPHA = 1.0

# Warmup epochs and accumulation steps
WARMUP_EPOCHS = 5
EARLY_STOPPING_PATIENCE = 7
FREEZE_EPOCHS = 5
ACCUMULATION_STEPS = 4

# --- 2. TRANSFORMS (From Notebook Cell 6) ---
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=0.05):
        self.mean = mean
        self.std = std
    def __call__(self, tensor):
        noise = torch.randn(tensor.size()) * self.std + self.mean
        noisy_tensor = tensor + noise
        return torch.clamp(noisy_tensor, 0., 1.)
    def __repr__(self):
        return self.__class__.__name__ + f'(mean={self.mean}, std={self.std})'

weak_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.02),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
strong_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ColorJitter(0.4,0.4,0.4,0.1),
    RandAugment(num_ops=2, magnitude=9),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.05),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
val_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

# --- 3. DATASET HELPER FUNCTIONS (From Notebook Cell 7) ---
def read_file_with_encoding(file_path, encodings=['utf-8', 'utf-8-sig', 'ISO-8859-1']):
    for encoding in encodings:
        try:
            with open(file_path, 'r', encoding=encoding) as f:
                return f.readlines()
        except UnicodeDecodeError:
            pass
        except Exception as e:
            raise 
    raise RuntimeError(f"Unable to read {file_path} with any of the provided encodings.")

def load_testing_dataset_info(info_file, image_dir):
    image_paths = []
    labels = []
    encodings = ['utf-8-sig', 'utf-8', 'ISO-8859-1', 'latin-1']
    lines = []
    
    for encoding in encodings:
        try:
            with open(info_file, 'r', encoding=encoding) as f:
                lines = f.readlines()
            break
        except UnicodeDecodeError:
            pass
        except Exception as e:
            raise 

    for line in lines:
        parts = line.strip().split()
        if len(parts) == 2:
            image_filename = parts[0]
            try:
                label = int(parts[1])
            except ValueError:
                continue
                
            label = 1 if label == 1 else 0
            image_full_path = os.path.join(image_dir, image_filename)
            image_paths.append(image_full_path)
            labels.append(label)
    
    return image_paths, labels

# --- 4. MULTIDATASET CLASS (From Notebook Cell 7) ---
class MultiDataset(Dataset):
    def __init__(self, root_dirs, txt_files, testing_image_paths=None, testing_labels=None, weak_transform=None, strong_transform=None, use_masks=True):
        self.root_dirs = root_dirs
        self.weak_transform = weak_transform
        self.strong_transform = strong_transform
        self.use_masks = use_masks
        self.samples = []

        if testing_image_paths is not None and testing_labels is not None:
            for img_path, label in zip(testing_image_paths, testing_labels):
                self.samples.append((img_path, label)) 
        
        if isinstance(txt_files, str):
            txt_files = [txt_files]

        all_lines = []
        for t in txt_files:
            if not os.path.exists(t):
                raise RuntimeError(f"TXT file not found: {t}. Please ensure all Kaggle input datasets are mounted.")

            lines = read_file_with_encoding(t)
            all_lines.extend([(line.strip(), t) for line in lines if line.strip()])

        for line, src_txt in all_lines:
            parts = line.split()
            if len(parts) == 0:
                continue

            fname = parts[0]
            if len(parts) >= 2:
                try:
                    lbl = int(parts[1])
                except:
                    lbl = 1 if "CAM" in fname or "cam" in fname else 0
            else:
                lbl = 1 if "CAM" in fname or "cam" in fname else 0
            
            lbl = 1 if lbl == 1 else 0
            base_fname = os.path.basename(fname)  

            found = False
            search_subs = [
                "", "Image", "Imgs", "images", "JPEGImages", "img", 
                "Images/Train", "Images/Test",
            ]
            
            for rdir in self.root_dirs:
                for sub in search_subs:
                    img_path = os.path.join(rdir, sub, base_fname)
                    if os.path.exists(img_path):
                        self.samples.append((img_path, lbl, rdir))
                        found = True
                        break
                if found:
                    break

        if len(self.samples) == 0:
            raise RuntimeError(f"No valid samples found from {txt_files}")

        print(f"✅ Loaded {len(self.samples)} samples from {len(self.root_dirs)} root directories.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        global IMG_SIZE 
        
        sample = self.samples[idx]
        img_path = sample[0]
        lbl = sample[1]
        
        if len(sample) == 3:
            rdir = sample[2]
        else:
            rdir = os.path.dirname(os.path.dirname(img_path))
            
        try:
            img = Image.open(img_path).convert("RGB")
        except:
            img = Image.fromarray(np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8))
            
        if self.weak_transform:
            weak = self.weak_transform(img)
        else:
            weak = transforms.ToTensor()(img)
            
        if self.strong_transform:
            strong = self.strong_transform(img)
        else:
            strong = weak.clone()

        mask = None
        if self.use_masks:
            mask_name = os.path.splitext(os.path.basename(img_path))[0] + ".png"
            found_mask = False
            for mask_dir in ["GT_Object", "GT", "masks", "Mask"]:
                mask_path = os.path.join(rdir, mask_dir, mask_name)
                
                if os.path.exists(mask_path):
                    m = Image.open(mask_path).convert("L").resize((IMG_SIZE, IMG_SIZE))
                    m = np.array(m).astype(np.float32) / 255.0
                    mask = torch.from_numpy((m > 0.5).astype(np.float32)).unsqueeze(0)
                    found_mask = True
                    break

            if mask is None:
                mask = torch.zeros((1, IMG_SIZE, IMG_SIZE), dtype=torch.float32)
                
        return weak, strong, lbl, mask

# --- 5. BACKBONE EXTRACTORS (From Notebook Cell 9) ---

# --- UPDATED CONSTANTS ---
# DenseNet169 final features: 1664
# MobileNetV3 Large final features before global pool: 112 (or 160/960 depending on layer, 
# but based on your previous code 112 is the expected mid-layer output)
DENSE_CHANNELS = 1664  
MOBILE_CHANNELS = 112  
TOTAL_FEATURES = DENSE_CHANNELS + MOBILE_CHANNELS # 1664 + 112 = 1776

class DenseNetExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        # Changed from densenet201 to densenet169
        self.features = models.densenet169(weights='IMAGENET1K_V1' if pretrained else None).features
        
    def forward(self, x):
        feats = []
        for name, layer in self.features._modules.items():
            x = layer(x)
            # The block names remain the same across DenseNet variants
            if name in ["denseblock1", "denseblock2", "denseblock3", "denseblock4"]:
                feats.append(x)
        return feats

class MobileNetExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.mobilenet_v3_large(weights='IMAGENET1K_V1' if pretrained else None).features
    def forward(self, x):
        feats = []
        out = x
        for i, layer in enumerate(self.features):
            out = layer(out)
            if i in (2,5,9,12):
                feats.append(out)
        if len(feats) < 4:
            feats.append(out)
        return feats

# --- 6. USER'S Keras-Style Fusion Model (Corrected) ---


NUM_CLASSES = 2 

class KerasStyleFusion(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, dense_out_channels=DENSE_CHANNELS, mobile_out_channels=MOBILE_CHANNELS):
        super().__init__()
        
        self.densenet_base = DenseNetExtractor(pretrained=True)
        self.mobilenet_base = MobileNetExtractor(pretrained=True)
        
        # Freezing logic
        for param in self.densenet_base.parameters():
            param.requires_grad = False
        for param in self.mobilenet_base.parameters():
            param.requires_grad = False
            
        self.global_pool = nn.AdaptiveAvgPool2d(1)

        in_features = dense_out_channels + mobile_out_channels
        
        # Input features are now 1776 (1664 + 112)
        self.classifier = nn.Sequential(
            nn.Linear(in_features, 512), 
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes) 
        )
        
    def forward(self, x, grl_lambda=0.0):
        # Extract features from DenseNet169
        densenet_feats = self.densenet_base(x)[-1]
        x1 = self.global_pool(densenet_feats).view(x.size(0), -1)

        # Extract features from MobileNetV3
        mobilenet_feats = self.mobilenet_base(x)[-1]
        x2 = self.global_pool(mobilenet_feats).view(x.size(0), -1)

        # Fusion
        concatenated_features = torch.cat([x1, x2], dim=1)

        logits = self.classifier(concatenated_features)
        
        # Maintaining your dictionary output structure
        out = {"logits": logits, "feat": concatenated_features}
        out["domain_logits"] = torch.zeros(x.size(0), 2).to(x.device) 
        out["seg"] = torch.zeros(x.size(0), 1, IMG_SIZE, IMG_SIZE).to(x.device)
        
        return out

# --- 7. LOSS AND ADVERSARIAL HELPERS (From Notebook Cells 17-19) ---

class LabelSmoothingCE(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.s = smoothing
    def forward(self, logits, target):
        c = logits.size(-1)
        logp = F.log_softmax(logits, dim=-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(logp)
            true_dist.fill_(self.s / (c - 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.s)
        return (-true_dist * logp).sum(dim=-1).mean()

class FocalLoss(nn.Module):
    def __init__(self, gamma=1.5):
        super().__init__()
        self.gamma = gamma
    def forward(self, logits, target):
        prob = F.softmax(logits, dim=1)
        pt = prob.gather(1, target.unsqueeze(1)).squeeze(1)
        ce = F.cross_entropy(logits, target, reduction='none')
        loss = ((1 - pt) ** self.gamma) * ce
        return loss.mean()

def dice_loss(pred, target, smooth=1.0):
    pred = torch.sigmoid(pred)
    num = 2 * (pred * target).sum() + smooth
    den = pred.sum() + target.sum() + smooth
    return 1 - (num / den)

def seg_loss_fn(pred, mask):
    if pred.shape[-2:] != mask.shape[-2:]:
        pred = F.interpolate(pred, size=mask.shape[-2:], mode="bilinear", align_corners=False)
    return F.binary_cross_entropy_with_logits(pred, mask) + dice_loss(pred, mask)

class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.cos = nn.CosineSimilarity(dim=-1)
    def forward(self, features, labels):
        device = features.device
        f = F.normalize(features, dim=1)
        sim = torch.matmul(f, f.T) / self.temperature
        labels = labels.contiguous().view(-1,1)
        mask = torch.eq(labels, labels.T).float().to(device)
        logits_max, _ = torch.max(sim, dim=1, keepdim=True)
        logits = sim - logits_max.detach()
        exp_logits = torch.exp(logits) * (1 - torch.eye(len(features), device=device))
        denom = exp_logits.sum(1, keepdim=True)
        pos_mask = mask - torch.eye(len(features), device=device)
        pos_exp = (exp_logits * pos_mask).sum(1)
        loss = -torch.log((pos_exp + 1e-8) / (denom + 1e-8) + 1e-12)
        valid = (pos_mask.sum(1) > 0).float()
        loss = (loss * valid).sum() / (valid.sum() + 1e-8)
        return loss

clf_loss_ce = LabelSmoothingCE(LABEL_SMOOTH)
clf_loss_focal = FocalLoss(gamma=1.5)
supcon_loss_fn = SupConLoss(temperature=0.07)

class GradReverse(Function):
    @staticmethod
    def forward(ctx, x, l):
        ctx.l = l
        return x.view_as(x)
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.l, None

def grad_reverse(x, l=1.0):
    return GradReverse.apply(x, l)


# --- 8. DATA AUGMENTATION HELPERS (From Notebook Cell 20) ---

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def apply_mixup(x, y, alpha=MIXUP_ALPHA):
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0))
    mixed_x = lam * x + (1 - lam) * x[idx]
    y_a, y_b = y, y[idx]
    return mixed_x, y_a, y_b, lam

def apply_cutmix(x, y, alpha=CUTMIX_ALPHA):
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0))
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    new_x = x.clone()
    new_x[:, :, bby1:bby2, bbx1:bbx2] = x[idx, :, bby1:bby2, bbx1:bbx2]
    lam_adjusted = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size(-1) * x.size(-2)))
    return new_x, y, y[idx], lam_adjusted


# --- 9. DATA LOADING SETUP (From Notebook Cell 8) ---
info_dir = "/kaggle/input/cod10k/COD10K-v3/Info"
train_dir_cod = "/kaggle/input/cod10k/COD10K-v3/Train" 
test_dir_cod = "/kaggle/input/cod10k/COD10K-v3/Test"  
    
train_cam_txt = os.path.join(info_dir, "CAM_train.txt")
train_noncam_txt = os.path.join(info_dir, "NonCAM_train.txt")
test_cam_txt = os.path.join(info_dir, "CAM_test.txt")
test_noncam_txt = os.path.join(info_dir, "NonCAM_test.txt")

info_dir2 = "/kaggle/input/camo-coco/CAMO_COCO/Info"
train_cam_txt2 = os.path.join(info_dir2, "camo_train.txt")
train_noncam_txt2 = os.path.join(info_dir2, "non_camo_train.txt")
test_cam_txt2 = os.path.join(info_dir2, "camo_test.txt")
test_noncam_txt2 = os.path.join(info_dir2, "non_camo_test.txt")

train_dir_camo_cam = "/kaggle/input/camo-coco/CAMO_COCO/Camouflage"
train_dir_camo_noncam = "/kaggle/input/camo-coco/CAMO_COCO/Non_Camouflage"

testing_info_file = "/kaggle/input/testing-dataset/Info/image_labels.txt"
testing_images_dir = "/kaggle/input/testing-dataset/Images"

# --- MOCKING DATASET LOADING FOR RUNNABILITY ---
try:
    testing_image_paths, testing_labels = load_testing_dataset_info(testing_info_file, testing_images_dir)
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        testing_image_paths, testing_labels, test_size=0.8, random_state=SEED
    )
    ALL_ROOT_DIRS = [
        train_dir_cod,       
        test_dir_cod,       
        train_dir_camo_cam,  
        train_dir_camo_noncam
    ]
    ALL_TRAIN_TXTS = [
        train_cam_txt2, train_noncam_txt2,
    ]
    ALL_VAL_TXTS = [test_cam_txt2, test_noncam_txt2]
    
    train_ds = MultiDataset(
        root_dirs=ALL_ROOT_DIRS, 
        txt_files=ALL_TRAIN_TXTS,               
        testing_image_paths=train_paths,        
        testing_labels=train_labels,            
        weak_transform=weak_tf, 
        strong_transform=strong_tf, 
        use_masks=USE_SEGMENTATION
    )
    val_ds = MultiDataset(
        root_dirs=ALL_ROOT_DIRS,  
        txt_files=ALL_VAL_TXTS,                 
        testing_image_paths=val_paths,          
        testing_labels=val_labels,              
        weak_transform=val_tf, 
        strong_transform=None, 
        use_masks=USE_SEGMENTATION
    )

    def build_weighted_sampler(dataset):
        labels = [sample[1] for sample in dataset.samples]  
        counts = Counter(labels)
        total = len(labels)
        if len(counts) <= 1:
            weights = [1.0] * total
        else:
            class_weights = {c: total / (counts[c] * len(counts)) for c in counts}
            weights = [class_weights[lbl] for lbl in labels]
        return WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

    train_sampler = build_weighted_sampler(train_ds)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers=NUM_WORKERS)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

except RuntimeError as e:
    print(f"⚠️ Warning: Cannot access Kaggle paths ({e}). Using Mock DataLoaders for demonstration.")
    
    class MockDataset(Dataset):
        def __init__(self, num_samples, num_classes=2, img_size=IMG_SIZE):
            self.num_samples = num_samples
            self.data = torch.randn(num_samples, 3, img_size, img_size)
            self.labels = torch.randint(0, num_classes, (num_samples,))
            self.masks = torch.randint(0, 2, (num_samples, 1, img_size, img_size)).float()

        def __len__(self):
            return self.num_samples

        def __getitem__(self, idx):
            return self.data[idx], self.data[idx].clone(), self.labels[idx], self.masks[idx]

    train_ds = MockDataset(num_samples=14150) 
    val_ds = MockDataset(num_samples=6606)   
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    print(f"Mock DataLoaders initialized with {len(train_ds)} train and {len(val_ds)} val samples.")


# --- 10. MODEL INSTANTIATION AND OPTIMIZER SETUP (From Notebook Cell 20) ---

model = KerasStyleFusion().to(device)

backbone_params = []
head_params = []
for name, param in model.named_parameters():
    if any(k in name for k in ['densenet_base', 'mobilenet_base']):
        backbone_params.append(param)
    else:
        head_params.append(param)

opt = torch.optim.AdamW([
    {'params': backbone_params, 'lr': LR * 0.2}, 
    {'params': head_params, 'lr': LR}
], lr=LR, weight_decay=1e-4)

def get_cosine_with_warmup_scheduler(optimizer, warmup_epochs, total_epochs, last_epoch=-1):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return float(epoch) / float(max(1.0, warmup_epochs))
        t = (epoch - warmup_epochs) / float(max(1, total_epochs - warmup_epochs))
        return 0.5 * (1.0 + math.cos(math.pi * t))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)

scheduler = get_cosine_with_warmup_scheduler(opt, WARMUP_EPOCHS, EPOCHS)
scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))

print(f"\nModel instantiated. LR: {LR}, Epochs: {EPOCHS}.")
print(f"Backbones are initially frozen for {FREEZE_EPOCHS} epochs.")

# --- 11. TRAINING LOOP (Adapted from Notebook Cell 21) ---

def compute_combined_clf_loss(logits, targets, mix_info=None, use_focal=False):
    if mix_info is None:
        if use_focal:
            return clf_loss_focal(logits, targets)
        else:
            return clf_loss_ce(logits, targets)
    else:
        y_a, y_b, lam = mix_info
        if use_focal:
            loss = lam * F.cross_entropy(logits, y_a) + (1 - lam) * F.cross_entropy(logits, y_b)
        else:
            loss = lam * clf_loss_ce(logits, y_a) + (1 - lam) * clf_loss_ce(logits, y_b)
        return loss

best_vf1 = 0.0
best_epoch = 0
patience_count = 0

for epoch in range(1, EPOCHS+1):
    # --- Freeze/Unfreeze Logic ---
    if epoch <= FREEZE_EPOCHS:
        for name, p in model.named_parameters():
            if any(k in name for k in ['densenet_base', 'mobilenet_base']):
                p.requires_grad = False
    elif epoch == FREEZE_EPOCHS + 1:
        print(f"--- Unfreezing all backbone layers at epoch {epoch} ---")
        for p in model.parameters():
            p.requires_grad = True

    model.train()
    running_loss = 0.0
    y_true, y_pred = [], []
    n_batches = 0

    opt.zero_grad() 
    
    for i, (weak_imgs, strong_imgs, labels, masks) in enumerate(tqdm(train_loader, desc=f"Train {epoch}/{EPOCHS}")):
        weak_imgs = weak_imgs.to(device); strong_imgs = strong_imgs.to(device)
        labels = labels.to(device)
        if masks is not None:
            masks = masks.to(device)

        imgs = weak_imgs

        mix_info = None
        rand = random.random()
        if rand < PROB_MIXUP:
            imgs, y_a, y_b, lam = apply_mixup(imgs, labels)
            mix_info = (y_a.to(device), y_b.to(device), lam)
        elif rand < PROB_MIXUP + PROB_CUTMIX:
            imgs, y_a, y_b, lam = apply_cutmix(imgs, labels)
            mix_info = (y_a.to(device), y_b.to(device), lam)

        with torch.autocast(device_type=device, enabled=(device=="cuda")):
            out = model(imgs) 
            logits = out["logits"]
            
            clf_loss = compute_combined_clf_loss(logits, labels, mix_info=mix_info, use_focal=False)

            # Auxiliary losses are set to 0 as the user's simple fusion model excludes these complex heads.
            seg_loss = 0.0
            supcon_loss = 0.0 
            cons_loss = 0.0   
            dom_loss = 0.0

            total_loss = clf_loss + GAMMA_SEG * seg_loss + BETA_SUPCON * supcon_loss + ETA_CONS * cons_loss + ALPHA_DOM * dom_loss
            
            total_loss = total_loss / ACCUMULATION_STEPS 

        scaler.scale(total_loss).backward()

        if (i + 1) % ACCUMULATION_STEPS == 0:
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(opt)
            scaler.update()
            opt.zero_grad() 

        running_loss += total_loss.item() * ACCUMULATION_STEPS
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(logits.argmax(1).cpu().numpy())
        n_batches += 1

    if n_batches % ACCUMULATION_STEPS != 0:
        scaler.unscale_(opt)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(opt)
        scaler.update()
        opt.zero_grad()

    scheduler.step()

    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="macro", zero_division=0)
    print(f"[Epoch {epoch}] Train Loss: {running_loss/max(1,n_batches):.4f} Acc: {acc:.4f} Prec: {prec:.4f} Rec: {rec:.4f} F1: {f1:.4f}")

    # -------------------
    # VALIDATION
    # -------------------
    model.eval()
    val_y_true, val_y_pred = [], []
    val_loss = 0.0
    with torch.no_grad():
        for weak_imgs, _, labels, masks in val_loader:
            imgs = weak_imgs.to(device)
            labels = labels.to(device)
            if masks is not None:
                masks = masks.to(device)

            out = model(imgs)
            logits = out["logits"]
            
            loss = compute_combined_clf_loss(logits, labels, mix_info=None, use_focal=False)
            if USE_SEGMENTATION and (masks is not None):
                loss += GAMMA_SEG * seg_loss_fn(out["seg"], masks) 
            val_loss += loss.item()

            val_y_true.extend(labels.cpu().numpy())
            val_y_pred.extend(logits.argmax(1).cpu().numpy())

    vacc = accuracy_score(val_y_true, val_y_pred)
    vprec, vrec, vf1, _ = precision_recall_fscore_support(val_y_true, val_y_pred, average="macro", zero_division=0)
    print(f"[Epoch {epoch}] Val Loss: {val_loss/max(1,len(val_loader)):.4f} Acc: {vacc:.4f} Prec: {vprec:.4f} Rec: {vrec:.4f} F1: {vf1:.4f}")

    if vf1 > best_vf1:
        best_vf1 = vf1
        best_epoch = epoch
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "opt_state": opt.state_dict(),
            "best_vf1": best_vf1
        }, SAVE_PATH)
        patience_count = 0
        print(f"Saved best model at epoch {epoch} (F1 {best_vf1:.4f})")
    else:
        patience_count += 1
        if patience_count >= EARLY_STOPPING_PATIENCE:
            print("Early stopping triggered.")
            break

print("\nTraining finished. Best val F1:", best_vf1, "at epoch", best_epoch)

Device: cuda
✅ Loaded 3651 samples from 4 root directories.
✅ Loaded 7106 samples from 4 root directories.
Downloading: "https://download.pytorch.org/models/densenet169-b2777c0a.pth" to /root/.cache/torch/hub/checkpoints/densenet169-b2777c0a.pth


100%|██████████| 54.7M/54.7M [00:00<00:00, 191MB/s]
  scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))



Model instantiated. LR: 0.0003, Epochs: 15.
Backbones are initially frozen for 5 epochs.


Train 1/15: 100%|██████████| 457/457 [00:41<00:00, 11.05it/s]

[Epoch 1] Train Loss: 0.6929 Acc: 0.5004 Prec: 0.5075 Rec: 0.5019 F1: 0.3880





[Epoch 1] Val Loss: 1.5345 Acc: 0.5317 Prec: 0.3766 Rec: 0.4956 F1: 0.3509
Saved best model at epoch 1 (F1 0.3509)


Train 2/15: 100%|██████████| 457/457 [00:38<00:00, 11.81it/s]

[Epoch 2] Train Loss: 0.6189 Acc: 0.7425 Prec: 0.7462 Rec: 0.7420 F1: 0.7413





[Epoch 2] Val Loss: 1.2672 Acc: 0.9648 Prec: 0.9641 Rec: 0.9659 F1: 0.9647
Saved best model at epoch 2 (F1 0.9647)


Train 3/15: 100%|██████████| 457/457 [00:38<00:00, 11.73it/s]

[Epoch 3] Train Loss: 0.4852 Acc: 0.7990 Prec: 0.7989 Rec: 0.7989 F1: 0.7989





[Epoch 3] Val Loss: 1.2966 Acc: 0.9454 Prec: 0.9457 Rec: 0.9445 F1: 0.9450


Train 4/15: 100%|██████████| 457/457 [00:38<00:00, 11.75it/s]

[Epoch 4] Train Loss: 0.4687 Acc: 0.7962 Prec: 0.7962 Rec: 0.7962 F1: 0.7962





[Epoch 4] Val Loss: 1.2953 Acc: 0.9671 Prec: 0.9666 Rec: 0.9689 F1: 0.9670
Saved best model at epoch 4 (F1 0.9670)


Train 5/15: 100%|██████████| 457/457 [00:39<00:00, 11.68it/s]

[Epoch 5] Train Loss: 0.4721 Acc: 0.8124 Prec: 0.8119 Rec: 0.8121 F1: 0.8120





[Epoch 5] Val Loss: 1.3059 Acc: 0.9706 Prec: 0.9699 Rec: 0.9720 F1: 0.9705
Saved best model at epoch 5 (F1 0.9705)
--- Unfreezing all backbone layers at epoch 6 ---


Train 6/15: 100%|██████████| 457/457 [01:14<00:00,  6.13it/s]

[Epoch 6] Train Loss: 0.4652 Acc: 0.8209 Prec: 0.8208 Rec: 0.8209 F1: 0.8208





[Epoch 6] Val Loss: 1.2901 Acc: 0.9740 Prec: 0.9733 Rec: 0.9748 F1: 0.9739
Saved best model at epoch 6 (F1 0.9739)


Train 7/15: 100%|██████████| 457/457 [01:09<00:00,  6.58it/s]


[Epoch 7] Train Loss: 0.4419 Acc: 0.8329 Prec: 0.8329 Rec: 0.8329 F1: 0.8329
[Epoch 7] Val Loss: 1.3211 Acc: 0.9724 Prec: 0.9717 Rec: 0.9736 F1: 0.9723


Train 8/15: 100%|██████████| 457/457 [01:10<00:00,  6.51it/s]


[Epoch 8] Train Loss: 0.4402 Acc: 0.8192 Prec: 0.8192 Rec: 0.8192 F1: 0.8192
[Epoch 8] Val Loss: 1.3494 Acc: 0.9695 Prec: 0.9688 Rec: 0.9710 F1: 0.9694


Train 9/15: 100%|██████████| 457/457 [01:09<00:00,  6.54it/s]


[Epoch 9] Train Loss: 0.4292 Acc: 0.8581 Prec: 0.8580 Rec: 0.8580 F1: 0.8580
[Epoch 9] Val Loss: 1.3355 Acc: 0.9697 Prec: 0.9690 Rec: 0.9708 F1: 0.9696


Train 10/15: 100%|██████████| 457/457 [01:09<00:00,  6.54it/s]


[Epoch 10] Train Loss: 0.4268 Acc: 0.8376 Prec: 0.8375 Rec: 0.8375 F1: 0.8375
[Epoch 10] Val Loss: 1.3268 Acc: 0.9724 Prec: 0.9718 Rec: 0.9740 F1: 0.9723


Train 11/15: 100%|██████████| 457/457 [01:10<00:00,  6.50it/s]


[Epoch 11] Train Loss: 0.4332 Acc: 0.8244 Prec: 0.8243 Rec: 0.8244 F1: 0.8243
[Epoch 11] Val Loss: 1.3400 Acc: 0.9712 Prec: 0.9705 Rec: 0.9728 F1: 0.9711


Train 12/15: 100%|██████████| 457/457 [01:10<00:00,  6.53it/s]

[Epoch 12] Train Loss: 0.4237 Acc: 0.8505 Prec: 0.8503 Rec: 0.8503 F1: 0.8503





[Epoch 12] Val Loss: 1.3401 Acc: 0.9723 Prec: 0.9717 Rec: 0.9740 F1: 0.9722


Train 13/15: 100%|██████████| 457/457 [01:10<00:00,  6.53it/s]


[Epoch 13] Train Loss: 0.4202 Acc: 0.8403 Prec: 0.8400 Rec: 0.8400 F1: 0.8400
[Epoch 13] Val Loss: 1.3212 Acc: 0.9734 Prec: 0.9728 Rec: 0.9751 F1: 0.9733
Early stopping triggered.

Training finished. Best val F1: 0.9738704278441408 at epoch 6


In [5]:
import os, random, math, time
from pathlib import Path
from tqdm import tqdm
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models
from torchvision.transforms import RandAugment
import timm 
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from collections import Counter
from torch.autograd import Function

# --- 1. SETUP AND HYPERPARAMETERS (From Notebook Cell 4 & 5) ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device == "cuda": torch.cuda.manual_seed_all(SEED)

IMG_SIZE = 224
BATCH_SIZE = 8         
EPOCHS = 15
NUM_WORKERS = 4         
LR = 3e-4              
LABEL_SMOOTH = 0.1
SAVE_PATH = "best_model.pth"
USE_SEGMENTATION = True 

# Loss weights 
ALPHA_DOM = 0.5
BETA_SUPCON = 0.3
ETA_CONS = 0.1
GAMMA_SEG = 0.5

# Mixup/CutMix probabilities and alphas
PROB_MIXUP = 0.5
PROB_CUTMIX = 0.5
MIXUP_ALPHA = 0.2
CUTMIX_ALPHA = 1.0

# Warmup epochs and accumulation steps
WARMUP_EPOCHS = 5
EARLY_STOPPING_PATIENCE = 7
FREEZE_EPOCHS = 5
ACCUMULATION_STEPS = 4

# --- 2. TRANSFORMS (From Notebook Cell 6) ---
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=0.05):
        self.mean = mean
        self.std = std
    def __call__(self, tensor):
        noise = torch.randn(tensor.size()) * self.std + self.mean
        noisy_tensor = tensor + noise
        return torch.clamp(noisy_tensor, 0., 1.)
    def __repr__(self):
        return self.__class__.__name__ + f'(mean={self.mean}, std={self.std})'

weak_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.02),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
strong_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ColorJitter(0.4,0.4,0.4,0.1),
    RandAugment(num_ops=2, magnitude=9),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.05),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
val_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

# --- 3. DATASET HELPER FUNCTIONS (From Notebook Cell 7) ---
def read_file_with_encoding(file_path, encodings=['utf-8', 'utf-8-sig', 'ISO-8859-1']):
    for encoding in encodings:
        try:
            with open(file_path, 'r', encoding=encoding) as f:
                return f.readlines()
        except UnicodeDecodeError:
            pass
        except Exception as e:
            raise 
    raise RuntimeError(f"Unable to read {file_path} with any of the provided encodings.")

def load_testing_dataset_info(info_file, image_dir):
    image_paths = []
    labels = []
    encodings = ['utf-8-sig', 'utf-8', 'ISO-8859-1', 'latin-1']
    lines = []
    
    for encoding in encodings:
        try:
            with open(info_file, 'r', encoding=encoding) as f:
                lines = f.readlines()
            break
        except UnicodeDecodeError:
            pass
        except Exception as e:
            raise 

    for line in lines:
        parts = line.strip().split()
        if len(parts) == 2:
            image_filename = parts[0]
            try:
                label = int(parts[1])
            except ValueError:
                continue
                
            label = 1 if label == 1 else 0
            image_full_path = os.path.join(image_dir, image_filename)
            image_paths.append(image_full_path)
            labels.append(label)
    
    return image_paths, labels

# --- 4. MULTIDATASET CLASS (From Notebook Cell 7) ---
class MultiDataset(Dataset):
    def __init__(self, root_dirs, txt_files, testing_image_paths=None, testing_labels=None, weak_transform=None, strong_transform=None, use_masks=True):
        self.root_dirs = root_dirs
        self.weak_transform = weak_transform
        self.strong_transform = strong_transform
        self.use_masks = use_masks
        self.samples = []

        if testing_image_paths is not None and testing_labels is not None:
            for img_path, label in zip(testing_image_paths, testing_labels):
                self.samples.append((img_path, label)) 
        
        if isinstance(txt_files, str):
            txt_files = [txt_files]

        all_lines = []
        for t in txt_files:
            if not os.path.exists(t):
                raise RuntimeError(f"TXT file not found: {t}. Please ensure all Kaggle input datasets are mounted.")

            lines = read_file_with_encoding(t)
            all_lines.extend([(line.strip(), t) for line in lines if line.strip()])

        for line, src_txt in all_lines:
            parts = line.split()
            if len(parts) == 0:
                continue

            fname = parts[0]
            if len(parts) >= 2:
                try:
                    lbl = int(parts[1])
                except:
                    lbl = 1 if "CAM" in fname or "cam" in fname else 0
            else:
                lbl = 1 if "CAM" in fname or "cam" in fname else 0
            
            lbl = 1 if lbl == 1 else 0
            base_fname = os.path.basename(fname)  

            found = False
            search_subs = [
                "", "Image", "Imgs", "images", "JPEGImages", "img", 
                "Images/Train", "Images/Test",
            ]
            
            for rdir in self.root_dirs:
                for sub in search_subs:
                    img_path = os.path.join(rdir, sub, base_fname)
                    if os.path.exists(img_path):
                        self.samples.append((img_path, lbl, rdir))
                        found = True
                        break
                if found:
                    break

        if len(self.samples) == 0:
            raise RuntimeError(f"No valid samples found from {txt_files}")

        print(f"✅ Loaded {len(self.samples)} samples from {len(self.root_dirs)} root directories.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        global IMG_SIZE 
        
        sample = self.samples[idx]
        img_path = sample[0]
        lbl = sample[1]
        
        if len(sample) == 3:
            rdir = sample[2]
        else:
            rdir = os.path.dirname(os.path.dirname(img_path))
           
        try:
            img = Image.open(img_path).convert("RGB")
        except:
            img = Image.fromarray(np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8))
            
        if self.weak_transform:
            weak = self.weak_transform(img)
        else:
            weak = transforms.ToTensor()(img)
            
        if self.strong_transform:
            strong = self.strong_transform(img)
        else:
            strong = weak.clone()

        mask = None
        if self.use_masks:
            mask_name = os.path.splitext(os.path.basename(img_path))[0] + ".png"
            found_mask = False
            for mask_dir in ["GT_Object", "GT", "masks", "Mask"]:
                mask_path = os.path.join(rdir, mask_dir, mask_name)
                
                if os.path.exists(mask_path):
                    m = Image.open(mask_path).convert("L").resize((IMG_SIZE, IMG_SIZE))
                    m = np.array(m).astype(np.float32) / 255.0
                    mask = torch.from_numpy((m > 0.5).astype(np.float32)).unsqueeze(0)
                    found_mask = True
                    break

            if mask is None:
                mask = torch.zeros((1, IMG_SIZE, IMG_SIZE), dtype=torch.float32)
                
        return weak, strong, lbl, mask

# --- 5. BACKBONE EXTRACTORS (From Notebook Cell 9) ---

# --- UPDATED CONSTANTS ---
# DenseNet169 final features: 1664
# MobileNetV3 Large final features before global pool: 112 (or 160/960 depending on layer, 
# but based on your previous code 112 is the expected mid-layer output)
# --- UPDATED CONSTANTS FOR TRIPLE FUSION ---
# Final feature maps after Global Average Pooling:
DENSE_CHANNELS = 1920      # DenseNet201
INCEPTION_CHANNELS = 2048  # InceptionV3
EFFICIENT_CHANNELS = 1280  # EfficientNetV2-S (timm default)

TOTAL_FEATURES = DENSE_CHANNELS + INCEPTION_CHANNELS + EFFICIENT_CHANNELS 

class DenseNet201Extractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.densenet201(weights='IMAGENET1K_V1' if pretrained else None).features
    def forward(self, x):
        return self.features(x)

class InceptionExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        # aux_logits=False is necessary to simplify the output to just features
        self.model = models.inception_v3(weights='IMAGENET1K_V1' if pretrained else None, aux_logits=True)
    def forward(self, x):
        # InceptionV3 requires 299x299 ideally, but will work with 224x224
        # We extract features before the final FC layer
        x = self.model.Conv2d_1a_3x3(x)
        x = self.model.Conv2d_2a_3x3(x)
        x = self.model.Conv2d_2b_3x3(x)
        x = self.model.maxpool1(x)
        x = self.model.Conv2d_3b_1x1(x)
        x = self.model.Conv2d_4a_3x3(x)
        x = self.model.maxpool2(x)
        x = self.model.Mixed_5b(x)
        x = self.model.Mixed_5c(x)
        x = self.model.Mixed_5d(x)
        x = self.model.Mixed_6a(x)
        x = self.model.Mixed_6b(x)
        x = self.model.Mixed_6c(x)
        x = self.model.Mixed_6d(x)
        x = self.model.Mixed_6e(x)
        x = self.model.Mixed_7a(x)
        x = self.model.Mixed_7b(x)
        x = self.model.Mixed_7c(x)
        return x

class EfficientNetExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        # Using timm for a modern EfficientNetV2 implementation
        self.model = timm.create_model('tf_efficientnetv2_s', pretrained=pretrained, num_classes=0, global_pool='')
    def forward(self, x):
        return self.model(x)



# --- 6. USER'S Keras-Style Fusion Model (Corrected) ---


NUM_CLASSES = 2 

class TripleFusionModel(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        
        self.densenet_base = DenseNet201Extractor(pretrained=True)
        self.inception_base = InceptionExtractor(pretrained=True)
        self.efficient_base = EfficientNetExtractor(pretrained=True)
        
        # Freezing logic for the first phase
        for base in [self.densenet_base, self.inception_base, self.efficient_base]:
            for param in base.parameters():
                param.requires_grad = False
            
        self.global_pool = nn.AdaptiveAvgPool2d(1)

        # Classifier input: 1920 + 2048 + 1280 = 5248
        self.classifier = nn.Sequential(
            nn.Linear(TOTAL_FEATURES, 1024), 
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes) 
        )
        
    def forward(self, x):
        # 1. Feature Extraction
        feat_dense = self.densenet_base(x)
        feat_inc = self.inception_base(x)
        feat_eff = self.efficient_base(x)

        # 2. Global Average Pooling & Flattening
        x1 = self.global_pool(feat_dense).view(x.size(0), -1)
        x2 = self.global_pool(feat_inc).view(x.size(0), -1)
        x3 = self.global_pool(feat_eff).view(x.size(0), -1)

        # 3. Concatenation (Triple Fusion)
        merged = torch.cat([x1, x2, x3], dim=1)

        # 4. Classification
        logits = self.classifier(merged)
        
        return {
            "logits": logits, 
            "feat": merged,
            "domain_logits": torch.zeros(x.size(0), 2).to(x.device),
            "seg": torch.zeros(x.size(0), 1, IMG_SIZE, IMG_SIZE).to(x.device)
        }

# --- 7. LOSS AND ADVERSARIAL HELPERS (From Notebook Cells 17-19) ---

class LabelSmoothingCE(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.s = smoothing
    def forward(self, logits, target):
        c = logits.size(-1)
        logp = F.log_softmax(logits, dim=-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(logp)
            true_dist.fill_(self.s / (c - 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.s)
        return (-true_dist * logp).sum(dim=-1).mean()

class FocalLoss(nn.Module):
    def __init__(self, gamma=1.5):
        super().__init__()
        self.gamma = gamma
    def forward(self, logits, target):
        prob = F.softmax(logits, dim=1)
        pt = prob.gather(1, target.unsqueeze(1)).squeeze(1)
        ce = F.cross_entropy(logits, target, reduction='none')
        loss = ((1 - pt) ** self.gamma) * ce
        return loss.mean()

def dice_loss(pred, target, smooth=1.0):
    pred = torch.sigmoid(pred)
    num = 2 * (pred * target).sum() + smooth
    den = pred.sum() + target.sum() + smooth
    return 1 - (num / den)

def seg_loss_fn(pred, mask):
    if pred.shape[-2:] != mask.shape[-2:]:
        pred = F.interpolate(pred, size=mask.shape[-2:], mode="bilinear", align_corners=False)
    return F.binary_cross_entropy_with_logits(pred, mask) + dice_loss(pred, mask)

class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.cos = nn.CosineSimilarity(dim=-1)
    def forward(self, features, labels):
        device = features.device
        f = F.normalize(features, dim=1)
        sim = torch.matmul(f, f.T) / self.temperature
        labels = labels.contiguous().view(-1,1)
        mask = torch.eq(labels, labels.T).float().to(device)
        logits_max, _ = torch.max(sim, dim=1, keepdim=True)
        logits = sim - logits_max.detach()
        exp_logits = torch.exp(logits) * (1 - torch.eye(len(features), device=device))
        denom = exp_logits.sum(1, keepdim=True)
        pos_mask = mask - torch.eye(len(features), device=device)
        pos_exp = (exp_logits * pos_mask).sum(1)
        loss = -torch.log((pos_exp + 1e-8) / (denom + 1e-8) + 1e-12)
        valid = (pos_mask.sum(1) > 0).float()
        loss = (loss * valid).sum() / (valid.sum() + 1e-8)
        return loss

clf_loss_ce = LabelSmoothingCE(LABEL_SMOOTH)
clf_loss_focal = FocalLoss(gamma=1.5)
supcon_loss_fn = SupConLoss(temperature=0.07)

class GradReverse(Function):
    @staticmethod
    def forward(ctx, x, l):
        ctx.l = l
        return x.view_as(x)
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.l, None

def grad_reverse(x, l=1.0):
    return GradReverse.apply(x, l)


# --- 8. DATA AUGMENTATION HELPERS (From Notebook Cell 20) ---

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def apply_mixup(x, y, alpha=MIXUP_ALPHA):
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0))
    mixed_x = lam * x + (1 - lam) * x[idx]
    y_a, y_b = y, y[idx]
    return mixed_x, y_a, y_b, lam

def apply_cutmix(x, y, alpha=CUTMIX_ALPHA):
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0))
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    new_x = x.clone()
    new_x[:, :, bby1:bby2, bbx1:bbx2] = x[idx, :, bby1:bby2, bbx1:bbx2]
    lam_adjusted = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size(-1) * x.size(-2)))
    return new_x, y, y[idx], lam_adjusted


# --- 9. DATA LOADING SETUP (From Notebook Cell 8) ---
info_dir = "/kaggle/input/cod10k/COD10K-v3/Info"
train_dir_cod = "/kaggle/input/cod10k/COD10K-v3/Train" 
test_dir_cod = "/kaggle/input/cod10k/COD10K-v3/Test"  
    
train_cam_txt = os.path.join(info_dir, "CAM_train.txt")
train_noncam_txt = os.path.join(info_dir, "NonCAM_train.txt")
test_cam_txt = os.path.join(info_dir, "CAM_test.txt")
test_noncam_txt = os.path.join(info_dir, "NonCAM_test.txt")

info_dir2 = "/kaggle/input/camo-coco/CAMO_COCO/Info"
train_cam_txt2 = os.path.join(info_dir2, "camo_train.txt")
train_noncam_txt2 = os.path.join(info_dir2, "non_camo_train.txt")
test_cam_txt2 = os.path.join(info_dir2, "camo_test.txt")
test_noncam_txt2 = os.path.join(info_dir2, "non_camo_test.txt")

train_dir_camo_cam = "/kaggle/input/camo-coco/CAMO_COCO/Camouflage"
train_dir_camo_noncam = "/kaggle/input/camo-coco/CAMO_COCO/Non_Camouflage"

testing_info_file = "/kaggle/input/testing-dataset/Info/image_labels.txt"
testing_images_dir = "/kaggle/input/testing-dataset/Images"

# --- MOCKING DATASET LOADING FOR RUNNABILITY ---
try:
    testing_image_paths, testing_labels = load_testing_dataset_info(testing_info_file, testing_images_dir)
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        testing_image_paths, testing_labels, test_size=0.8, random_state=SEED
    )
    ALL_ROOT_DIRS = [
        train_dir_cod,       
        test_dir_cod,       
        train_dir_camo_cam,  
        train_dir_camo_noncam
    ]
    ALL_TRAIN_TXTS = [
        train_cam_txt2, train_noncam_txt2,
    ]
    ALL_VAL_TXTS = [test_cam_txt2, test_noncam_txt2]
    
    train_ds = MultiDataset(
        root_dirs=ALL_ROOT_DIRS, 
        txt_files=ALL_TRAIN_TXTS,               
        testing_image_paths=train_paths,        
        testing_labels=train_labels,            
        weak_transform=weak_tf, 
        strong_transform=strong_tf, 
        use_masks=USE_SEGMENTATION
    )
    val_ds = MultiDataset(
        root_dirs=ALL_ROOT_DIRS,  
        txt_files=ALL_VAL_TXTS,                 
        testing_image_paths=val_paths,          
        testing_labels=val_labels,              
        weak_transform=val_tf, 
        strong_transform=None, 
        use_masks=USE_SEGMENTATION
    )

    def build_weighted_sampler(dataset):
        labels = [sample[1] for sample in dataset.samples]  
        counts = Counter(labels)
        total = len(labels)
        if len(counts) <= 1:
            weights = [1.0] * total
        else:
            class_weights = {c: total / (counts[c] * len(counts)) for c in counts}
            weights = [class_weights[lbl] for lbl in labels]
        return WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

    train_sampler = build_weighted_sampler(train_ds)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers=NUM_WORKERS)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

except RuntimeError as e:
    print(f"⚠️ Warning: Cannot access Kaggle paths ({e}). Using Mock DataLoaders for demonstration.")
    
    class MockDataset(Dataset):
        def __init__(self, num_samples, num_classes=2, img_size=IMG_SIZE):
            self.num_samples = num_samples
            self.data = torch.randn(num_samples, 3, img_size, img_size)
            self.labels = torch.randint(0, num_classes, (num_samples,))
            self.masks = torch.randint(0, 2, (num_samples, 1, img_size, img_size)).float()

        def __len__(self):
            return self.num_samples

        def __getitem__(self, idx):
            return self.data[idx], self.data[idx].clone(), self.labels[idx], self.masks[idx]

    train_ds = MockDataset(num_samples=14150) 
    val_ds = MockDataset(num_samples=6606)   
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    print(f"Mock DataLoaders initialized with {len(train_ds)} train and {len(val_ds)} val samples.")


# --- 10. MODEL INSTANTIATION AND OPTIMIZER SETUP (From Notebook Cell 20) ---

model = TripleFusionModel(num_classes=2).to(device)

# Update parameter groups for Optimizer
backbone_names = ['densenet_base', 'inception_base', 'efficient_base']
backbone_params = []
head_params = []
for name, param in model.named_parameters():
    if any(k in name for k in backbone_names):
        backbone_params.append(param)
    else:
        head_params.append(param)

def get_cosine_with_warmup_scheduler(optimizer, warmup_epochs, total_epochs, last_epoch=-1):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return float(epoch) / float(max(1.0, warmup_epochs))
        t = (epoch - warmup_epochs) / float(max(1, total_epochs - warmup_epochs))
        return 0.5 * (1.0 + math.cos(math.pi * t))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
opt = torch.optim.AdamW([
    {'params': backbone_params, 'lr': LR / 10}, # Backbones usually need a smaller LR
    {'params': head_params, 'lr': LR}
], weight_decay=1e-4)
scheduler = get_cosine_with_warmup_scheduler(opt, WARMUP_EPOCHS, EPOCHS)
scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))

print(f"\nModel instantiated. LR: {LR}, Epochs: {EPOCHS}.")
print(f"Backbones are initially frozen for {FREEZE_EPOCHS} epochs.")

# --- 11. TRAINING LOOP (Adapted from Notebook Cell 21) ---

def compute_combined_clf_loss(logits, targets, mix_info=None, use_focal=False):
    if mix_info is None:
        if use_focal:
            return clf_loss_focal(logits, targets)
        else:
            return clf_loss_ce(logits, targets)
    else:
        y_a, y_b, lam = mix_info
        if use_focal:
            loss = lam * F.cross_entropy(logits, y_a) + (1 - lam) * F.cross_entropy(logits, y_b)
        else:
            loss = lam * clf_loss_ce(logits, y_a) + (1 - lam) * clf_loss_ce(logits, y_b)
        return loss

best_vf1 = 0.0
best_epoch = 0
patience_count = 0

for epoch in range(1, EPOCHS+1):
    # --- Freeze/Unfreeze Logic ---
    if epoch <= FREEZE_EPOCHS:
        for name, p in model.named_parameters():
            if any(k in name for k in ['densenet_base', 'mobilenet_base']):
                p.requires_grad = False
    if epoch == FREEZE_EPOCHS + 1:
        print(f"--- Unfreezing all 3 backbones at epoch {epoch} ---")
        for name, p in model.named_parameters():
            if any(k in name for k in backbone_names):
                p.requires_grad = True

    model.train()
    running_loss = 0.0
    y_true, y_pred = [], []
    n_batches = 0
    opt.zero_grad() 
    
    for i, (weak_imgs, strong_imgs, labels, masks) in enumerate(tqdm(train_loader, desc=f"Train {epoch}/{EPOCHS}")):
        weak_imgs = weak_imgs.to(device); strong_imgs = strong_imgs.to(device)
        labels = labels.to(device)
        if masks is not None:
            masks = masks.to(device)

        imgs = weak_imgs

        mix_info = None
        rand = random.random()
        if rand < PROB_MIXUP:
            imgs, y_a, y_b, lam = apply_mixup(imgs, labels)
            mix_info = (y_a.to(device), y_b.to(device), lam)
        elif rand < PROB_MIXUP + PROB_CUTMIX:
            imgs, y_a, y_b, lam = apply_cutmix(imgs, labels)
            mix_info = (y_a.to(device), y_b.to(device), lam)

        with torch.autocast(device_type=device, enabled=(device=="cuda")):
            out = model(imgs) 
            logits = out["logits"]
            
            clf_loss = compute_combined_clf_loss(logits, labels, mix_info=mix_info, use_focal=False)

            # Auxiliary losses are set to 0 as the user's simple fusion model excludes these complex heads.
            seg_loss = 0.0
            supcon_loss = 0.0 
            cons_loss = 0.0   
            dom_loss = 0.0

            total_loss = clf_loss + GAMMA_SEG * seg_loss + BETA_SUPCON * supcon_loss + ETA_CONS * cons_loss + ALPHA_DOM * dom_loss
            
            total_loss = total_loss / ACCUMULATION_STEPS 

        scaler.scale(total_loss).backward()

        if (i + 1) % ACCUMULATION_STEPS == 0:
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(opt)
            scaler.update()
            opt.zero_grad() 

        running_loss += total_loss.item() * ACCUMULATION_STEPS
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(logits.argmax(1).cpu().numpy())
        n_batches += 1

    if n_batches % ACCUMULATION_STEPS != 0:
        scaler.unscale_(opt)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(opt)
        scaler.update()
        opt.zero_grad()

    scheduler.step()

    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="macro", zero_division=0)
    print(f"[Epoch {epoch}] Train Loss: {running_loss/max(1,n_batches):.4f} Acc: {acc:.4f} Prec: {prec:.4f} Rec: {rec:.4f} F1: {f1:.4f}")

    # -------------------
    # VALIDATION
    # -------------------
    model.eval()
    val_y_true, val_y_pred = [], []
    val_loss = 0.0
    with torch.no_grad():
        for weak_imgs, _, labels, masks in val_loader:
            imgs = weak_imgs.to(device)
            labels = labels.to(device)
            if masks is not None:
                masks = masks.to(device)

            out = model(imgs)
            logits = out["logits"]
            
            loss = compute_combined_clf_loss(logits, labels, mix_info=None, use_focal=False)
            if USE_SEGMENTATION and (masks is not None):
                loss += GAMMA_SEG * seg_loss_fn(out["seg"], masks) 
            val_loss += loss.item()

            val_y_true.extend(labels.cpu().numpy())
            val_y_pred.extend(logits.argmax(1).cpu().numpy())

    vacc = accuracy_score(val_y_true, val_y_pred)
    vprec, vrec, vf1, _ = precision_recall_fscore_support(val_y_true, val_y_pred, average="macro", zero_division=0)
    print(f"[Epoch {epoch}] Val Loss: {val_loss/max(1,len(val_loader)):.4f} Acc: {vacc:.4f} Prec: {vprec:.4f} Rec: {vrec:.4f} F1: {vf1:.4f}")

    if vf1 > best_vf1:
        best_vf1 = vf1
        best_epoch = epoch
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "opt_state": opt.state_dict(),
            "best_vf1": best_vf1
        }, SAVE_PATH)
        patience_count = 0
        print(f"Saved best model at epoch {epoch} (F1 {best_vf1:.4f})")
    else:
        patience_count += 1
        if patience_count >= EARLY_STOPPING_PATIENCE:
            print("Early stopping triggered.")
            break

print("\nTraining finished. Best val F1:", best_vf1, "at epoch", best_epoch)

Device: cuda
✅ Loaded 3651 samples from 4 root directories.
✅ Loaded 7106 samples from 4 root directories.


model.safetensors:   0%|          | 0.00/86.5M [00:00<?, ?B/s]

  scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))



Model instantiated. LR: 0.0003, Epochs: 15.
Backbones are initially frozen for 5 epochs.


Train 1/15: 100%|██████████| 457/457 [00:58<00:00,  7.76it/s]

[Epoch 1] Train Loss: 0.7162 Acc: 0.5021 Prec: 0.5168 Rec: 0.5122 F1: 0.4716





[Epoch 1] Val Loss: 1.5395 Acc: 0.5180 Prec: 0.4286 Rec: 0.4852 F1: 0.3719
Saved best model at epoch 1 (F1 0.3719)


Train 2/15: 100%|██████████| 457/457 [00:56<00:00,  8.11it/s]


[Epoch 2] Train Loss: 0.5255 Acc: 0.7623 Prec: 0.7622 Rec: 0.7623 F1: 0.7622
[Epoch 2] Val Loss: 1.2113 Acc: 0.9949 Prec: 0.9947 Rec: 0.9951 F1: 0.9949
Saved best model at epoch 2 (F1 0.9949)


Train 3/15: 100%|██████████| 457/457 [00:57<00:00,  7.99it/s]


[Epoch 3] Train Loss: 0.4908 Acc: 0.7869 Prec: 0.7869 Rec: 0.7869 F1: 0.7869
[Epoch 3] Val Loss: 1.1871 Acc: 0.9954 Prec: 0.9951 Rec: 0.9955 F1: 0.9953
Saved best model at epoch 3 (F1 0.9953)


Train 4/15: 100%|██████████| 457/457 [00:56<00:00,  8.10it/s]


[Epoch 4] Train Loss: 0.4851 Acc: 0.7946 Prec: 0.7946 Rec: 0.7946 F1: 0.7946
[Epoch 4] Val Loss: 1.1951 Acc: 0.9916 Prec: 0.9920 Rec: 0.9911 F1: 0.9915


Train 5/15: 100%|██████████| 457/457 [00:57<00:00,  7.88it/s]


[Epoch 5] Train Loss: 0.4774 Acc: 0.8042 Prec: 0.8042 Rec: 0.8041 F1: 0.8041
[Epoch 5] Val Loss: 1.1897 Acc: 0.9949 Prec: 0.9951 Rec: 0.9947 F1: 0.9949
--- Unfreezing all 3 backbones at epoch 6 ---


Train 6/15: 100%|██████████| 457/457 [02:15<00:00,  3.36it/s]

[Epoch 6] Train Loss: 0.4749 Acc: 0.8066 Prec: 0.8066 Rec: 0.8066 F1: 0.8066





[Epoch 6] Val Loss: 1.1851 Acc: 0.9942 Prec: 0.9945 Rec: 0.9940 F1: 0.9942


Train 7/15: 100%|██████████| 457/457 [02:10<00:00,  3.51it/s]

[Epoch 7] Train Loss: 0.4634 Acc: 0.8192 Prec: 0.8191 Rec: 0.8192 F1: 0.8192





[Epoch 7] Val Loss: 1.1983 Acc: 0.9961 Prec: 0.9959 Rec: 0.9962 F1: 0.9960
Saved best model at epoch 7 (F1 0.9960)


Train 8/15: 100%|██████████| 457/457 [02:11<00:00,  3.49it/s]

[Epoch 8] Train Loss: 0.4590 Acc: 0.8195 Prec: 0.8195 Rec: 0.8195 F1: 0.8195





[Epoch 8] Val Loss: 1.1875 Acc: 0.9962 Prec: 0.9960 Rec: 0.9964 F1: 0.9962
Saved best model at epoch 8 (F1 0.9962)


Train 9/15: 100%|██████████| 457/457 [02:10<00:00,  3.51it/s]

[Epoch 9] Train Loss: 0.4460 Acc: 0.8461 Prec: 0.8459 Rec: 0.8459 F1: 0.8459





[Epoch 9] Val Loss: 1.1879 Acc: 0.9963 Prec: 0.9964 Rec: 0.9962 F1: 0.9963
Saved best model at epoch 9 (F1 0.9963)


Train 10/15: 100%|██████████| 457/457 [02:09<00:00,  3.52it/s]

[Epoch 10] Train Loss: 0.4385 Acc: 0.8359 Prec: 0.8359 Rec: 0.8359 F1: 0.8359





[Epoch 10] Val Loss: 1.1884 Acc: 0.9973 Prec: 0.9973 Rec: 0.9973 F1: 0.9973
Saved best model at epoch 10 (F1 0.9973)


Train 11/15: 100%|██████████| 457/457 [02:09<00:00,  3.54it/s]

[Epoch 11] Train Loss: 0.4455 Acc: 0.8190 Prec: 0.8192 Rec: 0.8191 F1: 0.8189





[Epoch 11] Val Loss: 1.1917 Acc: 0.9973 Prec: 0.9974 Rec: 0.9973 F1: 0.9973


Train 12/15: 100%|██████████| 457/457 [02:08<00:00,  3.55it/s]

[Epoch 12] Train Loss: 0.4385 Acc: 0.8376 Prec: 0.8376 Rec: 0.8375 F1: 0.8375





[Epoch 12] Val Loss: 1.1879 Acc: 0.9976 Prec: 0.9976 Rec: 0.9975 F1: 0.9976
Saved best model at epoch 12 (F1 0.9976)


Train 13/15: 100%|██████████| 457/457 [02:07<00:00,  3.57it/s]

[Epoch 13] Train Loss: 0.4316 Acc: 0.8261 Prec: 0.8262 Rec: 0.8261 F1: 0.8261





[Epoch 13] Val Loss: 1.1884 Acc: 0.9977 Prec: 0.9978 Rec: 0.9977 F1: 0.9977
Saved best model at epoch 13 (F1 0.9977)


Train 14/15: 100%|██████████| 457/457 [02:08<00:00,  3.55it/s]

[Epoch 14] Train Loss: 0.4304 Acc: 0.8332 Prec: 0.8333 Rec: 0.8331 F1: 0.8331





[Epoch 14] Val Loss: 1.1903 Acc: 0.9969 Prec: 0.9968 Rec: 0.9970 F1: 0.9969


Train 15/15: 100%|██████████| 457/457 [02:08<00:00,  3.55it/s]

[Epoch 15] Train Loss: 0.4314 Acc: 0.8381 Prec: 0.8384 Rec: 0.8383 F1: 0.8381





[Epoch 15] Val Loss: 1.1883 Acc: 0.9977 Prec: 0.9977 Rec: 0.9977 F1: 0.9977
Saved best model at epoch 15 (F1 0.9977)

Training finished. Best val F1: 0.9977361654733695 at epoch 15
