In [1]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sn
from pprint import pprint
import textwrap

# Appearance of the Notebook
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# PyTorch
import torch

# Hugging Face 
from transformers import pipeline, set_seed

# Import this module with autoreload
%load_ext autoreload
%autoreload 2
import transformermodels as tm
print(f'Package version: {tm.__version__}')
print(f'PyTorch version: {torch.__version__}')

Package version: 0.0.post1.dev3+g8362a6c.d20240713
PyTorch version: 2.3.1+cu121


In [2]:
# GPU checks
is_cuda = torch.cuda.is_available()
print(f'CUDA available: {is_cuda}')
print(f'Number of GPUs found:  {torch.cuda.device_count()}')

if is_cuda:
    print(f'Current device ID:     {torch.cuda.current_device()}')
    print(f'GPU device name:       {torch.cuda.get_device_name(0)}')
    print(f'CUDNN version:         {torch.backends.cudnn.version()}')
    device_str = 'cuda:0'
    torch.cuda.empty_cache() 
else:
    device_str = 'cpu'
device = torch.device(device_str)
print()
print(f'Device for model training/inference: {device}')

CUDA available: True
Number of GPUs found:  1
Current device ID:     0
GPU device name:       NVIDIA GeForce RTX 3070 Laptop GPU
CUDNN version:         8902

Device for model training/inference: cuda:0


In [16]:
# Helper functions
def wrap(x):
    return textwrap.fill(x, replace_whitespace=False, fix_sentence_endings=True)

# Directories
data_dir = os.path.join(os.environ.get('HOME'), 'data', 'transformers')

# The BBC News Data Set
csv_file = os.path.join(data_dir, 'bbc_text_cls.csv')

In [4]:
mlm = pipeline('fill-mask', device=device)

No model was supplied, defaulted to distilbert/distilroberta-base and revision ec58a5b (https://huggingface.co/distilbert/distilroberta-base).
Using a pipeline without specifying a model name and revision in production is not recommended.


config.json:   0%|          | 0.00/480 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/331M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert/distilroberta-base were not used when initializing RobertaForMaskedLM: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [6]:
display((mlm('The cat <mask> over the box')))

[{'score': 0.10449142009019852,
  'token': 13855,
  'token_str': ' jumps',
  'sequence': 'The cat jumps over the box'},
 {'score': 0.057583652436733246,
  'token': 33265,
  'token_str': ' crawling',
  'sequence': 'The cat crawling over the box'},
 {'score': 0.04840477555990219,
  'token': 33189,
  'token_str': ' leaping',
  'sequence': 'The cat leaping over the box'},
 {'score': 0.04716692492365837,
  'token': 10907,
  'token_str': ' climbing',
  'sequence': 'The cat climbing over the box'},
 {'score': 0.03080764412879944,
  'token': 32564,
  'token_str': ' leaps',
  'sequence': 'The cat leaps over the box'}]

In [13]:
df = pd.read_csv(csv_file)
display(df.head())
print(df['labels'].unique())

# Pick a label
label = 'business'
texts = df.loc[df['labels'] == label, 'text']
print(len(texts))
print(texts[5])

Unnamed: 0,text,labels
0,Ad sales boost Time Warner profit\n\nQuarterly...,business
1,Dollar gains on Greenspan speech\n\nThe dollar...,business
2,Yukos unit buyer faces loan claim\n\nThe owner...,business
3,High fuel prices hit BA's profits\n\nBritish A...,business
4,Pernod takeover talk lifts Domecq\n\nShares in...,business


['business' 'entertainment' 'politics' 'sport' 'tech']
510
Japan narrowly escapes recession

Japan's economy teetered on the brink of a technical recession in the three months to September, figures show.

Revised figures indicated growth of just 0.1% - and a similar-sized contraction in the previous quarter. On an annual basis, the data suggests annual growth of just 0.2%, suggesting a much more hesitant recovery than had previously been thought. A common technical definition of a recession is two successive quarters of negative growth.

The government was keen to play down the worrying implications of the data. "I maintain the view that Japan's economy remains in a minor adjustment phase in an upward climb, and we will monitor developments carefully," said economy minister Heizo Takenaka. But in the face of the strengthening yen making exports less competitive and indications of weakening economic conditions ahead, observers were less sanguine. "It's painting a picture of a recovery..

In [30]:
np.random.seed(1234)
doc = np.random.choice(texts, size=1, replace=False)[0]
print(wrap(doc))

EU aiming to fuel development aid

European Union finance ministers
meet on Thursday to discuss proposals, including a tax on jet fuel, to
boost development aid for poorer nations.

The policy makers are to
ask for a report into how more development money can be raised, the EU
said.  The world's richest countries have said they want to increase
the amount of aid they give to 0.7% of their annual gross national
income by 2015. Airlines have reacted strongly against the proposed
fuel levy.

Profits have been under pressure in the airline industry,
with low-cost firms driving down prices and demand dipping after the
11 September terrorist attacks and the outbreak of the killer SARS
virus.

Things have picked up, but some European and US companies are
teetering on the brink of bankruptcy.  At present, the fuel used by
airlines enjoys either a very low tax rate or is untaxed in EU member
states.  "Of course we applaud humanitarian initiatives, but why
target the airlines?"  said Ulrich Schu

In [26]:
mlm('EU aiming to fuel development <mask>')

[{'score': 0.08735514432191849,
  'token': 4026,
  'token_str': ' agenda',
  'sequence': 'EU aiming to fuel development agenda'},
 {'score': 0.08316706866025925,
  'token': 8600,
  'token_str': ' boom',
  'sequence': 'EU aiming to fuel development boom'},
 {'score': 0.05308239907026291,
  'token': 1042,
  'token_str': ' costs',
  'sequence': 'EU aiming to fuel development costs'},
 {'score': 0.03144887089729309,
  'token': 4,
  'token_str': '.',
  'sequence': 'EU aiming to fuel development.'},
 {'score': 0.02905908413231373,
  'token': 1170,
  'token_str': ' efforts',
  'sequence': 'EU aiming to fuel development efforts'}]