# Imports

In [None]:
from __future__ import annotations
from typing import Tuple, List, Dict, Optional, Any

import numpy as np
import torch
"""import torch.nn as nn
import torch.nn.functional as F"""
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset, DownloadMode

import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
import tqdm

from helper_utils.enum_keys import (
    FPKey,
    DirPath,
    QuantStyle,
    MiscPrompts,
    Contexts,
    Texts,
    ModelKey
)

from PTQ.apply_ptq import applyPTQ
from PTQ.olmo_act_fns import patch_olmo_mlp
import helper_utils.utils as utils

from mech_interp_utils.utils_main.src.transformer_utils import (
    logit_lens,
    activation_lens,
    dictionary_learning,
    chatbot_analysis
)

import warnings
warnings.filterwarnings('ignore')

In [None]:
"""import torch._dynamo
torch._dynamo.config.suppress_errors = True"""

In [None]:
LL_DIR = 'Outputs/Report/LogitLens'
SAE_DIR = 'Outputs/SAE'
MISC_DIR = 'Outputs/Report/Misc'

In [None]:
filepath = r'D:\ThesisData\wikitext'

destination_path = str(Path(filepath))
dataset = load_dataset(
    'wikitext', 'wikitext-103-raw-v1',
    split={
        'train': 'train[:30%]',
        'validation': 'validation[:10%]',
        'test': 'test[:10%]',
    },
    cache_dir=destination_path,
    download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
    keep_in_memory=True
)

In [None]:
train_texts = dataset['train']

In [None]:
calibration_texts = [t for t in dataset['train']["text"] if isinstance(t, str) and t.strip()]

In [None]:
sub_txts = train_texts.take(20)

In [None]:
#calibration_texts = [t for t in sub_txts["text"] if isinstance(t, str) and t.strip()]

### Params

In [None]:
PARAMS:Dict = {
    'context': Contexts.C1.value,
    'prompt': MiscPrompts.Q11.value,
    'max_new_tokens': 250,
    'temperature': 0.8,
    'repetition_penalty': 1.1,
    'sample': True,
    'device': None
}

# Models and Tokenizer

### Load fp models func

In [None]:
def load_test_model(model_path:str, dtype=torch.dtype) -> AutoModelForCausalLM:
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        return_dict=True,
        output_hidden_states=True,
        torch_dtype=dtype,
        low_cpu_mem_usage=True,
        local_files_only=True,
        use_safetensors=True
    )

    return model

In [None]:
olmo_tokenizer = AutoTokenizer.from_pretrained(FPKey.OLMO1B_TOKENIZER.value)

In [None]:
olmo_fp16 = load_test_model(FPKey.OLMO1B_FP.value, dtype=torch.float16)

In [None]:
olmo_fp16

In [None]:
olmo_fp16.config

In [None]:
olmo1b_bitnet_fp32 = applyPTQ(
    load_test_model(FPKey.OLMO1B_FP.value, dtype=torch.float32),
    tokenizer=olmo_tokenizer,
    #calibration_input=sub_txts['text'],
    calibration_input=PARAMS.get('prompt'),
    mode='1.58bit',
    safer_quant=True,
    model_half=False,
    quant_half=False,
    layers_to_quant_weights=QuantStyle.BITNET.value,
    layers_to_quant_activations=QuantStyle.BITNET.value,
    fragile_layers=False,
    act_quant=True,
    act_bits=8,
    debugging=True,
    plot_debugging=False,
    plot_quantization=False,
    freeze_modules=True
)

In [None]:
olmo1b_bitnet_fp32

### FP model and tokenizer

In [None]:
deep3b_tokenizer = AutoTokenizer.from_pretrained(FPKey.TOKENIZER_3B.value, local_files_only=True)

In [None]:
deep3b_fp32 = load_test_model(FPKey.FP16_3B.value, dtype=torch.float32)

In [None]:
deep3b_fp32

In [None]:
deep3b_fp16 = load_test_model(FPKey.FP16_3B.value, dtype=torch.float16)

In [None]:
for k, v in deep3b_fp32.named_parameters():
    print(v.data.dtype)

### Quantized Deep Hermes 3B

