In [None]:
# AHMED AMIR RUSRUS
# 221001678

In [None]:
# data extraction

In [None]:
import os
import zipfile
import tarfile
import pandas as pd
import json

# 1️⃣ Extraction Step
base_dir = r"C:\Users\RusRus\Desktop\datasetsnlp"
os.makedirs(base_dir, exist_ok=True)

datasets = {
    "redditsarcasm": os.path.join(base_dir, "redditsarcasm.zip"),
    "goemotions": os.path.join(base_dir, "goemotions.zip"),
    "dailydialogue": os.path.join(base_dir, "dailydialogue.zip"),
    "empatheticdialogues": os.path.join(base_dir, "empatheticdialogues.tar.gz")
}

extract_dirs = {}

for name, archive_path in datasets.items():
    extract_to = os.path.join(base_dir, name)
    os.makedirs(extract_to, exist_ok=True)
    extract_dirs[name] = extract_to

    if archive_path.endswith(".zip"):
        with zipfile.ZipFile(archive_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to)
        print(f"✅ Extracted ZIP: {name}")
    elif archive_path.endswith(".tar.gz"):
        with tarfile.open(archive_path, 'r:gz') as tar_ref:
            tar_ref.extractall(extract_to)
        print(f"✅ Extracted TAR.GZ: {name}")
    else:
        print(f"❌ Unknown archive format for {name}")

# 2️⃣ Recursive Validation Function (Fixed)
def validate_files_recursively(root_dir):
    for root, dirs, files in os.walk(root_dir):
        for fname in files:
            fpath = os.path.join(root, fname)
            try:
                if fname.endswith(".csv"):
                    try:
                        df = pd.read_csv(fpath, engine="python", on_bad_lines='skip')
                        print(f"  ✅ {fname}: {df.shape}, columns: {list(df.columns)}")
                    except Exception as e:
                        print(f"  ❌ CSV Error in {fname}: {e}")
                elif fname.endswith(".tsv"):
                    try:
                        df = pd.read_csv(fpath, sep="\t", engine="python", on_bad_lines='skip')
                        print(f"  ✅ {fname}: {df.shape}, columns: {list(df.columns)}")
                    except Exception as e:
                        print(f"  ❌ TSV Error in {fname}: {e}")
                elif fname.endswith(".json"):
                    try:
                        with open(fpath, "r", encoding="utf-8") as f:
                            data = json.load(f)
                            summary = f"{len(data)} items" if isinstance(data, list) else f"{len(data.keys())} keys"
                            print(f"  ✅ {fname}: JSON loaded ({summary})")
                    except Exception as e:
                        print(f"  ❌ JSON Error in {fname}: {e}")
            except Exception as outer:
                print(f"  ❌ Error reading {fname}: {outer}")

# 3️⃣ Validate All Datasets Recursively
for name, path in extract_dirs.items():
    print(f"\n📁 Validating dataset: {name}")
    validate_files_recursively(path)


In [None]:
# REDDIT SRCASM PREPROCESSING

In [None]:
import os
import pandas as pd
import json

# Path to dataset
sarcasm_dir = r"C:\Users\RusRus\Desktop\datasetsnlp\redditsarcasm"

# Output file
output_file = "processed/redditsarcasm.jsonl"
os.makedirs(os.path.dirname(output_file), exist_ok=True)

# Read and process function
def process_sarcasm_file(file_path):
    try:
        df = pd.read_csv(file_path, on_bad_lines='skip')
        data = []

        for _, row in df.iterrows():
            input_text = row.get("comment", "")
            label = row.get("label", None)
            if input_text and label in [0, 1]:
                data.append({
                    "input": str(input_text),
                    "response": None,
                    "sarcasm": bool(label),
                    "emotion": None
                })
        return data
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return []

