In [1]:
import torch
from torch.cuda.amp import autocast  # For mixed precision
from torch.utils.data import DataLoader
import transformers
from trl import SFTTrainer

from accelerate import PartialState
from peft import LoraConfig, PeftModel, PeftConfig
from transformers import (
    AutoModelForCausalLM, AutoTokenizer,
    BitsAndBytesConfig,
    logging,
    set_seed,
    BatchEncoding,
    EarlyStoppingCallback
)
from datasets import load_dataset

from typing import Any, DefaultDict, List, Dict
import os, time, socket, argparse
import numpy as np
import pyarrow as pa
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings("ignore")

# For loading Tfix dataset 
from prepare_data import create_data,extract_warning_types
from data_reader import GetDataAsPython

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--datadir", type=str, default="APRDataset/TFix")
parser.add_argument("--model_id", type=str, default="starcoder2-3b")
parser.add_argument("--basemodeldir", type=str, default=f"/LLMS/starcoder2-3b") ## change to your directory
parser.add_argument("--modelsavedir", type=str, default=f"APRModels/StarCoder2-3B_Tfix")
parser.add_argument("--output_dir", type=str, default="StarCoder2-3B_TFix")
parser.add_argument("--checkpointdir", type=str, default="")

parser.add_argument("--attention_dropout", type=float, default=0.1)
parser.add_argument("--max_steps", type=int, default=5000)
parser.add_argument("--micro_batch_size", type=int, default=8)
parser.add_argument("--seed", type=int, default=48)
parser.add_argument("--max_seq_length", type=int, default=512)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--fp16", type=bool, default=True)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument("--push_to_hub", type=bool, default=False)

# Parse the arguments with the updated defaults
args = parser.parse_args(args=[])
print(args)
set_seed(args.seed)
modelname = args.basemodeldir.split("/")[-1]

Namespace(datadir='APRDataset/TFix', model_id='starcoder2-3b', basemodeldir='/media/zero/ssd2/LLMS/starcoder2-3b', modelsavedir='APRModels/StarCoder2-3B_Tfix', output_dir='StarCoder2-3B_TFix', checkpointdir='', attention_dropout=0.1, max_steps=5000, micro_batch_size=8, seed=48, max_seq_length=512, gradient_accumulation_steps=1, weight_decay=0.01, fp16=True, learning_rate=0.0002, lr_scheduler_type='cosine', warmup_steps=500, push_to_hub=False)


In [3]:
def change_eos(data):
    data = [x.replace("</s>",tokenizer.eos_token) for x in data]
    return data
def delete_eos(data):
    data = [x.replace("</s>","") for x in data]
    return data

# For batch running
class CausalBugFixDataset(torch.utils.data.Dataset):
    def __init__(self, encodings: BatchEncoding, targets: BatchEncoding, idxs):
        self.encodings = encodings
        self.target_encodings = targets
        self.idxs = idxs

    def __getitem__(self, index: int) -> Dict[str, Any]:
        item = {key: val[index] for key, val in self.encodings.items()}
        item["labels"] = self.target_encodings["input_ids"][index]
        item["idx"] = self.idxs[index]
        return item

    def __len__(self) -> int:
        return len(self.encodings["input_ids"])
  
def create_dataset(
    idxs: List[int],
    inputs: List[str],
    tokenizer,
) -> CausalBugFixDataset:
    # Based on Transformer version: padding='max_length' or padding=True
    input_encodings = tokenizer(
        inputs, truncation=True, padding='longest', return_tensors='pt', max_length=512
    ) 
    label_encodings = tokenizer(
        inputs, truncation=True, padding='longest', return_tensors='pt', max_length=512
    ) 
    label_encodings["input_ids"][label_encodings["input_ids"] == tokenizer.pad_token_id] = -100


    dataset = CausalBugFixDataset(input_encodings, label_encodings, idxs)
    return dataset

# Load model

In [4]:
#### Setting for Quantization & LoRA 
bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

lora_config = LoraConfig(
    r=8,
    target_modules=[
        "q_proj",
        "o_proj",
        "k_proj",
        "v_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    task_type="CAUSAL_LM",
)
########################

start = time.time()
model =AutoModelForCausalLM.from_pretrained(args.basemodeldir,quantization_config=bnb_config,
                                            device_map={"": PartialState().process_index},
                                            attention_dropout=args.attention_dropout,
                                            )
end = time.time()
print(f"{modelname} loaded in", end- start,"sec")
tokenizer = AutoTokenizer.from_pretrained(args.basemodeldir)
  
if args.checkpointdir!='':
    tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_dir)
    model = PeftModel.from_pretrained(model, 
    checkpoint_dir,
     is_trainable=False 
    )

