In [1]:
import time
import os
import ftfy
from tqdm import tqdm

import numpy as np
import pandas as pd

from dataclasses import dataclass
from concurrent.futures import ThreadPoolExecutor

import torch
import sklearn

from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizerBase
from transformers.data.data_collator import pad_without_fast_tokenizer_warning
from peft import PeftModel

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

USE_LEFTSIDE_TRUNCATION = True

In [3]:
@dataclass
class Config:
    gemma_dir = 'TOP5-MODEL/QUANTIZED-TOP5-LMSYS-8BIT'
    lora_dir = 'output-4-8BIT-TOP5-LMSYS-MODEL-99.9PERCENT-CUSTOM-HEAD-LEFTSIDE-NO-EXTRA-DATA-MAXLEN2048-R64-A4-BF16/checkpoint-3024'
    head_dropout = 0.1
    hdim = 3584
    num_labels = 2
    device = torch.device('cuda')    
    max_length = 2048
    batch_size = 32
    tta = True
    
cfg = Config()

In [4]:
test = pd.read_parquet('data/for-pseudolabeling/wsdm_for_pseudolabel.parquet') # data/for-pseudolabeling/hf-open-models-v1.parquet

In [5]:
orpo = pd.read_parquet('data/orpo-dpo-44k-for-wsdm.parquet')

In [6]:
mask = ~test['prompt'].isin(orpo['prompt'])
test = test[mask]

In [7]:
test = test[test['response_a'] != '']
test = test[test['response_b'] != '']

In [8]:
test = test[test['response_a'] != ' ']
test = test[test['response_b'] != ' ']

In [9]:
test = test[test['response_b'] != '\n']
test = test[test['response_b'] != '\n']

In [10]:
def process_text(text: str) -> str:
    return ftfy.fix_text(text)

In [11]:
def tokenize(
    tokenizer, prompt, response_a, response_b, max_length=cfg.max_length
):
    prompt = ['<prompt>: ' + process_text(t) for t in prompt]
    response_a = ['\n\n<response_a>: ' + process_text(t) for t in response_a]
    response_b = ['\n\n<response_b>: ' + process_text(t) for t in response_b]
    
    texts = [p + r_a + r_b for p, r_a, r_b in zip(prompt, response_a, response_b)]
    tokenized = tokenizer(texts, max_length=max_length, truncation=True)
    
    input_ids = tokenized.input_ids
    attention_mask = tokenized.attention_mask
        
    return input_ids, attention_mask

In [12]:
tokenizer = AutoTokenizer.from_pretrained(cfg.gemma_dir)
tokenizer.add_eos_token = True
if USE_LEFTSIDE_TRUNCATION:
    tokenizer.truncation_side = 'left'
    tokenizer.padding_side = 'left'

In [13]:
%%time

data = pd.DataFrame()
#data["id"] = test["id"]
data["input_ids"], data["attention_mask"] = tokenize(tokenizer, test["prompt"], test["response_a"], test["response_b"])
data["length"] = data["input_ids"].apply(len)

aug_data = pd.DataFrame()
#aug_data["id"] = test["id"]
aug_data['input_ids'], aug_data['attention_mask'] = tokenize(tokenizer, test["prompt"], test["response_b"], test["response_a"])
aug_data["length"] = aug_data["input_ids"].apply(len)

CPU times: user 27min 18s, sys: 49.7 s, total: 28min 8s
Wall time: 13min 33s


In [14]:
print(tokenizer.decode(data["input_ids"][0]))

<bos><prompt>: Given a list of numbers, sort the list. However, instead of using traditional sorting methods, implement a heap sort algorithm to sort the list in ascending order. Do not use any in-built or library functions for sorting.

List : [7, 3, 5, 6, 2]

<response_a>:   Sure, I can help you with that! Here's an example of how to implement a heap sort algorithm to sort the given list in ascending order:

First, let's define a function to swap two elements in the list:
```
def swap(arr, i, j):
    temp = arr[i]
    arr[i] = arr[j]
    arr[j] = temp
```
Now, let's define a function to heapify the list:
```
def heapify(arr, n):
    # Last index of the list
    last_idx = n - 1

    # Loop from the last index to the second last index
    for i in range(last_idx, 0, -1):
        # largest element in the heap
        largest = i

        # Compare the largest element with its children
        for j in range(2 * i + 1, n):
            if arr[j] > arr[largest]:
                largest = 

In [15]:
print(tokenizer.decode(aug_data["input_ids"][0]))

<bos><prompt>: Given a list of numbers, sort the list. However, instead of using traditional sorting methods, implement a heap sort algorithm to sort the list in ascending order. Do not use any in-built or library functions for sorting.

List : [7, 3, 5, 6, 2]

<response_a>: 
Here is a Python solution using heap sort algorithm:

```python
def heapify(arr, n, i):
    largest = i  # Initialize largest as root
    l = 2 * i + 1     # left = 2*i + 1
    r = 2 * i + 2     # right = 2*i + 2
    # See if left child of root exists and is greater than root
    if l < n and arr[i] < arr[l]:
        largest = l
    # See if right child of root exists and is greater than root
    if r < n and arr[largest] < arr[r]:
        largest = r
    # Change root, if needed
    if largest!= i:
        arr[i], arr[largest] = arr[largest], arr[i]  # swap
        # Heapify the root.
        heapify(arr, n, largest)

