In [1]:
import pandas as pd

import warnings
warnings.filterwarnings('ignore')

import transformers
from transformers import AutoTokenizer, AutoModel

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import pkbar

In [2]:
class ToxicData(Dataset):
    def __init__(self, path, tokenizer):
        super(ToxicData, self).__init__()
        
        self.dataframe = pd.read_csv(path)
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        
        pair = self.dataframe.iloc[idx]
        less_toxic = pair['less_toxic']
        more_toxic = pair['more_toxic']
        
        more_toxic = tokenizer.encode_plus(more_toxic, add_special_tokens=True, padding = 'max_length', max_length=128, truncation=True, return_tensors='pt')
        less_toxic = tokenizer.encode_plus(less_toxic, add_special_tokens=True, padding = 'max_length', max_length=128, truncation=True, return_tensors='pt')
        
        ##Squeezing because it adds a sequence dimension which is not needed in the default
        less_toxic_tokens =  less_toxic['input_ids'].squeeze(0)
        less_toxic_attn_mask = less_toxic['attention_mask'].squeeze(0)
        more_toxic_tokens = more_toxic['input_ids'].squeeze(0)
        more_toxic_attn_mask = more_toxic['attention_mask'].squeeze(0)
                
        targets = torch.ones(1).squeeze(0) ##If this is 1 then more toxic should be first input to MarginRankingLoss else use -1.
        
        return {'less_toxic_tokens': less_toxic_tokens, 'less_toxic_attn_mask': less_toxic_attn_mask, 'more_toxic_tokens': more_toxic_tokens, 'more_toxic_attn_mask': more_toxic_attn_mask, 'targets': targets}

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataPath = './data/validation_data.csv'
# df = pd.read_csv('./data/validation_data.csv')

# tokenizer = transformers.AutoTokenizer.from_pretrained('bert-base-cased', cache_dir = './input/tokenizer')
tokenizer = transformers.AutoTokenizer.from_pretrained('bert-base-cased')

dataset = ToxicData(dataPath, tokenizer)
train_loader = DataLoader(dataset, batch_size = 64, shuffle = True, num_workers = 4, pin_memory = True)

In [4]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.bert = AutoModel.from_pretrained('bert-base-cased')
        self.fc = nn.Linear(768, 1)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x, attention_mask):
        
        x = self.bert(x, attention_mask, output_hidden_states=False)
        x = self.dropout(x[1])
        x = self.fc(x)
        return x

In [5]:
%%capture
# model = transformers.AutoModel.from_pretrained('bert-base-cased', cache_dir = './input/model/')
# model = transformers.AutoModel.from_pretrained('bert-base-cased')
model = Model().to(device)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
def train_step(batch):

    optimizer.zero_grad(set_to_none=True)
    
    less_toxic_tokens = batch['less_toxic_tokens'].to(device)
    less_toxic_attn_mask = batch['less_toxic_attn_mask'].to(device)
    
    more_toxic_tokens = batch['more_toxic_tokens'].to(device)
    more_toxic_attn_mask = batch['more_toxic_attn_mask'].to(device)
    
    targets = batch['targets'].to(device)

    less_toxic_score = model(less_toxic_tokens, less_toxic_attn_mask)
    more_toxic_score = model(more_toxic_tokens, more_toxic_attn_mask)

    batch_loss = criterion(more_toxic_score, less_toxic_score, targets)
    batch_loss.backward()
    optimizer.step()
    
    return batch_loss.item()

In [7]:
@torch.no_grad()
def val_step(dataloader):
    
    total_loss = 0
    
    for idx, batch in enumerate(dataloader):    

        less_toxic_tokens = batch['less_toxic_tokens'].to(device)
        less_toxic_attn_mask = batch['less_toxic_attn_mask'].to(device)
        
        more_toxic_tokens = batch['more_toxic_tokens'].to(device)
        more_toxic_attn_mask = batch['more_toxic_attn_mask'].to(device)
        
        targets = batch['targets'].to(device)

        less_toxic_score = model(less_toxic_tokens, less_toxic_attn_mask)
        more_toxic_score = model(more_toxic_tokens, more_toxic_attn_mask)

        batch_loss = criterion(more_toxic_score, less_toxic_score, targets)
        total_loss += batch_loss.item()
    
    return total_loss

In [8]:
criterion = nn.MarginRankingLoss(margin = 1.0)
optimizer = torch.optim.AdamW(model.parameters(), lr = 3e-5)

In [9]:
num_epochs = 20

In [None]:
train_per_epoch = len(train_loader)

for epoch in range(num_epochs):
    
    print("\n")
    kbar = pkbar.Kbar(target = train_per_epoch, epoch=epoch, num_epochs=num_epochs, width = 10, always_stateful=False)
    
    for idx, batch in enumerate(train_loader):    

        batch_loss = train_step(batch)
        
        ########################################################################
        kbar.update(idx, values= [("train_loss", batch_loss)])
        ########################################################################
    
    ########################################################################
    # val_loss = val_step(dataloader)
    # kbar.add(1, values = [("val_loss", val_loss)])
    ########################################################################



Epoch: 1/20

Epoch: 2/20

Epoch: 3/20

Epoch: 4/20

Epoch: 5/20

Epoch: 6/20

Epoch: 7/20

Epoch: 8/20
273/471 [====>.....] - ETA: 1:19 - train_loss: 0.4557