### Purpose
We will use a bert model to classify text

In [1]:
# default_exp scripts.kaggle.bert_classification

In [2]:
# export
import pandas as pd
import numpy as np
import matplotlib
import os

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.metrics import Accuracy

from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import json
from argparse import ArgumentParser

from transformers import BertTokenizer, BertModel
import logging
# logging.basicConfig(level=logging.DEBUG)

SEED = 2334
torch.manual_seed(SEED)
np.random.seed(SEED)


I0626 16:38:45.967020 4709492160 file_utils.py:39] PyTorch version 1.5.0 available.
I0626 16:38:49.066612 4709492160 file_utils.py:55] TensorFlow version 2.1.0 available.


#### Download and read the data

In [3]:
# !kaggle datasets download -d rmisra/news-category-dataset
# !unzip news-category-dataset.zip

In [10]:
# export
# split into train and validation


train.shape - (196832, 8) val.shape - (4016, 8)
max_len - 1492


In [11]:
# export
class TextDS(Dataset):
    def __init__(self, df, label_encoder, max_len):
        self.df = df
        self.le = label_encoder
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
        self.max_len = max_len
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        tokenized_text = self.tokenizer.tokenize(self.df.at[idx,'text'])
        
        # https://huggingface.co/transformers/glossary.html
        sequence_dict = self.tokenizer.encode_plus(tokenized_text, max_length=self.max_len, pad_to_max_length=True)
        cat_tensor = torch.tensor(self.le.transform([self.df.at[idx,'category']]))[0]
        return torch.tensor(sequence_dict['input_ids']), torch.tensor(sequence_dict['attention_mask']), cat_tensor

array([ 6, 10, 10, ..., 28, 28, 28])

In [12]:
# itr = iter(train_dl)

In [13]:
# next(itr)

In [68]:
# export
class LinBnDrop(nn.Sequential):
    "Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers"

    def __init__(self, n_in, n_out, bn=True, p=0., act=None, lin_first=False):
        layers = [nn.BatchNorm1d(n_out if lin_first else n_in)] if bn else []
        if p != 0: layers.append(nn.Dropout(p))
        lin = [nn.Linear(n_in, n_out, bias=not bn)]
        if act is not None: lin.append(act)
        layers = lin + layers if lin_first else layers + lin
        super().__init__(*layers)

class BertFineTuner(LightningModule):
    def __init__(self, hparams, *args, **kwargs):
        self.hparams = hparams
#         print('num_classes: ', hparams.num_cls)
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased', output_attentions=True)
        self.loss_func = nn.CrossEntropyLoss()
        self.accuracy = Accuracy()
        
        # freeze all bert parameters
        for param in self.named_parameters():
            if param[0].startswith('bert'):
                param[1].requires_grad = False
                
    def prepare_data(self):
        # read data
        lines = []
        filepath = 'News_Category_Dataset_v2.json'
        with open(filepath) as f:
            for line in f.readlines():
                lines.append(json.loads(line))
        df = pd.DataFrame(lines)
        df['text'] = df['headline'] + ' ' + df['short_description']
        df['text'] = df['text'].str.strip()
        del lines

        # filter lines without headline
        df['text_len'] = df.text.str.len()
        df = df[df.text_len > 0].copy()

        # split into train and test
        num_val = int(df.shape[0] * 0.02)
        val_mask = np.random.choice(np.arange(df.shape[0]),num_val, replace=False)
        self.df_train = df[~df.index.isin(val_mask)].copy()
        self.df_val = df[df.index.isin(val_mask)].copy()
        print(f'train.shape - {self.df_train.shape} val.shape - {self.df_val.shape}')
        
        # TODO. avoid hardcoding
        if not self.hparams.max_len:
            self.max_len = df['text'].str.len().max() + 5
            print(f'max_len - {self.max_len}')
        
        self.le = LabelEncoder()
        self.le.fit_transform(df['category'])
        
        # define linear layer
        self.lin = LinBnDrop(self.bert.config.hidden_size, len(self.le.classes_))
        
    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--bsz', default=8, type=int, help='batch_size', )
        parser.add_argument('--val-bsz', default=8, type=int, help='batch_size', )
        parser.add_argument('--lr', default=0.001, type=int, help='batch_size', )
        parser.add_argument('--find-lr', default=1, type=int, help='1 to find lr' )
        parser.add_argument('--max-len', default=500, type=int, help='max len of text' )
        return parser
        
    def forward(self, input_ids, attention_mask):
        h, _, attn = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        h_cls = h[:,0]
        logits = self.lin(h_cls)
        return logits, attn
    
    def train_dataloader(self):
        return DataLoader(TextDS(self.df_train.reset_index(), self.le, self.hparams.max_len), batch_size=self.hparams.bsz, num_workers=8)
    
    def val_dataloader(self):
        return DataLoader(TextDS(self.df_val.reset_index(), self.le, self.hparams.max_len),batch_size=self.hparams.val_bsz, num_workers=8)
    
    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, cat_idx = batch
        logits, attn = self.forward(input_ids, attention_mask)
        loss = self.loss_func(logits, cat_idx)
        
        # get training accuracy
        yhat = torch.argmax(logits, dim=1).detach()
        metrics = {'loss': loss, 'train_accuracy': self.accuracy(yhat,cat_idx )}
        
        # log every 10 steps
        if batch_idx%10 == 0:
            self.logger.log_metrics(metrics)
        return metrics
    
    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, cat_idx = batch
        logits, attn = self(input_ids, attention_mask)
        loss = self.loss_func(logits, cat_idx)
        yhat = torch.argmax(logits, dim=1)
        return {'val_loss': loss, 'yhat': yhat, 'y': cat_idx}
        
    def validation_epoch_end(self, outputs):
        val_loss_mean = torch.stack([o['val_loss'] for o in outputs]).mean()
        yhat_all = torch.stack([o['yhat'] for o in outputs])
        y_all = torch.stack([o['y'] for o in outputs])
        metrics = {'val_loss': val_loss_mean, 'val_accuracy': self.accuracy(yhat_all,y_all)}
        self.logger.log_metrics(metrics)
        return metrics

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=(self.hparams.lr))
        return optimizer

