In [15]:
import os
os.environ["WANDB_DISABLED"] = "true"

import pandas as pd
import torch
import transformers
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, EarlyStoppingCallback

In [16]:
df = pd.read_csv("topic-analysis-dataset.csv")
print(df.head())

label_encoder = LabelEncoder()
df['label'] = label_encoder.fit_transform(df['topic'])
num_labels = len(label_encoder.classes_)

train_texts, val_texts, train_labels, val_labels = train_test_split(
    df['text'], df['label'],
    test_size=0.2,
    random_state=42,
    stratify=df['label']  #to ensure equal class distribution
)

train_dataset = Dataset.from_dict({'text': train_texts.tolist(), 'label': train_labels.tolist()})
val_dataset = Dataset.from_dict({'text': val_texts.tolist(), 'label': val_labels.tolist()})

                                                text topic
0  usually , he would be tearing around the livin...  book
1  but just one look at a minion sent him practic...  book
2  that had been megan 's plan when she got him d...  book
3  he 'd seen the movie almost by mistake , consi...  book
4  she liked to think being surrounded by adults ...  book


In [3]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

def tokenize(batch):
    return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=256)

train_dataset = train_dataset.map(tokenize, batched=True)
val_dataset = val_dataset.map(tokenize, batched=True)

train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

Map: 100%|██████████| 60000/60000 [00:15<00:00, 3871.26 examples/s]
Map: 100%|██████████| 15000/15000 [00:03<00:00, 3801.82 examples/s]


In [4]:
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=num_labels)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
training_args = TrainingArguments(
    output_dir="./bert-topic-classifier",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,  # increase for speed on CPU with 32 GB RAM
    per_device_eval_batch_size=16,
    num_train_epochs=1,              # DistilBERT trains quickly, so 3 epochs is enough
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    save_total_limit=2,
    fp16=True,  # disable if training on CPU
    disable_tqdm=False,
    seed=42
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = torch.argmax(torch.tensor(logits), dim=-1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds, average="weighted"),
    }

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)

trainer.train()

model.save_pretrained("./bert-topic-classifier")
tokenizer.save_pretrained("./bert-topic-classifier")

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
  1%|▏         | 51/3750 [00:06<07:24,  8.32it/s]

{'loss': 0.4743, 'grad_norm': 0.455658882856369, 'learning_rate': 1.9733333333333336e-05, 'epoch': 0.01}


  3%|▎         | 101/3750 [00:12<07:17,  8.33it/s]

{'loss': 0.0423, 'grad_norm': 0.12707911431789398, 'learning_rate': 1.9477333333333334e-05, 'epoch': 0.03}


  4%|▍         | 151/3750 [00:18<07:14,  8.28it/s]

{'loss': 0.0225, 'grad_norm': 0.06316158920526505, 'learning_rate': 1.921066666666667e-05, 'epoch': 0.04}


  5%|▌         | 201/3750 [00:24<07:09,  8.26it/s]

{'loss': 0.0112, 'grad_norm': 0.04801322892308235, 'learning_rate': 1.8944000000000004e-05, 'epoch': 0.05}


  7%|▋         | 251/3750 [00:30<07:05,  8.22it/s]

{'loss': 0.0222, 'grad_norm': 0.07207075506448746, 'learning_rate': 1.8677333333333335e-05, 'epoch': 0.07}


  8%|▊         | 301/3750 [00:36<07:00,  8.21it/s]

{'loss': 0.0263, 'grad_norm': 0.04337659478187561, 'learning_rate': 1.8410666666666666e-05, 'epoch': 0.08}


  9%|▉         | 351/3750 [00:42<06:54,  8.20it/s]

{'loss': 0.0113, 'grad_norm': 0.02709384448826313, 'learning_rate': 1.8144e-05, 'epoch': 0.09}


 11%|█         | 401/3750 [00:48<06:47,  8.22it/s]

{'loss': 0.0195, 'grad_norm': 4.015676021575928, 'learning_rate': 1.7877333333333336e-05, 'epoch': 0.11}


 12%|█▏        | 451/3750 [00:54<06:42,  8.19it/s]

{'loss': 0.0273, 'grad_norm': 0.022494181990623474, 'learning_rate': 1.7616000000000002e-05, 'epoch': 0.12}


 13%|█▎        | 501/3750 [01:00<06:38,  8.16it/s]

{'loss': 0.0013, 'grad_norm': 0.018728483468294144, 'learning_rate': 1.7349333333333337e-05, 'epoch': 0.13}


 15%|█▍        | 551/3750 [01:06<06:30,  8.20it/s]

