In [None]:
!pip install transformers datasets evaluate accelerate



In [None]:
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer
import torch
import time
import evaluate
import pandas as pd
import numpy as np

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
dataset_full = load_dataset("mbpp")

In [None]:
model_name = 'google/flan-t5-small'

original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
def print_trainable_model_parameters(model):
  trainable_model_params = 0
  all_model_params = 0

  for _, param in model.named_parameters():
    all_model_params += param.numel()
    if param.requires_grad:
      trainable_model_params += param.numel()

  return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage: {trainable_model_params/all_model_params * 100}%"

In [None]:
print(print_trainable_model_parameters(original_model))

trainable model parameters: 76961152
all model parameters: 76961152
percentage: 100.0%


In [None]:
dataset_full

DatasetDict({
    train: Dataset({
        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list'],
        num_rows: 374
    })
    test: Dataset({
        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list'],
        num_rows: 500
    })
    validation: Dataset({
        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list'],
        num_rows: 90
    })
    prompt: Dataset({
        features: ['task_id', 'text', 'code', 'test_list', 'test_setup_code', 'challenge_test_list'],
        num_rows: 10
    })
})

In [None]:
dataset_full['train'][0]

{'task_id': 601,
 'text': 'Write a function to find the longest chain which can be formed from the given set of pairs.',
 'code': 'class Pair(object): \r\n\tdef __init__(self, a, b): \r\n\t\tself.a = a \r\n\t\tself.b = b \r\ndef max_chain_length(arr, n): \r\n\tmax = 0\r\n\tmcl = [1 for i in range(n)] \r\n\tfor i in range(1, n): \r\n\t\tfor j in range(0, i): \r\n\t\t\tif (arr[i].a > arr[j].b and\r\n\t\t\t\tmcl[i] < mcl[j] + 1): \r\n\t\t\t\tmcl[i] = mcl[j] + 1\r\n\tfor i in range(n): \r\n\t\tif (max < mcl[i]): \r\n\t\t\tmax = mcl[i] \r\n\treturn max',
 'test_list': ['assert max_chain_length([Pair(5, 24), Pair(15, 25),Pair(27, 40), Pair(50, 60)], 4) == 3',
  'assert max_chain_length([Pair(1, 2), Pair(3, 4),Pair(5, 6), Pair(7, 8)], 4) == 4',
  'assert max_chain_length([Pair(19, 10), Pair(11, 12),Pair(13, 14), Pair(15, 16), Pair(31, 54)], 5) == 5'],
 'test_setup_code': '',
 'challenge_test_list': []}

In [None]:
index = 200

text = dataset_full['test'][index]['text']
summary = dataset_full['test'][index]['code']

prompt = f"""
Generate python code based on the below query.

{text}

Code.
"""

print(prompt)


Generate python code based on the below query.

Write a python function to count numbers whose oth and nth bits are set.

Code.



In [None]:
inputs = tokenizer(prompt, return_tensors='pt')
output = tokenizer.decode(
    original_model.generate(
        inputs["input_ids"],
        max_new_tokens=200,
    )[0],
    skip_special_tokens=True
)

dash_line = '-'.join('' for x in range(100))
print(dash_line)
print(f'INPUT PROMPT:\n{prompt}')
print(dash_line)
print(f'BASELINE HUMAN SUMMARY:\n{summary}')
print(dash_line)
print(f'MODEL GENERATION - ZERO SHOT:\n{output}')

---------------------------------------------------------------------------------------------------
INPUT PROMPT:

Generate python code based on the below query.

Write a python function to count numbers whose oth and nth bits are set.

Code.

---------------------------------------------------------------------------------------------------
BASELINE HUMAN SUMMARY:
def count_Num(n): 
    if (n == 1): 
        return 1
    count = pow(2,n - 2) 
    return count 
---------------------------------------------------------------------------------------------------
MODEL GENERATION - ZERO SHOT:
nth_bits = 0 for i in range(1, nth_bits): if i == nth_bits: nth_bits = i nth_bits = i nth_bits = i nth_bits = i nth_bits = i nth_bits = i nth_bits = i nth_bits = i nth_bits = i nth_bits = i nth_bits = i nth_bits = i nth_bits = i nth_bits = i nth_bits = i nth_bits = i nth_bits = i nth_bits = i 


In [None]:
def tokenize_function(example):
  start_prompt = 'Generate python code based on the below query.\n\n'
  end_prompt = '\n\nCode: '
  prompt = [start_prompt + dialogue + end_prompt for dialogue in example["text"]]
  example["input_ids"] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
  example["labels"] = tokenizer(example["code"], padding="max_length", truncation=True, return_tensors="pt").input_ids

  return example

In [None]:
tokenized_datasets = dataset_full.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['test_list', 'test_setup_code', 'challenge_test_list'])
tokenized_datasets

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['task_id', 'text', 'code', 'input_ids', 'labels'],
        num_rows: 374
    })
    test: Dataset({
        features: ['task_id', 'text', 'code', 'input_ids', 'labels'],
        num_rows: 500
    })
    validation: Dataset({
        features: ['task_id', 'text', 'code', 'input_ids', 'labels'],
        num_rows: 90
    })
    prompt: Dataset({
        features: ['task_id', 'text', 'code', 'input_ids', 'labels'],
        num_rows: 10
    })
})

In [None]:
print("Shapes of datasets")
print(f"Training: {tokenized_datasets['train'].shape}")
print(f"Test: {tokenized_datasets['test'].shape}")

Shapes of datasets
Training: (374, 5)
Test: (500, 5)


In [None]:
output_dir = "/"

training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=1e-5,
    num_train_epochs=5,
    weight_decay=0.01,
    logging_steps=100,
    max_steps=1000
)

trainer = Trainer(
    model=original_model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
)

In [None]:
trainer.train()

Step,Training Loss
100,49.9825
200,46.98
300,45.5875
400,44.75
500,44.075
600,43.9875
700,43.6975
800,43.76
900,43.7725
1000,43.7175


TrainOutput(global_step=1000, training_loss=45.031, metrics={'train_runtime': 888.5246, 'train_samples_per_second': 9.004, 'train_steps_per_second': 1.125, 'total_flos': 1479316636434432.0, 'train_loss': 45.031, 'epoch': 21.28})

In [None]:
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to('cuda')
model_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beans=1))
model_text_output = tokenizer.decode(model_outputs[0], skip_special_tokens=True)

dash_line = '-'.join('' for x in range(100))
print(dash_line)
print(f'INPUT PROMPT:\n{prompt}')
print(dash_line)
print(f'BASELINE HUMAN SUMMARY:\n{summary}')
print(dash_line)
print(f'MODEL GENERATION - ZERO SHOT:\n{model_text_output}')

---------------------------------------------------------------------------------------------------
INPUT PROMPT:

Generate python code based on the below query.

Write a python function to count numbers whose oth and nth bits are set.

Code.

---------------------------------------------------------------------------------------------------
BASELINE HUMAN SUMMARY:
def count_Num(n): 
    if (n == 1): 
        return 1
    count = pow(2,n - 2) 
    return count 
---------------------------------------------------------------------------------------------------
MODEL GENERATION - ZERO SHOT:
Using a nth-bits function to count numbers, a nth-bits function to count numbers, b nth bits, a nth nth bits, a nth bits, a nth bits, a nth bits, a nth bits, a nth bits, a nth bits, a nth bits, a nth bits, a nth bits, a nth bits, a nth bits, a nth bits, a nth bits, a nth bits, a nth bits, a nth bits, nth bits, a nth bits, a nth bits, 
