In [47]:
import math
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import pytorch_lightning as pl

from transformers import MT5ForConditionalGeneration, MT5TokenizerFast

from zemberek import TurkishSentenceExtractor


In [48]:
def tquad2df(path):
    extractor = TurkishSentenceExtractor()

    df = {'title': [], 'context': [], 'question': [], 'cloze': [], 'answer': []}

    dataset = pd.read_json(path).data


    for data in dataset:
        title = data['title']
        for para in data['paragraphs']:
            context = para['context']
            for qa in para['qas']:
                question = qa['question']

                unique_answers = set()
                for answer in qa['answers']:
                    answer_text, answer_span = answer['text'], int(answer['answer_start'])
                    spans = extractor.extract_to_spans(context)

                    for span in spans:
                        if answer_text not in unique_answers and span.in_span(answer_span):
                            unique_answers.add(answer_text)
                            cloze = span.get_sub_string(context)
                            df['title'].append(title)
                            df['context'].append(context)
                            df['question'].append(question)
                            df['cloze'].append(cloze)
                            df['answer'].append(answer_text)

    return pd.DataFrame(df)


In [49]:
class TData(Dataset):
    def __init__(self, df, tokenizer):
        super(TData, self).__init__()

        self.df = df
        self.tok = tokenizer

    def __getitem__(self, i):
        row = self.df.iloc[i]

        # extra_id_0 -> mask token
        # extra_id_1 -> bos token
        # extra_id_2 -> eos_token

        cloze = f"generate question for answer {row['answer']} : {row['cloze']}"

        model_inputs = self.tok(cloze, padding='max_length', max_length=256,
                                truncation=True, return_tensors='pt')
        with self.tok.as_target_tokenizer():
            labels = self.tok(row['question'], padding='max_length', max_length=256,
                                truncation=True, return_tensors="pt")
        model_inputs["labels"] = labels["input_ids"]

        return {k: v[0] for k, v in model_inputs.items()}

    def __len__(self):
        return len(self.df)

In [50]:
TRAIN_DIR = '../../taboo-datasets/tquad2/tquad_train_data_v2.json'
DEV_DIR= '../../taboo-datasets/tquad2/tquad_dev_data_v2.json'

train_df = tquad2df(TRAIN_DIR)
val_df = tquad2df(DEV_DIR)

In [51]:
model = MT5ForConditionalGeneration.from_pretrained('google/mt5-small')
tokenizer = MT5TokenizerFast.from_pretrained('google/mt5-small')



In [52]:
train_data = TData(train_df, tokenizer)
val_data = TData(val_df, tokenizer)

In [53]:
class AttackerModel(pl.LightningModule):
    def __init__(self, model, lr):
        super(AttackerModel, self).__init__()

        self.model = model
        self.lr = lr

    def forward(self, **batch):
        return self.model(**batch)

    def training_step(self, batch, batch_idx):
        loss = self(**batch).loss
        self.log_dict({'loss': loss, 'ppl': math.exp(loss.item())})

        if batch_idx % 500 == 0:
            with torch.no_grad():
                sentence = train_df.iloc[0].cloze
                tokenized_sent = tokenizer(sentence, padding='max_length', max_length=256,
                                           truncation=True, return_tensors='pt')
                generated_question = self.model.cuda().generate(tokenized_sent['input_ids'].cuda(), max_length=256, do_sample=True, top_k=50,
                                                                      top_p=0.95, num_beams=5, num_return_sequences=3)
                print(tokenizer.batch_decode(generated_question, skip_special_tokens=True))

        return loss

    def validation_step(self, batch, batch_idx):
        loss = self(**batch).loss
        self.log_dict({'loss': loss, 'ppl': math.exp(loss.item())}, sync_dist=True)

    def configure_optimizers(self):
        return optim.AdamW(self.model.parameters())


In [54]:
BATCH_SIZE = 4
LR = 1e-3
EPOCHS = 64

In [55]:
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE)

In [46]:
attacker_model = AttackerModel(model=model, lr=LR)

# wandb_logger = WandbLogger(project="Attacker Model")

trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=EPOCHS)
trainer.fit(model=attacker_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

2023-02-28 08:59:55,313 - pytorch_lightning.utilities.rank_zero - INFO
Msg: GPU available: True (mps), used: True

2023-02-28 08:59:55,314 - pytorch_lightning.utilities.rank_zero - INFO
Msg: TPU available: False, using: 0 TPU cores

2023-02-28 08:59:55,314 - pytorch_lightning.utilities.rank_zero - INFO
Msg: IPU available: False, using: 0 IPUs

2023-02-28 08:59:55,314 - pytorch_lightning.utilities.rank_zero - INFO
Msg: HPU available: False, using: 0 HPUs

Msg: Missing logger folder: /Users/quimba/Desktop/adversarial-taboo/modules/notebooks/lightning_logs





2023-02-28 08:59:56,050 - pytorch_lightning.callbacks.model_summary - INFO
Msg: 
  | Name  | Type                        | Params
------------------------------------------------------
0 | model | MT5ForConditionalGeneration | 300 M 
------------------------------------------------------
300 M     Trainable params
0         Non-trainable params
300 M     Total params
1,200.707 Total estimated model params size (MB)



Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


TypeError: Operation 'abs_out_mps()' does not support input type 'int64' in MPS backend.

In [56]:
train_df.iloc[0]

title                                                 Normans
context     Rollo'nun gelişinden önce popülasyonları Picar...
question    Kim geldiğinde orijinal viking yerleşimcilerin...
cloze       Rollo'nun gelişinden önce popülasyonları Picar...
answer                                                  Rollo
Name: 0, dtype: object