# Import Library

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from transformers import BertTokenizer, BertModel
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
from tqdm import tqdm
import matplotlib.pyplot as plt
import ast
import gdown

# Import Dataset

In [None]:
import shutil
import os

def clear_kaggle_cache():
    """
    Clear all possible caches and temporary files from Kaggle's working directories.
    This will free up disk space by removing unnecessary files.
    """
    cache_dirs = [
        "/kaggle/working",  # Working directory for models
    ]
    
    for cache_dir in cache_dirs:
        if os.path.exists(cache_dir):
            try:
                # Remove all files and subdirectories in the cache
                for filename in os.listdir(cache_dir):
                    file_path = os.path.join(cache_dir, filename)
                    try:
                        if os.path.isfile(file_path) or os.path.islink(file_path):
                            os.unlink(file_path)
                        elif os.path.isdir(file_path):
                            shutil.rmtree(file_path)
                    except Exception as e:
                        print(f"⚠️ Could not delete {file_path}: {e}")
                print(f"🧹 Emptied: {cache_dir}")
            except Exception as e:
                print(f"⚠️ Could not clear {cache_dir}: {e}")
        else:
            print(f"⚠️ Directory does not exist: {cache_dir}")

# Call the function to clean up
clear_kaggle_cache()


In [None]:
import os, ast, math, json, random
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from transformers import AutoTokenizer, AutoModel, get_cosine_schedule_with_warmup

# -------------- Repro --------------
SEED = 1337
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------- Paths (EDIT if needed) --------------
TXT_PATHS = {
    'train': '/kaggle/working/MNRE/mnre_txt/mnre_train.txt',
    'val':   '/kaggle/working/MNRE/mnre_txt/mnre_val .txt',   # will auto-fix if 'mnre_val .txt' exists
    'test':  '/kaggle/working/MNRE/mnre_txt/mnre_test.txt'
}
IMG_DIRS = {
    'train': '/kaggle/working/img_org/train',
    'val':   '/kaggle/working/img_org/val',
    'test':  '/kaggle/working/img_org/test'
}


# Data visulization and Analysis 

In [None]:


import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from PIL import Image

# --------------------
# 1. Load Data and Relation Distribution
# --------------------
# Helper function to load data from MNRE text files
def load_data(txt_path):
    data = []
    with open(txt_path, 'r', encoding='utf-8') as f:
        for line in f:
            obj = ast.literal_eval(line.strip())
            data.append(obj)
    return data

train_data = load_data(TXT_PATHS['train'])
val_data = load_data(TXT_PATHS['val'])
test_data = load_data(TXT_PATHS['test'])

# --------------------
# 2. Relation Distribution
# --------------------
relations = [obj['relation'] for obj in train_data]
relation_counts = dict(Counter(relations))

# Plot distribution of relations
plt.figure(figsize=(12,6))
sns.barplot(x=list(relation_counts.keys()), y=list(relation_counts.values()))
plt.xticks(rotation=90)
plt.title("Distribution of Relations in the Training Set")
plt.xlabel("Relation Types")
plt.ylabel("Frequency")
plt.show()





# --------------------
# 4. Text Tokenization: Average Text Length
# --------------------
# Analyze token lengths
text_lengths = [len(obj['token']) for obj in train_data]

plt.figure(figsize=(10,6))
sns.histplot(text_lengths, kde=True, bins=30)
plt.title("Distribution of Text Lengths (Number of Tokens)")
plt.xlabel("Number of Tokens")
plt.ylabel("Frequency")
plt.show()

# --------------------
# 5. Visualize Sample Images
# --------------------
sample_img_ids = [obj['img_id'] for obj in train_data[:5]]  # First 5 sample images

fig, axes = plt.subplots(1, 5, figsize=(15,5))
for i, img_id in enumerate(sample_img_ids):
    img_path = os.path.join(IMG_DIRS['train'], img_id)
    img = Image.open(img_path).convert('RGB')
    axes[i].imshow(img)
    axes[i].axis('off')
    axes[i].set_title(f"Image {i+1}")
plt.tight_layout()
plt.show()


# RoBERTa-large ResNet-101

In [None]:
!pip install -q transformers==4.42.4 accelerate datasets==2.21.0

import os, ast, math, random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from transformers import RobertaModel, RobertaTokenizer, get_cosine_schedule_with_warmup

# -------------- Repro --------------
SEED = 1337
random.seed(SEED); torch.manual_seed(SEED); np.random.seed(SEED)
torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False



# -------------- Utilities --------------
def fix_val_path():
    bad = "/kaggle/working/MNRE/mnre_txt/mnre_val .txt"
    if os.path.exists(bad) and not os.path.exists(TXT_PATHS['val']):
        os.rename(bad, TXT_PATHS['val'])

fix_val_path()

def load_relation2id(train_txt):
    rels = set()
    with open(train_txt, 'r', encoding='utf-8') as f:
        for line in f:
            obj = ast.literal_eval(line.strip())
            rels.add(obj['relation'])
    rel2id = {r:i for i,r in enumerate(sorted(rels))}
    return rel2id

# -------------- Tokenizer --------------
TEXT_MODEL = "roberta-large"     # 1024-d hidden size
tokenizer = RobertaTokenizer.from_pretrained(TEXT_MODEL)

# -------------- Dataset --------------
class MNREDataset(Dataset):
    def __init__(self, txt_file, img_dir, relation2id, max_len=128, train_aug=False):
        self.samples = []
        self.img_dir = img_dir
        self.max_len = max_len
        self.relation2id = relation2id

        # image transforms
        if train_aug:
            self.tfm = transforms.Compose([
                transforms.Resize((256,256)),
                transforms.RandomResizedCrop((224,224), scale=(0.8,1.0)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ColorJitter(0.2,0.2,0.2,0.05),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
            ])
        else:
            self.tfm = transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
            ])

        with open(txt_file, 'r', encoding='utf-8') as f:
            for line in f:
                obj = ast.literal_eval(line.strip())
                tokens = obj['token']
                h_start, h_end = obj['h']['pos']
                t_start, t_end = obj['t']['pos']

                # insert entity markers
                spans = [('h', h_start, h_end), ('t', t_start, t_end)]
                spans.sort(key=lambda x: x[1], reverse=True)
                toks = tokens.copy()
                for et, s, e in spans:
                    toks.insert(e, '[/E1]' if et=='h' else '[/E2]')
                    toks.insert(s, '[E1]' if et=='h' else '[E2]')

                text = " ".join(toks)
                img_id = obj['img_id']
                label = self.relation2id[obj['relation']]
                self.samples.append({'text': text, 'img_id': img_id, 'label': label})

    def __len__(self): return len(self.samples)

    def __getitem__(self, i):
        s = self.samples[i]
        image_path = os.path.join(self.img_dir, s['img_id'])
        img = Image.open(image_path).convert('RGB')
        img = self.tfm(img)
        enc = tokenizer(
            s['text'], padding='max_length', truncation=True,
            max_length=self.max_len, return_tensors='pt'
        )
        return {
            'input_ids': enc['input_ids'].squeeze(0),
            'attention_mask': enc['attention_mask'].squeeze(0),
            'pixel_values': img,
            'label': torch.tensor(s['label'], dtype=torch.long)
        }