In [29]:
hparams = parser.parse_args('--bsz 64 --find-lr 0'.split())

#### Adjust params

### Find Learning Rate

In [26]:
# # export 
# trainer = Trainer(gpus=hparams.lr_gpu)
# model = BertFineTuner(hparams, finding_lr=True)

# # find learning rate
# if hparams.find_lr == 1:
#     train_dl = DataLoader(TextDS(df_train.reset_index(), le, max_len), batch_size=8, num_workers=8)
#     val_dl = DataLoader(TextDS(df_val.reset_index(), le, max_len), batch_size=8, num_workers=8)
#     lr_finder = trainer.lr_find(model, train_dataloader=train_dl,val_dataloaders=[val_dl] )
#     fig = lr_finder.plot(suggest=True)
#     fig.show()
#     new_lr = lr_finder.suggestion()
#     logging.info(f'new_lr is {new_lr}')
# else:
#     logging.info(f'Not finding lr')
#     new_lr = hparams.lr

In [None]:
# export
def main(hparams):
    model = BertFineTuner(hparams)

#     logging.info(f'hparams.gpu - {hparams.gpus}')
    trainer = Trainer(gpus=hparams.gpus,max_epochs=hparams.max_epochs,
                      logger=wandb_logger, val_check_interval=hparams.val_check_interval,
                     distributed_backend='ddp')
    trainer.fit(model)
    
if __name__ == '__main__':
    wandb_logger = WandbLogger(name='achinta',project='bert-text-cls')
    
    parser = ArgumentParser()
    parser = BertFineTuner.add_model_specific_args(parser)
    parser.add_argument('--gpus', type=int, default=0)
    parser.add_argument('--lr-gpu', default=0)
    parser.add_argument('--val-check-interval', type=int, default=100)
    parser.add_argument('--max-epochs', type=int, default=3)
    hparams = parser.parse_args()
    
    # set to cpu if no gpu is passed. Note that 0 is cpu and [0] is gpu-0
    if not hparams.gpus:
        hparams.gpus = 0
    print(f'hparams.gpus - {hparams.gpus}')
        
    main(hparams)

In [73]:
%%time
!nbdev_build_lib --fname 104-bert-text-classification.ipynb
!scp ../ml/scripts/kaggle/bert_classification.py rc:/n/home11/tatacomm/kaggle/datasets/news-category-dataset

Converted 104-bert-text-classification.ipynb.
bert_classification.py                        100% 7593     3.7KB/s   00:02    
CPU times: user 481 ms, sys: 142 ms, total: 624 ms
Wall time: 26.9 s


### Playground

In [None]:
df.headline.str.len().mean()

In [None]:
a = torch.randn(4, 3)

In [None]:
torch.argmax(a, dim=1)

In [57]:
SEED = 2334
torch.manual_seed(SEED)
np.random.seed(SEED)

np.random.choice(np.arange(100),3)

array([73, 17, 36])

5