In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import torch
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
# model_checkpoint = "distilbert-base-uncased"
model_checkpoint = "roberta-base"

In [None]:
df = pd.read_csv("../input/jrstc-folds/jrstc_5folds.csv")
df = df.drop_duplicates(subset=['less_toxic', 'more_toxic']).reset_index(drop=True)
df['text'] = ''
df['label'] = 0
for i in range(len(df)):
    less = df.loc[i, 'less_toxic']
    more = df.loc[i, 'more_toxic']
    df.loc[i, 'label'] = i % 2
    if i % 2 == 0:
        df.loc[i, 'text'] = less + '</s>' + more
    else:
        df.loc[i, 'text'] = more + '</s>' + less
df = df.sample(frac=1).reset_index(drop=True)
df

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

In [None]:
class JRSTCDataset(Dataset):
    def __init__(self, tokenizer, df, fold, is_val=False):
        self.tokenizer = tokenizer
        if is_val:
            self.texts = df.loc[df["kfold"]==fold].text.tolist()
            self.labels = df.loc[df["kfold"]==fold].label.tolist()
        else:
            self.texts = df.loc[df["kfold"]!=fold].text.tolist()
            self.labels = df.loc[df["kfold"]!=fold].label.tolist()
        self.encodings = self.tokenizer(self.texts, truncation=True, padding=False)
        
    def __len__(self):
        return len(self.encodings["input_ids"])
    def __getitem__(self, idx):
        return {'input_ids': torch.tensor(self.encodings["input_ids"][idx]), 
                'attention_mask': torch.tensor(self.encodings["attention_mask"][idx]), 
                'labels': torch.tensor(self.labels[idx])
               }

In [None]:
# for fold in range(5):
for fold in range(1):
    print('fold:', fold)
    model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
    train_dataset = JRSTCDataset(tokenizer, df, fold, False)
    val_dataset = JRSTCDataset(tokenizer, df, fold, True)
    training_args = TrainingArguments(
        output_dir=f'./fold{fold}',          # output directory
        num_train_epochs=3,              # total number of training epochs
        per_device_train_batch_size=16,  # batch size per device during training
        per_device_eval_batch_size=64,   # batch size for evaluation
        warmup_ratio=0.1,                # number of warmup steps for learning rate scheduler
        weight_decay=0.01,               # strength of weight decay
        logging_dir=f'./fold{fold}',            # directory for storing logs
        fp16=True,
        report_to='none',
        save_total_limit=1,
        learning_rate=3e-5,
        seed=42,
        group_by_length=True,
        save_strategy='steps',
        save_steps=100,
        evaluation_strategy='steps',
        eval_steps=100,
        logging_strategy='steps',
        logging_steps=100
    )

    trainer = Trainer(
        model=model,                         # the instantiated 🤗 Transformers model to be trained
        args=training_args,                  # training arguments, defined above
        train_dataset=train_dataset,         # training dataset
        eval_dataset=val_dataset,             # evaluation dataset
        tokenizer=tokenizer
    )

    trainer.train()