<a href="https://colab.research.google.com/github/akiabe/Transformers/blob/master/Fine_tune_BERT_for_Multiclass_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/cd/40/866cbfac4601e0f74c7303d533a9c5d4a53858bd402e08e3e294dd271f25/transformers-4.2.1-py3-none-any.whl (1.8MB)
[K     |████████████████████████████████| 1.8MB 13.0MB/s 
[?25hCollecting tokenizers==0.9.4
[?25l  Downloading https://files.pythonhosted.org/packages/0f/1c/e789a8b12e28be5bc1ce2156cf87cb522b379be9cadc7ad8091a4cc107c4/tokenizers-0.9.4-cp36-cp36m-manylinux2010_x86_64.whl (2.9MB)
[K     |████████████████████████████████| 2.9MB 37.8MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 49.8MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893261 sha256=257bcfb37914

In [17]:
import pandas as pd

In [18]:
df = pd.read_csv('newsCorpora.csv', header=None, sep='\t', names=['ID', 'TITLE', 'URL', 'PUBLISHER', 'CATEGORY', 'STORY', 'HOSTNAME', 'TIMESTAMP'])
df = df.loc[df['PUBLISHER'].isin(['Reuters', 'Huffington Post', 'Businessweek', 'Contactmusic.com', 'Daily Mail']), ['TITLE', 'CATEGORY']]

In [19]:
from sklearn import model_selection

train, valid_test = model_selection.train_test_split(
    df, 
    test_size=0.2, 
    shuffle=True, 
    random_state=123, 
    stratify=df['CATEGORY'],
)

valid, test = model_selection.train_test_split(
    valid_test, 
    test_size=0.5, 
    shuffle=True, 
    random_state=123, 
    stratify=valid_test['CATEGORY']
)

train.reset_index(drop=True, inplace=True)
valid.reset_index(drop=True, inplace=True)
test.reset_index(drop=True, inplace=True)

In [20]:
train.head(3)

Unnamed: 0,TITLE,CATEGORY
0,Fitch Lowers South Africa Credit-Rating Outloo...,b
1,FOREX-Dollar rises on US rate speculation afte...,b
2,UPDATE 2-Mexico's lower house generally approv...,b


In [21]:
train["CATEGORY"].value_counts()

b    4502
e    4223
t    1219
m     728
Name: CATEGORY, dtype: int64

In [22]:
valid.head(3)

Unnamed: 0,TITLE,CATEGORY
0,China's IBM Scrutiny Highlights High Stakes in...,b
1,Kim Kardashian confirms wedding will NOT be te...,e
2,Amazon under investigation after worker was cr...,b


In [23]:
valid["CATEGORY"].value_counts()

b    562
e    528
t    153
m     91
Name: CATEGORY, dtype: int64

In [24]:
test.head(3)

Unnamed: 0,TITLE,CATEGORY
0,Google set to open its first flagship store in...,t
1,How fear can be 'programmed' into infants by t...,m
2,US STOCKS-S&P 500 on path to new closing high;...,b


In [25]:
test["CATEGORY"].value_counts()

b    563
e    528
t    152
m     91
Name: CATEGORY, dtype: int64

In [29]:
y_train = pd.get_dummies(train, columns=['CATEGORY'])[['CATEGORY_b', 'CATEGORY_e', 'CATEGORY_t', 'CATEGORY_m']].values
y_valid = pd.get_dummies(valid, columns=['CATEGORY'])[['CATEGORY_b', 'CATEGORY_e', 'CATEGORY_t', 'CATEGORY_m']].values
y_test = pd.get_dummies(test, columns=['CATEGORY'])[['CATEGORY_b', 'CATEGORY_e', 'CATEGORY_t', 'CATEGORY_m']].values

In [35]:
print(f"y_train: {y_train}")
print(f"y_train shape: {y_train.shape}")
print(f"y_valid shape: {y_valid.shape}")
print(f"y_test shape: {y_test.shape}")

y_train: [[1 0 0 0]
 [1 0 0 0]
 [1 0 0 0]
 ...
 [1 0 0 0]
 [0 1 0 0]
 [0 1 0 0]]
y_train shape: (10672, 4)
y_valid shape: (1334, 4)
y_test shape: (1334, 4)


In [30]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import transformers
from transformers import BertTokenizer, BertModel
import numpy as np
from tqdm import tqdm

In [31]:
class NewsDataset(Dataset):
  def __init__(
      self, 
      X,
      y, 
      tokenizer,
      max_len,
      ):
    self.X = X
    self.y = y
    self.tokenizer = tokenizer
    self.max_len = max_len
  
  def __len__(self):
    return len(self.y)
  
  def __getitem__(self, index):
    text = self.X[index]
    inputs = self.tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=self.max_len,
        pad_to_max_length=True,
        )
    ids = inputs["input_id"]
    mask = inputs["attention_mask"]

    return {
        "ids": torch.LongTensor(ids),
        "mask": torch.LongTensor(mask),
        "labels": torch.Tensor(self.y[index]),
        }

In [None]:
max_len = 