The idea of this notebook is to study how linear certain layers are. Inspired by ["Your transformer is secretly linear"](https://arxiv.org/pdf/2405.12250)

In [1]:
from pathlib import Path
from typing import Callable, Optional
from dataclasses import dataclass

from opt_sim_dataset import ConsecutiveOutputsDataset 

import torch
import torch.nn as nn
import torch.nn.functional as F 
from torch.utils.data import DataLoader, random_split

from tqdm.notebook import tqdm

In [2]:
def normalize(x): 
    X = x - x.mean(dim=0, keepdim=True)
    return X / X.norm()

def get_A_est(X, Y):
    U, S, Vh = torch.linalg.svd(X, full_matrices=False)
    A_estimation = Vh.T * (1 / S)[None, ...] @ U.T @ Y # Y=XA
    return A_estimation

def get_est_svd(X, Y):
    """
    X -- torch tensor with shape [n_samples, dim]
    Y -- torch tensor with shape [n_samples, dim]

    Approximates Y matrix with linear transformation Y = XA
    """
    A_estimation = get_A_est(X, Y) 
    Y_est =  X @ A_estimation
    return Y_est

def compute_linearity_score(x, y):
    """
    x -- torch tensor with shape [n_samples, dim]
    y -- torch tensor with shape [n_samples, dim]
    """
    with torch.no_grad(): 
        X, Y = normalize(x), normalize(y)
        Y_estimation = get_est_svd(X, Y)
    
        y_error = (Y_estimation - Y).square().sum()
        sim = float(1 - y_error)
    return sim

In [3]:
data = ConsecutiveOutputsDataset(Path('./data'), 13, 500, 0) 

FileNotFoundError: [Errno 2] No such file or directory: 'data/block0/100'

In [None]:
block_embeddings = [[] for _ in range(data.num_layers)]
for x, v1, v2, blocks in data:
    block = blocks[0]
    block_embeddings[block].append(x)
    if block == data.num_layers - 2:
        block_embeddings[block + 1].append(v1)

In [None]:
block_embeddings = [torch.cat(x).reshape(-1, 768) for x in block_embeddings]

In [None]:
block_embeddings[0].shape

torch.Size([22986, 768])

In [None]:
compute_linearity_score(torch.randn_like(block_embeddings[0]), torch.randn_like(block_embeddings[0]))

0.03338080644607544

In [None]:
k = 10
for idx in range(data.num_layers - k):
    print(idx, compute_linearity_score(block_embeddings[idx], block_embeddings[idx + k]))

0 0.9724007844924927
1 0.9659294486045837
2 0.47418320178985596


In [None]:
def get_linearity_dist(block1, block2):
    n_samples = block1.size(0)
    assert n_samples == block2.size(0)
    X, Y = normalize(block1), normalize(block2)
    A_est = get_A_est(X, Y)
    linearities = torch.zeros(n_samples)
    for idx, (x, y) in enumerate(zip(X, Y)): 
        err = (A_est @ x - y).square().sum() 
        linearities[idx] = 1 - err 

    return linearities

In [None]:
get_linearity_dist(block_embeddings[0], block_embeddings[-1])

tensor([0.9998, 0.9999, 0.9999,  ..., 0.9999, 0.9999, 0.9999])

In [None]:
import seaborn as sns

sns.displot(get_linearity_dist(block_embeddings[2], block_embeddings[7]))

<seaborn.axisgrid.FacetGrid at 0x7f7081827c10>

Error in callback <function flush_figures at 0x7f6ff743b880> (for post_execute), with arguments args (),kwargs {}:


KeyboardInterrupt: 

In [None]:
class SimpleAutoEncoder(nn.Module):
    def __init__(self, embed_dim: int, hidden_size: int):
        super(SimpleAutoEncoder, self).__init__()

        self.fc1 = nn.Linear(embed_dim, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(embed_dim, hidden_size)
    
    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [None]:
N = len(data)
train_len = int(N * 0.7)
train_data, test_data = random_split(data, lengths=[train_len, N - train_len])
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = DataLoader(test_data, batch_size=16, shuffle=False)
tgt_block = 3

In [None]:
criterion = nn.MSELoss()
# model = SimpleAutoEncoder(768, 768).to('cuda')
model = nn.Linear(768, 768).to('cuda')
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
for epoch in range(10):
    epoch_loss = 0
    model.train()
    for x, v1, v2, blocks in train_loader:
        x_blocks = blocks[:, 0]
        indices = x_blocks == tgt_block 
        if indices.sum() == 0:
            continue
        
        optimizer.zero_grad()
        outputs = model(x[indices])
        loss = criterion(outputs , v1[indices])
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
    
    test_loss = 0
    model.eval()
    for x, v1, v2, blocks in test_loader:
        x_blocks = blocks[:, 0]
        indices = x_blocks == tgt_block
        if indices.sum() == 0:
            continue
    
        outputs = model(x[indices])
        loss = criterion(outputs, v1[indices])
        test_loss += loss.item()

    print(f'{epoch=}, Train loss: {epoch_loss / len(train_data):.4e}, Test loss {test_loss / len(test_data):.4e}')

epoch=0, Train loss: 5.2948e-02, Test loss 4.9223e-02
epoch=1, Train loss: 3.6105e-02, Test loss 3.8368e-02
epoch=2, Train loss: 2.7115e-02, Test loss 3.0896e-02
epoch=3, Train loss: 2.3295e-02, Test loss 3.7011e-02
epoch=4, Train loss: 1.3110e-02, Test loss 1.5223e-02
epoch=5, Train loss: 1.0085e-02, Test loss 9.3757e-03
epoch=6, Train loss: 9.0032e-03, Test loss 6.2763e-03
epoch=7, Train loss: 3.3769e-03, Test loss 3.0761e-03
epoch=8, Train loss: 1.7365e-03, Test loss 2.2186e-03
epoch=9, Train loss: 2.6213e-03, Test loss 2.7589e-03


In [None]:
for idx in range(len(block_embeddings)):
    print(idx, block_embeddings[idx].norm(dim=-1).mean()) 

0 tensor(1.6551, device='cuda:0')
1 tensor(2.9450, device='cuda:0')
2 tensor(3.6688, device='cuda:0')
3 tensor(5.0450, device='cuda:0')
4 tensor(5.4629, device='cuda:0')
5 tensor(6.0037, device='cuda:0')
6 tensor(6.3741, device='cuda:0')
7 tensor(6.8520, device='cuda:0')
8 tensor(7.4376, device='cuda:0')
9 tensor(8.3768, device='cuda:0')
10 tensor(9.7573, device='cuda:0')
11 tensor(11.5002, device='cuda:0')
12 tensor(9.8634, device='cuda:0')


In [None]:
F.cosine_similarity(block_embeddings[0][0], block_embeddings[1][0], dim=0)

NameError: name 'F' is not defined

In [None]:
for idx in range(len(block_embeddings) - 1):
    print(idx, F.cosine_similarity(block_embeddings[idx], block_embeddings[idx + 1], dim=-1).mean())

0 tensor(0.7339, device='cuda:0')
1 tensor(0.9434, device='cuda:0')
2 tensor(0.9602, device='cuda:0')
3 tensor(0.9710, device='cuda:0')
4 tensor(0.9579, device='cuda:0')
5 tensor(0.9495, device='cuda:0')
6 tensor(0.9486, device='cuda:0')
7 tensor(0.9476, device='cuda:0')
8 tensor(0.9338, device='cuda:0')
9 tensor(0.9226, device='cuda:0')
10 tensor(0.9309, device='cuda:0')
11 tensor(0.9041, device='cuda:0')
