In [None]:
MAX_SEQ_LENGTH = 4096

### Load Data

In [None]:
from datasets import load_dataset

DATA_NAME = "cnn_dailymail"
SUBSET = "2.0.0"
DATA_PATH = "data/cnn_dailymail"
FRAC = 0.1

def save_fraction_dataset(dataset_name, fraction, save_name, subset=None):
    if subset is not None:
        data = load_dataset(dataset_name, name=subset, cache_dir=None)
    else:
        data = load_dataset(dataset_name, cache_dir=None)

    splits = data.keys()
    for split in splits:
        splitdata = data[split]
        num_to_keep = int(len(splitdata) * fraction)
        indices = list(range(len(splitdata)))
        # randomly shuffle if you want
        # random.shuffle(indices)
        data[split] = splitdata.select(indices[:num_to_keep])

    # save data
    data.save_to_disk(save_name)

save_fraction_dataset(DATA_NAME, FRAC, DATA_PATH, SUBSET)

In [None]:
from datasets import load_from_disk
data = load_from_disk(DATA_PATH)
train_data = data["train"]
val_data = data["validation"]
test_data = data["test"]

### Format Data

In [None]:
def promptify_data(examples):
    articles = examples['article']
    summaries = examples['highlights']
    texts = []
    for article, summary in zip(articles, summaries):
        text = f"### Article: {article}\n### Summary: {summary}"
        texts.append(text)
    return {"text" : texts,}

# format dataset
train_data = train_data.map(promptify_data, batched=True)
val_data = val_data.map(promptify_data, batched=True)


### Training

In [None]:
import torch
from datasets import load_dataset, load_from_disk, disable_caching
from transformers import (
    TrainingArguments,
    BitsAndBytesConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
)
from trl import SFTTrainer
from modeleval import evaluate_model
from peft import prepare_model_for_kbit_training, LoraConfig

In [None]:
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained('sarvamai/OpenHathi-7B-Hi-v0.1-Base', cache_dir = "./hub", quantization_config=bnb_config, device_map={"": 0})
model = prepare_model_for_kbit_training(model)

In [None]:
tokenizer = AutoTokenizer.from_pretrained('sarvamai/OpenHathi-7B-Hi-v0.1-Base', use_fast=True, add_eos_token=True)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = "left"

In [None]:
peft_config = LoraConfig(
    lora_alpha=16, 
    lora_dropout=0.05,
    r=16,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules= ["down_proj","up_proj","gate_proj"]
)

In [None]:
training_arguments = TrainingArguments(
        output_dir="./results/",
        evaluation_strategy="steps",
        optim="paged_adamw_8bit",
        save_steps=50,
        log_level="debug",
        logging_steps=100,
        learning_rate=1e-4,
        eval_steps=100,
        fp16=True,
        do_eval=True,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        gradient_accumulation_steps=3,
        warmup_steps=100,
        max_steps=3000,
        lr_scheduler_type="linear"
)

In [None]:
# create trainer object
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_data,
    eval_dataset=val_data,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=MAX_SEQ_LENGTH,
    args=training_arguments,
    packing=False,
)

In [None]:
trainer.train()

### Eval

In [None]:
from tqdm import tqdm
import evaluate

METRIC = "rouge"
MAX_SEQ_LENGTH = 4096
MAX_OUT_LENGTH = 1000
NUM_EXAMPLES = 5

metric = evaluate.load(METRIC)

def promptify_single(article):
    return f"###Article: {article}\n### Summary:"

def example_input_output(model, tokenizer, data):
    print("\nExample Input/Output...\n")
    for i in range(NUM_EXAMPLES):
        prompt = promptify_single(data["article"][i])
        print("INPUT:")
        print(prompt)
        tokens = tokenizer(prompt, return_tensors='pt')
        tok_len = tokens["input_ids"].shape[1]
        model_out = model.generate(**tokens,
                                   do_sample=True,
                                   temperature=0.1,
                                   output_scores=True,
                                   max_new_tokens=MAX_OUT_LENGTH,)
        new_tokens = model_out[0, tok_len:]
        output = tokenizer.decode(new_tokens, skip_special_tokens=True)
        print("OUTPUT:")
        print(output)

def rouge_test(model, tokenizer, data):
    print("\nTesting Model...\n")
    outputs = []
    targets = data["highlights"]
    for model_in in tqdm(data["article"]):
        prompt = promptify_single(model_in)
        tokens = tokenizer(prompt, return_tensors='pt')
        tok_len = tokens["input_ids"].shape[1]
        model_out = model.generate(**tokens,
                                   do_sample=True,
                                   temperature=0.1,
                                   output_scores=True,
                                   max_new_tokens=MAX_OUT_LENGTH,)
        new_tokens = model_out[0, tok_len:]
        output = tokenizer.decode(new_tokens, skip_special_tokens=True)
        outputs.append(output)

    results = metric.compute(predictions=outputs, references=targets)
    rouge1 = results['rouge1']
    rouge2 = results['rouge2']
    rougeL = results['rougeL']
    rougeLsum = results['rougeLsum']
    print(f"Rouge test results:\nrouge1:{rouge1}\nrouge2:{rouge2}\nrougeL:{rougeL}\nrougeLSum:{rougeLsum}")

In [None]:
example_input_output(model, tokenizer, test_data)

In [None]:
rouge_test(model, tokenizer, test_data)