In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import pandas as pd
import numpy as np
import os
from pathlib import Path
from tqdm import tqdm
import random
from transformers import CLIPProcessor, CLIPModel
from sklearn.metrics import f1_score
from sklearn.utils.class_weight import compute_class_weight
from torchvision import transforms
import glob
import sys
import argparse
import warnings

warnings.filterwarnings('ignore')


CLIP_MODEL_ID = "openai/clip-vit-large-patch14"
D_CLIP_HIDDEN = 1024
RESIDUAL_ALPHA = 0.2
COSINE_CLASSIFIER_SCALE = 20.0
OUTPUT_SIZE_D = 2
LABEL_TO_ID_D = {'No Humor': 0, 'Humor': 1}
ID_TO_LABEL_D = {v: k for k, v in LABEL_TO_ID_D.items()}


SEED = 42
BATCH_SIZE = 32
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_WORKERS = 2

HEAD_ONLY_EPOCHS = 5
FULL_MODEL_EPOCHS = 20
LAYERS_TO_UNFREEZE = 2


MAX_LR_HEAD = 1e-4
MAX_LR_BACKBONE = 1e-6
WEIGHT_DECAY = 1e-2
LABEL_SMOOTHING = 0.1
GRADIENT_CLIP_VAL = 1.0
EARLY_STOPPING_PATIENCE = 5

def seed_everything(seed):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    print(f"Random seed set to {seed}")


def load_data_D(split, text_csv, img_dir, label_csv=None, label_map=None):
    """Load and prepare dataset for training or validation."""
    text_df = pd.read_csv(text_csv).rename(columns={'index': 'name', 'text': 'text'})

    if split == 'train':
        data = []
        for category, label_id in label_map.items():
            folder = os.path.join(img_dir, category)
            for img_path in glob.glob(os.path.join(folder, '*.*')):
                name = os.path.basename(img_path)
                text = text_df[text_df['name'] == name]['text'].values[0] if name in text_df['name'].values else ""
                data.append({'name': name, 'text': text, 'label': label_id, 'img_path': img_path})
        return pd.DataFrame(data)
    else:  # validation set
        labels_df = pd.read_csv(label_csv).rename(columns={'index': 'name', 'label': 'label'})
        df = pd.merge(labels_df, text_df, on='name')
        df['img_path'] = df['name'].apply(lambda x: os.path.join(img_dir, x))
        return df

class MultimodalDataset(Dataset):
    """PyTorch Dataset for multimodal data."""
    def __init__(self, data_df, processor, transform):
        self.data = data_df
        self.processor = processor
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        label = torch.tensor(item['label'], dtype=torch.long)
        text = self.processor.tokenizer(
            str(item['text']), return_tensors="pt", padding='max_length', truncation=True
        ).input_ids.squeeze(0)

        try:
            image = self.transform(Image.open(item['img_path']).convert("RGB"))
        except Exception:
            image = torch.zeros((3, 224, 224), dtype=torch.float)

        return {"image": image, "text": text, "label": label}


