<a href="https://colab.research.google.com/github/JayThibs/Weak-Supervised-Learning-Case-Study/blob/main/text_classifier/notebooks/01_dbpedia_14_bert_classification_exploration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

Sat May  1 00:12:53 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   47C    P0    29W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!pip install transformers==4.5.1 --quiet
!pip install pytorch_lightning==1.2.10 --quiet
!pip install wandb --quiet
!pip install datasets



In [3]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, AdamW, get_linear_schedule_with_warmup, AutoModelForSequenceClassification, AutoTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional.classification import auroc
from datasets import load_dataset

import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt

from tqdm import tqdm
import wandb

In [4]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

RANDOM_SEED = 42
BASE_MODEL_NAME = 'bert-base-cased'

np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x7f8e2092fbd0>

In [5]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mjacquesthibs[0m (use `wandb login --relogin` to force relogin)


True

## 💡 Configuration tips

W&B integration with Hugging Face can be configured to add extra functionalities:

* auto-logging of models as artifacts: just set environment varilable `WANDB_LOG_MODEL` to `true`
* log histograms of gradients and parameters: by default gradients are logged, you can also log parameters by setting environment variable `WANDB_WATCH` to `all`
* set custom run names with `run_name` arg present in scripts or as part of `TrainingArguments`
* organize runs by project with the `WANDB_PROJECT` environment variable

For more details refer to [W&B + HF integration documentation](https://docs.wandb.ai/integrations/huggingface).

Let's log every trained model.

In [6]:
%env WANDB_LOG_MODEL=true

env: WANDB_LOG_MODEL=true


In [7]:
dbpedia_dataset = load_dataset('dbpedia_14')

Reusing dataset d_bpedia14 (/root/.cache/huggingface/datasets/d_bpedia14/dbpedia_14/2.0.0/7f0577ea0f4397b6b89bfe5c5f2c6b1b420990a1fc5e8538c7ab4ec40e46fa3e)


In [8]:
label_names = [
    "Company",
    "EducationalInstitution",
    "Artist",
    "Athlete",
    "OfficeHolder",
    "MeanOfTransportation",
    "Building",
    "NaturalPlace",
    "Village",
    "Animal",
    "Plant",
    "Album",
    "Film",
    "WrittenWork"]

In [9]:
label_names[0]

'Company'

In [10]:
dbpedia_dataset['train'][0]

{'content': ' Abbott of Farnham E D Abbott Limited was a British coachbuilding business based in Farnham Surrey trading under that name from 1929. A major part of their output was under sub-contract to motor vehicle manufacturers. Their business closed in 1972.',
 'label': 0,
 'title': 'E. D. Abbott Ltd'}

In [11]:
dbpedia_dataset = dbpedia_dataset.rename_column("label", "labels")

`str2int` and `int2str` help us go from class label to their integer mapping.

In [12]:
dbpedia_dataset['train'].features['labels']

ClassLabel(num_classes=14, names=['Company', 'EducationalInstitution', 'Artist', 'Athlete', 'OfficeHolder', 'MeanOfTransportation', 'Building', 'NaturalPlace', 'Village', 'Animal', 'Plant', 'Album', 'Film', 'WrittenWork'], names_file=None, id=None)

In [13]:
# Let's look at the 5th label with int2str.
# We'll use this later when we are outputting the prediction.
dbpedia_dataset['train'].features['labels'].int2str(4)

'OfficeHolder'

For our topic classification task, we use `content` as input and try to predict `labels`.

In [14]:
label_list = dbpedia_dataset['train'].unique('labels')
label_list.sort()
label_list

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]

This particular dataset is split between 14 different topics, that will be represented by 14 classes from our model output.

In [15]:
num_labels = len(label_list)
num_labels

14

The "topic" class needs to be renamed to "labels" for the `Trainer` to find it.

In [20]:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)

In [21]:
def merge_title_with_content(example):
    example["content"] = example["title"] + " " + example["content"]
    return example


def encode(batch):
    return tokenizer(
        batch["content"],
        add_special_tokens=True,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        return_tensors="np",
    )

In [22]:
dbpedia_dataset = dbpedia_dataset.map(merge_title_with_content, num_proc=10)

 

