In [6]:
# Import necessary libraries
import logging
import numpy as np
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    AutoConfig,
    DataCollatorWithPadding,
)
from datasets import load_dataset
from evaluate import load  # Updated for metric loading
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import precision_recall_fscore_support

In [7]:
# !pip install datasets transformers scikit-learn
# !pip install evaluate

In [9]:
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [15]:
# Configuration variables
model_name_or_path = 'cardiffnlp/twitter-roberta-base-sentiment'
train_file = 'data/training_data/train.csv'  # Local path to train CSV file
validation_file = 'data/training_data/train.csv'  # Local path to validation CSV file
output_dir = 'data/output_model'  # Local output directory for model artifacts
max_seq_length = 128
do_train = True
do_eval = True
pad_to_max_length = True

In [16]:
# Load the dataset from local CSV files
data_files = {"train": train_file, "validation": validation_file}
raw_datasets = load_dataset("csv", data_files=data_files)

Generating train split: 4682 examples [00:00, 102014.72 examples/s]
Generating validation split: 4682 examples [00:00, 260329.97 examples/s]


In [17]:
# Map labels to indices
label_list = raw_datasets['train'].unique('label')
label_list.sort()
num_labels = len(label_list)
label_map = {label: i for i, label in enumerate(label_list)}

In [18]:
# Convert labels to indices in dataset
raw_datasets = raw_datasets.map(lambda examples: {'label': label_map[examples['label']]})


Map: 100%|██████████| 4682/4682 [00:00<00:00, 29919.69 examples/s]
Map: 100%|██████████| 4682/4682 [00:00<00:00, 27647.99 examples/s]


In [19]:
# Calculate class weights for imbalanced data handling
labels = raw_datasets['train']['label']
class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(labels), y=labels)
class_weights = torch.tensor(class_weights, dtype=torch.float)

In [21]:
# Load the tokenizer and model with the appropriate configuration
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name_or_path, 
    config=config, 
    ignore_mismatched_sizes=True  # Added to handle size mismatch
)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment and are newly initialized because the shapes did not match:
- classifier.out_proj.weight: found shape torch.Size([3, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
- classifier.out_proj.bias: found shape torch.Size([3]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [24]:
# Tokenize the dataset
def preprocess_function(examples):
    return tokenizer(examples['emailtext'], padding="max_length" if pad_to_max_length else False, truncation=True, max_length=max_seq_length)

tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer)

Map: 100%|██████████| 4682/4682 [00:00<00:00, 12878.54 examples/s]
Map: 100%|██████████| 4682/4682 [00:00<00:00, 14069.03 examples/s]


In [25]:
# Define a custom Trainer class to handle class weights
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights.to(logits.device))
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss


In [33]:
# Define custom metrics for evaluation, focusing on the "Negative" class
def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    
    # Calculate precision, recall, and F1 score for label 1 (Negative)
    precision, recall, f1, _ = precision_recall_fscore_support(
        p.label_ids, preds, average=None, labels=[1]
    )
    
    return {
        "precision_negative": precision[0],
        "recall_negative": recall[0],
        "f1_negative": f1[0]
    }

In [34]:
# Training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    evaluation_strategy="epoch" if do_eval else "no",
    save_strategy="epoch",
    logging_dir=f"{output_dir}/logs",
    logging_steps=100,
)

In [35]:
# Initialize Trainer with class weighting
trainer = WeightedTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"] if do_train else None,
    eval_dataset=tokenized_datasets["validation"] if do_eval else None,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [36]:
# Train the model
if do_train:
    trainer.train()
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 
[A                                              

 33%|███▎      | 293/879 [39:51<35:59,  3.68s/it]  
[A

{'loss': 0.8649, 'grad_norm': 0.051305219531059265, 'learning_rate': 4.431171786120592e-05, 'epoch': 0.34}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 
[A                                              

 33%|███▎      | 293/879 [47:19<35:59,  3.68s/it]
[A

{'loss': 0.5223, 'grad_norm': 0.1161932498216629, 'learning_rate': 3.862343572241183e-05, 'epoch': 0.68}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[

{'eval_loss': 0.7862402200698853, 'eval_precision_negative': 0.0, 'eval_recall_negative': 0.0, 'eval_f1_negative': 0.0, 'eval_runtime': 341.3593, 'eval_samples_per_second': 13.716, 'eval_steps_per_second': 0.858, 'epoch': 1.0}



[A
[A
[A
[A
[A
[A
                                                 
[A                                              

 33%|███▎      | 293/879 [1:00:11<35:59,  3.68s/it]
[A

{'loss': 0.7732, 'grad_norm': 0.13836641609668732, 'learning_rate': 3.293515358361775e-05, 'epoch': 1.02}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                   
[A                                              

 33%|███▎      | 293/879 [1:07:21<35:59,  3.68s/it]
[A

{'loss': 0.6976, 'grad_norm': 0.11379458010196686, 'learning_rate': 2.7246871444823664e-05, 'epoch': 1.37}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                   
[A                                              

 33%|███▎      | 293/879 [1:14:22<35:59,  3.68s/it]
[A

{'loss': 0.6056, 'grad_norm': 16.700416564941406, 'learning_rate': 2.1558589306029582e-05, 'epoch': 1.71}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

{'eval_loss': 0.7453038692474365, 'eval_precision_negative': 0.0, 'eval_recall_negative': 0.0, 'eval_f1_negative': 0.0, 'eval_runtime': 340.4139, 'eval_samples_per_second': 13.754, 'eval_steps_per_second': 0.861, 'epoch': 2.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                   
[A                                                

 33%|███▎      | 293/879 [1:27:02<35:59,  3.68s/it]
[A

{'loss': 0.4342, 'grad_norm': 0.07788524776697159, 'learning_rate': 1.5870307167235497e-05, 'epoch': 2.05}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                   
[A                                                

 33%|███▎      | 293/879 [1:34:04<35:59,  3.68s/it]
[A

{'loss': 0.7107, 'grad_norm': 0.11831828951835632, 'learning_rate': 1.0182025028441412e-05, 'epoch': 2.39}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                   
[A                                                

 33%|███▎      | 293/879 [1:41:05<35:59,  3.68s/it]
[A

{'loss': 0.5882, 'grad_norm': 17.875883102416992, 'learning_rate': 4.493742889647327e-06, 'epoch': 2.73}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[

{'eval_loss': 0.6755397319793701, 'eval_precision_negative': 0.0, 'eval_recall_negative': 0.0, 'eval_f1_negative': 0.0, 'eval_runtime': 345.2898, 'eval_samples_per_second': 13.56, 'eval_steps_per_second': 0.849, 'epoch': 3.0}


                                                   
[A                                                

 33%|███▎      | 293/879 [1:52:32<35:59,  3.68s/it]
100%|██████████| 879/879 [1:20:27<00:00,  5.49s/it]


{'train_runtime': 4827.0953, 'train_samples_per_second': 2.91, 'train_steps_per_second': 0.182, 'train_loss': 0.6694491042483246, 'epoch': 3.0}


In [37]:
# Evaluate the model
if do_eval:
    eval_result = trainer.evaluate()
    logger.info(f"Evaluation result: {eval_result}")


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A