In [None]:
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models
from torchvision import transforms
import cv2
import numpy as np
from PIL import Image, ImageFile
import pandas as pd
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from collections import Counter
import albumentations as A
from imblearn.over_sampling import RandomOverSampler
from tqdm import tqdm
import os
from datetime import datetime

# For truncated image loading
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Device selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data loading
labels = []
file_names = []
base_path = Path(r"C:\Users\mibra\Downloads\archive (4)\fer_ckplus_kdef")

for file in sorted(base_path.rglob('*.*')):
    label = file.parent.name
    if label not in ["contempt", "fear", "disgust"]:  # Cleaned control
        labels.append(label)
        file_names.append(str(file))

print(f"Total files: {len(file_names)}, Label count: {len(labels)}")

df = pd.DataFrame({"image": file_names, "label": labels})
print(f"DataFrame shape: {df.shape}")
print("Class distribution:", Counter(labels))

y = df["label"].tolist()
x = df["image"].tolist()

def balance_with_augmentation_optimized(X_paths, y_labels, img_size=(224, 224)):
    """Memory-efficient augmentation and balancing for grayscale images"""

    # Albumentations transform - optimized for grayscale
    aug_transform = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=20, p=0.6),
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
        A.Affine(
            translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)},
            scale={"x": (0.85, 1.15), "y": (0.85, 1.15)},
            rotate=(-15, 15),
            p=0.5
        ),
    ])

    # Calculate class distribution
    class_counts = Counter(y_labels)
    mean_count = int(np.mean(list(class_counts.values())))

    print(f"Original class distribution: {class_counts}")
    print(f"Average class count: {mean_count}")
    print(f"Target class count (mean-based): {mean_count}")

    X_balanced = []
    y_balanced = []

    print("Starting augmentation and balancing process...")

    for idx, path in enumerate(tqdm(X_paths)):
        # Read as grayscale
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            print(f"Warning: {path} file could not be read, skipping...")
            continue

        img = cv2.resize(img, img_size)

        # Image enhancement
        kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
        img = cv2.filter2D(img, -1, kernel)

        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        img = clahe.apply(img)

        label = y_labels[idx]

        # Add original image
        X_balanced.append(img.flatten())
        y_balanced.append(label)

        # Augmentation only for minority classes (those below mean)
        current_count = class_counts[label]
        if current_count < mean_count:
            needed_augmentations = min((mean_count - current_count) // current_count, 3)

            for _ in range(needed_augmentations):
                # Convert to 3 channels
                img_3ch = np.stack([img, img, img], axis=-1)

                try:
                    augmented = aug_transform(image=img_3ch)["image"]
                    augmented_gray = cv2.cvtColor(augmented, cv2.COLOR_RGB2GRAY)
                    augmented_gray = cv2.resize(augmented_gray, img_size)

                    # Enhancement
                    augmented_gray = cv2.filter2D(augmented_gray, -1, kernel)
                    augmented_gray = clahe.apply(augmented_gray)

                    X_balanced.append(augmented_gray.flatten())
                    y_balanced.append(label)
                except Exception as e:
                    print(f"Augmentation error: {e}")
                    continue

    # Convert to NumPy array
    X_array = np.array(X_balanced, dtype=np.uint8)
    y_array = np.array(y_balanced)

    print(f"Total samples after augmentation: {len(X_array)}")

    # Final balancing with RandomOverSampler
    ros = RandomOverSampler(random_state=42)
    X_resampled, y_resampled = ros.fit_resample(X_array, y_array)

    # Reshape back to image format
    X_final = X_resampled.reshape(-1, img_size[0], img_size[1])

    print(f"Final balanced dataset: {len(X_final)} samples")
    print(f"Final class distribution: {Counter(y_resampled)}")

    return X_final, y_resampled

# Data preparation
X_final, y_final = balance_with_augmentation_optimized(x, y)

# Label encoding
le = LabelEncoder()
y_encoded = le.fit_transform(y_final)

# Save label encoder
with open("label_encoder.pkl", "wb") as f:
    pickle.dump(le, f)

print(f"Label classes: {le.classes_}")
print(f"Total samples after balancing: {len(y_encoded)}")

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(
    X_final, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded
)

print(f"Train samples: {len(X_train)}, Test samples: {len(X_test)}")

class EmotionDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.X = X
        self.y = y
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        img = self.X[idx].astype(np.uint8)
        label = self.y[idx]

        try:
            img_pil = Image.fromarray(img, mode='L')

            if self.transform:
                img_tensor = self.transform(img_pil)
            else:
                img_tensor = transforms.ToTensor()(img_pil)

        except Exception as e:
            print(f"Transform error: {e}")
            img_tensor = torch.zeros((3, 224, 224))

        return img_tensor, torch.tensor(label, dtype=torch.long)

# Transforms
transform_train = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Datasets and DataLoaders
train_dataset = EmotionDataset(X_train, y_train, transform=transform_train)
test_dataset = EmotionDataset(X_test, y_test, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")

# Model creation
model = models.mobilenet_v3_small(pretrained=True)

# Adjust classifier for number of classes - THIS IS IMPORTANT!
num_classes = len(le.classes_)
model.classifier[3] = nn.Linear(in_features=1024, out_features=num_classes)

# Move model to device
model = model.to(device)

print(f"Model output classes: {num_classes}")
print(f"Model classes: {le.classes_}")

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.7)

def train_model(model, train_loader, test_loader, criterion, optimizer, scheduler, num_epochs=10):
    """Advanced training function"""
    best_acc = 0.0
    train_losses = []
    test_accuracies = []

    # Training start time
    start_time = datetime.now()
    print(f"Training start time: {start_time.strftime('%Y-%m-%d %H:%M:%S')}")

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 40)

        for batch_idx, (inputs, labels) in enumerate(tqdm(train_loader, desc="Training")):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader)
        train_acc = 100 * correct_train / total_train
        train_losses.append(epoch_loss)

        # Validation phase
        model.eval()
        correct_test = 0
        total_test = 0
        test_loss = 0.0

        with torch.no_grad():
            for inputs, labels in tqdm(test_loader, desc="Testing"):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                test_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                total_test += labels.size(0)
                correct_test += (predicted == labels).sum().item()

        test_acc = 100 * correct_test / total_test
        test_accuracies.append(test_acc)

        print(f"Train Loss: {epoch_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Test Loss: {test_loss / len(test_loader):.4f}, Test Acc: {test_acc:.2f}%")

        # Save best model
        if test_acc > best_acc:
            best_acc = test_acc
            # Timestamped model save
            timestamp = datetime.now().strftime("%Y%m%d_%H%M")
            model_path = f'best_emotion_model_{timestamp}.pth'
            torch.save(model.state_dict(), model_path)
            print(f"New best accuracy: {best_acc:.2f}% - Model saved: {model_path}")

        scheduler.step()
        print(f"Learning rate: {scheduler.get_last_lr()[0]:.6f}")

    end_time = datetime.now()
    training_duration = end_time - start_time
    print(f"\nTraining completed!")
    print(f"Training duration: {training_duration}")
    print(f"Best test accuracy: {best_acc:.2f}%")

    return train_losses, test_accuracies, best_acc

