In [None]:
import pandas as pd
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import evaluate
import torch

MAX_LENGTH = 152
BATCH_SIZE = 32
EPOCHS = 3
BASE_MODEL_NAME = "distilbert-base-uncased"
FINE_TUNED_MULTILABEL_MODEL_DIR = "fine_tuned_multilabel_model"

# Auto-clean CSV load
df = pd.read_csv(
    "sarcasm_data.csv",
    quotechar='"',
    escapechar='\\',
    engine="python",
    on_bad_lines="skip"
)

# Drop rows where any label is missing or invalid
for col in ["sarcastic", "irony", "multipolarity"]:
    if col in df.columns:
        df[col] = pd.to_numeric(df[col], errors="coerce")  # invalid -> NaN
    else:
        raise KeyError(f"Column '{col}' not found in CSV!")

# Drop rows with NaN in any of the label columns
df = df.dropna(subset=["sarcastic", "irony", "multipolarity"])

# Convert to int after dropping NaNs
df[["sarcastic", "irony", "multipolarity"]] = df[["sarcastic", "irony", "multipolarity"]].astype(int)

print(f"Dataset cleaned: {len(df)} rows remaining")

# Convert to Dataset
dataset = Dataset.from_pandas(df)

# Tokenization
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)

def tokenize(batch):
    return tokenizer(
        batch["review"],
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH
    )

dataset = dataset.map(tokenize, batched=True)

# Create multi-label "labels" column as float
dataset = dataset.map(
    lambda x: {
        "labels": [
            float(x["sarcastic"]),
            float(x["irony"]),
            float(x["multipolarity"])
        ]
    }
)

# Now set format for PyTorch
dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

# Split train/test
dataset = dataset.train_test_split(test_size=0.2)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]

# Metrics
f1 = evaluate.load("f1")
accuracy = evaluate.load("accuracy")

def compute_metrics(pred):
    logits, labels = pred
    preds = (torch.sigmoid(torch.tensor(logits)) > 0.5).int().numpy()
    labels = labels.numpy()
    return {
        "accuracy": accuracy.compute(predictions=preds, references=labels)["accuracy"],
        "f1": f1.compute(predictions=preds, references=labels, average="macro")["f1"]
    }

# Model (multi-label)
model = AutoModelForSequenceClassification.from_pretrained(
    BASE_MODEL_NAME, 
    num_labels=3,   # sarcasm, irony, multipolarity
    problem_type="multi_label_classification"
)

# Training Arguments
training_args = TrainingArguments(
    output_dir="multilabel_results",
    evaluation_strategy="steps",
    eval_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    learning_rate=2e-5,
    per_device_train_batch_size=BATCH_SIZE,  
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    weight_decay=0.01,
    load_best_model_at_end=True,
    logging_steps=1000,
    logging_dir="multilabel_results/logs"
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

# Train & Save
trainer.train()
trainer.save_model(FINE_TUNED_MULTILABEL_MODEL_DIR)
tokenizer.save_pretrained(FINE_TUNED_MULTILABEL_MODEL_DIR)
print(f"Multi-label model (sarcasm/irony/multipolarity) saved to {FINE_TUNED_MULTILABEL_MODEL_DIR}")

Dataset cleaned: 10765 rows remaining




Map:   0%|          | 0/10765 [00:00<?, ? examples/s]

Map:   0%|          | 0/10765 [00:00<?, ? examples/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
The git executable must be specified in one of the following ways:
    - be included in your $PATH
    - be set via $GIT_PYTHON_GIT_EXECUTABLE
    - explicitly set via git.refresh(<full-path-to-git-executable>)

All git commands will error until this is rectified.

This initial message can be silenced or aggravated in the future by setting the
$GIT_PYTHON_REFRESH environment variable. Use one of the following values:
    - quiet|q|silence|s|silent|none|n|0: for no message or exception
    - error|e|exception|raise|r|2: for a raised exception

Example:
    export GIT_PYTHON_REFRESH=quiet



  0%|          | 0/810 [00:00<?, ?it/s]



{'train_runtime': 43044.6188, 'train_samples_per_second': 0.6, 'train_steps_per_second': 0.019, 'train_loss': 0.0424851264482663, 'epoch': 3.0}
Multi-label model (sarcasm/irony/multipolarity) saved to fine_tuned_multilabel_model
