In [None]:
import torch
from transformers import AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig
from peft import LoraConfig
from transformers import set_seed
from transformers import AutoTokenizer
from peft import prepare_model_for_kbit_training, get_peft_model
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer
import os
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from peft import PeftModelForCausalLM
import torch
from datasets import DatasetDict, Dataset
from sklearn.model_selection import train_test_split
import pandas as pd
import gc
import json


In [None]:
seed = 42
set_seed(seed)
model_name = "TinyLlama/TinyLlama_v1.1"

In [None]:
def train_model(base_config, lora_config, bnb_config, data, tokenizer, collator):
    global model_name

    if bnb_config:
        print("QLORA")
        bnb = BitsAndBytesConfig(
            load_in_4bit=bnb_config["load_in_4bit"],
            bnb_4bit_use_double_quant=bnb_config["bnb_4bit_use_double_quant"],
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=bnb_config["bnb_4bit_compute_dtype"]
        )
        
        model = AutoModelForCausalLM.from_pretrained(
                    model_name,
                    device_map="auto",
                    revision="main",
                    quantization_config = bnb
                )
    else:
        model = AutoModelForCausalLM.from_pretrained(
                model_name,
                device_map="auto",
                revision="main"
            )
    
    if lora_config:
        print("LORA")
        lora = LoraConfig(
            r = lora_config["r"],
            lora_alpha = lora_config["lora_alpha"],
            init_lora_weights = True,
            lora_dropout = lora_config["lora_dropout"],
            bias = 'none',
            task_type="CAUSAL_LM"
        )
        
        model = prepare_model_for_kbit_training(model)
        model = get_peft_model(model,lora)
    
    args = TrainingArguments(
        output_dir=".",
        fp16=base_config["fp16"],
        weight_decay=base_config["weight_decay"],
        learning_rate=base_config["learning_rate"],
        label_names=['input_ids'],
        num_train_epochs=1,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        gradient_accumulation_steps=8,
        no_cuda=False,
        optim="paged_adamw_8bit"
    )
    
    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=data,
        tokenizer=tokenizer,
        data_collator=collator
    )
    
    trainer.train()
    save_path = "tmp_trainer_smol"
    #trainer.save_model(save_path)
    #model.save_pretrained(save_path+"_peft")
    
    return model

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

def reformat_func(example):
    example["full"] = "# <func>\n" + example["head"] + example["body"] + "\n</func>"
    return example

def tokenize_func(example):
    return tokenizer(example["full"], return_tensors="np",padding="max_length",max_length=1000)

data = Dataset.from_parquet("../data/chunks/chunk_1.parquet")

df = pd.DataFrame(data)

filtered_df = df[df["language"] == "Python"]

sampled_df = filtered_df.sample(frac=.005, random_state=42).reset_index(drop=True)

data = Dataset.from_pandas(sampled_df)

print(f"Original size after filtering: {len(filtered_df)}")
print(f"Sampled size (1%): {len(sampled_df)}")
print(sampled_df.head())


data = data.map(reformat_func)
tokenized_ds = data.map(tokenize_func, batched=True)



In [None]:
model = train_model({"weight_decay":0.1,"learning_rate":1e-4,"fp16":False},
            {"r":4096,"lora_alpha":4096,"lora_dropout":0.1},
            {"load_in_4bit":True,"bnb_4bit_use_double_quant":True,"bnb_4bit_compute_dtype":"bfloat16"},
            #False,
            tokenized_ds,tokenizer,collator)

prompt = """from typing import List\n# <func>\n# Python\n# Check if in given list of numbers, are any two numbers closer to each other than given threshold.\n#>>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n# False\n# >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n# True\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:"""

gen = pipeline(model=model, tokenizer=tokenizer, task="text-generation", device_map="auto",max_new_tokens=512)
print(gen(prompt))


gen = pipeline(model=model.merge_and_unload(), tokenizer=tokenizer, task="text-generation", device_map="auto",max_new_tokens=512)
print(gen(prompt))


gen = pipeline(model=AutoModelForCausalLM.from_pretrained(model_name), tokenizer=tokenizer, task="text-generation", device_map="auto",max_new_tokens=512)
print(gen(prompt))