## Setup

In [10]:
import numpy as np
import pandas as pd

import torch
from datasets import load_metric
from datasets import Dataset, DatasetDict
from transformers import TrainingArguments, Trainer
from transformers import AutoModelForSequenceClassification, AutoTokenizer

In [2]:
import logging
import warnings

warnings.simplefilter('ignore')
logging.disable(logging.WARNING)

## Loading Data

In [3]:
# Loading data
train_df = pd.read_csv('data/train_df.csv')
val_df = pd.read_csv('data/val_df.csv')
train_df.score = (train_df.score >= 0.5).astype(float)
val_df.score = (val_df.score >= 0.5).astype(float)
train_df.shape, val_df.shape

((27383, 4), (9090, 4))

## Loading Model & Tokenizer

In [5]:
model_ckpt = "distilbert-base-uncased"
# model_ckpt = "microsoft/deberta-v3-small"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

In [6]:
def tok_func(x): return tokenizer(x["inputs"])

In [7]:
sep = " " + tokenizer.sep_token + " "
sep

' [SEP] '

## Preparing Data

In [17]:
def get_dds(train_df, val_df):
    train_ds = Dataset.from_pandas(train_df).rename_column('score', 'label')
    val_ds = Dataset.from_pandas(val_df).rename_column('score', 'label')

    inps = "anchor","target","context"
    train_ds = train_ds.map(tok_func, batched=True, remove_columns=inps+('inputs',))
    val_ds = val_ds.map(tok_func, batched=True, remove_columns=inps+('inputs',))

    dds = DatasetDict({"train":train_ds, "valid": val_ds})
    return dds

## Training Setup

In [12]:
metric = load_metric('accuracy')

def accuracy(eval_pred):
    predictions, labels = eval_pred
    predictions = np.where(predictions>= 0.5, 1, 0)
    return metric.compute(predictions=predictions, references=labels)

In [13]:
lr,bs = 8e-5, 128
wd,epochs = 0.01, 10

In [14]:
args = TrainingArguments('outputs', learning_rate=lr, warmup_ratio=0.1, lr_scheduler_type='cosine',
                         fp16=True, evaluation_strategy="epoch", 
                         per_device_train_batch_size=bs, per_device_eval_batch_size=int(bs*1.5),
                         num_train_epochs=epochs, weight_decay=wd, report_to='none')

## Training

### Without context

In [15]:
train_df['inputs'] = "TEXT1: " + train_df.anchor + " TEXT2: " + train_df.target
val_df['inputs'] = "TEXT1: " + val_df.anchor + " TEXT2: " + val_df.target

In [18]:
dds = get_dds(train_df, val_df)

                                                               

In [19]:
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=1)
trainer = Trainer(model, args, 
                  train_dataset=dds['train'], eval_dataset=dds['valid'], 
                  tokenizer=tokenizer, compute_metrics=accuracy)

In [20]:
# model : distilbert-base-uncased
trainer.train();

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.166296,0.775578
2,No log,0.149712,0.793069
3,0.159800,0.161773,0.783938
4,0.159800,0.154818,0.791969
5,0.076300,0.163993,0.791419
6,0.076300,0.17908,0.787019
7,0.076300,0.174618,0.791419
8,0.040600,0.17953,0.790979
9,0.040600,0.180879,0.792079
10,0.026200,0.182508,0.791639


### With Context

In [21]:
train_df['inputs'] = "TEXT1: " + train_df.anchor + " TEXT2: " + train_df.target + " CONTEXT: " + train_df.target
val_df['inputs'] = "TEXT1: " + val_df.anchor + " TEXT2: " + val_df.target + " CONTEXT: " + val_df.target

In [22]:
dds = get_dds(train_df, val_df)

                                                               

In [23]:
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=1)
trainer = Trainer(model, args, 
                  train_dataset=dds['train'], eval_dataset=dds['valid'], 
                  tokenizer=tokenizer, compute_metrics=accuracy)

In [24]:
# model : distilbert-base-uncased
trainer.train();

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.158723,0.780528
2,No log,0.150953,0.784928
3,0.157200,0.159375,0.789879
4,0.157200,0.163825,0.782398
5,0.070100,0.169011,0.787789
6,0.070100,0.176864,0.787679
7,0.070100,0.180012,0.789329
8,0.035200,0.184525,0.788229
9,0.035200,0.188165,0.786359
10,0.021000,0.188673,0.787569
