## Importing Libraries

In [5]:
import transformers
from datasets import load_dataset, Dataset, DatasetDict
from transformers import DistilBertTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf
import numpy as np

# Set verbosity to error
transformers.logging.set_verbosity_error()

## Preparing Dataset

In [6]:
# Load AG News dataset
dataset_name = 'ag_news'
ag_news = load_dataset(dataset_name)

In [7]:
# Sample only 25% of the data
sample_size = 0.01

def sample_dataset(dataset, sample_size):
    return dataset.shuffle(seed=42).select(range(int(len(dataset) * sample_size)))

## Creating Train and Test Dataset

In [8]:
# Sample 25% of the train and test sets
train_dataset = sample_dataset(ag_news['train'], sample_size)
test_dataset = sample_dataset(ag_news['test'], sample_size)

print(train_dataset)
print(train_dataset[15:20])

# Print out label names
print("\nCategory labels used:", train_dataset.features['label'].names)

Dataset({
    features: ['text', 'label'],
    num_rows: 1200
})
{'text': ['They #146;re in the wrong ATHENS -- Matt Emmons was focusing on staying calm. He should have been focusing on the right target.', "Mularkey Sticking With Bledsoe As Bills QB (AP) AP - Mike Mularkey has a message to those clamoring for rookie quarterback J.P. Losman to replace Drew Bledsoe as Buffalo's starter. Not yet.", 'Greek membership of eurozone not in doubt BRUSSELS - Greece #39;s membership of the eurozone is not in doubt despite a damaging review of its budget data stretching back five years, a European Commission official said Monday.', "Some fear it's a passport to identity theft It's December 2005 and you're all set for Christmas in Vienna. You have your most fashionable cold-weather gear, right down to the red maple leaves embroidered on your jacket and backpack, to conceal your American citizenship from hostile denizens of Europe.", 'U.S. Plans Crackdown on Piracy, Counterfeiting  WASHINGTON (Reute

## Encoding Text

In [9]:
model_name = 'distilbert-base-uncased'
db_tokenizer = DistilBertTokenizer.from_pretrained(model_name)

In [10]:
def tokenize(batch):
    return db_tokenizer(batch['text'],
                       padding=True,
                       truncation=True)

# Apply tokenization
enc_train_dataset = train_dataset.map(tokenize, batched=True, batch_size=None)
enc_test_dataset = test_dataset.map(tokenize, batched=True, batch_size=None)

# Print tokenized data
print(enc_train_dataset[0:5])

{'text': ['Bangladesh paralysed by strikes Opposition activists have brought many towns and cities in Bangladesh to a halt, the day after 18 people died in explosions at a political rally.', 'Desiring Stability Redskins coach Joe Gibbs expects few major personnel changes in the offseason and wants to instill a culture of stability in Washington.', 'Will Putin #39;s Power Play Make Russia Safer? Outwardly, Russia has not changed since the barrage of terrorist attacks that culminated in the school massacre in Beslan on Sept.', 'U2 pitches for Apple New iTunes ads airing during baseball games Tuesday will feature the advertising-shy Irish rockers.', 'S African TV in beheading blunder Public broadcaster SABC apologises after news bulletin shows footage of American beheaded in Iraq.'], 'label': [0, 1, 0, 3, 0], 'input_ids': [[101, 7269, 11498, 2135, 6924, 2011, 9326, 4559, 10134, 2031, 2716, 2116, 4865, 1998, 3655, 1999, 7269, 2000, 1037, 9190, 1010, 1996, 2154, 2044, 2324, 2111, 2351, 1999

In [11]:
# Print out attention masks
print("Text:", enc_train_dataset[1].get('text'))
print("Attention Mask:", enc_train_dataset[1].get('attention_mask'))

Text: Desiring Stability Redskins coach Joe Gibbs expects few major personnel changes in the offseason and wants to instill a culture of stability in Washington.
Attention Mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


## Prepare TensorFlow datasets

In [None]:
tokenizer_columns = db_tokenizer.model_input_names

def to_tf_dataset(dataset, tokenizer_columns, batch_size=64):
    return dataset.to_tf_dataset(
        columns=tokenizer_columns,
        label_cols=["label"],
        shuffle=True,
        batch_size=batch_size
    )

In [13]:
batch_size = 64
train_dataset_tf = to_tf_dataset(enc_train_dataset, tokenizer_columns, batch_size)
val_dataset_tf = to_tf_dataset(enc_test_dataset, tokenizer_columns, batch_size)

## Load and Compile Model

In [14]:
# Load transformer model
num_labels = len(train_dataset.features['label'].names)
sentiment_model = TFAutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)

# Freeze the first layer if needed (uncommon in practice)
sentiment_model.layers[0].trainable = True

# Compile the model
sentiment_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=tf.metrics.SparseCategoricalAccuracy()
)



## Train the model

In [None]:
sentiment_model.fit(
    train_dataset_tf,
    validation_data=val_dataset_tf,
    epochs=5  # Adjust epochs as needed
)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tf_keras.src.callbacks.History at 0x3222cd250>

## Prepare Inference Data

In [21]:
# Prepare data for inference
infer_data = {
    'text': [
        'The stock market soared to new heights today as major indices hit record highs.',
        'Scientists have discovered a new species of dinosaur in the Arctic region.',
        'The local soccer team won the championship game after a thrilling penalty shootout.',
        'The latest breakthrough in AI technology promises to revolutionize the way we interact with machines.'
    ],
    'label': [2, 3, 1, 3]  # Corresponding labels for AG News classes
}

infer_dataset = Dataset.from_dict(infer_data)
ds_dict = DatasetDict({'infer': infer_dataset})

In [22]:
# Tokenize the inference data
enc_infer_dataset = ds_dict.map(tokenize, batched=True, batch_size=None)

Map:   0%|          | 0/4 [00:00<?, ? examples/s]

In [23]:
# Convert inference dataset to TensorFlow format
infer_final_dataset = enc_infer_dataset["infer"].to_tf_dataset(
    columns=tokenizer_columns,
    shuffle=False,
    batch_size=batch_size
)

## Perform inference

In [24]:
predictions = sentiment_model.predict(infer_final_dataset)
pred_label_ids = np.argmax(predictions.logits, axis=1)



In [25]:
# Print predictions
labels = enc_train_dataset.features['label'].names
for index, pred_label_id in enumerate(pred_label_ids):
    print("\nText:", infer_data['text'][index],
          "\n\tPredicted Label:", labels[pred_label_ids[index]],
          "\n\tTrue Label:", labels[infer_data['label'][index]])


Text: The stock market soared to new heights today as major indices hit record highs. 
	Predicted Label: Business 
	True Label: Business

Text: Scientists have discovered a new species of dinosaur in the Arctic region. 
	Predicted Label: Sci/Tech 
	True Label: Sci/Tech

Text: The local soccer team won the championship game after a thrilling penalty shootout. 
	Predicted Label: Sports 
	True Label: Sports

Text: The latest breakthrough in AI technology promises to revolutionize the way we interact with machines. 
	Predicted Label: Sci/Tech 
	True Label: Sci/Tech