{'loss': 0.0044, 'grad_norm': 0.012642888352274895, 'learning_rate': 1.7082666666666668e-05, 'epoch': 0.15}


 16%|█▌        | 601/3750 [01:16<07:32,  6.96it/s]

{'loss': 0.0038, 'grad_norm': 0.011263432912528515, 'learning_rate': 1.6816e-05, 'epoch': 0.16}


 17%|█▋        | 651/3750 [01:22<06:15,  8.25it/s]

{'loss': 0.0105, 'grad_norm': 0.06334268301725388, 'learning_rate': 1.6549333333333334e-05, 'epoch': 0.17}


 19%|█▊        | 701/3750 [01:28<06:10,  8.24it/s]

{'loss': 0.0194, 'grad_norm': 0.010712861083447933, 'learning_rate': 1.628266666666667e-05, 'epoch': 0.19}


 20%|██        | 751/3750 [01:34<06:05,  8.21it/s]

{'loss': 0.0013, 'grad_norm': 0.009780782274901867, 'learning_rate': 1.6016e-05, 'epoch': 0.2}


 21%|██▏       | 801/3750 [01:40<06:00,  8.19it/s]

{'loss': 0.0181, 'grad_norm': 0.008610731922090054, 'learning_rate': 1.5749333333333335e-05, 'epoch': 0.21}


 23%|██▎       | 851/3750 [01:51<06:00,  8.04it/s]

{'loss': 0.0006, 'grad_norm': 0.006961142178624868, 'learning_rate': 1.5482666666666667e-05, 'epoch': 0.23}


 24%|██▍       | 901/3750 [01:57<05:45,  8.25it/s]

{'loss': 0.0048, 'grad_norm': 0.0072786081582307816, 'learning_rate': 1.5216000000000001e-05, 'epoch': 0.24}


 25%|██▌       | 951/3750 [02:03<05:40,  8.22it/s]

{'loss': 0.012, 'grad_norm': 0.011805880814790726, 'learning_rate': 1.4949333333333333e-05, 'epoch': 0.25}


 27%|██▋       | 1001/3750 [02:09<05:34,  8.21it/s]

{'loss': 0.0121, 'grad_norm': 5.335690021514893, 'learning_rate': 1.4682666666666667e-05, 'epoch': 0.27}


 28%|██▊       | 1051/3750 [02:15<05:29,  8.20it/s]

{'loss': 0.0146, 'grad_norm': 0.00791691243648529, 'learning_rate': 1.4416e-05, 'epoch': 0.28}


 29%|██▉       | 1101/3750 [02:26<05:21,  8.23it/s]

{'loss': 0.0012, 'grad_norm': 0.0064302110113203526, 'learning_rate': 1.4149333333333335e-05, 'epoch': 0.29}


 31%|███       | 1151/3750 [02:32<05:16,  8.21it/s]

{'loss': 0.0053, 'grad_norm': 0.004470003768801689, 'learning_rate': 1.3882666666666668e-05, 'epoch': 0.31}


 32%|███▏      | 1201/3750 [02:38<05:10,  8.20it/s]

{'loss': 0.0003, 'grad_norm': 0.004985154140740633, 'learning_rate': 1.3616e-05, 'epoch': 0.32}


 33%|███▎      | 1251/3750 [02:44<05:04,  8.21it/s]

{'loss': 0.0003, 'grad_norm': 0.004128571599721909, 'learning_rate': 1.3349333333333334e-05, 'epoch': 0.33}


 35%|███▍      | 1301/3750 [02:50<04:57,  8.23it/s]

{'loss': 0.0114, 'grad_norm': 5.688532829284668, 'learning_rate': 1.3088e-05, 'epoch': 0.35}


 36%|███▌      | 1351/3750 [02:56<04:54,  8.14it/s]

{'loss': 0.039, 'grad_norm': 0.005249334499239922, 'learning_rate': 1.2821333333333334e-05, 'epoch': 0.36}


 37%|███▋      | 1400/3750 [03:06<15:16,  2.56it/s]

{'loss': 0.0115, 'grad_norm': 0.010301298461854458, 'learning_rate': 1.2554666666666669e-05, 'epoch': 0.37}


 39%|███▊      | 1451/3750 [03:13<04:38,  8.24it/s]

{'loss': 0.0005, 'grad_norm': 0.004811098799109459, 'learning_rate': 1.2288e-05, 'epoch': 0.39}


 40%|████      | 1501/3750 [03:19<04:32,  8.24it/s]

