- Training notebook: https://www.kaggle.com/code/ravaghi/wsdm-cup-gemma-2-9b-4-bit-qlora-training

# Imports and configs

In [None]:
!pip install accelerate peft bitsandbytes transformers trl unsloth seaborn 
!pip install --upgrade 'optree>=0.13.0'

In [2]:
from transformers import Gemma2ForSequenceClassification, GemmaTokenizerFast
from transformers.data.data_collator import pad_without_fast_tokenizer_warning
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
from sklearn.base import clone
from concurrent.futures import ThreadPoolExecutor
from timeit import default_timer as timer
from peft import PeftModel
from tqdm import tqdm
from nltk.tokenize import sent_tokenize
from catboost import CatBoostClassifier
from lightgbm import LGBMClassifier
from scipy.special import logit
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import warnings
import joblib
import torch
import json
import gc
import os

warnings.filterwarnings('ignore')

In [3]:
import kagglehub

gemma_2_9b_4bit_it_unsloth_transformers_default_1_path = kagglehub.model_download('leimeng46/gemma-2-9b-4bit-it-unsloth/Transformers/default/1')
wsdm_cup_gemma_2_9b_4_bit_qlora_path = kagglehub.dataset_download('ravaghi/wsdm-cup-gemma-2-9b-4-bit-qlora')

In [4]:
class CFG:
    train_path = '../data/train.parquet'
    test_path = '../data/train.parquet'
    sample_sub_path = '/kaggle/input/wsdm-cup-multilingual-chatbot-arena/sample_submission.csv'

    data_path = '/kaggle/input/wsdm-cup-gemma-2-9b-4-bit-qlora'

    gemma_dir = gemma_2_9b_4bit_it_unsloth_transformers_default_1_path + "/gemma-2-9b-it-4bit-unsloth_old"
    lora_dir = wsdm_cup_gemma_2_9b_4_bit_qlora_path + "/gemma2-9b-4bit/gemma-2-9b-it-bnb-4bit-3072-8/checkpoint-2900"
    
    max_length = 3072
    batch_size = 4

    target = 'winner'
    n_folds = 5
    seed = 42

    char_vectorizer_params = {
        'analyzer': "char",
        "lowercase": False,
        "max_df": 0.605,
        "max_features": 331,
        "min_df": 0.075,
        "ngram_range": (1, 3),
        "strip_accents": "unicode"
    }

    word_vectorizer_params = {
        "analyzer": "word",
        "lowercase": True,
        "max_df": 0.985,
        "max_features": 769,
        "min_df": 0.01,
        "ngram_range": (1, 2),
        "strip_accents": "unicode"
    }

# Gemma-2 9b 4-bit

In [5]:
from sklearn.model_selection import train_test_split

In [6]:
test = pd.read_parquet(CFG.test_path).fillna('')
train, test = train_test_split(test, test_size=0.2, random_state=1)
val, test = train_test_split(test, test_size=0.5, random_state=1)
train = pd.concat((train, val))

In [7]:
if len(test) > 10_000:
    time_limit = int(3600 * 12) 
else:
    time_limit = int(3600 * 4.75)

## Tokenizing

In [None]:
def tokenize(tokenizer, prompt, response_a, response_b, max_length=CFG.max_length):
    prompt = ["<prompt>: " + t for t in prompt]
    response_a = ["\n\n<response_a>: " + t for t in response_a]
    response_b = ["\n\n<response_b>: " + t for t in response_b]
    
    texts = [p + r_a + r_b for p, r_a, r_b in zip(prompt, response_a, response_b)]
    tokenized = tokenizer(texts, max_length=max_length, truncation=True)
    
    return tokenized['input_ids'], tokenized['attention_mask']

In [9]:
CFG.gemma_dir

'/home/ya.pristalov/.cache/kagglehub/models/leimeng46/gemma-2-9b-4bit-it-unsloth/Transformers/default/1/gemma-2-9b-it-4bit-unsloth_old'

In [10]:
tokenizer = GemmaTokenizerFast.from_pretrained(CFG.gemma_dir)


tokenizer.add_eos_token = True
tokenizer.padding_side = "right"

In [11]:
for col in ['prompt', 'response_a', 'response_b']:
    test[col] = test[col].fillna('')
    text_list = []
    if col == "prompt":
        max_no = 512
        s_no = 255
        e_no = -256
    else:
        max_no = 3072
        s_no = 1535
        e_no = -1536
    for text in tqdm(test[col]):
        encoded = tokenizer(text, return_offsets_mapping=True)
        if len(encoded['input_ids']) > max_no:
            start_idx, end_idx = encoded['offset_mapping'][s_no]
            new_text = text[:end_idx]
            start_idx, end_idx = encoded['offset_mapping'][e_no]
            new_text = new_text + "\n(snip)\n" + text[start_idx:]
            text = new_text
        text_list.append(text)
    test[col] = text_list

