# Imports

In [None]:
from __future__ import annotations
from typing import Tuple, List, Dict, Optional, Any

import numpy as np
import torch
"""import torch.nn as nn
import torch.nn.functional as F"""

from transformers import AutoModelForCausalLM, AutoTokenizer

import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
import tqdm

from helper_utils.enum_keys import (
    FPKey,
    DirPath,
    QuantStyle,
    MiscPrompts,
    Contexts,
    Texts,
    ModelKey
)

from PTQ.apply_ptq import applyPTQ

import helper_utils.utils as utils

from mech_interp_utils.utils_main.src.transformer_utils import (
    logit_lens,
    activation_lens,
    dictionary_learning,
    chatbot_analysis
)

import warnings
warnings.filterwarnings('ignore')

In [None]:
"""import torch._dynamo
torch._dynamo.config.suppress_errors = True"""

In [None]:
LL_DIR = 'Outputs/Report/LogitLens'
SAE_DIR = 'Outputs/SAE'
MISC_DIR = 'Outputs/Report/Misc'

### Params

In [None]:
PARAMS:Dict = {
    'context': Contexts.C1.value,
    'prompt': MiscPrompts.Q2.value,
    'max_new_tokens': 250,
    'temperature': 0.8,
    'repetition_penalty': 1.1,
    'sample': True,
    'device': None
}

In [None]:
QuantStyle.BITNET.value

# Models and Tokenizer

### Load fp models func

In [None]:
def load_test_model(model_path:str, dtype=torch.dtype) -> AutoModelForCausalLM:
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        return_dict=True,
        output_hidden_states=True,
        torch_dtype=dtype,
        low_cpu_mem_usage=True,
        local_files_only=True,
        use_safetensors=True
    )

    return model

In [None]:
FPKey.OLMO7B_FP.value

In [None]:
FPKey.OLMO1B_FP.value

In [None]:
olmo1b_tokenizer = AutoTokenizer.from_pretrained(FPKey.OLMO1B_TOkENIZER.value)

In [None]:
olmo1b_fp16 = load_test_model(FPKey.OLMO1B_FP.value, dtype=torch.float16)

In [None]:
olmo1b_fp16

In [None]:
olmo1b_bitnet_fp32 = applyPTQ(
    load_test_model(FPKey.OLMO7B_FP.value, dtype=torch.float32),
    tokenizer=olmo1b_tokenizer,
    calibration_input=MiscPrompts.Q11.value,
    mode='1.58bit',
    qkv_safety=False,
    safer_quant=False,
    ffq=False,
    model_half=False,
    quant_half=True,
    layers_to_quant=QuantStyle.BITNET.value,
    act_quant=True,
    act_bits=8,
    dropout_prob=0.0,
    redundancy=0,
    frame_dropout_prob=0.0
)

## NousResearch/DeepHermes-3-Llama-3-8B-Preview (NOT RUN YET!!)

## NousResearch/DeepHermes-3-Llama-3-3B-Preview (28 layers 0-27)

### FP model and tokenizer

In [None]:
deep3b_tokenizer = AutoTokenizer.from_pretrained(FPKey.TOKENIZER_3B.value, local_files_only=True)

In [None]:
deep3b_fp32 = load_test_model(FPKey.FP16_3B.value, dtype=torch.float32)

In [None]:
deep3b_fp32

In [None]:
deep3b_fp16 = load_test_model(FPKey.FP16_3B.value, dtype=torch.float16)

In [None]:
for k, v in deep3b_fp32.named_parameters():
    print(v.data.dtype)

### Quantized Deep Hermes 3B

In [None]:
deep3b_fp32_158bit_ptsq = applyPTQ(
    load_test_model(FPKey.FP16_3B.value, dtype=torch.float32),
    tokenizer=deep3b_tokenizer,
    calibration_input="The quick brown fox jumps over the  lazy dog",
    mode='1.58bit',
    qkv_safety=False,
    safer_quant=False,
    ffq=False,
    model_half=False,
    quant_half=True,
    layers_to_quant=QuantStyle.BITNET.value,
    act_quant=True,
    act_bits=8,
    dropout_prob=0.0,
    redundancy=0,
    frame_dropout_prob=0.0
)

