In [None]:
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from transformers import AutoTokenizer
from x_transformers import TransformerWrapper, Decoder

In [None]:
FEWREL_JSON = "fewrel_train.json"
MODEL_PATH = "xtransformer_fewrel.pt"

MAX_LEN = 256
BATCH_SIZE = 16
EPOCHS = 50
LR = 2e-4
SEED = 42

In [None]:
torch.manual_seed(SEED)
np.random.seed(SEED)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tokenizer.add_special_tokens(
    {"additional_special_tokens": ["[E1]", "[/E1]", "[E2]", "[/E2]"]}
)

PAD_ID = tokenizer.pad_token_id
VOCAB_SIZE = tokenizer.vocab_size + 4

In [None]:
def load_fewrel(path):
    texts, labels = [], []
    rel2id = {}

    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)

    for rid, rel in enumerate(sorted(data.keys())):
        rel2id[rel] = rid

        for item in data[rel]:
            tokens = item["tokens"].copy()
            h_pos = item["h"][2][0]
            t_pos = item["t"][2][0]

            if h_pos[0] < t_pos[0]:
                tokens.insert(h_pos[0], "[E1]")
                tokens.insert(h_pos[1] + 2, "[/E1]")
                tokens.insert(t_pos[0] + 2, "[E2]")
                tokens.insert(t_pos[1] + 4, "[/E2]")
            else:
                tokens.insert(t_pos[0], "[E2]")
                tokens.insert(t_pos[1] + 2, "[/E2]")
                tokens.insert(h_pos[0] + 2, "[E1]")
                tokens.insert(h_pos[1] + 4, "[/E1]")

            texts.append(" ".join(tokens))
            labels.append(rid)

    return texts, labels, rel2id

texts, labels, rel2id = load_fewrel(FEWREL_JSON)
id2rel = {v: k for k, v in rel2id.items()}
NUM_CLASSES = len(rel2id)

In [None]:
encoded = tokenizer(
    texts,
    truncation=True,
    max_length=MAX_LEN,
    padding=False
)["input_ids"]

X_train, X_tmp, y_train, y_tmp = train_test_split(
    encoded, labels, test_size=0.2, random_state=SEED
)
X_val, X_test, y_val, y_test = train_test_split(
    X_tmp, y_tmp, test_size=0.5, random_state=SEED
)

In [None]:
class FewRelDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = torch.tensor(y)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.tensor(self.X[idx]), self.y[idx]

def collate_fn(batch):
    seqs, labels = zip(*batch)
    x = pad_sequence(seqs, batch_first=True, padding_value=PAD_ID)
    return x.to(device), torch.tensor(labels).to(device)

train_loader = DataLoader(
    FewRelDataset(X_train, y_train),
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn
)
val_loader = DataLoader(
    FewRelDataset(X_val, y_val),
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn
)
test_loader = DataLoader(
    FewRelDataset(X_test, y_test),
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn
)

In [None]:
class XTransformerFewRel(nn.Module):
    def __init__(self):
        super().__init__()
        self.transformer = TransformerWrapper(
            num_tokens=VOCAB_SIZE,
            max_seq_len=MAX_LEN,
            attn_layers=Decoder(
                dim=256,
                depth=4,
                heads=4
            )
        )
        self.fc = nn.Linear(256, NUM_CLASSES)

    def forward(self, x):
        h = self.transformer(x, return_embeddings=True)
        mask = (x != PAD_ID).unsqueeze(-1)
        pooled = (h * mask).sum(1) / mask.sum(1).clamp(min=1)
        return self.fc(pooled)

In [None]:
model = XTransformerFewRel().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR)

In [None]:
best_f1 = 0.0

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0

    for Xb, yb in train_loader:
        optimizer.zero_grad()
        logits = model(Xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    model.eval()
    yt, yp = [], []

    with torch.no_grad():
        for Xb, yb in val_loader:
            preds = model(Xb).argmax(1)
            yt.extend(yb.cpu().numpy())
            yp.extend(preds.cpu().numpy())

    val_f1 = f1_score(yt, yp, average="macro")

    print(f"Epoch {epoch+1:03d} | Loss {total_loss:.4f} | Val F1 {val_f1:.4f}")

    if val_f1 > best_f1:
        best_f1 = val_f1
        torch.save(model.state_dict(), MODEL_PATH)

In [None]:
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()

yt, yp = [], []
with torch.no_grad():
    for Xb, yb in test_loader:
        preds = model(Xb).argmax(1)
        yt.extend(yb.cpu().numpy())
        yp.extend(preds.cpu().numpy())

print("\nFINAL TEST MACRO F1:", f1_score(yt, yp, average="macro"))

In [None]:
def predict(sentence):
    """
    sentence MUST contain [E1] [/E1] and [E2] [/E2]
    """
    model.eval()

    enc = tokenizer(
        sentence,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_LEN
    )
    x = enc["input_ids"].to(device)

    with torch.no_grad():
        logits = model(x)
        probs = torch.softmax(logits, dim=-1)[0]
        pred_id = probs.argmax().item()

    return {
        "relation": id2rel[pred_id],
        "confidence": float(probs[pred_id])
    }

def predict_topk(sentence, k=5):
    model.eval()

    enc = tokenizer(
        sentence,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_LEN
    )
    x = enc["input_ids"].to(device)

    with torch.no_grad():
        probs = torch.softmax(model(x), dim=-1)[0]

    topk = torch.topk(probs, k)

    return [
        {"relation": id2rel[i.item()], "confidence": float(s)}
        for s, i in zip(topk.values, topk.indices)
    ]

example = "[E1] Steve Jobs [/E1] was born in [E2] San Francisco [/E2]"
print("\nPREDICT:")
print(predict(example))
print("\nTOP-5:")
print(predict_topk(example))