In [35]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from sklearn.decomposition import PCA
import plotly.express as px
import plotly.graph_objects as go

In [2]:
def load_model(model_name):
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model("allenai/OLMoE-1B-7B-0924")


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
input_text = "what is principal component analysis ?"
input_ids = tokenizer.encode(input_text, return_tensors="pt")

outputs = model(input_ids, output_router_logits=True)

output = model.generate(input_ids, max_length=70)
print("Generated text :")
for token in output[0]:
    print(tokenizer.decode(token), end='', flush=True)
print()

# Print router logits for each layer  
print("\nRouter logits per layer :")
for i, router_logits in enumerate(outputs.router_logits):
    print(f"Layer {i} router logits shape: {router_logits.shape}")

Generated text :
what is principal component analysis?

Principal component analysis (PCA) is a statistical method for reducing the dimensionality of a data set. It is a technique for extracting the maximum variance from a data set. It is a statistical method for extracting the maximum variance from a data set.

PCA is a statistical method for extracting the maximum

Router logits per layer :
Layer 0 router logits shape: torch.Size([6, 64])
Layer 1 router logits shape: torch.Size([6, 64])
Layer 2 router logits shape: torch.Size([6, 64])
Layer 3 router logits shape: torch.Size([6, 64])
Layer 4 router logits shape: torch.Size([6, 64])
Layer 5 router logits shape: torch.Size([6, 64])
Layer 6 router logits shape: torch.Size([6, 64])
Layer 7 router logits shape: torch.Size([6, 64])
Layer 8 router logits shape: torch.Size([6, 64])
Layer 9 router logits shape: torch.Size([6, 64])
Layer 10 router logits shape: torch.Size([6, 64])
Layer 11 router logits shape: torch.Size([6, 64])
Layer 12 route

In [8]:
def get_router_logits(model, input_ids):
    """get router logits from model forward pass"""
    outputs = model(input_ids, output_router_logits=True)
    return outputs.router_logits

def get_last_token_router_probs(router_logits, layer_idx):
    """Get router probabilities for the last token in a specified layer"""
    layer_logits = router_logits[layer_idx]  # Shape: [sequence_length, num_experts]
    last_token_logits = layer_logits[-1]  # get last token logits
    routing_probs = torch.nn.functional.softmax(last_token_logits, dim=-1)
    return routing_probs

def topk(router_probs, k):
    """zero out all components except top k router probabilities"""
    values, indices = torch.topk(router_probs, k)
    zeroed_probs = torch.zeros_like(router_probs)
    zeroed_probs[indices] = values
    return zeroed_probs


layer_idx = 0

router_logits = get_router_logits(model, input_ids)
probs = get_last_token_router_probs(outputs.router_logits, layer_idx)
print(f"router probs shape: {probs.shape}, sum: {probs.sum():.2f}")
print(f'router probs : {probs}')
top_probs = topk(probs, k=8) 
print(f"top 8 probs : {top_probs}")
print(f"top 8 probs sum : {top_probs.sum():.2f}")



router probs shape: torch.Size([64]), sum: 1.00
router probs : tensor([0.0239, 0.0114, 0.0153, 0.0202, 0.0063, 0.0099, 0.0711, 0.0168, 0.0112,
        0.0055, 0.0166, 0.0292, 0.0047, 0.0084, 0.0108, 0.0123, 0.0129, 0.0112,
        0.0069, 0.0292, 0.0070, 0.0098, 0.0132, 0.0260, 0.0880, 0.0104, 0.0108,
        0.0134, 0.0033, 0.0114, 0.0136, 0.0096, 0.0124, 0.0067, 0.0129, 0.0086,
        0.0090, 0.0127, 0.0108, 0.0139, 0.0192, 0.0054, 0.0077, 0.0133, 0.0169,
        0.0120, 0.0047, 0.0182, 0.0115, 0.0167, 0.0111, 0.0214, 0.0127, 0.0193,
        0.0231, 0.0084, 0.0200, 0.0176, 0.0031, 0.0126, 0.0320, 0.0190, 0.0153,
        0.0215], grad_fn=<SoftmaxBackward0>)
