In [None]:
!pip -q install transformers==4.42.4 accelerate datasets==2.21.0

In [None]:
!git clone https://github.com/jefferyYu/UMT.git

In [None]:

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from transformers import BertTokenizer, BertModel
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
from tqdm import tqdm
import matplotlib.pyplot as plt
import ast
import gdown
import numpy as np
import torch
from tqdm import tqdm
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
    confusion_matrix
)
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
def parse_twitter_conll(path: str):
    """Parse Twitter2015 style: blocks separated by blank lines, starting with IMGID:xxxx"""
    samples = []
    with open(path, 'r', encoding='utf-8') as f:
        lines = [ln.rstrip('\n') for ln in f]
    img_id, toks, tags = None, [], []
    for ln in lines + ['']:  # sentinel blank to flush last sample
        if not ln.strip():
            if img_id is not None and toks:
                samples.append({'img_id': img_id, 'tokens': toks, 'labels': tags})
            img_id, toks, tags = None, [], []
            continue
        if ln.startswith('IMGID:'):
            img_id = ln.split(':', 1)[1].strip()
        else:
            # token and BIO tag separated by whitespace
            parts = ln.split()
            if len(parts) >= 2:
                toks.append(parts[0])
                tags.append(parts[1])
    return samples

def build_label_vocab(*lists_of_samples):
    labels = set()
    for s_list in lists_of_samples:
        for s in s_list:
            labels.update(s['labels'])
    labels = sorted(labels)  # stable order
    label2id = {l:i for i,l in enumerate(labels)}
    id2label = {i:l for l,i in label2id.items()}
    return label2id, id2label

def find_image_path(img_dir: str, img_id: str):
    for ext in ('.jpg', '.jpeg', '.png', '.bmp'):
        p = os.path.join(img_dir, img_id + ext)
        if os.path.exists(p):
            return p
    return None

In [None]:
# ================== Setup & Imports ==================
# !pip -q install transformers==4.42.4 accelerate datasets==2.21.0

import os, random, math, json, numpy as np
from collections import Counter
from typing import List, Tuple

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from PIL import Image
from torchvision import models, transforms

from transformers import RobertaTokenizerFast, RobertaModel, get_cosine_schedule_with_warmup
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# ------------------- Repro -------------------
SEED = 1337
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

# ================== Paths (your config) ==================
TXT_PATHS = {
    'train': '/kaggle/working/UMT/data/twitter2015/train.txt',
    'val':   '/kaggle/working/UMT/data/twitter2015/valid.txt',
    'test':  '/kaggle/working/UMT/data/twitter2015/test.txt'
}
IMG_DIRS = {
    'train': '/kaggle/input/twitter2015/twitter2015/twitter2015_images',
    'val':   '/kaggle/input/twitter2015/twitter2015/twitter2015_images',
    'test':  '/kaggle/input/twitter2015/twitter2015/twitter2015_images'
}

# ================== Utils ==================
def parse_twitter_conll(path: str):
    """Parse Twitter2015 style: blocks separated by blank lines, starting with IMGID:xxxx"""
    samples = []
    with open(path, 'r', encoding='utf-8') as f:
        lines = [ln.rstrip('\n') for ln in f]
    img_id, toks, tags = None, [], []
    for ln in lines + ['']:  # sentinel blank to flush last sample
        if not ln.strip():
            if img_id is not None and toks:
                samples.append({'img_id': img_id, 'tokens': toks, 'labels': tags})
            img_id, toks, tags = None, [], []
            continue
        if ln.startswith('IMGID:'):
            img_id = ln.split(':', 1)[1].strip()
        else:
            # token and BIO tag separated by whitespace
            parts = ln.split()
            if len(parts) >= 2:
                toks.append(parts[0])
                tags.append(parts[1])
    return samples

def build_label_vocab(*lists_of_samples):
    labels = set()
    for s_list in lists_of_samples:
        for s in s_list:
            labels.update(s['labels'])
    labels = sorted(labels)  # stable order
    label2id = {l:i for i,l in enumerate(labels)}
    id2label = {i:l for l,i in label2id.items()}
    return label2id, id2label