In [None]:
deep3b_fp32_158bit_ptsq

In [None]:
utils.print_model_weights(deep3b_fp32_158bit_ptsq)

In [None]:
ws=utils.get_weights(deep3b_fp32_158bit_ptsq)
ws

In [None]:
deep3b_fp32_2bit_ptsq_sym = applyPTQ(
    load_test_model(FPKey.FP16_3B.value, dtype=torch.float32),
    tokenizer=deep3b_tokenizer,
    calibration_input=PARAMS.get('prompt'),
    mode='8bit_sym',
    qkv_safety=False,
    safer_quant=False,
    ffq=False,
    model_half=False,
    quant_half=False,
    layers_to_quant=QuantStyle.BITNET.value,
    act_quant=False,
    act_bits=8,
    dropout_prob=0.0,
    redundancy=0,
    frame_dropout_prob=0.0
)

# Chatbot Analysis

#### 3B Version Dicts

In [None]:
"""dict_3b_fp16_sym = {
    #'3b-fp16': deep3b_fp16,
    #'3b-fp16-1.58bit-ptsq': deep3b_fp16_158bit_ptsq,
    #'3b-fp16-2bit-ptsq-sym': deep3b_fp16_2bit_ptsq_sym,
    #'3b-fp16-2bit-ffsq-sym': deep3b_fp16_2bit_ffsq_sym,
}"""

dict_3b_fp32_sym = {
    #'3b-fp32': deep3b_fp32,
    '3b-fp32-1.58bit-ptsq': deep3b_fp32_158bit_ptsq,
    #'3b-fp32-2bit-ffsq-sym': deep3b_fp32_2bit_ffsq_sym,
    #'3b-fp32-2bit-ptsq-sym': deep3b_fp32_2bit_ptsq_sym
}
"""
dict_3b_fp16_asym = {
    '3b-fp16': deep3b_fp16,
    '3b-fp16-1.58bit-ptsq': deep3b_fp16_158bit_ptsq,
    '3b-fp16-2bit-ffsq-asym': deep3b_fp16_2bit_ffsq_asym,
    '3b-fp16-2bit-ptsq-asym': deep3b_fp16_2bit_ptsq_asym
}

dict_3b_fp32_asym = {
    '3b-fp32': deep3b_fp32,
    '3b-fp32-1.58bit-ptsq': deep3b_fp32_158bit_ptsq,
    '3b-fp32-2bit-ffsq-asym': deep3b_fp32_2bit_ffsq_asym,
    '3b-fp32-2bit-ptsq-asym': deep3b_fp32_2bit_ptsq_asym
}"""

In [None]:
chatbot_analysis.run_chatbot_analysis(
    models=dict_3b_fp32_sym,
    tokenizer=deep3b_tokenizer,
    #device='cpu',
    full_path='logs/chatbot_logs/DeepHermes3B/SymPTQ'
)

In [None]:
chatbot_analysis.plot_chatbot_analysis(
    json_logs='logs/chatbot_logs/DeepHermes3B/SymPTQ',
    parallel_plot=True
)

# Activation Lens and Logit Lens

In [None]:
logit_lens.plot_logit_lens(
    model=olmo1b_fp16,
    tokenizer=olmo1b_tokenizer,
    input_ids=MiscPrompts.Q11.value,
    start_ix=0, end_ix=15,
    save_fig_path=None,
    #save_fig_path='Outputs/FP/ll_nwd_deep3b_fp16_math.png',
    #entropy=True,
)

In [None]:
deep3b_fp32_158bit_ptsq.cpu()

