## Imports

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer

## Model

In [9]:
class LeafEncoder(nn.Module):
    
    def __init__(self, model_name='bert-base-uncased'):
        
        super().__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        for param in self.bert.parameters():
            param.requires_grad = False
            
    def encode_query(self, input_ids):
        return self.bert.embeddings.word_embeddings(input_ids)

    def forward(self, input_ids, attention_mask):
        
        # Add CLS token
        cls_token_id = self.tokenizer.cls_token_id
        cls_tokens = torch.full((input_ids.size(0), 1), cls_token_id, dtype=torch.long, device=input_ids.device)
        input_ids = torch.cat([cls_tokens, input_ids], dim=1)
        cls_mask = torch.ones((attention_mask.size(0), 1), dtype=attention_mask.dtype, device=attention_mask.device)
        attention_mask = torch.cat([cls_mask, attention_mask], dim=1)
        
        # Process sequences
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, output_attentions=True)
        return outputs.last_hidden_state[:, 0, :], outputs.attentions[-1]


class BinaryMerger(nn.Module):
    
    def __init__(self, d_model):
        super().__init__()
        self.merger = nn.Linear(2 * d_model, d_model)

    def forward(self, left, right):
        return self.merger(torch.cat([left, right], dim=-1))


class HierarchicalBinaryTree(nn.Module):
    
    def __init__(
        self,
        d_model : int,
        tree_depth : int = 4,
        k_sampling : int = 1,
        train_seq_len : int = 512
    ):
        super().__init__()
        
        self.d_model = d_model
        self.tree_depth = tree_depth
        self.leaf_encoder = LeafEncoder()
        self.mergers = nn.ModuleList([BinaryMerger(d_model) for _ in range(2 ** tree_depth - 1)])
        self.project_vectors = nn.Parameter(torch.randn(self.tree_depth, d_model))
        self.train_seq_len : int = train_seq_len
        self.k_sampling = k_sampling
        
        self.tree_cache = None
        self.cached_tree_chunk_size = None
        
    @property
    def modeling_capacity(self):
        return self.train_seq_len * (2 ** self.tree_depth)
        
    def build_retrieval_tree(
        self,
        context_ids : torch.Tensor
    ):
        """
        Args:
            context_ids (torch.Tensor): (B, L) input ids of the context. There should be not special tokens like CLS. 
            # answer_positions (torch.Tensor, optional): Allow for sparse building of the tree. Defaults to None. Reduce the building complexity by 2 (B, 1) TODO LATER
        """
        k = 2 ** self.tree_depth
        
        assert (L % k == 0), f"The chosen sequence size is not usable by the model. It should be divisable by {2 ** self.tree_depth}"
        
        if L >  self.train_seq_len:
            self.train_seq_len = L
        
        B, L = context_ids
        
        chunks = context_ids.view(B * k, L // k)
        tree_levels = []
        m : torch.Tensor = self.leaf_encoder(chunks).reshape(B, k, self.d_model) # (B * k, D)
        D = self.d_model
        
        tree_levels.append(m)
        for level in range(self.tree_depth):
            level_l = m.shape[1]
            new_stride = (level_l * D, D, D * level_l // 2, 1)
            new_shape = (B, level_l // 2, 2, D)
            final_shape = (B, level_l // 2, 2 * D)
            m = self.mergers(torch.as_strided(m, new_shape, new_stride, final_shape).reshape(final_shape))
            tree_levels.append(m)
            
        self.tree_cache = tree_levels
        self.cached_tree_chunk_size = k
        
    def forward(
        self,
        query_ids : torch.Tensor,
        answer_positions : torch.Tensor = None
    ):
        """

        Args:
            query (torch.Tensor): (B, L)
            answer_positions (torch.Tensor, optional): (B, L, K) K is has k-top softmax elements. If it's not None, the retreival loss will be calculated
        """
        
        assert (self.tree_cache != None), "No tree to do retreival from?"
        
        # Inference
        if answer_positions == None:
            ...
            # Sample answers randomly at each level for k_sampling per token
        
            B, L = query_ids.shape
            K = self.k_sampling
            D = self.d_model
            device = query_ids.device
        
            # Encode query tokens: (B, L, D)
            with torch.no_grad():
                queries = self.leaf_encoder.encode_query(query_ids)  # (B, L, D)
            q_proj = self.q_proj(queries)  # (B, L, D)
        
            # Initialize: root index 0, repeated K times per token
            current_indices = torch.zeros(B, L, K, dtype=torch.long, device=device)
            selected_path = []
            path_logits = []
        
            for l in range(self.tree_depth):
                proj_q = q_proj.unsqueeze(2).expand(-1, -1, K, -1)  # (B, L, K, D)
        
                left_idx = current_indices * 2
                right_idx = current_indices * 2 + 1  # both (B, L, K)
        
                left_nodes = self.tree_cache[l + 1].gather(
                    1, left_idx.unsqueeze(-1).expand(-1, -1, -1, D))  # (B, L, K, D)
                right_nodes = self.tree_cache[l + 1].gather(
                    1, right_idx.unsqueeze(-1).expand(-1, -1, -1, D))  # (B, L, K, D)
        
                score_left = torch.sum(proj_q * left_nodes, dim=-1)     # (B, L, K)
                score_right = torch.sum(proj_q * right_nodes, dim=-1)   # (B, L, K)
        
                logits = torch.stack([score_left, score_right], dim=-1)  # (B, L, K, 2)
                path_logits.append(logits)
        
                # Expand to 2K candidates per token
                all_scores = torch.cat([score_left, score_right], dim=-1)  # (B, L, 2K)
                all_indices = torch.cat([left_idx, right_idx], dim=-1)     # (B, L, 2K)
        
                topk_scores, topk_idx = torch.topk(all_scores, k=K, dim=-1)       # (B, L, K)
                new_indices = torch.gather(all_indices, dim=-1, index=topk_idx)   # (B, L, K)
        
                current_indices = new_indices
                selected_path.append(current_indices)
                
            # === Final token-level prediction from selected chunks ===
            max_chunk_len = self.cached_tree_chunk_size
            top_token_scores = torch.zeros(B, L, K, max_chunk_len, device=query_ids.device)
            
            for b in range(B):
                for l in range(L):
                    query_token_id = query_ids[b, l].unsqueeze(0).unsqueeze(0)  # (1, 1)
                    for k_id, chunk_idx in enumerate(current_indices[b, l]):  # current_indices: (B, L, K)
                        # Select chunk
                        chunk = self.context_chunks[chunk_idx]  # (1, chunk_len)
            
                        # Build BERT input: [CLS] query [SEP] chunk [SEP]
                        input_ids = torch.cat([
                            torch.tensor([[self.leaf_encoder.tokenizer.cls_token_id]], device=query_ids.device),
                            query_token_id,
                            torch.tensor([[self.leaf_encoder.tokenizer.sep_token_id]], device=query_ids.device),
                            chunk,
                            torch.tensor([[self.leaf_encoder.tokenizer.sep_token_id]], device=query_ids.device)
                        ], dim=1)  # (1, T)
            
                        attention_mask = torch.ones_like(input_ids)
            
                        # Run through BERT
                        with torch.no_grad():
                            outputs = self.leaf_encoder.bert(input_ids=input_ids, attention_mask=attention_mask, output_attentions=True)
            
                        # Use last-layer attention from query token to chunk
                        last_attn = outputs.attentions[-1]  # (layers, heads, 1, T)
                        avg_attn = last_attn.mean(dim=(0, 1))  # (1, T)
                        chunk_attn = avg_attn[0, 3:-1]  # skip CLS/query/SEP
            
                        # Save
                        top_token_scores[b, l, k_id, :chunk_attn.size(0)] = chunk_attn
            
            # Final output: attention over chunk tokens for each query
            # You could take argmax over last dim for final prediction
            final_token_predictions = top_token_scores.argmax(dim=-1)  # (B, L, K)
            # Assuming:
            # - final_token_predictions: (B, L, K), local token index within chunk
            # - current_indices: (B, L, K), chunk indices from tree traversal
            # - self.cached_tree_chunk_size: scalar, number of tokens per chunk
            
            chunk_starts = current_indices * self.cached_tree_chunk_size  # (B, L, K)
            global_token_predictions = chunk_starts + final_token_predictions  # (B, L, K)

            return global_token_predictions
            # return {
            #     "final_leaf_indices": current_indices,  # (B, L, K)
            #     "path_logits": path_logits,             # list of (B, L, K, 2)
            #     "selected_path": selected_path          # list of (B, L, K)
            # }
        
        # Training
        else:
            
            # Infer the actual position of the k elements and from it the paths
            # The chunk position of the answer positions is index // chunk_size. This give the index starting at 0 to n-1
            
            # leaf_indices = (answer_positions // self.cached_tree_chunk_size).clamp(max=2 ** self.tree_depth - 1)
            # path_labels = torch.stack([(leaf_indices >> (self.tree_depth - l - 1)) & 1 for l in range(self.tree_depth)]) # (d, B, L, K)
            
            # # Encode queries (B, L) -> (B, L, D)
            # queries = self.leaf_encoder.encode_query(query_ids)
            
            # # Project each query to each level using hadamar product (B, L, D) -> (B, L, d, D)
            # query_projections = queries.unsqueeze(-2) * self.project_vectors 
            
            # self.tree_cache (d, B, number of summaries at this node, D), path labels (d, B, L, K)
            # traverse each level and select the summaries relevant to the input
            
            # Multiply the query projections and the merged_summaries selected (B, L, d, 1, D) * (B, L, d, K, D).swapaxes(-1, -2) 
            
            B, L = query_ids.shape
            D = self.d_model
            device = query_ids.device

            if answer_positions is None:
                raise NotImplementedError("Sampling for inference is not yet implemented.")

            # === Step 1: Compute chunk indices from answer positions ===
            # answer_positions: (B, L, K)
            leaf_indices = (answer_positions // self.cached_tree_chunk_size).clamp(max=2 ** self.tree_depth - 1)  # (B, L, K)

            # === Step 2: Compute binary path bits per level ===
            # path_labels: (d, B, L, K), where each bit tells you to go left (0) or right (1)
            path_labels = torch.stack([
                (leaf_indices >> (self.tree_depth - l - 1)) & 1
                for l in range(self.tree_depth)
            ], dim=0)  # (d, B, L, K)

            # === Step 3: Compute summary indices per level ===
            # This gives the node index per level in the tree cache that corresponds to each top-K answer
            level_indices = torch.zeros(self.tree_depth + 1, B, L, K := leaf_indices.shape[-1], dtype=torch.long, device=device)
            level_indices[0] = 0  # Root is always index 0

            for l in range(1, self.tree_depth + 1):
                prev = level_indices[l - 1]  # (B, L, K)
                bit = path_labels[l - 1]     # (B, L, K)
                level_indices[l] = prev * 2 + bit  # Compute child index

            # === Step 4: Encode the queries ===
            queries = self.leaf_encoder.encode_query(query_ids)  # (B, L, D)
            queries = queries.unsqueeze(-2) * self.project_vectors.view(self.tree_depth, 1, 1, D)  # (d, B, L, D)

            # === Step 5: Select corresponding summaries from tree_cache using level_indices ===
            # self.tree_cache: list of tensors per level: each (B, N_nodes, D)
            selected_summaries = []
            for l in range(1, self.tree_depth + 1):
                cache = self.tree_cache[l]  # (B, N_nodes, D)
                idx = level_indices[l]      # (B, L, K)

                # Gather summaries at this level
                gathered = torch.gather(
                    cache.unsqueeze(1).expand(-1, L, -1, -1),  # (B, L, N_nodes, D)
                    dim=2,
                    index=idx.unsqueeze(-1).expand(-1, -1, -1, D)  # (B, L, K, D)
                )  # (B, L, K, D)
                selected_summaries.append(gathered)
                
            # === Step 6: Compute logits and loss ===
            logits_per_level = []
            loss = 0.0
            d, B, L, K = path_labels.shape
            
            for l in range(self.tree_depth):
                q_proj = queries[l]                           # (B, L, D)
                summaries = selected_summaries[l]             # (B, L, K, D)
                labels = path_labels[l]                       # (B, L, K)
            
                # Compute dot products: (B, L, K)
                logits = torch.einsum('bld,blkd->blk', q_proj, summaries)
            
                # Reshape for cross-entropy: treat each (B, L) position as a batch element
                logits_flat = logits.view(-1, K)              # (B * L, K)
                labels_flat = labels.view(-1)                 # (B * L,)
            
                loss += F.cross_entropy(logits_flat, labels_flat)
                logits_per_level.append(logits)
            
            return loss


## Load Dataset

In [10]:
from datasets import load_dataset
import torch
from torch.utils.data import Dataset, DataLoader
import random

class HotpotQADataset(Dataset):
    def __init__(self, split="train", tokenizer=None, max_length=512, num_queries=8):
        self.data = load_dataset("hotpot_qa", "fullwiki", split=split)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.num_queries = num_queries
        self.vocab = list(tokenizer.get_vocab().values()) if tokenizer else list(range(30522))  # fallback to BERT vocab

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        context = item["context"]  # list of (title, paragraph)
        full_text = " ".join([p for _, p in context])[:self.max_length * 4]  # truncate long inputs

        encoding = self.tokenizer(
            full_text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt"
        )

        context_ids = encoding["input_ids"].squeeze(0)  # (L,)
        query_ids = torch.tensor(
            random.choices(self.vocab, k=self.num_queries),
            dtype=torch.long
        )  # (num_queries,)

        return {
            "context_ids": context_ids,
            "query_ids": query_ids
        }

def get_hotpotqa_dataloader(tokenizer, batch_size=8, split="train"):
    dataset = HotpotQADataset(split=split, tokenizer=tokenizer)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

## Init Model

In [11]:
model = HierarchicalBinaryTree(
    d_model=768,
    tree_depth=3,
    k_sampling=1,
    train_seq_len=32
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
bert = BertModel.from_pretrained("bert-base-uncased")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

## Train Model

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm
def generate_topk_labels(query_ids, context_ids, k=5):
    """
    For each token in the query, compute the top-k attended context token positions.
    Returns: (B, Lq, K) long tensor of token indices in the context
    """
    B, Lq = query_ids.shape
    _, Lc = context_ids.shape

    input_ids = torch.cat([query_ids, context_ids], dim=1)  # (B, Lq + Lc)
    attention_mask = (input_ids != tokenizer.pad_token_id).long()

    with torch.no_grad():
        outputs = model.leaf_encoder.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=True
        )
        attn = outputs.attentions[-1]  # (layers, heads, B, T, T)
        avg_attn = attn.mean(dim=(0, 1))  # (B, T, T)

        # Only look at attention from query tokens to context tokens
        query_to_context_attn = avg_attn[:, :Lq, Lq:]  # (B, Lq, Lc)

        # Get top-k positions per query token
        topk_indices = torch.topk(query_to_context_attn, k=k, dim=-1).indices  # (B, Lq, K)

    return topk_indices  # (B, Lq, K) — token-level query-dependent top-k


def train_one_epoch(model, dataloader, optimizer, k=5):
    model.train()
    total_loss = 0

    for batch in tqdm(dataloader):
        context_ids = batch["context_ids"]  # (B, L)
        query_ids = batch["query_ids"]      # (B, Lq)
        attention_mask = (context_ids != tokenizer.pad_token_id).long()

        # Build tree from context
        model.build_retrieval_tree(context_ids)

        # Get pseudo-labels using BERT attention
        topk_indices = generate_topk_labels(query_ids, context_ids, attention_mask, k=k)  # (B, k)

        # Broadcast labels to match token-level (Lq) inputs
        topk_labels = topk_indices.unsqueeze(1).expand(-1, query_ids.shape[1], -1)  # (B, Lq, k)

        # Forward pass
        loss = model(query_ids, answer_positions=topk_labels)  # output["loss"], etc.

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()

    return total_loss / len(dataloader)

dataloader = get_hotpotqa_dataloader(tokenizer, trust_remote_code=True)

README.md:   0%|          | 0.00/9.19k [00:00<?, ?B/s]

hotpot_qa.py:   0%|          | 0.00/6.42k [00:00<?, ?B/s]

ValueError: Loading hotpot_qa requires you to execute the dataset script in that repo on your local machine. Make sure you have read the code there to avoid malicious use, then set the option `trust_remote_code=True` to remove this error.