<a href="https://colab.research.google.com/github/ahanam05/deep-learning/blob/main/Transaction_Classification_with_DistilBERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [49]:
import torch
import torch.nn as nn
import torch.optim as optim
from time import time
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import DistilBertModel, DistilBertTokenizer

In [50]:
# configuration and data

NUM_EPOCHS = 10
MAX_LENGTH = 32
LEARNING_RATE = 2e-5
BATCH_SIZE = 8
TEST_SPLIT_RATIO = 0.2

DATA = [
    ("monthly rent payment", "housing"),
    ("apartment rent due", "housing"),
    ("house mortgage payment", "housing"),
    ("home loan EMI", "housing"),
    ("property tax annual payment", "housing"),
    ("housing society maintenance", "housing"),
    ("apartment maintenance charges", "housing"),
    ("condo association fees", "housing"),
    ("house insurance premium", "housing"),
    ("home security system subscription", "housing"),
    ("pest control service", "housing"),
    ("plumbing repair charges", "housing"),
    ("electrical wiring fix", "housing"),
    ("refrigerator maintenance", "housing"),
    ("geyser installation", "housing"),
    ("ceiling fan replacement", "housing"),
    ("light fixtures purchase", "housing"),
    ("door lock replacement", "housing"),
    ("window screen repair", "housing"),
    ("garage door maintenance", "housing"),
    ("garden maintenance service", "housing"),
    ("lawn mowing service", "housing"),
    ("landscaping charges", "housing"),
    ("home cleaning service", "housing"),
    ("deep cleaning charges", "housing"),
    ("carpet cleaning service", "housing"),
    ("chimney sweep service", "housing"),
    ("septic tank cleaning", "housing"),
    ("water tank cleaning", "housing"),
    ("balcony waterproofing", "housing"),
    ("wall seepage repair", "housing"),
    ("kitchen cabinet repair", "housing"),
    ("wardrobe installation", "housing"),
    ("shelving unit purchase", "housing"),
    ("home decoration items", "housing"),
    ("wall paint cans", "housing"),
    ("pillow and cushion set", "housing"),
    ("towel set purchase", "housing"),
    ("kitchen utensils and cookware", "housing"),
    ("dinner plate set", "housing"),
    ("glassware collection", "housing"),
    ("trash bins purchase", "housing"),
    ("vacuum cleaner purchase", "housing"),
    ("iron and ironing board", "housing"),
    ("mop and cleaning supplies", "housing"),
    ("air purifier purchase", "housing"),
    ("room heater purchase", "housing"),
    ("mosquito net installation", "housing"),
    ("house rent deposit", "housing"),
    ("moving truck rental", "housing"),
    ("security deposit payment", "housing"),
    ("uber ride to work", "transportation"),
    ("ola cab to airport", "transportation"),
    ("oil change service", "transportation"),
    ("tire replacement", "transportation"),
    ("brake pad replacement", "transportation"),
    ("battery replacement", "transportation"),
    ("monthly parking subscription", "transportation"),
    ("toll tax payment", "transportation"),
    ("highway toll booth", "transportation"),
    ("bridge toll charge", "transportation"),
    ("car accessories purchase", "transportation"),
    ("seat covers installation", "transportation"),
    ("car stereo upgrade", "transportation"),
    ("GPS device purchase", "transportation"),
    ("dashcam installation", "transportation"),
    ("bike helmet purchase", "transportation"),
    ("riding jacket and gloves", "transportation"),
    ("puncture repair charges", "transportation"),
    ("towing service payment", "transportation"),
    ("roadside assistance fee", "transportation"),
    ("vehicle inspection fee", "transportation"),
    ("driving license renewal", "transportation"),
    ("learner permit fee", "transportation"),
    ("traffic challan payment", "transportation"),
    ("speeding ticket fine", "transportation"),
    ("parking violation fine", "transportation"),
    ("ride sharing payment", "transportation"),
    ("carpool contribution", "transportation"),
    ("bike rental for weekend", "transportation"),
    ("car rental for road trip", "transportation"),
    ("scooter rental charges", "transportation"),
    ("airport shuttle service", "transportation"),
    ("hotel pickup charges", "transportation"),
    ("intercity bus ticket", "transportation"),
    ("sleeper bus booking", "transportation"),
    ("train sleeper berth", "transportation"),
    ("AC coach upgrade", "transportation"),
    ("baggage handling fee", "transportation"),
    ("extra luggage charge", "transportation"),
    ("seat reservation fee", "transportation"),
    ("priority boarding fee", "transportation"),
    ("lounge access at airport", "transportation"),
    ("travel insurance purchase", "transportation"),
    ("visa processing fee", "transportation"),
    ("passport renewal charges", "transportation"),
    ("mutual fund SIP payment", "investments"),
    ("monthly SIP investment", "investments"),
    ("stock market purchase", "investments"),
    ("equity shares bought", "investments"),
    ("IPO application money", "investments"),
    ("bonds investment", "investments"),
    ("government bonds purchase", "investments"),
    ("fixed deposit renewal", "investments"),
    ("recurring deposit installment", "investments"),
    ("PPF account deposit", "investments"),
    ("NPS contribution", "investments"),
    ("retirement fund contribution", "investments"),
    ("401k contribution", "investments"),
    ("pension plan payment", "investments"),
    ("gold ETF purchase", "investments"),
    ("sovereign gold bond", "investments"),
    ("digital gold investment", "investments"),
    ("cryptocurrency purchase", "investments"),
    ("bitcoin investment", "investments"),
    ("ethereum bought", "investments"),
    ("real estate investment", "investments"),
    ("REIT investment", "investments"),
    ("property down payment", "investments"),
    ("land purchase payment", "investments"),
    ("sukanya samriddhi deposit", "investments"),
    ("ELSS tax saving fund", "investments"),
    ("tax saving investment", "investments"),
    ("long term investment fund", "investments"),
    ("wealth management fees", "investments"),
    ("financial advisor consultation", "investments"),
    ("portfolio management charge", "investments"),
    ("demat account charges", "investments"),
    ("trading account maintenance", "investments"),
    ("brokerage fees paid", "investments"),
    ("investment app subscription", "investments"),
    ("robo advisor fees", "investments"),
    ("cryptocurrency exchange fee", "investments"),
    ("forex trading investment", "investments"),
    ("commodities investment", "investments"),
    ("precious metals purchase", "investments"),
    ("silver coins investment", "investments"),
    ("art investment purchase", "investments"),
    ("wine collection investment", "investments"),
    ("antique purchase", "investments"),
    ("collectibles investment", "investments"),
    ("startup equity investment", "investments"),
    ("angel investing contribution", "investments"),
    ("venture capital investment", "investments"),
    ("peer to peer lending", "investments"),
    ("grocery shopping at supermarket", "food"),
    ("weekly grocery purchase", "food"),
    ("monthly provisions shopping", "food"),
    ("vegetables and fruits", "food"),
    ("fresh produce market", "food"),
    ("meat and chicken purchase", "food"),
    ("fish market shopping", "food"),
    ("dairy products purchase", "food"),
    ("milk and eggs", "food"),
    ("bread and bakery items", "food"),
    ("lunch at cafe", "food"),
    ("breakfast at diner", "food"),
    ("fast food order", "food"),
    ("pizza delivery", "food"),
    ("burger meal", "food"),
    ("chinese food takeout", "food"),
    ("indian restaurant bill", "food"),
    ("italian cuisine dinner", "food"),
    ("sushi restaurant", "food"),
    ("thai food order", "food"),
    ("mexican food delivery", "food"),
    ("food court meal", "food"),
    ("street food purchase", "food"),
    ("food truck order", "food"),
    ("buffet lunch payment", "food"),
    ("fine dining experience", "food"),
    ("coffee shop purchase", "food"),
    ("starbucks order", "food"),
    ("cafe latte and pastry", "food"),
    ("tea and snacks", "food"),
    ("bakery pastries", "food"),
    ("cake purchase", "food"),
    ("dessert shop visit", "food"),
    ("ice cream parlor", "food"),
    ("donut shop", "food"),
    ("smoothie purchase", "food"),
    ("juice bar order", "food"),
    ("meal prep service", "food"),
    ("tiffin service payment", "food"),
    ("lunch box subscription", "food"),
    ("food delivery app order", "food"),
    ("swiggy order", "food"),
    ("zomato food delivery", "food"),
    ("uber eats order", "food"),
    ("online grocery delivery", "food"),
    ("grocery app payment", "food"),
    ("office canteen meal", "food"),
    ("workplace cafeteria", "food"),
    ("vending machine snacks", "food"),
    ("catering service payment", "food"),
    ("party food order", "food"),
    ("birthday cake order", "food"),
    ("wedding catering advance", "food"),
    ("buffet catering", "food"),
    ("sandwich and wrap", "food"),
    ("salad bar purchase", "food"),
    ("healthy meal subscription", "food"),
    ("craft beer purchase", "food"),
    ("meal kit subscription", "food"),
    ("cooking ingredients", "food"),
    ("baking supplies", "food"),
    ("kitchen condiments", "food"),
    ("sauces and dressings", "food"),
    ("herbs and seasonings", "food"),
    ("instant meals", "food"),
    ("ready to eat food", "food"),
    ("electricity bill payment", "utilities"),
    ("power bill for apartment", "utilities"),
    ("electric company charge", "utilities"),
    ("electricity meter reading", "utilities"),
    ("water bill payment", "utilities"),
    ("municipal water charges", "utilities"),
    ("water supply bill", "utilities"),
    ("gas cylinder booking", "utilities"),
    ("LPG gas refill", "utilities"),
    ("piped gas bill", "utilities"),
    ("broadband monthly charges", "utilities"),
    ("wifi service payment", "utilities"),
    ("fiber optic internet", "utilities"),
    ("internet service provider", "utilities"),
    ("mobile phone bill", "utilities"),
    ("postpaid mobile plan", "utilities"),
    ("phone service charges", "utilities"),
    ("mobile recharge", "utilities"),
    ("prepaid mobile topup", "utilities"),
    ("data pack purchase", "utilities"),
    ("roaming charges", "utilities"),
    ("international calling pack", "utilities"),
    ("landline phone bill", "utilities"),
    ("telephone service payment", "utilities"),
    ("cable TV subscription", "utilities"),
    ("DTH recharge", "utilities"),
    ("satellite TV payment", "utilities"),
    ("TV channels package", "utilities"),
    ("streaming service subscription", "utilities"),
    ("spotify subscription", "utilities"),
    ("apple music payment", "utilities"),
    ("cloud storage subscription", "utilities"),
    ("daily newspaper subscription", "utilities"),
    ("magazine subscription", "utilities"),
    ("digital news subscription", "utilities"),
    ("sewage charges", "utilities"),
    ("trash collection fee", "utilities"),
    ("waste management charges", "utilities"),
    ("recycling service fee", "utilities"),
    ("HOA utility charges", "utilities"),
    ("common area electricity", "utilities"),
    ("building maintenance utilities", "utilities"),
    ("generator maintenance fee", "utilities"),
    ("water softener maintenance", "utilities"),
    ("solar panel maintenance", "utilities"),
    ("backup power charges", "utilities"),
    ("antenna installation fee", "utilities"),
    ("router replacement", "utilities"),
    ("modem rental charges", "utilities"),
    ("VPN subscription", "utilities"),
    ("movie ticket purchase", "entertainment"),
    ("cinema hall booking", "entertainment"),
    ("IMAX movie experience", "entertainment"),
    ("theater show tickets", "entertainment"),
    ("play performance tickets", "entertainment"),
    ("concert ticket purchase", "entertainment"),
    ("music festival pass", "entertainment"),
    ("stand-up comedy show", "entertainment"),
    ("sports match tickets", "entertainment"),
    ("cricket match booking", "entertainment"),
    ("F1 race tickets", "entertainment"),
    ("theme park entry", "entertainment"),
    ("amusement park tickets", "entertainment"),
    ("water park visit", "entertainment"),
    ("zoo entry tickets", "entertainment"),
    ("aquarium visit", "entertainment"),
    ("museum entry fee", "entertainment"),
    ("art gallery tickets", "entertainment"),
    ("science center visit", "entertainment"),
    ("planetarium show", "entertainment"),
    ("adventure park booking", "entertainment"),
    ("trampoline park entry", "entertainment"),
    ("bowling alley charges", "entertainment"),
    ("gaming arcade tokens", "entertainment"),
    ("laser tag game", "entertainment"),
    ("gaming console purchase", "entertainment"),
    ("playstation 5 purchase", "entertainment"),
    ("xbox series x", "entertainment"),
    ("nintendo switch", "entertainment"),
    ("video game purchase", "entertainment"),
    ("steam game download", "entertainment"),
    ("playstation plus subscription", "entertainment"),
    ("xbox game pass", "entertainment"),
    ("gaming accessories", "entertainment"),
    ("controller purchase", "entertainment"),
    ("gaming headset", "entertainment"),
    ("VR headset purchase", "entertainment"),
    ("board games purchase", "entertainment"),
    ("card games set", "entertainment"),
    ("puzzle purchase", "entertainment"),
    ("hobby supplies", "entertainment"),
    ("craft materials", "entertainment"),
    ("painting supplies", "entertainment"),
    ("musical instrument purchase", "entertainment"),
    ("guitar purchase", "entertainment"),
    ("keyboard instrument", "entertainment"),
    ("music lessons payment", "entertainment"),
    ("guitar class fees", "entertainment"),
    ("dance class subscription", "entertainment"),
    ("yoga class membership", "entertainment"),
    ("photography course", "entertainment"),
    ("cooking class fees", "entertainment"),
    ("art workshop payment", "entertainment"),
    ("pottery class", "entertainment"),
    ("spa day package", "entertainment"),
    ("massage therapy session", "entertainment"),
    ("salon visit", "entertainment"),
    ("haircut and styling", "entertainment"),
    ("manicure pedicure", "entertainment"),
    ("book purchase", "entertainment"),
    ("novel collection", "entertainment"),
    ("kindle books", "entertainment"),
    ("audiobook subscription", "entertainment"),
    ("doctor consultation fee", "healthcare"),
    ("general physician visit", "healthcare"),
    ("specialist doctor appointment", "healthcare"),
    ("pediatrician consultation", "healthcare"),
    ("dermatologist visit", "healthcare"),
    ("dentist appointment", "healthcare"),
    ("dental cleaning", "healthcare"),
    ("tooth filling charges", "healthcare"),
    ("root canal treatment", "healthcare"),
    ("orthodontic braces", "healthcare"),
    ("eye checkup", "healthcare"),
    ("optometrist consultation", "healthcare"),
    ("eyeglasses purchase", "healthcare"),
    ("contact lenses", "healthcare"),
    ("prescription medicine", "healthcare"),
    ("pharmacy purchase", "healthcare"),
    ("bandages and gauze", "healthcare"),
    ("medical equipment", "healthcare"),
    ("blood pressure monitor", "healthcare"),
    ("glucometer purchase", "healthcare"),
    ("thermometer purchase", "healthcare"),
    ("diagnostic tests", "healthcare"),
    ("blood test charges", "healthcare"),
    ("urine test", "healthcare"),
    ("x-ray charges", "healthcare"),
    ("MRI scan payment", "healthcare"),
    ("CT scan charges", "healthcare"),
    ("ultrasound test", "healthcare"),
    ("ECG test", "healthcare"),
    ("health checkup package", "healthcare"),
    ("full body checkup", "healthcare"),
    ("annual physical exam", "healthcare"),
    ("vaccination charges", "healthcare"),
    ("flu shot", "healthcare"),
    ("covid vaccine", "healthcare"),
    ("operation theater charges", "healthcare"),
    ("hospital room charges", "healthcare"),
    ("ICU charges", "healthcare"),
    ("medical procedure", "healthcare"),
    ("physiotherapy session", "healthcare"),
    ("physical therapy", "healthcare"),
    ("occupational therapy", "healthcare"),
    ("speech therapy", "healthcare"),
    ("chiropractic treatment", "healthcare"),
    ("acupuncture session", "healthcare"),
    ("alternative medicine", "healthcare"),
    ("ayurvedic treatment", "healthcare"),
    ("homeopathy consultation", "healthcare"),
    ("mental health counseling", "healthcare"),
    ("psychiatrist visit", "healthcare"),
    ("psychologist consultation", "healthcare"),
    ("therapy session", "healthcare"),
    ("counseling fees", "healthcare"),
    ("health insurance premium", "healthcare"),
    ("medical insurance payment", "healthcare"),
    ("health insurance renewal", "healthcare"),
    ("insurance co-payment", "healthcare"),
    ("insurance deductible", "healthcare"),
    ("medical claim reimbursement", "healthcare"),
    ("hearing aid purchase", "healthcare"),
    ("wheelchair rental", "healthcare"),
    ("crutches purchase", "healthcare"),
    ("medical mobility aids", "healthcare"),
    ("amazon online shopping", "miscellaneous"),
    ("flipkart purchase", "miscellaneous"),
    ("online shopping delivery", "miscellaneous"),
    ("e-commerce purchase", "miscellaneous"),
    ("clothing purchase", "miscellaneous"),
    ("new shirt and pants", "miscellaneous"),
    ("dress shopping", "miscellaneous"),
    ("shoes purchase", "miscellaneous"),
    ("wallet bought", "miscellaneous"),
    ("belt and accessories", "miscellaneous"),
    ("jewelry purchase", "miscellaneous"),
    ("gold jewelry", "miscellaneous"),
    ("silver ornaments", "miscellaneous"),
    ("fashion jewelry", "miscellaneous"),
    ("cosmetics purchase", "miscellaneous"),
    ("makeup products", "miscellaneous"),
    ("skincare items", "miscellaneous"),
    ("perfume purchase", "miscellaneous"),
    ("hair care products", "miscellaneous"),
    ("personal care items", "miscellaneous"),
    ("shaving supplies", "miscellaneous"),
    ("toiletries purchase", "miscellaneous"),
    ("bath products", "miscellaneous"),
    ("laundry detergent", "miscellaneous"),
    ("cleaning supplies", "miscellaneous"),
    ("household items", "miscellaneous"),
    ("home essentials", "miscellaneous"),
    ("stationery purchase", "miscellaneous"),
    ("phone case and accessories", "miscellaneous"),
    ("charger and cables", "miscellaneous"),
    ("computer accessories", "miscellaneous"),
    ("mouse and keyboard", "miscellaneous"),
    ("webcam purchase", "miscellaneous"),
    ("monitor purchase", "miscellaneous"),
    ("printer purchase", "miscellaneous"),
    ("camera purchase", "miscellaneous"),
    ("photography equipment", "miscellaneous"),
    ("tripod stand", "miscellaneous"),
    ("pet supplies", "miscellaneous"),
    ("postal charges", "miscellaneous"),
]