In [None]:
deep3b_fp32_158bit_ptsq = applyPTQ(
    load_test_model(FPKey.FP16_3B.value, dtype=torch.float32),
    tokenizer=deep3b_tokenizer,
    calibration_input="The quick brown fox jumps over the  lazy dog",
    mode='1.58bit',
    qkv_safety=False,
    safer_quant=False,
    ffq=False,
    model_half=False,
    quant_half=True,
    layers_to_quant=QuantStyle.BITNET.value,
    act_quant=True,
    act_bits=8,
    dropout_prob=0.0,
    redundancy=0,
    frame_dropout_prob=0.0
)

# Chatbot Analysis (text template only)

#### 3B Version Dicts

In [None]:
"""dict_3b_fp16_sym = {
    #'3b-fp16': deep3b_fp16,
    #'3b-fp16-1.58bit-ptsq': deep3b_fp16_158bit_ptsq,
    #'3b-fp16-2bit-ptsq-sym': deep3b_fp16_2bit_ptsq_sym,
    #'3b-fp16-2bit-ffsq-sym': deep3b_fp16_2bit_ffsq_sym,
}"""

dict_3b_fp32_sym = {
    #'3b-fp32': deep3b_fp32,
    '3b-fp32-1.58bit-ptsq': deep3b_fp32_158bit_ptsq,
    #'3b-fp32-2bit-ffsq-sym': deep3b_fp32_2bit_ffsq_sym,
    #'3b-fp32-2bit-ptsq-sym': deep3b_fp32_2bit_ptsq_sym
}
"""
dict_3b_fp16_asym = {
    '3b-fp16': deep3b_fp16,
    '3b-fp16-1.58bit-ptsq': deep3b_fp16_158bit_ptsq,
    '3b-fp16-2bit-ffsq-asym': deep3b_fp16_2bit_ffsq_asym,
    '3b-fp16-2bit-ptsq-asym': deep3b_fp16_2bit_ptsq_asym
}

dict_3b_fp32_asym = {
    '3b-fp32': deep3b_fp32,
    '3b-fp32-1.58bit-ptsq': deep3b_fp32_158bit_ptsq,
    '3b-fp32-2bit-ffsq-asym': deep3b_fp32_2bit_ffsq_asym,
    '3b-fp32-2bit-ptsq-asym': deep3b_fp32_2bit_ptsq_asym
}"""

In [None]:
chatbot_analysis.run_chatbot_analysis(
    models=dict_3b_fp32_sym,
    tokenizer=deep3b_tokenizer,
    #device='cpu',
    full_path='logs/chatbot_logs/DeepHermes3B/SymPTQ'
)

In [None]:
chatbot_analysis.plot_chatbot_analysis(
    json_logs='logs/chatbot_logs/DeepHermes3B/SymPTQ',
    parallel_plot=True
)

# Activation Lens and Logit Lens

In [None]:
logit_lens.plot_logit_lens(
    model=olmo_fp16,
    tokenizer=olmo_tokenizer,
    input_ids=MiscPrompts.Q11.value,
    start_ix=0, end_ix=15,
    #save_fig_path=None,
    save_fig_path='Outputs/LogitLens/logits_allenai-olmo1b-fp16.jpg',
    #entropy=True,
)

In [None]:
logit_lens.plot_logit_lens(
    model=olmo1b_bitnet_fp32,
    tokenizer=olmo_tokenizer,
    input_ids=MiscPrompts.Q11.value,
    start_ix=0, end_ix=15,
    #save_fig_path=None,
    save_fig_path='Outputs/LogitLens/logits_allenai-olmo1b-singlecalibration.jpg',
    #entropy=True,
)

In [None]:
logit_lens.plot_comparing_lens(
    models=(olmo_fp16, olmo1b_bitnet_fp32),
    tokenizer=olmo_tokenizer,
    input_ids=PARAMS.get('prompt'),
    start_ix=0, end_ix=15,
    save_fig_path='Outputs/LogitLens/nwd_allenai-olmo1bfp16-ptsq-singlecalibration.jpg',
    #save_fig_path=None,
    wasserstein=True,
    #top_down=False,
)

In [None]:
olmo1b_bitnet_fp32