{'loss': 0.0003, 'grad_norm': 0.005671632941812277, 'learning_rate': 1.2021333333333333e-05, 'epoch': 0.4}


 41%|████▏     | 1551/3750 [03:25<04:27,  8.23it/s]

{'loss': 0.0147, 'grad_norm': 0.007145676761865616, 'learning_rate': 1.1754666666666668e-05, 'epoch': 0.41}


 43%|████▎     | 1601/3750 [03:31<04:23,  8.16it/s]

{'loss': 0.0087, 'grad_norm': 0.01032024621963501, 'learning_rate': 1.1488e-05, 'epoch': 0.43}


 44%|████▍     | 1651/3750 [03:37<04:19,  8.07it/s]

{'loss': 0.0085, 'grad_norm': 0.007098798174411058, 'learning_rate': 1.1221333333333336e-05, 'epoch': 0.44}


 45%|████▌     | 1700/3750 [03:47<14:51,  2.30it/s]

{'loss': 0.0141, 'grad_norm': 0.004619240295141935, 'learning_rate': 1.0954666666666667e-05, 'epoch': 0.45}


 47%|████▋     | 1751/3750 [03:54<04:01,  8.27it/s]

{'loss': 0.0007, 'grad_norm': 0.003545518033206463, 'learning_rate': 1.0688e-05, 'epoch': 0.47}


 48%|████▊     | 1801/3750 [04:00<03:56,  8.24it/s]

{'loss': 0.0035, 'grad_norm': 0.003170543583109975, 'learning_rate': 1.0421333333333335e-05, 'epoch': 0.48}


 49%|████▉     | 1851/3750 [04:06<03:51,  8.21it/s]

{'loss': 0.0022, 'grad_norm': 0.003631798317655921, 'learning_rate': 1.0154666666666668e-05, 'epoch': 0.49}


 51%|█████     | 1901/3750 [04:12<03:49,  8.06it/s]

{'loss': 0.0002, 'grad_norm': 0.003186260350048542, 'learning_rate': 9.888000000000001e-06, 'epoch': 0.51}


 52%|█████▏    | 1951/3750 [04:18<03:41,  8.14it/s]

{'loss': 0.0066, 'grad_norm': 0.003792922245338559, 'learning_rate': 9.621333333333334e-06, 'epoch': 0.52}


 53%|█████▎    | 2001/3750 [04:30<03:32,  8.23it/s]

{'loss': 0.0024, 'grad_norm': 0.003110329620540142, 'learning_rate': 9.354666666666667e-06, 'epoch': 0.53}


 55%|█████▍    | 2051/3750 [04:36<03:25,  8.25it/s]

{'loss': 0.006, 'grad_norm': 0.00438734982162714, 'learning_rate': 9.088000000000002e-06, 'epoch': 0.55}


 56%|█████▌    | 2101/3750 [04:42<03:19,  8.26it/s]

{'loss': 0.0021, 'grad_norm': 0.0024841176345944405, 'learning_rate': 8.821333333333333e-06, 'epoch': 0.56}


 57%|█████▋    | 2151/3750 [04:48<03:15,  8.18it/s]

{'loss': 0.0002, 'grad_norm': 0.0019794481340795755, 'learning_rate': 8.554666666666668e-06, 'epoch': 0.57}


 59%|█████▊    | 2201/3750 [04:54<03:10,  8.13it/s]

{'loss': 0.0054, 'grad_norm': 0.0017471067840233445, 'learning_rate': 8.288000000000001e-06, 'epoch': 0.59}


 60%|██████    | 2251/3750 [05:05<03:03,  8.17it/s]

{'loss': 0.0022, 'grad_norm': 0.0024947526399046183, 'learning_rate': 8.021333333333334e-06, 'epoch': 0.6}


 61%|██████▏   | 2301/3750 [05:11<02:55,  8.26it/s]

{'loss': 0.0018, 'grad_norm': 0.00197832053527236, 'learning_rate': 7.754666666666667e-06, 'epoch': 0.61}


 63%|██████▎   | 2351/3750 [05:17<02:49,  8.28it/s]

{'loss': 0.0015, 'grad_norm': 0.0022134361788630486, 'learning_rate': 7.488000000000001e-06, 'epoch': 0.63}


 64%|██████▍   | 2401/3750 [05:23<02:44,  8.18it/s]

{'loss': 0.0126, 'grad_norm': 0.0021320602390915155, 'learning_rate': 7.221333333333333e-06, 'epoch': 0.64}


 65%|██████▌   | 2451/3750 [05:29<02:39,  8.17it/s]

