In [None]:
%%capture
%%bash
pip install datasets sentencepiece rouge_score wandb
pip install accelerate -U
pip install transformers[torch]

In [None]:
!pip install wandb

In [None]:
import torch
import numpy as np
import datasets
from datasets import Dataset

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    BartTokenizer,
    BartForConditionalGeneration,
)

from tabulate import tabulate
import nltk
from datetime import datetime

In [None]:
WANDB_INTEGRATION = True
# if WANDB_INTEGRATION:
import wandb

In [None]:
# model_name = "facebook/bart-large"
# model_name = "facebook/bart-base"
model_name = "Koshti10/BART-large-ET-Synthetic"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)


# Set model parameters or use the default
# print(model.config)

# tokenization
encoder_max_length = 1024  # demo
decoder_max_length = 256 # demo

In [None]:

import pandas as pd

train_df = pd.read_csv('/path/to/train_dataset')
valid_df = pd.read_csv('/path/to/valid_dataset')

new_train_diaog = []
for i in range(len(train_df)):
  dia = train_df.iloc[i]['dialog']
  dia = dia.replace("<<Commander>>", "").replace("<<Driver>>", "")
  new_train_diaog.append(dia)

new_valid_diaog = []
for i in range(len(valid_df)):
  dia = valid_df.iloc[i]['dialog']
  dia = dia.replace("<<Commander>>", "").replace("<<Driver>>", "")
  new_valid_diaog.append(dia)



In [None]:
tra = {
    "DIALOG" : new_train_diaog,
    "GAMEPLAN" : train_df["gameplan_prediction"]
}

train_df = pd.DataFrame(tra)

val = {
    "DIALOG" : new_valid_diaog,
    "GAMEPLAN" : valid_df["gameplan_prediction"]
}

train_df = pd.DataFrame(tra)
valid_df = pd.DataFrame(val)

In [None]:
train_data_txt = Dataset.from_pandas(train_df)
validation_data_txt = Dataset.from_pandas(valid_df)

## Pre-process dataset

In [None]:
def batch_tokenize_preprocess(batch, tokenizer, max_source_length, max_target_length):
    source, target = batch["DIALOG"], batch["GAMEPLAN"]
    source_tokenized = tokenizer(
        source, padding="max_length", truncation=True, max_length=max_source_length
    )
    target_tokenized = tokenizer(
        target, padding="max_length", truncation=True, max_length=max_target_length
    )

    batch = {k: v for k, v in source_tokenized.items()}
    # Ignore padding in the loss
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in l]
        for l in target_tokenized["input_ids"]
    ]
    return batch


train_data = train_data_txt.map(
    lambda batch: batch_tokenize_preprocess(
        batch, tokenizer, encoder_max_length, decoder_max_length
    ),
    batched=True,
    remove_columns=train_data_txt.column_names,
)

validation_data = validation_data_txt.map(
    lambda batch: batch_tokenize_preprocess(
        batch, tokenizer, encoder_max_length, decoder_max_length
    ),
    batched=True,
    remove_columns=validation_data_txt.column_names,
)

Map:   0%|          | 0/4258 [00:00<?, ? examples/s]

Map:   0%|          | 0/462 [00:00<?, ? examples/s]

In [None]:
# Borrowed from https://github.com/huggingface/transformers/blob/master/examples/seq2seq/run_summarization.py

nltk.download("punkt", quiet=True)

metric = datasets.load_metric("rouge")


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract a few results from ROUGE
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="BART_large_Synthetic_Gameplan",
    num_train_epochs=20,  # demo
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=8,  # demo
    per_device_eval_batch_size=8,
    learning_rate=5e-05,
    warmup_steps=500,
    weight_decay=0.01,
    label_smoothing_factor=0.1,
    predict_with_generate=True,
    logging_dir="logs",
    logging_steps=50,
    save_total_limit=3,
    push_to_hub = True,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
#WANDB integration
if WANDB_INTEGRATION:
    wandb_run = wandb.init(
        project="TEACh",
        config={
            "per_device_train_batch_size": training_args.per_device_train_batch_size,
            "learning_rate": training_args.learning_rate,
            "dataset": "ET",
        },
    )

    now = datetime.now()
    current_time = now.strftime("%H%M%S")
    wandb_run.name = "run_ET_" + current_time

In [None]:
trainer.train()

In [None]:
trainer.evaluate(eval_dataset = validation_data)