# -------------- Model --------------
class CrossModalRE(nn.Module):
    """
    RoBERTa-large (text) + ResNet-101 (image)
    - Project image -> 1024
    - Cross-attn encoder (2 layers, 16 heads)
    - Gating between fused image token & text CLS
    """
    def __init__(self, num_classes):
        super().__init__()
        self.text_enc = RobertaModel.from_pretrained(TEXT_MODEL)
        self.text_hidden = self.text_enc.config.hidden_size  # 1024

        # image encoder
        cnn = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V2)
        cnn.fc = nn.Identity()
        self.img_enc = cnn
        self.img_proj = nn.Linear(2048, self.text_hidden)

        # cross-attention: put image token at index 0, then text sequence
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.text_hidden, nhead=16,
            dim_feedforward=4096, dropout=0.1, activation='gelu',
            batch_first=True
        )
        self.xattn = nn.TransformerEncoder(encoder_layer, num_layers=2)

        # gating and classifier
        self.gate = nn.Linear(self.text_hidden*2, 1)
        self.classifier = nn.Sequential(
            nn.Linear(self.text_hidden, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )

        # temperature for optional contrastive
        self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1/0.07))

    def forward(self, input_ids, attention_mask, pixel_values, return_feats=False):
        txt = self.text_enc(input_ids=input_ids, attention_mask=attention_mask)
        txt_seq = txt.last_hidden_state            # [B, L, 1024]
        txt_cls = txt_seq[:,0,:]                   # [B, 1024]

        img = self.img_enc(pixel_values)           # [B, 2048]
        img = self.img_proj(img).unsqueeze(1)      # [B, 1, 1024]

        mm = torch.cat([img, txt_seq], dim=1)      # [B, 1+L, 1024]
        fused = self.xattn(mm)                     # [B, 1+L, 1024]
        img_token = fused[:,0,:]                   # [B, 1024]

        # gating
        g = torch.sigmoid(self.gate(torch.cat([txt_cls, img_token], dim=1)))  # [B,1]
        fused_vec = g*txt_cls + (1-g)*img_token                                # [B,1024]

        logits = self.classifier(fused_vec)
        if return_feats:
            return logits, fused_vec, img_token
        return logits

# -------------- Losses --------------
class LabelSmoothingCE(nn.Module):
    def __init__(self, eps=0.1):
        super().__init__()
        self.eps = eps
        self.log_softmax = nn.LogSoftmax(dim=1)

    def forward(self, logits, target):
        n_classes = logits.size(1)
        logprobs = self.log_softmax(logits)
        nll = -logprobs.gather(dim=1, index=target.unsqueeze(1)).squeeze(1)
        smooth = -logprobs.mean(dim=1)
        loss = (1 - self.eps) * nll + self.eps * smooth
        return loss.mean()

def contrastive_loss(z_txt, z_img, temperature):
    # z: [B, D] normalized
    z_txt = nn.functional.normalize(z_txt, dim=1)
    z_img = nn.functional.normalize(z_img, dim=1)
    logits = z_txt @ z_img.t() * torch.exp(temperature)
    labels = torch.arange(z_txt.size(0), device=z_txt.device)
    loss_i = nn.functional.cross_entropy(logits, labels)
    loss_t = nn.functional.cross_entropy(logits.t(), labels)
    return (loss_i + loss_t) / 2

# -------------- Train/Eval --------------
def compute_class_weights(train_txt, relation2id):
    from collections import Counter
    cnt = Counter()
    with open(train_txt,'r',encoding='utf-8') as f:
        for line in f:
            r = ast.literal_eval(line.strip())['relation']
            cnt[r]+=1
    total = sum(cnt.values())
    weights = np.zeros(len(relation2id), dtype=np.float32)
    for r, idx in relation2id.items():
        # inverse frequency
        weights[idx] = total / (len(relation2id)*cnt[r])
    # normalize
    weights = weights * (len(relation2id)/weights.sum())
    return torch.tensor(weights, dtype=torch.float32)

def train_one_epoch(model, loader, optim, sched, ce_loss, alpha_contrast=0.05, scaler=None):
    model.train()
    losses, preds, labels = [], [], []
    pbar = tqdm(loader, desc="Train", leave=False)
    for batch in pbar:
        optim.zero_grad(set_to_none=True)
        input_ids = batch['input_ids'].to(device)
        attn = batch['attention_mask'].to(device)
        pixels = batch['pixel_values'].to(device)
        y = batch['label'].to(device)

        with torch.cuda.amp.autocast(enabled=(scaler is not None)):
            logits, fused_vec, img_tok = model(input_ids, attn, pixels, return_feats=True)
            loss_cls = ce_loss(logits, y)
            # optional contrastive term between fused vec and image token
            t = model.logit_scale
            loss_ctr = contrastive_loss(fused_vec, img_tok, t) * alpha_contrast
            loss = loss_cls + loss_ctr

        if scaler:
            scaler.scale(loss).backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optim)
            scaler.update()
        else:
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optim.step()

        if sched: sched.step()
        losses.append(loss.item())
        preds.extend(torch.argmax(logits,1).detach().cpu().numpy())
        labels.extend(y.detach().cpu().numpy())
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='macro')
    return np.mean(losses), acc, f1

@torch.no_grad()
def evaluate(model, loader, ce_loss):
    model.eval()
    losses, preds, labels = [], [], []
    for batch in tqdm(loader, desc="Val", leave=False):
        input_ids = batch['input_ids'].to(device)
        attn = batch['attention_mask'].to(device)
        pixels = batch['pixel_values'].to(device)
        y = batch['label'].to(device)
        logits = model(input_ids, attn, pixels)
        loss = ce_loss(logits, y)
        losses.append(loss.item())
        preds.extend(torch.argmax(logits,1).cpu().numpy())
        labels.extend(y.cpu().numpy())
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='macro')
    prec = precision_score(labels, preds, average='macro', zero_division=0)
    rec = recall_score(labels, preds, average='macro', zero_division=0)
    return np.mean(losses), acc, f1, prec, rec