{'loss': 0.0044, 'grad_norm': 0.001906339661218226, 'learning_rate': 6.954666666666667e-06, 'epoch': 0.65}


 67%|██████▋   | 2501/3750 [05:35<02:35,  8.06it/s]

{'loss': 0.0137, 'grad_norm': 0.004570585675537586, 'learning_rate': 6.688e-06, 'epoch': 0.67}


 68%|██████▊   | 2550/3750 [05:44<07:42,  2.59it/s]

{'loss': 0.0003, 'grad_norm': 0.0023480840027332306, 'learning_rate': 6.421333333333334e-06, 'epoch': 0.68}


 69%|██████▉   | 2601/3750 [05:52<02:19,  8.24it/s]

{'loss': 0.0015, 'grad_norm': 0.0023851811420172453, 'learning_rate': 6.154666666666668e-06, 'epoch': 0.69}


 71%|███████   | 2651/3750 [05:58<02:13,  8.22it/s]

{'loss': 0.0001, 'grad_norm': 0.0016720193671062589, 'learning_rate': 5.888e-06, 'epoch': 0.71}


 72%|███████▏  | 2701/3750 [06:04<02:09,  8.13it/s]

{'loss': 0.0002, 'grad_norm': 0.002168837934732437, 'learning_rate': 5.621333333333334e-06, 'epoch': 0.72}


 73%|███████▎  | 2751/3750 [06:10<02:02,  8.14it/s]

{'loss': 0.0106, 'grad_norm': 0.0024712455924600363, 'learning_rate': 5.354666666666667e-06, 'epoch': 0.73}


 75%|███████▍  | 2800/3750 [06:17<04:19,  3.66it/s]

{'loss': 0.0162, 'grad_norm': 0.0019406556384637952, 'learning_rate': 5.088000000000001e-06, 'epoch': 0.75}


 76%|███████▌  | 2851/3750 [06:27<01:48,  8.26it/s]

{'loss': 0.0174, 'grad_norm': 0.004092625807970762, 'learning_rate': 4.821333333333334e-06, 'epoch': 0.76}


 77%|███████▋  | 2901/3750 [06:33<01:42,  8.25it/s]

{'loss': 0.0006, 'grad_norm': 0.0018610279075801373, 'learning_rate': 4.554666666666667e-06, 'epoch': 0.77}


 79%|███████▊  | 2951/3750 [06:39<01:38,  8.12it/s]

{'loss': 0.0002, 'grad_norm': 0.0028861667960882187, 'learning_rate': 4.288e-06, 'epoch': 0.79}


 80%|████████  | 3001/3750 [06:45<01:31,  8.20it/s]

{'loss': 0.0017, 'grad_norm': 0.00245651975274086, 'learning_rate': 4.021333333333333e-06, 'epoch': 0.8}


 81%|████████▏ | 3051/3750 [06:51<01:26,  8.06it/s]

{'loss': 0.0009, 'grad_norm': 0.002112875459715724, 'learning_rate': 3.754666666666667e-06, 'epoch': 0.81}


 83%|████████▎ | 3101/3750 [06:57<01:20,  8.06it/s]

{'loss': 0.0001, 'grad_norm': 0.0016546031692996621, 'learning_rate': 3.4880000000000003e-06, 'epoch': 0.83}


 84%|████████▍ | 3151/3750 [07:03<01:14,  8.04it/s]

{'loss': 0.0001, 'grad_norm': 0.002248524921014905, 'learning_rate': 3.2213333333333334e-06, 'epoch': 0.84}


 85%|████████▌ | 3201/3750 [07:15<01:06,  8.25it/s]

{'loss': 0.0006, 'grad_norm': 0.0014898525550961494, 'learning_rate': 2.954666666666667e-06, 'epoch': 0.85}


 87%|████████▋ | 3251/3750 [07:21<01:00,  8.23it/s]

{'loss': 0.0083, 'grad_norm': 0.0014701859327033162, 'learning_rate': 2.688e-06, 'epoch': 0.87}


 88%|████████▊ | 3301/3750 [07:27<00:54,  8.20it/s]

{'loss': 0.0001, 'grad_norm': 0.0017805667594075203, 'learning_rate': 2.4213333333333334e-06, 'epoch': 0.88}


 89%|████████▉ | 3351/3750 [07:33<00:48,  8.17it/s]

