In [16]:
# Install necessary packages if not already installed:
# pip install transformers datasets sklearn

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import numpy as np
from sklearn.metrics import accuracy_score, f1_score

# 1. Load the sentiment/emotion dataset
dataset = load_dataset("emotion")  # 6 emotions: sadness, joy, love, anger, fear, surprise
# If using GoEmotions, you could do: dataset = load_dataset("go_emotions", "simplified") and map 27 emotions to 4 sentiments here.

# 2. Pre-process: Tokenize the text
checkpoint = "distilbert-base-uncased"  # Using DistilBERT for efficiency
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

def tokenize_batch(batch):
    return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=64)

# Tokenize the dataset (this adds 'input_ids' and 'attention_mask' columns)
tokenized_dataset = dataset.map(tokenize_batch, batched=True)

# 3. Set format for PyTorch
tokenized_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])

# 4. Load the pre-trained model for sequence classification
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=6)

# 5. Define training hyperparameters and Trainer
batch_size = 16
training_args = TrainingArguments(
    output_dir="sentiment_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=4,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    logging_steps=50,
    log_level="error"  # only show errors to avoid too much log output
)

# Metric function for evaluation
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    acc = accuracy_score(labels, preds)
    macro_f1 = f1_score(labels, preds, average='macro')
    return {"accuracy": acc, "macro_f1": macro_f1}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    compute_metrics=compute_metrics
)

# 6. Train the model
trainer.train()

# 7. Evaluate on the test set
results = trainer.evaluate(tokenized_dataset["test"])
print(f"Test accuracy: {results['eval_accuracy']:.4f}")
print(f"Test Macro-F1: {results['eval_macro_f1']:.4f}")


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Generating train split: 100%|██████████| 16000/16000 [00:00<00:00, 482093.52 examples/s]
Generating validation split: 100%|██████████| 2000/2000 [00:00<00:00, 979977.57 examples/s]
Generating test split: 100%|██████████| 2000/2000 [00:00<00:00, 795505.74 examples/s]
Map: 100%|██████████| 16000/16000 [00:00<00:00, 51004.69 examples/s]
Map: 100%|██████████| 2000/2000 [00:00<00:00, 41022.50 examples/s]
Map: 100%|██████████| 2000/2000 [00:00<00:00, 35934.44 examples/s]
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'classifier.bias', 'pre_classifier.bias']
You should probably TRAIN t

{'loss': 1.5308, 'learning_rate': 1.9750000000000002e-05, 'epoch': 0.05}


  3%|▎         | 102/4000 [00:03<01:55, 33.82it/s]

{'loss': 1.2819, 'learning_rate': 1.951e-05, 'epoch': 0.1}


  4%|▍         | 154/4000 [00:04<01:49, 35.17it/s]

{'loss': 1.0329, 'learning_rate': 1.9260000000000002e-05, 'epoch': 0.15}


  5%|▌         | 206/4000 [00:06<01:46, 35.66it/s]

{'loss': 0.8524, 'learning_rate': 1.9010000000000003e-05, 'epoch': 0.2}


  6%|▋         | 254/4000 [00:07<01:43, 36.02it/s]

{'loss': 0.6901, 'learning_rate': 1.876e-05, 'epoch': 0.25}


  8%|▊         | 306/4000 [00:09<01:43, 35.58it/s]

{'loss': 0.6016, 'learning_rate': 1.851e-05, 'epoch': 0.3}


  9%|▉         | 354/4000 [00:10<01:38, 37.03it/s]

{'loss': 0.4353, 'learning_rate': 1.826e-05, 'epoch': 0.35}


 10%|█         | 406/4000 [00:11<01:38, 36.37it/s]

{'loss': 0.3975, 'learning_rate': 1.8010000000000002e-05, 'epoch': 0.4}


 11%|█▏        | 454/4000 [00:13<01:38, 35.94it/s]

{'loss': 0.3546, 'learning_rate': 1.7760000000000003e-05, 'epoch': 0.45}


 13%|█▎        | 506/4000 [00:14<01:37, 35.86it/s]

{'loss': 0.3322, 'learning_rate': 1.751e-05, 'epoch': 0.5}


 14%|█▍        | 554/4000 [00:15<01:33, 36.85it/s]

{'loss': 0.2577, 'learning_rate': 1.726e-05, 'epoch': 0.55}


 15%|█▌        | 606/4000 [00:17<01:33, 36.11it/s]