def heap_sort(arr):
    n = len(arr)
    # Build a maxheap.
    for i in range(n // 2 - 1, -1, -

In [16]:
model = AutoModelForSequenceClassification.from_pretrained(
    cfg.gemma_dir,
    torch_dtype=torch.float16,
    num_labels=cfg.num_labels,
    device_map='auto',
    use_cache=False
)
model.score = torch.nn.Sequential(
    torch.nn.Dropout(cfg.head_dropout),
    torch.nn.Linear(cfg.hdim, cfg.hdim // 2),
    torch.nn.Dropout(cfg.head_dropout),
    torch.nn.GELU(),
    torch.nn.Linear(cfg.hdim // 2, cfg.num_labels),
).to('cuda:0')

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [17]:
model = PeftModel.from_pretrained(model, cfg.lora_dir)

In [18]:
@torch.no_grad()
@torch.cuda.amp.autocast()
def inference(df, model, device, batch_size=cfg.batch_size, max_length=cfg.max_length):
    a_win, b_win = [], []
    
    for start_idx in tqdm(range(0, len(df), batch_size)):
        end_idx = min(start_idx + batch_size, len(df))
        tmp = df.iloc[start_idx:end_idx]
        input_ids = tmp['input_ids'].to_list()
        attention_mask = tmp['attention_mask'].to_list()
        inputs = pad_without_fast_tokenizer_warning(
            tokenizer,
            {'input_ids': input_ids, 'attention_mask': attention_mask},
            padding='longest',
            pad_to_multiple_of=None,
            return_tensors='pt',
        )
        outputs = model(**inputs.to(device))
        proba = outputs.logits.softmax(-1).cpu()
        
        a_win.extend(proba[:, 0].tolist())
        b_win.extend(proba[:, 1].tolist())
    
    df['winner_model_a'] = a_win
    df['winner_model_b'] = b_win
    
    return df

  @torch.cuda.amp.autocast()


In [19]:
st = time.time()

data = data.sort_values('length', ascending=False)
result_df = inference(data, model, 'cuda:0')
proba = result_df[['winner_model_a', 'winner_model_b']].values

print(f'elapsed time: {time.time() - st}')

 71%|███████   | 12513/17702 [8:44:24<1:08:22,  1.26it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [20]:
proba

array([[0.03584693, 0.96415305],
       [0.77210134, 0.22789866],
       [0.08889936, 0.91110069],
       ...,
       [0.37307605, 0.62692398],
       [0.20577459, 0.79422539],
       [0.49821472, 0.50178528]])

In [21]:
st = time.time()

if cfg.tta:
    aug_data = aug_data.sort_values('length', ascending=False)
    tta_result_df = inference(aug_data, model, 'cuda:0')
    tta_proba = tta_result_df[['winner_model_b', 'winner_model_a']].values

print(f'elapsed time: {time.time() - st}')

 17%|█▋        | 2999/17702 [4:03:13<12:44:07,  3.12s/it]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100%|██████████| 17702/17702 [9:22:51<00:00,  1.91s/it]


elapsed time: 33772.72440576553


In [22]:
tta_proba

array([[0.00454726, 0.99545276],
       [0.75976056, 0.24023944],
       [0.07613248, 0.92386752],
       ...,
       [0.45786574, 0.54213423],
       [0.63411862, 0.36588141],
       [0.61172998, 0.38827005]])

In [23]:
proba = (proba + tta_proba) / 2

In [24]:
proba

array([[0.0201971 , 0.97980291],
       [0.76593095, 0.23406905],
       [0.08251592, 0.9174841 ],
       ...,
       [0.4154709 , 0.5845291 ],
       [0.4199466 , 0.5800534 ],
       [0.55497235, 0.44502766]])

In [25]:
test['winner'] = proba.tolist()

In [26]:
test = test[['prompt', 'response_a', 'response_b', 'winner']]

In [27]:
test.to_parquet('data/wsdm-pseudolabeled/huuuuuge_pseudolabeled_df.parquet', index=False) # data/wsdm-pseudolabeled/8k_pseudolabeled.parquet

In [28]:
test

Unnamed: 0,prompt,response_a,response_b,winner
0,"Given a list of numbers, sort the list. Howeve...","Sure, I can help you with that! Here's an ex...",\nHere is a Python solution using heap sort al...,"[0.020197095349431038, 0.9798029065132141]"
1,Write Python code to solve the task:\nMedicine...,Step 1: The problem requires us to find the m...,Step 1: We need to determine the minimum cost...,"[0.7659309506416321, 0.23406904935836792]"
2,Write Python code to solve the task:\nJiro is ...,"Step 1: To solve this problem, we need to cal...",Step 1: We need to calculate the rank of each...,"[0.08251592144370079, 0.9174841046333313]"
3,Refine a Python function using PyTorch and ski...,"Sure, I'd be happy to help! Here's a refined...",\nBelow is a simplified Python function that c...,"[0.8477382957935333, 0.15226170420646667]"
4,what are alternatives to the egg drop challang...,Here are a few alternative challenges that a g...,The egg drop challenge is a classic engineerin...,"[0.5739484578371048, 0.4260515198111534]"
...,...,...,...,...
585442,You will be given a definition of a task first...,The object of the main clause in the given sen...,Here's my prediction: The 'chamber's heavy bla...,"[0.3252182975411415, 0.6747817397117615]"
585445,"In addition to the given prompt, what are some...",In enhancing your Java performance review temp...,I'm not capable of modifying code directly. ho...,"[0.43432340025901794, 0.5656766295433044]"
585447,Describe three different approaches to estimat...,Estimating the effectiveness of an organizatio...,Three different approaches to estimating the e...,"[0.4154708981513977, 0.5845291018486023]"
585448,How can I develop a code in Mark down that per...,Markdown is a lightweight markup language desi...,To develop code in Markdown that performs mult...,"[0.4199466034770012, 0.5800534039735794]"
