In [None]:
! pip install pytorch-lightning --quiet
! pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/3a/83/e74092e7f24a08d751aa59b37a9fc572b2e4af3918cb66f7766c3affb1b4/transformers-3.5.1-py3-none-any.whl (1.3MB)
[K     |████████████████████████████████| 1.3MB 9.8MB/s 
Collecting sentencepiece==0.1.91
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 31.4MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 52.0MB/s 
[?25hCollecting tokenizers==0.9.3
[?25l  Downloading https://files.pythonhosted.org/packages/4c/34/b39eb9994bc3c999270b69c9eea40ecc6f0e97991dba28282b9fd32d44ee/tokenizers-0.9.3-cp36-cp36m-manylinux1_x86_64.whl (2.9MB)
[K  

In [None]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.metrics import Accuracy 
from pytorch_lightning.metrics.functional.classification import to_categorical 
from torch.utils.data import random_split
from sklearn.metrics import classification_report , confusion_matrix

import torchtext
from torchtext import  data , vocab
from torchtext.datasets import SST

from transformers import BertTokenizer, BertModel


In [None]:
class BERTSentimentClassifier(pl.LightningModule):
    def __init__(self , num_classes , learning_rate=1e-5 , weight_decay=1e-5 ):
      super().__init__()
      self.bert = BertModel.from_pretrained('bert-base-uncased')
      self.hidden1 = nn.Linear(self.bert.config.hidden_size , 256)
      self.hidden2 = nn.Linear(256,64)
      self.out = nn.Linear(64, num_classes)
      self.dropout = nn.Dropout(0.1)
      self.learning_rate = learning_rate
      self.weight_decay  = weight_decay
     
    def forward(self,x):
        _,x = self.bert(x)
        x = F.elu(x ,alpha=0.2)
        x = self.dropout(x)
        x = F.relu(self.hidden1(x))
        x = self.dropout(x)
        x = F.relu(self.hidden2(x))
        x = self.out(x)
        return x
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate , weight_decay=self.weight_decay)
        return optimizer
    
    def training_step(self,batch,batch_idx):
        x, y = batch.text, batch.label-1

        logits = self(x)
        loss = F.cross_entropy(logits,y)
        accuracy = Accuracy()
        acc = accuracy(torch.tensor(logits).cpu(),torch.tensor(y).cpu())

        pbar = {'training_acc': acc}

        return {'loss' : loss , 'progress_bar':pbar}
    
    def training_epoch_end(self , train_step_outputs):
        avg_loss = torch.stack([x['loss'] for x in train_step_outputs]).mean()
        avg_acc =  torch.stack([x['progress_bar']['training_acc'] for x in train_step_outputs]).mean()

        tqdm_dict = {'train_loss': avg_loss , 'train_acc': avg_acc}
        return {
                'progress_bar': tqdm_dict,
                'log': {'train_loss': avg_loss , 'train_acc': avg_acc},
                }
    
    def validation_step(self , batch , batch_idx):
        result = self.training_step(batch,batch_idx)
        result['progress_bar']['val_acc'] = result['progress_bar']['training_acc']
        return result
    
    def validation_epoch_end(self , val_step_outputs):
        avg_loss = torch.stack([x['loss'] for x in val_step_outputs]).mean()
        avg_acc =  torch.stack([x['progress_bar']['val_acc'] for x in val_step_outputs]).mean()

        tqdm_dict = {'val_loss': avg_loss , 'val_acc': avg_acc}
        return {
                'progress_bar': tqdm_dict,
                'log': {'val_loss': avg_loss , 'val_acc': avg_acc},
                }
    
    def test_step(self , batch , batch_idx):
        x, y = batch.text, batch.label-1
        logits = self(x)
        loss = F.cross_entropy(logits,y)
        logits = torch.tensor(logits)
        accuracy = Accuracy()
        acc = accuracy(logits.cpu(), torch.tensor(y).cpu())

        pbar = {'test_acc': acc }
        return {'loss' : loss , 'progress_bar':pbar}

    def test_epoch_end(self , test_step_outputs):
        avg_loss = torch.stack([x['loss'] for x in test_step_outputs]).mean()
        avg_acc =  torch.stack([x['progress_bar']['test_acc'] for x in test_step_outputs]).mean()

        tqdm_dict = {'test_loss': avg_loss , 'test_acc': avg_acc}
        return {
                'progress_bar': tqdm_dict,
                'log': {'test_loss': avg_loss , 'test_acc': avg_acc},
                }
    

In [None]:
  class SSTDataModule(pl.LightningDataModule):
    def __init__(self , batch_size=64 , fine_grained = True , gpu=0):
        super().__init__()
        self.batch_size = batch_size
        self.fine_grained = fine_grained
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            
        init_token_idx = tokenizer.cls_token_id
        eos_token_idx = tokenizer.sep_token_id
        pad_token_idx = tokenizer.pad_token_id
        unk_token_idx = tokenizer.unk_token_id

        def tokenize_and_cut(sentence):
            tokens = tokenizer.tokenize(sentence)
            max_input_length = tokenizer.max_model_input_sizes['bert-base-uncased'] 
            tokens = tokens[:max_input_length-2]
            return tokens

        self.TEXT = data.Field(batch_first = True,
                        use_vocab = False,
                        tokenize = tokenize_and_cut,
                        preprocessing = tokenizer.convert_tokens_to_ids,
                        init_token = init_token_idx,
                        eos_token = eos_token_idx,
                        pad_token = pad_token_idx,
                        unk_token = unk_token_idx)
        self.LABEL = data.Field(sequential=False)
        self.gpu = gpu
       
        self.num_classes = 3 

        if self.fine_grained : 
            self.num_classes = 5
        
        print(f"-- num classes = {self.num_classes}")
        self.train,self.val, self.test = SST.splits(self.TEXT, self.LABEL , fine_grained=self.fine_grained)
        
    def train_dataloader(self):
        return self.train_iter
    
    def val_dataloader(self):
        return self.val_iter
    
    def test_dataloader(self):
        return self.test_iter

def setup_BERT(self,stage=None):
    if stage is None :
        self.LABEL.build_vocab(self.train)
        print(f"Label : {self.LABEL.vocab.stoi}")
        device = 'cpu'
        if self.gpu > 0 :
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        self.train_iter, self.val_iter , self.test_iter = data.BucketIterator.splits(
            (self.train,self.val, self.test), 
            batch_size=self.batch_size,
            sort_within_batch = True,
            device=device
        )

SSTDataModule.setup = setup_BERT
dm = SSTDataModule(batch_size=32 , gpu=1)
dm.setup() 

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…


-- num classes = 5
downloading trainDevTestTrees_PTB.zip


trainDevTestTrees_PTB.zip: 100%|██████████| 790k/790k [00:00<00:00, 2.53MB/s]


extracting
Label : defaultdict(<function _default_unk_index at 0x7ff831c6b268>, {'<unk>': 0, 'positive': 1, 'negative': 2, 'neutral': 3, 'very positive': 4, 'very negative': 5})


In [None]:
pl.seed_everything(1234)

hparam = {
        "num_classes" : dm.num_classes ,
        "learning_rate" : 5e-5,
        "weight_decay" : 1e-3,
    }

model = BERTSentimentClassifier(num_classes=hparam["num_classes"],learning_rate=hparam["learning_rate"],weight_decay=hparam["weight_decay"])
# training
trainer = pl.Trainer(gpus=1, precision=16, max_epochs=5,gradient_clip_val=0.3, progress_bar_refresh_rate=50)
trainer.fit(model, dm)
trainer.test(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.

  | Name    | Type      | Params
--------------------------------------
0 | bert    | BertModel | 109 M 
1 | hidden1 | Linear    | 196 K 
2 | hidden2 | Linear    | 16.4 K
3 | out     | Linear    | 325   
4 | dropout | Dropout   | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



Please use self.log(...) inside the lightningModule instead.

# log on a step or aggregate epoch metric to the logger and/or progress bar
# (inside LightningModule)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
Please use self.log(...) inside the lightningModule instead.

# log on a step or aggregate epoch metric to the logger and/or progress bar
# (inside LightningModule)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.5308), 'test_loss': tensor(1.3840, device='cuda:0')}
--------------------------------------------------------------------------------





[{'test_acc': 0.5308035612106323, 'test_loss': 1.3839584589004517}]