In [None]:
class TrackstarUnbatched:
    def __init__(self, model, device='cuda', eps=1e-8):
        # disable caching so forward is purely stateless
        
        self.model = model.to(device).eval()
        self.device = device
        self.eps = eps
        self.second_moment = None
        self.loss_fn = nn.CrossEntropyLoss(reduction='sum')

        self.projector = None
        self.counter = 0
        self.grouped_grads = []
        print('Initialized TrackStar Class')

    def compute_gradients(self, sample, group_size=2, prompt_length=0):
        self.model.zero_grad()
        input_ids = torch.tensor(sample['input_ids']).to(self.device)
        attention_mask = torch.ones_like(input_ids).to(self.device)
        x_in = input_ids[:, :-1]
        y    = input_ids[:, 1:]
        m    = attention_mask[:, :-1]
        labels = y.clone()
        labels[:, :prompt_length] = -100  
        logits = self.model(input_ids=x_in, attention_mask=m).logits
        loss = self.loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))  
        loss = loss / (labels != -100).sum()  
        loss.backward()
        
        # collect per-param grads
        block_grads = defaultdict(list)
        for name, p in self.model.named_parameters():
            if p.grad is None or 'wte' in name: continue
            g = p.grad.detach().flatten()
            if name.startswith('transformer.h.'):
                L = int(name.split('.')[2])
                b = L // group_size
                t = 'attn' if 'attn' in name else 'mlp'
                key = f'group{b}_{t}'
            elif name.startswith('transformer.ln_f'):
                key = 'final_ln'
            elif name.startswith('transformer.lm_head'):
                key = 'lm_head'
            else:
                continue

            block_grads[key].append(g)
        block_grads = {k: torch.cat(v, dim=0) for k, v in block_grads.items()}
    
        block_grads = {key: g/(self.second_moment[key].to(self.device) + self.eps) for key, g in block_grads.items()}
        block_grads = self.projector.project_per_block(block_grads)
      
        self.grouped_grads.append(block_grads)
        return self.grouped_grads  


def compute_block_shapes(grouped_second_moment, embedding_dim):
    block_shapes = {}
    for key, vec in grouped_second_moment.items():
        total_dim = vec.numel()
        n = embedding_dim
        if total_dim % n != 0:
            raise ValueError(f"Dimension mismatch in {key}: total {total_dim} not divisible by embedding_dim {n}")
        m = total_dim // n

        block_shapes[key] = (m, n)
    return block_shapes

from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader
from datasets import load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel, GPT2Tokenizer
import torch


In [None]:

## Second Moment Computation and grouping. 
## Compute block shapes from grouped second moments
## Use these block shapes to initialize a BlockProjector which essentially initializes projection matrices for each block

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
model = GPT2LMHeadModel.from_pretrained('out/wiki_model')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
from collections import defaultdict
import torch




In [None]:

second_moment = torch.load('data/trackstar/wiki/second_moment.pt', map_location='cpu')
for k, vec in second_moment.items():
    print(f"{k:12s} → {tuple(vec.shape)}")



In [None]:

from collections import defaultdict
import torch
import torch.nn as nn
import re
import math


class BlockProjector:
    def __init__(self, block_shapes, d=4096, device='cpu'):
        self.d = d
        self.sqrt_d = int(math.sqrt(d))
        self.device = device
        self.proj_matrices = {}
        torch.manual_seed(0)
        for key, (m, n) in block_shapes.items():
            P0 = torch.randn(self.sqrt_d, m, device=device) / math.sqrt(self.sqrt_d)
            P1 = torch.randn(self.sqrt_d, n, device=device) / math.sqrt(self.sqrt_d)
            self.proj_matrices[key] = (P0, P1)
            
        
    def project_per_block(self, block_grads):

        out = {}
        for key, vec in block_grads.items():
            P0, P1 = self.proj_matrices[key]
            m, n = P0.shape[1], P1.shape[1]
            W = vec.view(m, n)
            out[key] = (P0 @ W @ P1.T).flatten()   # → [d]
        return out

def group_second_moment(second_moment: dict[str, torch.Tensor],
                        group_size: int) -> dict[str, torch.Tensor]:
    pattern = re.compile(r"^group(\d+)_(attn|mlp)$")
    bucketed: dict[str, list[torch.Tensor]] = defaultdict(list)

    for key, vec in second_moment.items():
        m = pattern.match(key)
        if m:
            layer = int(m.group(1))
            typ   = m.group(2)                    # 'attn' or 'mlp'
            new_layer = layer // group_size       # 0,1,2,...
            new_key   = f"group{new_layer}_{typ}"
        else:
            # keep other parameters (final_ln, lm_head) as is
            new_key = key

        bucketed[new_key].append(vec)

    # concatenate each list of tensors into one vector
    return {k: torch.cat(vs, dim=0) for k, vs in bucketed.items()}



grouped_second_moment = group_second_moment(second_moment, 2)
for k, vec in grouped_second_moment.items():
    print(k, vec.shape)
block_shapes = compute_block_shapes(grouped_second_moment, model.config.n_embd)
proj = BlockProjector(block_shapes, d=4096, device='cpu')

## Compute projected gradients
## Go through the dataset in batches 
## For each batch, compute the gradient of each sample in the batch and group the gradients by blocks defined before
## For each gradient in the batch, normalize it with its corresponding second moment block
## Each gradient is essentially a dictionary of grouped block names, which will be passed to the projector
## Projector is initialized before and projects each gradient dictionary and appends it to a list of gradients. 
## Once the list reaches a certain cutoff, store it in a file. 

trackstar = TrackstarUnbatched(model, device='cpu')
trackstar.second_moment = grouped_second_moment
trackstar.projector = proj


In [None]:
auto_corr_matrix = torch.load('data/trackstar/wiki/gradients/autocorr_matrices_inv_sqrt.pt', map_location='cpu')
print(auto_corr_matrix.keys())
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
examples = [
    'everest',
    'jazz',
    'ancient_rome',
    'dna',
    'feminism',
    'impressionism',
    'internet',
    'philosophy_mind',
    'solar_system',
    'ww2',
    'thermodynamics',
    'iss_dot_model',
    'ww1_dot_model',
    'ancient_egypt',
    'ancient_greece',
    'art_deco',
    'big_bang',
    'buddhism',
    'democracy',
    'ecology',
    'genetics',
    'gothic_architecture',
    'probability',
    'renaissance',
    'shakespeare'
]

for example in examples:
    ckpt = torch.load(f'out/wiki_models_finetuned/fisher_regularized_models/{example}_finetuned_fisher.pt', map_location='cpu')
    output_ids = ckpt['output_ids']
    output_text = tokenizer.decode(output_ids[0])
    print(output_text)
    split_text = output_text.split('.')[0]
    prompt = split_text + '.'
    prompt_length = len(tokenizer.encode(prompt))
    sample = {
        'input_ids': output_ids
    }
    gradient_query_sample = trackstar.compute_gradients(sample, group_size=2, prompt_length=prompt_length)
    sample_gradient = {}
    for key, R in auto_corr_matrix.items():
        grad = torch.tensor(R).to('cpu') @ gradient_query_sample[-1][key].to('cpu')
        norm = torch.norm(grad, dim=0, keepdim=True)
        sample_gradient[key] = grad / norm
    torch.save(sample_gradient, f'data/trackstar/wiki/testing/gradients/{example}_gradient.pt')