#Linear Probe

#YOU DO NOT NEED GPU FOR THIS TASK, dont worry.

In this task, you will implement a training of linear probe.
Probing is a technique of checking if model has a specific information encoded in its hidden state. The simplest way to do it is to train a `probe`, different model that gets hidden state of the original model as input and tries to predict a specific information about this input (e.g. is the word currently processed a noun?). The simplest probe is a linear model.
In this notebook, we provide implementation of a simple model consisting of

1.   Embedding
2.   Two attention layers
3.   Classification head

This model is trained on a specific task:


1.   The input to the model is a sequence with exactly one 0 inside (e.g. 1,2,0,4,5,6)
2.   The model should predict the first number after 0 (you can assume that 0 is not on the last position) (in the example, it should be 4)
3. Output of the model is the output of the classification head on the last number in the sequence (in this case, 6)

Your task is to:



1.  (2p) Complete the `generate_batch` for the training specified above and run the training (should reach >90 % accuracy with the provided config, you don't need GPU)
2.  (2p) Complete the `get_model_embedding_at_positions` and `get_first_attention_output_at_positions` functions for the probe training.
3.  (6p) Write the training of the probe (complete `train_probe`). Probe should be trained to answer a task: "is the token immediately before the current token 0?" (e.g. in sequence 1, 2, 0, 4, 5, the probe should output number < 0 for "2","5" and > 0 for "4"). The probe should be a trained using BCEWithLogitsLoss.
It is up to you how you will sample negative examples for probe training, but you should sample them in the way that balances the number of positive and negative examples.
You should at least log the accuracy during training. Probe training should reach >90% accuracy when training on the output of the first attention, but should stay at ~50-60% when trained on the output of the embedding (as output of embedding has no information of the previous tokens)







In [1]:
import torch
import numpy as np
import pandas as pd
import torch.nn as nn

In [2]:
class Attention(nn.Module):
    def __init__(self, d_model, n_heads):
        super(Attention, self).__init__()
        self.norm = nn.RMSNorm(d_model)
        self.n_heads = n_heads
        self.d_model = d_model
        self.head_dim = d_model
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.attention_mechanism = nn.MultiheadAttention(d_model, n_heads, bias=False, batch_first=True)
    def forward(self, x):
        residual = x
        x = self.norm(x)
        q = self.w_q(x)
        k = self.w_k(x)
        v = self.w_v(x)
        mask = torch.triu(torch.ones(q.size(1), q.size(1)), diagonal=1).bool().to(x.device)
        attention_output, _ = self.attention_mechanism(q, k, v, attn_mask = mask, is_causal=True)
        return attention_output + residual

class FullEncoding(nn.Module):
    def __init__(self, d_model, vocab_size, seq_length):
        super(FullEncoding, self).__init__()
        self.d_model = d_model
        self.seq_length = seq_length
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(seq_length, d_model)
    def forward(self, x):
        token_embedding = self.token_embedding(x)
        position_embedding = self.position_embedding(torch.arange(x.size(1)).unsqueeze(0).to(x.device))
        return token_embedding + position_embedding




class MultiAttentionModel(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, seq_length):
        super(MultiAttentionModel, self).__init__()
        self.full_embedding = FullEncoding(d_model, vocab_size, seq_length)
        self.layers = nn.ModuleList([
            Attention(d_model, num_heads) for _ in range(num_layers)
        ])
        self.head = nn.Linear(d_model, vocab_size)
    def forward(self, x):
        x = self.full_embedding(x)
        for layer in self.layers:
            x = layer(x)
        return self.head(x)

def generate_sequences(vocab_size, batch_size, sequence_length):
    return torch.randint(1, vocab_size, (batch_size, sequence_length))

def generate_0_placements(batch_size, sequence_length):
    return torch.randint(0, sequence_length-1, (batch_size,))

def generate_batch(sequences, placements):
    """
    This function, given sequences and placements of 0 tokens, should
    generate a batch for MultiAttentionModel
    sequences: torch tensor of shape (batch_size, sequence_length)
    placements: torch tensor of shape (batch_size,)
    return:
    sequences_with_0: torch tensor of shape (batch_size, sequence_length), which
    is the same as sequences but with 0 at the given placements
    targets: torch tensor of shape (batch_size,), which is the number after
    0 in the sequences
    """
    batch_size = sequences.size(0)
    sequences_with_0 = sequences.clone()
    # Task 1

    indices = torch.arange(batch_size)

    sequences_with_0[indices, placements] = 0
    targets = sequences[indices, placements+1]
    # End Task 1

    return sequences_with_0, targets

example_sequences = torch.tensor(
    [
        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
        [11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
    ]
)
example_placements = torch.tensor([2, 7])
example_targets = torch.tensor([4, 19])
example_sequences_with_0 = torch.tensor(
    [
        [1, 2, 0, 4, 5, 6, 7, 8, 9, 10],
        [11, 12, 13, 14, 15, 16, 17, 0, 19, 20],
    ]
)
assert torch.all(generate_batch(example_sequences, example_placements)[0] == example_sequences_with_0)
assert torch.all(generate_batch(example_sequences, example_placements)[1] == example_targets)



def train(model, optimizer, batch_size, sequence_length, num_epochs, device):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        model.train()
        sequences = generate_sequences(vocab_size, batch_size, sequence_length)
        placements = generate_0_placements(batch_size, sequence_length)
        batch, target = generate_batch(sequences, placements)
        batch = batch.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output = model(batch)
        last_token_logits = output[:, -1, :]
        loss = criterion(last_token_logits, target)
        loss.backward()
        optimizer.step()
        if (epoch+1) % 100 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")
            predicted_token = torch.argmax(last_token_logits, dim=1)
            accuracy = (predicted_token == target).float().mean().item()
            print(f"Epoch {epoch+1}/{num_epochs}, Accuracy: {accuracy}")



batch_size = 64
sequence_length = 10
vocab_size = 128
d_model = 16
num_heads = 1
num_layers = 2
num_epochs = 2000
learning_rate = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
model = MultiAttentionModel(vocab_size, d_model, num_heads, num_layers, sequence_length)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
train(model, optimizer, batch_size, sequence_length, num_epochs, device)


Epoch 100/2000, Loss: 4.999610424041748
Epoch 100/2000, Accuracy: 0.0
Epoch 200/2000, Loss: 4.955053806304932
Epoch 200/2000, Accuracy: 0.0
Epoch 300/2000, Loss: 4.854390621185303
Epoch 300/2000, Accuracy: 0.0
Epoch 400/2000, Loss: 4.742520809173584
Epoch 400/2000, Accuracy: 0.015625
Epoch 500/2000, Loss: 4.804018020629883
Epoch 500/2000, Accuracy: 0.015625
Epoch 600/2000, Loss: 4.85786247253418
Epoch 600/2000, Accuracy: 0.015625
Epoch 700/2000, Loss: 4.441585063934326
Epoch 700/2000, Accuracy: 0.15625
Epoch 800/2000, Loss: 2.950584650039673
Epoch 800/2000, Accuracy: 0.25
Epoch 900/2000, Loss: 0.8388757109642029
Epoch 900/2000, Accuracy: 0.859375
Epoch 1000/2000, Loss: 0.17074620723724365
Epoch 1000/2000, Accuracy: 1.0
Epoch 1100/2000, Loss: 0.055065739899873734
Epoch 1100/2000, Accuracy: 1.0
Epoch 1200/2000, Loss: 0.025743212550878525
Epoch 1200/2000, Accuracy: 1.0
Epoch 1300/2000, Loss: 0.015844233334064484
Epoch 1300/2000, Accuracy: 1.0
Epoch 1400/2000, Loss: 0.012686382047832012
Ep

In [3]:
def get_model_embedding_at_positions(model: MultiAttentionModel, x, positions):
    """
    model: MultiAttentionModel
    x: torch tensor of shape (batch_size, sequence_length)
    positions: torch tensor of shape (batch_size,)

    Example: if model on x has embedding '[[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]]'
    and positions is [0, 2] then the output should be '[[0, 1], [10, 11]]'
    return: torch tensor of shape (batch_size, d_model)

    """
    # Task 2.1
    batch_size = x.size(0)
    all_embeddings = model.full_embedding(x)
    indices = torch.arange(batch_size, device=x.device)
    chosen_embeddings = all_embeddings[indices, positions]

    return chosen_embeddings
    # End Task 2.1

# batch, target = generate_batch(example_sequences, example_placements)
# batch_size = batch.size(0)
# print(get_model_embedding_at_positions(model, batch, torch.arange(batch_size)))

def get_first_attention_output_at_positions(model, x, positions):
    """
    model: MultiAttentionModel
    x: torch tensor of shape (batch_size, sequence_length)
    positions: torch tensor of shape (batch_size,)

    Example: if first attention of model on x outputs '[[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]]'
    and positions is [0, 2] then the output should be '[[0, 1], [10, 11]]'
    return: torch tensor of shape (batch_size, d_model)
    """
    # Task 2.2
    batch_size = x.size(0)
    x = model.full_embedding(x)
    all_attentions = model.layers[0](x)
    indices = torch.arange(batch_size, device=x.device)
    chosen_attentions = all_attentions[indices, positions]

    return chosen_attentions
    # End Task 2.2

# batch, target = generate_batch(example_sequences, example_placements)
# batch_size = batch.size(0)
# print(get_first_attention_output_at_positions(model, batch, torch.arange(batch_size)))

In [22]:
def train_probe(model, probe, optimizer, batch_size, sequence_length, num_epochs, device, model_latent_function):
    """
    Implement the training of the probe.
    model_latent_function: either get_model_embedding_at_positions or get_first_attention_output_at_positions
    """
    # Task 3
    model = model.to(device)
    probe = probe.to(device)

    criterion = nn.BCEWithLogitsLoss()
    for epoch in range(num_epochs):
        probe.train()
        sequences = generate_sequences(vocab_size, batch_size, sequence_length).to(device)
        placements = generate_0_placements(batch_size, sequence_length).to(device)



        batch, _ = generate_batch(sequences, placements)
        batch = batch.to(device)

        positive_instances = model_latent_function(model, batch, placements+1).to(device)
        positive_targets = torch.ones(batch_size, device=device)

        offsets = torch.randint(low=1, high=sequence_length-1, size=(batch_size, ), device=device)
        incorrect_placements = (placements+1+offsets) % sequence_length

        negative_instances = model_latent_function(model, batch, incorrect_placements).to(device)
        negative_targets = torch.zeros(batch_size, device=device)

        double_batch = torch.concatenate((positive_instances, negative_instances))
        double_target = torch.concatenate((positive_targets, negative_targets))

        batch_indices = torch.randperm(2*batch_size, device=device)[:batch_size]

        batch = double_batch[batch_indices]
        target = double_target[batch_indices]




        optimizer.zero_grad()
        output = probe(batch)
        loss = criterion(output.squeeze(), target)
        loss.backward()
        optimizer.step()
        if (epoch+1) % 100 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")
            predictions = (output.squeeze() > 0).float()
            accuracy = (predictions == target).float().mean().item()
            print(f"Epoch {epoch+1}/{num_epochs}, Accuracy: {accuracy}")

    # End Task 3






In [23]:
probe = nn.Linear(d_model, 1)
optimizer = torch.optim.Adam(probe.parameters(), lr=learning_rate)

train_probe(model, probe, optimizer, batch_size, sequence_length, num_epochs, device, get_model_embedding_at_positions)

Epoch 100/2000, Loss: 0.7376176118850708
Epoch 100/2000, Accuracy: 0.484375
Epoch 200/2000, Loss: 0.6592475771903992
Epoch 200/2000, Accuracy: 0.578125
Epoch 300/2000, Loss: 0.6951236724853516
Epoch 300/2000, Accuracy: 0.453125
Epoch 400/2000, Loss: 0.7018994092941284
Epoch 400/2000, Accuracy: 0.4375
Epoch 500/2000, Loss: 0.7027159929275513
Epoch 500/2000, Accuracy: 0.421875
Epoch 600/2000, Loss: 0.6955924034118652
Epoch 600/2000, Accuracy: 0.46875
Epoch 700/2000, Loss: 0.6729242205619812
Epoch 700/2000, Accuracy: 0.46875
Epoch 800/2000, Loss: 0.6636433601379395
Epoch 800/2000, Accuracy: 0.5625
Epoch 900/2000, Loss: 0.6671609878540039
Epoch 900/2000, Accuracy: 0.5625
Epoch 1000/2000, Loss: 0.6507031321525574
Epoch 1000/2000, Accuracy: 0.546875
Epoch 1100/2000, Loss: 0.686193585395813
Epoch 1100/2000, Accuracy: 0.484375
Epoch 1200/2000, Loss: 0.6740529537200928
Epoch 1200/2000, Accuracy: 0.5
Epoch 1300/2000, Loss: 0.7059669494628906
Epoch 1300/2000, Accuracy: 0.46875
Epoch 1400/2000, Lo

In [24]:
probe = nn.Linear(d_model, 1)
optimizer = torch.optim.Adam(probe.parameters(), lr=learning_rate)

train_probe(model, probe, optimizer, batch_size, sequence_length, num_epochs, device, get_first_attention_output_at_positions)


Epoch 100/2000, Loss: 0.3288019895553589
Epoch 100/2000, Accuracy: 0.890625
Epoch 200/2000, Loss: 0.20135006308555603
Epoch 200/2000, Accuracy: 1.0
Epoch 300/2000, Loss: 0.1499752402305603
Epoch 300/2000, Accuracy: 0.984375
Epoch 400/2000, Loss: 0.12741243839263916
Epoch 400/2000, Accuracy: 0.96875
Epoch 500/2000, Loss: 0.12439523637294769
Epoch 500/2000, Accuracy: 0.96875
Epoch 600/2000, Loss: 0.12360329926013947
Epoch 600/2000, Accuracy: 0.96875
Epoch 700/2000, Loss: 0.06919854879379272
Epoch 700/2000, Accuracy: 1.0
Epoch 800/2000, Loss: 0.09740914404392242
Epoch 800/2000, Accuracy: 0.96875
Epoch 900/2000, Loss: 0.057605065405368805
Epoch 900/2000, Accuracy: 0.984375
Epoch 1000/2000, Loss: 0.12252090126276016
Epoch 1000/2000, Accuracy: 0.96875
Epoch 1100/2000, Loss: 0.09889624267816544
Epoch 1100/2000, Accuracy: 0.953125
Epoch 1200/2000, Loss: 0.04261656105518341
Epoch 1200/2000, Accuracy: 1.0
Epoch 1300/2000, Loss: 0.05559863895177841
Epoch 1300/2000, Accuracy: 1.0
Epoch 1400/2000, 