In [1]:

import os, sys
from pathlib import Path
HOME = os.getcwd()

current = HOME 
while 'src' not in os.listdir(current):
    current = Path(current).parent


PARENT_DIR = str(current)
DATA_FOLDER = os.path.join(PARENT_DIR, 'src', 'data')
data_path = os.path.join(DATA_FOLDER, 'filtered.tsv')
sys.path.append(str(current))



In [None]:
from datasets import load_dataset
all_data = load_dataset('csv', data_files=os.path.join(DATA_FOLDER, 'everythingprepared.csv'), split='train')
all_data = all_data.filter(lambda s: (isinstance(s['source'], str) and isinstance(s['target'], str)))

In [44]:
import src.data.preprocess1 as pr1

def filter_data(sample):
  
    source = pr1.tokenize(sample['source'], tokenizer_type='word')
    target = pr1.tokenize(sample['target'], tokenizer_type='word')
    return len(source) > len(target)

summary_data = all_data.filter(filter_data)

summary_data.to_csv(os.path.join(DATA_FOLDER, 'model1_data.csv'), index=False)

Creating CSV from Arrow format:   0%|          | 0/277 [00:00<?, ?ba/s]

Creating CSV from Arrow format: 100%|██████████| 277/277 [00:01<00:00, 239.18ba/s]


31406141

In [45]:
import src.data.preprocess2 as pdr
sample = summary_data.select(range(5000))
train_data, val_data, test_data = pdr.data_split(all_data=sample)

In [46]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
CHECKPOINT = 't5-small'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TOKENIZER = AutoTokenizer.from_pretrained(CHECKPOINT)
MODEL = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT).to(DEVICE)

In [47]:
TASK_PREFIX = 'summarize: '

def prepare_labeled_data(batch):
    inputs = [TASK_PREFIX + doc for doc in batch["source"]]
 
    model_inputs = TOKENIZER(inputs, truncation=True, max_length=1028)

    labels = TOKENIZER(text_target=batch["target"], truncation=True)
 
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [48]:
train_data = train_data.map(prepare_labeled_data, batched=True).remove_columns(['source', 'target'])
val_data = val_data.map(prepare_labeled_data, batched=True).remove_columns(['source', 'target'])

In [49]:

from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer=TOKENIZER, model=CHECKPOINT)
from torch.utils.data import DataLoader
train_dl = DataLoader(dataset=train_data, batch_size=4, shuffle=True, collate_fn=data_collator)
val_dl = DataLoader(dataset=val_data, batch_size=4, shuffle=False, collate_fn=data_collator)

In [50]:
# make sure the data is loaded correctly
b1, b2 = next(iter(train_dl)), next(iter(val_dl))

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [51]:
from src.models.train.model1 import train2 as tc
singleton_obj = tc.EvalutionSingletonInitializer()
tx_classifier, tx_tokenizer, tx_device = singleton_obj.get_toxic_classifier(), singleton_obj.get_toxic_tokenizer(), singleton_obj.get_device()

from torch.optim import Adam
from torch.optim.lr_scheduler import LinearLR


optimizer = Adam(MODEL.parameters(), lr=2 * 10 ** -5)
scheduler = LinearLR(optimizer=optimizer, start_factor=1, end_factor=0.5,total_iters=100)

In [52]:
from src.models.train.model2.train2 import EvalutionSingletonInitializer
from torch.nn.functional import softmax
from typing import Union

def toxic_summary_model_loss(output_decoded: torch.Tensor, 
                             device,
                             return_tensor: bool=False) -> Union[float, torch.Tensor]:
    
    singleton_obj = EvalutionSingletonInitializer()
    tc_tokenizer, tc_classifier = singleton_obj.get_toxic_tokenizer(), singleton_obj.get_toxic_classifier()

    for p in tc_classifier.parameters():
        p.requires_grad = False

    tc_classifier.to(device)
    model_input = tc_tokenizer(output_decoded, return_tensors='pt', padding=True, truncation=True)
    model_input = {k: v.to(device) for k, v in model_input.items()}

    output = tc_classifier(**model_input)
    
    loss = torch.mean(softmax(output.logits, dim=1)[:, 1])
    
    if return_tensor: 
        loss.requires_grad=True
        return loss
    
    return loss.item()


In [None]:
from src.models.train.model2 import train  as ss

_, _, best_model = ss.train_custom_summarizer(train_dataloader=train_dl, 
                                            val_dataloader=val_dl,
                                            summary_model=MODEL,
                                            summary_tokenizer=TOKENIZER, 
                                            toxicity_loss_function=toxic_summary_model_loss,
                                            toxicity_coeff=0.5, 
                                            optimizer=optimizer, 
                                            scheduler=scheduler, 
                                            num_epochs=20,   
                                            report_per_epoch=1,
                                            log_dir=os.path.join(HOME, 'runs')
                                            )


In [54]:
for i in range(0, len(val_data), 20):
    input_ids = val_data[i]['input_ids']
    attention_mask = val_data[i]['attention_mask']
    labels = val_data[i]['labels']

    print(f"source: {TOKENIZER.decode(input_ids, skip_special_tokens=True)}")
    print(f"target: {TOKENIZER.decode(labels, skip_special_tokens=True)}")

    outputs = best_model.generate(
        input_ids=torch.tensor(input_ids).unsqueeze(0).to('cuda'),
        attention_mask=torch.tensor(attention_mask).unsqueeze(0).to('cuda'),
        max_length=512,
        num_beams=5,
        early_stopping=True
    )

    print(f"generated :{TOKENIZER.decode(outputs[0], skip_special_tokens=True)}")
    print("#" * 100)


source: summarize: we don't fear these demons. we destroy them.
target: we don't fight demons, we destroy them.
generated :we don't fear these demons.
####################################################################################################
source: summarize: the old witch shook her head.
target: she shook her head.
generated :the old witch shook her head.
####################################################################################################
source: summarize: you want me to fucking leave her
target: you want to leave it
generated :you want me to leave her
####################################################################################################
source: summarize: why don't you get rid of her before you go, huh
target: why not make amends before you leave, huh
generated :why don't you get rid of her before you go
####################################################################################################
source: summarize: darn it. fractions a