In [None]:
logit_lens.plot_topk_lens(
    model=olmo_fp16,
    tokenizer=olmo_tokenizer,
    input_ids=PARAMS.get('prompt'),
    start_ix=0, end_ix=15,
    topk_n=5,
    save_fig_path='Outputs/LogitLens/tokk5-logits_allenai-olmo1b-fp16.jpg'
    #save_fig_path=None,
    #entropy=True,
    #top_down=False,
)

In [None]:
logit_lens.plot_topk_lens(
    model=olmo1b_bitnet_fp32,
    tokenizer=olmo_tokenizer,
    input_ids=PARAMS.get('prompt'),
    start_ix=0, end_ix=15,
    topk_n=5,
    save_fig_path='Outputs/LogitLens/topk5-logits_allenai-olmo1b-ptdq.jpg'
    #save_fig_path=None,
    #entropy=True,
    #top_down=False,
)

In [None]:
activation_lens.plot_activation_lens(
    model=olmo_fp16,
    tokenizer=olmo_tokenizer,
    input_ids=PARAMS.get('prompt'),
    start_ix=0, end_ix=15,
    metric='norm',
    save_fig_path='Outputs/LogitLens/actnorm_allenai-olmo1b-fp16.jpg'
    #save_fig_path=None,
)

In [None]:
activation_lens.plot_activation_lens(
    model=olmo1b_bitnet_fp32,
    tokenizer=olmo_tokenizer,
    input_ids=PARAMS.get('prompt'),
    start_ix=0, end_ix=15,
    metric='norm',
    save_fig_path='Outputs/LogitLens/actnorm_allenai-olmo1b-ptsdq.jpg'
    #save_fig_path=None,
)

In [None]:
activation_lens.plot_comparing_act_lens(
    models=(olmo_fp16, olmo1b_bitnet_fp32),
    tokenizer=olmo_tokenizer,
    input_ids=PARAMS.get('prompt'),
    start_ix=0, end_ix=15,
    metric='norm',
    metric_name='l2',
    save_fig_path='Outputs/LogitLens/actnorml2_allenai-olmo1bfp16-ptdq.jpg'
    #save_fig_path=None,
)

# Dictionary Learning: SAE

In [None]:
dictionary_learning.plot_sae_tokens(
    model=olmo1b_bitnet_fp32,
    tokenizer=olmo_tokenizer,
    inputs=PARAMS.get('prompt'),
    multi_tokens=False,
    do_log=False,
    target_layers=[5],
    vis_projection=None,
    log_path=None,
    log_name=None,
    fig_path=None
)

In [None]:
dictionary_learning.plot_sae_tokens(
    model=olmo1b_bitnet_fp32,
    tokenizer=olmo_tokenizer,
    inputs=PARAMS.get('prompt'),
    multi_tokens=True,
    do_log=False,
    target_layers=[5],
    vis_projection=None,
    log_path=None,
    log_name=None,
    fig_path=None
)

In [None]:
dictionary_learning.plot_sae_tokens(
    model=olmo_fp16,
    tokenizer=olmo_tokenizer,
    inputs=PARAMS.get('prompt'),
    multi_tokens=True,
    do_log=False,
    target_layers=[5],
    vis_projection=None,
    log_path=None,
    log_name=None,
    fig_path=None
)

In [None]:
dictionary_learning.plot_sae_heatmap(
    model=olmo_fp16,
    tokenizer=olmo_tokenizer,
    inputs=Texts.T1.value,
    do_log=False,
    top_k=5,
    tokens_per_row=30,
    target_layers=[5],
    log_path=None,
    log_name=None,
    fig_path=None
)

In [None]:
dictionary_learning.plot_sae_heatmap(
    model=olmo1b_bitnet_fp32,
    tokenizer=olmo_tokenizer,
    inputs=Texts.T1.value,
    do_log=False,
    top_k=5,
    tokens_per_row=30,
    target_layers=[5],
    log_path=None,
    log_name=None,
    fig_path=None
)

In [None]:
dictionary_learning.plot_comparing_heatmap(
    models=(olmo_fp16, olmo1b_bitnet_fp32),
    tokenizer=olmo_tokenizer,
    inputs=Texts.T1.value,
    top_k=5,
    tokens_per_row=30,
    target_layers=[5],
    #fig_path='Outputs/PTQ/sae5_deep3b_fp32_158bit_sym-math.png',
    fig_path=None
)