<a href="https://colab.research.google.com/github/Tanmay06/automated_code_comment/blob/main/prompting_LLMs_automated_code_gen.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from pathlib import Path
import json
from tqdm import tqdm

import torch
import pandas as pd
from torch.utils.data import DataLoader
from transformers import T5ForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup
from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda")

In [4]:
# model_family = "t5"
# model_config = "t5-base"
# tokenizer_config = "t5-base"

# model_family = "t5"
# model_config = "google/flan-t5-base"
# tokenizer_config = "google/flan-t5-base"

# model_family = "llama"
# model_config = "openlm-research/open_llama_3b"
# tokenizer_config = "openlm-research/open_llama_3b"

model_family = "llama"
model_config = "meta-llama/Llama-2-7b-chat-hf"
tokenizer_config = "meta-llama/Llama-2-7b-chat-hf"

In [5]:
# configs
experiment_name = model_config
path = "/home/thv200000/projects/automated_code_comment/data/python_code_comment_samples.csv"
infer_batchsize = 15

In [6]:
datapath = Path(path)
outfile_path = datapath.parent / ("gen_" + experiment_name.replace('/', '-') + ".jsonl")

In [7]:
dataset = load_dataset('csv', data_files=str(datapath))

In [8]:
prompt_path = "/home/thv200000/projects/automated_code_comment/prompt_python.json"

In [9]:
with open(prompt_path) as prompt_file:
    prompt_data = json.loads(prompt_file.read())
    prompt = prompt_data[0]['prompt']
    context = prompt_data[0]['context']

In [10]:
tokenizer = AutoTokenizer.from_pretrained(
    tokenizer_config,
    model_max_length=512
)
tokenizer.pad_token = tokenizer.eos_token

In [11]:
peft_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    num_virtual_tokens=8,
    prompt_tuning_init_text=prompt,
    tokenizer_name_or_path=tokenizer_config,
)

In [12]:
if model_family == "t5":
  model = T5ForConditionalGeneration.from_pretrained(model_config,  device_map="auto")
  setattr(model, 'model_parallel', True)
  setattr(model, 'is_parallelizable', True)
else:
  model = AutoModelForCausalLM.from_pretrained(model_config,  device_map="auto")

Loading checkpoint shards: 100%|██████████| 2/2 [00:16<00:00,  8.28s/it]


In [13]:
model = get_peft_model(model, peft_config)
print(model.print_trainable_parameters())

trainable params: 32,768 || all params: 6,738,448,384 || trainable%: 0.0004862840543203603
None


In [13]:
# def prepare_prompt(context_samples, task):
#     if context_samples:
#         prompt = task + "\n" + '\n'.join(
#             ["####\n" + sample['code'] + " => " + sample['comment'] for sample in context_samples]
#             ) + "\n####\n$code => "
#     else:
#         prompt = "Q: " + task + "\n$code ?\nA:"
    
#     return Template(prompt)

In [14]:
# input_template = prepare_prompt(context, prompt)

In [13]:
def preprocess(examples, input_template, tokenizer):
    # examples["input_ids"] = tokenizer(
    #     # [prompt + example + "A:" for example in examples["func_code_string"]],
    #     [input_template.substitute(code=example) for example in examples["func_code_string"]],
    #     return_tensors='pt',
    #     padding=True,
    #     truncation=True,
    # ).input_ids

    examples["input_ids"] = tokenizer(
        examples["func_code_string"],
        return_tensors='pt',
        padding=True,
        truncation=True,
    ).input_ids
    
    return examples

In [14]:
# tokenized = dataset.map(
#     preprocess,
#     fn_kwargs={"input_template":input_template, "tokenizer":tokenizer},
#     batched=True
# )

tokenized = dataset.map(
    preprocess,
    fn_kwargs={"input_template":None, "tokenizer":tokenizer},
    batched=True
)

In [15]:
tokenized.set_format(type="torch", columns=["input_ids"])

In [16]:
tokenized['train']['input_ids'].shape

torch.Size([113, 512])

In [17]:
train_dataloader = DataLoader(tokenized['train'], batch_size=infer_batchsize)
eval_dataloader = DataLoader(tokenized['train'], batch_size=infer_batchsize)

In [18]:
num_epochs = 50
lr = 3e-2

In [19]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

In [20]:
batch = next(iter(train_dataloader))

In [21]:
out = model(batch["input_ids"].to(device))

AttributeError: 'T5Stack' object has no attribute 'first_device'

In [30]:
out.loss['logits'].shape

torch.Size([15, 520, 32000])

In [22]:
# model = model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(tqdm(train_dataloader)):
        # batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(batch["input_ids"].to(device))
        loss = outputs.loss
        total_loss += loss.detach().float()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    model.eval()
    eval_loss = 0
    eval_preds = []
    for step, batch in enumerate(tqdm(eval_dataloader)):
        # batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(batch["input_ids"].to(device))
        loss = outputs.loss
        eval_loss += loss.detach().float()
        eval_preds.extend(
            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
        )

    eval_epoch_loss = eval_loss / len(eval_dataloader)
    eval_ppl = torch.exp(eval_epoch_loss)
    train_epoch_loss = total_loss / len(train_dataloader)
    train_ppl = torch.exp(train_epoch_loss)
    print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")

  0%|          | 0/8 [00:00<?, ?it/s]


AttributeError: 'T5Stack' object has no attribute 'first_device'