In [None]:
logit_lens.plot_logit_lens(
    model=olmo1b_bitnet_fp32,
    tokenizer=olmo1b_tokenizer,
    input_ids=MiscPrompts.Q11.value,
    start_ix=0, end_ix=15,
    save_fig_path=None,
    #save_fig_path='Outputs/FP/ll_logits_deep3b_fp32_math.png',
    #probs=True,
)

In [None]:
logit_lens.plot_logit_lens(
    model=deep3b_fp32_158bit_ptsq,
    tokenizer=deep3b_tokenizer,
    input_ids=PARAMS.get('prompt'),
    start_ix=0, end_ix=16,
    save_fig_path=None,
    #save_fig_path='Outputs/PTQ/ll_logits_deep3b_fp32_158bit_sym-math.png',
    #probs=True,
)

In [None]:
logit_lens.plot_comparing_lens(
    models=(deep3b_fp32, deep3b_fp32_158bit_ptsq),
    tokenizer=deep3b_tokenizer,
    input_ids=PARAMS.get('prompt'),
    start_ix=0, end_ix=16,
    save_fig_path='Outputs/PTQ/ll_nwd_deep3b_fp32_158bit_sym-math.png',
    #save_fig_path=None,
    wasserstein=True,
    #top_down=False,
)

In [None]:
logit_lens.plot_comparing_lens(
    models=(deep3b_fp32, deep3b_2bit_ffsq_sym),
    tokenizer=deep3b_tokenizer,
    input_ids=PARAMS.get('prompt'),
    start_ix=0, end_ix=16,
    #save_fig_path=f"{DirPath.LENS_VIS.value}/compare_nwd_deep3bfpptsq158bit_math.jpg",
    save_fig_path=None,
    wasserstein=True,
    #top_down=False,
)

In [None]:
logit_lens.plot_topk_lens(
    model=deep3b_fp32,
    tokenizer=deep3b_tokenizer,
    input_ids=PARAMS.get('prompt'),
    start_ix=0, end_ix=16,
    topk_n=5,
    #save_fig_path=f"{DirPath.LENS_VIS.value}/topk_5_logits_deep3bfp32_math.jpg",
    save_fig_path=None,
    #entropy=True,
    #top_down=False,
)

In [None]:
logit_lens.plot_topk_lens(
    model=deep3b_2bit_ffsq_sym,
    tokenizer=deep3b_tokenizer,
    input_ids=PARAMS.get('prompt'),
    start_ix=0, end_ix=16,
    topk_n=5,
    #save_fig_path=f"{DirPath.LENS_VIS.value}/topk_5_logits_deep3bfp32_math.jpg",
    save_fig_path=None,
    #entropy=True,
    #top_down=False,
)

In [None]:
logit_lens.plot_topk_lens(
    model=deep3b_2bit_ptsq_sym,
    tokenizer=deep3b_tokenizer,
    input_ids=PARAMS.get('prompt'),
    start_ix=0, end_ix=16,
    topk_n=5,
    #save_fig_path=f"{DirPath.LENS_VIS.value}/topk_5_logits_deep3bfp32_math.jpg",
    save_fig_path=None,
    #entropy=True,
    #top_down=False,
)

In [None]:
activation_lens.plot_activation_lens(
    model=deep3b_fp32,
    tokenizer=deep3b_tokenizer,
    input_ids=PARAMS.get('prompt'),
    start_ix=0, end_ix=16,
    metric='norm',
    save_fig_path='Outputs/FP/act_norm_deep3b_fp32_math.png',
    #save_fig_path=None,
)

In [None]:
activation_lens.plot_activation_lens(
    model=deep3b_fp32_2bit_ptsq_sym,
    tokenizer=deep3b_tokenizer,
    input_ids=PARAMS.get('prompt'),
    start_ix=0, end_ix=16,
    metric='norm',
    #save_fig_path='Outputs/PTQ/act_norm_deep3b_fp32_158bi-math.png',
    save_fig_path=None,
)

