#Linear Probe

#YOU DO NOT NEED GPU FOR THIS TASK, dont worry.

In this task, you will implement a training of linear probe.
Probing is a technique of checking if model has a specific information encoded in its hidden state. The simplest way to do it is to train a `probe`, different model that gets hidden state of the original model as input and tries to predict a specific information about this input (e.g. is the word currently processed a noun?). The simplest probe is a linear model.
In this notebook, we provide implementation of a simple model consisting of

1.   Embedding
2.   Two attention layers
3.   Classification head

This model is trained on a specific task:


1.   The input to the model is a sequence with exactly one 0 inside (e.g. 1,2,0,4,5,6)
2.   The model should predict the first number after 0 (you can assume that 0 is not on the last position) (in the example, it should be 4)
3. Output of the model is the output of the classification head on the last number in the sequence (in this case, 6)

Your task is to:



1.  (2p) Complete the `generate_batch` for the training specified above and run the training (should reach >90 % accuracy with the provided config, you don't need GPU)
2.  (2p) Complete the `get_model_embedding_at_positions` and `get_first_attention_output_at_positions` functions for the probe training.
3.  (6p) Write the training of the probe (complete `train_probe`). Probe should be trained to answer a task: "is the token immediately before the current token 0?" (e.g. in sequence 1, 2, 0, 4, 5, the probe should output number < 0 for "2","5" and > 0 for "4"). The probe should be a trained using BCEWithLogitsLoss.
It is up to you how you will sample negative examples for probe training, but you should sample them in the way that balances the number of positive and negative examples.
You should at least log the accuracy during training. Probe training should reach >90% accuracy when training on the output of the first attention, but should stay at ~50-60% when trained on the output of the embedding (as output of embedding has no information of the previous tokens)







In [None]:
import torch
import numpy as np
import pandas as pd
import torch.nn as nn

In [None]:
class Attention(nn.Module):
    def __init__(self, d_model, n_heads):
        super(Attention, self).__init__()
        self.norm = nn.RMSNorm(d_model)
        self.n_heads = n_heads
        self.d_model = d_model
        self.head_dim = d_model
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.attention_mechanism = nn.MultiheadAttention(d_model, n_heads, bias=False, batch_first=True)
    def forward(self, x):
        residual = x
        x = self.norm(x)
        q = self.w_q(x)
        k = self.w_k(x)
        v = self.w_v(x)
        mask = torch.triu(torch.ones(q.size(1), q.size(1)), diagonal=1).bool().to(x.device)
        attention_output, _ = self.attention_mechanism(q, k, v, attn_mask = mask, is_causal=True)
        return attention_output + residual

class FullEncoding(nn.Module):
    def __init__(self, d_model, vocab_size, seq_length):
        super(FullEncoding, self).__init__()
        self.d_model = d_model
        self.seq_length = seq_length
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(seq_length, d_model)
    def forward(self, x):
        token_embedding = self.token_embedding(x)
        position_embedding = self.position_embedding(torch.arange(x.size(1)).unsqueeze(0).to(x.device))
        return token_embedding + position_embedding




class MultiAttentionModel(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, seq_length):
        super(MultiAttentionModel, self).__init__()
        self.full_embedding = FullEncoding(d_model, vocab_size, seq_length)
        self.layers = nn.ModuleList([
            Attention(d_model, num_heads) for _ in range(num_layers)
        ])
        self.head = nn.Linear(d_model, vocab_size)
    def forward(self, x):
        x = self.full_embedding(x)
        for layer in self.layers:
            x = layer(x)
        return self.head(x)

def generate_sequences(vocab_size, batch_size, sequence_length):
    return torch.randint(1, vocab_size, (batch_size, sequence_length))

def generate_0_placements(batch_size, sequence_length):
    return torch.randint(0, sequence_length-1, (batch_size,))

def generate_batch(sequences, placements):
    """
    This function, given sequences and placements of 0 tokens, should
    generate a batch for MultiAttentionModel
    sequences: torch tensor of shape (batch_size, sequence_length)
    placements: torch tensor of shape (batch_size,)
    return:
    sequences_with_0: torch tensor of shape (batch_size, sequence_length), which
    is the same as sequences but with 0 at the given placements
    targets: torch tensor of shape (batch_size,), which is the number after
    0 in the sequences
    """
    batch_size = sequences.size(0)
    sequences_with_0 = sequences.clone()
    # Task 1
    # Place 0 at the specified positions
    for i in range(batch_size):
        sequences_with_0[i, placements[i]] = 0

    # Get the targets - the token immediately after the 0
    targets = sequences[torch.arange(batch_size), placements + 1]
    # End Task 1

    return sequences_with_0, targets

example_sequences = torch.tensor(
    [
        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
        [11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
    ]
)
example_placements = torch.tensor([2, 7])
example_targets = torch.tensor([4, 19])
example_sequences_with_0 = torch.tensor(
    [
        [1, 2, 0, 4, 5, 6, 7, 8, 9, 10],
        [11, 12, 13, 14, 15, 16, 17, 0, 19, 20],
    ]
)
assert torch.all(generate_batch(example_sequences, example_placements)[0] == example_sequences_with_0)
assert torch.all(generate_batch(example_sequences, example_placements)[1] == example_targets)



def train(model, optimizer, batch_size, sequence_length, num_epochs, device):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        model.train()
        sequences = generate_sequences(vocab_size, batch_size, sequence_length)
        placements = generate_0_placements(batch_size, sequence_length)
        batch, target = generate_batch(sequences, placements)
        batch = batch.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output = model(batch)
        last_token_logits = output[:, -1, :]
        loss = criterion(last_token_logits, target)
        loss.backward()
        optimizer.step()
        if (epoch+1) % 100 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")
            predicted_token = torch.argmax(last_token_logits, dim=1)
            accuracy = (predicted_token == target).float().mean().item()
            print(f"Epoch {epoch+1}/{num_epochs}, Accuracy: {accuracy}")



batch_size = 64
sequence_length = 10
vocab_size = 128
d_model = 16
num_heads = 1
num_layers = 2
num_epochs = 2000
learning_rate = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
model = MultiAttentionModel(vocab_size, d_model, num_heads, num_layers, sequence_length)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
train(model, optimizer, batch_size, sequence_length, num_epochs, device)


In [None]:

def get_model_embedding_at_positions(model: MultiAttentionModel, x, positions):
    """
    model: MultiAttentionModel
    x: torch tensor of shape (batch_size, sequence_length)
    positions: torch tensor of shape (batch_size,)

    Example: if model on x has embedding '[[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]]'
    and positions is [0, 2] then the output should be '[[0, 1], [10, 11]]'
    return: torch tensor of shape (batch_size, d_model)

    """
    # Task 2.1
    with torch.no_grad():
        embeddings = model.full_embedding(x)  # (batch_size, sequence_length, d_model)

    # Select embeddings at the specified positions for each batch element
    batch_indices = torch.arange(x.size(0))
    selected_embeddings = embeddings[batch_indices, positions]  # (batch_size, d_model)

    return selected_embeddings
    # End Task 2.1

def get_first_attention_output_at_positions(model, x, positions):
    """
    model: MultiAttentionModel
    x: torch tensor of shape (batch_size, sequence_length)
    positions: torch tensor of shape (batch_size,)

    Example: if first attention of model on x outputs '[[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]]'
    and positions is [0, 2] then the output should be '[[0, 1], [10, 11]]'
    return: torch tensor of shape (batch_size, d_model)
    """
    # Task 2.2
    with torch.no_grad():
        x_embed = model.full_embedding(x)
        # Pass through the first attention layer
        first_attention_output = model.layers[0](x_embed)  # (batch_size, sequence_length, d_model)

    # Select outputs at the specified positions for each batch element
    batch_indices = torch.arange(x.size(0))
    selected_outputs = first_attention_output[batch_indices, positions]  # (batch_size, d_model)

    return selected_outputs
    # End Task 2.2



In [None]:
def train_probe(model, probe, optimizer, batch_size, sequence_length, num_epochs, device, model_latent_function):
    """
    Implement the training of the probe.
    model_latent_function: either get_model_embedding_at_positions or get_first_attention_output_at_positions
    """
    # Task 3
    model = model.to(device)
    probe = probe.to(device)
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(num_epochs):
        probe.train()

        # Generate sequences and placements
        sequences = generate_sequences(vocab_size, batch_size, sequence_length)
        placements = generate_0_placements(batch_size, sequence_length)
        batch, _ = generate_batch(sequences, placements)
        batch = batch.to(device)

        # We need to create balanced positive and negative examples
        # Positive examples: positions right after 0 (placements + 1)
        # Negative examples: random positions that are NOT right after 0

        # Positive examples - tokens immediately after 0
        positive_positions = placements + 1  # (batch_size,)
        positive_latents = model_latent_function(model, batch, positive_positions)
        positive_labels = torch.ones(batch_size, 1).to(device)

        # Negative examples - sample positions that are not immediately after 0
        # We'll sample positions that are at least 2 positions away from the 0
        negative_positions = torch.zeros(batch_size, dtype=torch.long)
        for i in range(batch_size):
            # Available positions are all except placements[i] and placements[i]+1
            available_positions = list(range(sequence_length))
            # Remove the 0 position and the position right after it
            if placements[i].item() in available_positions:
                available_positions.remove(placements[i].item())
            if (placements[i] + 1).item() in available_positions:
                available_positions.remove((placements[i] + 1).item())
            # Randomly select one of the available positions
            if available_positions:
                negative_positions[i] = torch.tensor(available_positions[torch.randint(0, len(available_positions), (1,)).item()])
            else:
                # Fallback (shouldn't happen with sequence_length=10)
                negative_positions[i] = 0

        negative_latents = model_latent_function(model, batch, negative_positions)
        negative_labels = torch.zeros(batch_size, 1).to(device)

        # Combine positive and negative examples
        all_latents = torch.cat([positive_latents, negative_latents], dim=0)
        all_labels = torch.cat([positive_labels, negative_labels], dim=0)

        # Shuffle the combined data
        shuffle_indices = torch.randperm(all_latents.size(0))
        all_latents = all_latents[shuffle_indices]
        all_labels = all_labels[shuffle_indices]

        # Forward pass
        optimizer.zero_grad()
        predictions = probe(all_latents)
        loss = criterion(predictions, all_labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Log accuracy
        if (epoch + 1) % 100 == 0:
            with torch.no_grad():
                probe.eval()
                predicted_labels = (predictions > 0).float()
                accuracy = (predicted_labels == all_labels).float().mean().item()
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}, Accuracy: {accuracy:.4f}")

    # End Task 3







In [None]:
probe = nn.Linear(d_model, 1)
optimizer = torch.optim.Adam(probe.parameters(), lr=learning_rate)

train_probe(model, probe, optimizer, batch_size, sequence_length, num_epochs, device, get_model_embedding_at_positions)

In [None]:
probe = nn.Linear(d_model, 1)
optimizer = torch.optim.Adam(probe.parameters(), lr=learning_rate)

train_probe(model, probe, optimizer, batch_size, sequence_length, num_epochs, device, get_first_attention_output_at_positions)
