# Setup

In [1]:
import os
import torch
from torch import cuda
from torch.utils.data import TensorDataset, DataLoader
import evaluate
from transformers import Seq2SeqTrainer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments, pipeline, AdamW, get_scheduler
from argparse import ArgumentParser
import numpy as np
from datasets import load_dataset, DownloadMode
from tqdm.auto import tqdm
from script.rec_adam import RecAdam

## model + tokenizer

In [2]:
model_name='facebook/bart-base'
device='cuda'
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = "[PAD]"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model.to(device)
print("Model + tokenizer")

Model + tokenizer


## Train dataset

In [3]:
dataset_dir= 'modified_dataset/'
train_dict={}
relations = ['Physical','Event', 'Intent','Reaction']
for relation in relations:
    train_dict[relation] = load_dataset('json', data_files={'train': f'{dataset_dir}{relation} train.json'}, download_mode= DownloadMode.REUSE_DATASET_IF_EXISTS)

max_seq_length = 64
def preprocess_function(examples):
    model_inputs = tokenizer(examples['head'], text_target=examples['tail'], max_length=max_seq_length,
                             truncation=True)
    return model_inputs

train_tok_dict={}
for relation in relations:
    train_tok_dict[relation] = train_dict[relation].map(
        preprocess_function,
        batched=True,  # num_proc=num_proc,
        remove_columns=['head', 'tail'],
        load_from_cache_file=True
    )