# Model training
print("Starting model training...")
train_losses, test_accuracies, best_accuracy = train_model(
    model, train_loader, test_loader, criterion, optimizer, scheduler, num_epochs=10
)

# Save final model
final_timestamp = datetime.now().strftime("%Y%m%d_%H%M")
final_model_path = f'final_emotion_model_{final_timestamp}.pth'
torch.save(model.state_dict(), final_model_path)
print(f"Final model saved: {final_model_path}")

# Save metadata
model_info = {
    'model_state_dict': model.state_dict(),
    'num_classes': num_classes,
    'class_names': le.classes_.tolist(),
    'best_accuracy': best_accuracy,
    'training_date': datetime.now().isoformat(),
    'model_architecture': 'MobileNetV3-Small',
    'input_size': (224, 224),
    'final_train_loss': train_losses[-1] if train_losses else None,
    'final_test_accuracy': test_accuracies[-1] if test_accuracies else None
}

metadata_path = f'model_info_{final_timestamp}.pkl'
with open(metadata_path, 'wb') as f:
    pickle.dump(model_info, f)
print(f"Model metadata saved: {metadata_path}")

# Final evaluation
print("\n" + "="*50)
print("FINAL EVALUATION")
print("="*50)

model.eval()
correct = 0
total = 0
class_correct = [0] * num_classes
class_total = [0] * num_classes

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Final Evaluation"):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Per-class accuracy
        for i in range(labels.size(0)):
            label_idx = labels[i].item()
            pred_idx = predicted[i].item()

            class_total[label_idx] += 1
            if label_idx == pred_idx:
                class_correct[label_idx] += 1

final_accuracy = 100 * correct / total
print(f'Final Test Accuracy: {final_accuracy:.2f}%')
print(f'Total test samples: {total}')

# Class-wise accuracy
print(f"\nClass-wise Accuracies:")
print("-" * 50)
for i, class_name in enumerate(le.classes_):
    if class_total[i] > 0:
        class_acc = 100 * class_correct[i] / class_total[i]
        print(f'{class_name:>12}: {class_acc:>6.2f}% ({class_correct[i]:>4}/{class_total[i]:>4})')
    else:
        print(f'{class_name:>12}: No samples in test set')

