In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import os
from dataclasses import dataclass
import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
from tqdm import tqdm

from utils import plot_ci, plot_ci_plus_heatmap
from llamawrapper import load_tokenizer_only
from eval_core import (
    get_latents_from_hf_model, 
    get_latents_from_bnb_model, 
    get_latents_from_gptq_model, 
    get_latents_from_awq_model, 
    get_latents_from_rtn_model, 
    get_latents_from_mpq_model,
    get_latents_from_mpqr_model,
)
# fix random seed
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

In [None]:
input_lang = 'en'
target_lang = 'fr'
model_size = '7b'
custom_model = "meta-llama/Llama-2-7b-chat-hf"
single_token_only = False
multi_token_only = False
out_dir = './visuals'
quant_type = "mpq_rand" # Can be among ["rtn,", "hf", "gptq", "awq", "twq", "mpq", "mpqr", "twq_rand", "mpq_rand", "mpqr_rand"]
precision = 4

In [None]:
if single_token_only and multi_token_only:
    raise ValueError('single_token_only and multi_token_only cannot be True at the same time')

In [None]:
prefix = "./data/langs/"
df_en_fr = pd.read_csv(f'{prefix}{input_lang}/clean.csv').reindex()
df_en_de = pd.read_csv(f'{prefix}{target_lang}/clean.csv').reindex()

In [None]:
tokenizer = load_tokenizer_only(custom_model)
# model = llama.model

In [None]:
count = 0
for idx, word in enumerate(df_en_de['word_translation']):
    if word in tokenizer.get_vocab() or '▁'+word in tokenizer.get_vocab():
        count += 1
        if multi_token_only:
            df_en_de.drop(idx, inplace=True)
    elif single_token_only:
        df_en_de.drop(idx, inplace=True)

print(f'for {target_lang} {count} of {len(df_en_de)} are single tokens')

if input_lang == target_lang:
    df_en_de_fr = df_en_de.copy()
    df_en_de_fr.rename(columns={'word_original': 'en', 
                                f'word_translation': target_lang if target_lang != 'en' else 'en_tgt'}, 
                                inplace=True)
else:
    df_en_de_fr = df_en_de.merge(df_en_fr, on=['word_original'], suffixes=(f'_{target_lang}', f'_{input_lang}'))
    df_en_de_fr.rename(columns={'word_original': 'en', 
                                f'word_translation_{target_lang}': target_lang if target_lang != 'en' else 'en_tgt', 
                                f'word_translation_{input_lang}': input_lang if input_lang != 'en' else 'en_in'}, 
                                inplace=True)
# delete all rows where en is contained in de or fr
if target_lang != 'en':
    for i, row in df_en_de_fr.iterrows():
        if row['en'].lower() in row[target_lang].lower():
            df_en_de_fr.drop(i, inplace=True)

print(f'final length of df_en_de_fr: {len(df_en_de_fr)}')

In [None]:
def token_prefixes(token_str: str):
    n = len(token_str)
    tokens = [token_str[:i] for i in range(1, n+1)]
    return tokens 

def add_spaces(tokens):
    return ['▁' + t for t in tokens] + tokens

def capitalizations(tokens):
    return list(set(tokens))

def unicode_prefix_tokid(zh_char = "云", tokenizer=tokenizer):
    start = zh_char.encode().__str__()[2:-1].split('\\x')[1]
    unicode_format = '<0x%s>'
    start_key = unicode_format%start.upper()
    if start_key in tokenizer.get_vocab():
        return tokenizer.get_vocab()[start_key]
    return None

def process_tokens(token_str: str, tokenizer, lang):
    with_prefixes = token_prefixes(token_str)
    with_spaces = add_spaces(with_prefixes)
    with_capitalizations = capitalizations(with_spaces)
    final_tokens = []
    for tok in with_capitalizations:
        if tok in tokenizer.get_vocab():
            final_tokens.append(tokenizer.get_vocab()[tok])
    if lang in ['zh', 'ru']:
        tokid = unicode_prefix_tokid(token_str, tokenizer)
        if tokid is not None:
            final_tokens.append(tokid)
    return final_tokens

