## P47 Wormhole 
- Cleaned up version w/ original param replacement bug fix

In [1]:
# ! pip install transformers matplotlib tqdm huggingface_hub

In [2]:
# from huggingface_hub import login
# login()

In [3]:
import torch
from transformers import pipeline
from torch.nn import functional as F
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib import cm
import os
import copy
from collections import OrderedDict

from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast, LlamaConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

device='cuda'

In [4]:
model_id = "meta-llama/Llama-3.2-1B"

#Pretrained
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model.eval();

## Configuration for this run

In [5]:
output_dir='/workspace/apr_26_2'
num_points=32 
n_steps=128 
lr=1e-7
delayed_viz_start=64 #e.g. set to 10 if i only want to start renderding after the 10th optimiation steps

## Support Functions

In [6]:
def get_random_directions(params, seed=None):
    """
    Generate random direction vectors for each parameter tensor.
    
    Args:
        params: List of (name, parameter) tuples from model.named_parameters()
        seed: Random seed for reproducibility
        
    Returns:
        direction: OrderedDict mapping parameter names to random direction tensors
    """
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
    
    direction = OrderedDict()
    for name, param in params:
        if param.requires_grad:
            direction[name] = torch.randn_like(param.data)
    
    return direction

def normalize_direction(direction, params):
    """
    Normalize the direction tensors to match the norm of each parameter tensor.
    
    Args:
        direction: OrderedDict mapping parameter names to direction tensors
        params: List of (name, parameter) tuples from model.named_parameters()
        
    Returns:
        normalized_direction: OrderedDict with normalized direction tensors
    """
    param_dict = OrderedDict(params)
    normalized_direction = OrderedDict()
    
    for name, dir_tensor in direction.items():
        param_norm = torch.norm(param_dict[name].data)
        dir_norm = torch.norm(dir_tensor)
        
        # Avoid division by zero
        if dir_norm > 0:
            normalized_direction[name] = dir_tensor * (param_norm / dir_norm)
        else:
            normalized_direction[name] = dir_tensor
    
    return normalized_direction

### Setup example and run some Computation Checks

In [7]:
text = "The capital of France is Paris"
inputs = tokenizer(text, return_tensors="pt").to(device)
input_ids = inputs["input_ids"]

In [8]:
with torch.no_grad():
    outputs = model(input_ids, labels=input_ids)

my_probs=F.softmax(outputs.logits, dim=-1)
y_one_hot=F.one_hot(input_ids, num_classes=model.config.vocab_size)
correct_next_token_probs = (my_probs[:,:-1]*y_one_hot[:,1:]).sum(-1) #I'm sure there's waaay more efficient ways to do this
my_loss=-torch.log(correct_next_token_probs).mean()
print(my_loss.item(), outputs.loss.item())

3.3751845359802246 3.3751840591430664


In [9]:
with torch.no_grad():
    outputs = model(input_ids, labels=input_ids)

my_probs=F.softmax(outputs.logits, dim=-1)
y_one_hot=F.one_hot(input_ids, num_classes=model.config.vocab_size)
correct_next_token_probs = (my_probs[:,:-1]*y_one_hot[:,1:]).sum(-1) #I'm sure there's waaay more efficient ways to do this
my_loss=-torch.log(correct_next_token_probs).mean()

paris_only_loss=-np.log(my_probs[0, 5, 12366].item())
print(my_loss.item(), outputs.loss.item(), paris_only_loss)

3.3751845359802246 3.3751840591430664 0.9376922065287221


In [10]:
sI=np.argsort(my_probs[0,5, :].detach().cpu().float().numpy())[::-1]
for i in sI[:10]:
    print(i, round(my_probs[0, 5, i].item(),5), tokenizer.decode([i]))

12366 0.39153  Paris
264 0.08419  a
279 0.0704  the
832 0.03096  one
1101 0.03061  also
2162 0.02528  home
3967 0.02462  known
539 0.01659  not
459 0.01241  an
7559 0.01172  located


In [11]:
prefix='pretrained_'
filtered_params = [(name, p) for name, p in model.named_parameters() if p.requires_grad]
# layers_name='all'

layers_name='first_8'
filtered_params = filtered_params[1:73] 

# layers_name='last_8'
# filtered_params = filtered_params[73:] #Last 8 layers - some nice structue, but yeah more parabolic than I would like

random_seed_1=11
random_seed_2=111

# Generate and normalize random directions
direction1 = get_random_directions(filtered_params, seed=random_seed_1)
direction2 = get_random_directions(filtered_params, seed=random_seed_2)

direction1 = normalize_direction(direction1, filtered_params)
direction2 = normalize_direction(direction2, filtered_params)

original_params = OrderedDict()
for name, param in filtered_params:
    original_params[name] = param.data.clone()

alphas=np.linspace(-2.5, 2.5, num_points)
betas=np.linspace(-2.5, 2.5, num_points)

