In [None]:
#!/usr/bin/env python
# coding: utf-8

# ==============================================================================
# SCRIPT: V30_Training_Script.py
#
# PURPOSE:
# This script trains a multi-class classification model (V30) for emotion
# recognition. It is an updated version of the V29 script, modified to
# incorporate new 'speech_action' and 'hard_case' labels to improve
# robustness and handle ambiguity.
# ==============================================================================

#%%
# V30 changes:
    # section #1 - Added 'speech_action' and 'hard_case' to the list of labels.
    # section #2 - Updated TargetedSmoothedCrossEntropyLoss to also disable smoothing for the new labels.
    # section #8 - Added new labels to the targeted augmentation list.
    # section #10 - Added new labels to the weighted sampling list.
    # overview: Fully integrate 'speech_action' and 'hard_case' into the training pipeline.

#%%
# --------------------------
# 0. Imports
# --------------------------
# Standard Library Imports
import datasets
import csv
import gc
import glob
import multiprocessing as mp
import os
import random
import re
import shutil
import subprocess
import sys
import time

# Third-Party Imports
import accelerate
import dill
import face_recognition
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import transformers

# From Imports
from collections import Counter
from datasets import ClassLabel, Dataset, Features, Image as DatasetsImage, concatenate_datasets, load_dataset
from datetime import datetime
from functools import partial
from imagehash import phash, hex_to_hash
from io import BytesIO
from pathlib import Path
from PIL import Image, ImageOps, ExifTags, UnidentifiedImageError
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix, log_loss, precision_recall_fscore_support
from torch import nn
from torch.nn import functional as F
from torch.optim import AdamW, LBFGS
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torchvision import transforms
from torchvision.transforms import (
    GaussianBlur,
    RandAugment,
    RandomAffine,
    RandomApply,
    RandomPerspective,
    RandomAdjustSharpness,
    ToPILImage,
    ToTensor
)
from tqdm import tqdm
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    EarlyStoppingCallback,
    TrainingArguments,
    Trainer,
)


#%%
# --------------------------
# 1. Global Configurations
# --------------------------
RUN_INFERENCE = True  # Toggle this off to disable running inference

# MODIFICATION: Instruct user to update this path to their new consolidated dataset.
# --- IMPORTANT: Update this path to your new dataset folder ---
# This folder should contain subdirectories for the original 10 emotions
# PLUS the new 'speech_action' and 'hard_case' folders.
IMAGE_DIR = "/Users/natalyagrokh/AI/ml_expressions/img_datasets/consolidated_dataset_for_v30"
BASE_PATH = IMAGE_DIR
MODEL_ROOT = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training"

# MODIFICATION: Added 'speech_action' and 'hard_case' to the list of labels.
LABEL_NAMES = [
    'anger', 'disgust', 'fear', 'happiness', 'neutral',
    'questioning', 'sadness', 'surprise', 'contempt', 'unknown',
    'speech_action', 'hard_case'
]
id2label = dict(enumerate(LABEL_NAMES))
label2id = {v: k for k, v in id2label.items()}

HARD_CLASS_NAMES = ['contempt', 'disgust', 'fear', 'questioning']
hard_class_ids = [label2id[n] for n in HARD_CLASS_NAMES]

VALID_EXTENSIONS = (".jpg", ".jpeg", ".png", ".tif", ".tiff")

def is_valid_image(filename):
    return filename.lower().endswith(VALID_EXTENSIONS) and not filename.startswith("._")

label_mapping = {name.lower(): name for name in LABEL_NAMES}

# üî¢ Dynamically determine the next version
def get_next_version(base_dir):
    all_entries = glob.glob(os.path.join(base_dir, "V*_*"))
    existing = [
        os.path.basename(d) for d in all_entries if os.path.isdir(d)
    ]
    versions = [
        int(d[1:].split("_")[0]) for d in existing
        if d.startswith("V") and "_" in d and d[1:].split("_")[0].isdigit()
    ]
    next_version = max(versions, default=0) + 1
    return f"V{next_version}"