{'loss': 0.2703, 'learning_rate': 1.701e-05, 'epoch': 0.6}


 16%|█▋        | 654/4000 [00:18<01:34, 35.57it/s]

{'loss': 0.2768, 'learning_rate': 1.6760000000000002e-05, 'epoch': 0.65}


 18%|█▊        | 706/4000 [00:20<01:28, 37.02it/s]

{'loss': 0.2698, 'learning_rate': 1.6510000000000003e-05, 'epoch': 0.7}


 19%|█▉        | 754/4000 [00:21<01:29, 36.43it/s]

{'loss': 0.2305, 'learning_rate': 1.626e-05, 'epoch': 0.75}


 20%|██        | 806/4000 [00:22<01:31, 34.81it/s]

{'loss': 0.2111, 'learning_rate': 1.601e-05, 'epoch': 0.8}


 21%|██▏       | 854/4000 [00:24<01:26, 36.52it/s]

{'loss': 0.2222, 'learning_rate': 1.576e-05, 'epoch': 0.85}


 23%|██▎       | 906/4000 [00:25<01:26, 35.77it/s]

{'loss': 0.2121, 'learning_rate': 1.5510000000000002e-05, 'epoch': 0.9}


 24%|██▍       | 954/4000 [00:26<01:29, 34.12it/s]

{'loss': 0.2148, 'learning_rate': 1.5260000000000003e-05, 'epoch': 0.95}


 25%|██▌       | 1000/4000 [00:28<01:24, 35.66it/s]

{'loss': 0.2458, 'learning_rate': 1.501e-05, 'epoch': 1.0}



 25%|██▌       | 1000/4000 [00:29<01:24, 35.66it/s]

{'eval_loss': 0.1959214210510254, 'eval_accuracy': 0.9285, 'eval_macro_f1': 0.8955428794928096, 'eval_runtime': 0.771, 'eval_samples_per_second': 2593.905, 'eval_steps_per_second': 162.119, 'epoch': 1.0}


 26%|██▋       | 1054/4000 [00:31<01:30, 32.57it/s]

{'loss': 0.1654, 'learning_rate': 1.4760000000000001e-05, 'epoch': 1.05}


 28%|██▊       | 1106/4000 [00:32<01:21, 35.38it/s]

{'loss': 0.1536, 'learning_rate': 1.4510000000000002e-05, 'epoch': 1.1}


 29%|██▉       | 1154/4000 [00:34<01:19, 35.63it/s]

{'loss': 0.1368, 'learning_rate': 1.426e-05, 'epoch': 1.15}


 30%|███       | 1206/4000 [00:35<01:22, 33.87it/s]

{'loss': 0.1559, 'learning_rate': 1.4010000000000001e-05, 'epoch': 1.2}


 31%|███▏      | 1254/4000 [00:37<01:16, 36.02it/s]

{'loss': 0.1886, 'learning_rate': 1.376e-05, 'epoch': 1.25}


 33%|███▎      | 1306/4000 [00:38<01:16, 35.15it/s]

{'loss': 0.1436, 'learning_rate': 1.3510000000000001e-05, 'epoch': 1.3}


 34%|███▍      | 1354/4000 [00:39<01:15, 35.15it/s]

{'loss': 0.1209, 'learning_rate': 1.3260000000000002e-05, 'epoch': 1.35}


 35%|███▌      | 1406/4000 [00:41<01:14, 34.77it/s]

{'loss': 0.2047, 'learning_rate': 1.301e-05, 'epoch': 1.4}


 36%|███▋      | 1454/4000 [00:42<01:11, 35.71it/s]

{'loss': 0.1231, 'learning_rate': 1.2760000000000001e-05, 'epoch': 1.45}


 38%|███▊      | 1502/4000 [00:44<01:14, 33.56it/s]

{'loss': 0.1946, 'learning_rate': 1.251e-05, 'epoch': 1.5}


 39%|███▉      | 1554/4000 [00:45<01:11, 34.33it/s]

{'loss': 0.1192, 'learning_rate': 1.2260000000000001e-05, 'epoch': 1.55}


 40%|████      | 1606/4000 [00:47<01:07, 35.34it/s]

{'loss': 0.127, 'learning_rate': 1.2010000000000002e-05, 'epoch': 1.6}


 41%|████▏     | 1654/4000 [00:48<01:09, 33.87it/s]

{'loss': 0.2042, 'learning_rate': 1.1760000000000001e-05, 'epoch': 1.65}


 43%|████▎     | 1706/4000 [00:50<01:07, 34.09it/s]

