In [1]:
import torch
import pandas as pd

from mem_llm.tokenizer import CharTokenizer
from mem_llm.interface import ModelOutput

from mem_llm import MemLLM

# Inference

In [31]:
from torch.distributions import Categorical


@torch.no_grad()
def generate(
        seed: str,
        model: torch.nn.Module,
        tokenizer: CharTokenizer,
        *,
        device: str,
        max_length: int,
) -> str:
    model.eval()
    model.to(device)
    
    model.do_compile = True

    tokens = tokenizer.encode(seed).to(device)
    print(seed, end='')

    for _ in range(max_length):
        outputs = model(tokens, num_logits_to_keep=1)
        outputs: ModelOutput
    
        logits = outputs.logits.view(1, -1)
        topk_logits, topk_indices = torch.topk(logits, k=10, dim=-1)
    
        probs = torch.softmax(topk_logits, dim=-1)
    
        next_token_id = Categorical(probs).sample()
        best_next_char = topk_indices[0, next_token_id]
    
        tokens = torch.concat((tokens, best_next_char), dim=0)
    
        print(tokenizer.decode(best_next_char), end='')
        
    return tokenizer.decode(tokens)

In [52]:
model = MemLLM.load('../runs/char_lm/ts_from_scratch/model', device='cuda')
tokenizer = CharTokenizer.load('../runs/char_lm/ts_from_scratch/model')

In [53]:
generate('O God! O God! ', model, tokenizer, device='cuda', max_length=2000);

O God! O God! what comfort on?
What! it didst did stone? and hateful friends?
Mush to am I.

SICINIUS:
You will speak as you of a file; but I am,
Your bidding true, bestrong you to make your diggers,
Ere you not set too city flight too great,
And you will me such merrieve to guests;
Sufferant a watering dates charity?
When they say is that? Mismicker your commanded?
You willingly minds in yours, believe here!
By that shall give me much of your grace,
But you shall have told me scarce more than you.
Beating than I a fine, at wavers my heart
That lies incilict of the present looks,
But fly thee pessing my daughter of the lume
Of disgrains gold the law of your countenation,
By the ground that they do no great mother forth,
He do tomorrow many wounds to do it.

First Murderer:
So now, by God's grave! I have that might soul,
How not it will be our book'd trial, we know,
And I am sure to death this marriage words,
From my dear good since, whom I all said,
That setter like a scenen's biggest 

In [54]:
model = MemLLM.load('../runs/char_lm/ts_finetune/model', device='cuda')
tokenizer = CharTokenizer.load('../runs/char_lm/ts_finetune/model')

In [55]:
generate('O God! O God!', model, tokenizer, device='cuda', max_length=2000);

O God! O God!
William God hath behoved the bride.
Must not be taught at all, God hath met his bride.

CAPULET:
O God! my son, God! sir.

God! my face to my captain;
O God! my son.

MUSTOCAPE

ISABELLA:
The captain offers a period of written darkness:
I speaks it wrong in your mother's breast,
Holy for them as methinks: if you heaven stirr'd in
This wounds into one; indeed, you must not keep
The melody of God, Giod God! I sent me fingers too,
To that too.
Will not thou taught you spoke:
You must have you a bit of here.

MUSTOCAPE:
I speak you
To God. Will I love your cousin?

BULLIONS:
And state thy mate! I must speak in thy prayer:
You have thus taken him seriously.

CAPULET:
Although I, till morn's say we are all good about
Take the same way as he be animal, I'll do so.

ASIDE:
O God!

CAPULET:
And nobust the god of how you are, Giod God!

CAPULET:
You done sadness, I shall go on.

CAPULET:
As taught, these are thus! well, if I, thou cannot stay,
Yet speak, Giod God!

CAPULET:
I'le to

In [3]:
model = MemLLM.load('../runs/char_lm/mem_unet_embed/model', device='cuda')
tokenizer = CharTokenizer.load('../runs/char_lm/mem_unet_embed/model')

In [32]:
generate('O God! O God! ', model, tokenizer, device='cuda', max_length=2000);

O God! O God! The God was an American civil society. However, a number of people who were not accountable for there was a great advantage to an economic change and the gap of economic process and trade. On the other hand, it was a spiritual policy. But in the Gospel and Mrs. Ben Goddard, traditional legislator, for the second century of the God and because there would not be a man or a few important parts. It is essential to earn an organisation that the God's bureau in the Greek province and the God, was a guardian political and a past so what, in a disagreement, a total of the Gospel. When the Gospel successfully suffered the Gospel with this influence and, when we had talked to musical and traditional monuments, the Gospels had not done. Out of God's stuck being handed out in agriculture, there were several persons who were irritated.
Above all, the Gospel was soon able to grow and submit, and then several Gods speaking agreements were alike, in the Gospel. The Gospel was able to fi

