In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '7'
os.chdir('..')

import pickle
import re
from pathlib import Path

import torch
import datasets
import numpy as np
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from transformers.modeling_outputs import CausalLMOutputWithPast

from tqdm import tqdm
from nltk import sent_tokenize

from model import MemoryCell
from train import calculate_accuracy

pd.set_option('display.max_rows', 200)

In [2]:
# get token-lvl accuracy (all tokens except the first one)
# get ppl on non-prefix text (all tokens except the first one)
# args:
# - model name (model)
# - prefix (left context) length
# - suffix length (compressed text) length

In [20]:
model_name = 'EleutherAI/pythia-410m'
dtype = 'float32'
device = 'cuda'
use_flash_attention_2 = False

# model_name = 'meta-llama/Llama-3.2-1B'
# dtype = 'bfloat16'
# device = 'cuda'
# use_flash_attention_2 = True

dtype = getattr(torch, dtype)
N_mem_tokens = 1
max_length = 64
prefix_length = 2
texts_path = './data/pg19_valid_1k_chunks.csv'
mem_results_path = Path(f'./runs/{model_name}/mem_{N_mem_tokens}_len_{max_length}.pkl')

with_prefix_results_path = mem_results_path.parent / 'with_prefix' / f'mem_{N_mem_tokens}_len_{max_length}.json'

In [21]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             use_flash_attention_2=use_flash_attention_2)
model = model.to(device)



In [None]:
import pandas as pd
texts_df = pd.read_csv(texts_path, index_col=0)

In [24]:
mem_results = pickle.load(open(mem_results_path, 'rb'))
mem_result = mem_results[0]
sample_idx = mem_result['args']['sample_idx']
sample_idx

0

In [26]:
mem_result.keys()

dict_keys(['losses', 'accuracies', 'original_loss', 'original_accuracy', 'best_memory_params', 'best_loss', 'best_accuracy', 'max_length', 'n_mem_tokens', 'args'])

In [225]:
from torch.nn import CrossEntropyLoss

