In [18]:
import sys
sys.path.append('../jigsaw')

from jigsaw.datasets import (
    PairedDataset, RegressionDataset, 
    get_paired_loader, get_regression_loader
)
from jigsaw.deep_models.lightning_models import RegressionModel

from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import callbacks
import pytorch_lightning as pl

from transformers import AutoTokenizer
import torch.optim as optim

from functools import partial
from collections import defaultdict
from tqdm import tqdm
from box import Box
import pandas as pd
import numpy as np
import warnings
import wandb
import os

warnings.filterwarnings("ignore")

In [3]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33msashnevskiy[0m (use `wandb login --relogin` to force relogin)


True

In [4]:
regression_data_raw = pd.read_csv('data/jigsaw-regression/train_data_version1.csv') 

regression_data_postprocess = pd.read_csv('data/jigsaw-regression/train_data_version2.csv')#пока не очень информативно
regression_data_postprocess.drop_duplicates(subset = ['text', 'y'], inplace = True)

competition_data = pd.read_csv('data/jigsaw-rate-severity/validation_data.csv')
competition_data.drop_duplicates(subset = ['less_toxic', 'more_toxic'], inplace = True)

rudit_data = pd.read_csv('data/ruddit/Dataset/ruddit_with_text.csv')
toxic_comment_data = pd.read_csv('data/jigsaw-toxic-comment/train.csv') #вроде норм

In [5]:
data_dict = {
    'regression_postprocess': regression_data_postprocess,
    'regression_raw': regression_data_raw,
    'competition_data': competition_data,
    'rudit_data': rudit_data,
    'toxic_comment_data': toxic_comment_data
}

In [6]:
cfg = {
    'seed': 42,
    'logger': {
        'save_dir': 'models',
        'project': 'Jigsaw',
        'log_model': True,
    },
    'dataset': {
        'type': 'paired', #paired, regression
        'name': 'regression_postprocess',
        'more_toxic_col': 'more_toxic',
        'less_toxic_col': 'less_toxic'
    },  
    'model_name': 'roberta-base',
    'max_length': 128,
    'batch_size': 16,
    'acc_step': 1,
    'epoch': 5,
    'num_classes': 1,
    'margin': 0.5,
    'optimizer':{
        'name': 'optim.AdamW',
        'params':{
            'lr': 1e-5,
            'weight_decay': 1e-5,
        },
    },
    'scheduler':{
        'name': 'get_cosine_schedule_with_warmup',
        'params':{
            'num_warmup_steps': 0.06
        },
    },
    'trainer': {
        'gpus': 1,
        'auto_lr_find': False,
        'progress_bar_refresh_rate': 3,
        'fast_dev_run': False,
        'num_sanity_val_steps': 2,
        #'overfit_batches': 1,
        'resume_from_checkpoint': None,
    }
}

cfg = Box(cfg)
cfg['tokenizer'] = AutoTokenizer.from_pretrained(cfg['model_name'])

In [7]:
seed_everything(cfg.seed)

Global seed set to 42


42

In [8]:
cfg.dataset.type = 'regression'
cfg.dataset.text_col = 'text'
cfg.dataset.target_col = 'y'

In [None]:
model = RegressionModel(cfg, data_dict[cfg.dataset.name], competition_data)

In [10]:
earystopping = EarlyStopping(monitor="val_acc", patience = 2)
lr_monitor = callbacks.LearningRateMonitor()
loss_checkpoint = callbacks.ModelCheckpoint(
      dirpath = os.path.join(cfg.logger.save_dir, cfg.model_name, cfg.dataset.name),
      filename=f"{cfg.model_name}",
      monitor="val_acc",
      save_top_k=1,
      mode="max",
      save_last=False,
      )
wandb_logger = WandbLogger(
      log_model = True,
      )

In [None]:
wandb.init(project = cfg.logger.project, name = f'{cfg.model_name}_{cfg.dataset.name}')
wandb.define_metric("val_acc", summary="max")
wandb.define_metric("val_loss", summary="min")

trainer = pl.Trainer(
      max_epochs=cfg.epoch,
      logger = wandb_logger,
      callbacks=[
            lr_monitor, 
            loss_checkpoint, 
            earystopping
            ],
      deterministic=True,
      **cfg.trainer,
      )
  
trainer.fit(model)