# Merge all parts
all_data = []
for filename in ["train-balanced-sarcasm.csv", "test-balanced.csv", "test-unbalanced.csv"]:
    file_path = os.path.join(sarcasm_dir, filename)
    if os.path.exists(file_path):
        all_data.extend(process_sarcasm_file(file_path))
    else:
        print(f"File not found: {file_path}")

# Save to JSONL
with open(output_file, "w", encoding="utf-8") as f:
    for item in all_data:
        f.write(json.dumps(item) + "\n")

print(f"✅ Processed {len(all_data)} entries from RedditSarcasm into {output_file}")


In [None]:
# GOEMOTIONS PREPROCESSING

In [None]:
import os
import pandas as pd
import json

# Paths
goemotions_dir = r"C:\Users\RusRus\Desktop\datasetsnlp\goemotions"
output_file = "processed/goemotions.jsonl"
os.makedirs(os.path.dirname(output_file), exist_ok=True)

# Load emotion mapping from label ID to emotion name
mapping_file = os.path.join(goemotions_dir, "emotion_mapping.json")
if os.path.exists(mapping_file):
    with open(mapping_file, "r", encoding="utf-8") as f:
        label_map = json.load(f)
else:
    # Fall back to default GoEmotions 28-label map if file not found
    label_map = {
        str(i): name for i, name in enumerate([
            "admiration", "amusement", "anger", "annoyance", "approval", "caring",
            "confusion", "curiosity", "desire", "disappointment", "disapproval",
            "disgust", "embarrassment", "excitement", "fear", "gratitude", "grief",
            "joy", "love", "nervousness", "optimism", "pride", "realization",
            "relief", "remorse", "sadness", "surprise", "neutral"
        ])
    }

# Process function
def process_tsv(file_path):
    try:
        df = pd.read_csv(file_path, sep="\t", header=None, names=["text", "labels", "annotator"], on_bad_lines='skip')
        data = []

        for _, row in df.iterrows():
            text = row["text"]
            raw_labels = str(row["labels"]).split(",")
            mapped = [label_map.get(lbl.strip(), None) for lbl in raw_labels if lbl.strip() in label_map]
            if text and mapped:
                data.append({
                    "input": text,
                    "response": None,
                    "sarcasm": None,
                    "emotion": mapped
                })
        return data
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return []

# Merge all parts
all_data = []
for split in ["train.tsv", "dev.tsv", "test.tsv"]:
    path = os.path.join(goemotions_dir, split)
    if os.path.exists(path):
        all_data.extend(process_tsv(path))
    else:
        print(f"File not found: {path}")

# Save
with open(output_file, "w", encoding="utf-8") as f:
    for item in all_data:
        f.write(json.dumps(item) + "\n")

print(f"✅ Processed {len(all_data)} entries from GoEmotions into {output_file}")


In [None]:
# empatheticdialogues preprocessing

In [None]:
import os
import pandas as pd
import json

print("📁 Processing: EmpatheticDialogues")

# Paths
input_dir = r"C:\Users\RusRus\Desktop\datasetsnlp\empatheticdialogues"
output_file = "processed/empatheticdialogues.jsonl"
os.makedirs(os.path.dirname(output_file), exist_ok=True)

all_data = []
skipped_files = 0

for filename in ["train.csv", "valid.csv", "test.csv"]:
    path = os.path.join(input_dir, filename)
    print(f"📄 Reading: {filename}")
    
    try:
        df = pd.read_csv(path, on_bad_lines='skip')
        for _, row in df.iterrows():
            utterance = str(row.get("utterance", "")).strip()
            emotion = str(row.get("context", "")).strip()
            if utterance and emotion:
                all_data.append({
                    "input": utterance,
                    "response": None,
                    "sarcasm": None,
                    "emotion": emotion
                })
    except Exception as e:
        print(f"❌ Error reading {filename}: {e}")
        skipped_files += 1

# Save
with open(output_file, "w", encoding="utf-8") as f:
    for item in all_data:
        f.write(json.dumps(item) + "\n")

