In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from transformers import BertTokenizer, BertModel
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
from tqdm import tqdm
import matplotlib.pyplot as plt
import ast
import gdown

In [None]:
!git clone https://github.com/thecharm/MNRE.git
!pip install -q gdown
file_id = "1FYiJFtRayWY32nRH0rdycYzIdDcMmDFR"
gdown.download(f"https://drive.google.com/uc?id={file_id}", quiet=False)
!unzip mnre_img.zip

In [None]:
txt_paths = {
        'train': '/kaggle/working/MNRE/mnre_txt/mnre_train.txt',
        'val':   '/kaggle/working/MNRE/mnre_txt/mnre_val.txt',  # fixed space typo
        'test':  '/kaggle/working/MNRE/mnre_txt/mnre_test.txt'
    }
img_dirs = {
        'train': '/kaggle/working/img_org/train',
        'val':   '/kaggle/working/img_org/val',
        'test':  '/kaggle/working/img_org/test'
    }

In [None]:
import os
import ast
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

from transformers import BlipProcessor, BlipModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ========================
# 1. Load BLIP (pretrained vision-language model)
# ========================
MODEL_NAME = "Salesforce/blip-image-captioning-base"
processor = BlipProcessor.from_pretrained(MODEL_NAME)
blip_model = BlipModel.from_pretrained(MODEL_NAME).to(device)

# Freeze BLIP encoder (optional)
for param in blip_model.parameters():
    param.requires_grad = False

HIDDEN_SIZE = blip_model.config.text_config.hidden_size  # 768

# ========================
# 2. Dataset
# ========================
class MNREDataset(Dataset):
    def __init__(self, txt_file, img_dir, relation2id, transform=None, max_len=128):
        self.samples = []
        self.img_dir = img_dir
        self.transform = transform
        self.max_len = max_len

        with open(txt_file, "r", encoding="utf-8") as f:
            for line in f:
                obj = ast.literal_eval(line.strip())
                tokens = obj['token']
                h_start, h_end = obj['h']['pos']
                t_start, t_end = obj['t']['pos']

                # Insert entity markers
                spans = sorted([('h', h_start, h_end), ('t', t_start, t_end)], key=lambda x: x[1], reverse=True)
                for etype, start, end in spans:
                    tag_close = '[/E1]' if etype == 'h' else '[/E2]'
                    tokens.insert(end, tag_close)
                    tag_open = '[E1]' if etype == 'h' else '[E2]'
                    tokens.insert(start, tag_open)

                text = " ".join(tokens)
                img_id = obj['img_id']
                label = relation2id[obj['relation']]

                self.samples.append({"text": text, "img_id": img_id, "label": label})

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        image_path = os.path.join(self.img_dir, sample['img_id'])
        image = Image.open(image_path).convert("RGB")

        inputs = processor(images=image, text=sample['text'], padding="max_length", truncation=True,
                           max_length=self.max_len, return_tensors="pt")

        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "pixel_values": inputs["pixel_values"].squeeze(0),
            "label": torch.tensor(sample["label"], dtype=torch.long)
        }

# ========================
# 3. Model (Relation Classifier)
# ========================
class MultimodalREModel(nn.Module):
    def __init__(self, num_relations):
        super().__init__()
        self.blip = blip_model

        # 🔹 Get actual BLIP hidden size (can be 512 or 768 depending on model variant)
        hidden_size = self.blip.config.projection_dim  

        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_relations)
        )

    def forward(self, input_ids, attention_mask, pixel_values):
        outputs = self.blip(input_ids=input_ids,
                            attention_mask=attention_mask,
                            pixel_values=pixel_values)

        pooled = outputs.image_embeds  # or outputs.text_embeds (both are projection_dim size)
        logits = self.classifier(pooled)
        return logits


# ========================
# 4. Training & Evaluation
# ========================
def train_epoch(model, loader, criterion, optimizer):
    model.train()
    losses, preds, labels = [], [], []
    for batch in tqdm(loader, desc="Training", leave=False):
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attn = batch["attention_mask"].to(device)
        pixels = batch["pixel_values"].to(device)
        labs = batch["label"].to(device)

        logits = model(input_ids, attn, pixels)
        loss = criterion(logits, labs)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        preds.extend(torch.argmax(logits, 1).cpu().numpy())
        labels.extend(labs.cpu().numpy())

    return sum(losses)/len(losses), accuracy_score(labels, preds), f1_score(labels, preds, average="macro")