def find_image_path(img_dir: str, img_id: str):
    for ext in ('.jpg', '.jpeg', '.png', '.bmp'):
        p = os.path.join(img_dir, img_id + ext)
        if os.path.exists(p):
            return p
    return None

# ================== Dataset ==================
class Twitter2015MNER(Dataset):
    def __init__(self, samples, img_dir, tokenizer: RobertaTokenizerFast, label2id, max_len=128, aug=False):
        self.samples = samples
        self.img_dir = img_dir
        self.tokenizer = tokenizer
        self.label2id = label2id
        self.max_len = max_len

        if aug:
            self.img_tfm = transforms.Compose([
                transforms.Resize((256,256)),
                transforms.RandomResizedCrop((224,224), scale=(0.9,1.0)),
                transforms.ColorJitter(0.15,0.15,0.15,0.05),
                transforms.RandomHorizontalFlip(0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
            ])
        else:
            self.img_tfm = transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
            ])

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        ex = self.samples[idx]
        tokens: List[str] = ex['tokens']
        labels: List[str] = ex['labels']

        # Tokenize with word alignment
        # --- Tokenize text first ---
        encodings = self.tokenizer(
            tokens,
            is_split_into_words=True,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors='pt'
        )
        
        # 👉 Call .word_ids BEFORE squeezing
        word_ids = encodings.word_ids(batch_index=0)
        
        # Then convert tensors for PyTorch
        enc = {k: v.squeeze(0) for k, v in encodings.items()}
        
        # --- Align labels with subwords ---
        label_ids = []
        prev_word = None
        for w_id in word_ids:
            if w_id is None:
                label_ids.append(-100)
            elif w_id != prev_word:
                label_ids.append(self.label2id[labels[w_id]])
            else:
                label_ids.append(-100)
            prev_word = w_id
        label_ids = torch.tensor(label_ids, dtype=torch.long)

        # Image
        img_path = find_image_path(self.img_dir, ex['img_id'])
        if img_path is None:
            # fallback: blank image if missing
            img = Image.new('RGB', (224,224), color=(0,0,0))
        else:
            img = Image.open(img_path).convert('RGB')
        img = self.img_tfm(img)

        return {
            'input_ids': enc['input_ids'],
            'attention_mask': enc['attention_mask'],
            'pixel_values': img,
            'labels': label_ids
        }