100%|██████████| 4844/4844 [00:01<00:00, 3290.47it/s]
100%|██████████| 4844/4844 [00:02<00:00, 1687.36it/s]
100%|██████████| 4844/4844 [00:02<00:00, 1660.52it/s]


In [12]:
data = pd.DataFrame()
data["id"] = test["id"]
data["input_ids"], data["attention_mask"] = tokenize(tokenizer, test["prompt"], test["response_a"], test["response_b"])
data["length"] = data["input_ids"].apply(len)

aug_data = pd.DataFrame()
aug_data["id"] = test["id"]
# swap response_a & response_b
aug_data['input_ids'], aug_data['attention_mask'] = tokenize(tokenizer, test["prompt"], test["response_b"], test["response_a"])
aug_data["length"] = aug_data["input_ids"].apply(len)

In [28]:
test

Unnamed: 0,id,prompt,response_a,response_b,winner,model_a,model_b,language
7419,271ba404fc25609b8ceb6f37c3375f278c42e2479514ea...,Улучшь сообщение: Добрый вечер! \nЭто Алексей ...,"Добрый вечер!\n\nМеня зовут Алексей Королев, я...","Добрый вечер!\n\nЯ, Алексей Королев, заместите...",model_b,qwen-max-0919,c4ai-aya-expanse-32b,Russian
39112,ce7be103ac8f040accc9b8ec421a6cbf78189c503671df...,i have a notification icon in my webpage. when...,Certainly! To implement a notification system ...,Here's a simple example of how you can achieve...,model_b,qwen2.5-72b-instruct,llama-3.1-405b-instruct-bf16,English
33946,b39725e0a7992eb0c5f34542f960727a4fd306741568fb...,I have a raspberry pi that is booted from USB....,**Hotplugging the SD Card on a Raspberry Pi**\...,"No, you generally **cannot** safely hotplug th...",model_a,llama-3.1-8b-instruct,gemini-1.5-pro-exp-0827,English
35392,bb33f84e8e85da9e41fa271f6210fdffba0400d725a04c...,"Imagine video on YouTube ""Asking people what i...",Here are some imagined YouTube comments relati...,Here are some comments that might appear under...,model_a,gemini-1.5-flash-002,gemma-2-27b-it,English
10254,364c8146a6a236fbdfa34492dba10405a0e448d3062316...,generate a wine description,Please provide some context! To generate a co...,**Wine Description:**\n\n**Name:** Sunset on t...,model_b,gemini-1.5-flash-8b-001,llama-3.1-70b-instruct,English
...,...,...,...,...,...,...,...,...
26506,8bd008cebcfd4a9e504f0dbca5a5664d412ca9b4c4045f...,"""La existencia de diferencias con respecto a c...",La presencia de variaciones en los espectros d...,La presencia de diferencias según el gas utili...,model_a,gpt-4-1106-preview,gpt-4o-2024-08-06,Spanish
4869,198aa3635a6fc9f6a7d069e82cc9cddcbc783e3a0e83e8...,Thế nào là tham nhũng?\n\n Tham nhũng là hà...,"Theo Khoản 1, Điều 353 Bộ luật Hình sự năm 201...",**Câu 1: Thế nào là tham nhũng?**\n\n**Đáp án ...,model_b,claude-3-opus-20240229,o1-mini,Vietnamese
16686,584414c294899b255a9576ade65032610ca1937cffcad4...,Which number is greater: 9.11 or 9.9?\nIs 9.11...,Let's compare the two numbers step by step:\n\...,9.11 is greater than 9.9.,model_a,grok-2-2024-08-13,deepseek-v2.5,English
28451,95fb0a8dabf14619a9564334dc390cbf170a0e09905086...,```\nto order the special call commission for ...,This is a transcript of a Kentucky Fish and Wi...,This transcript details a special call commiss...,model_b,gemini-1.5-pro-002,gemini-1.5-flash-8b-001,English


## Model

In [13]:
CFG.gemma_dir

'/home/ya.pristalov/.cache/kagglehub/models/leimeng46/gemma-2-9b-4bit-it-unsloth/Transformers/default/1/gemma-2-9b-it-4bit-unsloth_old'

In [14]:
model = Gemma2ForSequenceClassification.from_pretrained(
    CFG.gemma_dir,
    device_map=torch.device("cuda"),
    use_cache=False,
)

Some weights of Gemma2ForSequenceClassification were not initialized from the model checkpoint at /home/ya.pristalov/.cache/kagglehub/models/leimeng46/gemma-2-9b-4bit-it-unsloth/Transformers/default/1/gemma-2-9b-it-4bit-unsloth_old and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
model = PeftModel.from_pretrained(model, CFG.lora_dir)

In [16]:
model.eval()

model.base_model.model.score = torch.nn.Identity()

## Inference

In [18]:
@torch.no_grad()
@torch.cuda.amp.autocast()
def inference(df, model, device, batch_size, max_length=CFG.max_length):
    all_embeddings = []
    
    for start_idx in tqdm(range(0, len(df), batch_size)):
        end_idx = min(start_idx + batch_size, len(df))
        tmp = df.iloc[start_idx:end_idx]
        input_ids = tmp["input_ids"].to_list()
        attention_mask = tmp["attention_mask"].to_list()
        inputs = pad_without_fast_tokenizer_warning(
            tokenizer,
            {"input_ids": input_ids, "attention_mask": attention_mask},
            padding="longest",
            pad_to_multiple_of=None,
            return_tensors="pt",
        )
        outputs = model(**inputs.to(device))
        
        embeddings = outputs.logits.cpu()
        
        all_embeddings.extend(embeddings.tolist())
    
    return all_embeddings

In [19]:
global_timer = timer()

In [20]:
data['index'] = np.arange(len(data), dtype=np.int32)
data = data.sort_values("length", ascending=False).reset_index(drop=True)

In [21]:
data

Unnamed: 0,id,input_ids,attention_mask,length,index
0,79c362d0af7595cb69e665027d591df7ab83fe2d7dd34e...,"[2, 235322, 39038, 78880, 3893, 105821, 1982, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",3072,3881
1,a83127929a103f8e0f35928e5902da56c2307c032378dd...,"[2, 235322, 39038, 78880, 206468, 3901, 190158...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",3072,4208
2,49b5e589b60986b9088ab224965e0a9d0fc4a228484d7d...,"[2, 235322, 39038, 78880, 1004, 68574, 1454, 3...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",3072,2184
3,159fa657ef6484164d6253ad7b4a90e5de752213032bc0...,"[2, 235322, 39038, 78880, 182483, 6520, 12788,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",3072,2216
4,e4242d5a354ccda2f1da52cc889ff5908578432aff732a...,"[2, 235322, 39038, 78880, 7717, 11809, 42765, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",3072,2224
...,...,...,...,...,...
4839,1cb8843f3e3352e005a362da220956fadf1d8bdfa46c52...,"[2, 235322, 39038, 78880, 5823, 131345, 2218, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",37,3814
4840,944a871333310cc97c688233b45c883837404577e44ac7...,"[2, 235322, 39038, 78880, 3337, 1297, 235248, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",36,2161
4841,b63122dabc0457a9c5f75541790f8627ba39b70c769229...,"[2, 235322, 39038, 78880, 123781, 109, 235322,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",36,3490
4842,1e2833f1fc663fe31efc6da3fd5626c5602d94db6e7ba2...,"[2, 235322, 39038, 78880, 3233, 109, 235322, 4...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",35,808


In [22]:
embeddings = inference(
    data,
    model, 
    torch.device('cuda'),
    CFG.batch_size
)

100%|██████████| 1211/1211 [18:30<00:00,  1.09it/s]


In [31]:
eval_df = pd.DataFrame(embeddings)
eval_df['id'] = data['id']

eval_df = eval_df.merge(test[['id', 'winner']], on='id').drop(columns=['id'])

train_eval_X, test_eval_X, train_eval_y, test_eval_y = train_test_split(eval_df.iloc[:, :-1], eval_df.iloc[:, -1], test_size=0.5)

In [35]:
clf = CatBoostClassifier(verbose=100, eval_metric='Accuracy', depth=1)

clf.fit(train_eval_X, train_eval_y, eval_set=(test_eval_X, test_eval_y))

Learning rate set to 0.039428
0:	learn: 0.6787779	test: 0.6696945	best: 0.6696945 (0)	total: 4.59ms	remaining: 4.59s
100:	learn: 0.7369942	test: 0.7080925	best: 0.7089182 (97)	total: 299ms	remaining: 2.66s
200:	learn: 0.7456647	test: 0.7085054	best: 0.7113955 (118)	total: 603ms	remaining: 2.4s
300:	learn: 0.7551610	test: 0.7089182	best: 0.7113955 (118)	total: 892ms	remaining: 2.07s
400:	learn: 0.7630058	test: 0.7105698	best: 0.7113955 (118)	total: 1.17s	remaining: 1.75s
500:	learn: 0.7696119	test: 0.7101569	best: 0.7113955 (118)	total: 1.49s	remaining: 1.48s
600:	learn: 0.7762180	test: 0.7126342	best: 0.7142857 (581)	total: 1.85s	remaining: 1.23s
700:	learn: 0.7824112	test: 0.7138728	best: 0.7163501 (675)	total: 2.22s	remaining: 949ms
800:	learn: 0.7910818	test: 0.7118084	best: 0.7163501 (675)	total: 2.6s	remaining: 646ms
900:	learn: 0.7960363	test: 0.7118084	best: 0.7163501 (675)	total: 2.98s	remaining: 328ms
999:	learn: 0.8047069	test: 0.7118084	best: 0.7163501 (675)	total: 3.36s	rem

<catboost.core.CatBoostClassifier at 0x7f5440fbf0d0>