In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import pandas as pd
import numpy as np
import os
from pathlib import Path
from tqdm import tqdm
import random
from transformers import CLIPProcessor, CLIPModel
from sklearn.metrics import f1_score
from sklearn.utils.class_weight import compute_class_weight
from torchvision import transforms
import glob
import sys
import argparse
import warnings

warnings.filterwarnings('ignore')


CLIP_MODEL_ID = "openai/clip-vit-large-patch14"
OUTPUT_SIZE_B = 4
LABEL_TO_ID_B = {'Undirected': 0, 'Individual': 1, 'Community': 2, 'Organization': 3}

SEED = 42
BATCH_SIZE = 8
GRAD_ACCUMULATION_STEPS = 8
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_WORKERS = 2
NUM_EPOCHS = 15
MAX_LR = 2e-5
WEIGHT_DECAY = 1e-1
LABEL_SMOOTHING = 0.1
EARLY_STOPPING_PATIENCE = 5

def seed_everything(seed):
    """Set random seeds for reproducibility."""
    random.seed(seed); os.environ['PYTHONHASHSEED'] = str(seed); np.random.seed(seed)
    torch.manual_seed(seed);
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed); torch.backends.cudnn.deterministic = True
    print(f"Random seed set to {seed} for reproducibility.")


def load_and_combine_data_B(base_dir):
    """Loads and prepares train/validation dataframes from the specified base directory."""

    train_images_root_dir = os.path.join(base_dir, "Subtask_B", "Train", "Subtask_B_Train")
    train_text_csv_file = os.path.join(base_dir, "Train_Text", "STask_B_train.csv")
    val_images_dir = os.path.join(base_dir, "Subtask_B", "Evaluation", "STask_B_val_img")
    val_text_csv_file = os.path.join(base_dir, "Eval_Data_Text", "STask-B(index,text)val.csv")
    val_labels_csv_file = os.path.join(base_dir, "Eval_Data_Labels", "STask-B(index,label)val.csv")

    train_text_df = pd.read_csv(train_text_csv_file)
    if 'name' not in train_text_df.columns and 'index' in train_text_df.columns:
        train_text_df.rename(columns={'index': 'name'}, inplace=True)
    train_text_df.set_index('name', inplace=True)

    train_data = []
    for category, label_id in LABEL_TO_ID_B.items():
        folder_path = os.path.join(train_images_root_dir, category)
        image_files = glob.glob(os.path.join(folder_path, '*.*'))
        for img_path in image_files:
            img_filename = os.path.basename(img_path)
            try: text = str(train_text_df.loc[img_filename]['text'])
            except KeyError: text = ""
            train_data.append({'name': img_filename, 'text': text, 'label': label_id, 'img_path': img_path, 'split': 'train'})
    train_df = pd.DataFrame(train_data)

    val_labels_df = pd.read_csv(val_labels_csv_file)
    val_text_df = pd.read_csv(val_text_csv_file)
    if 'name' not in val_labels_df.columns and 'index' in val_labels_df.columns: val_labels_df.rename(columns={'index': 'name'}, inplace=True)
    if 'name' not in val_text_df.columns and 'index' in val_text_df.columns: val_text_df.rename(columns={'index': 'name'}, inplace=True)
    val_df = pd.merge(val_labels_df, val_text_df, on='name', how='inner')
    val_df.dropna(subset=['label'], inplace=True); val_df['label'] = val_df['label'].astype(int)
    val_df['img_path'] = val_df['name'].apply(lambda x: os.path.join(val_images_dir, x)); val_df['split'] = 'val'
    return pd.concat([train_df, val_df], ignore_index=True)