# Automatically create a versioned output folder
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
VERSION = get_next_version("/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training")
VERSION_TAG = VERSION + "_" + timestamp
SAVE_DIR = os.path.join("/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training", VERSION_TAG)
LOGITS_PATH = os.path.join(SAVE_DIR, f"logits_eval_{VERSION}.npy")
LABELS_PATH = os.path.join(SAVE_DIR, f"labels_eval_{VERSION}.npy")
os.makedirs(SAVE_DIR, exist_ok=True)
print(f"üìÅ Output directory created: {SAVE_DIR}")


#%%
# --------------------------
# 2. Utility Functions (Metrics & Calibration)
# --------------------------

# ------------------------------------------
# Part A: Data Preparation & Augmentation
# ------------------------------------------

# üó∫Ô∏è Injects 'image_path' to dataset BEFORE any map/filter
def add_image_path(example):
    img_obj = example["image"]
    path = getattr(img_obj, "filename", None)
    if path is None:
        if "file" in example:
            path = os.path.join(BASE_PATH, example["file"])
        else:
            path = ""
    example["image_path"] = path
    return example

# üè∑Ô∏è Standardizes labels from various sources (int, str, filepath) to a consistent integer ID.
def reconcile_labels(example):
    label = example.get("label", None)
    if isinstance(label, int):
        original_label = dataset.features["label"].int2str(label).strip().lower()
    elif isinstance(label, str):
        original_label = label.strip().lower()
    else:
        file_path = example["image_path"]
        original_label = os.path.basename(os.path.dirname(file_path)).lower() if file_path else None
    
    pretrain_label = label_mapping.get(original_label)
    example["label"] = label2id[pretrain_label] if pretrain_label is not None else -1
    return example

# üîç Computes a perceptual hash (pHash) for an image to find visually similar duplicates.
def compute_hash(image_path):
    try:
        img = Image.open(image_path).convert("L").resize((64, 64))
        return str(phash(img))
    except Exception:
        return None

# üì¶ Applies augmentations and processes images on-the-fly for each batch.
class DataCollatorWithAugmentation:
    def __init__(self, processor, augment_dict):
        self.processor = processor
        self.augment_dict = augment_dict

    def __call__(self, features):
        processed_images = []
        for x in features:
            label = x["label"]
            aug_pipeline = self.augment_dict.get(label, data_augment)
            rgb_image = x["image"].convert("RGB")
            augmented_image = aug_pipeline(rgb_image)
            processed_images.append(augmented_image)

        batch = self.processor(
            images=processed_images,
            return_tensors="pt"
        )
        
        batch["labels"] = torch.tensor([x["label"] for x in features], dtype=torch.long)
        return batch

# ------------------------------------------
# Part B: Model & Training Components
# ------------------------------------------
      
# üèãÔ∏è Defines a custom Trainer that uses targeted loss function
class CustomLossTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_fct = TargetedSmoothedCrossEntropyLoss(smoothing=0.05)

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss = self.loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss
        
# üîÑ Implements Cross-Entropy Loss with Targeted Label Smoothing for multiple classes
# Smoothing turned OFF for 'contempt', 'disgust', and now 'speech_action', 'hard_case'
class TargetedSmoothedCrossEntropyLoss(nn.Module):
    def __init__(self, smoothing=0.05):
        super().__init__()
        self.smoothing = smoothing
        # MODIFICATION: Added 'speech_action' and 'hard_case' to ensure the model
        # makes confident, sharp predictions for these critical new classes.
        self.target_class_ids = [label2id['contempt'], label2id['disgust'], label2id['speech_action'], label2id['hard_case']]

    def forward(self, logits, target):
        num_classes = logits.size(1)
        with torch.no_grad():
            smooth_labels = torch.full_like(logits, self.smoothing / (num_classes - 1))
            smooth_labels.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
            
            target_mask = torch.isin(target, torch.tensor(self.target_class_ids, device=target.device))
            
            if target_mask.any():
                sharp_labels = F.one_hot(target[target_mask], num_classes=num_classes).float()
                smooth_labels[target_mask] = sharp_labels
        
        log_probs = F.log_softmax(logits, dim=1)
        loss = -(smooth_labels * log_probs).sum(dim=1).mean()
        return loss