# -------------- Main --------------
def main():
    relation2id = load_relation2id(TXT_PATHS['train'])
    num_classes = len(relation2id)
    print("Relations:", num_classes, relation2id)

    train_ds = MNREDataset(TXT_PATHS['train'], IMG_DIRS['train'], relation2id, train_aug=True)
    val_ds   = MNREDataset(TXT_PATHS['val'],   IMG_DIRS['val'],   relation2id, train_aug=False)
    test_ds  = MNREDataset(TXT_PATHS['test'],  IMG_DIRS['test'],  relation2id, train_aug=False)

    train_loader = DataLoader(train_ds, batch_size=12, shuffle=True, num_workers=2, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=16, shuffle=False, num_workers=2, pin_memory=True)
    test_loader  = DataLoader(test_ds,  batch_size=16, shuffle=False, num_workers=2, pin_memory=True)

    model = CrossModalRE(num_classes).to(device)

    # Loss with label smoothing + class weights
    class_w = compute_class_weights(TXT_PATHS['train'], relation2id).to(device)
    ce_ls = LabelSmoothingCE(eps=0.1)
    # Wrap to apply weights: compute CE with smoothing manually + scale by weights of target
    def weighted_loss(logits, target):
        base = ce_ls(logits, target)
        # lightweight approximation: multiply by per-batch mean weight
        w = class_w[target].mean()
        return base * w

    # Optim & sched
    total_steps = len(train_loader) * 10   # 10 epochs default
    warmup = int(0.1 * total_steps)
    optim = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
    sched = get_cosine_schedule_with_warmup(optim, num_warmup_steps=warmup, num_training_steps=total_steps)

    scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))

    best_f1, best_state = -1, None
    patience, bad = 3, 0

    for epoch in range(1, 6):
        tr_loss, tr_acc, tr_f1 = train_one_epoch(model, train_loader, optim, sched, weighted_loss, alpha_contrast=0.05, scaler=scaler)
        vl_loss, vl_acc, vl_f1, vl_p, vl_r = evaluate(model, val_loader, weighted_loss)
        print(f"Epoch {epoch:02d} | Train: loss {tr_loss:.4f} acc {tr_acc:.3f} f1 {tr_f1:.3f} || Val: loss {vl_loss:.4f} acc {vl_acc:.3f} f1 {vl_f1:.3f}")

        if vl_f1 > best_f1:
            best_f1 = vl_f1; bad = 0
            best_state = {
                'model_state': model.state_dict(),
                'relation2id': relation2id,
                'config': {'text_model': TEXT_MODEL, 'img_backbone': 'resnet101', 'xattn_layers':2, 'heads':16}
            }
            torch.save(best_state, 'mnre_roberta_resnet101_xattn_best.pth')
        else:
            bad += 1
            if bad >= patience:
                print("Early stopping.")
                break

    # Test with best checkpoint
    if best_state is None:
        best_state = torch.load('mnre_roberta_resnet101_xattn_best.pth', map_location=device)
        model.load_state_dict(best_state['model_state'])
    else:
        model.load_state_dict(best_state['model_state'])

    te_loss, te_acc, te_f1, te_p, te_r = evaluate(model, test_loader, weighted_loss)
    print(f"\nTEST  | acc {te_acc:.4f} f1 {te_f1:.4f}  P {te_p:.4f} R {te_r:.4f}")

if __name__ == "__main__":
    main()  



# Prediction and NER apply 

In [None]:
from transformers import pipeline
import torch
from transformers import RobertaTokenizer
from torchvision import transforms
from PIL import Image
import os
from transformers import pipeline, RobertaTokenizer, AutoModelForTokenClassification, BertTokenizer

# Assuming you have a trained model saved as 'mnre_roberta_resnet101_xattn_best.pth'
checkpoint = torch.load('/kaggle/working/mnre_roberta_resnet101_xattn_best.pth', map_location=device)
model = CrossModalRE(num_classes=len(checkpoint['relation2id'])).to(device)
model.load_state_dict(checkpoint['model_state'])
model.eval()

# Load the pre-trained NER model and tokenizer
ner_model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
ner_tokenizer = BertTokenizer.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, device=0)

# --------------------------
# Sample images and predictions
# --------------------------

# Assuming `train_data` is already defined and holds the dataset
sample_img_ids = [obj['img_id'] for obj in train_data[:5]]  # First 5 sample images

# Load images and show them
fig, axes = plt.subplots(1, 5, figsize=(15,5))
for i, img_id in enumerate(sample_img_ids):
    img_path = os.path.join(IMG_DIRS['train'], img_id)
    img = Image.open(img_path).convert('RGB')
    axes[i].imshow(img)
    axes[i].axis('off')
    axes[i].set_title(f"Image {i+1}")
plt.tight_layout()
plt.show()

# Define the function to predict relation for each image and extract text
def predict_relation_for_image(image_path, tokens, model, tokenizer, relation2id, max_len=128):
    img = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img_tensor = transform(img).unsqueeze(0).to(device)

    # Tokenizing the input text (from the dataset)
    text = " ".join(tokens)  # Combine tokens to make the input text
    enc = tokenizer(text, padding='max_length', truncation=True, max_length=max_len, return_tensors='pt').to(device)

    # Get model predictions
    with torch.no_grad():
        logits = model(enc['input_ids'], enc['attention_mask'], img_tensor)

    # Process the logits for relation prediction
    relation_preds = torch.argmax(logits, dim=1).cpu().numpy()

    # Map relation ID back to relation name
    id2relation = {v: k for k, v in relation2id.items()}  # Reverse the relation2id dictionary
    relation_name = id2relation.get(relation_preds[0], "Unknown")  # Get relation name

    return relation_name, text  # Returning the relation prediction along with the text

# Now predict relations for the first 5 images and get the text associated with each image
predictions = []
for img_id in sample_img_ids:
    img_path = os.path.join(IMG_DIRS['train'], img_id)
    
    # Extract the tokens (text) associated with the image from your dataset
    sample = next(item for item in train_data if item['img_id'] == img_id)
    tokens = sample['token']  # Extract tokens associated with this image
    
    # Predict relation and extract text
    relation, extracted_text = predict_relation_for_image(img_path, tokens, model, tokenizer, checkpoint['relation2id'])
    
    # Apply NER to the extracted text
    ner_results = ner_pipe(extracted_text)
    
    predictions.append((img_id, relation, extracted_text, ner_results))

# Display predicted relations, extracted text, and NER results for the images
for i, (img_id, rel, extracted_text, ner_results) in enumerate(predictions):
    print(f"Image {i+1} ({img_id}) Predicted Relation: {rel}")
    print(f"Extracted Text: {extracted_text}")
    print("NER Results:")
    for ner in ner_results:
        print(f"Entity: {ner['word']} | Type: {ner['entity']} | Score: {ner['score']:.4f}")
    print("-" * 50)


# resnet34 + bert-base-uncased models 

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from transformers import BertTokenizer, BertModel
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
from tqdm import tqdm
import matplotlib.pyplot as plt
import ast
import gdown

