In [1]:
import os
import transformers
import torch
import numpy as np
import pandas as pd
from pprint import pprint
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM
from unsloth import FastLanguageModel
# from prompts import QUERY_INTRO_NO_ANS, SYSTEM_MSG
from sklearn.model_selection import train_test_split
from torch import nn
from sklearn.metrics import roc_auc_score
from collections import Counter

# from config import *

  from .autonotebook import tqdm as notebook_tqdm


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


In [3]:
%reload_ext autoreload
%autoreload 2
from utils_train import fit, validate, to_dataloader

In [4]:
# MODEL_ID = "google/gemma-2-2b-it"
MODEL_ID = "unsloth/gemma-2-9b-it-bnb-4bit"
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
HUG_TOKEN = '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'
DATA_DIR = "data"
DATA_FILE = "test_gemma_resp.csv"
V1 = 'gemma-2-9b-it-bnb-4bit__v1'
V2 = 'gemma-2-9b-it-bnb-4bit__v2'

In [4]:
df = pd.read_csv(os.path.join("..", DATA_DIR, DATA_FILE))

In [120]:
class MLP(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.batch_norm = nn.BatchNorm1d(hidden_dim)
        self.elu = nn.ELU()
        self.dropout = nn.Dropout(0.25)

    def forward(self, x, mode: str = 'train'):
        x = self.fc1(x)
        x = self.batch_norm(x)
        x = self.elu(x)

        if mode == 'train':
            x = self.dropout(x)

        x = self.fc2(x)
        return x

In [15]:
def prepare_prompt(
        tokenizer,
        user_input: str, 
        system_input: str = "",
        has_system_role: bool = False) -> list:
    
    messages = []
    
    if has_system_role:
        messages.append({"role": "system", "content": system_input})

    messages = [
        {
            "role": "user", 
            "content": f"{system_input}{user_input}" 
                if not has_system_role 
                else user_input
        },
    ]

    prompt = tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )

    return prompt

In [1]:
pip list

Package                            Version
---------------------------------- ------------
accelerate                         0.33.0
aiohappyeyeballs                   2.3.5
aiohttp                            3.10.1
aiohttp-retry                      2.8.3
aiosignal                          1.3.1
alembic                            1.13.2
amqp                               5.2.0
aniso8601                          9.0.1
annotated-types                    0.7.0
antlr4-python3-runtime             4.9.3
anyio                              4.4.0
appdirs                            1.4.4
asttokens                          2.4.1
async-timeout                      4.0.3
asyncssh                           2.15.0
atpublic                           5.0
attrs                              24.2.0
auto_gptq                          0.7.1+cu118
bcrypt                             4.2.0
billiard                           4.2.0
bitsandbytes                       0.43.3
blinker                            1.8

In [16]:
def get_zero_indices(tensor):

    zero_mask = (tensor == 0).all(dim=(1, 2, 3))
    indices = torch.nonzero(zero_mask).squeeze()
    
    return indices

In [5]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_ID,
    max_seq_length=8192,
    dtype=None,
    device_map={"": DEVICE},
    load_in_4bit=True
)

FastLanguageModel.for_inference(model)

Unsloth: If you want to finetune Gemma 2, install flash-attn to make it faster!
To install flash-attn, do the below:

pip install --no-deps --upgrade "flash-attn>=2.6.3"
==((====))==  Unsloth 2024.8: Fast Gemma2 patching. Transformers = 4.43.2.
   \\   /|    GPU: NVIDIA L4. Max memory: 21.964 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.3.0+cu121. CUDA = 8.9. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.26.post1. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


In [21]:
input_text = "W którym roku Korona Kielce wygrała ekstraklasę?"
input_ids = tokenizer(input_text, return_tensors="pt").to(DEVICE)

In [11]:
input_ids

