# Bilingual Hypergraph QA — Colab cell-by-cell


In [None]:
#Installin

In [2]:
# %% [markdown]
!pip install -q transformers datasets sentencepiece accelerate

## Imports

In [None]:
import os
import math
import json
import random
from typing import List, Tuple, Dict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel, AutoConfig, AdamW, get_linear_schedule_with_warmup

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

## Ultilits

In [None]:
## 3) Utilities: set_seed, normalize, EM/F1 (simple)

# %%

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(SEED)

# simple string normalize for EM/F1 (not exhaustive)
import re

def normalize_answer(s: str) -> str:
    s = s.lower()
    s = re.sub(r"[^a-z0-9\s]", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

# simple f1 and exact match for demonstration

def compute_em_f1(pred: str, gold: str) -> Tuple[int, float]:
    pred_n = normalize_answer(pred)
    gold_n = normalize_answer(gold)
    em = int(pred_n == gold_n)
    pred_tokens = pred_n.split()
    gold_tokens = gold_n.split()
    if len(pred_tokens) == 0 or len(gold_tokens) == 0:
        f1 = 1.0 if pred_tokens == gold_tokens else 0.0
    else:
        common = set(pred_tokens) & set(gold_tokens)
        num_same = sum(min(pred_tokens.count(w), gold_tokens.count(w)) for w in common)
        if num_same == 0:
            f1 = 0.0
        else:
            precision = num_same / len(pred_tokens)
            recall = num_same / len(gold_tokens)
            f1 = 2 * precision * recall / (precision + recall)
    return em, f1

## Tokenizer

In [None]:
MODEL_NAME = 'xlm-roberta-base'  # multilingual encoder

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
base_model = AutoModel.from_pretrained(MODEL_NAME).to(device)
base_model.eval()  # we'll use it as an encoder; optionally fine-tune

print('Tokenizer vocab size:', tokenizer.vocab_size)

## Tiny synthetic dataset

In [None]:
## 5) Example data format and a tiny synthetic dataset

# We'll create a tiny synthetic example showing English + French contexts and a span answer.
# In practice you'd load a bilingual dataset (e.g., translated SQuAD) where each example has English context, French context (aligned), question, and answer span(s).

# %%
examples = [
    {
        'id': 'example-1',
        'question_en': 'What color is the cat?',
        'context_en': 'The cat sat on the mat. The cat is black and white.',
        'context_fr': "Le chat est assis sur le tapis. Le chat est noir et blanc.",
        'answer_text': 'black and white',
        'answer_start': 34  # index in the English context for demonstration
    }
]

# A helper to find token-level start/end after tokenization (approximate for the demo)

def encode_example(example, tokenizer, max_length=384):
    q = example['question_en']
    c_en = example['context_en']
    c_fr = example['context_fr']

    # Tokenize contexts separately (we will keep token-level mapping separate)
    enc_q = tokenizer(q, add_special_tokens=False)
    enc_en = tokenizer(c_en, add_special_tokens=False)
    enc_fr = tokenizer(c_fr, add_special_tokens=False)

    return {
        'id': example['id'],
        'question_tokens': enc_q,
        'en_tokens': enc_en,
        'fr_tokens': enc_fr,
        'answer_text': example['answer_text']
    }

encoded = [encode_example(e, tokenizer) for e in examples]

##Hyper graph construction

In [None]:

# We'll construct simple hyperedges:
# - sentence-level hyperedges (all tokens in a sentence)
# - question-centered hyperedge (all tokens in the question)
# - sliding-window phrase hyperedges (n-grams)

# We'll represent hypergraph by incidence matrix H (num_nodes x num_hyperedges)

# %%

def build_simple_hypergraph(num_nodes: int, sentence_boundaries: List[Tuple[int,int]],
                            question_node_indices: List[int], window_size=4) -> torch.Tensor:
    """
    Return H (num_nodes x num_hyperedges) as float tensor.
    sentence_boundaries: list of (start_idx, end_idx) inclusive indices for sentences
    question_node_indices: list of token indices that belong to the question (if any)
    """
    hyperedges = []
    # sentence hyperedges
    for (s,e) in sentence_boundaries:
        members = list(range(s, e+1))
        hyperedges.append(members)
    # question hyperedge
    if question_node_indices:
        hyperedges.append(list(question_node_indices))
    # sliding windows over tokens
    for i in range(0, num_nodes, max(1, window_size//2)):
        members = list(range(i, min(i+window_size, num_nodes)))
        if len(members) > 0:
            hyperedges.append(members)

    H = torch.zeros((num_nodes, len(hyperedges)), dtype=torch.float)
    for j, members in enumerate(hyperedges):
        for idx in members:
            if 0 <= idx < num_nodes:
                H[idx, j] = 1.0
    return H

# normalize incidence for hypergraph convolution

def hypergraph_normalization(H: torch.Tensor):
    # H: N x E
    N, E = H.shape
    deg_v = torch.clamp(H.sum(dim=1), min=1.0)  # node degrees
    deg_e = torch.clamp(H.sum(dim=0), min=1.0)  # hyperedge degrees
    Dv_inv_sqrt = torch.diag(torch.pow(deg_v, -0.5))
    De_inv = torch.diag(1.0 / deg_e)
    return Dv_inv_sqrt, De_inv

## Hypergraph Conv layer

In [None]:
## 7) Hypergraph convolution layer (PyTorch)

# Equation (one commonly used form):
# X' = D_v^{-1/2} H W_e^{-1} H^T D_v^{-1/2} X W
# where H is incidence matrix (N x E), W_e is diagonal of hyperedge sizes (E x E) (we used De_inv above),
# and W is learnable weight matrix.

# %%

class HyperGraphConv(nn.Module):
    def __init__(self, in_dim, out_dim, bias=True):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim, bias=bias)

    def forward(self, X: torch.Tensor, H: torch.Tensor):
        # X: N x F, H: N x E
        # compute normalization
        Dv_inv_sqrt, De_inv = hypergraph_normalization(H)
        # propagation
        # step 1: H^T Dv^{-1/2} X
        X_ = Dv_inv_sqrt @ X  # N x F
        tmp = H.t() @ X_      # E x F
        tmp = De_inv @ tmp    # E x F
        out = H @ tmp         # N x F
        out = Dv_inv_sqrt @ out  # N x F
        out = self.linear(out)   # N x out_dim
        return out

# simple block
class HyperGCNBlock(nn.Module):
    def __init__(self, dim, hidden_dim=None, dropout=0.1):
        super().__init__()
        hidden_dim = hidden_dim or dim
        self.conv = HyperGraphConv(dim, hidden_dim)
        self.norm = nn.LayerNorm(hidden_dim)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, X, H):
        h = self.conv(X, H)
        h = self.act(self.norm(h))
        return self.dropout(h)

## Cross attention model

In [None]:
## 8) Cross-attention module (English<->French)

# We'll implement a simple cross-attention using PyTorch's MultiheadAttention.
# English tokens attend to French representations and vice versa. Outputs are residual-style fused representations.

# %%

class CrossAttentionFusion(nn.Module):
    def __init__(self, dim, n_heads=8, dropout=0.1):
        super().__init__()
        self.dim = dim
        self.mha_en_to_fr = nn.MultiheadAttention(embed_dim=dim, num_heads=n_heads, dropout=dropout, batch_first=True)
        self.mha_fr_to_en = nn.MultiheadAttention(embed_dim=dim, num_heads=n_heads, dropout=dropout, batch_first=True)
        self.norm_en = nn.LayerNorm(dim)
        self.norm_fr = nn.LayerNorm(dim)
        self.ff = nn.Sequential(nn.Linear(2*dim, dim), nn.GELU(), nn.Linear(dim, dim))

    def forward(self, en_repr, fr_repr, en_mask=None, fr_mask=None):
        # en_repr: B x N_en x D, fr_repr: B x N_fr x D
        # use mha: query=EN, key/value=FR -> EN attends to FR
        en_attended, _ = self.mha_en_to_fr(query=en_repr, key=fr_repr, value=fr_repr, key_padding_mask=fr_mask)
        fr_attended, _ = self.mha_fr_to_en(query=fr_repr, key=en_repr, value=en_repr, key_padding_mask=en_mask)

        # residual + norm
        en_fused = self.norm_en(en_repr + en_attended)
        fr_fused = self.norm_fr(fr_repr + fr_attended)

        # cross-fusion: concatenate corresponding pooled vectors (simple)
        # We will broadcast-average each sequence and then fuse
        en_pool = en_fused.mean(dim=1)  # B x D
        fr_pool = fr_fused.mean(dim=1)
        combined = torch.cat([en_pool, fr_pool], dim=-1)  # B x 2D
        fused = self.ff(combined)  # B x D
        return en_fused, fr_fused, fused

## Model assembly

In [None]:

# Components:
# - Transformer encoder (we'll reuse `base_model` to get contextual token embeddings for each language separately).
# - HyperGCN blocks on each stream.
# - Cross-attention fusion.
# - Prediction head for start/end logits over English tokens (or both languages; for now predict on English tokens but use French info).

# %%

class BilingualHypergraphQA(nn.Module):
    def __init__(self, encoder, hidden_dim=768, hyper_layers=2, attn_heads=8, dropout=0.1, freeze_encoder=False):
        super().__init__()
        self.encoder = encoder
        self.hidden_dim = hidden_dim
        if freeze_encoder:
            for p in self.encoder.parameters():
                p.requires_grad = False

        # hypergraph blocks for English and French
        self.en_hyper_blocks = nn.ModuleList([HyperGCNBlock(hidden_dim, hidden_dim, dropout) for _ in range(hyper_layers)])
        self.fr_hyper_blocks = nn.ModuleList([HyperGCNBlock(hidden_dim, hidden_dim, dropout) for _ in range(hyper_layers)])

        self.cross = CrossAttentionFusion(hidden_dim, n_heads=attn_heads, dropout=dropout)

        # prediction head; we'll make start and end predictors over English sequence
        self.pred_start = nn.Linear(hidden_dim, 1)
        self.pred_end = nn.Linear(hidden_dim, 1)

    def encode_text(self, input_ids, attention_mask):
        # returns sequence embeddings B x L x D
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=False, return_dict=True)
        # use last hidden state
        return out.last_hidden_state

    def forward(self, en_input_ids, en_mask, fr_input_ids, fr_mask, H_en, H_fr):
        # en_input_ids: B x L_en ; fr_input_ids: B x L_fr
        B = en_input_ids.size(0)

        en_seq = self.encode_text(en_input_ids, en_mask)  # B x L_en x D
        fr_seq = self.encode_text(fr_input_ids, fr_mask)  # B x L_fr x D

        # Flatten batch to apply hypergraph ops per sample (our HyperGraphConv expects N x F)
        en_outs = []
        fr_outs = []
        for b in range(B):
            Xe = en_seq[b]  # L_en x D
            Xf = fr_seq[b]  # L_fr x D
            He = H_en[b].to(Xe.device)  # L_en x E_en
            Hf = H_fr[b].to(Xf.device)  # L_fr x E_fr

            # apply hypergraph layers sequentially
            for layer in self.en_hyper_blocks:
                Xe = layer(Xe, He)
            for layer in self.fr_hyper_blocks:
                Xf = layer(Xf, Hf)

            en_outs.append(Xe)
            fr_outs.append(Xf)

        # pad back to B x L x D (they are already that shape), stack
        en_stack = torch.stack([e for e in en_outs], dim=0)
        fr_stack = torch.stack([f for f in fr_outs], dim=0)

        # cross-attention fusion
        en_fused, fr_fused, pooled = self.cross(en_stack, fr_stack, en_mask==0 if en_mask is not None else None, fr_mask==0 if fr_mask is not None else None)

        # predict on English fused tokens
        start_logits = self.pred_start(en_fused).squeeze(-1)  # B x L_en
        end_logits = self.pred_end(en_fused).squeeze(-1)

        return start_logits, end_logits, pooled