print(f"✅ Saved {len(all_data)} entries to: {output_file}")
if skipped_files > 0:
    print(f"⚠️ Skipped {skipped_files} file(s) due to errors.")


In [None]:
# CSV INSPECTION

In [None]:
import os
import pandas as pd

def inspect_csv_columns(dataset_dir):
    for filename in ["train.csv", "valid.csv", "test.csv"]:
        path = os.path.join(dataset_dir, filename)
        if os.path.exists(path):
            try:
                df = pd.read_csv(path, on_bad_lines='skip', nrows=5)
                print(f"📄 {filename} → Columns: {list(df.columns)}")
                print(df.head(1), "\n")
            except Exception as e:
                print(f"❌ Could not read {filename}: {e}")
        else:
            print(f"❌ File not found: {filename}")

inspect_csv_columns(r"C:\Users\RusRus\Desktop\datasetsnlp\empatheticdialogues")


In [None]:
# DailyDialog preprocessing

In [None]:
import os
import json
import pandas as pd
import ast
from pathlib import Path

def process_dailydialogue(dataset_dir, output_file):
    """
    Processes the DailyDialog dataset where each row contains a stringified list of utterances.

    Args:
        dataset_dir (str): Path to the dataset folder containing train.csv, etc.
        output_file (str): Path to save the processed JSONL file.
    """
    print("📁 Processing: DailyDialog")

    csv_path = os.path.join(dataset_dir, "train.csv")
    if not os.path.exists(csv_path):
        print(f"❌ File not found: {csv_path}")
        return

    try:
        df = pd.read_csv(csv_path, on_bad_lines='skip')
    except Exception as e:
        print(f"❌ Error reading {csv_path}: {e}")
        return

    if 'dialog' not in df.columns:
        print("❌ Column 'dialog' not found in CSV.")
        return

    output_dir = os.path.dirname(output_file)
    os.makedirs(output_dir, exist_ok=True)

    with open(output_file, "w", encoding="utf-8") as out_f:
        count = 0
        skipped = 0
        for idx, row in df.iterrows():
            try:
                utterances = ast.literal_eval(row['dialog'])
                if not isinstance(utterances, list):
                    raise ValueError("Parsed 'dialog' is not a list")

                for utt in utterances:
                    if isinstance(utt, str) and utt.strip():
                        out_f.write(json.dumps({
                            "text": utt.strip(),
                            "source": "dailydialogue"
                        }, ensure_ascii=False) + "\n")
                        count += 1
            except Exception as e:
                skipped += 1

    print(f"✅ Saved {count} utterances to: {output_file}")
    if skipped > 0:
        print(f"⚠️ Skipped {skipped} rows due to formatting issues.")

# ✅ Run it
process_dailydialogue(
    dataset_dir=r"C:\Users\RusRus\Desktop\datasetsnlp\dailydialogue",
    output_file="processed/dailydialogue.jsonl"
)


In [None]:
# VALIDATION COLUMNS

In [None]:
import pandas as pd

dailydialog_path = r"C:\Users\RusRus\Desktop\datasetsnlp\dailydialogue\train.csv"  # Adjust if path is different
df = pd.read_csv(dailydialog_path, on_bad_lines='skip')
print("Columns:", df.columns.tolist())
print("First few rows:")
print(df.head())


In [None]:
###############################################################################################################################

In [None]:
#STEP 2 : THE REAL PREPROCESSING

In [None]:
import os
import json
import pandas as pd
import re
import nltk
from nltk.corpus import stopwords
from tqdm import tqdm

# Download stopwords if not already downloaded
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

# Directories
input_dir = r"C:\Users\RusRus\Desktop\datasetsnlp\processed"
output_dir = "processed_cleaned"
os.makedirs(output_dir, exist_ok=True)