In [None]:
class MNREDataset(Dataset):
    def __init__(self, txt_file, img_dir, tokenizer, relation2id, max_len=128, transform=None):
        self.txt_file = txt_file
        self.img_dir = img_dir
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.relation2id = relation2id
        self.samples = []
        with open(self.txt_file, 'r', encoding='utf-8') as f:
            for line in f:
                obj = ast.literal_eval(line.strip())
                tokens = obj['token']
                h_start, h_end = obj['h']['pos']
                t_start, t_end = obj['t']['pos']
                spans = sorted([('h', h_start, h_end), ('t', t_start, t_end)], key=lambda x: x[1], reverse=True)
                for etype, start, end in spans:
                    tag_close = '[/E1]' if etype == 'h' else '[/E2]'
                    tokens.insert(end, tag_close)
                    tag_open = '[E1]' if etype == 'h' else '[E2]'
                    tokens.insert(start, tag_open)
                text = " ".join(tokens)
                img_id = obj['img_id']
                label = self.relation2id[obj['relation']]
                self.samples.append({'text': text, 'img_id': img_id, 'label': label})

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        enc = self.tokenizer(sample['text'], padding='max_length', truncation=True,
                             max_length=self.max_len, return_tensors='pt')
        img_path = os.path.join(self.img_dir, sample['img_id'])
        image = Image.open(img_path).convert('RGB')
        if self.transform: image = self.transform(image)
        return {
            'input_ids': enc['input_ids'].squeeze(0),
            'attention_mask': enc['attention_mask'].squeeze(0),
            'image': image,
            'label': torch.tensor(sample['label'], dtype=torch.long)
        }

In [None]:
import os, ast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from transformers import BertModel, BertTokenizer
from PIL import Image
import matplotlib.pyplot as plt

# =====================
# Model Architecture
# =====================
class MultimodalREModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # ---- Text encoder (BERT) ----
        self.text_encoder = BertModel.from_pretrained('bert-base-uncased')

        # ---- Image encoder (ResNet-34) ----
        cnn = models.resnet34(pretrained=True)
        cnn.fc = nn.Identity()  # remove classification head
        self.image_encoder = cnn

        # ---- Projection layer to match text hidden size ----
        self.img_proj = nn.Linear(512, self.text_encoder.config.hidden_size)

        # ---- Cross-Attention (Transformer encoder layer) ----
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.text_encoder.config.hidden_size,
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1,
            activation="relu",
            batch_first=True
        )
        self.cross_attention = nn.TransformerEncoder(encoder_layer, num_layers=1)

        # ---- Classifier ----
        hidden_dim = self.text_encoder.config.hidden_size
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, input_ids, attention_mask, image):
        # Encode text
        txt_out = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        txt_feats = txt_out.last_hidden_state   # [B, seq_len, H]

        # Encode image
        img_feats = self.image_encoder(image)   # [B, 512]
        img_feats = self.img_proj(img_feats).unsqueeze(1)  # [B, 1, H]

        # Concatenate image embedding as a token to text sequence
        multimodal_feats = torch.cat([img_feats, txt_feats], dim=1)  # [B, 1+seq_len, H]

        # Cross-attention fusion
        fused_feats = self.cross_attention(multimodal_feats)  # [B, 1+seq_len, H]

        # Use the image token ([IMG]) as global fused representation
        img_token = fused_feats[:, 0, :]  # [B, H]

        return self.classifier(img_token)

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    losses, preds, labels = [], [], []
    for batch in tqdm(loader, desc="Training", leave=False):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        images = batch['image'].to(device)
        labs = batch['label'].to(device)
        outs = model(input_ids, attention_mask, images)
        loss = criterion(outs, labs)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        pred = torch.argmax(outs, dim=1).cpu().numpy()
        preds.extend(pred)
        labels.extend(labs.cpu().numpy())
    return {
        'loss': sum(losses)/len(losses),
        'acc': accuracy_score(labels, preds),
        'prec': precision_score(labels, preds, average='macro'),
        'rec': recall_score(labels, preds, average='macro'),
        'f1': f1_score(labels, preds, average='macro')
    }

def eval_epoch(model, loader, criterion, device):
    model.eval()
    losses, preds, labels, probs = [], [], [], []
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating", leave=False):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labs = batch['label'].to(device)
            outs = model(input_ids, attention_mask, images)
            loss = criterion(outs, labs)
            prob = torch.softmax(outs, dim=1)
            losses.append(loss.item())
            pred = torch.argmax(outs, dim=1).cpu().numpy()
            preds.extend(pred)
            labels.extend(labs.cpu().numpy())
            if prob.shape[1] == 2:
                probs.extend(prob[:,1].cpu().numpy())
    metrics = {
        'loss': sum(losses)/len(losses),
        'acc': accuracy_score(labels, preds),
        'prec': precision_score(labels, preds, average='macro'),
        'rec': recall_score(labels, preds, average='macro'),
        'f1': f1_score(labels, preds, average='macro')
    }
    roc_data = None
    if len(set(labels)) == 2:
        fpr, tpr, _ = roc_curve(labels, probs)
        metrics['roc_auc'] = auc(fpr, tpr)
        roc_data = (fpr, tpr)
    return metrics, roc_data

# bert-base-uncased models 

