In [2]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
from datasets import load_dataset
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.linalg import sqrtm


In [37]:
def make_weight_list(samples=16, epochs=100):
    weights_samples_1 = []
    weights_samples_2 = []

    for _ in tqdm(range(samples)):
        # Step 1: Load the TREC dataset
        dataset = load_dataset("trec", split="train")
        dataset = dataset.train_test_split(test_size=0.2)
        train_data = dataset['train']
        val_data = dataset['test']

        # Step 2: Preprocess the data with a tokenizer
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        def preprocess(data):
            return tokenizer(data['text'], truncation=True, padding=True, max_length=32)

        train_data = train_data.map(preprocess, batched=True)
        val_data = val_data.map(preprocess, batched=True)

        train_data = train_data.rename_column("coarse_label", "label")
        val_data = val_data.rename_column("coarse_label", "label")

        train_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
        val_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

        # Step 3: Create DataLoaders
        def collate_fn(batch):
            # Ensure all input tensors have the same length by padding them dynamically
            input_ids = torch.nn.utils.rnn.pad_sequence([b['input_ids'] for b in batch], batch_first=True, padding_value=tokenizer.pad_token_id)
            attention_mask = torch.nn.utils.rnn.pad_sequence([b['attention_mask'] for b in batch], batch_first=True, padding_value=0)
            labels = torch.tensor([b['label'] for b in batch])
            return {"input_ids": input_ids, "attention_mask": attention_mask, "label": labels}

        train_loader = DataLoader(train_data, batch_size=16, shuffle=True, collate_fn=collate_fn)
        val_loader = DataLoader(val_data, batch_size=16, collate_fn=collate_fn)

        # Step 4: Define a custom PyTorch model
        class CustomClassifier(nn.Module):
            def __init__(self, input_dim, hidden_dim, num_classes, num_attention_layers=2, vocab_size=30522):
                super(CustomClassifier, self).__init__()
                
                # Embedding layer to map token IDs to embeddings
                self.embedding = nn.Embedding(vocab_size, input_dim)  # vocab_size is typically 30522 for BERT
                
                # Define a list to store multiple attention layers
                self.attention_layers = nn.ModuleList([nn.Linear(input_dim, hidden_dim) for _ in range(num_attention_layers)])
                
                # Initialise the norm of the attention layers small enough
                for layer in self.attention_layers:
                    nn.init.normal_(layer.weight, mean=0, std=1e-1)
                    nn.init.constant_(layer.bias, 0)

                # Final fully connected layer for classification
                self.fc = nn.Linear(input_dim, num_classes)

            def _attention_layer(self, x, layer):
                # Project the input using the hidden layer
                x_new = self.attention_layers[layer](x)

                # Self attention (scaled dot-product)
                attention_scores = torch.matmul(x_new, x_new.transpose(1, 2)) / x_new.size(-1) ** 0.5  # Scaled attention scores
                attention_weights = torch.nn.functional.softmax(attention_scores, dim=-1)

                # Apply attention weights to each input in the batch
                x = torch.einsum("nij,njd->nid", attention_weights, x)
                return x

            def forward(self, input_ids, attention_mask):
                # Get the embeddings from the input IDs
                embedded = self.embedding(input_ids)
                
                # Apply attention layers
                for layer in range(len(self.attention_layers)):
                    embedded = self._attention_layer(embedded, layer)
                
                # Mean pooling (average the token embeddings across the sequence)
                pooled_output = embedded.mean(dim=1)  # Average over the token dimension
                
                # Pass through the final fully connected layer
                logits = self.fc(pooled_output)
                return logits

        model = CustomClassifier(input_dim=128, hidden_dim=64, num_classes=6)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)

        # Step 5: Define optimizer and loss function
        optimizer = AdamW(model.parameters(), lr=5e-5)
        criterion = torch.nn.CrossEntropyLoss()

        weights_attention_1 = []
        weights_attention_2 = []

        # Step 6: Train the model
        def train(epoch):
            model.train()
            for batch in train_loader:
                optimizer.zero_grad()
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)

                logits = model(input_ids=input_ids, attention_mask=attention_mask)
                loss = criterion(logits, labels)
                loss.backward()
                optimizer.step()

                weights_attention_1.append(model.attention_layers[0].weight.data.cpu().numpy())
                weights_attention_2.append(model.attention_layers[1].weight.data.cpu().numpy())

            print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

        # Step 7: Evaluate the model
        def evaluate():
            model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for batch in val_loader:
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    labels = batch['label'].to(device)

                    logits = model(input_ids=input_ids, attention_mask=attention_mask)
                    predictions = torch.argmax(logits, dim=-1)
                    correct += (predictions == labels).sum().item()
                    total += labels.size(0)

            print(f"Validation Accuracy: {correct / total:.4f}")

        # Run training and evaluation
        for epoch in range(epochs):
            train(epoch)
            evaluate()

        weights_samples_1.append(weights_attention_1)
        weights_samples_2.append(weights_attention_2)

    return weights_samples_1, weights_samples_2




def make_distance_measure(weights_attention_1, weights_attention_2):

    dist_1 = np.zeros(len(weights_attention_1))
    dist_2 = np.zeros(len(weights_attention_1))

    norm_1 = np.zeros((weights_attention_1[0].shape[0], weights_attention_1[0].shape[0]))
    norm_2 = np.zeros_like(norm_1)
    mag_1 = np.zeros_like(norm_1)
    mag_2 = np.zeros_like(norm_1)

    norm_1_final = np.linalg.norm(weights_attention_1[-1], axis=1)
    norm_2_final = np.linalg.norm(weights_attention_2[-1], axis=1)


    ### THIS USES SQUARED WEIGHTS
    weights_attention_1_squared = []
    weights_attention_2_squared = []
    for i in tqdm(range(len(weights_attention_1))):
        weights_attention_1_squared.append(np.einsum("pi,qi->pq", weights_attention_1[i],weights_attention_1[i]))
        weights_attention_2_squared.append(np.einsum("pi,qi->pq", weights_attention_2[i],weights_attention_2[i]))

    norm_1_final = np.linalg.norm(weights_attention_1_squared[-1], ord="fro")
    norm_2_final = np.linalg.norm(weights_attention_2_squared[-1], ord="fro")

    for i in tqdm(range(len(weights_attention_1))):
        norm_1 = np.linalg.norm(weights_attention_1_squared[i], ord="fro")
        norm_2 = np.linalg.norm(weights_attention_2_squared[i], ord="fro")

        mag_1 = np.einsum("pq,qk->pk", weights_attention_1_squared[i],weights_attention_1_squared[-1])
        mag_2 = np.einsum("pq,qk->pk", weights_attention_2_squared[i],weights_attention_2_squared[-1])

        dist_1[i] = np.trace(mag_1)/norm_1_final/norm_1
        dist_2[i] = np.trace(mag_2)/norm_2_final/norm_2

    return dist_1, dist_2


In [None]:
samples = 1
epochs = 10

weights_samples_1, weights_samples_2 = make_weight_list(samples, epochs)

dist_samples_1 = np.zeros((samples, len(weights_samples_1[0])))
dist_samples_2 = np.zeros((samples, len(weights_samples_1[0])))

for i in tqdm(range(len(weights_samples_1))):
    dist_samples_1[i], dist_samples_2[i] = make_distance_measure(weights_samples_1[i], weights_samples_2[i])

In [12]:
np.save("realistic_dist_samples_1.npy", dist_samples_1)
np.save("realistic_dist_samples_2.npy", dist_samples_2)