# üìä Compute Metrics with Confusion Matrix Logging
def compute_metrics_with_confusion(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)

    print("\nClassification Report:")
    report = classification_report(labels, preds, target_names=LABEL_NAMES, output_dict=True)
    print(classification_report(labels, preds, target_names=LABEL_NAMES))

    np.save(os.path.join(SAVE_DIR, f"logits_eval_{VERSION}.npy"), logits)
    np.save(os.path.join(SAVE_DIR, f"labels_eval_{VERSION}.npy"), labels)

    f1s = [report[name]["f1-score"] for name in LABEL_NAMES]
    recalls = [report[name]["recall"] for name in LABEL_NAMES]
    precisions = [report[name]["precision"] for name in LABEL_NAMES]

    softmax_probs = F.softmax(torch.tensor(logits), dim=-1)
    entropies = -torch.sum(softmax_probs * torch.log(softmax_probs + 1e-12), dim=-1)
    entropy_per_class = []
    for idx, class_name in enumerate(LABEL_NAMES):
        mask = (np.array(labels) == idx)
        if mask.any():
            class_entropy = entropies[mask].mean().item()
            entropy_per_class.append((class_name, class_entropy))
        else:
            entropy_per_class.append((class_name, 0.0))
    sorted_entropy = sorted(entropy_per_class, key=lambda x: x[1], reverse=True)

    epoch_metrics_path = os.path.join(SAVE_DIR, "per_class_metrics.csv")
    epoch = getattr(trainer.state, "epoch", None) if "trainer" in globals() else None
    df_row = pd.DataFrame({
        "epoch": [epoch],
        **{f"f1_{n}": [f] for n, f in zip(LABEL_NAMES, f1s)},
        **{f"recall_{n}": [r] for n, r in zip(LABEL_NAMES, recalls)},
        **{f"precision_{n}": [p] for n, p in zip(LABEL_NAMES, precisions)},
        **{f"entropy_{n}": [e] for n, e in entropy_per_class}
    })
    if os.path.exists(epoch_metrics_path):
        df_row.to_csv(epoch_metrics_path, mode="a", header=False, index=False)
    else:
        df_row.to_csv(epoch_metrics_path, mode="w", header=True, index=False)

    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(10, 8)) # Increased figure size for more labels
    sns.heatmap(
        cm, annot=True, fmt="d", cmap="Blues",
        xticklabels=LABEL_NAMES,
        yticklabels=LABEL_NAMES
    )
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, f"confusion_matrix_epoch_{VERSION}.png"))
    plt.close()

    confusion_pairs = [
        ((LABEL_NAMES[i], LABEL_NAMES[j]), cm[i][j])
        for i in range(len(LABEL_NAMES))
        for j in range(len(LABEL_NAMES)) if i != j
    ]
    top_confusions = sorted(confusion_pairs, key=lambda x: x[1], reverse=True)[:5] # Increased to top 5
    print("\nTop 5 confused class pairs:")
    for (true_label, pred_label), count in top_confusions:
        print(f"  - {true_label} ‚Üí {pred_label}: {count} instances")

    avg_entropy = entropies.mean().item()
    print(f"\nüß† Avg prediction entropy: {avg_entropy:.4f}")

    print("\nüîç Class entropies (sorted):")
    for class_name, entropy in sorted_entropy:
        print(f"  - {class_name}: entropy = {entropy:.4f}")

    accuracy = (preds == labels).mean()
    return {"accuracy": accuracy}