##Dataset & collate for bilingual inputs

In [None]:
# For demo, we create a simple dataset that tokenizes question + context separately for English and French.

# %%
class SimpleBilingualQADataset(Dataset):
    def __init__(self, examples, tokenizer, max_len=128):
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        q = ex['question_en']
        c_en = ex['context_en']
        c_fr = ex['context_fr']

        enc_q = tokenizer(q, return_tensors='pt', add_special_tokens=True)
        enc_en = tokenizer(c_en, return_tensors='pt', add_special_tokens=True)
        enc_fr = tokenizer(c_fr, return_tensors='pt', add_special_tokens=True)

        return {
            'id': ex['id'],
            'q': q,
            'context_en': c_en,
            'context_fr': c_fr,
            'enc_q': enc_q,
            'enc_en': enc_en,
            'enc_fr': enc_fr,
            'answer_text': ex.get('answer_text', '')
        }

# collate that returns tensors and builds toy hypergraphs

def collate_fn(batch):
    B = len(batch)
    # for simplicity ensure we use the same tokenization lengths as returned
    en_ids = [b['enc_en']['input_ids'].squeeze(0) for b in batch]
    en_mask = [b['enc_en']['attention_mask'].squeeze(0) for b in batch]
    fr_ids = [b['enc_fr']['input_ids'].squeeze(0) for b in batch]
    fr_mask = [b['enc_fr']['attention_mask'].squeeze(0) for b in batch]

    # pad sequences
    en_ids_padded = torch.nn.utils.rnn.pad_sequence(en_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    en_mask_padded = torch.nn.utils.rnn.pad_sequence(en_mask, batch_first=True, padding_value=0)
    fr_ids_padded = torch.nn.utils.rnn.pad_sequence(fr_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    fr_mask_padded = torch.nn.utils.rnn.pad_sequence(fr_mask, batch_first=True, padding_value=0)

    # build H_en and H_fr per sample
    H_en_list = []
    H_fr_list = []
    for i, b in enumerate(batch):
        L_en = en_ids[i].size(0)
        L_fr = fr_ids[i].size(0)
        # naive sentence boundaries: split on period tokens in the decoded text (very approximate)
        txt_en = b['context_en']
        sentences = [s.strip() for s in txt_en.split('.') if s.strip()]
        s_bounds = []
        cur = 1  # account for special token possibly at position 0
        for s in sentences:
            toks = tokenizer(s, add_special_tokens=False)
            start = cur
            end = cur + len(toks['input_ids']) - 1
            s_bounds.append((start, min(L_en - 1, end)))
            cur = end + 1
        q_tokens = tokenizer(b['q'], add_special_tokens=False)
        q_indices = list(range(1, 1 + len(q_tokens['input_ids']))) if len(q_tokens['input_ids'])>0 else []
        H_en = build_simple_hypergraph(L_en, s_bounds, q_indices, window_size=4)

        # French: similar
        txt_fr = b['context_fr']
        sentences = [s.strip() for s in txt_fr.split('.') if s.strip()]
        s_bounds = []
        cur = 1
        for s in sentences:
            toks = tokenizer(s, add_special_tokens=False)
            start = cur
            end = cur + len(toks['input_ids']) - 1
            s_bounds.append((start, min(L_fr - 1, end)))
            cur = end + 1
        H_fr = build_simple_hypergraph(L_fr, s_bounds, [], window_size=4)

        H_en_list.append(H_en)
        H_fr_list.append(H_fr)

    return {
        'en_input_ids': en_ids_padded.to(device),
        'en_mask': en_mask_padded.to(device),
        'fr_input_ids': fr_ids_padded.to(device),
        'fr_mask': fr_mask_padded.to(device),
        'H_en': H_en_list,
        'H_fr': H_fr_list,
        'raw': batch
    }


##Instantiate model, optimizer, and a short training loop (demo)

In [None]:
# %%
model = BilingualHypergraphQA(encoder=base_model, hidden_dim=base_model.config.hidden_size, hyper_layers=2, attn_heads=8, freeze_encoder=True).to(device)
optimizer = AdamW([p for p in model.parameters() if p.requires_grad], lr=3e-4)

# small demo dataset
train_dataset = SimpleBilingualQADataset([examples[0]], tokenizer)
train_loader = DataLoader(train_dataset, batch_size=1, collate_fn=collate_fn)


##Training step (very small demo — not meaningful metrics)

In [None]:
# %%
model.train()
for batch in train_loader:
    en_ids = batch['en_input_ids']
    en_mask = batch['en_mask']
    fr_ids = batch['fr_input_ids']
    fr_mask = batch['fr_mask']
    H_en = batch['H_en']
    H_fr = batch['H_fr']

    start_logits, end_logits, pooled = model(en_ids, en_mask, fr_ids, fr_mask, H_en, H_fr)

    # For demo, create toy target: find answer token positions by searching string in decoded text (approx)
    raw = batch['raw'][0]
    answer = raw['answer_text']
    decoded_en = tokenizer.decode(en_ids[0], skip_special_tokens=True)
    # naive index search
    if answer in decoded_en:
        char_idx = decoded_en.index(answer)
        # approximate token-level mapping by tokenizing prefix
        prefix = decoded_en[:char_idx]
        prefix_toks = tokenizer(prefix, add_special_tokens=False)['input_ids']
        start_pos = len(prefix_toks) + 1
        answer_toks = tokenizer(answer, add_special_tokens=False)['input_ids']
        end_pos = start_pos + len(answer_toks) - 1
    else:
        start_pos, end_pos = 1, 1

    # build label tensors
    B, L = start_logits.shape
    start_labels = torch.zeros((B, L), device=device)
    end_labels = torch.zeros((B, L), device=device)
    start_labels[0, start_pos] = 1.0
    end_labels[0, end_pos] = 1.0

    loss_fct = nn.BCEWithLogitsLoss()
    loss = loss_fct(start_logits, start_labels) + loss_fct(end_logits, end_labels)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    print('Demo loss:', loss.item())


##Inference helper to extract predicted span

In [None]:
def predict_span(start_logits: torch.Tensor, end_logits: torch.Tensor, input_ids: torch.Tensor):
    # greedy argmax
    start_idx = torch.argmax(start_logits, dim=-1).item()
    end_idx = torch.argmax(end_logits, dim=-1).item()
    # clamp
    start_idx = max(0, min(start_idx, input_ids.size(1)-1))
    end_idx = max(start_idx, min(end_idx, input_ids.size(1)-1))
    tokens = input_ids[0, start_idx:(end_idx+1)]
    text = tokenizer.decode(tokens, skip_special_tokens=True)
    return text, start_idx, end_idx

##Run the demo inference

In [None]:
# %%
model.eval()
with torch.no_grad():
    for batch in train_loader:
        s_logits, e_logits, pooled = model(batch['en_input_ids'], batch['en_mask'], batch['fr_input_ids'], batch['fr_mask'], batch['H_en'], batch['H_fr'])
        pred_text, s_idx, e_idx = predict_span(s_logits, e_logits, batch['en_input_ids'])
        print('Predicted span:', pred_text)
        print('Ground truth:', batch['raw'][0]['answer_text'])
        em, f1 = compute_em_f1(pred_text, batch['raw'][0]['answer_text'])
        print('EM:', em, 'F1:', f1)