# === Text Cleaning Function ===
def clean_text(text):
    text = str(text).lower()
    text = re.sub(r"[^a-zA-Z0-9\s.,!?']", " ", text)  # remove special characters but keep basic punctuation
    text = re.sub(r"\s+", " ", text).strip()  # normalize whitespace
    tokens = text.split()
    tokens = [token for token in tokens if token not in stop_words]
    return " ".join(tokens)

# === Process All Datasets in JSONL ===
def preprocess_all_cleaned():
    for filename in os.listdir(input_dir):
        if not filename.endswith(".jsonl"):
            continue

        input_path = os.path.join(input_dir, filename)
        output_path = os.path.join(output_dir, filename)

        print(f"\n📁 Processing file: {filename}")
        cleaned_data = []

        with open(input_path, "r", encoding="utf-8") as infile:
            for line in tqdm(infile, desc="  → Cleaning entries"):
                try:
                    entry = json.loads(line)

                    # Apply cleaning
                    entry["input"] = clean_text(entry.get("input", ""))
                    if entry.get("response"):
                        entry["response"] = clean_text(entry["response"])

                    cleaned_data.append(entry)
                except json.JSONDecodeError:
                    print(f"⚠️ Skipped invalid JSON line.")

        # Save cleaned version
        with open(output_path, "w", encoding="utf-8") as outfile:
            for item in cleaned_data:
                outfile.write(json.dumps(item) + "\n")

        print(f"✅ Saved: {output_path} ({len(cleaned_data)} entries)")

# === Run Script ===
preprocess_all_cleaned()


In [None]:
# MERGING THE DATASETS

In [None]:
import os
import json
from tqdm import tqdm

# Input and output paths
input_dir = r"C:\Users\RusRus\Desktop\datasetsnlp\processed_cleaned"
output_path = "final_dataset/merged_dataset.jsonl"
os.makedirs(os.path.dirname(output_path), exist_ok=True)

# Dataset-specific source labeling
DATASETS = {
    "dailydialogue.jsonl": "dailydialogue",
    "empatheticdialogues.jsonl": "empatheticdialogues",
    "goemotions.jsonl": "goemotions",
    "redditsarcasm.jsonl": "redditsarcasm"
}

merged_data = []

for file_name, source in DATASETS.items():
    input_path = os.path.join(input_dir, file_name)
    if not os.path.exists(input_path):
        print(f"❌ File missing: {input_path}")
        continue

    print(f"📁 Merging from: {input_path}")
    with open(input_path, "r", encoding="utf-8") as f:
        lines = f.readlines()

        # Limit Reddit Sarcasm to 250,000 samples
        if source == "redditsarcasm":
            lines = lines[:150000]

        for line in tqdm(lines, desc=f"Processing {source}"):
            try:
                entry = json.loads(line.strip())
                merged_data.append({
                    "input": entry.get("input", "").strip(),
                    "response": entry.get("response", None),
                    "sarcasm": entry.get("sarcasm", None),
                    "emotion": entry.get("emotion", None),
                    "source": source
                })
            except Exception as e:
                print(f"⚠️ Skipped line due to error: {e}")

# Save merged output
with open(output_path, "w", encoding="utf-8") as f:
    for item in merged_data:
        f.write(json.dumps(item) + "\n")

print(f"\n✅ Merged {len(merged_data)} entries to: {output_path}")


In [None]:
#✅ Merged 304068 entries to: final_dataset/merged_dataset.jsonl

In [None]:
# # # # # #  # # # # # # # # # # # #  # # # # # # # # # # # # # # 

In [None]:
#TRAINING THE UNIFIED_MULTITASK MODEL :STEP1 ,Data Filtering

In [None]:
import os
import json
from tqdm import tqdm
import re

# ✅ Updated paths for Kaggle
input_dir = "/kaggle/input/mnbjvghvu/datasetsnlp/final_merged"  # Upload this folder with merged_dataset.jsonl
input_file = os.path.join(input_dir, "merged_dataset.jsonl")
output_dir = "/kaggle/working/final_dataset_cleaned"
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, "merged_dataset_cleaned.jsonl")

