In [1]:
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import importlib
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
PRETRAINED_MODEL_NAME = "t5-small"

# Worth trying the below one..
# PRETRAINED_MODEL_NAME = "MaRiOrOsSi/t5-base-finetuned-question-answering"

BATCH_SIZE = 64

DEVICE = 'cuda:3'
NUM_TRAIN_EPOCHS = 15
MAX_INPUT_LENGTH = 256

CKPT_SAVE_PATH = "t5_finetuned_faithdial_edit_ep15_khorr_seqlen256"

### Data Loader

In [3]:
from datasets import load_dataset
# from torch.utils.data import Dataset, DataLoader

In [4]:
dataset = load_dataset("McGill-NLP/FaithDial")
dataset['train']

No config specified, defaulting to: faith_dial/plain_text
Found cached dataset faith_dial (/home/csgrad/jayashok/.cache/huggingface/datasets/McGill-NLP___faith_dial/plain_text/1.0.0/70568c8ab3bbc83b603bce58fa593ab27e7f0d0cde51034e1c2073ff3e14189a)
100%|██████████| 7/7 [00:00<00:00, 777.11it/s]


Dataset({
    features: ['dialog_idx', 'response', 'original_response', 'history', 'knowledge', 'BEGIN', 'VRM'],
    num_rows: 18357
})

In [5]:
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)

#### Using Custom Pipeline for custom training loop

In [6]:
from torch.utils.data import DataLoader

In [7]:
import CustomDataset
from CustomDataset import Dataset

In [8]:
train_set = Dataset(dataset['train'], tokenizer, CustomDataset.DatasetMap.faithdial_edit)

100%|██████████| 18357/18357 [00:00<00:00, 20117.13it/s]


In [9]:
validation_set = Dataset(dataset['validation'], tokenizer, CustomDataset.DatasetMap.faithdial_edit)

100%|██████████| 3417/3417 [00:00<00:00, 19543.77it/s]


In [10]:
import random
print(random.choice(train_set))
print(random.choice(validation_set))

('is an American television game show created by Merv Griffin.', 'My favorite TV show is JEOPARDY!.  Please tell me a little about my favorite show, and be sure to phrase your answers in the form of a question.|Oh no! This will be challenging. Okay. American tv show that Merv Griffen created?', 'Okay, did you know Jeopardy is the show that Merv Griffin created?')
("The Horseshoe Falls lies on the border of the United States and Canada with the American Falls entirely on the United States' side, separated by Goat Island.", 'Ok... shocked over here! Had no idea America had a Falls. They are all connected I assume?|Firstly,Horseshoe Falls lies on the border of the United States and Canada with the American Falls entirely on the United States', 'I can confirm if they are connected or not. What I know is that the Horseshoe Falls is located on the border of Canada and the US, while the American Falls lies entirely on the US')


In [11]:
my_trainset_dataloader = DataLoader(train_set, batch_size=BATCH_SIZE,
                                    num_workers=16, collate_fn=lambda data: train_set.pack_minibatch(data))
my_validation_dataloader = DataLoader(validation_set, batch_size=BATCH_SIZE,
                                        num_workers=16, collate_fn=lambda data: validation_set.pack_minibatch(data))

### Model Initialization

In [12]:
# Initialize the T5 model
model = AutoModelForSeq2SeqLM.from_pretrained(PRETRAINED_MODEL_NAME)

# Resize the model's embeddings to accommodate the new tokens (No New Tokens used yet)
# model.resize_token_embeddings(len(tokenizer))


### Training

In [13]:
import torch
from tqdm import tqdm

In [14]:
model.train()
model.to(DEVICE)

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Drop