In [None]:
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    epochs, batch_size, lr = 5, 16, 1e-5
    txt_paths = {
        'train': '/kaggle/working/MNRE/mnre_txt/mnre_train.txt',
        'val':   '/kaggle/working/MNRE/mnre_txt/mnre_val .txt',  # Fixed typo in path
        'test':  '/kaggle/working/MNRE/mnre_txt/mnre_test.txt'
    }
    img_dirs = {
        'train': '/kaggle/working/img_org/train',
        'val':   '/kaggle/working/img_org/val',
        'test':  '/kaggle/working/img_org/test'
    }
    # Build relation2id from training data
    rels = set()
    with open(txt_paths['train'], 'r', encoding='utf-8') as f:
        for line in f:
            obj = ast.literal_eval(line.strip())
            rels.add(obj['relation'])
    relation2id = {r: idx for idx, r in enumerate(sorted(rels))}
    num_classes = len(relation2id)

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])

    datasets = {split: MNREDataset(txt_paths[split], img_dirs[split], tokenizer, relation2id, transform=transform)
                for split in ('train', 'val', 'test')}
    loaders = {split: DataLoader(datasets[split], batch_size=batch_size, shuffle=(split == 'train'))
               for split in ('train', 'val', 'test')}

    model = MultimodalREModel(num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr)

    history = {'train': [], 'val': []}
    for epoch in range(1, epochs + 1):
        train_metrics = train_epoch(model, loaders['train'], criterion, optimizer, device)
        val_metrics, _ = eval_epoch(model, loaders['val'], criterion, device)
        history['train'].append(train_metrics)
        history['val'].append(val_metrics)
        print(f"Epoch {epoch}/{epochs} | Training Accuracy: {train_metrics['acc']:.4f} | Val Accuracy: {val_metrics['acc']:.4f} | "
              f"Train F1: {train_metrics['f1']:.4f} | Val F1: {val_metrics['f1']:.4f}")

    # Save the model
    torch.save({'model_state': model.state_dict(), 'relation2id': relation2id}, 'mnre_model.pth')

    # Evaluate on test set
    test_metrics, roc_data = eval_epoch(model, loaders['test'], criterion, device)
    print(f"Test Metrics: Training Accuracy={test_metrics['acc']:.4f}, F1={test_metrics['f1']:.4f}, "
          f"Precision={test_metrics['prec']:.4f}, Recall={test_metrics['rec']:.4f}, "
          f"ROC AUC={test_metrics.get('roc_auc', 'N/A')}")

    # Plot Training Accuracy and F1 Score
    epochs_range = range(1, epochs + 1)
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_range, [m['acc'] for m in history['train']], label='Training Accuracy')
    plt.plot(epochs_range, [m['acc'] for m in history['val']], label='Val Accuracy')
    plt.plot(epochs_range, [m['f1'] for m in history['train']], '--', label='Train F1')
    plt.plot(epochs_range, [m['f1'] for m in history['val']], '--', label='Val F1')
    plt.title('Training Accuracy & F1 Score over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.legend()
    plt.show()

    # Plot Precision and Recall for Training
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_range, [m['prec'] for m in history['train']], label='Train Precision')
    plt.plot(epochs_range, [m['rec'] for m in history['train']], label='Train Recall')
    plt.title('Training Precision & Recall over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.legend()
    plt.show()

    if roc_data:
        fpr, tpr = roc_data
        plt.figure(figsize=(6, 6))
        plt.plot(fpr, tpr, label=f"ROC (AUC={test_metrics['roc_auc']:.2f})")
        plt.plot([0, 1], [0, 1], linestyle='--')
        plt.title('ROC Curve (Test Set)')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.legend()
        plt.show()

    return tokenizer, transform, device, relation2id  # Return for prediction

In [None]:
tokenizer, transform, device, relation2id = main()

In [None]:
checkpoint = torch.load('mnre_model.pth')
relation2id = checkpoint['relation2id']
id2relation = {v: k for k, v in relation2id.items()}
model = MultimodalREModel(num_classes=len(relation2id))
model.load_state_dict(checkpoint['model_state'])
model.to(device)
model.eval()

In [None]:
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import ast
import os


# Relation Predictor class
class RelationPredictor:
    def __init__(self, model_path, dataset_path='/kaggle/working/MNRE/mnre_txt/mnre_train.txt', image_dir='/kaggle/working/img_org'):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # Load the trained model
        checkpoint = torch.load(model_path, map_location=self.device)
        self.relation2id = checkpoint['relation2id']
        self.id2relation = {v: k for k, v in self.relation2id.items()}
        self.model = MultimodalREModel(num_classes=len(self.relation2id))
        self.model.load_state_dict(checkpoint['model_state'])
        self.model.to(self.device)
        self.model.eval()
        # Load tokenizer and image transform
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
        # Load dataset
        with open(dataset_path, 'r', encoding='utf-8') as f:
            self.dataset = [ast.literal_eval(line.strip()) for line in f]
        self.image_dir = image_dir

    def find_sample(self, tokens):
        for sample in self.dataset:
            if sample['token'] == tokens:
                return sample
        raise ValueError("No matching sample found for the given token list")

    def preprocess_text(self, tokens, h_pos, t_pos):
        tokens = tokens.copy()
        spans = [('h', h_pos[0], h_pos[1]), ('t', t_pos[0], t_pos[1])]
        spans.sort(key=lambda x: x[1], reverse=True)
        for etype, start, end in spans:
            tag_close = '[/E1]' if etype == 'h' else '[/E2]'
            tokens.insert(end, tag_close)
            tag_open = '[E1]' if etype == 'h' else '[E2]'
            tokens.insert(start, tag_open)
        return " ".join(tokens)

    def predict_from_token(self, tokens):
        sample = self.find_sample(tokens)
        h_pos = sample['h']['pos']
        t_pos = sample['t']['pos']
        h_name = sample['h']['name']
        t_name = sample['t']['name']
        img_id = sample['img_id']
        image_path = os.path.join(self.image_dir, img_id)

        processed_text = self.preprocess_text(tokens, h_pos, t_pos)
        enc = self.tokenizer(processed_text, padding='max_length', truncation=True,
                            max_length=128, return_tensors='pt')
        image = Image.open(image_path).convert('RGB')
        image_tensor = self.transform(image).unsqueeze(0).to(self.device)
        input_ids = enc['input_ids'].to(self.device)
        attention_mask = enc['attention_mask'].to(self.device)

        with torch.no_grad():
            outputs = self.model(input_ids, attention_mask, image_tensor)
            probs = torch.softmax(outputs, dim=1)
            pred_idx = torch.argmax(probs, dim=1).item()
            pred_relation = self.id2relation[pred_idx]
            confidence = probs[0, pred_idx].item()

        plt.imshow(image)
        plt.axis('off')
        plt.show()
        print(f"Processed Text: {processed_text}")
        print(f"Head Entity: {h_name}")
        print(f"Tail Entity: {t_name}")
        print(f"Predicted Relation: {pred_relation} with confidence {confidence:.4f}")

        return {
            'image': image,
            'processed_text': processed_text,
            'head_entity': h_name,
            'tail_entity': t_name,
            'predicted_relation': pred_relation,
            'confidence': confidence
        }

# Example usage
predictor = RelationPredictor(
    model_path='/kaggle/working/mnre_model.pth',
    dataset_path='/kaggle/working/MNRE/mnre_txt/mnre_train.txt',
    image_dir='/kaggle/working/img_org/train'
)
tokens = ['The', 'latest', 'Arkham', 'Horror', 'LCG', 'deluxe', 'expansion',
          'the', 'Circle', 'Undone', 'has', 'been', 'released', ':']
result = predictor.predict_from_token(tokens)

In [None]:



# Example usage
predictor = RelationPredictor(
    model_path='/kaggle/working/mnre_model.pth',
    dataset_path='/kaggle/working/MNRE/mnre_txt/mnre_train.txt',
    image_dir='/kaggle/working/img_org/train'
)
tokens = ['RT', '@PaulStrangwood', ':', '@ThePhotoHour', 'this', 'is', 'part', 'of', 'the', 'Magat', 'Damms', 'near', 'to', 'Cauayan', 'in', 'the', 'Philippines']
result = predictor.predict_from_token(tokens)

In [None]:
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import ast
import os

class RelationPredictor:
    def __init__(self, model_path, dataset_path, image_dir):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        checkpoint = torch.load(model_path, map_location=self.device)
        self.relation2id = checkpoint['relation2id']
        self.id2relation = {v: k for k, v in self.relation2id.items()}
        self.model = MultimodalREModel(num_classes=len(self.relation2id))
        self.model.load_state_dict(checkpoint['model_state'])
        self.model.to(self.device)
        self.model.eval()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
        with open(dataset_path, 'r', encoding='utf-8') as f:
            self.dataset = [ast.literal_eval(line.strip()) for line in f]
        self.image_dir = image_dir

    def find_sample(self, tokens):
        for sample in self.dataset:
            if sample['token'] == tokens:
                return sample
        raise ValueError("❌ No matching sample found for the given tokens.")

    def preprocess_text(self, tokens, h_pos, t_pos):
        tokens = tokens.copy()
        spans = [('h', h_pos[0], h_pos[1]), ('t', t_pos[0], t_pos[1])]
        spans.sort(key=lambda x: x[1], reverse=True)
        for etype, start, end in spans:
            tag_close = '[/E1]' if etype == 'h' else '[/E2]'
            tokens.insert(end, tag_close)
            tag_open = '[E1]' if etype == 'h' else '[E2]'
            tokens.insert(start, tag_open)
        return " ".join(tokens)

    def predict_from_token(self, tokens):
        sample = self.find_sample(tokens)
        h_pos = sample['h']['pos']
        t_pos = sample['t']['pos']
        h_name = sample['h']['name']
        t_name = sample['t']['name']
        img_id = sample['img_id']
        image_path = os.path.join(self.image_dir, img_id)

        processed_text = self.preprocess_text(tokens, h_pos, t_pos)
        enc = self.tokenizer(processed_text, padding='max_length', truncation=True,
                             max_length=128, return_tensors='pt')
        image = Image.open(image_path).convert('RGB')
        image_tensor = self.transform(image).unsqueeze(0).to(self.device)
        input_ids = enc['input_ids'].to(self.device)
        attention_mask = enc['attention_mask'].to(self.device)

        with torch.no_grad():
            outputs = self.model(input_ids, attention_mask, image_tensor)
            probs = torch.softmax(outputs, dim=1)
            pred_idx = torch.argmax(probs, dim=1).item()
            pred_relation = self.id2relation[pred_idx]
            confidence = probs[0, pred_idx].item()

        plt.imshow(image)
        plt.axis('off')
        plt.show()

        print(f"📝 Processed Text: {processed_text}")
        print(f"👤 Head Entity: {h_name}")
        print(f"👥 Tail Entity: {t_name}")
        print(f"🔗 Predicted Relation: {pred_relation}")
        print(f"📊 Confidence Score: {confidence:.4f}")

        return {
            'image': image,
            'processed_text': processed_text,
            'head_entity': h_name,
            'tail_entity': t_name,
            'predicted_relation': pred_relation,
            'confidence': confidence
        }

# === Usage Example ===

tokens = ['RT', '@TedNesi', ':', 'Bishop', 'Tobin', 'says', 'he', 'is', 'leaving', 'Twitter', 'after', 'just', 'a', 'few', 'months']

predictor = RelationPredictor(
    model_path='/kaggle/working/mnre_model.pth',
    dataset_path='/kaggle/working/MNRE/mnre_txt/mnre_train.txt',
    image_dir='/kaggle/working/img_org/train'
)

predictor.predict_from_token(tokens)


# Pretrained model BLip 

In [None]:
import os
import ast
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

from transformers import BlipProcessor, BlipModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ========================
# 1. Load BLIP (pretrained vision-language model)
# ========================
MODEL_NAME = "Salesforce/blip-image-captioning-base"
processor = BlipProcessor.from_pretrained(MODEL_NAME)
blip_model = BlipModel.from_pretrained(MODEL_NAME).to(device)

# Freeze BLIP encoder (optional)
for param in blip_model.parameters():
    param.requires_grad = False

HIDDEN_SIZE = blip_model.config.text_config.hidden_size  # 768

# ========================
# 2. Dataset
# ========================
class MNREDataset(Dataset):
    def __init__(self, txt_file, img_dir, relation2id, transform=None, max_len=128):
        self.samples = []
        self.img_dir = img_dir
        self.transform = transform
        self.max_len = max_len

        with open(txt_file, "r", encoding="utf-8") as f:
            for line in f:
                obj = ast.literal_eval(line.strip())
                tokens = obj['token']
                h_start, h_end = obj['h']['pos']
                t_start, t_end = obj['t']['pos']

                # Insert entity markers
                spans = sorted([('h', h_start, h_end), ('t', t_start, t_end)], key=lambda x: x[1], reverse=True)
                for etype, start, end in spans:
                    tag_close = '[/E1]' if etype == 'h' else '[/E2]'
                    tokens.insert(end, tag_close)
                    tag_open = '[E1]' if etype == 'h' else '[E2]'
                    tokens.insert(start, tag_open)

                text = " ".join(tokens)
                img_id = obj['img_id']
                label = relation2id[obj['relation']]

                self.samples.append({"text": text, "img_id": img_id, "label": label})

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        image_path = os.path.join(self.img_dir, sample['img_id'])
        image = Image.open(image_path).convert("RGB")

        inputs = processor(images=image, text=sample['text'], padding="max_length", truncation=True,
                           max_length=self.max_len, return_tensors="pt")

        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "pixel_values": inputs["pixel_values"].squeeze(0),
            "label": torch.tensor(sample["label"], dtype=torch.long)
        }