In [None]:

id2voc = {id:voc for voc, id in tokenizer.get_vocab().items()}
def get_tokens(token_ids, id2voc=id2voc):
    return [id2voc[tokid] for tokid in token_ids]

def compute_entropy(probas):
    return (-probas*torch.log2(probas)).sum(dim=-1)

lang2name = {'fr': 'Français', 'de': 'Deutsch', 'ru': 'Русский', 'en': 'English', 'zh': '中文'}
def sample(df, ind, k=5, tokenizer=tokenizer, lang1='fr', lang2='de', lang_latent='en'):
    df = df.reset_index(drop=True)
    temp = df[df.index!=ind]
    sample = pd.concat([temp.sample(k-1), df[df.index==ind]], axis=0)
    prompt = ""
    for idx, (df_idx, row) in enumerate(sample.iterrows()):
        if idx < k-1:
            prompt += f'{lang2name[lang1]}: "{row[lang1]}" - {lang2name[lang2]}: "{row[lang2]}"\n'
        else:
            prompt += f'{lang2name[lang1]}: "{row[lang1]}" - {lang2name[lang2]}: "'
            in_token_str = row[lang1]
            out_token_str = row[lang2]
            out_token_id = process_tokens(out_token_str, tokenizer, lang2)
            latent_token_str = row[lang_latent]
            latent_token_id = process_tokens(latent_token_str, tokenizer, 'en')
            intersection = set(out_token_id).intersection(set(latent_token_id))
            if len(out_token_id) == 0 or len(latent_token_id) == 0:
                yield None
            if lang2 != 'en' and len(intersection) > 0:
                yield None
            yield {'prompt': prompt, 
                'out_token_id': out_token_id, 
                'out_token_str': out_token_str,
                'latent_token_id': latent_token_id, 
                'latent_token_str': latent_token_str, 
                'in_token_str': in_token_str}

In [None]:
dataset = []
for ind in tqdm(range(len(df_en_de_fr))):
    d = next(sample(df_en_de_fr, ind, lang1=input_lang, lang2=target_lang))
    if d is None:
        continue
    dataset.append(d)

In [None]:
if quant_type in ["mpq_rand", "mpqr_rand", "twq_rand"]:
    path = "../neurons/random.neuron.pth"
    target_neurons = torch.load(path)
    quant_type = quant_type[:-5]
    model_label = f"{quant_type}_{precision}bit_random"
elif quant_type in ["mpq", "mpqr", "twq"]:
    path = "../neurons/combined.neuron.pth"
    target_neurons = torch.load(path)
    model_label = f"{quant_type}_{precision}bit_combined"
else:
    target_neurons = None
    model_label = f"{quant_type}_{precision}bit"

In [None]:
df = pd.DataFrame(dataset)
os.makedirs(f'{os.path.join(out_dir, model_label)}/translation', exist_ok=True)
if single_token_only:
    df.to_csv(f'{os.path.join(out_dir, model_label)}/translation/{model_size}_{input_lang}_{target_lang}_dataset_single_token.csv', index=False)
elif multi_token_only:
    df.to_csv(f'{os.path.join(out_dir, model_label)}/translation/{model_size}_{input_lang}_{target_lang}_dataset_multi_token.csv', index=False)
else:
    df.to_csv(f'{os.path.join(out_dir, model_label)}/translation/{model_size}_{input_lang}_{target_lang}_dataset.csv', index=False)

In [None]:
match quant_type:
    case "hf": latent_getter = get_latents_from_hf_model
    case "bnb": latent_getter = get_latents_from_bnb_model
    case "gptq": latent_getter = get_latents_from_gptq_model 
    case "awq": latent_getter = get_latents_from_awq_model if precision == 4 else get_latents_from_hf_model
    case "rtn": latent_getter = get_latents_from_rtn_model
    case "twq": latent_getter = get_latents_from_hf_model
    case "mpq": latent_getter = get_latents_from_mpq_model
    case "mpqr": latent_getter = get_latents_from_mpqr_model
    case _: raise ValueError("Not yet implemented!")
        
