In [None]:
import re
import torchmetrics
import torch

import pandas as pd
import pytorch_lightning as pl

from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset


config = {
    'model_name': 'bert-base-uncased',
    'learning_rate': 1e-5,
    'path': '../input/sentiment-analysis/',
    'max_seq_len': 64,
    'batch_size': 32,
    'num_workers': 2,
    'num_epochs': 5,
    'output_units': 1,
    'dropout': 0.1
}

In [None]:
class SentimentAnalyzer(pl.LightningModule):
    def __init__(self, model_name, output_units, dropout):
        super().__init__()
        self.model = torch.hub.load('huggingface/pytorch-transformers', 
                                    'modelForSequenceClassification', 
                                    model_name, 
                                    num_labels=output_units, 
                                    hidden_dropout_prob=dropout,
                                   )
        
        self.train_acc = torchmetrics.Accuracy()
        self.val_acc = torchmetrics.Accuracy() 
        
    def forward(self, input_ids, attention_mask):
        return self.model(input_ids, attention_mask)[0]
    
    def training_step(self, batch, batch_nb):
        input_ids, attention_mask, targets = batch
        preds = self(input_ids, attention_mask)
        loss = F.binary_cross_entropy_with_logits(preds.view(-1), targets)
        self.log('train_loss', loss)
        
        self.train_acc(preds.view(-1), targets.type(torch.int64))
        self.log('train_acc', self.train_acc, on_step=True, on_epoch=False, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_nb):
        input_ids, attention_mask, targets = batch
        preds = self(input_ids, attention_mask)
        loss = F.binary_cross_entropy_with_logits(preds.view(-1), targets)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        
        self.val_acc(preds.view(-1), targets.type(torch.int64))
        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=config['learning_rate'])

In [None]:
class BERTDataset(Dataset):
    def __init__(self, file_name, model_name, max_seq_len, test=False):
        super().__init__()
        self.df = pd.read_csv(file_name).iloc[:10000, :]
        self.tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', model_name)
        self.max_seq_len = max_seq_len
        self.test = test
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        text = self.df.loc[idx, 'text']
        text = re.sub(r"(?:\@|https?\://)\S+", "", text)
        text = re.sub(r"[^a-zA-Z]", ' ', text)
        text = text.lower()
        encoded_input = self.tokenizer.encode_plus(text, padding='max_length', max_length=self.max_seq_len, 
                                                   add_special_tokens=True, truncation='longest_first')
        
        if self.test:
            return torch.tensor(encoded_input['input_ids']), torch.tensor(encoded_input['attention_mask'])        
            
        else:
            target = self.df.loc[idx, 'target']
            return torch.tensor(encoded_input['input_ids']), torch.tensor(encoded_input['attention_mask']), torch.tensor(target, dtype=torch.float32)

In [None]:
train_dataset = BERTDataset(config['path'] + 'train.csv', config['model_name'], config['max_seq_len'])
dev_dataset = BERTDataset(config['path'] + 'validation.csv', config['model_name'], config['max_seq_len'])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'])
dev_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'])

In [None]:
sentiment_analyzer = SentimentAnalyzer(config['model_name'], config['output_units'], config['dropout'])
callbacks = [
    pl.callbacks.ModelCheckpoint(monitor='val_acc', dirpath='./', verbose=True, mode='max'),
    pl.callbacks.EarlyStopping(patience=2, monitor='val_acc', verbose=True, mode='max')
]
trainer = pl.Trainer(max_epochs=config['num_epochs'], callbacks=callbacks, gpus=1)    
trainer.fit(sentiment_analyzer, train_loader, dev_loader)