In [1]:
import pandas as pd

import warnings
warnings.filterwarnings('ignore')

import transformers
from transformers import AutoTokenizer, AutoModel

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

import os

In [2]:
class ToxicData(Dataset):
    def __init__(self, path, tokenizer):
        super(ToxicData, self).__init__()
        
        self.dataframe = pd.read_csv(path)
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        
        pair = self.dataframe.iloc[idx]
        less_toxic = pair['less_toxic']
        more_toxic = pair['more_toxic']
        
        more_toxic = tokenizer.encode_plus(more_toxic, add_special_tokens=True, padding = 'max_length', max_length=128, truncation=True, return_tensors='pt')
        less_toxic = tokenizer.encode_plus(less_toxic, add_special_tokens=True, padding = 'max_length', max_length=128, truncation=True, return_tensors='pt')
        
        ##Squeezing because it adds a sequence dimension which is not needed in the default
        less_toxic_tokens =  less_toxic['input_ids'].squeeze(0)
        less_toxic_attn_mask = less_toxic['attention_mask'].squeeze(0)
        more_toxic_tokens = more_toxic['input_ids'].squeeze(0)
        more_toxic_attn_mask = more_toxic['attention_mask'].squeeze(0)
                
        targets = torch.ones(1).squeeze(0) ##If this is 1 then more toxic should be first input to MarginRankingLoss else use -1.
        
        return {'less_toxic_tokens': less_toxic_tokens, 'less_toxic_attn_mask': less_toxic_attn_mask, 'more_toxic_tokens': more_toxic_tokens, 'more_toxic_attn_mask': more_toxic_attn_mask, 'targets': targets}

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataPath = './data/validation_data.csv'
# df = pd.read_csv('./data/validation_data.csv')

# tokenizer = transformers.AutoTokenizer.from_pretrained('bert-base-cased', cache_dir = './input/tokenizer')
tokenizer = transformers.AutoTokenizer.from_pretrained('bert-base-cased')

dataset = ToxicData(dataPath, tokenizer)
train_loader = DataLoader(dataset, batch_size = 128, shuffle = True, num_workers = 4, pin_memory = True)

In [4]:
class LitBert(pl.LightningModule):
    def __init__(self):
        super(LitBert, self).__init__()
        
        self.bert = AutoModel.from_pretrained('bert-base-cased')
        self.fc = nn.Linear(768, 1)
        self.dropout = nn.Dropout(0.5)
            
    def forward(self, x, attention_mask):
        
        x = self.bert(x, attention_mask, output_hidden_states=False)
        x = self.dropout(x[1])
        x = self.fc(x)
        return x
    
    def training_step(self, batch, batch_idx):
            
        less_toxic_tokens = batch['less_toxic_tokens']
        less_toxic_attn_mask = batch['less_toxic_attn_mask']
        more_toxic_tokens = batch['more_toxic_tokens']
        more_toxic_attn_mask = batch['more_toxic_attn_mask']
        targets = batch['targets']

        less_toxic_score = model(less_toxic_tokens, less_toxic_attn_mask)
        more_toxic_score = model(more_toxic_tokens, more_toxic_attn_mask)

        loss = F.margin_ranking_loss(more_toxic_score, less_toxic_score, targets, margin = 1.0)
        
        self.log("train_loss", loss)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr = 3e-5)
        return optimizer

In [5]:
%%capture
# model = transformers.AutoModel.from_pretrained('bert-base-cased', cache_dir = './input/model/')
# model = transformers.AutoModel.from_pretrained('bert-base-cased')
model = LitBert()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
num_epochs = 20

In [7]:
logger = TensorBoardLogger(save_dir = os.getcwd(), version=1, name='lightning_logs')
trainer = pl.Trainer(max_epochs=num_epochs, precision = 16, gpus = 1, deterministic = True, logger=logger)
trainer.fit(model, train_loader)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type      | Params
--------------------------------------
0 | bert    | BertModel | 108 M 
1 | fc      | Linear    | 769   
2 | dropout | Dropout   | 0     
--------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
216.622   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

In [8]:
@torch.no_grad()
def val_step(dataloader):
    
    total_loss = 0
    
    for idx, batch in enumerate(dataloader):    

        less_toxic_tokens = batch['less_toxic_tokens'].to(device)
        less_toxic_attn_mask = batch['less_toxic_attn_mask'].to(device)
        
        more_toxic_tokens = batch['more_toxic_tokens'].to(device)
        more_toxic_attn_mask = batch['more_toxic_attn_mask'].to(device)
        
        targets = batch['targets'].to(device)

        less_toxic_score = model(less_toxic_tokens, less_toxic_attn_mask)
        more_toxic_score = model(more_toxic_tokens, more_toxic_attn_mask)

        batch_loss = criterion(more_toxic_score, less_toxic_score, targets)
        total_loss += batch_loss.item()
    
    return total_loss