# 🔧 Text cleaning functions
def normalize_text(text):
    text = text.lower()
    text = re.sub(r"http\S+|www\S+", "", text)
    text = re.sub(r"[^a-zA-Z0-9\s.,!?']", "", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text

def remove_stopwords(text, stopwords):
    tokens = text.split()
    tokens = [t for t in tokens if t not in stopwords]
    return " ".join(tokens)

# 🧠 Load stopwords
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS as sklearn_stopwords
stopwords = set(sklearn_stopwords)

# 📦 Processing loop
valid_entries = []
sarcasm_count = 0
emotion_count = 0
multitask_count = 0
skip_count = 0

print(f"📁 Processing file: {input_file}\n")
with open(input_file, "r", encoding="utf-8") as f:
    for line in tqdm(f, desc="→ Cleaning entries"):
        try:
            item = json.loads(line)

            # Defensive coding
            input_text = str(item.get("input") or "").strip()
            response = str(item.get("response") or "").strip()
            sarcasm = item.get("sarcasm")
            emotion = item.get("emotion")

            if not input_text:
                skip_count += 1
                continue

            # Clean text
            input_text = remove_stopwords(normalize_text(input_text), stopwords)
            if response:
                response = remove_stopwords(normalize_text(response), stopwords)

            # Count sample type
            if sarcasm is not None and emotion is not None:
                multitask_count += 1
            elif sarcasm is not None:
                sarcasm_count += 1
            elif emotion is not None:
                emotion_count += 1

            valid_entries.append({
                "input": input_text,
                "response": response if response else None,
                "sarcasm": sarcasm,
                "emotion": emotion
            })

        except Exception:
            skip_count += 1

# 💾 Save cleaned data
with open(output_file, "w", encoding="utf-8") as f:
    for item in valid_entries:
        f.write(json.dumps(item) + "\n")

print(f"\n✅ Saved cleaned dataset: {output_file} ({len(valid_entries)} entries)")
print(f"📊 Stats:")
print(f"   - Sarcasm-only samples: {sarcasm_count}")
print(f"   - Emotion-only samples: {emotion_count}")
print(f"   - Multitask (sarcasm + emotion): {multitask_count}")
print(f"   - Skipped invalid lines: {skip_count}")


In [None]:
📊 Stats:
   - Sarcasm-only samples: 149300
   - Emotion-only samples: 142793
   - Multitask (sarcasm + emotion): 0
   - Skipped invalid lines: 11975

In [None]:
########################################################################################################

In [None]:
# architecture : loading

In [None]:
# ✅ Step 5 - Cell 1: Load libraries, set seed, load dataset
import os
import json
import torch
import random
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split

# Transformers
from transformers import DistilBertTokenizerFast

# ✅ Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

# ✅ Load merged dataset
merged_path = "/kaggle/working/final_dataset_cleaned/merged_dataset_cleaned.jsonl"
with open(merged_path, "r", encoding="utf-8") as f:
    raw_data = [json.loads(line) for line in f]

print(f"✅ Loaded {len(raw_data):,} examples from {merged_path}")

# ✅ Check class distribution
sarcasm_count = sum(1 for x in raw_data if x["sarcasm"] is not None)
emotion_count = sum(1 for x in raw_data if x["emotion"] is not None)
multitask_count = sum(1 for x in raw_data if x["sarcasm"] is not None and x["emotion"] is not None)

print(f"📊 Distribution:")
print(f"   - Sarcasm examples: {sarcasm_count:,}")
print(f"   - Emotion examples: {emotion_count:,}")
print(f"   - Multitask examples (both labels): {multitask_count:,}")

# ✅ Initialize tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")


In [None]:
 Loaded 292,093 examples from /kaggle/working/final_dataset_cleaned/merged_dataset_cleaned.jsonl
📊 Distribution:
   - Sarcasm examples: 149,300
   - Emotion examples: 142,793

In [None]:
#Apply tokenization

#Build dataset objects

#Create DataLoaders

In [None]:
from torch.utils.data import Dataset, DataLoader

# ✅ Label mapping (you can update as needed)
emotion_label2id = {
    'admiration': 0, 'amusement': 1, 'anger': 2, 'annoyance': 3, 'approval': 4,
    'caring': 5, 'confusion': 6, 'curiosity': 7, 'desire': 8, 'disappointment': 9,
    'disapproval': 10, 'disgust': 11, 'embarrassment': 12, 'excitement': 13, 'fear': 14,
    'gratitude': 15, 'grief': 16, 'joy': 17, 'love': 18, 'nervousness': 19,
    'optimism': 20, 'pride': 21, 'realization': 22, 'relief': 23, 'remorse': 24,
    'sadness': 25, 'surprise': 26, 'neutral': 27
}

# ✅ Tokenize and prepare features
def tokenize_entry(entry):
    encodings = tokenizer(
        entry["input"],
        truncation=True,
        padding="max_length",
        max_length=128,
        return_tensors="pt"
    )
    encoding = {key: val.squeeze(0) for key, val in encodings.items()}

    # Handle sarcasm
    encoding["sarcasm"] = int(entry["sarcasm"]) if entry.get("sarcasm") is not None else -1

    # Handle emotion (list, int, or str)
    emotion = entry.get("emotion")
    if isinstance(emotion, list) and emotion:
        label = emotion[0]
    elif isinstance(emotion, str):
        label = emotion
    elif isinstance(emotion, int):
        encoding["emotion"] = emotion
        return encoding
    else:
        encoding["emotion"] = -1
        return encoding

    encoding["emotion"] = emotion_label2id.get(label, -1)
    return encoding

# ✅ Dataset class
class MultiTaskDataset(Dataset):
    def __init__(self, data):
        self.data = [tokenize_entry(x) for x in tqdm(data)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# ✅ Train/Val split
train_data, val_data = train_test_split(
    raw_data,
    test_size=0.1,
    random_state=42,
    shuffle=True
)

# ✅ Build datasets and dataloaders
train_dataset = MultiTaskDataset(train_data)
val_dataset = MultiTaskDataset(val_data)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

print(f"✅ DataLoaders ready: Train = {len(train_dataset)}, Val = {len(val_dataset)}")


In [None]:
100%|██████████| 262883/262883 [01:04<00:00, 4059.14it/s]
100%|██████████| 29210/29210 [00:07<00:00, 3763.12it/s]

✅ DataLoaders ready: Train = 262883, Val = 29210


In [None]:
#A shared DistilBERT encoder

#Two task-specific heads:

    #One for sarcasm detection (binary classification)

    #One for emotion classification (multi-class)

In [None]:
import torch
import torch.nn as nn
from transformers import DistilBertModel

class MultiTaskModel(nn.Module):
    def __init__(self, num_emotions: int):
        super(MultiTaskModel, self).__init__()
        self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")

        # ✅ Sarcasm classification head
        self.sarcasm_head = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 2)  # Binary classification
        )

        # ✅ Emotion classification head
        self.emotion_head = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_emotions)  # Multi-class classification
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]  # CLS token

        sarcasm_logits = self.sarcasm_head(pooled_output)
        emotion_logits = self.emotion_head(pooled_output)

        return sarcasm_logits, emotion_logits