In [None]:
activation_lens.plot_activation_lens(
    model=deep3b_fp32_158bit_ptsq,
    tokenizer=deep3b_tokenizer,
    input_ids=PARAMS.get('prompt'),
    start_ix=0, end_ix=16,
    metric='norm',
    #save_fig_path=f"{DirPath.LENS_VIS.value}/act_norm_deep3b_fp32_math.jpg",
    save_fig_path=None,
)

In [None]:
activation_lens.plot_comparing_act_lens(
    models=(deep3b_fp32, deep3b_fp32_158bit_ptsq),
    tokenizer=deep3b_tokenizer,
    input_ids=PARAMS.get('prompt'),
    start_ix=0, end_ix=16,
    metric='norm',
    metric_name='l2',
    save_fig_path='Outputs/PTQ/act_norm-l2_deep3b_fp32_158bit_sym-math.png',
    #save_fig_path=None,
)

In [None]:
activation_lens.plot_comparing_act_lens(
    models=(deep3b_fp32, deep3b_158bit_ptsq),
    tokenizer=deep3b_tokenizer,
    input_ids=PARAMS.get('prompt'),
    start_ix=0, end_ix=16,
    metric='norm',
    metric_name='l2',
    #save_fig_path=f"{DirPath.LENS_VIS.value}/act_norml2_deep3b_fp_ffsq2bit_math.jpg",
    save_fig_path=None,
)

# Dictionary Learning: SAE

In [None]:
dictionary_learning.plot_sae_tokens(
    model=deep3b_fp32_158bit_ptsq,
    tokenizer=deep3b_tokenizer,
    inputs=PARAMS.get('prompt'),
    multi_tokens=False,
    do_log=False,
    target_layers=[5],
    vis_projection=None,
    log_path=None,
    log_name=None,
    fig_path=None
)

In [None]:
dictionary_learning.plot_sae_tokens(
    model=deep3b_fp32_158bit_ptsq,
    tokenizer=deep3b_tokenizer,
    inputs=PARAMS.get('prompt'),
    multi_tokens=True,
    do_log=False,
    target_layers=[5],
    vis_projection=None,
    log_path=None,
    log_name=None,
    fig_path=None
)

In [None]:
dictionary_learning.plot_sae_heatmap(
    model=deep3b_fp16_2bit_ptsq_sym,
    tokenizer=deep3b_tokenizer,
    inputs=Texts.T1.value,
    do_log=False,
    top_k=5,
    tokens_per_row=30,
    target_layers=[5],
    log_path=None,
    log_name=None,
    fig_path=None
)

In [None]:
dictionary_learning.plot_sae_heatmap(
    model=deep3b_fp32,
    tokenizer=deep3b_tokenizer,
    inputs=Texts.T1.value,
    do_log=False,
    top_k=5,
    tokens_per_row=30,
    target_layers=[5],
    log_path=None,
    log_name=None,
    fig_path='Outputs/FP/sae5_deep3b_fp32_math.png',
)

In [None]:
dictionary_learning.plot_sae_heatmap(
    model=deep3b_fp32_158bit_ptsq,
    tokenizer=deep3b_tokenizer,
    inputs=Texts.T1.value,
    do_log=False,
    top_k=5,
    tokens_per_row=30,
    target_layers=[5],
    log_path=None,
    log_name=None,
    #fig_path='Outputs/PTQ/sae5_deep3b_fp32_158bit_sym-math.png',
)

In [None]:
dictionary_learning.plot_comparing_heatmap(
    models=(deep3b_fp32, deep3b_fp32_158bit_ptsq),
    tokenizer=deep3b_tokenizer,
    inputs=Texts.T1.value,
    top_k=5,
    tokens_per_row=30,
    target_layers=[5],
    fig_path='Outputs/PTQ/sae5_deep3b_fp32_158bit_sym-math.png',
)

In [None]:
dictionary_learning.plot_comparing_heatmap(
    models=(deep3b_fp32, deep3b_158bit_ptsq),
    tokenizer=deep3b_tokenizer,
    inputs=Texts.T1.value,
    top_k=5,
    tokens_per_row=30,
    target_layers=[5],
    fig_path=None
)