latent_token_prob, out_token_prob, entropy, energy, latents = latent_getter(
    dataset, custom_model, precision=precision, targets=target_neurons)
latent_token_probs = latent_token_prob.cpu()
out_token_probs = out_token_prob.cpu()

In [None]:
size2tik = {'7b': 5, '13b': 5, '70b': 10, "tiny": 5}
fig, ax, ax2 = plot_ci_plus_heatmap(latent_token_probs, entropy, 'en', color='tab:orange', tik_step=size2tik[model_size], do_colorbar=True, #, do_colorbar=(model_size=='70b'),
nums=[.99, 0.18, 0.025, 0.6])
if target_lang != 'en':
    plot_ci(ax2, out_token_probs, target_lang, color='tab:blue', do_lines=False)
ax2.set_xlabel('layer')
ax2.set_ylabel('probability')
if model_size == '7b' or model_size == 'tiny':
    ax2.set_xlim(0, out_token_probs.shape[1]+1)
else:
    ax2.set_xlim(0, round(out_token_probs.shape[1]/10)*10+1)
ax2.set_ylim(0, 1)
# make xticks start from 1
# put legend on the top left
ax2.legend(loc='upper left')
os.makedirs(f'{os.path.join(out_dir, model_label)}/translation', exist_ok=True)
if single_token_only:
    plt.savefig(f'{os.path.join(out_dir, model_label)}/translation/{model_size}_{input_lang}_{target_lang}_probas_ent_single_token.pdf', dpi=300, bbox_inches='tight')
elif multi_token_only:
    plt.savefig(f'{os.path.join(out_dir, model_label)}/translation/{model_size}_{input_lang}_{target_lang}_probas_ent_multi_token.pdf', dpi=300, bbox_inches='tight')
else:
    plt.savefig(f'{os.path.join(out_dir, model_label)}/translation/{model_size}_{input_lang}_{target_lang}_probas_ent.pdf', dpi=300, bbox_inches='tight')

In [None]:
size2tik = {'7b': 5, '13b': 5, '70b': 10, "tiny": 5}

fig, ax2 = plt.subplots(figsize=(5,3))
plot_ci(ax2, energy.cpu(), 'energy', color='tab:green', do_lines=True, tik_step=size2tik[model_size])
ax2.set_xlabel('layer')
ax2.set_ylabel('energy')
if model_size == '7b' or model_size == 'tiny':
    ax2.set_xlim(0, out_token_probs.shape[1]+1)
else:
    ax2.set_xlim(0, round(out_token_probs.shape[1]/10)*10+1)
os.makedirs(f'{os.path.join(out_dir, model_label)}/translation', exist_ok=True)
if single_token_only:
    plt.savefig(f'{os.path.join(out_dir, model_label)}/translation/{model_size}_{input_lang}_{target_lang}_probas_ent_single_token.pdf', dpi=300, bbox_inches='tight')
elif multi_token_only:
    plt.savefig(f'{os.path.join(out_dir, model_label)}/translation/{model_size}_{input_lang}_{target_lang}_probas_ent_multi_token.pdf', dpi=300, bbox_inches='tight')
else:
    plt.savefig(f'{os.path.join(out_dir, model_label)}/translation/{model_size}_{input_lang}_{target_lang}_energy.pdf', dpi=300, bbox_inches='tight')

In [None]:
suffix = "_rtn" if quant_type == "rtn" else ""
torch.save(latent_token_probs, f'{os.path.join(out_dir, model_label)}/translation/{model_label}_{input_lang}_{target_lang}_latent_probs.pt')
torch.save(out_token_probs, f'{os.path.join(out_dir, model_label)}/translation/{model_label}_{input_lang}_{target_lang}_out_probs.pt')