text_sample = texts_df['text'][sample_idx]
sentences = sent_tokenize(text_sample)
prefix_text = ' '.join(sentences[:len(sentences)//2])
suffix_text = ' '.join(sentences[len(sentences)//2:])

inp = tokenizer(suffix_text, max_length=max_length, truncation=True, return_tensors='pt').to(device)

with torch.cuda.amp.autocast(dtype=dtype):
    with torch.no_grad():
        output = model(**inp, labels=inp['input_ids'])
        loss = output.loss.item()
        accuracy = calculate_accuracy(output.logits, inp['input_ids'])

        labels = inp['input_ids']
        logits = output.logits
        labels = labels.to(logits.device)
        shift_logits = logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()
        loss_fct = CrossEntropyLoss(reduction='none')
        loss_1 = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.


In [226]:
loss, accuracy, loss_1[0:].mean().item(), loss_1[1:].mean().item()

(2.95055890083313, 0.3809524178504944, 2.950559139251709, 2.842134714126587)

In [2]:
# option 1: tokenize(prefix) + tokenize(suffix)
# -- no space token, the first suffix token will glue to last prefix token
# option 2: tokenizer(prefix+' ') + tokenize(suffix)
# -- will produce space token (unnatural to real texts that model was trained on)
# option 3: tokenize(prefix + ' ' + suffix) -- THIS ONE
# -- will look like natural text, but the first suffix (compressed text) token will change
# 
# llama, opt adds bos -- need to remove it from suffix
# pythia has no bos -- no need to remove

from torch.nn import CrossEntropyLoss

def eval_model_with_text_prefix(model, tokenizer, max_length, prefix_length, dtype,
                                sample_idx=None, text_sample=None, sample=None, texts_df=None):    
    option = 2
    
    if sample is not None:
        # take all needed params from saved results from run with mem token
        sample_idx = sample['args']['sample_idx']
        text_sample = texts_df['text'][sample_idx]
        assert max_length == sample['max_length']

    assert sample_idx is not None
    assert text_sample is not None

    sentences = sent_tokenize(text_sample)
    prefix_text = ' '.join(sentences[:len(sentences)//2])
    suffix_text = ' '.join(sentences[len(sentences)//2:])

    has_special_tokens = (tokenizer('text text', add_special_tokens=True)['input_ids'] !=
                          tokenizer('text text', add_special_tokens=False)['input_ids'])
    
    if option == 3:
        suffix_inp = tokenizer(suffix_text, max_length=max_length, truncation=True, return_tensors='pt')
        if has_special_tokens:
            suffix_inp['input_ids'] = suffix_inp['input_ids'][:,1:]
            suffix_inp['attention_mask'] = suffix_inp['attention_mask'][:,1:]
        suffix_len = suffix_inp['input_ids'].shape[-1]
        suffix_text = tokenizer.decode(suffix_inp['input_ids'][0])
        # concat prefix text with suffix text and add space between them
        # inp_text = prefix_text[-(prefix_length+1)*20:] + ' ' + suffix_text
        inp_text = prefix_text[-(prefix_length+1)*20:] + ' ' + suffix_text
        # tokenize it
        inp = tokenizer(inp_text, return_tensors='pt')
        # cut inp to have length == suffix_len + desired_prefix_len
        # mb take only tokens that are the same in suffix and inp
        new_inp_len = suffix_len + prefix_length
        inp['input_ids'] = inp['input_ids'][:,-new_inp_len:]
        inp['attention_mask'] = inp['attention_mask'][:,-new_inp_len:]
        # check that last tokens from inp[suffix_len-1:] == suffix_tokens[1:]
        # (except the first one as it may change because of space)
        assert (inp['input_ids'][:,-(suffix_len-1):] == suffix_inp['input_ids'][:,1:]).all(), "ne ok"
    elif option == 2:
        suffix_inp = tokenizer(suffix_text, max_length=max_length, truncation=True, return_tensors='pt')
        if has_special_tokens:
            # remove bos token from text that was compressed
            suffix_inp['input_ids'] = suffix_inp['input_ids'][:,1:]
            suffix_inp['attention_mask'] = suffix_inp['attention_mask'][:,1:]
        suffix_len = suffix_inp['input_ids'].shape[-1]

        prefix_text = prefix_text[-(prefix_length+1)*20:] + ' '
        prefix_inp = tokenizer(prefix_text, return_tensors='pt')
        inp = prefix_inp
        inp['input_ids'] = inp['input_ids'][:,-prefix_length:]
        inp['attention_mask'] = inp['attention_mask'][:,-prefix_length:]
        inp['input_ids'] = torch.cat([inp['input_ids'], suffix_inp['input_ids']], axis=1)
        inp['attention_mask'] = torch.cat([inp['attention_mask'], suffix_inp['attention_mask']], axis=1)
        # check that last tokens from inp[-suffix_len:] == suffix_tokens
        assert (inp['input_ids'][:,-suffix_len:] == suffix_inp['input_ids']).all(), "not ok"


    with torch.cuda.amp.autocast(dtype=dtype):
        with torch.no_grad():
            inp = inp.to(device)
            output = model(**inp, labels=inp['input_ids'])

            labels = inp['input_ids'][:,-suffix_len:]
            logits = output.logits[:,-suffix_len:]
            labels = labels.to(logits.device)

            accuracy = calculate_accuracy(logits, labels)
            shift_logits = logits[:, :-1, :].contiguous()
            labels = labels[:, 1:].contiguous()
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)).item()
            # print(f'{sample_idx} {accuracy:.3f} {loss:.3f}')
            # print(f'{original_accuracy:.3f} {original_loss:.3f} {best_accuracy:.3f} {best_loss:.3f}')
            # print('-----')
    res = {
        'sample_idx': sample_idx,
        'prefix_length': prefix_length,
        'max_length': max_length,
        'loss': loss,
        'accuracy': accuracy
    }
    
    if sample is not None:
        res.update(
            {
                'n_mem_tokens': sample['n_mem_tokens'],
                'original_loss': sample['original_loss'],
                'original_accuracy': sample['original_accuracy'],
                'best_loss': sample['best_loss'],
                'best_accuracy': sample['best_accuracy'],
            })
    return res

In [38]:


prefix_lengths = [64, 128, 512, 1024]
max_lengths = [32, 64, 96, 128, 256, 512, 1024, 1568]

results = {}

prefix_length = 512

option = 2

for sample in tqdm(mem_results):
    res = eval_model_with_text_prefix(model, tokenizer, max_length, prefix_length, sample=sample, texts_df=texts_df)
res

100%|██████████| 50/50 [00:03<00:00, 13.47it/s]


{'sample_idx': 49,
 'prefix_length': 512,
 'max_length': 64,
 'loss': 2.7170255184173584,
 'accuracy': 0.3968254327774048,
 'n_mem_tokens': 1,
 'original_loss': 3.432558536529541,
 'original_accuracy': 0.3650793731212616,
 'best_loss': 0.13693493604660034,
 'best_accuracy': 1.0}

In [115]:
model_name = 'EleutherAI/pythia-410m'
print(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
print(f"prefix_text: {tokenizer.tokenize(prefix_text[-30:])}")
print(f"prefix_text + ' ': {tokenizer.tokenize(prefix_text[-30:] + ' ')}")
print(f"suffix_text: {tokenizer.tokenize(suffix_text[:10])}")
print(f"prefix_text + ' ' + suffix_text: {tokenizer.tokenize(prefix_text[-30:] + ' ' + suffix_text[:10])}")
print('no special tokens:', tokenizer('text text', add_special_tokens=True)['input_ids'] == tokenizer('text text', add_special_tokens=False)['input_ids'])

EleutherAI/pythia-410m
prefix_text: ['turned', 'Ġhis', 'Ġhead', 'Ġaway', 'Ġfrom', 'Ġhim', '.']
prefix_text + ' ': ['turned', 'Ġhis', 'Ġhead', 'Ġaway', 'Ġfrom', 'Ġhim', '.', 'Ġ']
suffix_text: ['His', 'Ġeyes', 'Ġm']
prefix_text + ' ' + suffix_text: ['turned', 'Ġhis', 'Ġhead', 'Ġaway', 'Ġfrom', 'Ġhim', '.', 'ĠHis', 'Ġeyes', 'Ġm']
no special tokens: True


In [216]:
model_name = "meta-llama/Llama-3.2-1B"
print(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
print(f"prefix_text: {tokenizer.tokenize(prefix_text[-30:])}")
print(f"prefix_text + ' ': {tokenizer.tokenize(prefix_text[-30:] + ' ')}")
print(f"suffix_text: {tokenizer.tokenize(suffix_text[:10])}")
print(f"prefix_text + ' ' + suffix_text: {tokenizer.tokenize(prefix_text[-30:] + ' ' + suffix_text[:10])}")
print('no special tokens:', tokenizer('text text', add_special_tokens=True)['input_ids'] == tokenizer('text text', add_special_tokens=False)['input_ids'])

meta-llama/Llama-3.2-1B
prefix_text: ['t', 'ance', 'Ġto', 'Ġits', 'Ġspiritual', 'Ġvalue', '.', 'Ġ']
prefix_text + ' ': ['t', 'ance', 'Ġto', 'Ġits', 'Ġspiritual', 'Ġvalue', '.', 'ĠĠ']
suffix_text: ['This', 'Ġfact', 'Ġ']
prefix_text + ' ' + suffix_text: ['t', 'ance', 'Ġto', 'Ġits', 'Ġspiritual', 'Ġvalue', '.', 'Ġ', 'ĠThis', 'Ġfact', 'Ġ']
no special tokens: False


In [40]:
print(tokenizer('text text', add_special_tokens=True)['input_ids'])
print(tokenizer('text text', add_special_tokens=False)['input_ids'])

[1156, 2505]
[1156, 2505]


## eval models with prefixes, dump results

In [3]:
import json
# model_names = ['EleutherAI/pythia-410m', 'EleutherAI/pythia-1.4b',
#                'meta-llama/Llama-3.2-1B', 'meta-llama/Meta-Llama-3.1-8B']
model_names = ['meta-llama/Llama-3.2-1B', 'meta-llama/Meta-Llama-3.1-8B']

prefix_lengths = [64, 128, 512, 1024]
max_lengths = [2048] #[64, 96, 128, 256, 512, 1024, 1568]
N_mem_tokens = 1

texts_path = './data/pg19_valid_1k_chunks.csv'

import pandas as pd
texts_df = pd.read_csv(texts_path, index_col=0)

desc = f"Running:"
progress_bar = tqdm(total=len(model_names) * len(max_lengths) * len(prefix_lengths), desc=desc, leave=False)

for model_name in model_names:
    for max_length in max_lengths:

        mem_results_path = Path(f'./runs/{model_name}/mem_{N_mem_tokens}_len_{max_length}.pkl')
        prefix_results_path = mem_results_path.parent / 'with_prefix' / f'mem_{N_mem_tokens}_len_{max_length}.json'
        if not mem_results_path.exists():
            print(f'skipping {model_name} with text_length: {max_length}')
            progress_bar.update(len(prefix_lengths))
            continue

        mem_results = pickle.load(open(mem_results_path, 'rb'))

        device = 'cuda'
        dtype = mem_results[0]['args']['dtype']
        use_flash_attention_2 = mem_results[0]['args']['use_flash_attention_2']

        # dtype = getattr(torch, dtype)
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name,
                                                     use_flash_attention_2=use_flash_attention_2)
        model = model.to(device)

        model_max_length = getattr(model.config, 'max_position_embeddings')

        results = {}

        for prefix_length in prefix_lengths:
            progress_bar.set_postfix(m=model_name, l=max_length, p=prefix_length)
            if model_max_length < prefix_length + max_length:
                print(f'skipping {model_name} with text_length: {max_length}, prefix_length: {prefix_length}')
                progress_bar.update(1)
                continue
            
            results[prefix_length] = []
            
            for sample in mem_results:
                res = eval_model_with_text_prefix(model, tokenizer, max_length, prefix_length, dtype,
                                                  sample=sample, texts_df=texts_df)
                results[prefix_length] += [res]
            progress_bar.update(1)
        
        prefix_results_path.parent.mkdir(parents=True, exist_ok=True)
        json.dump(results, prefix_results_path.open('w'), indent=4)

progress_bar.close()

Running::   0%|          | 0/8 [00:00<?, ?it/s]The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
Flash Attention 2.0 only 

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

                                                                                                       