# ========================
# 3. Model (Relation Classifier)
# ========================
class MultimodalREModel(nn.Module):
    def __init__(self, num_relations):
        super().__init__()
        self.blip = blip_model

        # 🔹 Get actual BLIP hidden size (can be 512 or 768 depending on model variant)
        hidden_size = self.blip.config.projection_dim  

        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_relations)
        )

    def forward(self, input_ids, attention_mask, pixel_values):
        outputs = self.blip(input_ids=input_ids,
                            attention_mask=attention_mask,
                            pixel_values=pixel_values)

        pooled = outputs.image_embeds  # or outputs.text_embeds (both are projection_dim size)
        logits = self.classifier(pooled)
        return logits


# ========================
# 4. Training & Evaluation
# ========================
def train_epoch(model, loader, criterion, optimizer):
    model.train()
    losses, preds, labels = [], [], []
    for batch in tqdm(loader, desc="Training", leave=False):
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attn = batch["attention_mask"].to(device)
        pixels = batch["pixel_values"].to(device)
        labs = batch["label"].to(device)

        logits = model(input_ids, attn, pixels)
        loss = criterion(logits, labs)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        preds.extend(torch.argmax(logits, 1).cpu().numpy())
        labels.extend(labs.cpu().numpy())

    return sum(losses)/len(losses), accuracy_score(labels, preds), f1_score(labels, preds, average="macro")

def eval_epoch(model, loader, criterion):
    model.eval()
    losses, preds, labels = [], [], []
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating", leave=False):
            input_ids = batch["input_ids"].to(device)
            attn = batch["attention_mask"].to(device)
            pixels = batch["pixel_values"].to(device)
            labs = batch["label"].to(device)

            logits = model(input_ids, attn, pixels)
            loss = criterion(logits, labs)

            losses.append(loss.item())
            preds.extend(torch.argmax(logits, 1).cpu().numpy())
            labels.extend(labs.cpu().numpy())

    return sum(losses)/len(losses), accuracy_score(labels, preds), f1_score(labels, preds, average="macro")