{'input_ids': tensor([[     2, 235325,  87004,  21653,  19849,   2977,  76950,    532,   6498,
           1025,   8222,  79476,  88182, 235508, 235336]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}

In [22]:
pred = model.forward(
    input_ids=input_ids.get("input_ids"),
    attention_mask=input_ids.get("attention_mask"),
    output_attentions=True,
)

AssertionError: 

In [19]:
for n in model.model.layers:
    print(n)

Gemma2DecoderLayer(
  (self_attn): Gemma2Attention(
    (q_proj): Linear4bit(in_features=3584, out_features=4096, bias=False)
    (k_proj): Linear4bit(in_features=3584, out_features=2048, bias=False)
    (v_proj): Linear4bit(in_features=3584, out_features=2048, bias=False)
    (o_proj): Linear4bit(in_features=4096, out_features=3584, bias=False)
    (rotary_emb): GemmaFixedRotaryEmbedding()
  )
  (mlp): Gemma2MLP(
    (gate_proj): Linear4bit(in_features=3584, out_features=14336, bias=False)
    (up_proj): Linear4bit(in_features=3584, out_features=14336, bias=False)
    (down_proj): Linear4bit(in_features=14336, out_features=3584, bias=False)
    (act_fn): PytorchGELUTanh()
  )
  (input_layernorm): Gemma2RMSNorm()
  (post_attention_layernorm): Gemma2RMSNorm()
  (pre_feedforward_layernorm): Gemma2RMSNorm()
  (post_feedforward_layernorm): Gemma2RMSNorm()
)
Gemma2DecoderLayer(
  (self_attn): Gemma2Attention(
    (q_proj): Linear4bit(in_features=3584, out_features=4096, bias=False)
    (k_p

In [14]:
pred.hidden_states

(tensor([[[-0.6719,  0.1553,  0.4805,  ..., -0.0378,  0.2598,  0.1982],
          [-0.9141, -0.0952, -2.4844,  ..., -1.0234,  0.5898, -1.0469],
          [ 0.3008, -0.8789,  1.4297,  ..., -0.7422,  0.1338, -0.6055],
          ...,
          [-2.9375, -1.6250, -6.1562,  ...,  2.2969,  1.0859,  0.5586],
          [-0.3809, -3.1406, -0.6406,  ...,  1.2500, -1.7422, -2.6719],
          [ 0.3281, -3.1406, -1.6016,  ..., -0.3379, -3.1875, -1.0703]]],
        device='cuda:0', dtype=torch.bfloat16),
 tensor([[[-0.2334, -1.9766, -0.0742,  ..., -0.1406,  0.0352,  0.2178],
          [-0.4590,  0.7617, -2.1562,  ..., -1.0000,  1.9531, -0.2461],
          [ 1.5000,  2.7188,  0.6953,  ..., -0.4258, -0.7227, -0.3828],
          ...,
          [-2.1875, -2.3125, -4.6562,  ...,  3.7344,  1.2422,  1.0234],
          [ 0.0869, -1.4688,  0.4219,  ...,  0.7305, -1.6641, -2.1250],
          [ 0.7422,  1.1094, -0.3633,  ..., -0.3848, -1.7734, -0.4805]]],
        device='cuda:0', dtype=torch.bfloat16),
 tenso

In [8]:
del input_ids, pred

In [9]:
for i, row in df.iterrows():
    
    query = row[QUERY_COL]
    context = row[CONTEXT_COL]
    previous_answer = row['gemma-2-9b-it-bnb-4bit']
    
    augumented_prompt = QUERY_INTRO_NO_ANS.format(query = query, context = context)
    prompt = prepare_prompt(tokenizer, augumented_prompt, SYSTEM_MSG)

    input_ids = tokenizer(prompt, return_tensors="pt").to(DEVICE)

    try:

        pred = model.forward(
            input_ids=input_ids.get("input_ids"),
            output_hidden_states=True,
        )

        del input_ids

        pred.hidden_states = [h.detach().cpu() for h in pred.hidden_states]
        hidden_states = torch.stack(pred.hidden_states[-8:])[:, :, -1, :]

    except Exception as e:
        print(f"Error: {e}")

    else:
        print(f"{hidden_states.shape = }, {type(hidden_states) = }")
        print(f"Saving hidden states for query: {i}")
        torch.save(hidden_states, os.path.join(DATA_DIR, 'hidden_states', f'hidden_states_{i}.pt'))
        print(f"Hidden states saved for query: {i}")
        
    finally:
        torch.cuda.empty_cache()

hidden_states.shape = torch.Size([8, 1, 3584]), type(hidden_states) = <class 'torch.Tensor'>
Saving hidden states for query: 0
Hidden states saved for query: 0
hidden_states.shape = torch.Size([8, 1, 3584]), type(hidden_states) = <class 'torch.Tensor'>
Saving hidden states for query: 1
Hidden states saved for query: 1
hidden_states.shape = torch.Size([8, 1, 3584]), type(hidden_states) = <class 'torch.Tensor'>
Saving hidden states for query: 2
Hidden states saved for query: 2
hidden_states.shape = torch.Size([8, 1, 3584]), type(hidden_states) = <class 'torch.Tensor'>
Saving hidden states for query: 3
Hidden states saved for query: 3
hidden_states.shape = torch.Size([8, 1, 3584]), type(hidden_states) = <class 'torch.Tensor'>
Saving hidden states for query: 4
Hidden states saved for query: 4
hidden_states.shape = torch.Size([8, 1, 3584]), type(hidden_states) = <class 'torch.Tensor'>
Saving hidden states for query: 5
Hidden states saved for query: 5
hidden_states.shape = torch.Size([8, 1, 

Unsloth: Input IDs of length 9125 > the model's max sequence length of 8192.
We shall truncate it ourselves. It's imperative if you correct this issue first.


Error: CUDA out of memory. Tried to allocate 56.00 MiB. GPU 
hidden_states.shape = torch.Size([8, 1, 3584]), type(hidden_states) = <class 'torch.Tensor'>
Saving hidden states for query: 273
Hidden states saved for query: 273
hidden_states.shape = torch.Size([8, 1, 3584]), type(hidden_states) = <class 'torch.Tensor'>
Saving hidden states for query: 274
Hidden states saved for query: 274
hidden_states.shape = torch.Size([8, 1, 3584]), type(hidden_states) = <class 'torch.Tensor'>
Saving hidden states for query: 275
Hidden states saved for query: 275
hidden_states.shape = torch.Size([8, 1, 3584]), type(hidden_states) = <class 'torch.Tensor'>
Saving hidden states for query: 276
Hidden states saved for query: 276
hidden_states.shape = torch.Size([8, 1, 3584]), type(hidden_states) = <class 'torch.Tensor'>
Saving hidden states for query: 277
Hidden states saved for query: 277
hidden_states.shape = torch.Size([8, 1, 3584]), type(hidden_states) = <class 'torch.Tensor'>
Saving hidden states for q

In [28]:
sorted_files = sorted(os.listdir(os.path.join(DATA_DIR, 'hidden_states')), key=lambda x: int(x.split('.')[0].split('_')[-1]))

In [29]:
dataset_vec = None
missing = []

for f in sorted_files[:300]:

    print(f"Processing hidden states for query: {f}")
    
    hidden_states = torch.load(os.path.join(DATA_DIR, 'hidden_states', f))
    if torch.all(hidden_states == 0):
        print(f"Hidden states are all zero for query: {f}")
        missing.append(f)
    
    if dataset_vec is None:
        dataset_vec = hidden_states
    else:
        dataset_vec = torch.cat((dataset_vec, hidden_states), dim=0)

Processing hidden states for query: hidden_states_0.pt
Processing hidden states for query: hidden_states_1.pt
Processing hidden states for query: hidden_states_2.pt
Processing hidden states for query: hidden_states_3.pt
Processing hidden states for query: hidden_states_4.pt
Processing hidden states for query: hidden_states_5.pt
Processing hidden states for query: hidden_states_6.pt
Processing hidden states for query: hidden_states_7.pt
Processing hidden states for query: hidden_states_8.pt
Processing hidden states for query: hidden_states_9.pt
Processing hidden states for query: hidden_states_10.pt
Processing hidden states for query: hidden_states_11.pt
Processing hidden states for query: hidden_states_12.pt
Processing hidden states for query: hidden_states_13.pt
Processing hidden states for query: hidden_states_14.pt
Processing hidden states for query: hidden_states_15.pt
Processing hidden states for query: hidden_states_16.pt
Processing hidden states for query: hidden_states_17.pt
Pr

In [35]:
missing

[164, 176, 269, 272]

In [20]:
# a, b, c = hidden_states.shape
# dataset_vec = dataset_vec.reshape(len(df), a, b, c).to(torch.float32).numpy()

In [13]:
X = np.load(os.path.join(DATA_DIR, 'X.npy'))

In [18]:
df[V1].value_counts(), df[V2].value_counts()

(gemma-2-9b-it-bnb-4bit__v1
  1    3147
 -1     548
  0     415
 Name: count, dtype: int64,
 gemma-2-9b-it-bnb-4bit__v2
  1    2888
 -1     813
  0     409
 Name: count, dtype: int64)

In [6]:
missing = []
to_skip_indicies = set(df.loc[df[V1] != df[V2]].index.tolist()).union(set(missing))
df_filtered = df.drop(to_skip_indicies)

In [17]:
df_filtered['context']

0       `Dokument [ "ProceduraMZS_Monitorowanie_SOZ_i_...
1       `Dokument [ "ProceduraMZS_Monitorowanie_SOZ_i_...
2       `Dokument [ "ProceduraMZS_WSKUE_włączanie_wyłą...
3       `Dokument [ "ProceduraMZS_Monitorowanie_SOZ_i_...
4       `Dokument [ "ProceduraMZS_Wsad_v3.21.docx" ]:`...
                              ...                        
4105    `Dokument [ "History of the New Jersey State C...
4106    `Dokument [1]:` Kino Polonia – kino w Łodzi zn...
4107    `Dokument [ "Josh Fogg" ]:` Joshua Smith Fogg ...
4108    `Dokument [1]:` Abigail Deveraux to fikcyjna p...
4109    `Dokument [1]:` Legenda o niebieskich oczach J...
Name: context, Length: 3706, dtype: object

In [20]:
df_filtered['tokens'] = df_filtered['context'].apply(lambda x: len(x.split(" ")) if isinstance(x, str) else 0)

In [24]:
df_filtered.loc[df_filtered[V2] == -1].sort_values(by='tokens', ascending=True)[['context', 'tokens']]

Unnamed: 0,context,tokens
886,`Dokument [1]:` Ludwik Tadeusz Waryński (ur. 2...,32
2916,`Dokument [1]:` Wacław Stanisław Sitkowski (ur...,37
412,`Dokument [1]:` Paul William Hodes (ur. 21 mar...,93
2380,`Dokument [1]:` Refektarz Infirmerii – jadalni...,94
984,`Dokument [1]:` Puchar Interkontynentalny 1969...,97
...,...,...
260,"`Dokument [ ""ProceduraMZS_Interakcja_v3.91.doc...",3677
268,"`Dokument [ ""ProceduraMZS_Weryfikacja_Monitora...",3746
176,"`Dokument [ ""ProceduraMZS_Monitorowanie_SOZ_i_...",3976
275,"`Dokument [ ""ProceduraMZS_Monitor_domen_TUXEDO...",4588


In [28]:
print(df_filtered.iloc[2380]['formatted'])

You are a helpful assistant. Your job will be to answer questions accurately based on the given context and not your internal knowledge.
    If you can not answer the question only based on the provided context, return the answer: `Nie mogę udzielić odpowiedzi na to pytanie na podstawie podanego kontekstu`.
    Pay special attention to the names of applications, services, tools, and components - it is crucial to return consistent information for the subject. 
    Think of it step by step: 
        1. Find relevant information in the provided context. 
        2a. If there is no information relevant to the query, return the answer: `Nie mogę udzielić odpowiedzi na to pytanie na podstawie podanego kontekstu`
        2b. If information is relevant to the query, based on the context's relevant information, formulate the final answer.
Your answers MUST be written in POLISH language.
The context will be provided by `CONTEXT`, the user query by `QUERY`, and your job is to return the answer `A

In [37]:
# X = dataset_vec[df_filtered.index]
y = df_filtered[V2].values


y = np.where(y == 0, 1, y)
y = np.where(y == -1, 0, y)
y = y.astype(np.int8)

In [39]:
y = np.load(os.path.join(DATA_DIR, 'y.npy'))

In [40]:
X.shape, y.shape

((3702, 8, 1, 3584), (3702,))

In [41]:
y.shape

(3702,)

In [42]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
X_train, X_test = X_train[:, -2, :, :].squeeze(), X_test[:, -2, :, :].squeeze()
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42, stratify=y_train)

In [43]:
counter_train, counter_val, counter_test = Counter(y_train), Counter(y_val), Counter(y_test)
print(f"{counter_train = }, {counter_val = }, {counter_test = }")

counter_train = Counter({1: 2052, 0: 316}), counter_val = Counter({1: 514, 0: 79}), counter_test = Counter({1: 642, 0: 99})


In [44]:
pos_weight = torch.tensor([counter_train[0] / counter_train[1]])
print(f"{pos_weight = }")

pos_weight = tensor([0.1540])


In [45]:
BATCH_SIZE = 128
EPOCHS = 10

In [46]:
train_data, _ = to_dataloader(torch.from_numpy(X_train), torch.from_numpy(y_train), batch_size=BATCH_SIZE)
val_data, _ = to_dataloader(torch.from_numpy(X_val), torch.from_numpy(y_val), batch_size=BATCH_SIZE)
test_data, _ = to_dataloader(torch.from_numpy(X_test), torch.from_numpy(y_test), batch_size=BATCH_SIZE)

In [121]:
hallu_cls = MLP(
    input_dim=X_train.shape[1], 
    hidden_dim=128, 
    output_dim=1
).to(DEVICE)

optimiser = torch.optim.Adam(hallu_cls.parameters(), lr=1e-5, weight_decay=0.003)
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(DEVICE))

In [122]:
metrics = fit(
    model=hallu_cls, 
    optimiser=optimiser, 
    loss_fn=loss_fn,
    train_dl=train_data, 
    val_dl=val_data, 
    epochs=EPOCHS
)

Epoch 0: train loss = 0.001 (acc: 16.811), validation loss = 0.001 (acc: 16.180)
Epoch 1: train loss = 0.001 (acc: 16.811), validation loss = 0.001 (acc: 15.943)
Epoch 2: train loss = 0.001 (acc: 16.811), validation loss = 0.001 (acc: 16.260)
Epoch 3: train loss = 0.001 (acc: 16.784), validation loss = 0.001 (acc: 16.339)
Epoch 4: train loss = 0.001 (acc: 16.730), validation loss = 0.001 (acc: 16.260)
Epoch 5: train loss = 0.001 (acc: 16.919), validation loss = 0.001 (acc: 15.943)
Epoch 6: train loss = 0.001 (acc: 16.811), validation loss = 0.001 (acc: 16.339)
Epoch 7: train loss = 0.001 (acc: 16.865), validation loss = 0.001 (acc: 16.260)
Epoch 8: train loss = 0.001 (acc: 16.811), validation loss = 0.001 (acc: 16.180)
Epoch 9: train loss = 0.001 (acc: 16.946), validation loss = 0.001 (acc: 16.339)


In [123]:
for n, x, y in zip(['train', 'val', 'test'], [X_train, X_val, X_test], [y_train, y_val, y_test]):

    preds = hallu_cls.forward(torch.from_numpy(x).to(DEVICE), mode='eval').squeeze()
    preds_probes = torch.sigmoid(preds).detach().cpu().numpy()
    preds = np.where(preds_probes > 0.5, 1, 0)
    print(f"{n} AUC: {roc_auc_score(y, preds_probes)}")


train AUC: 0.8638469415451426
val AUC: 0.8034773186228636
test AUC: 0.7513294943201485


In [124]:
torch.save(hallu_cls.state_dict(), os.path.join(DATA_DIR, 'hallu_cls.pth'))

In [125]:
model = MLP(
    input_dim=X_train.shape[1], 
    hidden_dim=128, 
    output_dim=1
).to(DEVICE)

In [126]:
model.load_state_dict(torch.load(os.path.join(DATA_DIR, 'hallu_cls.pth')))

<All keys matched successfully>

In [138]:
for n, x, y in zip(['train', 'val', 'test'], [X_train, X_val, X_test], [y_train, y_val, y_test]):

    preds = model.forward(torch.from_numpy(x).to(DEVICE), mode='eval').squeeze()
    preds_probes = torch.sigmoid(preds).detach().cpu().numpy()
    preds = np.where(preds_probes > 0.5, 1, 0)
    print(f"{n} AUC: {roc_auc_score(y, preds_probes)}")


train AUC: 0.8614904878229328
val AUC: 0.7997340294537754
test AUC: 0.748623304698071