{'loss': 0.0001, 'grad_norm': 0.0016189850866794586, 'learning_rate': 2.154666666666667e-06, 'epoch': 0.89}


 91%|█████████ | 3401/3750 [07:39<00:43,  8.05it/s]

{'loss': 0.0143, 'grad_norm': 0.0018365562427788973, 'learning_rate': 1.8880000000000002e-06, 'epoch': 0.91}


 92%|█████████▏| 3451/3750 [07:45<00:36,  8.13it/s]

{'loss': 0.0003, 'grad_norm': 0.0019030782859772444, 'learning_rate': 1.6213333333333335e-06, 'epoch': 0.92}


 93%|█████████▎| 3501/3750 [07:51<00:31,  8.02it/s]

{'loss': 0.0099, 'grad_norm': 0.0016853839624673128, 'learning_rate': 1.354666666666667e-06, 'epoch': 0.93}


 95%|█████████▍| 3551/3750 [08:03<00:24,  7.99it/s]

{'loss': 0.0001, 'grad_norm': 0.0017265642527490854, 'learning_rate': 1.088e-06, 'epoch': 0.95}


 96%|█████████▌| 3601/3750 [08:09<00:18,  8.23it/s]

{'loss': 0.0102, 'grad_norm': 0.0016302590956911445, 'learning_rate': 8.213333333333334e-07, 'epoch': 0.96}


 97%|█████████▋| 3651/3750 [08:15<00:12,  8.16it/s]

{'loss': 0.009, 'grad_norm': 0.0016861767508089542, 'learning_rate': 5.546666666666667e-07, 'epoch': 0.97}


 99%|█████████▊| 3701/3750 [08:21<00:06,  8.11it/s]

{'loss': 0.0091, 'grad_norm': 0.00156536849681288, 'learning_rate': 2.8800000000000004e-07, 'epoch': 0.99}


100%|██████████| 3750/3750 [08:27<00:00,  8.23it/s]

{'loss': 0.0039, 'grad_norm': 0.0018155795987695456, 'learning_rate': 2.1333333333333336e-08, 'epoch': 1.0}


                                                   
100%|██████████| 3750/3750 [09:02<00:00,  8.23it/s]

{'eval_loss': 0.001749207847751677, 'eval_accuracy': 0.9994, 'eval_f1': 0.9993999799859982, 'eval_runtime': 34.7837, 'eval_samples_per_second': 431.236, 'eval_steps_per_second': 26.967, 'epoch': 1.0}


100%|██████████| 3750/3750 [09:03<00:00,  6.89it/s]


{'train_runtime': 543.9451, 'train_samples_per_second': 110.305, 'train_steps_per_second': 6.894, 'train_loss': 0.014095021969079972, 'epoch': 1.0}


In [20]:
modelpath = "bert-topic-classifier/checkpoint-3750" #cahnge if needed
model = AutoModelForSequenceClassification.from_pretrained(
    modelpath,
    torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

df = pd.read_csv("../test-datasets/sentiment-topic-test.tsv", sep="\t")
sentences = df["sentence"].tolist()
true_labels = df["topic"].tolist()
print(true_labels)
model.eval()
inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)

label_map = {0: "book", 1: "movie", 2: "sports"}
inverse_label_map = {v: k for k, v in label_map.items()}
predicted_labels = [label_map[p.item()] for p in predictions]
print(predicted_labels)
true_label_ids = [inverse_label_map[label] for label in true_labels]

print("Classification Report:")
print(classification_report(true_labels, predicted_labels))

print("\nConfusion Matrix:")
print(confusion_matrix(true_labels, predicted_labels, labels=["negative", "neutral", "positive"]))

for index in range(len(sentences)):
    print(f"Sentence: {sentences[index]}, true label: {true_labels[index]}, predicted labels: {predicted_labels[index]}")




['sports', 'sports', 'book', 'book', 'book', 'movie', 'movie', 'sports', 'sports', 'movie', 'book', 'book', 'movie', 'sports', 'sports', 'movie', 'movie', 'book']
['book', 'book', 'book', 'book', 'book', 'movie', 'book', 'book', 'book', 'book', 'book', 'book', 'movie', 'book', 'book', 'book', 'book', 'book']
Classification Report:
              precision    recall  f1-score   support

        book       0.38      1.00      0.55         6
       movie       1.00      0.33      0.50         6
      sports       0.00      0.00      0.00         6

    accuracy                           0.44        18
   macro avg       0.46      0.44      0.35        18
weighted avg       0.46      0.44      0.35        18


Confusion Matrix:


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


ValueError: At least one label specified must be in y_true