# ========================
# 5. Main
# ========================
def main():
    txt_paths = {
        'train': '/kaggle/working/MNRE/mnre_txt/mnre_train.txt',
        'val':   '/kaggle/working/MNRE/mnre_txt/mnre_val .txt',
        'test':  '/kaggle/working/MNRE/mnre_txt/mnre_test.txt'
    }
    img_dirs = {
        'train': '/kaggle/working/img_org/train',
        'val':   '/kaggle/working/img_org/val',
        'test':  '/kaggle/working/img_org/test'
    }

    # Build relation mapping
    rels = set()
    with open(txt_paths["train"], "r", encoding="utf-8") as f:
        for line in f:
            obj = ast.literal_eval(line.strip())
            rels.add(obj["relation"])
    relation2id = {r: idx for idx, r in enumerate(sorted(rels))}
    num_relations = len(relation2id)

    # Datasets & Loaders
    datasets = {
        split: MNREDataset(txt_paths[split], img_dirs[split], relation2id)
        for split in ("train", "val", "test")
    }
    loaders = {
        split: DataLoader(datasets[split], batch_size=8, shuffle=(split=="train"))
        for split in ("train", "val", "test")
    }

    # Model, Loss, Optimizer
    model = MultimodalREModel(num_relations).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

    history = {"train": [], "val": []}

    EPOCHS = 5
    for epoch in range(1, EPOCHS+1):
        tr_loss, tr_acc, tr_f1 = train_epoch(model, loaders["train"], criterion, optimizer)
        val_loss, val_acc, val_f1 = eval_epoch(model, loaders["val"], criterion)

        history["train"].append((tr_acc, tr_f1))
        history["val"].append((val_acc, val_f1))

        print(f"Epoch {epoch}/{EPOCHS} | Train Acc: {tr_acc:.4f} F1: {tr_f1:.4f} | Val Acc: {val_acc:.4f} F1: {val_f1:.4f}")

    # Test Evaluation
    test_loss, test_acc, test_f1 = eval_epoch(model, loaders["test"], criterion)
    print(f"✅ Test Accuracy: {test_acc:.4f}, F1: {test_f1:.4f}")

    return model, relation2id

if __name__ == "__main__":
    model, relation2id = main()


In [None]:
import os
import ast
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

from transformers import BlipProcessor, BlipModel
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline

# ----------------------------
# Config
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# BLIP backbone (pretrained multimodal)
BLIP_NAME = "Salesforce/blip-image-captioning-base"

# Your MNRE files (adjust if needed)
TXT_PATHS = {
    'train': '/kaggle/working/MNRE/mnre_txt/mnre_train.txt',
    'val':   '/kaggle/working/MNRE/mnre_txt/mnre_val.txt',   # ensure no trailing space in filename
    'test':  '/kaggle/working/MNRE/mnre_txt/mnre_test.txt',
}
IMG_DIRS = {
    'train': '/kaggle/working/img_org/train',
    'val':   '/kaggle/working/img_org/val',
    'test':  '/kaggle/working/img_org/test',
}

# Your trained RE checkpoint (produced by the BLIP training script)
CKPT_PATH = "mnre_blip_re.pth"  # change if your filename differs

# Example images the user gave
IMAGES_TO_PREDICT = [
    "/kaggle/working/img_org/test/twitter_19_31_0_13.jpg",
    "/kaggle/working/img_org/test/twitter_19_31_0_8.jpg",
]

# ----------------------------
# Utilities
# ----------------------------
def safe_fix_val_path():
    """If the val file was created with an accidental space, rename it once."""
    bad = "/kaggle/working/MNRE/mnre_txt/mnre_val .txt"
    good = "/kaggle/working/MNRE/mnre_txt/mnre_val.txt"
    if os.path.exists(bad) and not os.path.exists(good):
        os.rename(bad, good)

def load_all_mnre_entries(txt_paths):
    """Load ALL entries across train/val/test for lookup by img_id."""
    entries = []
    for split, p in txt_paths.items():
        if not os.path.exists(p):
            raise FileNotFoundError(f"Missing file for split '{split}': {p}")
        with open(p, "r", encoding="utf-8") as f:
            for line in f:
                obj = ast.literal_eval(line.strip())
                obj["_split"] = split
                entries.append(obj)
    return entries

def build_relation2id_from_train(train_path):
    """Build relation2id by scanning the train file (consistent with training)."""
    rels = set()
    with open(train_path, "r", encoding="utf-8") as f:
        for line in f:
            obj = ast.literal_eval(line.strip())
            rels.add(obj["relation"])
    relation2id = {r: idx for idx, r in enumerate(sorted(rels))}
    return relation2id

def insert_markers(tokens, h_pos, t_pos):
    """Insert [E1] [/E1], [E2] [/E2] into tokens given head/tail spans."""
    toks = tokens.copy()
    spans = [('h', h_pos[0], h_pos[1]), ('t', t_pos[0], t_pos[1])]
    spans.sort(key=lambda x: x[1], reverse=True)
    for etype, start, end in spans:
        tag_close = '[/E1]' if etype == 'h' else '[/E2]'
        tag_open  = '[E1]'  if etype == 'h' else '[E2]'
        toks.insert(end, tag_close)
        toks.insert(start, tag_open)
    return " ".join(toks)

def extract_entities_from_marked_text(marked_text):
    """Return (head_text, tail_text) by parsing [E1]..[/E1] and [E2]..[/E2]."""
    def between(text, open_tag, close_tag):
        if open_tag in text and close_tag in text:
            s = text.index(open_tag) + len(open_tag)
            e = text.index(close_tag)
            return text[s:e].strip()
        return None
    h_text = between(marked_text, "[E1]", "[/E1]")
    t_text = between(marked_text, "[E2]", "[/E2]")
    return h_text, t_text

# ----------------------------
# BLIP-based RE model definition (same as training)
# ----------------------------
class BLIPRelationClassifier(nn.Module):
    """
    Classification on top of BLIP's projection embeddings.
    We use blip.config.projection_dim to stay compatible across variants.
    """
    def __init__(self, num_relations, blip_model):
        super().__init__()
        self.blip = blip_model
        hidden = self.blip.config.projection_dim  # 512 for base captioning model
        self.classifier = nn.Sequential(
            nn.Linear(hidden, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_relations)
        )

    def forward(self, input_ids, attention_mask, pixel_values, use_text_embeds=True):
        """
        use_text_embeds=True:
            classify using BLIP's text_embeds (projected text representation)
        use_text_embeds=False:
            classify using BLIP's image_embeds
        Optionally, you could concatenate both.
        """
        out = self.blip(input_ids=input_ids,
                        attention_mask=attention_mask,
                        pixel_values=pixel_values)
        pooled = out.text_embeds if use_text_embeds else out.image_embeds
        logits = self.classifier(pooled)
        return logits

