In [None]:
!pip install -qqq transformers datasets bitsandbytes accelerate scikit-learn peft trl

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer
from datasets import load_dataset
import numpy as np

TOKEN = "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"
MODEL_NAME = "google/gemma-7b"
device = "cuda"

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=TOKEN)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, quantization_config=bnb_config, device_map="auto", token=TOKEN)

In [None]:
from datasets import load_dataset

# Load IMDB dataset and extract 300 rows
dataset = load_dataset('imdb')

train_ds = dataset['train'].select(range(300))
test_ds = dataset['test'].shuffle(seed=42).select(range(300))

# Check for NaN or Inf values in the dataset
train_ds = train_ds.filter(lambda x: not any(pd.isna(v) or pd.isinf(v) for v in x.values()))
test_ds = test_ds.filter(lambda x: not any(pd.isna(v) or pd.isinf(v) for v in x.values()))

# Proceed with training or evaluation
print(train_ds)
print(test_ds)


In [None]:
TOKENS = {
    " Positive": 40695,
    " Negative": 48314,
    " positive": 6222,
    " negative": 8322,
    "Positive": 35202,
    "Negative": 39654,
}

def get_prompt_list(item):
    res = []
    for text, label in zip(item['text'], item['label']):
        content = get_prompt(text)
        content += "Positive" if label == 1 else 'Negative'
        res.append(content)
    # print(res)
    return res

def get_prompt(query):
    content = f"""### REVIEW:
{query}

### SENTIMENT:
"""
    return content

def llm(query):
    prompt = get_prompt(query)
    inputs = tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt")
    outputs = model.generate(input_ids=inputs.to(model.device), max_new_tokens=1, output_scores=True, return_dict_in_generate=True)

    positive_pred = outputs.scores[0][0][TOKENS['Positive']]
    negative_pred = outputs.scores[0][0][TOKENS['Negative']]

    positive_pred = positive_pred.cpu()
    negative_pred = negative_pred.cpu()

    scores = np.array([positive_pred, negative_pred])
    probs = np.exp(scores) / np.sum(np.exp(scores))
    
    positive_prob = probs[0]
    negative_prob = probs[1]
    # print(positive_pred, negative_pred)
    
    return tokenizer.decode(outputs.sequences[0]), positive_prob

def predict(query, print_res = False):
    text, prob = llm(query)
    if print_res:
        print(text)
    return prob

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score

def print_metrics(y_true, y_pred):
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    
    print("Accuracy:", accuracy)
    print("Precision:", precision)
    print("Recall:", recall)

In [None]:
from tqdm import tqdm

y_pred = []
y_test = []
i = 0
for ex in tqdm(test_ds):
    pred, label = predict(ex['text']), ex['label']
    # print(pred, label)
    y_pred.append(pred)
    y_test.append(label)

print('before fine tuning')
print_metrics(y_test, np.round(y_pred))

In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"
from peft import LoraConfig

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

In [None]:
import transformers
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    train_dataset=train_ds,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        learning_rate=2e-5,
        num_train_epochs=2,
        fp16=True,
        logging_steps=20,
        output_dir="outputs",
        optim="paged_adamw_8bit",
        report_to="none"
    ),
    peft_config=lora_config,
    formatting_func=get_prompt_list,
)

trainer.train()

In [None]:
# this is very slow if you want to run on all 25k samples :)

from tqdm import tqdm

y_pred = []
y_test = []
i = 0
for ex in tqdm(test_ds):
    pred, label = predict(ex['text']), ex['label']
    y_pred.append(pred)
    y_test.append(label)

print('after fine tuning')
print_metrics(y_test, np.round(y_pred))

In [None]:
cnt = 0
for i in range(len(y_pred)):
    if np.round(y_pred[i]) != y_test[i]:
        example = test_ds[i]
        print(predict(example['text'], print_res=True), example['label'])
        cnt += 1
        if cnt == 5:
            break