# Setup

In [2]:
import os
import torch
from torch import cuda
from torch.utils.data import TensorDataset, DataLoader
import evaluate
from transformers import Seq2SeqTrainer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments, pipeline, AdamW, get_scheduler
from argparse import ArgumentParser
import numpy as np
from datasets import load_dataset, DownloadMode
from tqdm.auto import tqdm

## model + tokenizer

In [3]:
model_name='facebook/bart-base'
device='cuda'
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = "[PAD]"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model.to(device)
print("Model + tokenizer")

Model + tokenizer


## Train dataset

In [4]:
dataset_dir= 'modified_dataset/'
train_dict={}
relations = ['Physical','Event', 'Intent','Reaction']
for relation in relations:
    train_dict[relation] = load_dataset('json', data_files={'train': f'{dataset_dir}{relation} train.json'}, download_mode= DownloadMode.REUSE_DATASET_IF_EXISTS)

max_seq_length = 64
def preprocess_function(examples):
    model_inputs = tokenizer(examples['head'], text_target=examples['tail'], max_length=max_seq_length,
                             truncation=True)
    return model_inputs

train_tok_dict={}
for relation in relations:
    train_tok_dict[relation] = train_dict[relation].map(
        preprocess_function,
        batched=True,  # num_proc=num_proc,
        remove_columns=['head', 'tail'],
        load_from_cache_file=True
    )