In [51]:
# build label map

def build_label_map(data):
  unique_labels = sorted(list(set(label for _, label in data)))
  label_map = {label: idx for idx, label in enumerate(unique_labels)}
  return label_map

In [52]:
# build dataset

class TransactionDataset(Dataset):
  def __init__(self, data, tokenizer, label_map, max_len):
    self.data = data
    self.tokenizer = tokenizer
    self.label_map = label_map
    self.max_len = max_len

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    text, label = self.data[idx]

    encoding = self.tokenizer(
        text,
        max_length = self.max_len,
        padding = 'max_length',
        truncation = True,
        add_special_tokens = True,
        return_tensors = 'pt')

    return {
        'input_ids': encoding['input_ids'].flatten(),
        'attention_mask': encoding['attention_mask'].flatten(),
        'label': torch.tensor(self.label_map[label], dtype = torch.long)
    }

In [53]:
# build model architecture with DistilBERT

class TransactionClassifier(nn.Module):
  def __init__(self, num_classes):
    super().__init__()
    self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
    self.dropout = nn.Dropout(0.3)
    self.classifier = nn.Linear(768, num_classes)

  def forward(self, input_ids, attention_mask):
    outputs = self.distilbert(
        input_ids=input_ids,
        attention_mask=attention_mask
    )

    pooled_output = outputs.last_hidden_state[:, 0, :]
    pooled_output = self.dropout(pooled_output)
    logits = self.classifier(pooled_output)

    return logits