def eval_epoch(model, loader, criterion):
    model.eval()
    losses, preds, labels = [], [], []
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating", leave=False):
            input_ids = batch["input_ids"].to(device)
            attn = batch["attention_mask"].to(device)
            pixels = batch["pixel_values"].to(device)
            labs = batch["label"].to(device)

            logits = model(input_ids, attn, pixels)
            loss = criterion(logits, labs)

            losses.append(loss.item())
            preds.extend(torch.argmax(logits, 1).cpu().numpy())
            labels.extend(labs.cpu().numpy())

    return sum(losses)/len(losses), accuracy_score(labels, preds), f1_score(labels, preds, average="macro")

# ========================
# 5. Main
# ========================
def main():
    txt_paths = {
        'train': '/kaggle/working/MNRE/mnre_txt/mnre_train.txt',
        'val':   '/kaggle/working/MNRE/mnre_txt/mnre_val .txt',
        'test':  '/kaggle/working/MNRE/mnre_txt/mnre_test.txt'
    }
    img_dirs = {
        'train': '/kaggle/working/img_org/train',
        'val':   '/kaggle/working/img_org/val',
        'test':  '/kaggle/working/img_org/test'
    }

    # Build relation mapping
    rels = set()
    with open(txt_paths["train"], "r", encoding="utf-8") as f:
        for line in f:
            obj = ast.literal_eval(line.strip())
            rels.add(obj["relation"])
    relation2id = {r: idx for idx, r in enumerate(sorted(rels))}
    num_relations = len(relation2id)

    # Datasets & Loaders
    datasets = {
        split: MNREDataset(txt_paths[split], img_dirs[split], relation2id)
        for split in ("train", "val", "test")
    }
    loaders = {
        split: DataLoader(datasets[split], batch_size=8, shuffle=(split=="train"))
        for split in ("train", "val", "test")
    }

    # Model, Loss, Optimizer
    model = MultimodalREModel(num_relations).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

    history = {"train": [], "val": []}

    EPOCHS = 5
    for epoch in range(1, EPOCHS+1):
        tr_loss, tr_acc, tr_f1 = train_epoch(model, loaders["train"], criterion, optimizer)
        val_loss, val_acc, val_f1 = eval_epoch(model, loaders["val"], criterion)

        history["train"].append((tr_acc, tr_f1))
        history["val"].append((val_acc, val_f1))

        print(f"Epoch {epoch}/{EPOCHS} | Train Acc: {tr_acc:.4f} F1: {tr_f1:.4f} | Val Acc: {val_acc:.4f} F1: {val_f1:.4f}")

    # Test Evaluation
    test_loss, test_acc, test_f1 = eval_epoch(model, loaders["test"], criterion)
    print(f"✅ Test Accuracy: {test_acc:.4f}, F1: {test_f1:.4f}")

    return model, relation2id

if __name__ == "__main__":
    model, relation2id = main()


In [None]:
import os
import ast
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

from transformers import BlipProcessor, BlipModel
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline

# ----------------------------
# Config
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# BLIP backbone (pretrained multimodal)
BLIP_NAME = "Salesforce/blip-image-captioning-base"

# Your MNRE files (adjust if needed)
TXT_PATHS = {
    'train': '/kaggle/working/MNRE/mnre_txt/mnre_train.txt',
    'val':   '/kaggle/working/MNRE/mnre_txt/mnre_val.txt',   # ensure no trailing space in filename
    'test':  '/kaggle/working/MNRE/mnre_txt/mnre_test.txt',
}
IMG_DIRS = {
    'train': '/kaggle/working/img_org/train',
    'val':   '/kaggle/working/img_org/val',
    'test':  '/kaggle/working/img_org/test',
}

# Your trained RE checkpoint (produced by the BLIP training script)
CKPT_PATH = "mnre_blip_re.pth"  # change if your filename differs

# Example images the user gave
IMAGES_TO_PREDICT = [
    "/kaggle/working/img_org/test/twitter_19_31_0_13.jpg",
    "/kaggle/working/img_org/test/twitter_19_31_0_8.jpg",
]

# ----------------------------
# Utilities
# ----------------------------
def safe_fix_val_path():
    """If the val file was created with an accidental space, rename it once."""
    bad = "/kaggle/working/MNRE/mnre_txt/mnre_val .txt"
    good = "/kaggle/working/MNRE/mnre_txt/mnre_val.txt"
    if os.path.exists(bad) and not os.path.exists(good):
        os.rename(bad, good)

def load_all_mnre_entries(txt_paths):
    """Load ALL entries across train/val/test for lookup by img_id."""
    entries = []
    for split, p in txt_paths.items():
        if not os.path.exists(p):
            raise FileNotFoundError(f"Missing file for split '{split}': {p}")
        with open(p, "r", encoding="utf-8") as f:
            for line in f:
                obj = ast.literal_eval(line.strip())
                obj["_split"] = split
                entries.append(obj)
    return entries