Using custom data configuration default-93aa8ff448102f0c
Found cached dataset json (/root/.cache/huggingface/datasets/json/default-93aa8ff448102f0c/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-2513e8f6c6b4c2f7
Found cached dataset json (/root/.cache/huggingface/datasets/json/default-2513e8f6c6b4c2f7/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-4dc7bdf93ec46776
Found cached dataset json (/root/.cache/huggingface/datasets/json/default-4dc7bdf93ec46776/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-7ba272b55de9335d
Found cached dataset json (/root/.cache/huggingface/datasets/json/default-7ba272b55de9335d/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-93aa8ff448102f0c/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-2f9f01ec0bd398ed.arrow


  0%|          | 0/147 [00:00<?, ?ba/s]

Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-4dc7bdf93ec46776/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-fe5a6b55ee6352c5.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-7ba272b55de9335d/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-c0a90a6ff09b3f00.arrow


## Test dataset

In [5]:
dataset_dir= 'modified_dataset/'
test_dict={}
relations = ['Physical','Event', 'Intent','Reaction']
for relation in relations:
    test_dict[relation] = load_dataset('json', data_files={'test': f'{dataset_dir}{relation} test.json'}, download_mode= DownloadMode.REUSE_DATASET_IF_EXISTS)

max_seq_length = 64
def preprocess_function(examples):
    examples['tail']=['\t'.join(x) for x in examples['tail']]
    return examples

for relation in relations:
    test_dict[relation] = test_dict[relation].map(
        preprocess_function,
        batched=True,
        load_from_cache_file=True
    )

Using custom data configuration default-7e6b9691b74736f5
Found cached dataset json (/root/.cache/huggingface/datasets/json/default-7e6b9691b74736f5/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-e48f3d7247e9024e
Found cached dataset json (/root/.cache/huggingface/datasets/json/default-e48f3d7247e9024e/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-9570b0225009fd85
Found cached dataset json (/root/.cache/huggingface/datasets/json/default-9570b0225009fd85/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-2afd8348b3f2243d
Found cached dataset json (/root/.cache/huggingface/datasets/json/default-2afd8348b3f2243d/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-7e6b9691b74736f5/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-b3a10f3c6a89ee85.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-e48f3d7247e9024e/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-236e9b251d5bbe77.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-9570b0225009fd85/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-1e0364c246fb4b17.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-2afd8348b3f2243d/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-a2aadcc021fcb54f.arrow


## generation example

In [6]:
for relation in relations:
    test_dataloader = DataLoader(
        test_dict[relation]['test'],
        batch_size=1,
    )
    for batch in test_dataloader:
        input_ids = tokenizer(batch['head'], padding=True, return_tensors="pt").to(device)
        generations = model.generate(**input_ids)
        print(f"{relation}")
        print(batch['head'])
        print(tokenizer.batch_decode(generations, skip_special_tokens=True))
        print('labels')
        print([s.split('\t') for s in batch['tail']])
        break



Physical
['a long rope can be used to']
['a long rope can be used to']
labels
[['make lots of knots', 'secure the rock', 'play tug of war', 'pull a car', 'hang yourself', 'tow a boat in', 'climb a mountain', 'save someone stuck in a tunnel', 'play a game of tug of war', 'play tag of war', 'create holder']]
Event
["PersonX plays PersonY's favorite songs. Before that,"]
["PersonX plays PersonY's favorite songs. Before that,"]
labels
[["PersonX wished their friend wasn't gone"]]
Intent
['PersonX runs for class president. PersonX did this to']
['PersonX runs for class president. PersonX did this to']
labels
[[' make a difference in their classroom', ' feel important']]
Reaction
["PersonX sees PersonY's girlfriend. The effect on PersonX will be that PersonX"]
["PersonX sees PersonY's girlfriend. The effect on PersonX will be that Person"]
labels
[['observes her', 'takes note of her', 'gets called', 'gets stoped', 'waves sat her', 'stops to talk to her']]


# Incremental Training + Evaluation

In [None]:
metric=evaluate.load('bleu')

for train_relation in relations:
    os.makedirs(f'incremental/{train_relation}', exist_ok=True)
    
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model
    )

    train_dataloader = DataLoader(
        train_tok_dict[train_relation]['train'],
        shuffle=True,
        collate_fn=data_collator,
        batch_size=32,
    )

    optimizer = AdamW(model.parameters(), lr=2e-5)



    num_train_epochs = 3
    num_update_steps_per_epoch = len(train_dataloader)
    num_training_steps = num_train_epochs * num_update_steps_per_epoch

    lr_scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )



    progress_bar = tqdm(range(num_training_steps))

    model.train()
    for epoch in range(num_train_epochs):
        for batch in train_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)

    
    model.eval() # put in testing mode (dropout modules are deactivated)
    for test_relation in relations:
        test_dataloader = DataLoader(
            test_dict[test_relation]['test'],
            batch_size=64,
        )
        for batch in test_dataloader:
            input_ids = tokenizer(batch['head'], padding=True, return_tensors="pt").to(device)
            with torch.no_grad():
                generations = model.generate(**input_ids)
            decoded_gens= tokenizer.batch_decode(generations, skip_special_tokens=True)
            labels = [s.split('\t') for s in batch['tail']]
            metric.add_batch(predictions=decoded_gens, references=labels)
        results = metric.compute()
        
        f = open(f'incremental/{train_relation}/results.txt', "a")
        f.write(f'{test_relation} \n {results} \n')
        f.close()
        

    model.save_pretrained(f'incremental/{train_relation}')

    
    



  0%|          | 0/16755 [00:00<?, ?it/s]

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


  0%|          | 0/13770 [00:00<?, ?it/s]

  0%|          | 0/28092 [00:00<?, ?it/s]

  0%|          | 0/31053 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



## Evaluation 

In [12]:
metric=evaluate.load('bleu')

for train_relation in relations:
    device='cuda'
    model_name= f'incremental/{train_relation}'
    tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')
    tokenizer.pad_token = "[PAD]"
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    model.to(device)
    
    model.eval() # put in testing mode (dropout modules are deactivated)
    for test_relation in relations:
        test_dataloader = DataLoader(
            test_dict[test_relation]['test'],
            batch_size=64,
        )
        for batch in test_dataloader:
            input_ids = tokenizer(batch['head'], padding=True, return_tensors="pt").to(device)
            with torch.no_grad():
                generations = model.generate(**input_ids)
            decoded_gens= tokenizer.batch_decode(generations, skip_special_tokens=True)
            labels = [s.split('\t') for s in batch['tail']]
            metric.add_batch(predictions=decoded_gens, references=labels)
        results = metric.compute(max_order=2)
        results['blue-1']=results['brevity_penalty']*results['precisions'][0]
        f = open(f'incremental/results.txt', "a")
        f.write(f'{train_relation} test on {test_relation} \n {results} \n')
        f.close()   
    