### 0. Install and load library

In [None]:
#Restart kernel after installation: Runtime -> Restart runtime

#!pip install -U transformers sentencepiece datasets

In [29]:
from datasets import load_dataset
from datasets import DatasetDict
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, pipeline
import torch
from tqdm import tqdm

### 1. Load dataset

In [2]:
datasets = load_dataset("ag_news")

Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)


### 2. Split dataset

In [3]:
datasets.keys() #dict_keys(['train', 'test'])

# split original train set into train/val with 8/2 ratio
# gather together to process splits at once
train_valid = datasets['train'].train_test_split(test_size=.2)
train_valid_test_datasets = DatasetDict({
    'train': train_valid['train'],
    'valid': train_valid['test'],
    'test': datasets['test']
})

Loading cached split indices for dataset at /root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548/cache-c0909072cca7619f.arrow and /root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548/cache-0b15a6520adbb42b.arrow


In [4]:
# ['World', 'Sports', 'Business', 'Sci/Tech']
dataset_labels = datasets['train'].features['label'].names
id2label = {id: label for id, label in enumerate(dataset_labels)}
label2id = {label: id for id, label in id2label.items()}

### 3. Load model and tokenizer with correct configs 

In [7]:
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=len(dataset_labels), label2id=label2id)
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'pre_clas

### 4. Tokenize and convert dataset to DataLoader

In [9]:
encoded_datasets = train_valid_test_datasets.map(lambda examples: tokenizer(examples['text'], truncation=True, padding='max_length'), batched=True)
encoded_datasets = encoded_datasets.map(lambda examples: {'labels': examples['label']}, batched=True)
encoded_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'], device='cuda')

train_dataloader = torch.utils.data.DataLoader(encoded_datasets['train'], batch_size=8)
val_dataloader = torch.utils.data.DataLoader(encoded_datasets['valid'], batch_size=8)
test_dataloader = torch.utils.data.DataLoader(encoded_datasets['test'], batch_size=8)

Loading cached processed dataset at /root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548/cache-ac74273562359cbd.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548/cache-8d6a8250d98f19bc.arrow


HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))

Loading cached processed dataset at /root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548/cache-e0245248aaaccc68.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548/cache-941fdf2fe65c22db.arrow





HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))




### 5. Prepare training arguments

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9, verbose=True) # unnecessary, for fun only
epochs = 3

model.to(device)

### 6. Training

In [None]:
print('*'*50)

for epoch in range(epochs):
    model.train()
     
    training_loss = 0.0 
    correct_pred = {label: 0 for label in dataset_labels}
    total_pred = {label: 0 for label in dataset_labels}

    for i, batch in enumerate(train_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        training_loss += loss.item()
        if i%2000 == 1999:
            print(f"\n[Epoch-{epoch+1} {i+1}] {training_loss/2000}")
            training_loss = 0.0
    
    scheduler.step()
    
    # validation
    model.eval()
    with torch.no_grad():
        for batch in val_dataloader:
            batch = {k: v.to(device) for k,v in batch.items()}
            outputs = model(**batch)
            logits = outputs.logits
            _, predictions = torch.max(logits, 1)
            labels = batch['labels']

            # collect info for each class
            for label, prediction in zip(labels, predictions):
                if label == prediction:
                    correct_pred[id2label[label.item()]] += 1
                total_pred[id2label[label.item()]] += 1
    
    for label, correct_count in correct_pred.items():
        acc = 100 * float(correct_count) / total_pred[label]
        print(f"\nAccuracy for class {label} is {acc:.3f}")

### 7. Predicts on Test set

In [28]:
device = 'cpu'
model.to(device)

correct_pred = {label: 0 for label in dataset_labels}
total_pred = {label: 0 for label in dataset_labels}

model.eval()
with torch.no_grad():
    for batch in tqdm(test_dataloader, position=0, leave=True):
        batch = {k: v.to(device) for k,v in batch.items()}
        outputs = model(**batch)
        logits = outputs.logits
        _, predictions = torch.max(logits, 1)
        labels = batch['labels']

        # collect info for each class
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[id2label[label.item()]] += 1
            total_pred[id2label[label.item()]] += 1

for label, correct_count in correct_pred.items():
    acc = 100 * float(correct_count) / total_pred[label]
    print(f"\nAccuracy for class {label} is {acc:.3f}")

100%|██████████| 13/13 [00:48<00:00,  3.71s/it]


Accuracy for class World is 90.000

Accuracy for class Sports is 100.000

Accuracy for class Business is 91.667

Accuracy for class Sci/Tech is 91.892





### 8. Inference using pipeline

In [30]:
ag_news_pipeline = pipeline('text-classification', model=model, tokenizer=tokenizer)

In [None]:
"""
datasets['test'][0]

>>> {
    'label': 2,
    'text': "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul."
    }

"""

In [40]:
test_input = datasets['test'][0]['text']
outputs = ag_news_pipeline(test_input)
print(outputs)

"""
>>> [{'label': 'LABEL_2', 'score': 0.9857332110404968}]
"""

[{'label': 'LABEL_2', 'score': 0.9857332110404968}]


"\n>>> [{'label': 'LABEL_2', 'score': 0.9857332110404968}]\n"