In [15]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [18]:
def process_data(tokenizer, questions, contexts, answers, max_input_length, device):
    def _apply(inp_tuple):
        tup_0, knowledge = inp_tuple
        prompt, resp = tup_0.split("|")
        return f"rewrite: {resp}  question:{prompt} context: {knowledge}"

    inputs = list(map(_apply, zip(questions,contexts)))
    # print("-------------------------------------")
    # print(inputs)
    # return
    encoded_inputs = tokenizer(
                            inputs,
                            padding="longest",
                            max_length=max_input_length,
                            truncation=True,
                            return_tensors="pt",
                        )
    encoded_targets = tokenizer(
                            answers,
                            padding="longest",
                            max_length=max_input_length,
                            truncation=True,
                            return_tensors="pt",
                        )
    
    input_ids, attention_mask = encoded_inputs.input_ids, encoded_inputs.attention_mask
    encoded_targets = encoded_targets.input_ids

    # replace padding target token id's of the labels by -100, crossEntropy skip target label == -100
    encoded_targets[encoded_targets == tokenizer.pad_token_id] = -100

    encoded_inputs = input_ids.to(device)
    encoded_targets = encoded_targets.to(device)
    attention_mask = attention_mask.to(device)

    return encoded_inputs, attention_mask, encoded_targets

In [19]:
for epoch in range(NUM_TRAIN_EPOCHS):
    ### Training loop
    epoch_train_loss = 0.0
    model.train()
    for contexts,questions,answers in tqdm(my_trainset_dataloader):
        # print("*"*20)
        # print('Answers:')
        # print(answers)

        encoded_inputs, attention_mask, encoded_targets = process_data(tokenizer, questions, contexts, answers, max_input_length=MAX_INPUT_LENGTH, device=DEVICE)
        # break
        optimizer.zero_grad()
        outputs = model(input_ids=encoded_inputs, attention_mask=attention_mask, labels=encoded_targets)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item() * encoded_inputs.shape[0]

    ## Validation loop
    model.eval()
    epoch_val_loss = 0.0
    with torch.no_grad():
        # model_predictions_encoded = []
        # target_encoded = []
        for contexts, questions, answers in tqdm(my_validation_dataloader):
            encoded_inputs, attention_mask, encoded_targets = process_data(tokenizer, questions, contexts, answers, max_input_length=MAX_INPUT_LENGTH, device=DEVICE)
            outputs = model(input_ids=encoded_inputs, attention_mask=attention_mask, labels=encoded_targets)
            loss = outputs.loss
            epoch_val_loss += loss.item() * encoded_inputs.shape[0]

    print("*"*20)
    print(f"epoch={epoch + 1}/{NUM_TRAIN_EPOCHS}")
    print(f"\t Train loss = {epoch_train_loss/len(train_set):.4f}")
    print(f"\t Val loss = {epoch_val_loss/len(validation_set):.4f}")

100%|██████████| 287/287 [01:11<00:00,  4.02it/s]
100%|██████████| 54/54 [00:05<00:00,  9.71it/s]


********************
epoch=1/15
	 Train loss = 2.1419
	 Val loss = 1.9936


100%|██████████| 287/287 [01:12<00:00,  3.96it/s]
100%|██████████| 54/54 [00:05<00:00,  9.64it/s]


********************
epoch=2/15
	 Train loss = 1.9346
	 Val loss = 1.9205


100%|██████████| 287/287 [01:12<00:00,  3.96it/s]
100%|██████████| 54/54 [00:05<00:00,  9.69it/s]


********************
epoch=3/15
	 Train loss = 1.8528
	 Val loss = 1.8822


100%|██████████| 287/287 [01:12<00:00,  3.95it/s]
100%|██████████| 54/54 [00:05<00:00,  9.47it/s]


********************
epoch=4/15
	 Train loss = 1.7942
	 Val loss = 1.8557


100%|██████████| 287/287 [01:12<00:00,  3.97it/s]
100%|██████████| 54/54 [00:05<00:00,  9.53it/s]


********************
epoch=5/15
	 Train loss = 1.7466
	 Val loss = 1.8413


100%|██████████| 287/287 [01:12<00:00,  3.95it/s]
100%|██████████| 54/54 [00:05<00:00,  9.74it/s]


********************
epoch=6/15
	 Train loss = 1.7072
	 Val loss = 1.8307


100%|██████████| 287/287 [01:12<00:00,  3.94it/s]
100%|██████████| 54/54 [00:05<00:00,  9.81it/s]


