In [4]:
# the notebook's main objective is to filter and prepare the dataset to train a summarizer on it.
import os, sys
from pathlib import Path
HOME = os.getcwd()

current = HOME 
while 'src' not in os.listdir(current):
    current = Path(current).parent

PARENT_DIR = str(current)
DATA_FOLDER = os.path.join(PARENT_DIR, 'src', 'data')
data_path = os.path.join(DATA_FOLDER, 'filtered.tsv')

sys.path.append(str(current))
sys.path.append(os.path.join(str(current), 'data_analysis'))
sys.path.append(os.path.join(str(current), 'evaluation'))
sys.path.append(os.path.join(str(current), 'text_processing')) 

In [5]:
model_checkpoint = os.path.join(PARENT_DIR, 'src', 'models', 's2s', 'toxic_classifier_checkpoints', 'checkpoint-500')

In [6]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F
from datasets import load_dataset
train_data = load_dataset("csv", data_files=os.path.join(DATA_FOLDER, 'train_split.csv'), split='train')
val_dat = load_dataset("csv", data_files=os.path.join(DATA_FOLDER, 'train_split.csv'), split='train')
test_data = load_dataset("csv", data_files=os.path.join(DATA_FOLDER, 'train_split.csv'), split='train')

OSError: Not enough disk space. Needed: Unknown size (download: Unknown size, generated: Unknown size, post-processed: Unknown size)

In [None]:
# sample = all_data.select(range(5000))
checkpoint = 'facebook/bart-base'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
toxic_classifier = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)

In [None]:
# split the data

import src.data_preparation.prepare_data as pdr

train_data, val_data, test_data = pdr.data_split(data, train_portion=0.96)

# save the different splits for reference
train_data.to_csv(os.path.join(DATA_FOLDER, 'train_split.csv'), index=False)
val_data.to_csv(os.path.join(DATA_FOLDER, 'val_split.csv'), index=False)
test_data.to_csv(os.path.join(DATA_FOLDER, 'test_split.csv'), index=False)

In [None]:
def prepare_labeled_data(batch):
    # tokenize 'x'
    model_inputs = tokenizer(batch['source'], truncation=True)
    # tokenize 'y'  
    labels = tokenizer(text_target=batch["target"], truncation=True)
    # add it to the model's input
    model_inputs["labels"] = labels["input_ids"]
    # model_inputs["labels_attention_masks"] = labels['attention_mask']    
    return model_inputs

train_data = train_data.map(prepare_labeled_data, batched=True).remove_columns(['source', 'target'])
val_data = val_data.map(prepare_labeled_data, batched=True).remove_columns(['source', 'target'])
test_data = test_data.map(prepare_labeled_data, batched=True).remove_columns(['source', 'target'])


In [None]:

from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)
# we are now ready to create the dataloader
from torch.utils.data import DataLoader
train_dl = DataLoader(dataset=train_data, batch_size=4, shuffle=True, collate_fn=data_collator)
val_dl = DataLoader(dataset=val_data, batch_size=4, shuffle=False, collate_fn=data_collator)

In [None]:
from torch import nn
from transformers import Trainer
from torch.nn.functional import softmax

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        model_output = model(**inputs)
        # extract the sequence to sequence loss
        s2s_loss = model_output.loss

        prediction_ids = model_output.logits.argmax(dim=-1)
        attention_mask = torch.where(prediction_ids == tokenizer.pad_token_id,
                                     torch.zeros(*prediction_ids.shape), torch.ones(*prediction_ids.shape))
        
        toxic_output = toxic_classifier(input_ids=prediction_ids, attention_mask=attention_mask)
        toxic_loss = torch.mean(softmax(toxic_output.logits, dim=1)[:, 1])
        loss = s2s_loss + 0.05 * toxic_loss 
        return (loss, model_output) if return_outputs else loss 
    

In [None]:
# Train model

from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

batch_size = 64
num_epochs = 5
learning_rate = 5e-5
warmup_steps = 500
weight_decay = 0.01

sc_training_args = Seq2SeqTrainingArguments(
    output_dir='seq_2_seq',
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    predict_with_generate=True,
    do_train=True,
    do_eval=True,
    logging_steps=100,
    save_steps=1000,
    eval_steps=10,
    overwrite_output_dir=True,
    warmup_steps=warmup_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    num_train_epochs=num_epochs,
    report_to="none",
)

trainer = Seq2SeqTrainer(
    model=model,
    args=sc_training_args,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=data_collator
)

In [None]:
trainer.train()

In [None]:

for i in range(0, len(val_data), 50):
    input_ids = val_data[i]['input_ids']
    attention_mask = val_data[i]['attention_mask']
    labels = val_data[i]['labels']

    print(f"source: {tokenizer.decode(input_ids, skip_special_tokens=True)}")
    print(f"target: {tokenizer.decode(labels, skip_special_tokens=True)}")

    outputs = model.generate(
        input_ids=torch.tensor(input_ids).unsqueeze(0).to('cuda'),
        attention_mask=torch.tensor(attention_mask).unsqueeze(0).to('cuda'),
        max_length=512,
        num_beams=5,
        early_stopping=True
    )

    print(f"generated :{tokenizer.decode(outputs[0], skip_special_tokens=True)}")
    print("#" * 100)

In [None]:
# let's write a function to compute the summarization + toxicity loss
from src.evaluation.toxicity_classication import EvalutionSingletonInitializer
from torch.nn.functional import softmax
from typing import Union

def toxic_summary_model_loss(output_decoded: torch.Tensor, 
                             device,
                             return_tensor: bool=False) -> Union[float, torch.Tensor]:
    
    singleton_obj = EvalutionSingletonInitializer()
    tc_tokenizer, tc_classifier = singleton_obj.get_toxic_tokenizer(), singleton_obj.get_toxic_classifier()

    # make sure to freeze their parameters
    for p in tc_classifier.parameters():
        p.requires_grad = False

    tc_classifier.to(device)
    # tokenize
    model_input = tc_tokenizer(output_decoded, return_tensors='pt', padding=True, truncation=True)
    # set the input to the device
    model_input = {k: v.to(device) for k, v in model_input.items()}
    # pass through the model
    output = tc_classifier(**model_input)
    
    loss = torch.mean(softmax(output.logits, dim=1)[:, 1])
    
    if return_tensor: 
        loss.requires_grad=True
        return loss
    
    return loss.item()


# train_custom_seq2seq(train_dataloader=train_dl, 
#                      val_dataloader=val_dl, 
#                      model=model, 
#                      tokenizer=tokenizer, 
#                      toxic_tokenizer=toxic_tokenizer,
#                      toxic_classifier=toxic_classifier,
#                      optimizer=optimizer, 
#                      scheduler=scheduler, 
#                      toxicity_loss_function=toxic_summary_model_loss,
#                      toxicity_coeff=0.5,
#                     num_epochs=2,   
#                     report_per_epoch=1,
#                     log_dir=os.path.join(HOME, 'runs')
#                     )

        