# ----------------------------
# Load BLIP + Processor + NER pipeline
# ----------------------------
def load_backbones():
    processor = BlipProcessor.from_pretrained(BLIP_NAME)
    blip_model = BlipModel.from_pretrained(BLIP_NAME).to(device)
    # We typically freeze BLIP for inference; it doesn't matter here.
    for p in blip_model.parameters():
        p.requires_grad = False

    # NER model (pretrained token-classification)
    ner_name = "dslim/bert-base-NER"
    ner_tok = AutoTokenizer.from_pretrained(ner_name)
    ner_mdl = AutoModelForTokenClassification.from_pretrained(ner_name).to(device)
    ner_pipe = pipeline("ner", model=ner_mdl, tokenizer=ner_tok, aggregation_strategy="simple", device=0 if device.type=="cuda" else -1)

    return processor, blip_model, ner_pipe

# ----------------------------
# Prepare model from checkpoint (or fall back)
# ----------------------------
def prepare_re_model(processor, blip_model, relation2id, ckpt_path):
    model = BLIPRelationClassifier(num_relations=len(relation2id), blip_model=blip_model).to(device)
    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path, map_location=device)
        missing = model.load_state_dict(ckpt["model_state"], strict=False)
        if missing.missing_keys or missing.unexpected_keys:
            print("⚠️ State dict mismatch:", missing)
        id2rel_from_ckpt = {v:k for k,v in ckpt.get("relation2id", {}).items()}
        print(f"✅ Loaded RE checkpoint: {ckpt_path}")
        return model, id2rel_from_ckpt if id2rel_from_ckpt else {v:k for k,v in relation2id.items()}
    else:
        print(f"⚠️ RE checkpoint not found at {ckpt_path}. Using fresh classifier weights (accuracy will be poor).")
        return model, {v:k for k,v in relation2id.items()}

# ----------------------------
# Build an index: img_id -> list of entries
# ----------------------------
def index_entries_by_img_id(entries):
    idx = {}
    for e in entries:
        img_id = e["img_id"]
        idx.setdefault(img_id, []).append(e)
    return idx

# ----------------------------
# Single-sample preprocessing for BLIP
# ----------------------------
def build_blip_inputs(processor, image_path, marked_text, max_len=128):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, text=marked_text,
                       padding="max_length", truncation=True,
                       max_length=max_len, return_tensors="pt")
    # Move to device
    for k in inputs:
        inputs[k] = inputs[k].to(device)
    return image, inputs

# ----------------------------
# Pretty print + visualize
# ----------------------------
def show_result(image, pred_relation, confidence, sentence, h_text, h_type, t_text, t_type):
    plt.imshow(image)
    plt.axis('off')
    plt.title(f"Relation: {pred_relation} ({confidence:.1%})")
    plt.show()

    print("────────────────────────────────────────")
    print("📝 Sentence:")
    print(sentence)
    print("\n👤 Head Entity (E1):", h_text, "| NER type:", h_type)
    print("👥 Tail Entity (E2):", t_text, "| NER type:", t_type)
    print("🔗 Predicted Relation:", pred_relation, f"(conf={confidence:.4f})")
    print("────────────────────────────────────────\n")

# ----------------------------
# NER helper for entity types
# ----------------------------
def ner_type_for_span(ner_pipe, full_text, span_text):
    """
    Run NER on the full sentence; pick the dominant label covering the given span_text.
    If not found, return 'MISC'.
    """
    ents = ner_pipe(full_text)
    # Find any entity whose text overlaps with span_text (simple contains)
    candidates = [e for e in ents if span_text and span_text in e["word"] or span_text and span_text in e.get("entity_group","")]
    if not candidates:
        # looser matching: check if entity words are inside the span text
        for e in ents:
            if e["word"] and e["word"] in span_text:
                candidates.append(e)
    if not candidates:
        return "MISC"
    # majority label
    from collections import Counter
    label = Counter([e["entity_group"] for e in candidates]).most_common(1)[0][0]
    return label

# ----------------------------
# Main: predict on a list of image paths
# ----------------------------
def predict_images(image_paths):
    # Fix any bad filename first
    safe_fix_val_path()

    # Load entries and index by img_id
    entries = load_all_mnre_entries(TXT_PATHS)
    by_img = index_entries_by_img_id(entries)

    # Build relation2id (from train) for consistent id order
    relation2id_train = build_relation2id_from_train(TXT_PATHS["train"])

    # Backbones
    processor, blip_model, ner_pipe = load_backbones()

    # RE model
    re_model, id2relation = prepare_re_model(processor, blip_model, relation2id_train, CKPT_PATH)
    re_model.eval()

    # Go over requested images
    for ipath in image_paths:
        if not os.path.exists(ipath):
            print(f"❌ Image not found: {ipath}")
            continue
        img_id = os.path.basename(ipath)

        if img_id not in by_img:
            print(f"⚠️ No MNRE entry found for image id: {img_id}")
            # Still show the image
            image = Image.open(ipath).convert("RGB")
            plt.imshow(image); plt.axis('off'); plt.title("Image (no matching text entry)"); plt.show()
            continue

        # MNRE can have multiple sentences per image; iterate all
        for sample in by_img[img_id]:
            tokens = sample["token"]
            h_pos = sample["h"]["pos"]
            t_pos = sample["t"]["pos"]

            # Insert entity markers so model knows E1/E2
            marked_text = insert_markers(tokens, h_pos, t_pos)
            h_text, t_text = extract_entities_from_marked_text(marked_text)

            # Build BLIP inputs
            image, inputs = build_blip_inputs(processor, ipath, marked_text, max_len=128)

            # RE prediction
            with torch.no_grad():
                logits = re_model(inputs["input_ids"], inputs["attention_mask"], inputs["pixel_values"], use_text_embeds=True)
                probs  = torch.softmax(logits, dim=1)
                pred_i = int(torch.argmax(probs, dim=1).item())
                conf   = float(probs[0, pred_i].item())
                pred_relation = id2relation.get(pred_i, f"rel_{pred_i}")

            # NER types for E1/E2 (using pretrained NER)
            # We run NER on the sentence WITHOUT tags for better tagging
            clean_sentence = marked_text.replace("[E1]", "").replace("[/E1]", "").replace("[E2]", "").replace("[/E2]", "")
            h_type = ner_type_for_span(ner_pipe, clean_sentence, h_text or "")
            t_type = ner_type_for_span(ner_pipe, clean_sentence, t_text or "")

            # Display
            show_result(image, pred_relation, conf, marked_text, h_text, h_type, t_text, t_type)

# ----------------------------
# RUN
# ----------------------------
predict_images(IMAGES_TO_PREDICT)