# Detailed classification report and confusion matrix
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# Collect all predictions
all_predictions = []
all_labels = []

model.eval()
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)

        all_predictions.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Classification report
print(f"\nDetailed Classification Report:")
print("-" * 60)
report = classification_report(all_labels, all_predictions,
                               target_names=le.classes_, digits=3)
print(report)

# Confusion Matrix
plt.figure(figsize=(10, 8))
cm = confusion_matrix(all_labels, all_predictions)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=le.classes_, yticklabels=le.classes_)
plt.title(f'Confusion Matrix - Accuracy: {final_accuracy:.2f}%')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.tight_layout()

# Save confusion matrix
cm_path = f'confusion_matrix_{final_timestamp}.png'
plt.savefig(cm_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"Confusion matrix saved: {cm_path}")

# Final summary
print(f"\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)
print(f"Best model saved as: best_emotion_model_{final_timestamp}.pth")
print(f"Final model saved as: {final_model_path}")
print(f"Metadata saved as: {metadata_path}")
print(f"Confusion matrix: {cm_path}")
print(f"Best accuracy achieved: {best_accuracy:.2f}%")
print(f"Final test accuracy: {final_accuracy:.2f}%")
print(f"Number of classes: {num_classes}")
print(f"Classes: {', '.join(le.classes_)}")
print("="*60)

# Prediction function
import torch.nn.functional as F

def predict_emotion_from_path(image_path, model, le, transform, device, show_image=True, show_probabilities=True):
    """
    Performs emotion prediction using a single image path

    Args:
        image_path (str): Path to image file
        model: Trained PyTorch model
        le: LabelEncoder object
        transform: Preprocessing transforms
        device: PyTorch device (cuda/cpu)
        show_image (bool): Show/hide image
        show_probabilities (bool): Show probabilities for all classes

    Returns:
        dict: Prediction results
    """
    try:
        # Load image and check
        if not os.path.exists(image_path):
            print(f"Error: {image_path} file not found!")
            return None

        # Open image with PIL
        image = Image.open(image_path)
        original_image = image.copy()  # Save original image for display

        # Show image (optional)
        if show_image:
            plt.figure(figsize=(6, 4))
            plt.imshow(image, cmap='gray' if image.mode == 'L' else None)
            plt.title(f"Input Image: {os.path.basename(image_path)}")
            plt.axis('off')
            plt.show()

        # Set model to evaluation mode
        model.eval()

        # Transform image
        if transform:
            input_tensor = transform(image).unsqueeze(0)  # Add batch dimension
        else:
            # Fallback transform
            input_tensor = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.Grayscale(num_output_channels=3),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])(image).unsqueeze(0)

        # Move to GPU
        input_tensor = input_tensor.to(device)

        # Make prediction
        with torch.no_grad():
            outputs = model(input_tensor)
            probabilities = F.softmax(outputs, dim=1)

        # Highest probability and class
        max_prob, predicted_class = torch.max(probabilities, 1)
        predicted_emotion = le.classes_[predicted_class.item()]
        confidence = max_prob.item()

        # Probabilities for all classes
        all_probabilities = probabilities[0].cpu().numpy()
        emotion_probabilities = {}

        for i, emotion in enumerate(le.classes_):
            emotion_probabilities[emotion] = all_probabilities[i]

        # Sort results (highest to lowest)
        sorted_emotions = sorted(emotion_probabilities.items(), key=lambda x: x[1], reverse=True)

        # Print results
        print(f"\n{'='*60}")
        print(f"EMOTION PREDICTION RESULTS")
        print(f"{'='*60}")
        print(f"Image: {os.path.basename(image_path)}")
        print(f"Predicted Emotion: {predicted_emotion.upper()}")
        print(f"Confidence: {confidence*100:.2f}%")
        print(f"Device: {device}")

        if show_probabilities:
            print(f"\nAll Emotion Probabilities:")
            print(f"{'-'*40}")
            for emotion, prob in sorted_emotions:
                bar = "â–ˆ" * int(prob * 20)  # Simple progress bar
                print(f"{emotion:>12}: {prob*100:>6.2f}% {bar}")

        # Return dictionary
        result = {
            'predicted_emotion': predicted_emotion,
            'confidence': confidence,
            'all_probabilities': emotion_probabilities,
            'sorted_probabilities': sorted_emotions,
            'image_path': image_path
        }

        return result

    except Exception as e:
        print(f"Prediction error: {e}")
        return None

# Transform for prediction
transform_predict = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Example usage:
# image_path = r"path/to/your/image.jpg"
# result = predict_emotion_from_path(image_path, model, le, transform_predict, device)