class FeatureAdapter(nn.Module):
    """A lightweight adapter module with residual connection."""
    def __init__(self, dim):
        super().__init__()
        self.adapter = nn.Sequential(nn.Linear(dim, dim // 2), nn.GELU(), nn.Linear(dim // 2, dim))
    def forward(self, x, alpha):
        return alpha * self.adapter(x) + (1.0 - alpha) * x

class CosineClassifier(nn.Module):
    """Cosine similarity based classifier using learned prompt embeddings."""
    def __init__(self, in_dim, out_dim, scale, prompts, clip_model, processor):
        super().__init__()
        self.weight = nn.Parameter(torch.Tensor(out_dim, in_dim))
        self.scale = scale
        with torch.no_grad():
            sai_proj = nn.Linear(clip_model.config.projection_dim, in_dim).to(clip_model.device)
            tokenized = processor.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(clip_model.device)
            raw_embeds = clip_model.get_text_features(**tokenized)
            self.weight.data = F.normalize(sai_proj(raw_embeds.float()).data, p=2, dim=1)
    def forward(self, x):
        return F.linear(F.normalize(x, p=2, dim=1), F.normalize(self.weight, p=2, dim=1)) * self.scale

class SotaMemeCLIP(nn.Module):
    """The full model combining CLIP with adapters and cosine classifier."""
    def __init__(self, clip_model, processor, num_classes, id_to_label_map):
        super().__init__()
        self.clip_model = clip_model
        clip_dim = self.clip_model.config.projection_dim

        self.image_proj = nn.Linear(clip_dim, D_CLIP_HIDDEN)
        self.text_proj = nn.Linear(clip_dim, D_CLIP_HIDDEN)
        self.image_adapter = FeatureAdapter(D_CLIP_HIDDEN)
        self.text_adapter = FeatureAdapter(D_CLIP_HIDDEN)
        self.pre_output = nn.Sequential(nn.Dropout(0.1), nn.Linear(D_CLIP_HIDDEN, D_CLIP_HIDDEN), nn.GELU(), nn.Dropout(0.1))

        prompts = [f"This is a meme about {id_to_label_map[i]}" for i in range(num_classes)]
        self.cosine_classifier = CosineClassifier(D_CLIP_HIDDEN, num_classes, COSINE_CLASSIFIER_SCALE, prompts, clip_model, processor)

    def forward(self, images, texts):
        with torch.set_grad_enabled(self.training):
            img_features = self.clip_model.get_image_features(pixel_values=images)
            text_features = self.clip_model.get_text_features(input_ids=texts)

        img_proj = self.image_proj(img_features)
        text_proj = self.text_proj(text_features)
        img_adapted = self.image_adapter(img_proj, RESIDUAL_ALPHA)
        text_adapted = self.text_adapter(text_proj, RESIDUAL_ALPHA)

        fused = img_adapted * text_adapted
        pre_out = self.pre_output(fused)
        logits = self.cosine_classifier(pre_out)
        return logits


def compute_loss(logits, targets, smoothing=0.0):
    """Compute cross-entropy loss with optional label smoothing."""
    if smoothing > 0:
        n_classes = logits.size(1)
        smooth_labels = (1 - smoothing) * F.one_hot(targets, num_classes=n_classes) + smoothing / n_classes
        log_probs = F.log_softmax(logits, dim=1)
        loss = -(smooth_labels * log_probs).sum(dim=1).mean()
    else:
        loss = F.cross_entropy(logits, targets)
    return loss

def calculate_metrics(preds, labels):
    """Compute weighted F1-score."""
    return f1_score(labels.cpu().numpy(), preds.cpu().numpy(), average='weighted')

def save_checkpoint(model, optimizer, epoch, path):
    """Save model checkpoint."""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, path)

def train_model(model, train_loader, val_loader, model_save_path, optimizer, scheduler=None):
    """Main two-stage training loop."""
    best_val_f1 = 0.0
    patience_counter = 0

    for stage, epochs in enumerate([HEAD_ONLY_EPOCHS, FULL_MODEL_EPOCHS], start=1):
        print(f"\n=== Training Stage {stage} for {epochs} epochs ===")

        if stage == 1:
            for param in model.clip_model.parameters(): param.requires_grad = False
            for param in model.parameters():
                if any(p is param for p in model.clip_model.parameters()): continue
                param.requires_grad = True
        else:
            for param in model.clip_model.parameters(): param.requires_grad = False
            for i in range(-LAYERS_TO_UNFREEZE, 0):
                for param in model.clip_model.vision_model.encoder.layers[i].parameters(): param.requires_grad = True

        for epoch in range(epochs):
            model.train()
            running_loss = 0.0
            for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}", leave=False):
                images, texts, labels = batch['image'].to(DEVICE), batch['text'].to(DEVICE), batch['label'].to(DEVICE)
                optimizer.zero_grad()
                outputs = model(images, texts)
                loss = compute_loss(outputs, labels, LABEL_SMOOTHING)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP_VAL)
                optimizer.step()
                if scheduler: scheduler.step()
                running_loss += loss.item()

            avg_train_loss = running_loss / len(train_loader)

            model.eval()
            val_preds, val_labels = [], []
            with torch.no_grad():
                for batch in val_loader:
                    images, texts, labels = batch['image'].to(DEVICE), batch['text'].to(DEVICE), batch['label'].to(DEVICE)
                    outputs = model(images, texts)
                    preds = torch.argmax(outputs, dim=1)
                    val_preds.append(preds); val_labels.append(labels)

            val_preds = torch.cat(val_preds); val_labels = torch.cat(val_labels)
            val_f1 = calculate_metrics(val_preds, val_labels)
            print(f"Epoch {epoch + 1}: Train Loss={avg_train_loss:.4f} | Val Weighted F1={val_f1:.4f}")

            if val_f1 > best_val_f1:
                best_val_f1 = val_f1; patience_counter = 0
                checkpoint_path = os.path.join(model_save_path, f"best_model_stage{stage}.pt")
                save_checkpoint(model, optimizer, epoch, checkpoint_path)
                print(f"Saved best model at epoch {epoch + 1} with F1: {val_f1:.4f}")
            else:
                patience_counter += 1
                if patience_counter >= EARLY_STOPPING_PATIENCE:
                    print(f"No improvement for {EARLY_STOPPING_PATIENCE} epochs, early stopping...")
                    return


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train MemeCLIP model for Subtask D: Humor Detection.")
    parser.add_argument(
        '--base_dir',
        type=str,
        default="./SharedTaskProject",
        help="Path to the base project directory containing the Subtask data folders."
    )
    args = parser.parse_args()

    seed_everything(SEED)


    BASE_PROJECT_DIR = args.base_dir
    MODELS_DIR_D = os.path.join(BASE_PROJECT_DIR, "models", "Subtask_D")
    EXP_NAME_D = "memeclip_humor_sota_v1"
    MODEL_SAVE_PATH_D = os.path.join(MODELS_DIR_D, EXP_NAME_D)
    Path(MODEL_SAVE_PATH_D).mkdir(parents=True, exist_ok=True)
    print(f"This run's models will be saved in: {MODEL_SAVE_PATH_D}")

    TRAIN_IMAGES_ROOT_DIR_D = os.path.join(BASE_PROJECT_DIR, "Subtask_D", "Train", "Subtask_D_Train")
    TRAIN_TEXT_CSV_FILE_D = os.path.join(BASE_PROJECT_DIR, "Train_Text", "STask_D_train.csv")
    VAL_IMAGES_DIR_D = os.path.join(BASE_PROJECT_DIR, "Subtask_D", "Evaluation", "STask_D_val_img")
    VAL_TEXT_CSV_FILE_D = os.path.join(BASE_PROJECT_DIR, "Eval_Data_Text", "STask-D(index,text)val.csv")
    VAL_LABELS_CSV_FILE_D = os.path.join(BASE_PROJECT_DIR, "Eval_Data_Labels", "STask-D(index,label)val.csv")


    print(f"Loading CLIP model and processor: {CLIP_MODEL_ID}...")
    try:
        processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID)
        clip_model = CLIPModel.from_pretrained(CLIP_MODEL_ID).to(DEVICE)
    except Exception as e:
        sys.exit(f"ERROR: Could not load CLIP model: {e}.")

    print("Loading and preparing datasets...")
    train_df = load_data_D('train', TRAIN_TEXT_CSV_FILE_D, TRAIN_IMAGES_ROOT_DIR_D, label_map=LABEL_TO_ID_D)
    val_df = load_data_D('val', VAL_TEXT_CSV_FILE_D, VAL_IMAGES_DIR_D, VAL_LABELS_CSV_FILE_D)


    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

    train_dataset = MultimodalDataset(train_df, processor, train_transform)
    val_dataset = MultimodalDataset(val_df, processor, val_transform)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
    print(f"Train samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")


    model = SotaMemeCLIP(clip_model, processor, OUTPUT_SIZE_D, ID_TO_LABEL_D).to(DEVICE)
    params = [
        {'params': model.clip_model.parameters(), 'lr': MAX_LR_BACKBONE},
        {'params': model.image_proj.parameters(), 'lr': MAX_LR_HEAD},
        {'params': model.text_proj.parameters(), 'lr': MAX_LR_HEAD},
        {'params': model.image_adapter.parameters(), 'lr': MAX_LR_HEAD},
        {'params': model.text_adapter.parameters(), 'lr': MAX_LR_HEAD},
        {'params': model.pre_output.parameters(), 'lr': MAX_LR_HEAD},
        {'params': model.cosine_classifier.parameters(), 'lr': MAX_LR_HEAD},
    ]
    optimizer = torch.optim.AdamW(params, weight_decay=WEIGHT_DECAY)
    scheduler = None


    print("Starting training...")
    train_model(model, train_loader, val_loader, MODEL_SAVE_PATH_D, optimizer, scheduler)
    print("Training completed!")