In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim.swa_utils import AveragedModel, SWALR
from PIL import Image, ImageFile
import pandas as pd
import numpy as np
import os
from pathlib import Path
from tqdm.auto import tqdm
import random
from transformers import CLIPProcessor, CLIPModel
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from sklearn.utils.class_weight import compute_class_weight
from torchvision import transforms
import glob
import sys
import argparse
import warnings

warnings.filterwarnings('ignore')
ImageFile.LOAD_TRUNCATED_IMAGES = True



CLIP_MODEL_ID = "openai/clip-vit-large-patch14"
D_CLIP_HIDDEN = 1024
RESIDUAL_ALPHA = 0.2
COSINE_CLASSIFIER_SCALE = 20.0
DROPOUT_RATE = 0.5
TASK_C = 'label'
OUTPUT_SIZE_C = 3
LABEL_TO_ID_C = {'Neutral': 0, 'Support': 1, 'Oppose': 2}
ID_TO_LABEL_C = {v: k for k, v in LABEL_TO_ID_C.items()}


SEED = 42
BATCH_SIZE = 16
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_WORKERS = 2


HEAD_ONLY_EPOCHS = 8
FULL_MODEL_EPOCHS = 20
LAYERS_TO_UNFREEZE = 2


MAX_LR_HEAD_STAGE1 = 1e-4
MAX_LR_HEAD_STAGE2 = 2e-5
MAX_LR_BACKBONE = 1e-8
WEIGHT_DECAY = 1e-2
GRADIENT_CLIP_VAL = 1.0


LABEL_SMOOTHING = 0.15
SWA_START_EPOCH = 6
SWA_LR = 5e-5
EARLY_STOPPING_PATIENCE = 7

def seed_everything(seed):
    random.seed(seed); os.environ['PYTHONHASHSEED'] = str(seed); np.random.seed(seed)
    torch.manual_seed(seed);
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed); torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False
    print(f"Random seed set to {seed} for reproducibility.")


def load_and_combine_data_C(base_dir):
    """Loads and prepares train/validation dataframes from the specified base directory."""

    train_images_root_dir = os.path.join(base_dir, "Subtask_C", "Train", "Subtask_C_Train")
    train_text_csv_file = os.path.join(base_dir, "Train_Text", "STask_C_train.csv")
    val_images_dir = os.path.join(base_dir, "Subtask_C", "Evaluation", "STask_C_val_img")
    val_text_csv_file = os.path.join(base_dir, "Eval_Data_Text", "STask-C(index,text)val.csv")
    val_labels_csv_file = os.path.join(base_dir, "Eval_Data_Labels", "STask-C(index,label)val.csv")

    print("\n--- Loading Training Data (Subtask C) ---")
    train_text_df = pd.read_csv(train_text_csv_file).rename(columns={'index': 'name', 'text': 'text'})
    train_data = []
    for category, label_id in LABEL_TO_ID_C.items():
        folder = os.path.join(train_images_root_dir, category)
        if not os.path.isdir(folder):
            print(f"Warning: Training directory not found, skipping: {folder}")
            continue
        for img_path in glob.glob(os.path.join(folder, '*.*')):
            name = os.path.basename(img_path)
            text_val = train_text_df[train_text_df['name'] == name]['text'].values
            text = str(text_val[0]) if len(text_val) > 0 and pd.notna(text_val[0]) else ""
            train_data.append({'name': name, 'text': text, 'label': label_id, 'img_path': img_path, 'split': 'train'})
    train_df = pd.DataFrame(train_data)

    print("\n--- Loading Validation Data (Subtask C) ---")
    val_labels_df = pd.read_csv(val_labels_csv_file).rename(columns={'index': 'name', 'label': 'label'})
    val_text_df = pd.read_csv(val_text_csv_file).rename(columns={'index': 'name', 'text': 'text'})
    val_df = pd.merge(val_labels_df, val_text_df, on='name', how='inner')
    val_df.dropna(subset=['label'], inplace=True)
    val_df['label'] = val_df['label'].astype(int)
    val_df['img_path'] = val_df['name'].apply(lambda x: os.path.join(val_images_dir, x))
    val_df['split'] = 'val'

    print("\n--- Data Loading Summary ---")
    all_df = pd.concat([train_df, val_df], ignore_index=True)
    print(f"Loaded {len(train_df)} training samples and {len(val_df)} validation samples.")
    if len(train_df) == 0 or len(val_df) == 0:
        sys.exit("FATAL: Training or validation data is empty. Check paths and files.")
    return all_df