In [33]:
generate('How many "r" are in the "strawberry"?', model, tokenizer, device='cuda', max_length=100);

How many "r" are in the "strawberry"?
How does "r" are in the strawberry flow facing organ. What are the different functions of product o

In [50]:
generate('Let N be the number of "r" are in the "strawberry". We can calculate N=', model, tokenizer, device='cuda', max_length=1);

Let N be the number of "r" are in the "strawberry". We can calculate N=2

In [51]:
generate('Everyone knows that 1 + 1 = ', model, tokenizer, device='cuda', max_length=1);

Everyone knows that 1 + 1 = 4

# Ablation study

In [3]:
import math
from torch.utils.data import DataLoader
import json
from pathlib import Path
from mem_llm.dataset import GuaranteedLengthDataset
import numpy as np
import torch
from mem_llm import MemLLM
import pandas as pd

torch.set_float32_matmul_precision('medium')

experiments = [
    'mem',
    'mem_unet',
    'baseline_rope1m',
    'baseline_rope10k',
    'baseline_rope500',
    'baseline_unet_embed',
    'mem_unet_embed',
]

@torch.no_grad()
def evaluate(path: Path):
    print(path)
    model = MemLLM.load(path / 'model', device='cuda')
    
    model.eval()
    model = torch.compile(model.to('cuda'))
    
    val_dataset = GuaranteedLengthDataset(
        '../data/fineweb-edu__char-vocab_size128-unk_token0-eos_token_id127/val', 
        example_length=100000, 
        source_dtype=np.uint8
    )
    
    dataloader = DataLoader(val_dataset, batch_size=None)
    
    total_log_likelihood = 0.0
    total_tokens = 0
    
    for batch, idx in zip(dataloader, range(25)):
        tokens = batch.to('cuda') 
    
        outputs = model(tokens) 
        
        logits = outputs.logits  
        targets = tokens[1:] 
        
        log_probs = torch.log_softmax(logits, dim=-1)
        
        targets = targets.unsqueeze(-1) 
        token_log_probs = torch.gather(log_probs[:-1], dim=-1, index=targets)
        
        total_log_likelihood += token_log_probs.sum().item()
        total_tokens += targets.numel()
    
    avg_log_likelihood = total_log_likelihood / total_tokens
    perplexity = math.exp(-avg_log_likelihood)
    
    print(f"Total Tokens: {total_tokens}")
    print(f"Perplexity: {perplexity:.2f}")

    return perplexity

all_data = []
for experiment in experiments:
    exp_path = Path('../runs/char_lm/' + experiment)
    config_path = exp_path.joinpath('config.json')
    with open(config_path, 'r') as f:
        config = json.load(f)
        
    data = {
        'rope_freq': config['model_config']['rotary_inv_freq_base'],
        'is_mem': config['model_config']['mem_freq'] < 100000,
        'is_unet': config['model_config']['unet_design'],
        'is_embed_residual': config['model_config']['embeds_residual'],
        'perplexity': evaluate(exp_path),
    }
    all_data.append(data)
    
data = pd.DataFrame(all_data)
data

../runs/char_lm/mem
Total Tokens: 2499975
Perplexity: 2.61
../runs/char_lm/mem_unet
Total Tokens: 2499975
Perplexity: 2.58
../runs/char_lm/baseline_rope1m
Total Tokens: 2499975
Perplexity: 2.63
../runs/char_lm/baseline_rope10k
Total Tokens: 2499975
Perplexity: 2.62
../runs/char_lm/baseline_rope500
Total Tokens: 2499975
Perplexity: 2.67
../runs/char_lm/baseline_unet_embed
Total Tokens: 2499975
Perplexity: 2.59
../runs/char_lm/mem_unet_embed
Total Tokens: 2499975
Perplexity: 2.57


Unnamed: 0,rope_freq,is_mem,is_unet,is_embed_residual,perplexity
0,1000000.0,True,False,False,2.611141
1,1000000.0,True,True,False,2.582623
2,1000000.0,False,False,False,2.632282
3,10000.0,False,False,False,2.617912
4,500.0,False,False,False,2.670399
5,1000000.0,False,True,True,2.585341
6,1000000.0,True,True,True,2.572131


In [5]:
data['-perplexity'] = -data['perplexity']
data.corr()['-perplexity']

rope_freq            0.680400
is_mem               0.592370
is_unet              0.827897
is_embed_residual    0.630527
perplexity          -1.000000
-perplexity          1.000000
Name: -perplexity, dtype: float64