class MultimodalDataset(Dataset):
    def __init__(self, df, tokenizer, transform):
        self.df, self.tokenizer, self.transform = df, tokenizer, transform
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text = self.tokenizer(str(row['text']), padding='max_length', max_length=77, truncation=True, return_tensors="pt")
        label = torch.tensor(row['label'], dtype=torch.long)
        try: image = self.transform(Image.open(row['img_path']).convert("RGB"))
        except Exception: image = torch.zeros((3, 224, 224), dtype=torch.float32)
        return {'image': image, 'text_input_ids': text['input_ids'].squeeze(), 'text_attention_mask': text['attention_mask'].squeeze(), 'label': label}


class CoFTModel(nn.Module):
    def __init__(self, clip_model, num_classes, dropout=0.3):
        super().__init__()
        self.clip = clip_model
        embed_dim = self.clip.config.projection_dim
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=8, dim_feedforward=embed_dim*4, dropout=dropout, activation='gelu', batch_first=True)
        self.text_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
        self.image_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
        self.cross_attention = nn.MultiheadAttention(embed_dim, num_heads=8, dropout=dropout, batch_first=True)
        self.head = nn.Sequential(
            nn.LayerNorm(embed_dim * 2),
            nn.Linear(embed_dim * 2, embed_dim),
            nn.GELU(),
            nn.Dropout(0.5),
            nn.Linear(embed_dim, num_classes)
        )

    def forward(self, image, text_input_ids, text_attention_mask):
        image_feat = self.clip.get_image_features(pixel_values=image).unsqueeze(1)
        text_feat = self.clip.get_text_features(input_ids=text_input_ids, attention_mask=text_attention_mask).unsqueeze(1)
        refined_image_feat = self.image_encoder(image_feat)
        refined_text_feat = self.text_encoder(text_feat)
        text_q_img_context, _ = self.cross_attention(query=refined_text_feat, key=refined_image_feat, value=refined_image_feat)
        fused_feat = torch.cat([refined_text_feat, text_q_img_context], dim=-1).squeeze(1)
        logits = self.head(fused_feat)
        return logits


def train_epoch(model, data_loader, criterion, optimizer, scheduler, scaler):
    model.train(); total_loss = 0; all_preds, all_labels = [], []
    optimizer.zero_grad()
    for i, batch in enumerate(tqdm(data_loader, desc="Training", leave=False)):
        images = batch['image'].to(DEVICE)
        input_ids = batch['text_input_ids'].to(DEVICE)
        attention_mask = batch['text_attention_mask'].to(DEVICE)
        labels = batch['label'].to(DEVICE)

        with torch.cuda.amp.autocast():
            logits = model(images, input_ids, attention_mask)
            loss = criterion(logits, labels) / GRAD_ACCUMULATION_STEPS
        scaler.scale(loss).backward()
        if (i + 1) % GRAD_ACCUMULATION_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()
        total_loss += loss.item() * GRAD_ACCUMULATION_STEPS
        all_preds.extend(torch.argmax(logits, dim=1).cpu().numpy()); all_labels.extend(labels.cpu().numpy())
    return total_loss/len(data_loader), f1_score(all_labels, all_preds, average='weighted', zero_division=0)