def build_relation2id_from_train(train_path):
    """Build relation2id by scanning the train file (consistent with training)."""
    rels = set()
    with open(train_path, "r", encoding="utf-8") as f:
        for line in f:
            obj = ast.literal_eval(line.strip())
            rels.add(obj["relation"])
    relation2id = {r: idx for idx, r in enumerate(sorted(rels))}
    return relation2id

def insert_markers(tokens, h_pos, t_pos):
    """Insert [E1] [/E1], [E2] [/E2] into tokens given head/tail spans."""
    toks = tokens.copy()
    spans = [('h', h_pos[0], h_pos[1]), ('t', t_pos[0], t_pos[1])]
    spans.sort(key=lambda x: x[1], reverse=True)
    for etype, start, end in spans:
        tag_close = '[/E1]' if etype == 'h' else '[/E2]'
        tag_open  = '[E1]'  if etype == 'h' else '[E2]'
        toks.insert(end, tag_close)
        toks.insert(start, tag_open)
    return " ".join(toks)

def extract_entities_from_marked_text(marked_text):
    """Return (head_text, tail_text) by parsing [E1]..[/E1] and [E2]..[/E2]."""
    def between(text, open_tag, close_tag):
        if open_tag in text and close_tag in text:
            s = text.index(open_tag) + len(open_tag)
            e = text.index(close_tag)
            return text[s:e].strip()
        return None
    h_text = between(marked_text, "[E1]", "[/E1]")
    t_text = between(marked_text, "[E2]", "[/E2]")
    return h_text, t_text

# ----------------------------
# BLIP-based RE model definition (same as training)
# ----------------------------
class BLIPRelationClassifier(nn.Module):
    """
    Classification on top of BLIP's projection embeddings.
    We use blip.config.projection_dim to stay compatible across variants.
    """
    def __init__(self, num_relations, blip_model):
        super().__init__()
        self.blip = blip_model
        hidden = self.blip.config.projection_dim  # 512 for base captioning model
        self.classifier = nn.Sequential(
            nn.Linear(hidden, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_relations)
        )

    def forward(self, input_ids, attention_mask, pixel_values, use_text_embeds=True):
        """
        use_text_embeds=True:
            classify using BLIP's text_embeds (projected text representation)
        use_text_embeds=False:
            classify using BLIP's image_embeds
        Optionally, you could concatenate both.
        """
        out = self.blip(input_ids=input_ids,
                        attention_mask=attention_mask,
                        pixel_values=pixel_values)
        pooled = out.text_embeds if use_text_embeds else out.image_embeds
        logits = self.classifier(pooled)
        return logits

# ----------------------------
# Load BLIP + Processor + NER pipeline
# ----------------------------
def load_backbones():
    processor = BlipProcessor.from_pretrained(BLIP_NAME)
    blip_model = BlipModel.from_pretrained(BLIP_NAME).to(device)
    # We typically freeze BLIP for inference; it doesn't matter here.
    for p in blip_model.parameters():
        p.requires_grad = False

    # NER model (pretrained token-classification)
    ner_name = "dslim/bert-base-NER"
    ner_tok = AutoTokenizer.from_pretrained(ner_name)
    ner_mdl = AutoModelForTokenClassification.from_pretrained(ner_name).to(device)
    ner_pipe = pipeline("ner", model=ner_mdl, tokenizer=ner_tok, aggregation_strategy="simple", device=0 if device.type=="cuda" else -1)

    return processor, blip_model, ner_pipe

# ----------------------------
# Prepare model from checkpoint (or fall back)
# ----------------------------
def prepare_re_model(processor, blip_model, relation2id, ckpt_path):
    model = BLIPRelationClassifier(num_relations=len(relation2id), blip_model=blip_model).to(device)
    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path, map_location=device)
        missing = model.load_state_dict(ckpt["model_state"], strict=False)
        if missing.missing_keys or missing.unexpected_keys:
            print("⚠️ State dict mismatch:", missing)
        id2rel_from_ckpt = {v:k for k,v in ckpt.get("relation2id", {}).items()}
        print(f"✅ Loaded RE checkpoint: {ckpt_path}")
        return model, id2rel_from_ckpt if id2rel_from_ckpt else {v:k for k,v in relation2id.items()}
    else:
        print(f"⚠️ RE checkpoint not found at {ckpt_path}. Using fresh classifier weights (accuracy will be poor).")
        return model, {v:k for k,v in relation2id.items()}

# ----------------------------
# Build an index: img_id -> list of entries
# ----------------------------
def index_entries_by_img_id(entries):
    idx = {}
    for e in entries:
        img_id = e["img_id"]
        idx.setdefault(img_id, []).append(e)
    return idx

