In [1]:
# !conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
# !pip install transformers[torch] datasets rouge_score nlp numpy pandas matplotlib

In [1]:
import pandas as pd
import torch

from torch.utils.data import Dataset

In [3]:
# Check if CUDA is available
if torch.cuda.is_available():
    # Get the number of available GPUs
    num_gpus = torch.cuda.device_count()
    print(f"Number of available GPUs: {num_gpus}")
else:
    print("CUDA is not available.")

Number of available GPUs: 1


In [4]:
filtered_dataset = pd.read_csv('filtered.tsv', sep='\t')
filtered_dataset.rename(columns={filtered_dataset.columns[0]: "id"}, inplace=True)

sorted = filtered_dataset.sort_values(by=['ref_tox'], ascending=False)
sorted.head()

Unnamed: 0,id,reference,translation,similarity,lenght_diff,ref_tox,trn_tox
551255,551255,His father would have used a booming voice to ...,his father would have answered with his thunde...,0.729428,0.091954,0.999724,0.004599
101676,101676,You have to send those idiots back in.,you have to get those guys back there.,0.622852,0.0,0.999723,0.000115
258368,258368,Salina could be with that stupid cop.,Salina could be with the cop.,0.774944,0.210526,0.999723,0.0005
318050,318050,And don't let those idiots in radiology hold y...,don't let them fool you in radiology.,0.711188,0.283019,0.999723,0.000874
70934,70934,My idiot friend here brought marijuana... - on...,my friend here took a marijuana...,0.715508,0.396552,0.999722,0.000161


In [5]:
class ToxicDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        ref = self.dataframe.iloc[idx, 1]
        trn = self.dataframe.iloc[idx, 2]

        return ref, trn

In [6]:
from transformers import (
    EncoderDecoderModel,
    AutoTokenizer,
    BertTokenizer,
    BertGenerationEncoder,
    BertGenerationDecoder,
    BertGenerationConfig,
)
from transformers import BertTokenizerFast


# bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained(
#     "bert-base-uncased", "bert-base-uncased"
# )

# # Set tokenizer
# tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# tokenizer.bos_token = tokenizer.cls_token
# tokenizer.eos_token = tokenizer.sep_token

# # Set model's config
# bert2bert.config.decoder_start_token_id = tokenizer.bos_token_id
# bert2bert.config.eos_token_id = tokenizer.eos_token_id
# bert2bert.config.pad_token_id = tokenizer.pad_token_id

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import datasets


torch_dataset = ToxicDataset(sorted)

train_size = int(0.8 * len(torch_dataset))
test_size = len(torch_dataset) - train_size

train_dataset, test_dataset = random_split(torch_dataset, [train_size, test_size])

def gen_train():
    for idx in range(len(train_dataset)):
        # yield torch_dataset[idx]  # this has to be a dictionary
        yield {
            "reference": train_dataset[idx][0],
            "translation": train_dataset[idx][1],
        }

def gen_test():
    for idx in range(len(test_dataset)):
        # yield torch_dataset[idx]  # this has to be a dictionary
        yield {
            "reference": test_dataset[idx][0],
            "translation": test_dataset[idx][1],
        }

# Convert from PyTorch dataset to HuggingFace dataset
train_dataset = datasets.Dataset.from_generator(gen_train)
test_dataset = datasets.Dataset.from_generator(gen_test)

Generating train split: 462221 examples [00:18, 24374.44 examples/s]
Generating train split: 115556 examples [00:04, 24437.12 examples/s]


In [38]:
import nlp
import logging
from transformers import BertTokenizer, GPT2Tokenizer, EncoderDecoderModel, Trainer, TrainingArguments
import os

logging.basicConfig(level=logging.INFO)

if os.path.exists('checkpoint'):
    model = EncoderDecoderModel.from_encoder_decoder_pretrained("./checkpoint")
    bert_tokenizer = BertTokenizer.from_pretrained("./checkpoint")
    gpt2_tokenizer = GPT2Tokenizer.from_pretrained("./checkpoint")
else:
    model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2")
    bert_tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
    gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# cache is currently not supported by EncoderDecoder framework
model.decoder.config.use_cache = False
# CLS token will work as BOS token
bert_tokenizer.bos_token = bert_tokenizer.cls_token
# SEP token will work as EOS token
bert_tokenizer.eos_token = bert_tokenizer.sep_token


# make sure GPT2 appends EOS in begin and end
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
    outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
    return outputs


GPT2Tokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens
# set pad_token_id to unk_token_id -> be careful here as unk_token_id == eos_token_id == bos_token_id
gpt2_tokenizer.pad_token = gpt2_tokenizer.unk_token

