# Overview

This is a toy project for exploring mechanistic interpretability methods.

In [2]:
import math
import torch
import torch.nn as nn
from torch.nn import functional as F

In [None]:
class SelfAttention(nn.Module):
    """
    A torch module for self attention. Informed by Andrej Karpathy's
    mingpt project ( https://github.com/karpathy/minGPT) and a slightly embarassing
    conversation with ChatGPT-5 ()

    Our attention mechanism :
    - A sequence of n-dimensional vectors is furnished, these are the embedded tokens
      from the input
    - For every vector as Q, we need to
      - compute its similarity with all vectors in the sequence (we are not a 'causal'
      model here, no shame at looking at the whole sequence if we're not generating!

    """

    def __init__(self, n_dim=10, scale=False):
        """
        Input dimensions are typically sharded across heads in multi-head attention.
        We are aiming for simplicity and avoid this, usiing just a single 'head' with
        the full input dimension.
        """
        super().__init__()
        self.Q = nn.Linear(n_dim, n_dim)
        self.K = nn.Linear(n_dim, n_dim)
        self.V = nn.Linear(n_dim, n_dim)

        self.n_dim = n_dim
        self.scale = scale

    def forward(self, x):
        """
        We accept input in the shape of batch, seq length, model dimension.

        Note we don't need any linear layers as output because we're only using a single
        attention head. If we had more, we would need to map our heads back into the d_model
        space with a linear layer.
        """
        # (batch, seq_len, n_dim)

        # Project our input into the query space (i.e. multiply by the query weights),
        # do the same for the key space. Then apply our similarity operation (dot product
        # by way of matmul) to yield an attention tensor.
        q = self.Q(x)
        k = self.K(x)
        attn = torch.matmul(q, k.transpose(-2,-1))

        # We optionally scale our attention values down to avoid them blasting off and saturating
        # the softmax function (thereby destroying gradients during backprop). For tiny models,
        # this is probably not an issue and so we allow omission to simplify the model.
        if self.scale:
            attn = attn / math.sqrt(self.n_dim)

        # Now normalize our logits with softmax so we can scale the value vector based on the
        # attention we are learning to pay to each respective token
        attn = F.softmax(attn, dim=-1)

        v = self.V(x)

        out = torch.matmul(attn, v)

        return out, attn


In [3]:
import numpy as np


In [29]:
test = torch.rand(1,3,3)

In [30]:
test.shape

torch.Size([1, 3, 3])

In [31]:
linear = nn.Linear(3,3)

In [32]:
for p in linear.parameters():
  print(p)

Parameter containing:
tensor([[-0.0082, -0.0069,  0.1016],
        [ 0.2007, -0.2645, -0.2846],
        [-0.0276,  0.2386,  0.5372]], requires_grad=True)
Parameter containing:
tensor([-0.5452, -0.2175, -0.2108], requires_grad=True)


In [33]:
out = linear(test)

In [34]:
out.shape

torch.Size([1, 3, 3])

In [37]:
test.transpose(-1,0).shape

torch.Size([3, 3, 1])