# ... (Other utility functions like save_model_and_processor remain the same) ...

#%%
# --------------------------
# 3. Auto-Load V29 Golden Checkpoint
# --------------------------
# MODIFICATION: Now loading from V29 as the starting point.
model_path = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807"
print(f"‚úÖ Explicitly loading V29 checkpoint from: {model_path}")

model = AutoModelForImageClassification.from_pretrained(model_path)
processor = AutoImageProcessor.from_pretrained(model_path)

# Reset the classifier head for new training with the new labels
model.classifier = nn.Linear(model.config.hidden_size, len(LABEL_NAMES))
model.config.id2label = id2label
model.config.label2id = label2id
model.config.num_labels = len(LABEL_NAMES)
print("‚úÖ Classifier head reset for new training.")

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("\nüñ•Ô∏è Using device:", device)
model.to(device).eval()


#%%
# --------------------------
# 4. Load and Prepare Dataset
# --------------------------
# (This section remains largely the same, assuming IMAGE_DIR points to the new dataset)
print("üîç Counting valid image files on disk for verification...")
expected_file_count = len(
    [p for p in Path(BASE_PATH).rglob("*") if is_valid_image(p.name)]
)
print(f"‚úÖ Found {expected_file_count} image files in {BASE_PATH}")

datasets.disable_caching()
print("‚úÖ Datasets caching disabled for this run to ensure fresh data load.")

dataset = load_dataset(
    "imagefolder",
    data_dir=BASE_PATH,
    split="train"
)

dataset = dataset.map(add_image_path, desc="Add file path to each record")
dataset = dataset.map(reconcile_labels, desc="Re-labeling dataset (preserving image_path)")
dataset = dataset.filter(lambda x: x["label"] != -1)

final_count = len(dataset)
print(f"‚úÖ Total examples after filtering: {final_count}")


#%%
# --------------------------
# 5. Data Curation (If needed, e.g., removing hard negatives from a previous run)
# --------------------------
# For V30, we are starting fresh with the new labels, so we skip the hard-negative
# removal from V24. If you have a new exclusion list, you can add that logic here.
curated_dataset = dataset
print("‚úÖ Using the full dataset for V30 training.")


#%%
# --------------------------
# 6. Dataset Label Overview
# --------------------------
# (This section remains the same, will now show counts for new labels)
# ...


#%%
# --------------------------
# 7. Perceptual Clustering
# --------------------------
# (This section can remain the same)
# ...

#%%
# --------------------------
# 8. Class Frequency-Aware Augmentation Targeting
# --------------------------
label_freqs = Counter(curated_dataset["label"])
label_id2name = {v: k for k, v in label2id.items()}
label_name2id = {v: k for k, v in label_id2name.items()}

minority_by_count = sorted(label_freqs, key=label_freqs.get)[:3]
minority_by_name = [label_id2name[i] for i in minority_by_count]
minority_by_name = [n for n in minority_by_name if n != "unknown"]

# MODIFICATION: Added 'speech_action' and 'hard_case' to the manual focus list
# to ensure they receive strong augmentation.
manual_focus_classes = ['disgust', 'questioning', 'contempt', 'speech_action', 'hard_case']

minority_class_names = list(set(minority_by_name + manual_focus_classes))
minority_classes = [label_name2id[name] for name in minority_class_names]
print(f"üéØ Targeted minority augmentation will apply to: {minority_class_names}")


#%%
# --------------------------
# 9. Define Data Augmentation
# --------------------------
# (This section remains the same)
data_augment = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomRotation(10),
    T.ColorJitter(brightness=0.1, contrast=0.1)
])

minority_aug = T.Compose([
    RandAugment(num_ops=2, magnitude=9),
    T.RandomResizedCrop(224, scale=(0.7, 1.0)),
    T.ColorJitter(0.3, 0.3, 0.3, 0.1),
])
print("‚úÖ Augmentation pipelines defined.")