ValueError: Unrecognized configuration class <class 'transformers.models.encoder_decoder.configuration_encoder_decoder.EncoderDecoderConfig'> for this kind of AutoModel: AutoModel.
Model type should be one of AlbertConfig, AlignConfig, AltCLIPConfig, ASTConfig, AutoformerConfig, BarkConfig, BartConfig, BeitConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig, BitConfig, BlenderbotConfig, BlenderbotSmallConfig, BlipConfig, Blip2Config, BloomConfig, BridgeTowerConfig, BrosConfig, CamembertConfig, CanineConfig, ChineseCLIPConfig, ClapConfig, CLIPConfig, CLIPSegConfig, LlamaConfig, CodeGenConfig, ConditionalDetrConfig, ConvBertConfig, ConvNextConfig, ConvNextV2Config, CpmAntConfig, CTRLConfig, CvtConfig, Data2VecAudioConfig, Data2VecTextConfig, Data2VecVisionConfig, DebertaConfig, DebertaV2Config, DecisionTransformerConfig, DeformableDetrConfig, DeiTConfig, DetaConfig, DetrConfig, DinatConfig, Dinov2Config, DistilBertConfig, DonutSwinConfig, DPRConfig, DPTConfig, EfficientFormerConfig, EfficientNetConfig, ElectraConfig, EncodecConfig, ErnieConfig, ErnieMConfig, EsmConfig, FalconConfig, FlaubertConfig, FlavaConfig, FNetConfig, FocalNetConfig, FSMTConfig, FunnelConfig, GitConfig, GLPNConfig, GPT2Config, GPT2Config, GPTBigCodeConfig, GPTNeoConfig, GPTNeoXConfig, GPTNeoXJapaneseConfig, GPTJConfig, GPTSanJapaneseConfig, GraphormerConfig, GroupViTConfig, HubertConfig, IBertConfig, IdeficsConfig, ImageGPTConfig, InformerConfig, JukeboxConfig, Kosmos2Config, LayoutLMConfig, LayoutLMv2Config, LayoutLMv3Config, LEDConfig, LevitConfig, LiltConfig, LlamaConfig, LongformerConfig, LongT5Config, LukeConfig, LxmertConfig, M2M100Config, MarianConfig, MarkupLMConfig, Mask2FormerConfig, MaskFormerConfig, MaskFormerSwinConfig, MBartConfig, MCTCTConfig, MegaConfig, MegatronBertConfig, MgpstrConfig, MistralConfig, MobileBertConfig, MobileNetV1Config, MobileNetV2Config, MobileViTConfig, MobileViTV2Config, MPNetConfig, MptConfig, MraConfig, MT5Config, MvpConfig, NatConfig, NezhaConfig, NllbMoeConfig, NystromformerConfig, OneFormerConfig, OpenLlamaConfig, OpenAIGPTConfig, OPTConfig, Owlv2Config, OwlViTConfig, PegasusConfig, PegasusXConfig, PerceiverConfig, PersimmonConfig, PLBartConfig, PoolFormerConfig, ProphetNetConfig, PvtConfig, QDQBertConfig, ReformerConfig, RegNetConfig, RemBertConfig, ResNetConfig, RetriBertConfig, RobertaConfig, RobertaPreLayerNormConfig, RoCBertConfig, RoFormerConfig, RwkvConfig, SamConfig, SeamlessM4TConfig, SegformerConfig, SEWConfig, SEWDConfig, Speech2TextConfig, SpeechT5Config, SplinterConfig, SqueezeBertConfig, SwiftFormerConfig, SwinConfig, Swin2SRConfig, Swinv2Config, SwitchTransformersConfig, T5Config, TableTransformerConfig, TapasConfig, TimeSeriesTransformerConfig, TimesformerConfig, TimmBackboneConfig, TrajectoryTransformerConfig, TransfoXLConfig, TvltConfig, UMT5Config, UniSpeechConfig, UniSpeechSatConfig, VanConfig, VideoMAEConfig, ViltConfig, VisionTextDualEncoderConfig, VisualBertConfig, ViTConfig, ViTHybridConfig, ViTMAEConfig, ViTMSNConfig, VitDetConfig, VitsConfig, VivitConfig, Wav2Vec2Config, Wav2Vec2ConformerConfig, WavLMConfig, WhisperConfig, XCLIPConfig, XGLMConfig, XLMConfig, XLMProphetNetConfig, XLMRobertaConfig, XLMRobertaXLConfig, XLNetConfig, XmodConfig, YolosConfig, YosoConfig.

In [10]:
# set decoding params
model.config.decoder_start_token_id = gpt2_tokenizer.bos_token_id
model.config.eos_token_id = gpt2_tokenizer.eos_token_id
model.config.max_length = 142
model.config.min_length = 56
model.config.no_repeat_ngram_size = 3
model.early_stopping = True
model.length_penalty = 2.0
model.num_beams = 4

# load train and validation data
# train_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="train")
# val_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="validation[:5%]")

# load rouge for validation
rouge = nlp.load_metric("rouge", experiment_id=2)

encoder_length = 128
decoder_length = 128
batch_size = 32