# ✅ Define number of emotion classes
num_emotions = len(emotion_label2id)

# ✅ Instantiate base model
base_model = MultiTaskModel(num_emotions)

# ✅ Detect GPU(s)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🖥️ Available GPUs: {torch.cuda.device_count()} - Using: {device}")

# ✅ Wrap in DataParallel if 2 GPUs are available
if torch.cuda.device_count() > 1:
    print("🚀 Using DataParallel for multi-GPU training")
    model = nn.DataParallel(base_model)
else:
    model = base_model

# ✅ Send to device
model.to(device)
print(f"✅ Model ready and sent to device: {device}")


In [None]:
#Weighted CrossEntropyLoss for emotion classification (to handle imbalance)

#CrossEntropyLoss for sarcasm (standard binary)

#AdamW optimizer with weight decay

#Linear learning rate scheduler with warmup

In [None]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from collections import Counter
import numpy as np

# ✅ Compute emotion class weights to handle imbalance
emotion_counts = Counter()
for item in train_data:
    if isinstance(item["emotion"], list) and item["emotion"]:
        emotion_counts[item["emotion"][0]] += 1
    elif isinstance(item["emotion"], str):
        emotion_counts[item["emotion"]] += 1

emotion_weights = []
total = sum(emotion_counts.values())
for label in emotion_label2id:
    count = emotion_counts.get(label, 1)  # avoid division by zero
    emotion_weights.append(total / count)