#%%
# --------------------------
# 10. Balance Dataset
# --------------------------
# (This section's logic is mostly the same, but the hard_classes list is updated)
MINORITY_CAP = 2250 # You may need to adjust this based on the size of your new classes
balanced_subsets = []
label_counts = Counter(curated_dataset["label"])
print("Original label distribution:", label_counts)

for label, count in label_counts.items():
    subset = curated_dataset.filter(lambda x: x['label'] == label, num_proc=1)
    class_name = LABEL_NAMES[label]
    if class_name == "unknown":
        balanced_subsets.append(subset)
    elif count < MINORITY_CAP:
        multiplier = MINORITY_CAP // len(subset)
        remainder = MINORITY_CAP % len(subset)
        subset = concatenate_datasets([subset] * multiplier + [subset.select(range(remainder))])
        balanced_subsets.append(subset)
    else:
        balanced_subsets.append(subset)

train_dataset = concatenate_datasets(balanced_subsets).shuffle(seed=42)
print("After balancing:", Counter(train_dataset['label']))

# MODIFICATION: Added 'speech_action' and 'hard_case' to the weighted sampling.
# This gives them 2x importance during training.
hard_classes = ['contempt', 'disgust', 'questioning', 'surprise', 'fear', 'speech_action', 'hard_case']
hard_class_ids = [label2id[c] for c in hard_classes]

weights = [2.0 if l in hard_class_ids else 1.0 for l in train_dataset["label"]]
weights = torch.DoubleTensor(weights)
sampler = torch.utils.data.WeightedRandomSampler(
    weights=weights,
    num_samples=len(weights),
    replacement=True
)

# ... (the rest of the script for splitting, training, and analysis remains the same) ...

#%%
# --------------------------
# 11. Optimizer, Scheduler, and Training
# --------------------------

# --- Part A: Train/Validation Split on CURATED Data ---
split_dataset = curated_dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]
print(f"‚úÖ Curated data split into {len(train_dataset)} training and {len(eval_dataset)} validation samples.")


# --- Part B: Define Training Arguments ---
training_args = TrainingArguments(
    output_dir=SAVE_DIR,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    learning_rate=4e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    logging_dir=os.path.join(SAVE_DIR, "logs"),
    logging_strategy="epoch",
    remove_unused_columns=False
)

early_stop_callback = EarlyStoppingCallback(
    early_stopping_patience=3,
    early_stopping_threshold=0.001
)

# --- Part C: Instantiate the Data Collator ---
minority_augment_map = {label_id: minority_aug for label_id in minority_classes}
data_collator = DataCollatorWithAugmentation(
    processor=processor,
    augment_dict=minority_augment_map
)

# --- Part D: Discriminative Learning Rate Optimizer Setup ---
head_lr = 5e-5
backbone_lr = 2e-7

for param in model.parameters():
    param.requires_grad = False
for param in model.classifier.parameters():
    param.requires_grad = True
for name, param in model.vit.encoder.layer[-4:].named_parameters():
    param.requires_grad = True

optimizer_grouped_parameters = [
    {'params': model.classifier.parameters(), 'lr': head_lr},
    {'params': model.vit.encoder.layer[-4:].parameters(), 'lr': backbone_lr}
]

optimizer = torch.optim.AdamW(optimizer_grouped_parameters, weight_decay=0.01)

# --- Part E: Trainer Initialization and Execution ---
trainer = CustomLossTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics_with_confusion,
    optimizers=(optimizer, None),
    data_collator=data_collator,
    callbacks=[early_stop_callback]
)

# --- Part F: Train the Model and Finalize ---
print(f"\n--- Starting {VERSION} Training with on-the-fly processing ---")
trainer.train()
print("--- Training Finished ---")

# --- (The rest of the script, sections 12-19, for inference and analysis can remain the same) ---