{'loss': 0.1618, 'learning_rate': 1.1510000000000002e-05, 'epoch': 1.7}


 44%|████▍     | 1754/4000 [00:51<01:03, 35.44it/s]

{'loss': 0.15, 'learning_rate': 1.126e-05, 'epoch': 1.75}


 45%|████▌     | 1806/4000 [00:52<01:03, 34.32it/s]

{'loss': 0.1512, 'learning_rate': 1.1010000000000001e-05, 'epoch': 1.8}


 46%|████▋     | 1854/4000 [00:54<01:04, 33.42it/s]

{'loss': 0.132, 'learning_rate': 1.0760000000000002e-05, 'epoch': 1.85}


 48%|████▊     | 1906/4000 [00:55<01:01, 33.94it/s]

{'loss': 0.1535, 'learning_rate': 1.0510000000000001e-05, 'epoch': 1.9}


 49%|████▉     | 1954/4000 [00:57<00:59, 34.55it/s]

{'loss': 0.1945, 'learning_rate': 1.0260000000000002e-05, 'epoch': 1.95}


 50%|█████     | 2000/4000 [00:58<00:58, 34.19it/s]

{'loss': 0.1368, 'learning_rate': 1.0009999999999999e-05, 'epoch': 2.0}



 50%|█████     | 2000/4000 [00:59<00:58, 34.19it/s]

{'eval_loss': 0.18374912440776825, 'eval_accuracy': 0.9355, 'eval_macro_f1': 0.9138653930010278, 'eval_runtime': 0.8394, 'eval_samples_per_second': 2382.543, 'eval_steps_per_second': 148.909, 'epoch': 2.0}


 51%|█████▏    | 2054/4000 [01:01<00:59, 32.85it/s]

{'loss': 0.1154, 'learning_rate': 9.760000000000001e-06, 'epoch': 2.05}


 53%|█████▎    | 2106/4000 [01:03<00:55, 34.14it/s]

{'loss': 0.1388, 'learning_rate': 9.51e-06, 'epoch': 2.1}


 54%|█████▍    | 2154/4000 [01:04<00:54, 33.85it/s]

{'loss': 0.1077, 'learning_rate': 9.265e-06, 'epoch': 2.15}


 55%|█████▌    | 2202/4000 [01:06<00:52, 34.07it/s]

{'loss': 0.1214, 'learning_rate': 9.015000000000001e-06, 'epoch': 2.2}


 56%|█████▋    | 2254/4000 [01:07<00:51, 34.00it/s]

{'loss': 0.095, 'learning_rate': 8.765e-06, 'epoch': 2.25}


 58%|█████▊    | 2306/4000 [01:09<00:48, 34.63it/s]

{'loss': 0.1039, 'learning_rate': 8.515e-06, 'epoch': 2.3}


 59%|█████▉    | 2354/4000 [01:10<00:48, 34.28it/s]

{'loss': 0.0791, 'learning_rate': 8.265000000000001e-06, 'epoch': 2.35}


 60%|██████    | 2403/4000 [01:11<00:45, 35.30it/s]

{'loss': 0.1095, 'learning_rate': 8.015e-06, 'epoch': 2.4}


 61%|██████▏   | 2455/4000 [01:13<00:45, 33.70it/s]

{'loss': 0.0715, 'learning_rate': 7.765000000000001e-06, 'epoch': 2.45}


 63%|██████▎   | 2507/4000 [01:14<00:42, 35.19it/s]

{'loss': 0.0952, 'learning_rate': 7.515e-06, 'epoch': 2.5}


 64%|██████▍   | 2555/4000 [01:16<00:41, 34.92it/s]

{'loss': 0.0852, 'learning_rate': 7.265000000000001e-06, 'epoch': 2.55}


 65%|██████▌   | 2603/4000 [01:17<00:41, 33.66it/s]

{'loss': 0.1155, 'learning_rate': 7.015000000000001e-06, 'epoch': 2.6}


 66%|██████▋   | 2655/4000 [01:19<00:40, 33.45it/s]

{'loss': 0.1048, 'learning_rate': 6.7650000000000005e-06, 'epoch': 2.65}


 68%|██████▊   | 2703/4000 [01:20<00:40, 32.30it/s]

{'loss': 0.0922, 'learning_rate': 6.515e-06, 'epoch': 2.7}


 69%|██████▉   | 2755/4000 [01:22<00:36, 33.81it/s]

