In [10]:
! python -m pip install pandas lm-format-enforcer bitsandbytes



In [1]:
import pandas as pd
from lmformatenforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    BitsAndBytesConfig,
    GenerationConfig,
)
import torch
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from lmformatenforcer import RegexParser
# we expect either a '1' or a '2' as a response
parser = RegexParser("[12]")

In [3]:
#model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
#model_id = "meta-llama/Meta-Llama-3.1-8B"
#model_id = "google/gemma-2-2b-it"
model_id = "microsoft/Phi-3-mini-128k-instruct"

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

# we'll quantize pretty aggressively, for improved inference speed
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    llm_int8_enable_fp32_cpu_offload=True,
)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="cuda",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    quantization_config=quantization_config,
    trust_remote_code=True,
)

Downloading shards: 100%|██████████| 2/2 [00:44<00:00, 22.43s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.01s/it]


In [5]:
prefix_function = build_transformers_prefix_allowed_tokens_fn(tokenizer, parser)
def classify(input_s: str):
    input_ids = tokenizer(input_s, return_tensors="pt").to("cuda")
    outputs = model.generate(
        **input_ids,
        generation_config=GenerationConfig(max_new_tokens=1),
        prefix_allowed_tokens_fn=prefix_function,
    )
    output_s = tokenizer.decode(outputs[0])
    return int(output_s[-1])

In [6]:
def fmt_prompt_llama(contents: str) -> str:
    return "\n".join([
        "<|begin_of_text|><|start_header_id|>system<|end_header_id|>",
        "".join([
            "You are a helpful spam email detector.",
            "Given an email, respond with the number '1' if the email is certainly spam, otherwise the number '2'.",
            "<|eot_id|><|start_header_id|>user<|end_header_id|>",
        ]),
        "".join([
            "EMAIL:",
            contents,
            "<|eot_id|><|start_header_id|>assistant<|end_header_id|>",
        ]),
    ])


In [9]:
def fmt_prompt_gemma(contents: str) -> str:
    return "\n".join([
        "<bos><start_of_turn>user",
        " ".join([
            "You are a helpful spam email detector.",
            "Given an email, respond with the number '1' if the email is likely spam, otherwise the number '2'.",
            f"EMAIL: {contents}<end_of_turn>",            
        ]),
        "<start_of_turn>model",
    ])

In [7]:
def fmt_prompt_phi(contents: str) -> str:
    return "\n".join([
        "<|system|>",
        "You are a helpful spam email detector.  Given an email, respond with the number '1' if it is trying to sell something, otherwise the number '2'.<|end|>",
        "<|user|>",
        f"Classify the following email: {contents}<|end|>",
        "<|assistant|>",
    ])

In [10]:
PROMPT_FORMATTERS = {
    "google/gemma-2-2b-it": fmt_prompt_gemma,
    "microsoft/Phi-3-mini-128k-instruct": fmt_prompt_phi,
    "meta-llama/Meta-Llama-3.1-8B-Instruct": fmt_prompt_llama,
}
df = pd.read_csv("./emails.csv")
manual_init = pd.concat((
    df[df["spam"] == 1].iloc[:10],
    df[df["spam"] == 0].iloc[:10],
))
prompt_formatter = PROMPT_FORMATTERS[model_id]
manual_init["bootstrap_spam"] = manual_init["text"].apply(
    lambda text: classify(prompt_formatter(text)) %2
)
manual_init

The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.


Unnamed: 0,text,spam,bootstrap_spam
0,Subject: naturally irresistible your corporate...,1,1
1,Subject: the stock trading gunslinger fanny i...,1,0
2,Subject: unbelievable new homes made easy im ...,1,1
3,Subject: 4 color printing special request add...,1,1
4,"Subject: do not have money , get software cds ...",1,1
5,"Subject: great nnews hello , welcome to medzo...",1,1
6,Subject: here ' s a hot play in motion homela...,1,0
7,Subject: save your money buy getting this thin...,1,1
8,Subject: undeliverable : home based business f...,1,1
9,Subject: save your money buy getting this thin...,1,1


In [11]:
len(manual_init[manual_init["spam"] == manual_init["bootstrap_spam"]]) / len(manual_init)

0.9

In [32]:
SCORES = {
    "gemma": 0.6,
    "phi": 0.9,
    "llama": 0.65,
}

Now, we'll create an initial bootstrapped labeling of the entire dataset using this model

In [13]:
df = pd.read_csv("./emails.csv")
bootstrap_spam = []
progress_bar = tqdm(range(len(df)))
for i, row in df.iterrows():
    bootstrap_spam.append(
        classify(prompt_formatter(row["text"])) %2
    )
    progress_bar.update(1)
    pass
df["bootstrap_spam"] = bootstrap_spam

100%|██████████| 5728/5728 [13:27<00:00,  6.80it/s]

In [14]:
df.to_csv("emails_bootstrapped.csv", index=False)

In [15]:
pd.concat((
    df[df["spam"] == 1].iloc[:50],
    df[df["spam"] == 0].iloc[:50],
)).to_csv("emails_bootstrapped_mixed_sample.csv")

In [17]:
len(df[df["spam"] == df["bootstrap_spam"]]) / len(df)

0.8170391061452514