In [None]:
class RobertaResNet50MNER(nn.Module):
    """
    Token classification with image-conditioned modulation (FiLM-like):
      - Text encoder: roberta-large (hidden=1024)
      - Image encoder: resnet50 -> 2048-d pooled -> Linear -> 1024
      - gamma = Wg(img), beta = Wb(img)
      - h' = (1 + gamma) * h + beta  (applied to every token)
      - Token classifier to BIO tag space
    """
    def __init__(self, num_labels, text_model='roberta-large'):
        super().__init__()

        # Text Model (RoBERTa)
        self.text = RobertaModel.from_pretrained(text_model)
        hidden = self.text.config.hidden_size  # 1024

        # Visual Model (ResNet50)
        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        resnet.fc = nn.Identity()
        self.visual = resnet

        # Image Projection to match RoBERTa's hidden size
        self.img_proj = nn.Linear(2048, hidden)

        # Cross-attention Fusion Layer (Gamma and Beta for FiLM modulation)
        self.gamma = nn.Linear(hidden, hidden)
        self.beta  = nn.Linear(hidden, hidden)

        # Bimodal Projection Layer (to project concatenated features to the desired dimension)
        self.bimodal_projection = nn.Linear(2048, hidden)

        # Additional Feedforward Layers (MLP) for improved learning
        self.ff_text = nn.Sequential(
            nn.Linear(hidden, hidden * 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden * 2, hidden)
        )

        self.ff_image = nn.Sequential(
            nn.Linear(hidden, hidden * 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden * 2, hidden)
        )

        # Attention Layer to refine fusion (optional)
        self.attn = nn.MultiheadAttention(embed_dim=hidden, num_heads=8, batch_first=True)

        # Dropout for regularization
        self.dropout = nn.Dropout(0.1)

        # Classifier to predict the labels
        self.classifier = nn.Linear(hidden, num_labels)

    def forward(self, input_ids, attention_mask, pixel_values, labels=None):
        # Text
        out = self.text(input_ids=input_ids, attention_mask=attention_mask)
        seq = out.last_hidden_state  # [B, L, 1024]

        # Image
        img_feat = self.visual(pixel_values)       # [B, 2048]
        img_feat = self.img_proj(img_feat)         # [B, 1024]

        # FiLM modulation per token
        g = self.gamma(img_feat).unsqueeze(1)      # [B, 1, 1024]
        b = self.beta(img_feat).unsqueeze(1)       # [B, 1, 1024]
        seq = (1 + g) * seq + b                    # [B, L, 1024]

        # Bimodal Fusion: Concatenate Text and Image features
        # Concatenate along the feature dimension (dim=-1)
        bimodal_feats = torch.cat((seq, img_feat.unsqueeze(1).repeat(1, seq.size(1), 1)), dim=-1)  # [B, L, 2048]

        # Apply Bimodal Projection to reduce the concatenated dimension to 1024
        bimodal_feats = self.bimodal_projection(bimodal_feats)  # [B, L, 1024]

        # Optionally apply attention on the fused features
        attn_output, _ = self.attn(bimodal_feats, bimodal_feats, bimodal_feats)

        # Dropout regularization
        attn_output = self.dropout(attn_output)

        # Classifier to predict the labels (BIO tags)
        logits = self.classifier(attn_output)  # [B, L, C]

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
        return logits, loss


In [None]:
# ================== Build Data ==================
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-large', add_prefix_space=True)


train_samples = parse_twitter_conll(TXT_PATHS['train'])
val_samples   = parse_twitter_conll(TXT_PATHS['val'])
test_samples  = parse_twitter_conll(TXT_PATHS['test'])

label2id, id2label = build_label_vocab(train_samples, val_samples, test_samples)
print("Labels:", label2id)

train_ds = Twitter2015MNER(train_samples, IMG_DIRS['train'], tokenizer, label2id, max_len=128, aug=True)
val_ds   = Twitter2015MNER(val_samples,   IMG_DIRS['val'],   tokenizer, label2id, max_len=128, aug=False)
test_ds  = Twitter2015MNER(test_samples,  IMG_DIRS['test'],  tokenizer, label2id, max_len=128, aug=False)

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True,  num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=12, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=12, shuffle=False, num_workers=2, pin_memory=True)

# ================== Train ==================
model = RobertaResNet50MNER(num_labels=len(label2id)).to(device)

EPOCHS = 5
optim = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)

total_steps = len(train_loader) * EPOCHS
warmup = int(0.1 * total_steps)
sched = get_cosine_schedule_with_warmup(optim, num_warmup_steps=warmup, num_training_steps=total_steps)

scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))

def run_epoch(loader, train=True):
    if train: model.train()
    else: model.eval()

    all_preds, all_labels = [], []
    losses = []

    it = tqdm(loader, leave=False, desc="Train" if train else "Eval")
    for batch in it:
        input_ids = batch['input_ids'].to(device)
        attn      = batch['attention_mask'].to(device)
        pixels    = batch['pixel_values'].to(device)
        labels    = batch['labels'].to(device)

        with torch.cuda.amp.autocast(enabled=(scaler is not None)):
            logits, loss = model(input_ids, attn, pixels, labels if train else labels)

        if train:
            optim.zero_grad(set_to_none=True)
            if scaler:
                scaler.scale(loss).backward()
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optim)
                scaler.update()
            else:
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optim.step()
            sched.step()

        losses.append(loss.item())

        # collect predictions & labels for metrics (ignore -100)
        preds = torch.argmax(logits, dim=-1).detach().cpu().numpy()
        golds = labels.detach().cpu().numpy()

        for p, g in zip(preds, golds):
            for pi, gi in zip(p, g):
                if gi == -100:  # skip subwords / pads
                    continue
                all_preds.append(pi)
                all_labels.append(gi)

    avg_loss = float(np.mean(losses))
    acc = accuracy_score(all_labels, all_preds)
    prec, rec, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='micro', zero_division=0)
    return avg_loss, acc, prec, rec, f1

