In [24]:
from utils import stories_dataset
import polars as pl

stories = stories_dataset()

small_sample = stories.sample(10000, seed=42)

small_sample

id,title,by,text,score,descendants,time,log_score,serialized,split
i64,str,str,str,i64,i64,datetime[μs],f64,str,str
34724037,"""Ask HN: How can I get into Neu…","""notmindthegap""","""Hi all – I’m in my mid-30s, we…",2,0,2023-02-09 14:05:38,0.693147,""" Ask HN: How can I get into Ne…","""test"""
22604483,"""Ask HN: Does having some stres…","""jbms""","""I&#x27;ve really enjoyed some …",1,0,2020-03-17 13:06:20,0.0,""" Ask HN: Does having some stre…","""train"""
33173205,"""Coaching for “Normals”?""","""wanderingCoder""","""I&#x27;m interested in finding…",73,61,2022-10-12 05:26:38,4.290459,""" Coaching for “Normals”? wande…","""train"""
38838197,"""Internet, Blockchain, AI, Amar…","""bernardlunn""","""Amara’s Law (coined by Roy Ama…",3,2,2024-01-02 04:42:10,1.098612,""" Internet, Blockchain, AI, Ama…","""train"""
22675416,"""Show HN: Self-Published Book –…","""anconia""","""I just self-published a book t…",24,13,2020-03-24 15:32:53,3.178054,""" Show HN: Self-Published Book …","""train"""
…,…,…,…,…,…,…,…,…,…
24970083,"""Ask HN: Would you like an Appl…","""ciccionamente""","""Today the Raspberry Pi 400 has…",5,11,2020-11-02 16:47:55,1.609438,""" Ask HN: Would you like an App…","""train"""
19907580,"""ASIC Verification Course in Ba…","""mavensilicon""","""ASIC Design and Verification c…",1,0,2019-05-14 07:46:53,0.0,""" ASIC Verification Course in B…","""train"""
34455854,"""Ask HN: Google spam filter get…","""jgwil2""","""I have noticed an uptick in un…",158,91,2023-01-20 16:50:27,5.062595,""" Ask HN: Google spam filter ge…","""test"""
11965471,"""Ask HN: How can I get iOS proj…","""selfthrow""","""I am a long time lurker here. …",1,0,2016-06-24 00:55:03,0.0,""" Ask HN: How can I get iOS pro…","""train"""


In [27]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import PeftModel
from liger_kernel.transformers import _apply_liger_kernel_to_instance

# Configuration
base_model = "unsloth/Meta-Llama-3.1-8B"
run_name = "stories_model_v2"
output_dir = f"./models/{run_name}"
max_length = 4096

print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(
    output_dir,
    truncation=True,
    padding=True,
    max_length=max_length,
)

model = AutoModelForSequenceClassification.from_pretrained(
    base_model,
    num_labels=1,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
)

# Apply the Liger kernel to the model
_apply_liger_kernel_to_instance(model=model)

# Load the trained PEFT model
model = PeftModel.from_pretrained(model, output_dir)

print("Model loaded and Liger kernel applied successfully.")


Loading tokenizer and model...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at unsloth/Meta-Llama-3.1-8B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded and Liger kernel applied successfully.


In [19]:
# import torch
# from transformers import AutoModelForSequenceClassification, AutoTokenizer
# from peft import PeftModel
# from liger_kernel.transformers import _apply_liger_kernel_to_instance

# # Configuration
# base_model = "unsloth/Meta-Llama-3.1-8B"
# run_name = "stories_model_v2"
# output_dir = f"./models/{run_name}"
# max_length = 4096

# print("Loading tokenizer and model...")
# tokenizer = AutoTokenizer.from_pretrained(
#     output_dir,
#     truncation=True,
#     padding=True,
#     max_length=max_length,
# )

# model = AutoModelForSequenceClassification.from_pretrained(
#     base_model,
#     num_labels=1,
#     device_map="auto",
#     attn_implementation="flash_attention_2",
#     torch_dtype=torch.bfloat16,
# )

# model = PeftModel.from_pretrained(model, output_dir)
# model = model.merge_and_unload()

# # Apply the Liger kernel to the model
# _apply_liger_kernel_to_instance(model=model)

# # Load the trained PEFT model

# print("Model loaded, merged, and Liger kernel applied successfully.")


Loading tokenizer and model...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at unsloth/Meta-Llama-3.1-8B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded, merged, and Liger kernel applied successfully.


In [12]:
from sklearn.metrics import mean_squared_error
import numpy as np


def calculate_rmse(actual, predicted):
    return np.sqrt(mean_squared_error(actual, predicted))


Learnings:
 - Time to process 10000 stories with unmerged PEFT model: 1m 13s
 - Time to process 10000 stories with merged PEFT model: 1m 24s
 - Probably not worth merging for small sample sizes, but worth it for large batches.

In [28]:
from tqdm import tqdm


def run_inference_transformers(prompts: list[str]) -> list[float]:


rewards = run_inference_transformers(small_sample["serialized"].to_list())

small_sample = small_sample.with_columns(
    pl.Series(name="log_predicted_score", values=rewards)
)

small_sample = small_sample.with_columns(
    pl.Series(name="predicted_score", values=small_sample["log_predicted_score"].exp())
)

print(
    "RMSE",
    calculate_rmse(small_sample["log_score"], small_sample["log_predicted_score"]),
)
small_sample

100%|██████████| 2500/2500 [01:21<00:00, 30.82it/s]

RMSE 1.1261623220488448





id,title,by,text,score,descendants,time,log_score,serialized,split,log_predicted_score,predicted_score
i64,str,str,str,i64,i64,datetime[μs],f64,str,str,f64,f64
34724037,"""Ask HN: How can I get into Neu…","""notmindthegap""","""Hi all – I’m in my mid-30s, we…",2,0,2023-02-09 14:05:38,0.693147,""" Ask HN: How can I get into Ne…","""test""",0.984375,2.676139
22604483,"""Ask HN: Does having some stres…","""jbms""","""I&#x27;ve really enjoyed some …",1,0,2020-03-17 13:06:20,0.0,""" Ask HN: Does having some stre…","""train""",0.605469,1.832111
33173205,"""Coaching for “Normals”?""","""wanderingCoder""","""I&#x27;m interested in finding…",73,61,2022-10-12 05:26:38,4.290459,""" Coaching for “Normals”? wande…","""train""",0.6875,1.988737
38838197,"""Internet, Blockchain, AI, Amar…","""bernardlunn""","""Amara’s Law (coined by Roy Ama…",3,2,2024-01-02 04:42:10,1.098612,""" Internet, Blockchain, AI, Ama…","""train""",0.644531,1.905094
22675416,"""Show HN: Self-Published Book –…","""anconia""","""I just self-published a book t…",24,13,2020-03-24 15:32:53,3.178054,""" Show HN: Self-Published Book …","""train""",0.9765625,2.655313
…,…,…,…,…,…,…,…,…,…,…,…
24970083,"""Ask HN: Would you like an Appl…","""ciccionamente""","""Today the Raspberry Pi 400 has…",5,11,2020-11-02 16:47:55,1.609438,""" Ask HN: Would you like an App…","""train""",1.1015625,3.008864
19907580,"""ASIC Verification Course in Ba…","""mavensilicon""","""ASIC Design and Verification c…",1,0,2019-05-14 07:46:53,0.0,""" ASIC Verification Course in B…","""train""",0.116699,1.123781
34455854,"""Ask HN: Google spam filter get…","""jgwil2""","""I have noticed an uptick in un…",158,91,2023-01-20 16:50:27,5.062595,""" Ask HN: Google spam filter ge…","""test""",1.3203125,3.744591
11965471,"""Ask HN: How can I get iOS proj…","""selfthrow""","""I am a long time lurker here. …",1,0,2016-06-24 00:55:03,0.0,""" Ask HN: How can I get iOS pro…","""train""",0.7109375,2.035899
