In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

import pandas as pd
import transformers
from sklearn.model_selection import train_test_split

from utilities import utils

RANDOM_STATE = 5

In [None]:
data_path = Path('../data/')

bbc_data = pd.read_csv(data_path/'BBC_News_Train.csv')
utils.df_summarise(bbc_data)

In [None]:
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased')
check_num_of_tokens = lambda row: len(tokenizer(row)['input_ids'])

bbc_data['num_of_tokens'] = bbc_data['Text'].apply(check_num_of_tokens)
utils.df_summarise(bbc_data)

In [None]:
bbc_data['num_of_tokens'].plot.hist()
bbc_data['num_of_tokens'].describe()

In [None]:
bbc_data.Category.value_counts().plot.pie(legend=True);

In [None]:
train, val = train_test_split(
    bbc_data, 
    test_size=0.2, 
    random_state=RANDOM_STATE, 
    shuffle=True, 
    stratify=bbc_data['Category']
    )
for dataset in [train, val]:
    print('='*60)
    print(dataset.shape)
    print(dataset['Category'].value_counts())

In [None]:
train.to_csv(data_path/'bbc_train.csv')
val.to_csv(data_path/'bbc_val.csv')

## Training

In [1]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

from dataset.long_text_data_module import LongTextDataModule
from model.long_text_classifier import LongTextClassifier
from model.callbacks import SwitchPretrainedWeightsState

  "`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package"


In [None]:
NUM_EPOCHS = 20
SEQUENCE_LENGTH = 200
SEQUENCE_OVERLAP = 50
BATCH_SIZE = 1  # TODO: at the moment model can only handle online learning
NUM_FREEZE_PRETRAINED = NUM_EPOCHS
CONFIG_PATH = 'settings/bbc_config.yaml'


data_module = LongTextDataModule(
    config_path=CONFIG_PATH,
    sequence_length=SEQUENCE_LENGTH,
    overlap=SEQUENCE_OVERLAP
)

model = LongTextClassifier(
    num_classes=data_module.num_classes,
    config_path=CONFIG_PATH,
    )

callbacks = [
    SwitchPretrainedWeightsState(),
    ModelCheckpoint(
        filename=f'BBCClassifier-seq_len{SEQUENCE_LENGTH}-ovlp{SEQUENCE_OVERLAP}'+'-{epoch}-{val/accuracy:.3f}',
        monitor='val/accuracy',
        mode='max',
        save_top_k=1,
        verbose=True,
    ),
    EarlyStopping(
        monitor='val/accuracy',
        mode='max',
        patience=5,
        verbose=True,
        ),
]

trainer = Trainer(
    max_epochs=NUM_EPOCHS,
    fast_dev_run=False,
    default_root_dir='../output',
    callbacks=callbacks
)

Training loop:

In [None]:
trainer.fit(
    model,
    train_dataloader=data_module.train_dataloader(),
    val_dataloaders=data_module.val_dataloader()
)