********************
epoch=7/15
	 Train loss = 1.6703
	 Val loss = 1.8185


100%|██████████| 287/287 [01:12<00:00,  3.94it/s]
100%|██████████| 54/54 [00:05<00:00,  9.54it/s]


********************
epoch=8/15
	 Train loss = 1.6383
	 Val loss = 1.8146


100%|██████████| 287/287 [01:12<00:00,  3.94it/s]
100%|██████████| 54/54 [00:05<00:00,  9.37it/s]


********************
epoch=9/15
	 Train loss = 1.6088
	 Val loss = 1.8124


100%|██████████| 287/287 [01:12<00:00,  3.94it/s]
100%|██████████| 54/54 [00:05<00:00,  9.57it/s]


********************
epoch=10/15
	 Train loss = 1.5825
	 Val loss = 1.8106


100%|██████████| 287/287 [01:12<00:00,  3.95it/s]
100%|██████████| 54/54 [00:05<00:00,  9.79it/s]


********************
epoch=11/15
	 Train loss = 1.5562
	 Val loss = 1.8096


 15%|█▌        | 44/287 [00:12<01:06,  3.66it/s]


KeyboardInterrupt: 

In [None]:
# interrupt and change ckpt path...
# see if this model is indeed better??

In [13]:
CKPT_SAVE_PATH = 't5_finetuned_faithdial_edit_ep10_khorr_seqlen256'

In [22]:
model.save_pretrained(CKPT_SAVE_PATH)

### Inference

In [19]:
from tqdm import tqdm

In [14]:
model.from_pretrained(CKPT_SAVE_PATH)

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Drop

In [32]:
test_sample = dataset['test'][50]

In [33]:
test_sample

{'dialog_idx': 11,
 'response': "I don't quite know what you mean but cheerleading is about participants cheering for their team as encouragement.",
 'original_response': 'So, I see you are excited about cheerleading, it all about team and ecouragement.',
 'history': ['2, 4 , 6, 8 who do you appreciate??'],
 'knowledge': "Cheerleading is an activity wherein the participants (referred to as ''cheerleaders'') cheer for their team as a form of encouragement.",
 'BEGIN': ['Hallucination', 'Entailment'],
 'VRM': ['Disclosure', 'Ack.']}

In [17]:
def infer(model, prompt, knowledge, org_resp, max_input_length, max_output_length, device):
    # question = "What is 42?"
    # context = "42 is the answer to life, the universe and everything"
    input = f"rewrite: {org_resp}  question:{prompt} context: {knowledge}"

    encoded_input = tokenizer([input],
                                return_tensors='pt',
                                max_length=max_input_length,
                                truncation=True)
    inp = encoded_input.input_ids.to(device)
    attn_mask = encoded_input.attention_mask.to(device)
    output = model.generate(input_ids = inp,
                                attention_mask = attn_mask, max_length=max_output_length)
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    # print(output)
    return output

In [35]:
res = infer(model, prompt=test_sample['history'][-1], knowledge=test_sample['knowledge'],
            org_resp=test_sample['original_response'],
             max_input_length=MAX_INPUT_LENGTH, max_output_length=100, device=DEVICE)

In [36]:
print('p:', test_sample['history'][-1])
print('k:', test_sample['knowledge'])
print('res:', res)
print('org:', test_sample['original_response'])
print('gt:', test_sample['response'])

p: 2, 4 , 6, 8 who do you appreciate??
k: Cheerleading is an activity wherein the participants (referred to as ''cheerleaders'') cheer for their team as a form of encouragement.
res: Well, Cheerleading is about cheerleading, which is a group where the participants cheer for their team as a form of encouragement.
org: So, I see you are excited about cheerleading, it all about team and ecouragement.
gt: I don't quite know what you mean but cheerleading is about participants cheering for their team as encouragement.


In [31]:
# Good Editing
print('p:', test_sample['history'][-1])
print('k:', test_sample['knowledge'])
print('res:', res)
print('org:', test_sample['original_response'])
print('gt:', test_sample['response'])