special_tokens_dict = {'pad_token': '<pad>', 'sep_token': '<sep>'}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))  # Resizing the token embeddings to match the tokenizer

# Ensure the tokenizer uses left padding
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token  # Typically, EOS is used as the pad token in GPT models
model

starcoder2-3b loaded in 2.594926595687866 sec


Starcoder2ForCausalLM(
  (model): Starcoder2Model(
    (embed_tokens): Embedding(49154, 3072)
    (layers): ModuleList(
      (0-29): 30 x Starcoder2DecoderLayer(
        (self_attn): Starcoder2SdpaAttention(
          (q_proj): Linear4bit(in_features=3072, out_features=3072, bias=True)
          (k_proj): Linear4bit(in_features=3072, out_features=256, bias=True)
          (v_proj): Linear4bit(in_features=3072, out_features=256, bias=True)
          (o_proj): Linear4bit(in_features=3072, out_features=3072, bias=True)
          (rotary_emb): Starcoder2RotaryEmbedding()
        )
        (mlp): Starcoder2MLP(
          (c_fc): Linear4bit(in_features=3072, out_features=12288, bias=True)
          (c_proj): Linear4bit(in_features=12288, out_features=3072, bias=True)
          (act): PytorchGELUTanh()
        )
        (input_layernorm): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
      )
 

# Load TFix Dataset

In [5]:
###############
#  Load Dataset 
###############
data = GetDataAsPython("TFixData/data_autofix_tracking_repo_specific_final.json")
data_eslint = GetDataAsPython("TFixData/data_autofix_tracking_eslint_final.json")
data += data_eslint
all_warning_types = extract_warning_types(data)
print("# of warning types:", len(all_warning_types))

(
    train_inputs,
    train_labels,
    val_inputs,
    val_labels,
    test_inputs,
    test_labels,
    train_info,
    val_info,
    test_info,
) = create_data(data, ['no-constant-condition'], include_warning=True, model_name='')

inputs, labels, types, infos = [],[],[],[]
for warning_type in all_warning_types:
    inputs+=test_inputs[warning_type]
    labels+=test_labels[warning_type]
    types+=[warning_type]*len(test_labels[warning_type])
    infos+=test_info[warning_type]
test_data = dict()
test_data['buggy']=delete_eos(inputs)
test_data['fixed']=change_eos(labels) 
test_data['info']=infos

inputs = delete_eos(val_inputs)
labels = change_eos(val_labels)
infos = val_info
types = [x.linter_report.rule_id for x in val_info]
val_data = dict()
val_data['buggy']=pa.array(inputs)
val_data['fixed']=pa.array(labels)
val_data['info']=infos

train_inputs = delete_eos(train_inputs)
train_labels = change_eos(train_labels)

val_inputs = delete_eos(val_inputs)
val_labels = change_eos(val_labels)


# Change Input format - follow TFix paper
train_inputs = [f"{i}\nFixed: <sep>\n{o}" for i,o in zip(train_inputs, train_labels)]
val_inputs = [f"{i}\nFixed: <sep>\n{o}" for i,o in zip(val_inputs, val_labels)]
print("\n####### Train data sample")
print(train_inputs[0])
test_inputs = [f"{i}\nFixed: " for i in test_data['buggy']]
tidxs = np.arange(len(test_inputs))

train size: 1039
val size: 116
test size: 129

####### Train data sample
fix no-constant-condition Unexpected constant condition. 		if (err.code = 'ECONNRESET') {
:
	onError(err) {
		if (err.code = 'ECONNRESET') {
			if (!this.retryRegistration) { 
 
Fixed: <sep>
	onError(err) {
		if (err.code === 'ECONNRESET') {
			if (!this.retryRegistration) { 
 <|endoftext|>


# Train

In [6]:
idxs = np.arange(len(train_inputs))
dataset = create_dataset(idxs, train_inputs,tokenizer)
vidxs = np.arange(len(val_inputs))
valdataset = create_dataset(vidxs, val_inputs,tokenizer)

# Define EarlyStoppingCallback
early_stopping = EarlyStoppingCallback(
    early_stopping_patience=5,  # Number of evaluations with no improvement before stopping
    early_stopping_threshold=0.0  # Minimum change to qualify as an improvement
)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    max_seq_length=args.max_seq_length,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=args.micro_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        warmup_steps=args.warmup_steps,
        max_steps=args.max_steps,
        learning_rate=args.learning_rate,
        lr_scheduler_type=args.lr_scheduler_type,
        weight_decay=args.weight_decay,
        fp16=args.fp16,
        logging_strategy="steps",
        logging_steps=10,
        output_dir=args.output_dir,
        optim="paged_adamw_8bit",
        seed=args.seed,
        run_name=f"train-{args.model_id.split('/')[-1]}-tfix",
        report_to="wandb",
        load_best_model_at_end=True,
        eval_strategy='steps',
        eval_steps=10,
        save_steps=10,
        save_strategy='steps',
        metric_for_best_model="eval_loss", 
        save_total_limit = 3,
    ),
    peft_config=lora_config,
    eval_dataset = valdataset,
    callbacks=[early_stopping]     
)

trainer.train()

max_steps is given, it will override any value given in num_train_epochs
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mxx[0m ([33mxx[0m). Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss
10,5.4624,3.58737
20,5.5552,3.573278
30,5.7617,3.518967
40,5.2587,3.407802
50,5.1072,3.211393
60,4.6042,2.929364
70,4.4039,2.621439
80,3.7701,2.392522
90,3.5098,2.262323
100,3.3665,2.126256


We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


TrainOutput(global_step=630, training_loss=1.8572383350796169, metrics={'train_runtime': 1062.3939, 'train_samples_per_second': 37.651, 'train_steps_per_second': 4.706, 'total_flos': 3.2329168415760384e+16, 'train_loss': 1.8572383350796169, 'epoch': 4.846153846153846})

# Test

In [7]:
model.eval()

batch_size = 32 
# Faster than tokenizing every time
all_tokenized_inputs = tokenizer(test_inputs, return_tensors="pt", truncation=True, max_length=512-2, padding=True)
all_tokenized_inputs = {k: v.to('cuda') for k, v in all_tokenized_inputs.items()}

# Bar to check progress
pbar = tqdm(range(0, len(test_data['fixed']), batch_size))
ems = []
count = 0

with torch.no_grad():  # Disable gradient computation for inference
    for i in range(0, len(test_data['fixed']), batch_size):
        # Prepare the batch
        batch_inputs = {k: v[i:i + batch_size] for k, v in all_tokenized_inputs.items()}
        batch_answers = test_data['fixed'][i:i + batch_size]

        # Mixed precision inference for speedup
        with autocast():
            outputs = model.generate(input_ids=batch_inputs['input_ids'], 
                                     attention_mask=batch_inputs['attention_mask'],
                                     num_beams=1, max_length=512, # Greedy decoding
                                     pad_token_id=tokenizer.eos_token_id, 
                                     early_stopping=True)

        # Decode the outputs
        batch_outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

        # Compare predictions with answers
        for j, (tout, tanswer) in enumerate(zip(batch_outputs, batch_answers)):
            tout = tout.split('Fixed:')
            if len(tout) < 2:
                fixed = tout[0].strip()
            else:
                fixed = tout[1].strip()

            # Check exact match with developer fix
            fixed = ''.join([x.strip() for x in fixed.split()])
            tanswer = tanswer.strip()
            tanswer = ''.join([x.strip() for x in tanswer.split()]).replace('<|endoftext|>', '')
            em = 1 if fixed == tanswer else 0
            ems.append(em)

        count += batch_size
        print(f"Current Score: {sum(ems)} / {len(ems)}")
        pbar.update(1)

# Calculate overall exact match (EM) score
em_score = sum(ems) / len(ems)
print(f"Exact Match (EM) Score: {em_score}")

  0%|          | 0/5 [00:00<?, ?it/s]

GenerationMode.GREEDY_SEARCH
Current Score: 12 / 32
GenerationMode.GREEDY_SEARCH
Current Score: 26 / 64
GenerationMode.GREEDY_SEARCH
Current Score: 40 / 96
GenerationMode.GREEDY_SEARCH
Current Score: 51 / 128
GenerationMode.GREEDY_SEARCH
Current Score: 51 / 129
Exact Match (EM) Score: 0.3953488372093023