In [12]:
os.makedirs(output_dir, exist_ok=True)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [13]:
#Move away from center
alpha_shift=-0.9 
beta_shift=0.05

alphas_shifted=alphas-alpha_shift #Shift scan points to keep thing consistent. 
betas_shifted=betas-beta_shift

#Replace actual model parameters with the shifted ones. 
for name, param in model.named_parameters():
    if name in direction1 and name in direction2:
        param.data = original_params[name] + alpha_shift * direction1[name] + beta_shift * direction2[name]

#Make copy for scaning/replacing. 
original_params_shifted = OrderedDict()
for name, param in filtered_params:
    original_params_shifted[name] = param.data.clone()

In [None]:
model_outputs=[]
for step in range(n_steps):
    losses=[]
    model.eval();

    with torch.no_grad(): #Check current outputs
        outputs = model(input_ids, labels=input_ids)
        my_probs=F.softmax(outputs.logits, dim=-1)
        sI=np.argsort(my_probs[0,5, :].detach().cpu().float().numpy())[::-1]
        current_outs=[[12366,  round(my_probs[0, 5, 12366].item(), 7), ' Paris']] #Put paris at top
        for i in sI[:10]:
            current_outs.append([i, round(my_probs[0, 5, i].item(),7), tokenizer.decode([i])])
        model_outputs.append(current_outs)
        print(step, 'loss=', -np.log(my_probs[0, 5, 12366].item()), current_outs[0], current_outs[1])

    if step>=delayed_viz_start: #Do I want to compute loss landscape at this step?
        with torch.no_grad():
            for i, alpha in enumerate(tqdm(alphas_shifted)):
                losses.append([])
                for j, beta in enumerate(betas_shifted):
                    for name, param in model.named_parameters():
                        if name in direction1:
                            param.data = original_params_shifted[name] + alpha * direction1[name] + beta*direction2[name]
                    
                    outputs = model(input_ids, labels=input_ids)
                    my_probs=F.softmax(outputs.logits, dim=-1)
                    paris_only_loss=-np.log(my_probs[0, 5, 12366].item()) #Just Paris
                    losses[-1].append(paris_only_loss)
            
            for name, param in model.named_parameters(): # Restore original shifted parameters
                if name in original_params: 
                    param.data.copy_(original_params_shifted[name])
        losses=np.array(losses)
        np.save(output_dir +'/'+str(step).zfill(3), losses) #Save loss landscape
        
        plt.clf()
        fig, ax = plt.subplots(figsize=(10, 8))
        contourf = ax.contourf(alphas, betas, losses, 20, cmap='viridis', alpha=0.8)
        contour = ax.contour(alphas, betas, losses, 30, colors='white', linewidths=0.5)
        plt.scatter(beta_shift, alpha_shift, c='m')
        plt.savefig(output_dir +'/'+str(step).zfill(3)+'.png')

    model.train()
    optimizer.zero_grad()
    outputs = model(**inputs, labels=inputs['input_ids'])
    loss = outputs.loss #Ok not just paris loss here -> not sure how much I'm worried about that
    loss.backward()
    optimizer.step()

    #After training I need to replace original_params_shifted with the new trained values
    original_params_shifted = OrderedDict()
    for name, param in filtered_params:
        original_params_shifted[name] = param.data.clone()

0 loss= 12.032893424603056 [12366, 5.9e-06, ' Paris'] [37180, 0.0486275, 'adar']
1 loss= 11.69476953080568 [12366, 8.3e-06, ' Paris'] [37180, 0.047558, 'adar']
2 loss= 11.35649891006215 [12366, 1.17e-05, ' Paris'] [37180, 0.0464448, 'adar']
3 loss= 11.019169033754162 [12366, 1.64e-05, ' Paris'] [37180, 0.0452946, 'adar']
4 loss= 10.683672679717075 [12366, 2.29e-05, ' Paris'] [37180, 0.0441116, 'adar']
5 loss= 10.350730993766888 [12366, 3.2e-05, ' Paris'] [37180, 0.0429004, 'adar']
6 loss= 10.020812051325684 [12366, 4.45e-05, ' Paris'] [37180, 0.0416649, 'adar']
7 loss= 9.694163320889235 [12366, 6.16e-05, ' Paris'] [37180, 0.0404064, 'adar']
8 loss= 9.370973957108722 [12366, 8.52e-05, ' Paris'] [37180, 0.0391275, 'adar']
9 loss= 9.05123692036561 [12366, 0.0001172, ' Paris'] [37180, 0.0378297, 'adar']
10 loss= 8.734785279797308 [12366, 0.0001609, ' Paris'] [37180, 0.0365111, 'adar']
11 loss= 8.421417450904402 [12366, 0.0002201, ' Paris'] [37180, 0.0351736, 'adar']
12 loss= 8.110884671537

 34%|███▍      | 11/32 [00:16<00:31,  1.51s/it]

In [None]:
# tokenizer.decode(12366)


In [None]:
current_outs