{'loss': 0.1321, 'learning_rate': 6.265e-06, 'epoch': 2.75}


 70%|███████   | 2803/4000 [01:23<00:35, 33.63it/s]

{'loss': 0.1096, 'learning_rate': 6.015000000000001e-06, 'epoch': 2.8}


 71%|███████▏  | 2855/4000 [01:25<00:32, 35.05it/s]

{'loss': 0.0954, 'learning_rate': 5.765000000000001e-06, 'epoch': 2.85}


 73%|███████▎  | 2903/4000 [01:26<00:32, 33.60it/s]

{'loss': 0.1005, 'learning_rate': 5.5150000000000006e-06, 'epoch': 2.9}


 74%|███████▍  | 2955/4000 [01:27<00:30, 34.66it/s]

{'loss': 0.1435, 'learning_rate': 5.265e-06, 'epoch': 2.95}


 75%|███████▌  | 3000/4000 [01:29<00:28, 35.71it/s]

{'loss': 0.1195, 'learning_rate': 5.015e-06, 'epoch': 3.0}



 75%|███████▌  | 3000/4000 [01:30<00:28, 35.71it/s]

{'eval_loss': 0.18004994094371796, 'eval_accuracy': 0.934, 'eval_macro_f1': 0.907203587982894, 'eval_runtime': 0.7825, 'eval_samples_per_second': 2555.85, 'eval_steps_per_second': 159.741, 'epoch': 3.0}


 76%|███████▋  | 3055/4000 [01:32<00:27, 34.18it/s]

{'loss': 0.1292, 'learning_rate': 4.765e-06, 'epoch': 3.05}


 78%|███████▊  | 3103/4000 [01:33<00:26, 33.77it/s]

{'loss': 0.0776, 'learning_rate': 4.515000000000001e-06, 'epoch': 3.1}


 79%|███████▉  | 3155/4000 [01:35<00:24, 34.38it/s]

{'loss': 0.0867, 'learning_rate': 4.265000000000001e-06, 'epoch': 3.15}


 80%|████████  | 3203/4000 [01:36<00:23, 34.55it/s]

{'loss': 0.0606, 'learning_rate': 4.0150000000000005e-06, 'epoch': 3.2}


 81%|████████▏ | 3255/4000 [01:38<00:22, 33.79it/s]

{'loss': 0.0612, 'learning_rate': 3.7650000000000004e-06, 'epoch': 3.25}


 83%|████████▎ | 3303/4000 [01:39<00:20, 33.59it/s]

{'loss': 0.082, 'learning_rate': 3.5150000000000002e-06, 'epoch': 3.3}


 84%|████████▍ | 3355/4000 [01:41<00:17, 35.96it/s]

{'loss': 0.0645, 'learning_rate': 3.2650000000000005e-06, 'epoch': 3.35}


 85%|████████▌ | 3403/4000 [01:42<00:16, 36.47it/s]

{'loss': 0.0624, 'learning_rate': 3.0150000000000004e-06, 'epoch': 3.4}


 86%|████████▋ | 3455/4000 [01:43<00:14, 36.96it/s]

{'loss': 0.0724, 'learning_rate': 2.7650000000000006e-06, 'epoch': 3.45}


 88%|████████▊ | 3503/4000 [01:45<00:14, 35.28it/s]

{'loss': 0.0823, 'learning_rate': 2.515e-06, 'epoch': 3.5}


 89%|████████▉ | 3555/4000 [01:46<00:13, 34.18it/s]

{'loss': 0.0794, 'learning_rate': 2.2650000000000003e-06, 'epoch': 3.55}


 90%|█████████ | 3603/4000 [01:47<00:11, 35.66it/s]

{'loss': 0.1033, 'learning_rate': 2.015e-06, 'epoch': 3.6}


 91%|█████████▏| 3655/4000 [01:49<00:09, 35.91it/s]

{'loss': 0.0904, 'learning_rate': 1.765e-06, 'epoch': 3.65}


 93%|█████████▎| 3703/4000 [01:50<00:08, 35.75it/s]

{'loss': 0.0548, 'learning_rate': 1.5150000000000001e-06, 'epoch': 3.7}


 94%|█████████▍| 3755/4000 [01:52<00:06, 36.62it/s]

{'loss': 0.0778, 'learning_rate': 1.2650000000000002e-06, 'epoch': 3.75}


 95%|█████████▌| 3803/4000 [01:53<00:05, 34.14it/s]

