<a href="https://colab.research.google.com/github/JayThibs/Weak-Supervised-Learning-Case-Study/blob/main/text_classifier/notebooks/01_dbpedia_14_bert_classification_exploration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

Sun May  2 13:57:13 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   49C    P0    39W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!pip install transformers==4.5.1 --quiet
!pip install pytorch_lightning==1.2.10 --quiet
!pip install wandb --quiet
!pip install datasets



In [3]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, AdamW, get_linear_schedule_with_warmup, AutoModelForSequenceClassification, AutoTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional.classification import auroc
from datasets import load_dataset

import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt

from tqdm import tqdm
import wandb

In [4]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

RANDOM_SEED = 42
BASE_MODEL_NAME = 'bert-base-cased'

np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x7fef3f106bd0>

In [5]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mjacquesthibs[0m (use `wandb login --relogin` to force relogin)


True

## 💡 Configuration tips

W&B integration with Hugging Face can be configured to add extra functionalities:

* auto-logging of models as artifacts: just set environment varilable `WANDB_LOG_MODEL` to `true`
* log histograms of gradients and parameters: by default gradients are logged, you can also log parameters by setting environment variable `WANDB_WATCH` to `all`
* set custom run names with `run_name` arg present in scripts or as part of `TrainingArguments`
* organize runs by project with the `WANDB_PROJECT` environment variable

For more details refer to [W&B + HF integration documentation](https://docs.wandb.ai/integrations/huggingface).

Let's log every trained model.

In [6]:
%env WANDB_LOG_MODEL=true

env: WANDB_LOG_MODEL=true


In [7]:
dbpedia_dataset = load_dataset('dbpedia_14')

Reusing dataset d_bpedia14 (/root/.cache/huggingface/datasets/d_bpedia14/dbpedia_14/2.0.0/7f0577ea0f4397b6b89bfe5c5f2c6b1b420990a1fc5e8538c7ab4ec40e46fa3e)


In [8]:
label_names = [
    "Company",
    "EducationalInstitution",
    "Artist",
    "Athlete",
    "OfficeHolder",
    "MeanOfTransportation",
    "Building",
    "NaturalPlace",
    "Village",
    "Animal",
    "Plant",
    "Album",
    "Film",
    "WrittenWork"]

In [9]:
label_names[0]

'Company'

In [10]:
dbpedia_dataset['train'][0]

{'content': ' Abbott of Farnham E D Abbott Limited was a British coachbuilding business based in Farnham Surrey trading under that name from 1929. A major part of their output was under sub-contract to motor vehicle manufacturers. Their business closed in 1972.',
 'label': 0,
 'title': 'E. D. Abbott Ltd'}

In [11]:
dbpedia_dataset = dbpedia_dataset.rename_column("label", "labels")

`str2int` and `int2str` help us go from class label to their integer mapping.

In [12]:
dbpedia_dataset['train'].features['labels']

ClassLabel(num_classes=14, names=['Company', 'EducationalInstitution', 'Artist', 'Athlete', 'OfficeHolder', 'MeanOfTransportation', 'Building', 'NaturalPlace', 'Village', 'Animal', 'Plant', 'Album', 'Film', 'WrittenWork'], names_file=None, id=None)

In [13]:
# Let's look at the 5th label with int2str.
# We'll use this later when we are outputting the prediction.
dbpedia_dataset['train'].features['labels'].int2str(4)

'OfficeHolder'

For our topic classification task, we use `content` as input and try to predict `labels`.

In [14]:
label_list = dbpedia_dataset['train'].unique('labels')
label_list.sort()
label_list

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]

This particular dataset is split between 14 different topics, that will be represented by 14 classes from our model output.

In [15]:
num_labels = len(label_list)
num_labels

14

The "topic" class needs to be renamed to "labels" for the `Trainer` to find it.

In [16]:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)

In [17]:
def merge_title_with_content(example):
    example["content"] = example["title"] + " " + example["content"]
    return example


def encode(batch):
    return tokenizer(
        batch["content"],
        add_special_tokens=True,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        return_tensors="np",
    )