In [54]:
# evaluation function

def evaluate(model, data_loader, criterion, device):
  model.eval()
  total_samples = 0
  total_loss = 0
  correct_predictions = 0

  with torch.no_grad():
    for batch in data_loader:
      input_ids = batch['input_ids'].to(device)
      attention_mask = batch['attention_mask'].to(device)
      labels = batch['label'].to(device)

      outputs = model(input_ids, attention_mask)
      loss = criterion(outputs, labels)

      total_loss += loss.item() * input_ids.size(0)

      _, predicted = torch.max(outputs, 1)
      correct_predictions += (predicted == labels).sum().item()
      total_samples += labels.size(0)

  average_loss = total_loss/total_samples
  accuracy = (correct_predictions/total_samples)*100

  return average_loss, accuracy

In [55]:
# execution

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# initialize tokenizer and build label map
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
label_map = build_label_map(DATA)
NUM_CLASSES = len(label_map)

# prepare training and test datasets with their data loaders
full_dataset = TransactionDataset(DATA, tokenizer, label_map, MAX_LENGTH)
train_size = int((1 - TEST_SPLIT_RATIO) * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle = False)

# initialize the model and define loss function, optimizer
model = TransactionClassifier(NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr = LEARNING_RATE)

print(f"Total Examples: {len(full_dataset)} | Train: {train_size} | Test: {test_size}")
print(f"Number of Classes: {NUM_CLASSES}")
print(f"Starting Training ({NUM_EPOCHS} epochs)")

