### Purpose
We will use a bert model to classify text

In [1]:
# default_exp scripts.kaggle.bert_classification

In [48]:
# export
import pandas as pd
import numpy as np
import matplotlib
import os

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import json
from argparse import ArgumentParser

from transformers import BertTokenizer, BertModel

#### Download and read the data

In [51]:
# !kaggle datasets download -d rmisra/news-category-dataset
# !unzip news-category-dataset.zip

In [81]:
# export
lines = []
filepath = 'News_Category_Dataset_v2.json'
with open(filepath) as f:
    for line in f.readlines():
        lines.append(json.loads(line))
df = pd.DataFrame(lines)
df['text'] = df['headline'] + ' ' + df['short_description']
del lines

# filter lines without headline
df['text_len'] = df.text.str.len()
df = df[df.text_len > 0].copy()

In [64]:
df.head(2)

Unnamed: 0,category,headline,authors,link,short_description,date,headline_len
0,CRIME,There Were 2 Mass Shootings In Texas Last Week...,Melissa Jeltsen,https://www.huffingtonpost.com/entry/texas-ama...,She left her husband. He killed their children...,2018-05-26,64
1,ENTERTAINMENT,Will Smith Joins Diplo And Nicky Jam For The 2...,Andy McDonald,https://www.huffingtonpost.com/entry/will-smit...,Of course it has a song.,2018-05-26,75


In [77]:
# export
# split into train and validation
num_val = int(df.shape[0] * 0.05)
val_mask = np.random.choice(np.arange(df.shape[0]),num_val, replace=False)
df_train = df[~df.index.isin(val_mask)].copy()
df_val = df[df.index.isin(val_mask)].copy()
print(f'train.shape - {df_train.shape} val.shape - {df_val.shape}')
max_len = df['text'].str.len().max() + 5
print(f'max_len - {max_len}')

max_len = 500

train.shape - (190805, 8) val.shape - (10042, 8)
max_len - 1492


In [78]:
# export
class TextDS(Dataset):
    def __init__(self, df, label_encoder, max_len):
        self.df = df
        self.le = label_encoder
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
        self.max_len = max_len
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        tokenized_text = self.tokenizer.tokenize(self.df.at[idx,'text'])
        
        # https://huggingface.co/transformers/glossary.html
        sequence_dict = self.tokenizer.encode_plus(tokenized_text, max_length=max_len, pad_to_max_length=True)
        cat_tensor = torch.tensor(self.le.transform([self.df.at[idx,'category']]))[0]
        return torch.tensor(sequence_dict['input_ids']), torch.tensor(sequence_dict['attention_mask']), cat_tensor

le = LabelEncoder()
le.fit_transform(df['category'])

array([ 6, 10, 10, ..., 28, 28, 28])

In [6]:
# itr = iter(train_dl)

In [7]:
# next(itr)

In [70]:
# export
class LinBnDrop(nn.Sequential):
    "Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers"

    def __init__(self, n_in, n_out, bn=True, p=0., act=None, lin_first=False):
        layers = [nn.BatchNorm1d(n_out if lin_first else n_in)] if bn else []
        if p != 0: layers.append(nn.Dropout(p))
        lin = [nn.Linear(n_in, n_out, bias=not bn)]
        if act is not None: lin.append(act)
        layers = lin + layers if lin_first else layers + lin
        super().__init__(*layers)

class BertFineTuner(LightningModule):
    def __init__(self, hparams, *args, **kwargs):
        self.hparams = hparams
#         print('num_classes: ', hparams.num_cls)
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased', output_attentions=True)
        self.lin = LinBnDrop(self.bert.config.hidden_size, hparams.num_cls)
        self.loss_func = nn.CrossEntropyLoss()
        
        # freeze all bert parameters
        for param in self.named_parameters():
            if param[0].startswith('bert'):
                param[1].requires_grad = False
        
    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--bsz', default=8, type=int, help='batch_size', )
        parser.add_argument('--val-bsz', default=8, type=int, help='batch_size', )
        parser.add_argument('--lr', default=0.001, type=int, help='batch_size', )
        return parser
        
    def forward(self, input_ids, attention_mask):
        h, _, attn = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        h_cls = h[:,0]
        logits = self.lin(h_cls)
        return logits, attn
    
    def train_dataloader(self):
        return DataLoader(TextDS(df_train.reset_index(), le, max_len), batch_size=self.hparams.bsz, num_workers=8)
    
    def val_dataloader(self):
        return DataLoader(TextDS(df_val.reset_index(), le, max_len),batch_size=self.hparams.val_bsz, num_workers=8)
    
    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, cat_idx = batch
        logits, attn = self.forward(input_ids, attention_mask)
        loss = self.loss_func(logits, cat_idx)
#       
        if batch_idx%10 == 0:
            self.logger.log_metrics({'loss': loss},step=self.current_epoch)
        return {'loss': loss}
    
    def validation_step(self, batch, batch_ids):
        input_ids, attention_mask, cat_idx = batch
        logits, attn = self.forward(input_ids, attention_mask)
        loss = self.loss_func(logits, cat_idx)
        
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=(self.hparams.lr))
        return optimizer

### Add Params

In [34]:
# export
parser = ArgumentParser()
parser = BertFineTuner.add_model_specific_args(parser)
parser.add_argument('--num-cls', type=int, default=len(le.classes_))
parser.add_argument('--gpus', default=0)
parser.add_argument('--lr-gpu', default=0)
parser.add_argument('--val-check-interval', type=int, default=100)
parser.add_argument('--max-epochs', type=int, default=3)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


In [None]:
hparams = parser.parse_args('--bsz 64'.split())

In [None]:
# export 
hparams = parser.parse_args()

### Find Learning Rate

In [None]:
# export 
trainer = Trainer(gpus=hparams.lr_gpu)
model = BertFineTuner(hparams, finding_lr=True)

# find learning rate
train_dl = DataLoader(TextDS(df_train.reset_index(), le, max_len), batch_size=8, num_workers=8)
val_dl = DataLoader(TextDS(df_val.reset_index(), le, max_len), batch_size=8, num_workers=8)
lr_finder = trainer.lr_find(model, train_dataloader=train_dl,val_dataloaders=[val_dl] )
fig = lr_finder.plot(suggest=True)
fig.show()
new_lr = lr_finder.suggestion()
print('new lr: ', new_lr)

In [None]:
# export
wandb_logger = WandbLogger(name='achinta',project='bert-text-cls')
hparams.lr = new_lr
model = BertFineTuner(hparams)

trainer = Trainer(gpus=hparams.gpus,max_epochs=hparams.max_epochs,
                  logger=wandb_logger, val_check_interval=hparams.val_check_interval)

trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]



  | Name      | Type             | Params
-----------------------------------------------
0 | bert      | BertModel        | 108 M 
1 | lin       | LinBnDrop        | 33 K  
2 | loss_func | CrossEntropyLoss | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), maxâ€¦

In [82]:
!nbdev_build_lib --fname 104-bert-text-classification.ipynb
!scp ../ml/scripts/kaggle/bert_classification.py rc:/n/home11/tatacomm/kaggle/datasets/news-category-dataset

Converted 104-bert-text-classification.ipynb.
bert_classification.py                        100% 6558    10.2KB/s   00:00    


### Playground

In [75]:
df.headline.str.len().mean()

57.94203547974329