# ----------------------------
# Single-sample preprocessing for BLIP
# ----------------------------
def build_blip_inputs(processor, image_path, marked_text, max_len=128):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, text=marked_text,
                       padding="max_length", truncation=True,
                       max_length=max_len, return_tensors="pt")
    # Move to device
    for k in inputs:
        inputs[k] = inputs[k].to(device)
    return image, inputs

# ----------------------------
# Pretty print + visualize
# ----------------------------
def show_result(image, pred_relation, confidence, sentence, h_text, h_type, t_text, t_type):
    plt.imshow(image)
    plt.axis('off')
    plt.title(f"Relation: {pred_relation} ({confidence:.1%})")
    plt.show()

    print("────────────────────────────────────────")
    print("📝 Sentence:")
    print(sentence)
    print("\n👤 Head Entity (E1):", h_text, "| NER type:", h_type)
    print("👥 Tail Entity (E2):", t_text, "| NER type:", t_type)
    print("🔗 Predicted Relation:", pred_relation, f"(conf={confidence:.4f})")
    print("────────────────────────────────────────\n")

# ----------------------------
# NER helper for entity types
# ----------------------------
def ner_type_for_span(ner_pipe, full_text, span_text):
    """
    Run NER on the full sentence; pick the dominant label covering the given span_text.
    If not found, return 'MISC'.
    """
    ents = ner_pipe(full_text)
    # Find any entity whose text overlaps with span_text (simple contains)
    candidates = [e for e in ents if span_text and span_text in e["word"] or span_text and span_text in e.get("entity_group","")]
    if not candidates:
        # looser matching: check if entity words are inside the span text
        for e in ents:
            if e["word"] and e["word"] in span_text:
                candidates.append(e)
    if not candidates:
        return "MISC"
    # majority label
    from collections import Counter
    label = Counter([e["entity_group"] for e in candidates]).most_common(1)[0][0]
    return label

# ----------------------------
# Main: predict on a list of image paths
# ----------------------------
def predict_images(image_paths):
    # Fix any bad filename first
    safe_fix_val_path()

    # Load entries and index by img_id
    entries = load_all_mnre_entries(TXT_PATHS)
    by_img = index_entries_by_img_id(entries)

    # Build relation2id (from train) for consistent id order
    relation2id_train = build_relation2id_from_train(TXT_PATHS["train"])

    # Backbones
    processor, blip_model, ner_pipe = load_backbones()

    # RE model
    re_model, id2relation = prepare_re_model(processor, blip_model, relation2id_train, CKPT_PATH)
    re_model.eval()

    # Go over requested images
    for ipath in image_paths:
        if not os.path.exists(ipath):
            print(f"❌ Image not found: {ipath}")
            continue
        img_id = os.path.basename(ipath)

        if img_id not in by_img:
            print(f"⚠️ No MNRE entry found for image id: {img_id}")
            # Still show the image
            image = Image.open(ipath).convert("RGB")
            plt.imshow(image); plt.axis('off'); plt.title("Image (no matching text entry)"); plt.show()
            continue

        # MNRE can have multiple sentences per image; iterate all
        for sample in by_img[img_id]:
            tokens = sample["token"]
            h_pos = sample["h"]["pos"]
            t_pos = sample["t"]["pos"]

            # Insert entity markers so model knows E1/E2
            marked_text = insert_markers(tokens, h_pos, t_pos)
            h_text, t_text = extract_entities_from_marked_text(marked_text)

            # Build BLIP inputs
            image, inputs = build_blip_inputs(processor, ipath, marked_text, max_len=128)

            # RE prediction
            with torch.no_grad():
                logits = re_model(inputs["input_ids"], inputs["attention_mask"], inputs["pixel_values"], use_text_embeds=True)
                probs  = torch.softmax(logits, dim=1)
                pred_i = int(torch.argmax(probs, dim=1).item())
                conf   = float(probs[0, pred_i].item())
                pred_relation = id2relation.get(pred_i, f"rel_{pred_i}")

            # NER types for E1/E2 (using pretrained NER)
            # We run NER on the sentence WITHOUT tags for better tagging
            clean_sentence = marked_text.replace("[E1]", "").replace("[/E1]", "").replace("[E2]", "").replace("[/E2]", "")
            h_type = ner_type_for_span(ner_pipe, clean_sentence, h_text or "")
            t_type = ner_type_for_span(ner_pipe, clean_sentence, t_text or "")

            # Display
            show_result(image, pred_relation, conf, marked_text, h_text, h_type, t_text, t_type)

# ----------------------------
# RUN
# ----------------------------
predict_images(IMAGES_TO_PREDICT)