# training loop with forward pass, backward pass and evaluation
start_time = time()
for epoch in range(NUM_EPOCHS):
  model.train()
  epoch_loss = 0

  for batch in train_loader:
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['label'].to(device)

    outputs = model(input_ids, attention_mask)
    loss = criterion(outputs, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    epoch_loss += loss.item()

  training_loss, training_accuracy = evaluate(model, train_loader, criterion, device)
  testing_loss, testing_accuracy = evaluate(model, test_loader, criterion, device)

  print(f"Epoch {epoch+1:2d}/{NUM_EPOCHS} | "
        f"Train Loss: {training_loss:.4f} | Train Acc: {training_accuracy:.2f}% | "
        f"Test Loss: {testing_loss:.4f} | Test Acc: {testing_accuracy:.2f}%")

end_time = time()
print(f"Training completed in {((end_time - start_time)/60):.2f} minutes")

Using device: cpu
Total Examples: 427 | Train: 341 | Test: 86
Number of Classes: 8
Starting Training (10 epochs)
Epoch  1/10 | Train Loss: 1.2187 | Train Acc: 81.52% | Test Loss: 1.3634 | Test Acc: 72.09%
Epoch  2/10 | Train Loss: 0.4124 | Train Acc: 93.55% | Test Loss: 0.6990 | Test Acc: 81.40%
Epoch  3/10 | Train Loss: 0.1411 | Train Acc: 97.95% | Test Loss: 0.5207 | Test Acc: 82.56%
Epoch  4/10 | Train Loss: 0.0576 | Train Acc: 99.41% | Test Loss: 0.5565 | Test Acc: 84.88%
Epoch  5/10 | Train Loss: 0.0268 | Train Acc: 100.00% | Test Loss: 0.5835 | Test Acc: 84.88%
Epoch  6/10 | Train Loss: 0.0160 | Train Acc: 100.00% | Test Loss: 0.6210 | Test Acc: 84.88%
Epoch  7/10 | Train Loss: 0.0101 | Train Acc: 100.00% | Test Loss: 0.6221 | Test Acc: 84.88%
Epoch  8/10 | Train Loss: 0.0103 | Train Acc: 100.00% | Test Loss: 0.6507 | Test Acc: 84.88%
Epoch  9/10 | Train Loss: 0.0062 | Train Acc: 100.00% | Test Loss: 0.6808 | Test Acc: 84.88%
Epoch 10/10 | Train Loss: 0.0072 | Train Acc: 100.00% 

In [57]:
# Test with sample predictions
print("Sample Predictions:")

model.eval()
test_samples = [
    "bought milk and eggs from store",
    "uber to airport",
    "netflix monthly subscription",
    "paid rent for this month",
    "had a gallbladder surgery",
    "paid the electricity bills"
]

inverse_label_map = {v: k for k, v in label_map.items()}

with torch.no_grad():
    for sample in test_samples:
        encoding = tokenizer(
            sample,
            add_special_tokens=True,
            max_length=MAX_LENGTH,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)

        outputs = model(input_ids, attention_mask)
        _, predicted = torch.max(outputs, 1)

        predicted_label = inverse_label_map[predicted.item()]
        print(f"Text: '{sample}'")
        print(f"Predicted: {predicted_label}\n")

Sample Predictions:
Text: 'bought milk and eggs from store'
Predicted: food

Text: 'uber to airport'
Predicted: transportation

Text: 'netflix monthly subscription'
Predicted: utilities

Text: 'paid rent for this month'
Predicted: housing

Text: 'had a gallbladder surgery'
Predicted: healthcare

Text: 'paid the electricity bills'
Predicted: utilities

