<a href="https://colab.research.google.com/github/DSPagan/llms-computational-complexity/blob/main/llm_complexity_estimation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Prepare the datasets

In [None]:
user = "DSPagan"
repo = "llms-computational-complexity"
src_dir = "data"
train_file = "train_data.jsonl"
test_file = "test_data.jsonl"

url_1 = f"https://raw.githubusercontent.com/{user}/{repo}/main/{src_dir}/{train_file}"
url_2 = f"https://raw.githubusercontent.com/{user}/{repo}/main/{src_dir}/{test_file}"

!wget --no-cache --backups=1 {url_1}
!wget --no-cache --backups=1 {url_2}

# Install dependencies

In [None]:
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo unsloth
!pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer

# Import libraries

In [None]:
from unsloth import FastLanguageModel, is_bfloat16_supported
from unsloth.chat_templates import get_chat_template, train_on_responses_only
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from datasets import Dataset
import json

# Load model

In [None]:
model_path = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
max_seq_length = 2048

# Load the model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_path,
    max_seq_length = max_seq_length, # maximum sequence length for the modelç
    load_in_4bit = True, # use 4-bit quantization for memory efficiency
    dtype = None,
)

# Apply the correct chat template
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3.1",
)

# Fine-tuning with QLoRA

In [None]:
train_data_path = "train_data.jsonl"
output_dir = "outputs"
num_epochs = 2
lora_r = 16
max_seq_length = 2048

# Extract the data from the test_data file
with open(train_data_path, 'r') as f:
    train_data = [json.loads(line.strip()) for line in f]

def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template([convo], tokenize = False, add_generation_prompt = False) for convo in convos]
    texts[1] = texts[1][len("<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 July 2024\n\n<|eot_id|>"):]
    return texts[0]+texts[1]

def gen():
    for code in train_data:
        prompt = f"""Analyze the time complexity of the following code.
    Choose exactly one of the following options: O(1), O(logn), O(n), O(nlogn), O(n^2), O(n^3) or exponential (O(2^n), O(3^n), etc.).
    Give the time complexity of the code:
    {code['src']}"""
        yield {"conversations": [{"role": "user", "content": prompt}, {"role": "assistant", "content": code['complexity']}],
            "text": formatting_prompts_func({"conversations": [{"role": "user", "content": prompt}, {"role": "assistant", "content": code['complexity']}]})}

dataset = Dataset.from_list(list(gen()))

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_r,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
    dataset_num_proc = 2,
    packing = False,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        num_train_epochs = num_epochs,
        learning_rate = 2e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = output_dir,
        report_to = "none",
    ),
)

trainer = train_on_responses_only(
    trainer,
    instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
    response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
)

trainer_stats = trainer.train()

# Inference

In [None]:
test_data_path = "test_data.jsonl"
save_path = "outputs/test_results.jsonl"
max_new_tokens = 1024

# Extract the data from the test_data file
with open(test_data_path, 'r') as f:
    test_data = [json.loads(line.strip()) for line in f]

with open(save_path, 'w') as f:
    for code in test_data:
        # Change the prompt to match the training data format
        prompt = f"""Analyze the time complexity of the following code.
    Choose exactly one of the following options: O(1), O(logn), O(n), O(nlogn), O(n^2), O(n^3) or exponential (O(2^n), O(3^n), etc.).
    Give the time complexity of the code:
    {code['src']}"""
        messages = [
            {"role": "user", "content": prompt}
            ]

        inputs = tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
        ).to("cuda")

        if len(inputs[0]) > max_seq_length:
            continue

        tokens = model.generate(input_ids=inputs,
                                do_sample=False, # deterministic generation
                                max_new_tokens=max_new_tokens, # maximum number of tokens to generate
                                use_cache=True, # use cache for faster generation
                                no_repeat_ngram_size=4) # to avoid repetition

        result = tokenizer.decode(tokens[0],skip_special_tokens=True)
        idx = result.find("assistant")
        result = result[idx + len("assistant"):].lstrip() # Filter the output to get the assistant's response

        entry = {
            "src": code['src'],
            "complexity": code['complexity'],
            "model": result,
        }
        f.write(json.dumps(entry) + '\n')