HBox(children=(FloatProgress(value=0.0, description='#0', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#1', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#2', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#3', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#4', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#5', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#6', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#7', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#8', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#9', max=56000.0, style=ProgressStyle(description_width='…











 

HBox(children=(FloatProgress(value=0.0, description='#0', max=7000.0, style=ProgressStyle(description_width='i…

 

HBox(children=(FloatProgress(value=0.0, description='#1', max=7000.0, style=ProgressStyle(description_width='i…

  

HBox(children=(FloatProgress(value=0.0, description='#2', max=7000.0, style=ProgressStyle(description_width='i…

 

HBox(children=(FloatProgress(value=0.0, description='#3', max=7000.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#4', max=7000.0, style=ProgressStyle(description_width='i…

  

HBox(children=(FloatProgress(value=0.0, description='#5', max=7000.0, style=ProgressStyle(description_width='i…

  

HBox(children=(FloatProgress(value=0.0, description='#6', max=7000.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#7', max=7000.0, style=ProgressStyle(description_width='i…

 

HBox(children=(FloatProgress(value=0.0, description='#8', max=7000.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#9', max=7000.0, style=ProgressStyle(description_width='i…













In [23]:
dbpedia_dataset = dbpedia_dataset.map(encode, batched=True, num_proc=10)
# dbpedia_dataset = dbpedia_dataset.map(lambda x: encode(x['content']), batched=True)
dbpedia_dataset.set_format(type="torch", 
                           columns=["input_ids", "token_type_ids", 
                                    "attention_mask", "labels"])

 

HBox(children=(FloatProgress(value=0.0, description='#0', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#1', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#2', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#3', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#4', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#5', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#6', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#7', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#8', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#9', max=56.0, style=ProgressStyle(description_width='ini…











 

HBox(children=(FloatProgress(value=0.0, description='#0', max=7.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#1', max=7.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#2', max=7.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#3', max=7.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#4', max=7.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#5', max=7.0, style=ProgressStyle(description_width='init…

  

HBox(children=(FloatProgress(value=0.0, description='#6', max=7.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#7', max=7.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#8', max=7.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#9', max=7.0, style=ProgressStyle(description_width='init…













In [26]:
# Let's look at an example of text from the dataset.
# You can see that we've successfully added the title to the content.
dbpedia_dataset['train']['content'][0]

'E. D. Abbott Ltd  Abbott of Farnham E D Abbott Limited was a British coachbuilding business based in Farnham Surrey trading under that name from 1929. A major part of their output was under sub-contract to motor vehicle manufacturers. Their business closed in 1972.'

In [27]:
# Labels are equally balanced.
np.unique(dbpedia_dataset['train']['labels'], return_counts=True)

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13]),
 array([40000, 40000, 40000, 40000, 40000, 40000, 40000, 40000, 40000,
        40000, 40000, 40000, 40000, 40000]))

In [28]:
# Test data is balanced too.
np.unique(dbpedia_dataset['test']['labels'], return_counts=True)

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13]),
 array([5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000,
        5000, 5000, 5000]))

In [29]:
# We can convert our input_ids back into words to see how the text was tokenized.
print(tokenizer.convert_ids_to_tokens(dbpedia_dataset['train']['input_ids'][0]))

['[CLS]', 'E', '.', 'D', '.', 'Abbott', 'Ltd', 'Abbott', 'of', 'Far', '##nham', 'E', 'D', 'Abbott', 'Limited', 'was', 'a', 'British', 'coach', '##building', 'business', 'based', 'in', 'Far', '##nham', 'Surrey', 'trading', 'under', 'that', 'name', 'from', '1929', '.', 'A', 'major', 'part', 'of', 'their', 'output', 'was', 'under', 'sub', '-', 'contract', 'to', 'motor', 'vehicle', 'manufacturers', '.', 'Their', 'business', 'closed', 'in', '1972', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]'

In [15]:
# DEBUGGING - Splice dataset to use smaller number of samples
#
# train_dataloader = torch.utils.data.DataLoader(
#     dbpedia_dataset["train"].select(
#         list(
#             np.random.randint(low=0, high=len(dbpedia_dataset["train"]) - 1, size=1000)
#         )
#     ),
#     batch_size=BATCH_SIZE,
#     shuffle=True,
# )
# test_dataloader = torch.utils.data.DataLoader(
#     dbpedia_dataset["test"].select(
#         list(np.random.randint(low=0, high=len(dbpedia_dataset["test"]) - 1, size=1000))
#     ),
#     batch_size=BATCH_SIZE,
#     shuffle=False,

In [30]:
# Let's have a look at one example from the DataLoader.
train_dataloader = DataLoader(dbpedia_dataset['train'],
                              batch_size=1,
                              shuffle=True)

sample_item = next(iter(train_dataloader))

sample_item

{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [31]:
# We create the dataloaders.
train_dataloader = DataLoader(dbpedia_dataset['train'],
                              batch_size=512,
                              shuffle=True)

test_dataloader = DataLoader(dbpedia_dataset['test'],
                             batch_size=512,
                             shuffle=True)

In [32]:
baseline_model = AutoModelForSequenceClassification.from_pretrained(
    BASE_MODEL_NAME,
    output_attentions=False,
    output_hidden_states=False,
    return_dict=True,
    num_labels=num_labels
)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435779157.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

In [38]:
prediction = baseline_model(sample_item['input_ids'], sample_item['attention_mask'])

In [39]:
prediction

SequenceClassifierOutput([('logits',
                           tensor([[-0.0932, -0.7690, -0.6399, -0.1519,  0.0639,  0.8291, -0.3584, -0.0162,
                                    -0.9446, -0.0527, -0.2492, -0.0557,  0.3037, -0.2569]],
                                  grad_fn=<AddmmBackward>))])

In [40]:
top_prediction = prediction.logits.detach().numpy().argmax().item()
print(top_prediction)
print(dbpedia_dataset['train'].features['labels'].int2str(top_prediction))
print(tokenizer.convert_ids_to_tokens(sample_item['input_ids'][0]))

5
MeanOfTransportation
['[CLS]', 'New', 'Castle', 'Area', 'School', 'District', 'New', 'Castle', 'Area', 'School', 'District', 'is', 'a', 'public', 'school', 'district', 'located', 'in', 'Lawrence', 'County', 'Pennsylvania', '.', 'The', 'district', 'serves', 'the', 'city', 'of', 'New', 'Castle', 'and', 'Taylor', 'Township', '.', 'New', 'Castle', 'Area', 'School', 'District', 'encompasses', 'approximately', '13', 'square', 'miles', '(', '34', 'km', '##2', ')', '.', 'According', 'to', '2007', 'local', 'census', 'data', 'it', 'served', 'a', 'resident', 'population', 'of', '264', '##64', '.', 'The', '2010', 'census', 'found', 'the', 'population', 'declined', 'to', '242', '##86', 'people', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PA

In [54]:
MAX_TOKEN_LENGTH = 128

In [55]:
# class dbpediaDataset(Dataset):

#   def __init__(
#       self,
#       data: Dataset,
#       tokenizer: AutoTokenizer.from_pretrained(BASE_MODEL_NAME),
#       max_token_len: int = MAX_TOKEN_LENGTH):
    
#     self.tokenizer = tokenizer
#     self.data = data
#     self.max_token_len = max_token_len

#   def __len__(self):
#     return len(self.data)

#   def __getitem__(self, index: int):
#     data_row = self.data

In [56]:
# class dbpediaDataModule(pl.LightningDataModule):

#   def __init__(self, train_dataloader, test_dataloader, batch_size=8, max_token_len=128)

In [41]:
dbpedia_dataset['train']

Dataset({
    features: ['attention_mask', 'content', 'input_ids', 'labels', 'title', 'token_type_ids'],
    num_rows: 560000
})

In [55]:
def get_topic(sentence, model, tokenize=tokenizer):
  # tokenize the input
  inputs = tokenizer(
        sentence,
        add_special_tokens=True,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        return_tensors="pt",
    )
  # ensure model and inputs are on the same device (GPU)
  inputs = {name: tensor.cuda() for name, tensor in inputs.items()}
  model = model.cuda()
  # get prediction - 14 labels
  with torch.no_grad():
    predictions = model(**inputs)[0].cpu().numpy()
  # get the top prediction class and convert it to its associated label
  top_prediction = predictions.argmax().item()
  return dbpedia_dataset['train'].features['labels'].int2str(top_prediction)

In [56]:
sample_item = dbpedia_dataset['train']['content'][0]
sample_item

'E. D. Abbott Ltd  Abbott of Farnham E D Abbott Limited was a British coachbuilding business based in Farnham Surrey trading under that name from 1929. A major part of their output was under sub-contract to motor vehicle manufacturers. Their business closed in 1972.'

In [57]:
get_topic(sample_item, baseline_model)

'OfficeHolder'

In [51]:
tokenizer(sample_item, return_tensors='pt').items()

dict_items([('input_ids', tensor([[  101,   142,   119,   141,   119, 15176,  4492, 15176,  1104,  8040,
         15898,   142,   141, 15176,  5975,  1108,   170,  1418,  2154, 12851,
          1671,  1359,  1107,  8040, 15898,  9757,  6157,  1223,  1115,  1271,
          1121,  3762,   119,   138,  1558,  1226,  1104,  1147,  5964,  1108,
          1223,  4841,   118,  2329,  1106,  5968,  3686,  9263,   119,  2397,
          1671,  1804,  1107,  2388,   119,   102]])), ('token_type_ids', tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]])), ('attention_mask', tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]]))])