emotion_weights = torch.tensor(emotion_weights, dtype=torch.float).to(device)

# ✅ Define loss functions
sarcasm_criterion = nn.CrossEntropyLoss()
emotion_criterion = nn.CrossEntropyLoss(weight=emotion_weights)

# ✅ Optimizer — use `.parameters()` from model (wrapped or not)
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

# ✅ Learning rate scheduler with warmup
def lr_lambda(current_step):
    warmup_steps = 500
    total_steps = 10000  # adjust based on dataset size and batch size
    if current_step < warmup_steps:
        return float(current_step) / float(max(1, warmup_steps))
    return max(0.0, float(1 - (current_step - warmup_steps) / (total_steps - warmup_steps)))

scheduler = LambdaLR(optimizer, lr_lambda)

print("✅ Loss functions, optimizer, and scheduler are ready.")


In [None]:
#Training Loop + Validation + Early Stopping

#This includes:

    #Epoch-based training with validation at each epoch

    #Logging loss and accuracy for both sarcasm and emotion tasks

    #Early stopping if validation doesn’t improve for N epochs

In [None]:
# CUDA DEVICE CHECK

In [None]:
if torch.cuda.device_count() > 1:
    print(f"✅ Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)

model.to(device)
model.device = device  # ⬅️ custom attribute used in training loop

In [None]:
from sklearn.metrics import accuracy_score
import copy
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
from torch.amp import autocast, GradScaler

# 🧠 Curriculum Helper
def sort_dataset_by_input_length(dataset):
    return sorted(dataset, key=lambda x: (x['input_ids'] != 0).sum(), reverse=False)

def train_model(model, train_loader, val_loader, epochs=5, patience=2):
    scaler = GradScaler('cuda')
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = copy.deepcopy(model.state_dict())
    
    # Get device from model properly
    device = next(model.parameters()).device

    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        train_sarcasm_preds, train_sarcasm_labels = [], []
        train_emotion_preds, train_emotion_labels = [], []

        for batch in tqdm(train_loader, desc=f"🧠 Training Epoch {epoch+1}"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            sarcasm_labels = batch["sarcasm"].to(device)
            emotion_labels = batch["emotion"].to(device)

            optimizer.zero_grad()

            with autocast('cuda'):
                sarcasm_logits, emotion_logits = model(input_ids=input_ids, attention_mask=attention_mask)

                valid_sarcasm_mask = sarcasm_labels != -1
                valid_sarcasm_logits = sarcasm_logits[valid_sarcasm_mask]
                valid_sarcasm_labels = sarcasm_labels[valid_sarcasm_mask]

                valid_emotion_mask = emotion_labels != -1
                valid_emotion_logits = emotion_logits[valid_emotion_mask]
                valid_emotion_labels = emotion_labels[valid_emotion_mask]

                loss_sarcasm = sarcasm_criterion(valid_sarcasm_logits, valid_sarcasm_labels)
                loss_emotion = emotion_criterion(valid_emotion_logits, valid_emotion_labels)
                loss = loss_sarcasm + loss_emotion

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            total_train_loss += loss.item()

            train_sarcasm_preds += torch.argmax(valid_sarcasm_logits, dim=1).cpu().tolist()
            train_sarcasm_labels += valid_sarcasm_labels.cpu().tolist()
            train_emotion_preds += torch.argmax(valid_emotion_logits, dim=1).cpu().tolist()
            train_emotion_labels += valid_emotion_labels.cpu().tolist()

        train_sarcasm_acc = accuracy_score(train_sarcasm_labels, train_sarcasm_preds)
        train_emotion_acc = accuracy_score(train_emotion_labels, train_emotion_preds)

        # 📊 Validation
        model.eval()
        total_val_loss = 0
        val_sarcasm_preds, val_sarcasm_labels = [], []
        val_emotion_preds, val_emotion_labels = [], []

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"🔍 Validating Epoch {epoch+1}"):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                sarcasm_labels = batch["sarcasm"].to(device)
                emotion_labels = batch["emotion"].to(device)

                with autocast('cuda'):
                    sarcasm_logits, emotion_logits = model(input_ids=input_ids, attention_mask=attention_mask)

                    valid_sarcasm_mask = sarcasm_labels != -1
                    valid_sarcasm_logits = sarcasm_logits[valid_sarcasm_mask]
                    valid_sarcasm_labels = sarcasm_labels[valid_sarcasm_mask]

                    valid_emotion_mask = emotion_labels != -1
                    valid_emotion_logits = emotion_logits[valid_emotion_mask]
                    valid_emotion_labels = emotion_labels[valid_emotion_mask]

                    loss_sarcasm = sarcasm_criterion(valid_sarcasm_logits, valid_sarcasm_labels)
                    loss_emotion = emotion_criterion(valid_emotion_logits, valid_emotion_labels)
                    loss = loss_sarcasm + loss_emotion

                total_val_loss += loss.item()

                val_sarcasm_preds += torch.argmax(valid_sarcasm_logits, dim=1).cpu().tolist()
                val_sarcasm_labels += valid_sarcasm_labels.cpu().tolist()
                val_emotion_preds += torch.argmax(valid_emotion_logits, dim=1).cpu().tolist()
                val_emotion_labels += valid_emotion_labels.cpu().tolist()

        val_sarcasm_acc = accuracy_score(val_sarcasm_labels, val_sarcasm_preds)
        val_emotion_acc = accuracy_score(val_emotion_labels, val_emotion_labels)

        print(f"\n📉 Epoch {epoch+1} Summary:")
        print(f"   Train Loss: {total_train_loss:.4f} | Sarcasm Acc: {train_sarcasm_acc:.4f} | Emotion Acc: {train_emotion_acc:.4f}")
        print(f"   Val Loss:   {total_val_loss:.4f} | Sarcasm Acc: {val_sarcasm_acc:.4f} | Emotion Acc: {val_emotion_acc:.4f}")

        if total_val_loss < best_val_loss:
            best_val_loss = total_val_loss
            best_model_state = copy.deepcopy(model.state_dict())
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("⏹️ Early stopping triggered.")
                break

        torch.cuda.empty_cache()

    model.load_state_dict(best_model_state)
    return model

# ✅ SAFE ENTRY POINT
if __name__ == "__main__":
    # Optional: suppress HuggingFace tokenizers fork warning
    import os
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

    sorted_train_data = sort_dataset_by_input_length(train_dataset)
    train_loader = DataLoader(
        sorted_train_data,
        batch_size=8,
        shuffle=False,  # ⚠️ Sorted, so don't shuffle
        num_workers=2,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=8,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
    trained_model = train_model(model, train_loader, val_loader, epochs=10)