In [None]:
!pip install gensim x-transformers scikit-learn

In [None]:

import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from gensim.models import Word2Vec
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from x_transformers import TransformerWrapper, Decoder

In [None]:
FEWREL_JSON = "fewrel_train.json"
MODEL_PATH = "xtransformer_w2v_fewrel.pt"

MAX_LEN = 128
EMB_DIM = 300
BATCH_SIZE = 16
EPOCHS = 40
LR = 2e-4
SEED = 42

In [None]:
torch.manual_seed(SEED)
np.random.seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
SPECIAL_TOKENS = ["<pad>", "<unk>", "[E1]", "[/E1]", "[E2]", "[/E2]"]

def load_fewrel(path):
    texts, labels = [], []
    rel2id = {}

    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)

    for rid, rel in enumerate(sorted(data.keys())):
        rel2id[rel] = rid

        for item in data[rel]:
            tokens = item["tokens"].copy()
            h_pos = item["h"][2][0]
            t_pos = item["t"][2][0]

            if h_pos[0] < t_pos[0]:
                tokens.insert(h_pos[0], "[E1]")
                tokens.insert(h_pos[1] + 2, "[/E1]")
                tokens.insert(t_pos[0] + 2, "[E2]")
                tokens.insert(t_pos[1] + 4, "[/E2]")
            else:
                tokens.insert(t_pos[0], "[E2]")
                tokens.insert(t_pos[1] + 2, "[/E2]")
                tokens.insert(h_pos[0] + 2, "[E1]")
                tokens.insert(h_pos[1] + 4, "[/E1]")

            texts.append(tokens)
            labels.append(rid)

    return texts, labels, rel2id

texts, labels, rel2id = load_fewrel(FEWREL_JSON)
id2rel = {v: k for k, v in rel2id.items()}
NUM_CLASSES = len(rel2id)

In [None]:
sentences = [SPECIAL_TOKENS + t for t in texts]

w2v = Word2Vec(
    sentences=sentences,
    vector_size=EMB_DIM,
    window=5,
    min_count=1,
    workers=4,
    seed=SEED
)

In [None]:
word2id = {w: i for i, w in enumerate(SPECIAL_TOKENS)}
for w in w2v.wv.index_to_key:
    if w not in word2id:
        word2id[w] = len(word2id)

PAD_ID = word2id["<pad>"]
UNK_ID = word2id["<unk>"]
VOCAB_SIZE = len(word2id)

embedding_matrix = np.zeros((VOCAB_SIZE, EMB_DIM))
for w, i in word2id.items():
    if w in w2v.wv:
        embedding_matrix[i] = w2v.wv[w]
    else:
        embedding_matrix[i] = np.random.normal(scale=0.6, size=(EMB_DIM,))

In [None]:
def encode(tokens):
    ids = [word2id.get(w, UNK_ID) for w in tokens]
    return ids[:MAX_LEN]

encoded = [encode(t) for t in texts]

X_train, X_tmp, y_train, y_tmp = train_test_split(
    encoded, labels, test_size=0.2, random_state=SEED
)
X_val, X_test, y_val, y_test = train_test_split(
    X_tmp, y_tmp, test_size=0.5, random_state=SEED
)

In [None]:
class FewRelDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = torch.tensor(y)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.tensor(self.X[idx]), self.y[idx]

def collate_fn(batch):
    seqs, labels = zip(*batch)
    x = pad_sequence(seqs, batch_first=True, padding_value=PAD_ID)
    return x.to(device), torch.tensor(labels).to(device)

train_loader = DataLoader(
    FewRelDataset(X_train, y_train),
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn
)
val_loader = DataLoader(
    FewRelDataset(X_val, y_val),
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn
)
test_loader = DataLoader(
    FewRelDataset(X_test, y_test),
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn
)

In [None]:
class W2V_XTransformer(nn.Module):
    def __init__(self):
        super().__init__()

        self.embedding = nn.Embedding.from_pretrained(
            torch.tensor(embedding_matrix, dtype=torch.float32),
            freeze=False,
            padding_idx=PAD_ID
        )

        self.transformer = TransformerWrapper(
            num_tokens=VOCAB_SIZE,
            max_seq_len=MAX_LEN,
            emb_dim=EMB_DIM,
            attn_layers=Decoder(
                dim=EMB_DIM,
                depth=4,
                heads=4
            )
        )

        self.fc = nn.Linear(EMB_DIM, NUM_CLASSES)

    def forward(self, x):
        emb = self.embedding(x)
        h = self.transformer(emb, return_embeddings=True)
        mask = (x != PAD_ID).unsqueeze(-1)
        pooled = (h * mask).sum(1) / mask.sum(1).clamp(min=1)
        return self.fc(pooled)

In [None]:
model = W2V_XTransformer().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR)

In [None]:
best_f1 = 0.0

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0

    for Xb, yb in train_loader:
        optimizer.zero_grad()
        logits = model(Xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    val_macro_f1, val_micro_f1 = evaluate(model, val_loader)

    print(
        f"Epoch {epoch+1:03d} | "
        f"Loss {total_loss:.4f} | "
        f"Val Macro-F1 {val_macro_f1:.4f} | "
        f"Val Micro-F1 {val_micro_f1:.4f}"
    )

    if val_macro_f1 > best_f1:
        best_f1 = val_macro_f1
        torch.save(model.state_dict(), MODEL_PATH)

In [None]:
model.load_state_dict(torch.load(MODEL_PATH))
test_macro_f1, test_micro_f1 = evaluate(model, test_loader)

print("\nFINAL TEST RESULTS")
print("Macro F1:", test_macro_f1)
print("Micro F1:", test_micro_f1)

In [None]:
def predict(tokens):
    model.eval()
    ids = torch.tensor([encode(tokens)], device=device)

    with torch.no_grad():
        probs = torch.softmax(model(ids), dim=-1)[0]
        pid = probs.argmax().item()

    return {
        "relation": id2rel[pid],
        "confidence": float(probs[pid])
    }