p: My ex girlfriend broke my heart. What are some ways to deal with heartbreak?
k: The concept is cross-cultural, often cited with reference to a desired or lost lover, and dates back at least 3,000 years.
res: Well, the concept of heartbreak dates back at least 3,000 years.
org: Best medicine is to just take up a hobby and get involved. But the concept of heartbreak dates back at least 3,000 years.
gt: I'm unable to solve that but you may be curious to know that it has 3000 years of history.


In [27]:
# No-Editing
print('p:', test_sample['history'][-1])
print('k:', test_sample['knowledge'])
print('res:', res)
print('org:', test_sample['original_response'])
print('gt:', test_sample['response'])

p: Oh interesting, I'm not really familiar with Coco Chanel, can you tell me more about it?
k: She was the founder and namesake of the Chanel brand.
res: Yes of course, she is the founder and nameake of the Chanel brand!
org: Yes of course she is the founder and namesake of the chanel brand!
gt: Yes, of course, she is the founder and namesake of the Chanel brand!


In [39]:
# Retrieve results for all samples

In [38]:
all_results = []
for test_sample in tqdm(dataset['test'], total=len(dataset['test'])):
    input_ = {'knowledge': test_sample['knowledge'],
            'prompt': test_sample['history'][-1],
            'org_resp': test_sample['original_response']}
    
    resp = infer(model, max_input_length=MAX_INPUT_LENGTH, max_output_length=100, device=DEVICE, **input_)
    all_results.append([input_['knowledge'], input_['prompt'], resp,
                         test_sample['original_response'], test_sample['response']])

# clean_all-results
all_results = [[cell if cell is not None else '' for cell in row] for row in all_results]
data_dump = "\n".join(['|'.join(res) for res in all_results])

with open("T5_edit_FaithDial_khorr.txt", 'w') as f:
    f.write(data_dump)

100%|██████████| 3539/3539 [08:12<00:00,  7.19it/s]


In [39]:
all_results

[["Dylan's Candy Bar is a chain of boutique candy shops and candy supplier currently located in New York City; East Hampton, New York; Los Angeles, Chicago and Miami Beach, as well as in wholesale venues around the globe.",
  "I love candy, what's a good brand?",
  "I don't have any preferences. I know that there are several stores that are located in New York City, and in Los Angeles, New York, and in Miami Beach.",
  "Dylan's Candy Bar is a great brand of candy",
  "I don't know how good they are, but Dylan's Candy Bar has a chain of candy shops in various cities."],
 ["Dylan's Candy Bar is a chain of boutique candy shops and candy supplier currently located in New York City; East Hampton, New York; Los Angeles, Chicago and Miami Beach, as well as in wholesale venues around the globe.",
  'Oh, they do? What kind of candy do they sell?',
  "I don't know, but they are a candy supplier.",
  "Dylan's Candy Bar is a candy supplier",
  "I don't know, really, but they also are a supplier of

In [None]:
# take the responses from gen_WoW

In [21]:
input_file = "T5_gen_WoW.txt"

with open(input_file, 'r') as f:
    input_data = f.read().split("\n")

# pipe delimited knowledge, history and pred_response
input_data = [x.split("|") for x in input_data]

all_results_pipeline = []
for fields in tqdm(input_data):
    if len(fields) == 3:
        knowledge, history, pred_response = fields
    else:
        knowledge, history, pred_response, org_resp, gt_resp = fields
    # print(len(knowledge), len(history), len(pred_response))
    
    input_ = {'knowledge': knowledge,
        'prompt': history,
        'org_resp': pred_response}

    resp = infer(model, max_input_length=MAX_INPUT_LENGTH, max_output_length=100, device=DEVICE, **input_)
    all_results_pipeline.append([input_['knowledge'], input_['prompt'], resp,
                         input_['org_resp']])

# clean_all-results
all_results_pipeline = [[cell if cell is not None else '' for cell in row] for row in all_results_pipeline]
data_dump = "\n".join(['|'.join(res) for res in all_results_pipeline])

with open("T5_edit_FaithDial_khorr.txt", 'w') as f:
    f.write(data_dump)

100%|██████████| 3539/3539 [03:03<00:00, 19.29it/s]