Using custom data configuration default-a65f559f25e8bb21
Found cached dataset json (/root/.cache/huggingface/datasets/json/default-a65f559f25e8bb21/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-ebda442e4b7407f2
Found cached dataset json (/root/.cache/huggingface/datasets/json/default-ebda442e4b7407f2/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-65647a2d2c9d2c86
Found cached dataset json (/root/.cache/huggingface/datasets/json/default-65647a2d2c9d2c86/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-c12fb152fedc44d5
Found cached dataset json (/root/.cache/huggingface/datasets/json/default-c12fb152fedc44d5/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-a65f559f25e8bb21/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-146e8d6bd159c98c.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-ebda442e4b7407f2/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-3c93bc7850aa0341.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-65647a2d2c9d2c86/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-f79984373479e4aa.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-c12fb152fedc44d5/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-18049fc3b76763fd.arrow


## Test dataset

In [4]:
dataset_dir= 'modified_dataset/'
test_dict={}
relations = ['Physical','Event', 'Intent','Reaction']
for relation in relations:
    test_dict[relation] = load_dataset('json', data_files={'test': f'{dataset_dir}{relation} test.json'}, download_mode= DownloadMode.REUSE_DATASET_IF_EXISTS)

max_seq_length = 64
def preprocess_function(examples):
    examples['tail']=['\t'.join(x) for x in examples['tail']]
    return examples

for relation in relations:
    test_dict[relation] = test_dict[relation].map(
        preprocess_function,
        batched=True,
        load_from_cache_file=True
    )

Using custom data configuration default-1036bc633c4cf542
Found cached dataset json (/root/.cache/huggingface/datasets/json/default-1036bc633c4cf542/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-4a7ea5f69cab20da
Found cached dataset json (/root/.cache/huggingface/datasets/json/default-4a7ea5f69cab20da/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-a44d390d889f4596
Found cached dataset json (/root/.cache/huggingface/datasets/json/default-a44d390d889f4596/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-67e1f0b46e31265d
Found cached dataset json (/root/.cache/huggingface/datasets/json/default-67e1f0b46e31265d/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-1036bc633c4cf542/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-43cd2fa2060b6560.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-4a7ea5f69cab20da/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-c6e3e0187d4f2d58.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-a44d390d889f4596/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-b7a860e5449eb85a.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-67e1f0b46e31265d/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-9ea25f7820045350.arrow


## generation example

In [5]:
for relation in relations:
    test_dataloader = DataLoader(
        test_dict[relation]['test'],
        batch_size=1,
    )
    for batch in test_dataloader:
        input_ids = tokenizer(batch['head'], padding=True, return_tensors="pt").to(device)
        generations = model.generate(**input_ids)
        print(f"{relation}")
        print(batch['head'])
        print(tokenizer.batch_decode(generations, skip_special_tokens=True))
        print('labels')
        print([s.split('\t') for s in batch['tail']])
        break



Physical
['You are likely to find a construction in a']
['You are likely to find a construction in a']
labels
[['roadblock']]
Event
['PersonX wants to hurt PersonY. Before that,']
['PersonX wants to hurt PersonY. Before that,']
labels
[['PersonX gets punched by PersonY']]
Intent
["PersonX preaches god 's ___. PersonX did this to"]
["PersonX preaches god's ___. PersonX did this to"]
labels
[['peace']]
Reaction
["PersonX sees PersonY in PersonX's office. PersonX will be"]
["PersonX sees PersonY in PersonX's office. PersonX will be"]
labels
[['surprised', 'courteous', 'interested']]


# Incremental Training + Evaluation

In [None]:
metric=evaluate.load('bleu')
USE_REC_ADAM =True
output_dir= 'rec_adam/' if USE_REC_ADAM else 'incremental/' 
for train_relation in relations:
    os.makedirs(f'{output_dir}{train_relation}', exist_ok=True)
    
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model
    )

    train_dataloader = DataLoader(
        train_tok_dict[train_relation]['train'],
        shuffle=True,
        collate_fn=data_collator,
        batch_size=64,
    )
    if USE_REC_ADAM and train_relation != relations[0]:
        optimizer = RecAdam(model.parameters(), lr=1e-3, pretrain_params= list(model.parameters()))
    else:
        optimizer = AdamW(model.parameters(), lr=2e-5)



    num_train_epochs = 3
    num_update_steps_per_epoch = len(train_dataloader)
    num_training_steps = num_train_epochs * num_update_steps_per_epoch

    lr_scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )



    progress_bar = tqdm(range(num_training_steps))

    model.train()
    for epoch in range(num_train_epochs):
        for batch in train_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)

    
    model.eval() # put in testing mode (dropout modules are deactivated)
    for test_relation in relations:
        test_dataloader = DataLoader(
            test_dict[test_relation]['test'],
            batch_size=64,
        )
        for batch in test_dataloader:
            input_ids = tokenizer(batch['head'], padding=True, return_tensors="pt").to(device)
            with torch.no_grad():
                generations = model.generate(**input_ids)
            decoded_gens= tokenizer.batch_decode(generations, skip_special_tokens=True)
            labels = [s.split('\t') for s in batch['tail']]
            metric.add_batch(predictions=decoded_gens, references=labels)
        results = metric.compute(max_order=2)
        results['blue-1']=results['brevity_penalty']*results['precisions'][0]
        f = open(f'{output_dir}results.txt', "a")
        f.write(f'{train_relation} test on {test_relation} \n {results} \n')
        f.close()   
    
        

    model.save_pretrained(f'{output_dir}{train_relation}')    
    

# Elastic Weight Consolidation

In [6]:
optpar_dict = {}
fisher_dict = {}
def on_task_update(train_dataloader):
    model.train()
    optimizer.zero_grad()
    # accumulating gradients
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
    optpar_dict.clear()
    fisher_dict.clear()
    # gradients accumulated is used to compute fisher
    for name, param in model.named_parameters():
        optpar_dict[name] = param.data.clone()
        fisher_dict[name] = param.grad.data.clone().pow(2)
    optimizer.zero_grad()

In [None]:

metric=evaluate.load('bleu')
ewc_lambda=1000
output_dir= f'ewc_l={ewc_lambda}/'
for train_relation in relations:
    os.makedirs(f'{output_dir}{train_relation}', exist_ok=False)
    
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model
    )

    train_dataloader = DataLoader(
        train_tok_dict[train_relation]['train'],
        shuffle=True,
        collate_fn=data_collator,
        batch_size=64,
    )
    
    optimizer = AdamW(model.parameters(), lr=2e-5)



    num_train_epochs = 3
    num_update_steps_per_epoch = len(train_dataloader)
    num_training_steps = num_train_epochs * num_update_steps_per_epoch

    lr_scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )



    progress_bar = tqdm(range(num_training_steps))

    model.train()
    for epoch in range(num_train_epochs):
        for batch in train_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            
            if train_relation!= relations[0]:
                #EWC penalty
                for name, param in model.named_parameters():
                    fisher = fisher_dict[name]
                    optpar = optpar_dict[name]
                    loss += (fisher * (optpar - param).pow(2)).sum() * ewc_lambda

            loss.backward()            
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
    
    on_task_update(train_dataloader)
    
    model.eval() # put in testing mode (dropout modules are deactivated)
    for test_relation in relations:
        test_dataloader = DataLoader(
            test_dict[test_relation]['test'],
            batch_size=64,
        )
        for batch in test_dataloader:
            input_ids = tokenizer(batch['head'], padding=True, return_tensors="pt").to(device)
            with torch.no_grad():
                generations = model.generate(**input_ids)
            decoded_gens= tokenizer.batch_decode(generations, skip_special_tokens=True)
            labels = [s.split('\t') for s in batch['tail']]
            metric.add_batch(predictions=decoded_gens, references=labels)
        results = metric.compute(max_order=2)
        results['blue-1']=results['brevity_penalty']*results['precisions'][0]
        f = open(f'{output_dir}results.txt', "a")
        f.write(f'{train_relation} test on {test_relation} \n {results} \n')
        f.close()   
    
        

    model.save_pretrained(f'{output_dir}{train_relation}')    
    



  0%|          | 0/8379 [00:00<?, ?it/s]

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