# map data correctly
def map_to_encoder_decoder_inputs(batch):    # Tokenizer will automatically set [BOS] <text> [EOS]
    # use bert tokenizer here for encoder
    inputs = bert_tokenizer(batch["reference"], padding="max_length", truncation=True, max_length=encoder_length)
    # force summarization <= 128
    outputs = gpt2_tokenizer(batch["translation"], padding="max_length", truncation=True, max_length=decoder_length)

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    batch["decoder_input_ids"] = outputs.input_ids
    batch["labels"] = outputs.input_ids.copy()
    batch["decoder_attention_mask"] = outputs.attention_mask

    # complicated list comprehension here because pad_token_id alone is not good enough to know whether label should be excluded or not
    batch["labels"] = [
        [-100 if mask == 0 else token for mask, token in mask_and_tokens] for mask_and_tokens in [zip(masks, labels) for masks, labels in zip(batch["decoder_attention_mask"], batch["labels"])]
    ]

    assert all([len(x) == encoder_length for x in inputs.input_ids])
    assert all([len(x) == decoder_length for x in outputs.input_ids])

    return batch


def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    # all unnecessary tokens are removed
    pred_str = gpt2_tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = gpt2_tokenizer.eos_token_id
    label_str = gpt2_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }


# make train dataset ready
train_dataset = train_dataset.map(
    map_to_encoder_decoder_inputs, batched=True, batch_size=batch_size, remove_columns=["reference", "translation"],
)
train_dataset.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

# same for validation dataset
# TODO: rename test_dataset to val_dataset
val_dataset = test_dataset.map(
    map_to_encoder_decoder_inputs, batched=True, batch_size=batch_size, remove_columns=["reference", "translation"],
)
val_dataset.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

INFO:nlp.load:Checking /root/.cache/huggingface/datasets/5ecb6e4b474317b41ae1fe5d702d1af8d86d452f0b1d70f77a12f6f014ded6ac.35bc2c477aa456d2f589656477ccb0b463c21cdfb83a9de86d63de8560a96d1b.py for additional imports.
INFO:nlp.load:Found main folder for metric https://s3.amazonaws.com/datasets.huggingface.co/nlp/metrics/rouge/rouge.py at /root/anaconda3/envs/torch/lib/python3.11/site-packages/nlp/metrics/rouge
INFO:nlp.load:Found specific version folder for metric https://s3.amazonaws.com/datasets.huggingface.co/nlp/metrics/rouge/rouge.py at /root/anaconda3/envs/torch/lib/python3.11/site-packages/nlp/metrics/rouge/06783dbed5f6b6a5413f84d2a5f0d9dc9cb871f1aeb3787f2c90a8e3fe60b1c1
INFO:nlp.load:Found script file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/metrics/rouge/rouge.py to /root/anaconda3/envs/torch/lib/python3.11/site-packages/nlp/metrics/rouge/06783dbed5f6b6a5413f84d2a5f0d9dc9cb871f1aeb3787f2c90a8e3fe60b1c1/rouge.py
INFO:nlp.load:Couldn't find dataset infos file at htt

In [25]:
# set training arguments - these params are not really tuned, feel free to change
training_args = TrainingArguments(
    output_dir="./checkpoints",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    # predict_from_generate=True,
    # evaluate_during_training=True,
    do_train=True,
    do_eval=True,
    logging_steps=1000,
    save_steps=1000,
    eval_steps=1000,
    overwrite_output_dir=True,
    warmup_steps=2000,
    save_total_limit=5,
    fp16=True,
    num_train_epochs=0.01,
    use_cpu=False,
)

# instantiate trainer
trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

# start training
trainer.train()



Step,Training Loss


TrainOutput(global_step=145, training_loss=1.0736075486859371e-05, metrics={'train_runtime': 59.7492, 'train_samples_per_second': 77.36, 'train_steps_per_second': 2.427, 'total_flos': 709387930828800.0, 'train_loss': 1.0736075486859371e-05, 'epoch': 0.01})

In [26]:
trainer.save_model('./checkpoint')

In [27]:
def generate_summary(s):
    # Tokenizer will automatically set [BOS] <text> [EOS]
    # cut off at BERT max length 512
    inputs = bert_tokenizer([s], padding="max_length", truncation=True, max_length=encoder_length, return_tensors="pt")
    input_ids = inputs.input_ids.to("cuda")
    attention_mask = inputs.attention_mask.to("cuda")

    model.eval()
    outputs = model.generate(input_ids, attention_mask=attention_mask)

    # all special tokens including will be removed
    output_str = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True)

    # batch["pred"] = output_str

    return output_str

In [34]:
result = generate_summary("example")
print(result)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['...--- - - ------------------ --- --- -------------+++++++++++++++++++ ++ ++ ++++++++)++)++)+)+)+)++)++)++++++;++;++;++++++++.+.+.+,+.+.+++,+.+,+,+.++.+,++plusplusplus plus plus plusplusplusPlusPlusPlusplusplusminusminusminus minus minus minusminusminusplusplus Plus Plus PlusPlusPlus Plus Plusplusplus PLUSPlusPlus plus plusPlusPlus PLUS PLUS PLUSPlusplusPlusplus plusplusPlus PlusPlusplus PlusPlus PlusplusPlus plusplus plusPlusplusminusplusminus minusminus minus']