In [18]:
dbpedia_dataset = dbpedia_dataset.map(merge_title_with_content, num_proc=10)

 

HBox(children=(FloatProgress(value=0.0, description='#0', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#1', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#2', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#3', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#4', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#5', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#6', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#7', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#8', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#9', max=56000.0, style=ProgressStyle(description_width='…











 

HBox(children=(FloatProgress(value=0.0, description='#0', max=7000.0, style=ProgressStyle(description_width='i…

 

HBox(children=(FloatProgress(value=0.0, description='#1', max=7000.0, style=ProgressStyle(description_width='i…

  

HBox(children=(FloatProgress(value=0.0, description='#2', max=7000.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#3', max=7000.0, style=ProgressStyle(description_width='i…

  

HBox(children=(FloatProgress(value=0.0, description='#4', max=7000.0, style=ProgressStyle(description_width='i…

 

HBox(children=(FloatProgress(value=0.0, description='#5', max=7000.0, style=ProgressStyle(description_width='i…

  

HBox(children=(FloatProgress(value=0.0, description='#6', max=7000.0, style=ProgressStyle(description_width='i…

 

HBox(children=(FloatProgress(value=0.0, description='#7', max=7000.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#8', max=7000.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#9', max=7000.0, style=ProgressStyle(description_width='i…













In [19]:
dbpedia_dataset = dbpedia_dataset.map(encode, batched=True, num_proc=10)
# dbpedia_dataset = dbpedia_dataset.map(lambda x: encode(x['content']), batched=True)
dbpedia_dataset.set_format(type="torch", 
                           columns=["input_ids", "token_type_ids", 
                                    "attention_mask", "labels"])

 

HBox(children=(FloatProgress(value=0.0, description='#0', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#1', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#2', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#3', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#4', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#5', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#6', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#7', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#8', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#9', max=56.0, style=ProgressStyle(description_width='ini…











 

HBox(children=(FloatProgress(value=0.0, description='#0', max=7.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#1', max=7.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#2', max=7.0, style=ProgressStyle(description_width='init…

  

HBox(children=(FloatProgress(value=0.0, description='#3', max=7.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#4', max=7.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#5', max=7.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#6', max=7.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#7', max=7.0, style=ProgressStyle(description_width='init…

  

HBox(children=(FloatProgress(value=0.0, description='#8', max=7.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#9', max=7.0, style=ProgressStyle(description_width='init…













In [20]:
# Let's look at an example of text from the dataset.
# You can see that we've successfully added the title to the content.
dbpedia_dataset['train']['content'][0]

'E. D. Abbott Ltd  Abbott of Farnham E D Abbott Limited was a British coachbuilding business based in Farnham Surrey trading under that name from 1929. A major part of their output was under sub-contract to motor vehicle manufacturers. Their business closed in 1972.'

In [21]:
# Labels are equally balanced.
np.unique(dbpedia_dataset['train']['labels'], return_counts=True)

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13]),
 array([40000, 40000, 40000, 40000, 40000, 40000, 40000, 40000, 40000,
        40000, 40000, 40000, 40000, 40000]))

In [22]:
# Test data is balanced too.
np.unique(dbpedia_dataset['test']['labels'], return_counts=True)

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13]),
 array([5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000,
        5000, 5000, 5000]))

In [23]:
# We can convert our input_ids back into words to see how the text was tokenized.
print(tokenizer.convert_ids_to_tokens(dbpedia_dataset['train']['input_ids'][0]))

['[CLS]', 'E', '.', 'D', '.', 'Abbott', 'Ltd', 'Abbott', 'of', 'Far', '##nham', 'E', 'D', 'Abbott', 'Limited', 'was', 'a', 'British', 'coach', '##building', 'business', 'based', 'in', 'Far', '##nham', 'Surrey', 'trading', 'under', 'that', 'name', 'from', '1929', '.', 'A', 'major', 'part', 'of', 'their', 'output', 'was', 'under', 'sub', '-', 'contract', 'to', 'motor', 'vehicle', 'manufacturers', '.', 'Their', 'business', 'closed', 'in', '1972', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]'

In [24]:
# #DEBUGGING - Splice dataset to use smaller number of samples

# BATCH_SIZE = 16

# train_dataloader = torch.utils.data.DataLoader(
#     dbpedia_dataset["train"].select(
#         list(
#             np.random.randint(low=0, high=len(dbpedia_dataset["train"]) - 1, size=1000)
#         )
#     ),
#     batch_size=BATCH_SIZE,
#     shuffle=True,
# )
# test_dataloader = torch.utils.data.DataLoader(
#     dbpedia_dataset["test"].select(
#         list(np.random.randint(low=0, high=len(dbpedia_dataset["test"]) - 1, size=1000))
#     ),
#     batch_size=BATCH_SIZE,
#     shuffle=False,)

In [25]:
debug_train = dbpedia_dataset["train"].select(
        list(
            np.random.randint(low=0, high=len(dbpedia_dataset["train"]) - 1, size=1000)
        ))

debug_test = dbpedia_dataset["test"].select(
        list(np.random.randint(low=0, high=len(dbpedia_dataset["test"]) - 1, size=1000)))

In [63]:
print(len(debug_train['labels']))
debug_train['labels']

1000


tensor([ 3,  3,  9,  6,  2,  1,  3, 13,  2,  4,  4,  6,  1,  8,  1,  8,  8,  4,
         2,  5,  5,  4,  8, 10,  6, 12, 12,  3,  9,  3, 13, 11,  1,  3,  2,  7,
         8, 13,  6,  8, 13, 12,  6, 11,  8,  6,  5,  5,  5, 10,  1,  6,  6,  2,
        11, 11,  5,  8,  7,  4,  6, 10,  5,  4,  5, 13, 12,  8,  6,  5,  0,  0,
        11, 13,  8,  6, 10, 12,  5,  9,  1,  7,  6,  3, 11,  3,  3, 12,  7,  0,
         3,  0,  7,  3,  7,  7,  0,  9,  8, 12, 11,  6,  0,  5, 12,  3, 11,  3,
         4,  5,  2,  5,  5,  3, 13,  6,  1,  3,  4,  2, 10,  2,  4,  4,  8,  9,
        10,  1,  9,  1, 11,  2, 13,  5,  0,  3, 13,  7,  1, 10, 10,  9, 11,  3,
         3, 10, 12, 11,  0,  5, 13,  4,  8,  1,  2, 11, 10,  0,  4,  5,  7,  1,
        12, 10,  2,  5,  9,  8, 10,  6,  7,  6,  4, 12,  3,  5, 12,  1,  4,  2,
         9, 11,  9,  2, 13,  2,  0, 10,  8,  1,  0,  7,  3,  7,  2, 13,  3,  5,
        10,  2, 11,  1,  6,  7,  0, 11,  9, 13,  3,  7,  8,  2,  3, 11,  9,  9,
         9,  2,  7,  4,  6,  9, 13, 13, 

In [26]:
# Let's have a look at one example from the DataLoader.
train_dataloader = DataLoader(dbpedia_dataset['train'],
                              batch_size=1,
                              shuffle=True)

sample_item = next(iter(train_dataloader))

sample_item

{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [27]:
# We create the dataloaders.
train_dataloader = DataLoader(dbpedia_dataset['train'],
                              batch_size=512,
                              shuffle=True)

test_dataloader = DataLoader(dbpedia_dataset['test'],
                             batch_size=512,
                             shuffle=True)

In [28]:
baseline_model = AutoModelForSequenceClassification.from_pretrained(
    BASE_MODEL_NAME,
    output_attentions=False,
    output_hidden_states=False,
    return_dict=True,
    num_labels=num_labels
)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

In [29]:
prediction = baseline_model(sample_item['input_ids'], sample_item['attention_mask'])

In [30]:
prediction

SequenceClassifierOutput([('logits',
                           tensor([[-0.0932, -0.7690, -0.6399, -0.1519,  0.0639,  0.8291, -0.3584, -0.0162,
                                    -0.9446, -0.0527, -0.2492, -0.0557,  0.3037, -0.2569]],
                                  grad_fn=<AddmmBackward>))])

In [31]:
top_prediction = prediction.logits.detach().numpy().argmax().item()
print(top_prediction)
print(dbpedia_dataset['train'].features['labels'].int2str(top_prediction))
print(tokenizer.convert_ids_to_tokens(sample_item['input_ids'][0]))

5
MeanOfTransportation
['[CLS]', 'New', 'Castle', 'Area', 'School', 'District', 'New', 'Castle', 'Area', 'School', 'District', 'is', 'a', 'public', 'school', 'district', 'located', 'in', 'Lawrence', 'County', 'Pennsylvania', '.', 'The', 'district', 'serves', 'the', 'city', 'of', 'New', 'Castle', 'and', 'Taylor', 'Township', '.', 'New', 'Castle', 'Area', 'School', 'District', 'encompasses', 'approximately', '13', 'square', 'miles', '(', '34', 'km', '##2', ')', '.', 'According', 'to', '2007', 'local', 'census', 'data', 'it', 'served', 'a', 'resident', 'population', 'of', '264', '##64', '.', 'The', '2010', 'census', 'found', 'the', 'population', 'declined', 'to', '242', '##86', 'people', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PA

In [32]:
MAX_TOKEN_LENGTH = 128

In [33]:
dbpedia_dataset['train']

Dataset({
    features: ['attention_mask', 'content', 'input_ids', 'labels', 'title', 'token_type_ids'],
    num_rows: 560000
})

In [34]:
def get_topic(sentence, model, tokenize=tokenizer):
  # tokenize the input
  inputs = tokenizer(
        sentence,
        add_special_tokens=True,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        return_tensors="pt",
    )
  # ensure model and inputs are on the same device (GPU)
  inputs = {name: tensor.cuda() for name, tensor in inputs.items()}
  model = model.cuda()
  # get prediction - 14 labels
  with torch.no_grad():
    predictions = model(**inputs)[0].cpu().numpy()
  # get the top prediction class and convert it to its associated label
  top_prediction = predictions.argmax().item()
  return dbpedia_dataset['train'].features['labels'].int2str(top_prediction)

In [35]:
sample_item = dbpedia_dataset['train']['content'][0]
sample_item

'E. D. Abbott Ltd  Abbott of Farnham E D Abbott Limited was a British coachbuilding business based in Farnham Surrey trading under that name from 1929. A major part of their output was under sub-contract to motor vehicle manufacturers. Their business closed in 1972.'

In [36]:
get_topic(sample_item, baseline_model)

'OfficeHolder'

In [37]:
tokenizer(sample_item, return_tensors='pt').items()

dict_items([('input_ids', tensor([[  101,   142,   119,   141,   119, 15176,  4492, 15176,  1104,  8040,
         15898,   142,   141, 15176,  5975,  1108,   170,  1418,  2154, 12851,
          1671,  1359,  1107,  8040, 15898,  9757,  6157,  1223,  1115,  1271,
          1121,  3762,   119,   138,  1558,  1226,  1104,  1147,  5964,  1108,
          1223,  4841,   118,  2329,  1106,  5968,  3686,  9263,   119,  2397,
          1671,  1804,  1107,  2388,   119,   102]])), ('token_type_ids', tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]])), ('attention_mask', tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]]))])

## Training the model

In [38]:
from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    report_to = 'wandb',                     # enable logging to W&B
    output_dir = 'dbpedia14_classification',    # output directory
    overwrite_output_dir = True,
    num_train_epochs = 3,
    evaluation_strategy = 'steps',          # check evaluation metrics at each epoch
    learning_rate = 5e-5,                   # we can customize learning rate
    max_steps = 3000,
    logging_steps = 25,                    # we will log every 100 steps
    eval_steps = 50,                      # we will perform evaluation every 500 steps
    load_best_model_at_end = True,
    metric_for_best_model = 'f1',
    per_device_train_batch_size = 16,
    per_device_eval_batch_size = 16,
    run_name = 'debugging_training_3'            # name of the W&B run
)

In [39]:
# !pip install bert_score --quiet

In [55]:
from datasets import load_metric
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='macro')
    acc = accuracy_score(labels, predictions)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [57]:
dbpedia_dataset

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'content', 'input_ids', 'labels', 'title', 'token_type_ids'],
        num_rows: 560000
    })
    test: Dataset({
        features: ['attention_mask', 'content', 'input_ids', 'labels', 'title', 'token_type_ids'],
        num_rows: 70000
    })
})

In [58]:
trainer = Trainer(
    model = baseline_model,                  # model to be trained
    args = args,                    # training args
    train_dataset=debug_train,
    eval_dataset=debug_test,
    tokenizer=tokenizer,         # for padding batched data
    compute_metrics=compute_metrics,
    ) # for custom metrics

In [59]:
trainer.evaluate()

  _warn_prf(average, modifier, msg_start, len(result))


{'eval_accuracy': 0.086,
 'eval_f1': 0.019159501274629936,
 'eval_loss': 2.69828462600708,
 'eval_mem_cpu_alloc_delta': 118784,
 'eval_mem_cpu_peaked_delta': 0,
 'eval_mem_gpu_alloc_delta': 0,
 'eval_mem_gpu_peaked_delta': 780402176,
 'eval_precision': 0.011266644854216695,
 'eval_recall': 0.06477591036414566,
 'eval_runtime': 17.9894,
 'eval_samples_per_second': 55.588,
 'init_mem_cpu_alloc_delta': 0,
 'init_mem_cpu_peaked_delta': 0,
 'init_mem_gpu_alloc_delta': 0,
 'init_mem_gpu_peaked_delta': 0}

In [64]:
trainer.train()

Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall,Runtime,Samples Per Second
50,1.1249,0.55635,0.978,0.977066,0.976816,0.978161,17.969,55.652
100,0.126,0.105369,0.98,0.979843,0.980255,0.98024,17.9898,55.587
150,0.0531,0.078403,0.984,0.983828,0.983117,0.985259,17.9776,55.625
200,0.0738,0.102686,0.982,0.981203,0.981015,0.982518,17.9727,55.64
250,0.0085,0.104899,0.982,0.981758,0.981905,0.982377,17.9776,55.625
300,0.0149,0.093708,0.982,0.981663,0.981957,0.98192,17.9945,55.573
350,0.0035,0.114157,0.979,0.978539,0.978779,0.979558,17.9747,55.634
400,0.0035,0.140508,0.976,0.975398,0.976105,0.976738,17.9743,55.635
450,0.0023,0.085967,0.985,0.984968,0.985051,0.985323,17.9741,55.636
500,0.0023,0.133289,0.98,0.979663,0.979891,0.980624,17.9765,55.628


TrainOutput(global_step=3000, training_loss=0.035843222200094414, metrics={'train_runtime': 4813.7728, 'train_samples_per_second': 0.623, 'total_flos': 1.5847468381323264e+16, 'epoch': 47.62, 'train_mem_cpu_alloc_delta': -491737088, 'train_mem_gpu_alloc_delta': 1740206080, 'train_mem_cpu_peaked_delta': 503443456, 'train_mem_gpu_peaked_delta': 0})

In [65]:
trainer.save_model()

In [66]:
trainer.save_state()

In [67]:
wandb.finish()

VBox(children=(Label(value=' 413.53MB of 413.53MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=…

0,1
eval/loss,0.12223
eval/accuracy,0.985
eval/f1,0.98481
eval/precision,0.98472
eval/recall,0.98535
eval/runtime,17.9933
eval/samples_per_second,55.576
train/global_step,3000.0
_runtime,4902.0
_timestamp,1619969129.0


0,1
eval/loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eval/accuracy,▁███████████████████████████████████████
eval/f1,▁███████████████████████████████████████
eval/precision,▁███████████████████████████████████████
eval/recall,▁███████████████████████████████████████
eval/runtime,▅▂▃▃▆▃▃▃▄▄▅▃▄▆▅▆▆█▇▇▄█▇▇▄▄▃▃█▄▅▂▂▅▂▂▁▃▄▆
eval/samples_per_second,▄▇▆▆▃▆▆▆▅▅▄▆▅▃▄▃▃▁▂▂▅▁▂▂▅▅▆▆▁▅▄▇▆▄▇▇█▆▅▃
train/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
