In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader
import random
import math

## Convolutional Binary Tree Attention

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader
import random
import math
import matplotlib.pyplot as plt
import os

class SoftmaxTeacher(nn.Module):
    def __init__(self, hidden_size=768):
        super().__init__()
        self.linear_q = nn.Linear(hidden_size, hidden_size)
        self.linear_k = nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        Q = self.linear_q(x)  # (B, L, D)
        K = self.linear_k(x)  # (B, L, D)
        attn_logits = torch.matmul(Q, K.transpose(-1, -2)) / (x.size(-1) ** 0.5)  # (B, L, L)
        attn_scores = torch.softmax(attn_logits, dim=-1)
        return attn_scores

class RandomTokenDataset(Dataset):
    def __init__(self, tokenizer, num_samples=1000, seq_len=512):
        self.tokenizer = tokenizer
        self.vocab = list(tokenizer.get_vocab().values())
        self.num_samples = num_samples
        self.seq_len = seq_len

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        input_ids = torch.tensor(
            random.choices(self.vocab, k=self.seq_len), dtype=torch.long
        )
        attention_mask = torch.ones_like(input_ids)
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
        }

class BinaryTreeAttention(nn.Module):
    def __init__(self, hidden_size=768, chunk_size=64):
        super().__init__()
        self.hidden_size = hidden_size
        self.chunk_size = chunk_size
        self.depth = int(math.log2(512 // chunk_size))

        # self.convs = nn.ModuleList()
        # max_chunks = 512 // self.chunk_size
        # for i in range(self.depth):
        #     k = min(2 ** (i + 1), max_chunks)
        #     self.convs.append(
        #         nn.Conv1d(
        #             in_channels=hidden_size,
        #             out_channels=hidden_size,
        #             kernel_size=k,
        #             stride=k,
        #             padding=0
        #         )
        #     )
        #     max_chunks = max_chunks // 2
        self.convs = nn.ModuleList([
            nn.Conv1d(
                in_channels=hidden_size,
                out_channels=hidden_size,
                kernel_size=2 ** (i + 1),
                stride=2 ** (i + 1),
                padding=0
            )
            for i in range(self.depth)
        ])


        self.query_vectors = nn.ParameterList([
            nn.Parameter(torch.randn(hidden_size)) for _ in range(self.depth)
        ])

        os.makedirs("visuals", exist_ok=True)
       
    # def forward(self, x, query, label_positions=None):
    #     B, L, D = x.shape
    #     assert L % self.chunk_size == 0, "Sequence length must be divisible by chunk size"
    #     num_chunks = L // self.chunk_size

    #     x_chunks = x.view(B, num_chunks, self.chunk_size, D)
    #     summaries = x_chunks.mean(dim=2)  # (B, num_chunks, D)
    #     summaries = summaries.transpose(1, 2)  # (B, D, num_chunks)

    #     tree_levels = [summaries.transpose(1, 2)]
    #     for level in range(self.depth):
    #         print(summaries.shape[-1], self.convs[level].kernel_size[0])
    #         if summaries.shape[-1] < self.convs[level].kernel_size[0]:
    #             break
    #         # summaries = self.convs[level](summaries)  # (B, D, reduced_chunks)
    #         tree_levels.append(self.convs[level](summaries).transpose(1, 2))  # (B, chunks, D)

    #     for level in range(len(tree_levels) - 1):
    #         q_proj = query[0] * self.query_vectors[level]  # visualize batch 0 only
    #         q_proj = q_proj.detach().cpu().numpy()
    #         plt.figure()
    #         plt.plot(q_proj)
    #         plt.title(f"Query projection at level {level}")
    #         plt.xlabel("Dimension")
    #         plt.ylabel("Value")
    #         plt.savefig(f"visuals/q_proj_level_{level}.png")
    #         plt.close()

    #     attn_weights = torch.zeros(B, L, device=x.device)
    #     per_level_loss = 0.0

    #     for b in range(B):
    #         q_i = query[b]  # (D,)
    #         idx_range = list(range(num_chunks))
    #         current_level = 0

    #         if label_positions is not None:
    #             target_chunk = label_positions[b].item() // self.chunk_size

    #         while len(idx_range) > 1 and current_level < len(tree_levels) - 1:
    #             next_range = []
    #             losses = []
    #             max_index = tree_levels[current_level].shape[1]
    #             for j in range(0, len(idx_range), 2):
    #                 left_idx = idx_range[j]
    #                 if j + 1 >= len(idx_range):
    #                     next_range.append(left_idx)
    #                     continue
    #                 right_idx = idx_range[j + 1]
    #                 if left_idx >= max_index or right_idx >= max_index:
    #                     continue
    #                 left_sum = tree_levels[current_level][b, left_idx]
    #                 right_sum = tree_levels[current_level][b, right_idx]
    #                 q_proj = q_i * self.query_vectors[current_level]
    #                 score_left = torch.dot(q_proj, left_sum)
    #                 score_right = torch.dot(q_proj, right_sum)
    #                 logit = torch.stack([score_left, score_right])

    #                 if label_positions is not None:
    #                     decision = 0 if target_chunk % 2 == 0 else 1
    #                     loss = F.cross_entropy(logit.unsqueeze(0), torch.tensor([decision], device=x.device))
    #                     target_chunk = target_chunk // 2
    #                 else:
    #                     loss = torch.tensor(0.0, device=x.device)
    #                     decision = torch.argmax(logit).item()

    #                 losses.append(loss)
    #                 chosen = left_idx if decision == 0 else right_idx
    #                 next_range.append(chosen)

    #             idx_range = next_range
    #             current_level += 1
    #             if losses:
    #                 per_level_loss += sum(losses) / len(losses)

    #         final_chunk = idx_range[0]
    #         start = final_chunk * self.chunk_size
    #         end = start + self.chunk_size
    #         local_scores = torch.matmul(q_i, x[b, start:end].T)
    #         local_attn = F.softmax(local_scores, dim=-1)
    #         attn_weights[b, start:end] = local_attn

    #     return attn_weights, per_level_loss / B

    def build_summaries(self, x):
        B, L, D = x.shape
        assert L % self.chunk_size == 0, "Sequence length must be divisible by chunk size"
        num_chunks = L // self.chunk_size

        x_chunks = x.view(B, num_chunks, self.chunk_size, D)
        leaf_summaries = x_chunks.mean(dim=2)  # (B, num_chunks, D)

        all_levels = [leaf_summaries]
        # for conv in self.convs:
        #     input_summary = leaf_summaries.transpose(1, 2)  # (B, D, num_chunks)
        #     print(input_summary.size(-1), conv.kernel_size[0])
        #     if input_summary.size(-1) < conv.kernel_size[0]:
        #         break
        #     merged = conv(input_summary)  # (B, D, reduced_chunks)
        #     all_levels.append(merged.transpose(1, 2))  # back to (B, chunks, D)
        for conv in self.convs:
            if leaf_summaries.size(1) < conv.kernel_size[0]:
                break
            input_summary = leaf_summaries.transpose(1, 2)  # (B, D, chunks)
            merged = conv(input_summary)  # (B, D, reduced_chunks)
            all_levels.append(merged.transpose(1, 2))  # (B, reduced_chunks, D)

        # for i, level in enumerate(all_levels):
        #     print(f"Level {i}: {level.shape}")
            
        return all_levels

    def forward(self, x, query, label_positions=None):
        B, L, D = x.shape
        tree_levels = self.build_summaries(x)  # List of (B, chunks, D)
        attn_weights = torch.zeros(B, L, device=x.device)
        per_level_loss = 0.0

        for level in range(self.depth):
            q_proj = query[0] * self.query_vectors[level]
            q_proj = q_proj.detach().cpu().numpy()
            plt.figure()
            plt.plot(q_proj)
            plt.title(f"Query projection at level {level}")
            plt.xlabel("Dimension")
            plt.ylabel("Value")
            plt.savefig(f"visuals/q_proj_level_{level}.png")
            plt.close()

        for b in range(B):
            q_i = query[b]  # (D,)

            if label_positions is not None:
                target_token = label_positions[b].item()
                # print(target_token)
                target_chunk = target_token // self.chunk_size
                # print(f'{target_chunk}/{512 // self.chunk_size}')
                path = []
                node = target_chunk
                for level in reversed(range(self.depth)):
                    path.append(node)
                    node = node // 2
                path = list(reversed(path))

                for level, chunk_idx in enumerate(path):
                    # if level >= len(tree_levels):
                    #     break

                    summaries = tree_levels[self.depth - level - 1][b]  # (chunks, D)
                    # if chunk_idx * 2 + 1 >= summaries.size(0):
                    #     continue
                    # print(len(summaries), chunk_idx)
                    left = summaries[chunk_idx // 2 * 2]
                    right = summaries[chunk_idx // 2 * 2 + 1]

                    q_proj = q_i * self.query_vectors[level]
                    score_left = torch.dot(q_proj, left)
                    score_right = torch.dot(q_proj, right)
                    logit = torch.stack([score_left, score_right])
                    decision = 0 if (chunk_idx % 2 == 0) else 1
                    loss = F.cross_entropy(logit.unsqueeze(0), torch.tensor([decision], device=x.device))
                    per_level_loss += loss

                final_chunk = path[-1]
                start = final_chunk * self.chunk_size
                end = start + self.chunk_size
                if end <= L:
                    local_scores = torch.matmul(q_i, x[b, start:end].T)
                    local_attn = F.softmax(local_scores, dim=-1)
                    attn_weights[b, start:end] = local_attn

        return attn_weights, per_level_loss / B


class TreeAttentionModel(nn.Module):
    def __init__(self, encoder, tree_model):
        super().__init__()
        self.encoder = encoder
        self.tree_model = tree_model

    def forward(self, input_ids, attention_mask, label_pos, teacher_attn):
        with torch.no_grad():
            x = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        query = x[:, 0, :]  # (B, D), assuming CLS token is used as the query
        tree_attn, gate_loss = self.tree_model(x, query, label_pos)
        attn_loss = F.kl_div((tree_attn + 1e-8).log(), teacher_attn[:, 0], reduction='batchmean')
        return attn_loss + gate_loss

if __name__ == '__main__':
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    bert = BertModel.from_pretrained("bert-base-uncased")

    teacher = SoftmaxTeacher(hidden_size=768)
    dataset = RandomTokenDataset(tokenizer=tokenizer, num_samples=1000, seq_len=512)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

    tree_model = BinaryTreeAttention(hidden_size=768, chunk_size=64)
    model = TreeAttentionModel(bert, tree_model)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(3):
        for batch in dataloader:
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]

            with torch.no_grad():
                x = bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
                teacher_attn = teacher(x)
                topk = torch.topk(teacher_attn[:, 0], k=1, dim=-1).indices.squeeze(-1)

            loss = model(input_ids, attention_mask, topk, teacher_attn)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            print("Loss:", loss.item())


Loss: 14.052316665649414
Loss: 16.453035354614258
Loss: 15.346094131469727
Loss: 13.340925216674805
Loss: 13.54625129699707
Loss: 12.48184585571289
Loss: 15.070412635803223
Loss: 11.933829307556152
Loss: 15.686492919921875
Loss: 14.974651336669922
Loss: 13.273134231567383
Loss: 12.10181713104248
Loss: 12.965959548950195
Loss: 16.144683837890625
Loss: 14.920731544494629
Loss: 16.661026000976562
Loss: 11.881268501281738
Loss: 15.943611145019531
Loss: 15.138113975524902
Loss: 11.625199317932129
Loss: 14.276535034179688
Loss: 13.001980781555176
Loss: 12.072905540466309
Loss: 19.355669021606445
Loss: 12.75156021118164
Loss: 13.264352798461914
Loss: 15.138164520263672
Loss: 13.575176239013672
Loss: 13.412993431091309
Loss: 12.051664352416992
Loss: 16.39854621887207
Loss: 15.619536399841309
Loss: 13.621010780334473
Loss: 14.802406311035156
Loss: 13.913017272949219
Loss: 12.637627601623535
Loss: 13.242046356201172
Loss: 15.567933082580566
Loss: 13.310282707214355
Loss: 13.25949478149414
Loss: 

KeyboardInterrupt: 

## Binary Mergers Tree Attention

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader
import random
import math
import matplotlib.pyplot as plt
import os

class SoftmaxTeacher(nn.Module):
    def __init__(self, hidden_size=768):
        super().__init__()
        self.linear_q = nn.Linear(hidden_size, hidden_size)
        self.linear_k = nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        Q = self.linear_q(x)
        K = self.linear_k(x)
        attn_logits = torch.matmul(Q, K.transpose(-1, -2)) / (x.size(-1) ** 0.5)
        attn_scores = torch.softmax(attn_logits, dim=-1)
        return attn_scores

class RandomTokenDataset(Dataset):
    def __init__(self, tokenizer, num_samples=1000, seq_len=512):
        self.tokenizer = tokenizer
        self.vocab = list(tokenizer.get_vocab().values())
        self.num_samples = num_samples
        self.seq_len = seq_len

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        input_ids = torch.tensor(
            random.choices(self.vocab, k=self.seq_len), dtype=torch.long
        )
        attention_mask = torch.ones_like(input_ids)
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
        }

class BinaryTreeAttention(nn.Module):
    def __init__(self, hidden_size=768, chunk_size=64):
        super().__init__()
        self.hidden_size = hidden_size
        self.chunk_size = chunk_size
        self.depth = int(math.log2(512 // chunk_size))

        self.merge_layers = nn.ModuleList([
            nn.Linear(2 * hidden_size, hidden_size) for _ in range(self.depth)
        ])

        self.query_vectors = nn.ParameterList([
            nn.Parameter(torch.randn(hidden_size)) for _ in range(self.depth)
        ])

        os.makedirs("visuals", exist_ok=True)

    def forward(self, x, query, label_positions=None):
        B, L, D = x.shape
        assert L % self.chunk_size == 0, "Sequence length must be divisible by chunk size"
        num_chunks = L // self.chunk_size

        x_chunks = x.view(B, num_chunks, self.chunk_size, D)
        summaries = x_chunks.mean(dim=2)  # (B, num_chunks, D)

        attn_weights = torch.zeros(B, L, device=x.device)
        per_level_loss = 0.0

        for level in range(self.depth):
            q_proj = query[0] * self.query_vectors[level]
            q_proj = q_proj.detach().cpu().numpy()
            plt.figure()
            plt.plot(q_proj)
            plt.title(f"Query projection at level {level}")
            plt.xlabel("Dimension")
            plt.ylabel("Value")
            plt.savefig(f"visuals/q_proj_level_{level}.png")
            plt.close()

        for b in range(B):
            q_i = query[b]  # (D,)
            idx_range = list(range(num_chunks))
            current_level = 0

            if label_positions is not None:
                target_chunk = label_positions[b].item() // self.chunk_size

            while len(idx_range) > 1 and current_level < self.depth:
                next_range = []
                losses = []
                for j in range(0, len(idx_range), 2):
                    left_idx = idx_range[j]
                    if j + 1 >= len(idx_range):
                        next_range.append(left_idx)
                        continue
                    right_idx = idx_range[j + 1]
                    left = summaries[b, left_idx]
                    right = summaries[b, right_idx]
                    merged = self.merge_layers[current_level](torch.cat([left, right], dim=-1))
                    q_proj = q_i * self.query_vectors[current_level]
                    score_left = torch.dot(q_proj, left)
                    score_right = torch.dot(q_proj, right)
                    logit = torch.stack([score_left, score_right])

                    if label_positions is not None:
                        decision = 0 if target_chunk % 2 == 0 else 1
                        loss = F.cross_entropy(logit.unsqueeze(0), torch.tensor([decision], device=x.device))
                        target_chunk = target_chunk // 2
                    else:
                        loss = torch.tensor(0.0, device=x.device)
                        decision = torch.argmax(logit).item()

                    losses.append(loss)
                    chosen = left_idx if decision == 0 else right_idx
                    next_range.append(chosen)
                idx_range = next_range
                current_level += 1
                if losses:
                    per_level_loss += sum(losses) / len(losses)

            final_chunk = idx_range[0]
            start = final_chunk * self.chunk_size
            end = start + self.chunk_size
            local_scores = torch.matmul(q_i, x[b, start:end].T)
            local_attn = F.softmax(local_scores, dim=-1)
            attn_weights[b, start:end] = local_attn

        return attn_weights, per_level_loss / B

class TreeAttentionModel(nn.Module):
    def __init__(self, encoder, tree_model):
        super().__init__()
        self.encoder = encoder
        self.tree_model = tree_model

    def forward(self, input_ids, attention_mask, label_pos, teacher_attn):
        with torch.no_grad():
            x = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        query = x[:, 0, :]
        tree_attn, gate_loss = self.tree_model(x, query, label_pos)
        attn_loss = F.kl_div((tree_attn + 1e-8).log(), teacher_attn[:, 0], reduction='batchmean')
        return attn_loss + gate_loss

if __name__ == '__main__':
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    bert = BertModel.from_pretrained("bert-base-uncased")

    teacher = SoftmaxTeacher(hidden_size=768)
    dataset = RandomTokenDataset(tokenizer=tokenizer, num_samples=1000, seq_len=512)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

    tree_model = BinaryTreeAttention(hidden_size=768, chunk_size=64)
    model = TreeAttentionModel(bert, tree_model)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(3):
        for batch in dataloader:
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]

            with torch.no_grad():
                x = bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
                teacher_attn = teacher(x)
                topk = torch.topk(teacher_attn[:, 0], k=1, dim=-1).indices.squeeze(-1)

            loss = model(input_ids, attention_mask, topk, teacher_attn)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            print("Loss:", loss.item())


Loss: 13.965141296386719
Loss: 13.2096586227417
Loss: 13.895458221435547
Loss: 17.250762939453125
Loss: 13.641131401062012
Loss: 13.940834045410156
Loss: 15.438117980957031
Loss: 16.266212463378906
Loss: 13.824352264404297
Loss: 14.223379135131836
Loss: 15.071582794189453
Loss: 17.11172103881836
Loss: 14.512250900268555
Loss: 16.729019165039062
Loss: 16.687368392944336
Loss: 15.245388984680176
Loss: 13.9462308883667
Loss: 13.948664665222168
Loss: 18.147167205810547
Loss: 13.607330322265625
Loss: 15.21592903137207
Loss: 16.46346092224121
Loss: 13.606891632080078
Loss: 15.262789726257324
Loss: 17.201038360595703
Loss: 15.359214782714844
Loss: 14.675223350524902
Loss: 13.768728256225586
Loss: 14.734827995300293
Loss: 13.728902816772461
Loss: 14.884159088134766
Loss: 20.467571258544922
Loss: 12.927183151245117
Loss: 13.422428131103516
Loss: 13.62571907043457
Loss: 14.110284805297852
Loss: 14.000645637512207
Loss: 14.118144989013672
Loss: 13.953058242797852
Loss: 20.6766414642334
Loss: 13.0

KeyboardInterrupt: 

## Self Attention Tree

In [54]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os

class BinaryTreeAttention(nn.Module):
    def __init__(self, hidden_size=768, chunk_size=64, depth=5):
        super().__init__()
        self.hidden_size = hidden_size
        self.chunk_size = chunk_size
        self.depth = depth

        self.query_vectors = nn.ParameterList([
            nn.Parameter(torch.randn(hidden_size)) for _ in range(depth)
        ])

        self.router_vectors = nn.ParameterList([
            nn.Parameter(torch.randn(hidden_size)) for _ in range(depth)
        ])

        os.makedirs("visuals", exist_ok=True)
        
    def build_summaries(self, x):
        B, L, D = x.shape
        assert L % self.chunk_size == 0, "Sequence length must be divisible by chunk size"
        num_chunks = L // self.chunk_size

        x_chunks = x.view(B, num_chunks, self.chunk_size, D)
        leaf_summaries = x_chunks.mean(dim=2)  # (B, num_chunks, D)

        all_levels = [leaf_summaries]

        for level, router in enumerate(self.router_vectors):
            group_size = 2 ** level
            num_groups = num_chunks // group_size
            summaries = []

            if num_groups == 0:
                break

            for g in range(num_groups):
                group = leaf_summaries[:, g * group_size:(g + 1) * group_size, :]  # (B, group_size, D)
                logits = torch.einsum('bnd,d->bn', group, router)  # (B, group_size)
                attn_weights = torch.softmax(logits, dim=-1).unsqueeze(-1)  # (B, group_size, 1)
                summary = (attn_weights * group).sum(dim=1)  # (B, D)
                summaries.append(summary.unsqueeze(1))

            summaries = torch.cat(summaries, dim=1)  # (B, num_groups, D)
            all_levels.append(summaries)

        return all_levels

    def forward(self, x, query, label_positions=None):
        B, L, D = x.shape
        tree_levels = self.build_summaries(x)  # List of (B, chunks, D)
        attn_weights = torch.zeros(B, L, device=x.device)
        per_level_loss = 0.0

        for level in range(self.depth):
            q_proj = query[0] * self.query_vectors[level]
            q_plot = q_proj.detach().cpu().numpy()
            plt.figure()
            plt.plot(q_plot)
            plt.title(f"Query projection at level {level}")
            plt.xlabel("Dimension")
            plt.ylabel("Value")
            plt.savefig(f"visuals/q_proj_level_{level}.png")
            plt.close()

        for b in range(B):
            q_i = query[b]  # (D,)

            if label_positions is not None:
                target_token = label_positions[b].item()
                target_chunk = target_token // self.chunk_size
                path = []
                node = target_chunk
                for level in reversed(range(self.depth)):
                    path.append(node)
                    node = node // 2
                path = list(reversed(path))

                for level, chunk_idx in enumerate(path):
                    summaries = tree_levels[self.depth - level - 2][b]  # (num_summaries, D)

                    q_proj = q_i * self.query_vectors[level]  # (D,)
                    logits = torch.matmul(summaries, q_proj)  # (num_summaries,)
                    weights = F.softmax(logits, dim=0)  # (num_summaries,)

                    # Find the correct index for supervision
                    target_idx = chunk_idx // (2 ** level)
                    target_weight = weights[target_idx]
                    loss = -torch.log(target_weight + 1e-8)
                    per_level_loss += loss

                final_chunk = path[-1]
                start = final_chunk * self.chunk_size
                end = start + self.chunk_size
                if end <= L:
                    local_scores = torch.matmul(q_i, x[b, start:end].T)
                    local_attn = F.softmax(local_scores, dim=-1)
                    attn_weights[b, start:end] = local_attn

        return attn_weights, per_level_loss / B

    # def build_summaries(self, x):
    #     B, L, D = x.shape
    #     assert L % self.chunk_size == 0, "Sequence length must be divisible by chunk size"
    #     num_chunks = L // self.chunk_size

    #     x_chunks = x.view(B, num_chunks, self.chunk_size, D)
    #     leaf_summaries = x_chunks.mean(dim=2)  # (B, num_chunks, D)

    #     all_levels = [leaf_summaries]

    #     for level, router in enumerate(self.router_vectors):
    #         group_size = 2 ** level
    #         num_groups = num_chunks // group_size
            
    #         if num_groups == 1:
    #             break
    #         summaries = []

    #         for g in range(num_groups):
    #             group = leaf_summaries[:, g * group_size:(g + 1) * group_size, :]  # (B, group_size, D)
    #             logits = torch.einsum('bnd,d->bn', group, router)  # (B, group_size)
    #             attn_weights = torch.softmax(logits, dim=-1).unsqueeze(-1)  # (B, group_size, 1)
    #             summary = (attn_weights * group).sum(dim=1)  # (B, D)
    #             summaries.append(summary.unsqueeze(1))

    #         summaries = torch.cat(summaries, dim=1)  # (B, num_groups, D)
    #         all_levels.append(summaries)

    #     return all_levels

    # def forward(self, x, query, label_positions=None):
    #     B, L, D = x.shape
    #     tree_levels = self.build_summaries(x)  # List of (B, chunks, D)
    #     attn_weights = torch.zeros(B, L, device=x.device)
    #     per_level_loss = 0.0

    #     for level in range(self.depth):
    #         q_proj = query[0] * self.query_vectors[level]
    #         q_plot = q_proj.detach().cpu().numpy()
    #         plt.figure()
    #         plt.plot(q_plot)
    #         plt.title(f"Query projection at level {level}")
    #         plt.xlabel("Dimension")
    #         plt.ylabel("Value")
    #         plt.savefig(f"visuals/q_proj_level_{level}.png")
    #         plt.close()

    #     for b in range(B):
    #         q_i = query[b]  # (D,)

    #         if label_positions is not None:
    #             target_token = label_positions[b].item()
    #             # print(target_token)
    #             target_chunk = target_token // self.chunk_size
    #             path = []
    #             node = target_chunk
    #             # print(f'{target_chunk}/{512 // self.chunk_size}')
    #             for level in reversed(range(self.depth - 1)):
    #                 path.append(node)
    #                 node = node // 2
    #             path = list(reversed(path))

    #             for level, chunk_idx in enumerate(path):
    #                 # if self.depth - level - 1 >= len(tree_levels):
    #                 #     break
    #                 # print(len(tree_levels), self.depth)
    #                 summaries = tree_levels[self.depth - level - 2][b]  # (chunks, D)
    #                 # if chunk_idx * 2 + 1 >= summaries.size(0):
    #                 #     continue
    #                 # print(len(summaries), chunk_idx)
    #                 left = summaries[chunk_idx // 2 * 2]
    #                 right = summaries[chunk_idx // 2 * 2 + 1]

    #                 q_proj = q_i * self.query_vectors[level]
    #                 score_left = torch.dot(q_proj, left)
    #                 score_right = torch.dot(q_proj, right)
    #                 logit = torch.stack([score_left, score_right])
    #                 decision = 0 if (chunk_idx % 2 == 0) else 1
    #                 label = torch.tensor([decision], dtype=torch.long, device=x.device)
    #                 loss = F.cross_entropy(logit.unsqueeze(0), label)

    #                 # loss = F.cross_entropy(logit.unsqueeze(0), torch.tensor([decision], device=x.device))
    #                 per_level_loss += loss
    #                 # print(per_level_loss.requires_grad)

    #             final_chunk = path[-1]
    #             start = final_chunk * self.chunk_size
    #             end = start + self.chunk_size
    #             if end <= L:
    #                 local_scores = torch.matmul(q_i, x[b, start:end].T)
    #                 local_attn = F.softmax(local_scores, dim=-1)
    #                 attn_weights[b, start:end] = local_attn

    #     return attn_weights, per_level_loss / B

if __name__ == '__main__':
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    bert = BertModel.from_pretrained("bert-base-uncased")

    teacher = SoftmaxTeacher(hidden_size=768)
    dataset = RandomTokenDataset(tokenizer=tokenizer, num_samples=1000, seq_len=512)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

    tree_model = BinaryTreeAttention(hidden_size=768, chunk_size=64)
    model = TreeAttentionModel(bert, tree_model)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(3):
        for batch in dataloader:
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]

            with torch.no_grad():
                x = bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
                teacher_attn = teacher(x)
                topk = torch.topk(teacher_attn[:, 0], k=1, dim=-1).indices.squeeze(-1)

            loss = model(input_ids, attention_mask, topk, teacher_attn)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            print("Loss:", loss.item())


Loss: 23.362707138061523
Loss: 24.99822998046875
Loss: 21.739498138427734
Loss: 21.880748748779297
Loss: 21.961990356445312
Loss: 24.311073303222656
Loss: 20.525001525878906
Loss: 20.038055419921875
Loss: 19.34503936767578
Loss: 16.988056182861328
Loss: 20.15441131591797
Loss: 23.3741397857666
Loss: 19.297767639160156
Loss: 22.010738372802734
Loss: 24.37759780883789
Loss: 23.678829193115234
Loss: 20.929767608642578
Loss: 20.05202865600586
Loss: 21.520214080810547
Loss: 17.408260345458984
Loss: 20.376911163330078
Loss: 23.20523452758789
Loss: 21.76390838623047
Loss: 19.359477996826172
Loss: 23.303855895996094
Loss: 19.793548583984375
Loss: 18.40351104736328
Loss: 19.889476776123047
Loss: 20.13365936279297
Loss: 19.169321060180664
Loss: 17.88991355895996
Loss: 18.278301239013672
Loss: 21.093997955322266
Loss: 21.939958572387695


KeyboardInterrupt: 