# About this notebook
Copied from [Y.Nakama's Notebook](https://www.kaggle.com/yasufuminakama/jigsaw4-luke-base-starter-sub)
Thanks a lot!!

- Add data
    - measureing hate speech [André Moura's dataset](https://www.kaggle.com/andre112/measuring-hate-speech)
    - toxic public dataframes [Deep Learner's dataset](https://www.kaggle.com/readoc/toxic-public-dataframes)

CV: 0.8050, PrivateLeaderboard: 0.8011

below is Nakama's references
- [Luke](https://arxiv.org/pdf/2010.01057v1.pdf) - base starter notebook
- [Inference notebook](https://www.kaggle.com/yasufuminakama/jigsaw4-luke-base-starter-sub)
- Approach References
    - https://www.kaggle.com/c/jigsaw-toxic-severity-rating/discussion/286471
    - https://www.kaggle.com/debarshichanda/pytorch-w-b-jigsaw-starter
    - https://www.kaggle.com/debarshichanda/0-816-jigsaw-inference
    - Thanks for sharing @debarshichanda

# Directory settings

In [None]:
# ====================================================
# Directory settings
# ====================================================
import os

OUTPUT_DIR = './'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

# CFG

In [None]:
# ====================================================
# CFG
# ====================================================
class CFG:
    competition='Jigsaw4-2'
    _wandb_kernel='HideBu'
    debug=False
    apex=True
    print_freq=50
    num_workers=4
    model="studio-ousia/luke-base"
    scheduler='cosine' # ['linear', 'cosine']
    batch_scheduler=True
    num_cycles=0.55
    num_warmup_steps=0
    epochs=3 # 3
    encoder_lr=1e-5
    decoder_lr=1e-5
    min_lr=1e-6
    eps=1e-6
    betas=(0.9, 0.999)
    batch_size=64 # 64
    fc_dropout=0.2 # 0
    text="text"
    target="target"
    target_size=1
    head=32 # 32
    tail=32 # 32
    max_len=head+tail
    weight_decay=0.01
    gradient_accumulation_steps=1
    max_grad_norm=1000
    margin=0.5
    seed=42
    n_fold=7 # 5
    trn_fold=[0, 1, 2, 3, 4, 5, 6]
    train=True

In [None]:
# ====================================================
# wandb
# ====================================================
import wandb

try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    secret_value_0 = user_secrets.get_secret("dadf9abee2eb795f94d957e7136922bb0bcc3c6c")
    wandb.login(key=secret_value_0)
    anony = None
except:
    anony = "must"
    print('If you want to use your W&B account, go to Add-ons -> Secrets and provide your W&B access token. Use the Label name as wandb_api. \nGet your W&B access token from here: https://wandb.ai/authorize')

    
def class2dict(f):
    return dict((name, getattr(f, name)) for name in dir(f) if not name.startswith('__'))

run = wandb.init(project='Jigsaw4-Public2', 
                 name=CFG.model,
                 config=class2dict(CFG),
                 group=CFG.model,
                 job_type="train",
                 anonymous=anony)

# Library

In [None]:
# ====================================================
# Library
# ====================================================
import os
import gc
import re
import sys
import json
import time
import math
import string
import pickle
import random
import joblib
import itertools
import warnings
warnings.filterwarnings("ignore")

import scipy as sp
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import DataLoader, Dataset

os.system('pip uninstall -q transformers -y')
os.system('pip uninstall -q tokenizers -y')
os.system('pip uninstall -q huggingface_hub -y')

os.system('mkdir -p /tmp/pip/cache-tokenizers/')
os.system('cp ../input/tokenizers-0103/tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl /tmp/pip/cache-tokenizers/')
os.system('pip install -q --no-index --find-links /tmp/pip/cache-tokenizers/ tokenizers')

os.system('mkdir -p /tmp/pip/cache-huggingface-hub/')
os.system('cp ../input/huggingface-hub-008/huggingface_hub-0.0.8-py3-none-any.whl /tmp/pip/cache-huggingface-hub/')
os.system('pip install -q --no-index --find-links /tmp/pip/cache-huggingface-hub/ huggingface_hub')

os.system('mkdir -p /tmp/pip/cache-transformers/')
os.system('cp ../input/transformers-470/transformers-4.7.0-py3-none-any.whl /tmp/pip/cache-transformers/')
os.system('pip install -q --no-index --find-links /tmp/pip/cache-transformers/ transformers')

import tokenizers
import transformers
print(f"tokenizers.__version__: {tokenizers.__version__}")
print(f"transformers.__version__: {transformers.__version__}")
from transformers import LukeTokenizer, LukeModel, LukeConfig
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Utils

In [None]:
# ====================================================
# Utils
# ====================================================
def get_score(df):
    score = len(df[df['less_toxic_pred'] < df['more_toxic_pred']]) / len(df)
    return score


def get_logger(filename=OUTPUT_DIR+'train'):
    from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=f"{filename}.log")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = get_logger()

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(seed=42)

In [None]:
# ====================================================
# text cleaning (add: 22-1-16)
# ====================================================
from bs4 import BeautifulSoup
def text_cleaning(text):
    '''
    Cleans text into a basic form for NLP. Operations include the following:-
    1. Remove special charecters like &, #, etc
    2. Removes extra spaces
    3. Removes embedded URL links
    4. Removes HTML tags
    5. Removes emojis
    
    text - Text piece to be cleaned.
    '''
    template = re.compile(r'https?://\S+|www\.\S+') #Removes website links
    text = template.sub(r'', text)
    
    soup = BeautifulSoup(text, 'lxml') #Removes HTML tags
    only_text = soup.get_text()
    text = only_text
    text = re.sub(r"[^a-zA-Z\d]", " ", text) #Remove special Charecters
    text = re.sub(' +', ' ', text) #Remove Extra Spaces
    text = text.strip() # remove spaces at the beginning and at the end of string

    return text

In [None]:
# FEATURE_WTS = {
#     'severe_toxic': 1.5, 'identity_hate': 1.0, 'threat': 1.0, 
#     'insult': 1.0, 'toxic': 1.0, 'obscene': 1.0, 
# }
FEATURE_WTS = {
    'severe_toxic': 1.5, 'identity_hate': 1.5, 'threat': 1.5, 
    'insult': 0.64, 'toxic': 0.32, 'obscene': 0.16, 
}
PSEUDO_LABEL_WEIGHT = 0.033

FEATURES = list(FEATURE_WTS.keys())
FEATURES

In [None]:
import re
import nltk
nltk.download('omw-1.4')
from nltk.stem import WordNetLemmatizer

wordnet_lemmatizer = WordNetLemmatizer()
def replaceURL(text):
    """ Replaces url address with "url" """
    text = re.sub('((www\.[^\s]+)|(https?://[^\s]+))','url',text)
    text = re.sub(r'#([^\s]+)', r'\1', text)
    return text

def replaceAbbrev(text):
    text = re.sub(r"what's", "what is ",text)    
    text = re.sub(r"\'ve", " have ",text)
    text = re.sub(r"can't", "cannot ",text)
    text = re.sub(r"n't", " not ",text)
    text = re.sub(r"i'm", "i am ",text)
    text = re.sub(r"\'re", " are ",text)
    text = re.sub(r"\'d", " would ",text)
    text = re.sub(r"\'ll", " will ",text)
    text = re.sub(r"\'scuse", " excuse ",text)
    text = re.sub(r"\'s", " ",text)
    text = re.sub(r"FC", "FUCK",text)
    text = re.sub(r"fc", "FUCK",text) 
    return text

def removeUnicode(text):
    """ Removes unicode strings like "\u002c" and "x96" """
    text = re.sub(r'(\\u[0-9A-Fa-f]+)',r' ', text)       
    text = re.sub(r'[^\x00-\x7f]',r' ',text)
    return text
def removeRepeatPattern(text):
    text=re.sub(r'([a-zA-Z])\1{2,}\b',r'\1\1',text)
    text=re.sub(r'([a-zA-Z])\1\1{2,}\B',r'\1\1\1',text)
    text=re.sub(r'[ ]{2,}',' ',text)
    return text

def replaceAtUser(text):
    """ Replaces "@user" with "atUser" """
    text = re.sub('@[^\s]+','atUser',text)
    return text

def replaceMultiToxicWords(text):
    text = re.sub(r'(fuckfuck)','fuck fuck ',text)
    text = re.sub(r'(f+)( *)([u|*|_]+)( *)([c|*|_]+)( *)(k)+','fuck',text)
    text = re.sub(r'(h+)(a+)(h+)(a+)','ha ha ',text)
    text = re.sub(r'(s+ *h+ *[i|!]+ *t+)','shit',text)
    text = re.sub(r'\b(n+)(i+)(g+)(a+)\b','nigga',text)
    text = re.sub(r'\b(n+)([i|!]+)(g+)(e+)(r+)\b','nigger',text)
    text = re.sub(r'\b(d+)(o+)(u+)(c+)(h+)(e+)( *)(b+)(a+)(g+)\b','douchebag',text)
    text = re.sub(r'([a|@][$|s][s|$])','ass',text)
    text = re.sub(r'(\bfuk\b)','fuck',text)
    return text

def removeNumbers(text):
    """ Removes integers """
    text = re.sub(r"(^|\W)\d+", " ", text)
    text = re.sub("5","s",text)
    text = re.sub("1","i",text)
    text = re.sub("0","o",text)
    return text
                  
def replaceMultiPunc(text):
    text=re.sub(r'([!])\1\1{2,}',r' mxm ',text)
    text=re.sub(r'([?])\1\1{2,}',r' mqm ',text)
    text=re.sub(r'([*])\1\1{2,}',r'*',text)
    return text


replace_pun = {}
separators = set('"%&\'()+,-./:;<=>@[\\]^_`{|}~')
for punc in separators:
    replace_pun[punc] = ' '
replace_pun['&']=' and '

def my_cleaner(s):
    #s = s.lower()
    s=replaceURL(s)
    s=removeUnicode(s)
    s=removeNumbers(s)
    s=replaceAbbrev(s)
    s=replaceMultiToxicWords(s)
    s=replaceMultiPunc(s)
    s=removeRepeatPattern(s)
    
    for punc in separators:
        s= s.replace(punc,replace_pun[punc])                   # remove & replace punctuations
    tokens = nltk.tokenize.word_tokenize(s)                    # split a string into words (tokens)
    tokens = [wordnet_lemmatizer.lemmatize(t) for t in tokens]
    return ' '.join(tokens)

# Data Loading

In [None]:
# ====================================================
# Data Loading
# ====================================================
jigsaw_train = pd.read_csv('../input/jigsaw-toxic-severity-rating/validation_data.csv')
if CFG.debug:
    igsaw_train = igsaw_train.sample(n=100, random_state=CFG.seed).reset_index(drop=True)
test = pd.read_csv('../input/jigsaw-toxic-severity-rating/comments_to_score.csv')
submission = pd.read_csv('../input/jigsaw-toxic-severity-rating/sample_submission.csv')

print(jigsaw_train.shape)
display(jigsaw_train.head())
print(test.shape, submission.shape)
display(test.head())
display(submission.head())

In [None]:
is_duplicate = jigsaw_train.duplicated(subset=['less_toxic', 'more_toxic'])
print("Duplicates: {} %".format(round(is_duplicate.mean() * 100, 2)))

In [None]:
jigsaw_train[is_duplicate].head(6)

In [None]:
jigsaw_train.drop_duplicates(subset=['less_toxic', 'more_toxic'], keep='last', inplace=True)
jigsaw_train.reset_index(drop=True, inplace=True)

In [None]:
jigsaw_train[jigsaw_train.worker==313].head()

In [None]:
jigsaw_train.groupby("worker").count()

In [None]:
# ====================================================
# Old_Train(jigsaw-toxic-comment-classification-challenge)
# ====================================================
old_train = pd.read_csv('../input/d/julian3833/jigsaw-toxic-comment-classification-challenge/train.csv')
display(old_train.shape)
display(old_train.head(3))

In [None]:
old_train['y'] = 0
for feat, wt in FEATURE_WTS.items(): 
    old_train.y += wt*old_train[feat]
old_train.y = old_train.y/old_train.y.max()
display(old_train.head(3))

In [None]:
is_duplicate = old_train.duplicated(subset='comment_text')
print("Duplicates: {} %".format(round(is_duplicate.mean() * 100, 2)))

In [None]:
old_train.y.hist()

In [None]:
pos = old_train[old_train.y>0]
neg = old_train[old_train.y==0].sample(len(pos), random_state=201)
old_train = pd.concat([pos, neg])
display(old_train.shape)
display(old_train.head(3))
old_train.y.hist()

In [None]:
# ====================================================
# concat dataframes
# ====================================================
old_df = pd.DataFrame(columns=["worker", "less_toxic", "more_toxic"])
old_df["less_toxic"] = neg.reset_index(drop=True, inplace=False)["comment_text"]
old_df["more_toxic"] = pos.reset_index(drop=True, inplace=False)["comment_text"]

old_df.reset_index(drop=True, inplace=True)
old_df["worker"] = old_df.index//16 + 753
old_df

In [None]:
pos[pos.id=="0002bcb3da6cb337"]

In [None]:
# ====================================================
# measuring_hate_speech
# ====================================================
hspeech = pd.read_csv('../input/measuring-hate-speech/measuring_hate_speech.csv')
display(hspeech.shape)
display(hspeech.head(3))

In [None]:
# get mean scores for each comment_id
scores_dict = hspeech.groupby('comment_id')['hate_speech_score'].apply(np.mean).to_dict()

# drop duplicate comment_ids
hspeech = hspeech.drop_duplicates(subset='comment_id')
hspeech['hate_speech_score'] = hspeech['comment_id'].map(scores_dict)

In [None]:
hspeech['hate_speech_score'].plot.hist(bins=100, title='Hate speech scores')

In [None]:
hspeech = hspeech[['comment_id', 'text','hate_speech_score']]
hspeech.rename(columns={
    "comment_id":"id", 
    "text":"comment_text", 
    "hate_speech_score":"y"
},inplace=True)

hspeech.head(3)

In [None]:
print(hspeech.shape)
hspeech.y.hist()

In [None]:
pos = hspeech[hspeech.y<(-2)]
neg = hspeech[hspeech.y>0].sample(len(pos), random_state=201)
hspeech = pd.concat([pos, neg])
display(hspeech.shape)
display(hspeech.head(3))
hspeech.y.hist()

In [None]:
# ====================================================
# concat dataframes
# ====================================================
hate_df = pd.DataFrame(columns=["worker", "less_toxic", "more_toxic"])
hate_df["less_toxic"] = neg.reset_index(drop=True, inplace=False)["comment_text"]
hate_df["more_toxic"] = pos.reset_index(drop=True, inplace=False)["comment_text"]

hate_df.reset_index(drop=True, inplace=True)
hate_df["worker"] = hate_df.index//16 + 1768
hate_df

In [None]:
train = pd.concat([jigsaw_train, old_df])
train = pd.concat([train, hate_df])

train.reset_index(drop=True, inplace=True)
train["less_toxic"] = train["less_toxic"].apply(my_cleaner)
train["more_toxic"] = train["more_toxic"].apply(my_cleaner)
train

In [None]:
del jigsaw_train, old_df, pos, neg, hspeech, hate_df
gc.collect()

# CV split

In [None]:
# ====================================================
# CV split
# ====================================================
Fold = GroupKFold(n_splits=CFG.n_fold)
for n, (trn_index, val_index) in enumerate(Fold.split(train, train, train['worker'])):
    train.loc[val_index, 'fold'] = int(n)
train['fold'] = train['fold'].astype(int)
display(train.groupby('fold').size())

# tokenizer

In [None]:
# ====================================================
# tokenizer
# ====================================================
tokenizer = LukeTokenizer.from_pretrained(CFG.model, lowercase=True)
tokenizer.save_pretrained(OUTPUT_DIR+'tokenizer/')
CFG.tokenizer = tokenizer

# Dataset

In [None]:
# ====================================================
# Dataset
# ====================================================
def prepare_input(text, cfg):
    if cfg.tail == 0:
        inputs = cfg.tokenizer.encode_plus(text, 
                                           return_tensors=None, 
                                           add_special_tokens=True, 
                                           max_length=cfg.max_len,
                                           pad_to_max_length=True,
                                           truncation=True)
        for k, v in inputs.items():
            inputs[k] = torch.tensor(v, dtype=torch.long)
    else:
        inputs = cfg.tokenizer.encode_plus(text,
                                           return_tensors=None, 
                                           add_special_tokens=True, 
                                           truncation=True)
        for k, v in inputs.items():
            v_length = len(v)
            if v_length > cfg.max_len:
                v = np.hstack([v[:cfg.head], v[-cfg.tail:]])
            if k == 'input_ids':
                new_v = np.ones(cfg.max_len) * cfg.tokenizer.pad_token_id
            else:
                new_v = np.zeros(cfg.max_len)
            new_v[:v_length] = v 
            inputs[k] = torch.tensor(new_v, dtype=torch.long)
    return inputs


class TrainDataset(Dataset):
    def __init__(self, cfg, df):
        self.cfg = cfg
        self.less_toxic = df['less_toxic'].fillna("none").values
        self.more_toxic = df['more_toxic'].fillna("none").values

    def __len__(self):
        return len(self.less_toxic)

    def __getitem__(self, item):
        less_toxic_inputs = prepare_input(str(self.less_toxic[item]), self.cfg)
        more_toxic_inputs = prepare_input(str(self.more_toxic[item]), self.cfg)
        label = torch.tensor(1, dtype=torch.float)
        return less_toxic_inputs, more_toxic_inputs, label


class TestDataset(Dataset):
    def __init__(self, cfg, df):
        self.cfg = cfg
        self.text = df[cfg.text].fillna("none").values

    def __len__(self):
        return len(self.text)

    def __getitem__(self, item):
        text = str(self.text[item])
        inputs = prepare_input(text, self.cfg)
        return inputs

# Model

In [None]:
# ====================================================
# Model
# ====================================================
class CustomModel(nn.Module):
    def __init__(self, cfg, config_path=None, pretrained=False):
        super().__init__()
        self.cfg = cfg
        if config_path is None:
            self.config = LukeConfig.from_pretrained(cfg.model, output_hidden_states=True)
        else:
            self.config = torch.load(config_path)
        if pretrained:
            self.model = LukeModel.from_pretrained(cfg.model, config=self.config)
        else:
            self.model = LukeModel(self.config)
        self.fc_dropout = nn.Dropout(cfg.fc_dropout)
        self.fc = nn.Linear(self.config.hidden_size, cfg.target_size)
        
    def feature(self, inputs):
        outputs = self.model(**inputs)
        last_hidden_states = outputs[0]
        feature = torch.mean(last_hidden_states, 1)
        return feature

    def forward(self, inputs):
        feature = self.feature(inputs)
        output = self.fc(self.fc_dropout(feature))
        return output

# Helpler functions

In [None]:
# ====================================================
# Helper functions
# ====================================================
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))


def train_fn(fold, train_loader, model, criterion, optimizer, epoch, scheduler, device):
    model.train()
    scaler = torch.cuda.amp.GradScaler(enabled=CFG.apex)
    losses = AverageMeter()
    start = end = time.time()
    global_step = 0
    for step, (less_toxic_inputs, more_toxic_inputs, labels) in enumerate(train_loader):
        for k, v in less_toxic_inputs.items():
            less_toxic_inputs[k] = v.to(device)
        for k, v in more_toxic_inputs.items():
            more_toxic_inputs[k] = v.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        with torch.cuda.amp.autocast(enabled=CFG.apex):
            less_toxic_y_preds = model(less_toxic_inputs)
            more_toxic_y_preds = model(more_toxic_inputs)
            loss = criterion(more_toxic_y_preds, less_toxic_y_preds, labels)
        losses.update(loss.item(), batch_size)
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        scaler.scale(loss).backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
        if (step + 1) % CFG.gradient_accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            global_step += 1
            if CFG.batch_scheduler:
                scheduler.step()
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print('Epoch: [{0}][{1}/{2}] '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  'Grad: {grad_norm:.4f}  '
                  'LR: {lr:.8f}  '
                  .format(epoch+1, step, len(train_loader), 
                          remain=timeSince(start, float(step+1)/len(train_loader)),
                          loss=losses,
                          grad_norm=grad_norm,
                          lr=scheduler.get_lr()[0]))
        wandb.log({f"[fold{fold}] loss": losses.val,
                   f"[fold{fold}] lr": scheduler.get_lr()[0]})
    return losses.avg


def inference_fn(test_loader, model, device):
    preds = []
    model.eval()
    model.to(device)
    tk0 = tqdm(test_loader, total=len(test_loader))
    for inputs in tk0:
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        with torch.no_grad():
            y_preds = model(inputs)
        preds.append(y_preds.sigmoid().to('cpu').numpy())
    predictions = np.concatenate(preds)
    return predictions

In [None]:
# ====================================================
# train loop
# ====================================================
def train_loop(folds, fold):
    
    LOGGER.info(f"========== fold: {fold} training ==========")

    # ====================================================
    # loader
    # ====================================================
    
    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index
    
    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    validation = folds.loc[val_idx].reset_index(drop=True)
    
    valid_folds = sorted(set(validation['less_toxic'].unique()) | set(validation['more_toxic'].unique()))
    valid_folds = pd.DataFrame({'text': valid_folds}).reset_index()
    
    train_dataset = TrainDataset(CFG, train_folds)
    valid_dataset = TestDataset(CFG, valid_folds)

    train_loader = DataLoader(train_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=True,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=False,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=False)

    # ====================================================
    # model & optimizer
    # ====================================================
    model = CustomModel(CFG, config_path=None, pretrained=True)
    torch.save(model.config, OUTPUT_DIR+'config.pth')
    model.to(device)
    
    def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0):
        param_optimizer = list(model.named_parameters())
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_parameters = [
            {'params': [p for n, p in model.model.named_parameters() if not any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': weight_decay},
            {'params': [p for n, p in model.model.named_parameters() if any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': 0.0},
            {'params': [p for n, p in model.named_parameters() if "model" not in n],
             'lr': decoder_lr, 'weight_decay': 0.0}
        ]
        return optimizer_parameters

    optimizer_parameters = get_optimizer_params(model,
                                                encoder_lr=CFG.encoder_lr, 
                                                decoder_lr=CFG.decoder_lr,
                                                weight_decay=CFG.weight_decay)
    optimizer = AdamW(optimizer_parameters, lr=CFG.encoder_lr, eps=CFG.eps, betas=CFG.betas)
    
    # ====================================================
    # scheduler
    # ====================================================
    def get_scheduler(cfg, optimizer, num_train_steps):
        if cfg.scheduler=='linear':
            scheduler = get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps
            )
        elif cfg.scheduler=='cosine':
            scheduler = get_cosine_schedule_with_warmup(
                optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps, num_cycles=cfg.num_cycles
            )
        return scheduler
    
    num_train_steps = int(len(train_folds) / CFG.batch_size * CFG.epochs)
    scheduler = get_scheduler(CFG, optimizer, num_train_steps)

    # ====================================================
    # loop
    # ====================================================
    criterion = nn.MarginRankingLoss(margin=CFG.margin)
    
    best_score = 0.

    for epoch in range(CFG.epochs):

        start_time = time.time()

        # train
        avg_loss = train_fn(fold, train_loader, model, criterion, optimizer, epoch, scheduler, device)

        # eval
        preds = inference_fn(valid_loader, model, device)
        
        # scoring
        valid_folds['pred'] = preds
        if 'less_toxic_pred' in validation.columns:
            validation = validation.drop(columns='less_toxic_pred')
        if 'more_toxic_pred' in validation.columns:
            validation = validation.drop(columns='more_toxic_pred')
        rename_cols = {CFG.text: 'less_toxic', 'pred': 'less_toxic_pred'}
        validation = validation.merge(valid_folds[[CFG.text, 'pred']].rename(columns=rename_cols), 
                                      on='less_toxic', how='left')
        rename_cols = {CFG.text: 'more_toxic', 'pred': 'more_toxic_pred'}
        validation = validation.merge(valid_folds[[CFG.text, 'pred']].rename(columns=rename_cols), 
                                      on='more_toxic', how='left')
        score = get_score(validation)

        elapsed = time.time() - start_time

        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  time: {elapsed:.0f}s')
        LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f}')
        wandb.log({f"[fold{fold}] epoch": epoch+1, 
                   f"[fold{fold}] avg_train_loss": avg_loss, 
                   f"[fold{fold}] score": score})
        
        if score > best_score:
            best_score = score
            LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {score:.4f} Model')
            torch.save({'model': model.state_dict(),
                        'preds': preds},
                        OUTPUT_DIR+f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth")

    preds = torch.load(OUTPUT_DIR+f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth", 
                       map_location=torch.device('cpu'))['preds']
    valid_folds['pred'] = preds
    if 'less_toxic_pred' in validation.columns:
        validation = validation.drop(columns='less_toxic_pred')
    if 'more_toxic_pred' in validation.columns:
        validation = validation.drop(columns='more_toxic_pred')
    rename_cols = {CFG.text: 'less_toxic', 'pred': 'less_toxic_pred'}
    validation = validation.merge(valid_folds[[CFG.text, 'pred']].rename(columns=rename_cols), 
                                  on='less_toxic', how='left')
    rename_cols = {CFG.text: 'more_toxic', 'pred': 'more_toxic_pred'}
    validation = validation.merge(valid_folds[[CFG.text, 'pred']].rename(columns=rename_cols), 
                                  on='more_toxic', how='left')

    torch.cuda.empty_cache()
    gc.collect()
    
    return validation

# Main

In [None]:
if __name__ == '__main__':
    
    def get_result(oof_df):
        score = get_score(oof_df)
        LOGGER.info(f'Score: {score:<.4f}')
    
    if CFG.train:
        # train 
        oof_df = pd.DataFrame()
        for fold in range(CFG.n_fold):
            if fold in CFG.trn_fold:
                _oof_df = train_loop(train, fold)
                oof_df = pd.concat([oof_df, _oof_df])
                LOGGER.info(f"========== fold: {fold} result ==========")
                get_result(_oof_df)
        oof_df = oof_df.reset_index(drop=True)
        # CV result
        LOGGER.info(f"========== CV ==========")
        get_result(oof_df)
        # save result
        oof_df.to_csv(OUTPUT_DIR+'oof_df.csv', index=False)
    
    wandb.finish()