def validate_epoch(model, data_loader, criterion):
    model.eval(); total_loss = 0; all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Validating", leave=False):
            images = batch['image'].to(DEVICE)
            input_ids = batch['text_input_ids'].to(DEVICE)
            attention_mask = batch['text_attention_mask'].to(DEVICE)
            labels = batch['label'].to(DEVICE)
            with torch.cuda.amp.autocast():
                logits = model(images, input_ids, attention_mask)
                loss = criterion(logits, labels)
            total_loss += loss.item()
            all_preds.extend(torch.argmax(logits, dim=1).cpu().numpy()); all_labels.extend(labels.cpu().numpy())
    return total_loss/len(data_loader), f1_score(all_labels, all_preds, average='weighted', zero_division=0)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train CoFT model for Subtask B: Target Classification.")
    parser.add_argument(
        '--base_dir',
        type=str,
        default="./SharedTaskProject",
        help="Path to the base project directory containing the Subtask data folders."
    )
    args = parser.parse_args()

    seed_everything(SEED)


    BASE_PROJECT_DIR = args.base_dir
    CURRENT_EXP_NAME = "CoFT_TargetClassification_EndToEnd_v1"
    MODEL_SAVE_PATH_B = os.path.join(BASE_PROJECT_DIR, "models", "Subtask_B", CURRENT_EXP_NAME)
    Path(MODEL_SAVE_PATH_B).mkdir(parents=True, exist_ok=True)
    BEST_MODEL_SAVE_PATH = os.path.join(MODEL_SAVE_PATH_B, "best_model_subtask_B.pt")
    print(f"This run's models will be saved in: {MODEL_SAVE_PATH_B}")

    print("Loading CLIP base model and processor...")
    clip_model = CLIPModel.from_pretrained(CLIP_MODEL_ID)
    processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID)

    print("Loading and preparing dataframes...")
    all_data_df = load_and_combine_data_B(BASE_PROJECT_DIR)
    train_df = all_data_df[all_data_df['split'] == 'train'].reset_index(drop=True)
    val_df = all_data_df[all_data_df['split'] == 'val'].reset_index(drop=True)


    norm_mean, norm_std = processor.image_processor.image_mean, processor.image_processor.image_std
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.RandAugment(num_ops=2, magnitude=10),
        transforms.ToTensor(),
        transforms.Normalize(mean=norm_mean, std=norm_std)])
    val_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=norm_mean, std=norm_std)])

    train_dataset = MultimodalDataset(train_df, processor.tokenizer, train_transform)
    val_dataset = MultimodalDataset(val_df, processor.tokenizer, val_transform)
    dataloader_train = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    dataloader_val = DataLoader(val_dataset, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=NUM_WORKERS)


    print("\n--- Training Co-Attentional Fusion Transformer (CoFT) End-to-End ---")
    model = CoFTModel(clip_model, OUTPUT_SIZE_B).to(DEVICE)

    class_weights = torch.tensor(compute_class_weight('balanced', classes=np.unique(train_df['label']), y=train_df['label']), dtype=torch.float).to(DEVICE)
    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=LABEL_SMOOTHING)

    head_params = [p for n, p in model.named_parameters() if 'clip' not in n]
    backbone_params = [p for n, p in model.named_parameters() if 'clip' in n]
    optimizer = torch.optim.AdamW([
        {'params': head_params, 'lr': MAX_LR},
        {'params': backbone_params, 'lr': MAX_LR / 10}
    ], weight_decay=WEIGHT_DECAY)

    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=[MAX_LR, MAX_LR/10],
                                                    epochs=NUM_EPOCHS,
                                                    steps_per_epoch=len(dataloader_train)//GRAD_ACCUMULATION_STEPS,
                                                    pct_start=0.2)

    scaler = torch.cuda.amp.GradScaler()
    best_val_f1 = -1.0
    epochs_no_improve = 0


    for epoch in range(NUM_EPOCHS):
        train_loss, train_f1 = train_epoch(model, dataloader_train, criterion, optimizer, scheduler, scaler)
        val_loss, val_f1 = validate_epoch(model, dataloader_val, criterion)

        print(f"Epoch {epoch+1}/{NUM_EPOCHS} -> Train F1: {train_f1:.4f} | Val F1: {val_f1:.4f} | Head LR: {optimizer.param_groups[0]['lr']:.2e}")

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            epochs_no_improve = 0
            torch.save(model.state_dict(), BEST_MODEL_SAVE_PATH)
            print(f"** New best model saved! Val F1: {val_f1:.4f} **")
        else:
            epochs_no_improve += 1
            print(f"Validation F1 did not improve. Patience: {epochs_no_improve}/{EARLY_STOPPING_PATIENCE}")

        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"\nEarly stopping triggered at epoch {epoch+1}.")
            break

    print("\n" + "="*50 + "\n TRAINING COMPLETE\n" + "="*50)
    print(f"Best Model F1 Score: {best_val_f1:.5f}")
    print(f"Model saved to {BEST_MODEL_SAVE_PATH}")