best_val_f1, best_state = -1.0, None
patience, bad = 2, 0

for epoch in range(1, EPOCHS+1):
    tr_loss, tr_acc, tr_p, tr_r, tr_f1 = run_epoch(train_loader, train=True)
    vl_loss, vl_acc, vl_p, vl_r, vl_f1 = run_epoch(val_loader,   train=False)

    print(f"Epoch {epoch:02d} | "
          f"Train loss {tr_loss:.4f} acc {tr_acc:.3f} P {tr_p:.3f} R {tr_r:.3f} F1 {tr_f1:.3f} || "
          f"Val loss {vl_loss:.4f} acc {vl_acc:.3f} P {vl_p:.3f} R {vl_r:.3f} F1 {vl_f1:.3f}")

    if vl_f1 > best_val_f1:
        best_val_f1 = vl_f1; bad = 0
        best_state = {
            'model': model.state_dict(),
            'label2id': label2id,
            'id2label': id2label,
            'config': {'text':'roberta-large','img':'resnet50','fusion':'FiLM'}
        }
        torch.save(best_state, 'roberta_resnet50_mner_best.pth')
    else:
        bad += 1
        if bad >= patience:
            print("Early stopping.")
            break

# ================== Evaluate on Test ==================
# load best
if best_state is None:
    best_state = torch.load('roberta_resnet50_mner_best.pth', map_location=device)
model.load_state_dict(best_state['model'])

def evaluate(loader):
    model.eval()
    all_preds, all_labels = [], []
    losses = []
    with torch.no_grad():
        for batch in tqdm(loader, leave=False, desc="Test"):
            input_ids = batch['input_ids'].to(device)
            attn      = batch['attention_mask'].to(device)
            pixels    = batch['pixel_values'].to(device)
            labels    = batch['labels'].to(device)
            logits, loss = model(input_ids, attn, pixels, labels)
            losses.append(loss.item())

            preds = torch.argmax(logits, dim=-1).detach().cpu().numpy()
            golds = labels.detach().cpu().numpy()
            for p, g in zip(preds, golds):
                for pi, gi in zip(p, g):
                    if gi == -100: continue
                    all_preds.append(pi); all_labels.append(gi)

    avg_loss = float(np.mean(losses))
    acc = accuracy_score(all_labels, all_preds)
    prec, rec, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='micro', zero_division=0)
    return avg_loss, acc, prec, rec, f1

te_loss, te_acc, te_p, te_r, te_f1 = evaluate(test_loader)
print("\n===== TEST RESULTS (Token-level, ignore subwords) =====")
print(f"Loss: {te_loss:.4f}")
print(f"Accuracy:  {te_acc:.4f}")
print(f"Precision: {te_p:.4f}")
print(f"Recall:    {te_r:.4f}")
print(f"F1:        {te_f1:.4f}")

# (Optional) per-label report
def per_label_report(loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attn      = batch['attention_mask'].to(device)
            pixels    = batch['pixel_values'].to(device)
            labels    = batch['labels'].to(device)
            logits, _ = model(input_ids, attn, pixels, labels)
            preds = torch.argmax(logits, dim=-1).detach().cpu().numpy()
            golds = labels.detach().cpu().numpy()
            for p, g in zip(preds, golds):
                for pi, gi in zip(p, g):
                    if gi == -100: continue
                    all_preds.append(pi); all_labels.append(gi)

    from sklearn.metrics import classification_report
    target_names = [l for i,l in sorted(id2label.items())]
    print("\nPer-label report:")
    print(classification_report(all_labels, all_preds, target_names=target_names, zero_division=0))

# Uncomment to print per-label metrics:
# per_label_report(test_loader)


In [None]:
per_label_report(test_loader)