class MultimodalDataset(Dataset):
    def __init__(self, data_df, processor, split, transform):
        self.data = data_df[data_df['split'] == split].reset_index(drop=True)
        self.processor = processor
        self.transform = transform
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        label = torch.tensor(item['label'], dtype=torch.long)
        text = self.processor.tokenizer(str(item['text']), return_tensors="pt", padding='max_length', truncation=True, max_length=77).input_ids.squeeze(0)
        try: image = self.transform(Image.open(item['img_path']).convert("RGB"))
        except Exception: image = torch.zeros((3, 224, 224), dtype=torch.float)
        return {"image": image, "text": text, "label": label}


class FeatureAdapter(nn.Module):
    def __init__(self, dim): super().__init__(); self.adapter = nn.Sequential(nn.Linear(dim, dim//2), nn.GELU(), nn.Linear(dim//2, dim))
    def forward(self, x, alpha): return alpha * self.adapter(x) + (1.0 - alpha) * x

class CosineClassifier(nn.Module):
    def __init__(self, in_dim, out_dim, scale, prompts, clip_model_ref, processor_ref):
        super().__init__(); self.weight = nn.Parameter(torch.Tensor(out_dim, in_dim)); self.scale = scale
        with torch.no_grad():
            sai_proj = nn.Linear(clip_model_ref.config.projection_dim, in_dim).to(clip_model_ref.device)
            tokenized = processor_ref.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(clip_model_ref.device)
            raw_embeds = clip_model_ref.get_text_features(**tokenized)
            self.weight.data = F.normalize(sai_proj(raw_embeds.float()).data, p=2, dim=1)
    def forward(self, x): return F.linear(F.normalize(x, p=2, dim=1), F.normalize(self.weight, p=2, dim=1)) * self.scale

class FinalMemeCLIP(nn.Module):
    def __init__(self, clip_model_ref, processor_ref, clip_output_dim):
        super().__init__()
        self.clip_model = clip_model_ref
        self.image_proj = nn.Linear(clip_output_dim, D_CLIP_HIDDEN)
        self.text_proj = nn.Linear(clip_output_dim, D_CLIP_HIDDEN)
        self.image_adapter = FeatureAdapter(D_CLIP_HIDDEN)
        self.text_adapter = FeatureAdapter(D_CLIP_HIDDEN)
        self.pre_output_layer = nn.Sequential(nn.Linear(D_CLIP_HIDDEN, D_CLIP_HIDDEN), nn.GELU(), nn.Dropout(DROPOUT_RATE))
        prompts = [f"a meme expressing a '{ID_TO_LABEL_C[i].lower()}' stance" for i in sorted(ID_TO_LABEL_C.keys())]
        self.classifier = CosineClassifier(D_CLIP_HIDDEN, OUTPUT_SIZE_C, COSINE_CLASSIFIER_SCALE, prompts, self.clip_model, processor_ref)
    def forward(self, image, text):
        with torch.no_grad():
            image_features = self.clip_model.get_image_features(image)
            text_features = self.clip_model.get_text_features(text)
        img_proj, txt_proj = self.image_proj(image_features.float()), self.text_proj(text_features.float())
        img_adapted, txt_adapted = self.image_adapter(img_proj, RESIDUAL_ALPHA), self.text_adapter(txt_proj, RESIDUAL_ALPHA)
        fused = self.pre_output_layer(img_adapted * txt_adapted)
        return self.classifier(fused)

def set_requires_grad(model, requires_grad):
    for param in model.parameters(): param.requires_grad = requires_grad

def unfreeze_last_n_layers(clip_model, n):
    for layer in list(clip_model.vision_model.encoder.layers)[-n:]: set_requires_grad(layer, True)
    for layer in list(clip_model.text_model.encoder.layers)[-n:]: set_requires_grad(layer, True)
    print(f"Unfroze last {n} layers of CLIP vision and text encoders for fine-tuning.")


def train_epoch(model, data_loader, criterion, optimizer, scheduler, scaler):
    model.train(); total_loss = 0; all_preds, all_labels = [], []
    pbar = tqdm(data_loader, desc="Training", leave=False)
    for data in pbar:
        img, txt, label = data['image'].to(DEVICE), data['text'].to(DEVICE), data['label'].to(DEVICE)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast():
            logits = model(img, txt); loss = criterion(logits, label)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP_VAL)
        scaler.step(optimizer)
        scaler.update()
        if scheduler: scheduler.step()
        total_loss += loss.item()
        all_preds.extend(torch.argmax(logits, dim=1).cpu().numpy()); all_labels.extend(label.cpu().numpy())
    return total_loss/len(data_loader), f1_score(all_labels, all_preds, average='weighted', zero_division=0)

def validate_epoch(model, data_loader, criterion):
    model.eval(); total_loss = 0; all_preds, all_labels = [], []
    with torch.no_grad():
        for data in tqdm(data_loader, desc="Validating", leave=False):
            img, txt, label = data['image'].to(DEVICE), data['text'].to(DEVICE), data['label'].to(DEVICE)
            with torch.cuda.amp.autocast():
                logits = model(img, txt); loss = criterion(logits, label)
            total_loss += loss.item()
            all_preds.extend(torch.argmax(logits, dim=1).cpu().numpy()); all_labels.extend(label.cpu().numpy())
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    prec = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    rec = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    acc = accuracy_score(all_labels, all_preds)
    return total_loss/len(data_loader), f1, prec, rec, acc


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train MemeCLIP model for Subtask C: Stance Detection.")
    parser.add_argument(
        '--base_dir',
        type=str,
        default="./SharedTaskProject",
        help="Path to the base project directory containing the Subtask data folders."
    )
    args = parser.parse_args()

    seed_everything(SEED)


    BASE_PROJECT_DIR = args.base_dir
    MODELS_DIR_C = os.path.join(BASE_PROJECT_DIR, "models", "Subtask_C")
    EXP_NAME_C = "MemeCLIP_Stance_F1_Breakthrough_vFinal"
    MODEL_SAVE_PATH_C = os.path.join(MODELS_DIR_C, EXP_NAME_C)
    Path(MODEL_SAVE_PATH_C).mkdir(parents=True, exist_ok=True)
    BEST_MODEL_SAVE_PATH = os.path.join(MODEL_SAVE_PATH_C, "best_model_subtask_C.pt")
    SWA_MODEL_SAVE_PATH = os.path.join(MODEL_SAVE_PATH_C, "swa_model_subtask_C.pt")
    print(f"This run's models will be saved in: {MODEL_SAVE_PATH_C}")


    print(f"Loading CLIP: {CLIP_MODEL_ID}...")
    processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID)
    clip_model = CLIPModel.from_pretrained(CLIP_MODEL_ID).to(DEVICE)
    CLIP_OUTPUT_DIM = clip_model.config.projection_dim

    all_data_df = load_and_combine_data_C(BASE_PROJECT_DIR)


    norm_mean, norm_std = processor.image_processor.image_mean, processor.image_processor.image_std
    val_transform = transforms.Compose([transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std)])
    train_transform = transforms.Compose([transforms.RandomResizedCrop(224, scale=(0.75, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), transforms.RandomHorizontalFlip(), transforms.TrivialAugmentWide(interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std)])

    train_dataset = MultimodalDataset(all_data_df, processor, 'train', train_transform)
    val_dataset = MultimodalDataset(all_data_df, processor, 'val', val_transform)
    dataloader_train = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    dataloader_val = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

    train_labels = all_data_df[all_data_df['split']=='train']['label']
    class_weights = torch.tensor(compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels.to_numpy()), dtype=torch.float).to(DEVICE)
    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=LABEL_SMOOTHING)
    print(f"Using CrossEntropyLoss with Label Smoothing: {LABEL_SMOOTHING} and Weights: {class_weights.cpu().numpy()}")

    full_model = FinalMemeCLIP(clip_model, processor, CLIP_OUTPUT_DIM).to(DEVICE)
    scaler = torch.cuda.amp.GradScaler()
    best_val_f1 = -1.0
    epochs_no_improve = 0

    print(f"\n--- STAGE 1: WARMING UP CLASSIFIER HEAD FOR {HEAD_ONLY_EPOCHS} EPOCHS ---")
    set_requires_grad(full_model.clip_model, False)
    set_requires_grad(full_model.classifier, True)
    set_requires_grad(full_model.pre_output_layer, True)
    set_requires_grad(full_model.image_adapter, True)
    set_requires_grad(full_model.text_adapter, True)
    set_requires_grad(full_model.image_proj, True)
    set_requires_grad(full_model.text_proj, True)

    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, full_model.parameters()), lr=MAX_LR_HEAD_STAGE1, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=MAX_LR_HEAD_STAGE1, epochs=HEAD_ONLY_EPOCHS, steps_per_epoch=len(dataloader_train))

    for epoch in range(HEAD_ONLY_EPOCHS):
        train_loss, train_f1 = train_epoch(full_model, dataloader_train, criterion, optimizer, scheduler, scaler)
        val_loss, val_f1, val_p, val_r, val_acc = validate_epoch(full_model, dataloader_val, criterion)
        print(f"[Head Only] Epoch {epoch+1}/{HEAD_ONLY_EPOCHS} -> Train F1: {train_f1:.4f} | Val F1: {val_f1:.4f} | Val P: {val_p:.4f} | Val R: {val_r:.4f}")
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save(full_model.state_dict(), BEST_MODEL_SAVE_PATH)
            print(f"** New best model saved in Stage 1! Val F1: {best_val_f1:.4f} **")


    print(f"\n--- STAGE 2: ULTRA-GENTLE FINE-TUNING & SWA ---")
    full_model.load_state_dict(torch.load(BEST_MODEL_SAVE_PATH))
    unfreeze_last_n_layers(full_model.clip_model, LAYERS_TO_UNFREEZE)

    param_groups = [{'params': filter(lambda p: p.requires_grad, full_model.clip_model.parameters()), 'lr': MAX_LR_BACKBONE},
                    {'params': [p for name, p in full_model.named_parameters() if 'clip_model' not in name and p.requires_grad], 'lr': MAX_LR_HEAD_STAGE2}]
    optimizer = torch.optim.AdamW(param_groups, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=[MAX_LR_BACKBONE, MAX_LR_HEAD_STAGE2], total_steps=FULL_MODEL_EPOCHS * len(dataloader_train))

    swa_model = AveragedModel(full_model)
    swa_scheduler = SWALR(optimizer, swa_lr=SWA_LR)
    print(f"SWA initialized. Will start averaging after epoch {SWA_START_EPOCH} of full-tuning.")

    for epoch in range(FULL_MODEL_EPOCHS):
        current_epoch_num = HEAD_ONLY_EPOCHS + epoch + 1
        train_loss, train_f1 = train_epoch(full_model, dataloader_train, criterion, optimizer, scheduler, scaler)

        val_loss, val_f1, val_p, val_r, val_acc = validate_epoch(full_model, dataloader_val, criterion)
        print(f"[Full Tune] Epoch {current_epoch_num} -> Train F1: {train_f1:.4f} | Val F1: {val_f1:.4f} | Val P: {val_p:.4f} | Val R: {val_r:.4f}")

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            epochs_no_improve = 0
            torch.save(full_model.state_dict(), BEST_MODEL_SAVE_PATH)
            print(f"** New best standard model saved! Val F1: {best_val_f1:.4f} **")
        else:
            epochs_no_improve += 1
            print(f"Validation F1 did not improve. Patience: {epochs_no_improve}/{EARLY_STOPPING_PATIENCE}")

        if epoch >= SWA_START_EPOCH:
            swa_model.update_parameters(full_model)
            swa_scheduler.step()
            print("SWA model updated.")

        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"\nEarly stopping triggered after {EARLY_STOPPING_PATIENCE} epochs with no improvement.")
            break


    print("\n--- FINAL EVALUATION USING SWA MODEL FOR MAXIMUM GENERALIZATION ---")
    torch.optim.swa_utils.update_bn(dataloader_train, swa_model, device=DEVICE)
    swa_val_loss, swa_val_f1, swa_val_p, swa_val_r, swa_val_acc = validate_epoch(swa_model, dataloader_val, criterion)

    print("\n" + "="*50 + "\n               TRAINING COMPLETE\n" + "="*50)
    print(f"Best Standard Model (val F1): {best_val_f1:.5f}\nFinal SWA Model (val F1):     {swa_val_f1:.5f}")
    print(f"SWA Metrics -> P: {swa_val_p:.4f}, R: {swa_val_r:.4f}, Acc: {swa_val_acc:.4f}")

    torch.save(swa_model.state_dict(), SWA_MODEL_SAVE_PATH)
    print(f" SWA model saved to {SWA_MODEL_SAVE_PATH}")