<a href="https://colab.research.google.com/github/JayThibs/Weak-Supervised-Learning-Case-Study/blob/main/text_classifier/notebooks/01_dbpedia_14_bert_classification_exploration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [28]:
!nvidia-smi

Wed Apr 28 14:10:57 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   34C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [29]:
!pip install transformers==4.5.1 --quiet
!pip install pytorch_lightning==1.2.10 --quiet
!pip install wandb --quiet
!pip install datasets



In [30]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, AdamW, get_linear_schedule_with_warmup, AutoModelForSequenceClassification, AutoTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional.classification import auroc
from datasets import load_dataset

import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt

from tqdm import tqdm
import wandb

In [51]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

RANDOM_SEED = 42
BASE_MODEL_NAME = 'bert-base-cased'

np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x7ff3f95afb70>

In [52]:
dbpedia_dataset = load_dataset('dbpedia_14')

Reusing dataset d_bpedia14 (/root/.cache/huggingface/datasets/d_bpedia14/dbpedia_14/2.0.0/7f0577ea0f4397b6b89bfe5c5f2c6b1b420990a1fc5e8538c7ab4ec40e46fa3e)


In [53]:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)

In [54]:
def merge_title_with_content(example):
    example["content"] = example["title"] + " " + example["content"]
    return example


def encode(batch):
    return tokenizer(
        batch["content"],
        add_special_tokens=True,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        return_tensors="np"
    )

In [55]:
num_classes = dbpedia_dataset["train"].info.features["label"].num_classes
dbpedia_dataset = dbpedia_dataset.map(merge_title_with_content, num_proc=10)
dbpedia_dataset = dbpedia_dataset.rename_column("label", "labels")
dbpedia_dataset = dbpedia_dataset.map(encode, batched=True, num_proc=10)
dbpedia_dataset.set_format(type="torch", 
                           columns=["input_ids", "token_type_ids", 
                                    "attention_mask", "labels"])

 

HBox(children=(FloatProgress(value=0.0, description='#0', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#1', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#2', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#3', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#4', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#5', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#6', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#7', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#8', max=56000.0, style=ProgressStyle(description_width='…

 

HBox(children=(FloatProgress(value=0.0, description='#9', max=56000.0, style=ProgressStyle(description_width='…











 

HBox(children=(FloatProgress(value=0.0, description='#0', max=7000.0, style=ProgressStyle(description_width='i…

 

HBox(children=(FloatProgress(value=0.0, description='#1', max=7000.0, style=ProgressStyle(description_width='i…

 

HBox(children=(FloatProgress(value=0.0, description='#2', max=7000.0, style=ProgressStyle(description_width='i…

  

HBox(children=(FloatProgress(value=0.0, description='#3', max=7000.0, style=ProgressStyle(description_width='i…

 

HBox(children=(FloatProgress(value=0.0, description='#4', max=7000.0, style=ProgressStyle(description_width='i…

 

HBox(children=(FloatProgress(value=0.0, description='#5', max=7000.0, style=ProgressStyle(description_width='i…

  

HBox(children=(FloatProgress(value=0.0, description='#6', max=7000.0, style=ProgressStyle(description_width='i…

 

HBox(children=(FloatProgress(value=0.0, description='#7', max=7000.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#9', max=7000.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='#8', max=7000.0, style=ProgressStyle(description_width='i…











 

HBox(children=(FloatProgress(value=0.0, description='#0', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#1', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#2', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#3', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#4', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#5', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#6', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#7', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#8', max=56.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#9', max=56.0, style=ProgressStyle(description_width='ini…











 

HBox(children=(FloatProgress(value=0.0, description='#0', max=7.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#1', max=7.0, style=ProgressStyle(description_width='init…

  

HBox(children=(FloatProgress(value=0.0, description='#2', max=7.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#3', max=7.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#4', max=7.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#5', max=7.0, style=ProgressStyle(description_width='init…

  

HBox(children=(FloatProgress(value=0.0, description='#6', max=7.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#7', max=7.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#8', max=7.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#9', max=7.0, style=ProgressStyle(description_width='init…













In [None]:
dbpedia_dataset = dbpedia_dataset.rename_column("label", "labels")
dbpedia_dataset = dbpedia_dataset.map(encode, batched=True, num_proc=10)
dbpedia_dataset.set_format(type="torch", 
                           columns=["input_ids", "token_type_ids", 
                                    "attention_mask", "labels"])

In [65]:
dbpedia_dataset

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'content', 'input_ids', 'labels', 'title', 'token_type_ids'],
        num_rows: 560000
    })
    test: Dataset({
        features: ['attention_mask', 'content', 'input_ids', 'labels', 'title', 'token_type_ids'],
        num_rows: 70000
    })
})

In [64]:
dbpedia_dataset['train']['content'][0:2]

['E. D. Abbott Ltd  Abbott of Farnham E D Abbott Limited was a British coachbuilding business based in Farnham Surrey trading under that name from 1929. A major part of their output was under sub-contract to motor vehicle manufacturers. Their business closed in 1972.',
 "Schwan-Stabilo  Schwan-STABILO is a German maker of pens for writing colouring and cosmetics as well as markers and highlighters for office use. It is the world's largest manufacturer of highlighter pens Stabilo Boss."]

In [58]:
# classes are equally balanced
np.unique(dbpedia_dataset['train']['labels'], return_counts=True)

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13]),
 array([40000, 40000, 40000, 40000, 40000, 40000, 40000, 40000, 40000,
        40000, 40000, 40000, 40000, 40000]))

In [59]:
np.unique(dbpedia_dataset['test']['labels'], return_counts=True)

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13]),
 array([5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000,
        5000, 5000, 5000]))

In [60]:
# DEBUGGING - Splice dataset to use smaller number of samples
#
# train_dataloader = torch.utils.data.DataLoader(
#     dbpedia_dataset["train"].select(
#         list(
#             np.random.randint(low=0, high=len(dbpedia_dataset["train"]) - 1, size=1000)
#         )
#     ),
#     batch_size=BATCH_SIZE,
#     shuffle=True,
# )
# test_dataloader = torch.utils.data.DataLoader(
#     dbpedia_dataset["test"].select(
#         list(np.random.randint(low=0, high=len(dbpedia_dataset["test"]) - 1, size=1000))
#     ),
#     batch_size=BATCH_SIZE,
#     shuffle=False,

In [56]:
train_dataloader = DataLoader(dbpedia_dataset['train'],
                              batch_size=1,
                              shuffle=True)

In [57]:
next(iter(train_dataloader))

{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0