{'loss': 0.0682, 'learning_rate': 1.0150000000000002e-06, 'epoch': 3.8}


 96%|█████████▋| 3855/4000 [01:55<00:04, 36.17it/s]

{'loss': 0.0957, 'learning_rate': 7.650000000000001e-07, 'epoch': 3.85}


 98%|█████████▊| 3903/4000 [01:56<00:02, 34.66it/s]

{'loss': 0.0821, 'learning_rate': 5.15e-07, 'epoch': 3.9}


 99%|█████████▉| 3955/4000 [01:57<00:01, 35.82it/s]

{'loss': 0.0712, 'learning_rate': 2.65e-07, 'epoch': 3.95}


100%|██████████| 4000/4000 [01:59<00:00, 36.74it/s]

{'loss': 0.0917, 'learning_rate': 1.5000000000000002e-08, 'epoch': 4.0}



100%|██████████| 4000/4000 [01:59<00:00, 36.74it/s]

{'eval_loss': 0.18607589602470398, 'eval_accuracy': 0.937, 'eval_macro_f1': 0.9125431369879897, 'eval_runtime': 0.7278, 'eval_samples_per_second': 2748.124, 'eval_steps_per_second': 171.758, 'epoch': 4.0}


100%|██████████| 4000/4000 [02:00<00:00, 33.11it/s]


{'train_runtime': 120.8143, 'train_samples_per_second': 529.739, 'train_steps_per_second': 33.109, 'train_loss': 0.2095928320288658, 'epoch': 4.0}


100%|██████████| 125/125 [00:00<00:00, 164.23it/s]

Test accuracy: 0.9235
Test Macro-F1: 0.8740





In [19]:
trainer.save_model("final_model_directory")
tokenizer.save_pretrained("final_model_directory")


('final_model_directory\\tokenizer_config.json',
 'final_model_directory\\special_tokens_map.json',
 'final_model_directory\\vocab.txt',
 'final_model_directory\\added_tokens.json',
 'final_model_directory\\tokenizer.json')

In [22]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np

# Load the saved model and tokenizer
model_dir = "final_model_directory"  # or use the output_dir you used during training
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)

# Function to predict sentiment for an individual prompt
def predict_sentiment(text):
    # Tokenize the input text
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=64)
    
    # Get model predictions (logits)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    # Convert logits to probabilities using softmax
    probs = torch.nn.functional.softmax(logits, dim=-1)
    
    # Get the predicted label index
    pred_label_idx = int(np.argmax(probs.numpy(), axis=-1)[0])
    
    # Map label indices to label names (assume same ordering as in training)
    label_names = ["sadness", "joy", "love", "anger", "fear", "surprise"]  # update if needed
    pred_label = label_names[pred_label_idx]
    return pred_label, probs.numpy()[0]

# Test on individual prompts
text_examples = [
    "I feel so happy and excited about life!",
    "I'm really upset and angry with what happened.",
    "I love spending time with my family.",
    "I am terrified of the dark.",
    "The news left me completely shocked.",
    "I'm feeling a bit down today.",
    "I want a burger."
]

for text in text_examples:
    label, prob = predict_sentiment(text)
    print(f"Text: {text}\nPredicted Sentiment: {label}\nProbabilities: {prob}\n")


Text: I feel so happy and excited about life!
Predicted Sentiment: joy
Probabilities: [6.6874069e-05 9.9965417e-01 1.0660033e-04 5.1774790e-05 2.8712684e-05
 9.1870250e-05]

Text: I'm really upset and angry with what happened.
Predicted Sentiment: anger
Probabilities: [5.12196450e-03 3.49217822e-04 1.15422496e-04 9.91796374e-01
 2.46886816e-03 1.48101841e-04]

Text: I love spending time with my family.
Predicted Sentiment: joy
Probabilities: [0.05378823 0.5414191  0.07344555 0.30905557 0.01956067 0.00273079]

Text: I am terrified of the dark.
Predicted Sentiment: fear
Probabilities: [1.7986086e-04 1.3891693e-04 7.8151359e-05 4.0574523e-04 9.9888700e-01
 3.1029904e-04]

Text: The news left me completely shocked.
Predicted Sentiment: surprise
Probabilities: [4.7630229e-04 1.1978641e-03 3.7912233e-04 1.3507752e-03 2.1683846e-03
 9.9442756e-01]

Text: I'm feeling a bit down today.
Predicted Sentiment: sadness
Probabilities: [9.9840039e-01 4.4780798e-04 9.9361489e-05 5.6936336e-04 3.8814213