top 8 probs : tensor([0.0239, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0711, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0292, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0292, 0.0000, 0.0000, 0.0000, 0.0260, 0.0880, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000

In [9]:
def get_router_probs_matrix(model, prompts, layer_idx=0, k=8):
    """get router probs matrix for last token for multiple inputs"""
    num_prompts = len(prompts)
    num_experts = 64 
    prob_matrix = torch.zeros((num_prompts, num_experts))
    
    # Get router probs for each prompt
    for i, prompt in enumerate(prompts):
        input_ids = tokenizer.encode(prompt, return_tensors="pt")
        router_logits = get_router_logits(model, input_ids)
        probs = get_last_token_router_probs(router_logits, layer_idx)
        top_probs = topk(probs, k=k)
        prob_matrix[i] = top_probs
        
    return prob_matrix

In [26]:
test_prompts = [
    "The quick brown fox",
    "1+1=",
    "the grey cat",
    "the grey elephant",
    "2*8",
    "def hello_world() : \n    print('hello world')",
    "what is principal component analysis",
    'what is capital of india',
    'sqrt 16',
    'void bubbleSort(int arr[], int n) {',
    'def is_prime(n):',
    'if n <= 1:',
    'return False',
    'for i in range(2, int(n**0.5) + 1):',
    'if n % i == 0:',
    'return False',
    'return True',
    "china",
    "the united states of america",
    "london",
    "tokyo",
    'paris'
]

router_prob_matrix = get_router_probs_matrix(model, test_prompts,k=64)
print(f"router probability matrix shape : {router_prob_matrix.shape}")
print("\nrouter probs matrix :")
print(router_prob_matrix[:])
print(f"\nverify each row sums to 1 : {router_prob_matrix.sum(dim=1)}")

router probability matrix shape : torch.Size([22, 64])

router probs matrix :
tensor([[0.0078, 0.0077, 0.0108,  ..., 0.0143, 0.0193, 0.0073],
        [0.0365, 0.0136, 0.0263,  ..., 0.0138, 0.0126, 0.0178],
        [0.0164, 0.0087, 0.0129,  ..., 0.0063, 0.0130, 0.0056],
        ...,
        [0.0036, 0.0072, 0.0101,  ..., 0.0042, 0.0160, 0.0110],
        [0.0020, 0.0631, 0.0061,  ..., 0.0071, 0.0062, 0.0186],
        [0.0127, 0.0122, 0.0091,  ..., 0.0046, 0.0097, 0.0112]],
       grad_fn=<SliceBackward0>)

verify each row sums to 1 : tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)


In [None]:
def get_last_token(prompt):
    """get the last token of a prompt using the tokenizer"""
    tokens = tokenizer.encode(prompt)
    last_token = tokenizer.decode([tokens[-1]])
    return last_token



In [34]:
def pca_visualize(router_prob_matrix, test_prompts):
    """perform PCA visualization on router probability matrix"""

    pca = PCA(n_components=3)
    router_prob_matrix_np = router_prob_matrix.detach().numpy()

    # Fit and transform the data
    pca_result = pca.fit_transform(router_prob_matrix_np)

    print("\nPCA results:")
    print(f"explained variance ratio: {pca.explained_variance_ratio_}")
    print(f"cumulative explained variance: {pca.explained_variance_ratio_.sum():.3f}")

    print("\nPCA transformed data shape:", pca_result.shape)
    # print("first few transformed points:")
    # print(pca_result)

    # Get last token of each prompt using tokenizer
    last_tokens = [get_last_token(prompt) for prompt in test_prompts]

    # Create 3D scatter plot
    fig = go.Figure(data=[go.Scatter3d(
        x=pca_result[:, 0],
        y=pca_result[:, 1],
        z=pca_result[:, 2],
        mode='markers+text',
        text=last_tokens,
        textposition="top center",
        marker=dict(
            size=10,
            opacity=0.8
        )
    )])

    # Update layout for 3D
    fig.update_layout(
        title='3D PCA of Router Probabilities',
        scene=dict(
            xaxis_title='first principal component',
            yaxis_title='second principal component',
            zaxis_title='third principal component',
            xaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
            yaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
            zaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray')
        ),
        width=1000,
        height=800,
        showlegend=False
    )

    fig.show()



PCA Results:
Explained variance ratio: [0.2713409  0.18755594 0.12105728]
Cumulative explained variance: 0.580

PCA transformed data shape: (22, 3)
First few transformed points:
[[-0.06236592  0.0963617  -0.03788715]
 [-0.06566197 -0.04998831  0.00261806]
 [-0.05967391  0.11169526 -0.03712185]
 [-0.05831084  0.12667473 -0.04036308]
 [-0.04960285 -0.0310475   0.00121144]
 [-0.02540681 -0.00746514 -0.00165672]
 [-0.00849235 -0.00898064  0.0054354 ]
 [-0.04082667 -0.05800136 -0.00263951]
 [ 0.05595113  0.04762353 -0.01735783]
 [-0.04446524 -0.01787041  0.00956392]
 [-0.06131588 -0.05367101  0.00433186]
 [-0.05683717 -0.06704464 -0.00959211]
 [ 0.10764203 -0.02419738 -0.04624162]
 [-0.05611155 -0.04148081  0.00125383]
 [-0.03913319 -0.06342007 -0.00941174]
 [ 0.10764203 -0.02419738 -0.04624161]
 [ 0.13262796 -0.02510761 -0.05662976]
 [ 0.0889605   0.03747458  0.091026  ]
 [-0.01364151 -0.0114995   0.00064074]
 [ 0.00123032  0.03440759  0.14570974]
 [ 0.07475901  0.04048082  0.02719185]
 [

In [29]:
def compute_cosine_similarity(router_prob_matrix, idx1, idx2):

    # Get the probability vectors for the two tokens
    vec1 = router_prob_matrix[idx1]
    vec2 = router_prob_matrix[idx2]
    
    # Convert numpy arrays to torch tensors
    vec1 = torch.from_numpy(vec1).float()
    vec2 = torch.from_numpy(vec2).float()
    
    # Compute cosine similarity using torch.nn.functional
    cos_sim = torch.nn.functional.cosine_similarity(vec1.unsqueeze(0), vec2.unsqueeze(0))
    
    return cos_sim.item()

print(f"Cosine similarity between tokens 2 and 3: {compute_cosine_similarity(router_prob_matrix_np, 2, 3):.4f}")
print(f"Cosine similarity between tokens 0 and 5: {compute_cosine_similarity(router_prob_matrix_np, 0, 5):.4f}")
print(f"Cosine similarity between tokens 10 and 15: {compute_cosine_similarity(router_prob_matrix_np, 10, 15):.4f}")


Cosine similarity between tokens 2 and 3: 0.9859
Cosine similarity between tokens 0 and 5: 0.6527
Cosine similarity between tokens 10 and 15: 0.4542


In [33]:
test_prompts = [
    "The quick brown fox",
    "1+1=",
    "the grey cat",
    "the grey elephant",
    "2*8",
    "def hello_world() : \n    print('hello world')",
    "what is principal component analysis",
    'what is capital of india',
    'sqrt 16',
    'void bubbleSort(int arr[], int n) {',
    'def is_prime(n):',
    'if n <= 1:',
    'return False',
    'for i in range(2, int(n**0.5) + 1):',
    'if n % i == 0:',
    'return False',
    'return True',
    "china",
    "the united states of america",
    "london",
    "tokyo",
    'paris'
]

def get_last_token(test_prompts, idx):
    """ returns the last token from the prompt at the given index."""
    if idx < 0 or idx >= len(test_prompts):
        raise IndexError("Index out of range")
        
    prompt = test_prompts[idx]
    tokens = tokenizer.encode(prompt)
    
    if len(tokens) == 0:
        return ""
        
    last_token_id = tokens[-1]
    return tokenizer.decode([last_token_id])

print(f"Last token of prompt at index 0: {get_last_token(test_prompts, 5)}")  # Should print 'fox'
print(f"Last token of prompt at index 2: {get_last_token(test_prompts, 2)}")  # Should print 'cat'



Last token of prompt at index 0: ')
Last token of prompt at index 2:  cat
