In [None]:
# Enhanced Emotion Detection Preprocessing for Music Recommendation System
# Optimized ResNet-18 Architecture for 48x48 Grayscale Facial Emotion Recognition

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image, ImageEnhance
import cv2
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import make_grid
from collections import Counter, defaultdict
import json
import time
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

In [None]:
# ENHANCED CONFIGURATION FOR MUSIC RECOMMENDATION SYSTEM

CONFIG = {
    # Directory paths - Updated for correct output location
    'dataset_root': '../../data/raw/fer2013/train',
    'train_dir': '../../data/raw/fer2013/train',
    'test_dir': '../../data/raw/fer2013/test',
    'output_dir': '../../data/processed/FC211033_Sahan',
    'model_save_dir': '../../models',
    
    # Dataset parameters
    'image_size': (48, 48),
    'batch_size': 64,  # Increased for better gradient estimates
    'validation_split': 0.2,
    'test_split': 0.1,
    
    # Training parameters - Optimized for ResNet-18
    'num_epochs': 100,
    'learning_rate': 0.001,
    'weight_decay': 1e-4,
    'patience': 15,  # Increased patience
    'lr_scheduler_patience': 8,
    'lr_scheduler_factor': 0.5,
    
    # Original emotion labels from FER2013
    'original_emotions': ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise'],
    
    # Music-relevant emotions (excluding fear and disgust for better music mapping)
    'music_emotions': ['angry', 'happy', 'neutral', 'sad', 'surprise'],
    
    # Music mapping for recommendations
    'emotion_to_music': {
        'happy': ['Pop', 'Dance', 'Upbeat', 'Electronic'],
        'sad': ['Blues', 'Ballads', 'Acoustic', 'Melancholic'],
        'angry': ['Rock', 'Metal', 'Punk', 'Aggressive'],
        'neutral': ['Classical', 'Ambient', 'Instrumental', 'Chill'],
        'surprise': ['Experimental', 'Fusion', 'Eclectic', 'Dynamic']
    },
    
    # Data augmentation parameters
    'augmentation': {
        'rotation_degrees': 15,
        'horizontal_flip_prob': 0.5,
        'brightness_factor': 0.2,
        'contrast_factor': 0.2,
        'translate': (0.1, 0.1),
        'scale': (0.9, 1.1)
    }
}

# Create output directories
for dir_path in [CONFIG['output_dir'], CONFIG['model_save_dir']]:
    os.makedirs(dir_path, exist_ok=True)

In [None]:
# ENHANCED DATASET CLASS WITH DATA PROCESSING AND COPYING

import shutil
from tqdm import tqdm

class MusicEmotionDataset(Dataset):
    """
    Enhanced dataset class for music recommendation emotion detection.
    Processes and copies quality-filtered images to processed directory.
    """
    
    def __init__(self, data_dir, transform=None, apply_clahe=True, filter_quality=True, 
                 copy_to_processed=True, processed_dir=None):
        self.data_dir = data_dir
        self.transform = transform
        self.apply_clahe = apply_clahe
        self.filter_quality = filter_quality
        self.copy_to_processed = copy_to_processed
        self.processed_dir = processed_dir or CONFIG['output_dir']
        self.data = []
        self.labels = []
        self.label_encoder = LabelEncoder()
        
        # Initialize CLAHE for contrast enhancement
        self.clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        
        # Load and process music-relevant emotions
        self._process_emotions()
        
        # Encode labels
        self.labels = self.label_encoder.fit_transform(self.labels)
    
    def _process_emotions(self):
        """Process and copy quality-filtered images"""
        total_loaded = 0
        total_filtered = 0
        
        print("Processing music-relevant emotions...")
        for emotion in CONFIG['music_emotions']:
            emotion_dir = os.path.join(self.data_dir, emotion)
            if os.path.exists(emotion_dir):
                loaded, filtered = self._process_emotion_images(emotion_dir, emotion)
                total_loaded += loaded
                total_filtered += filtered
        
        print(f"Processed {total_loaded} quality images, filtered out {total_filtered} low-quality images")
        
    def _process_emotion_images(self, emotion_dir, emotion):
        """Process and optionally copy images for an emotion"""
        loaded_count = 0
        filtered_count = 0
        
        # Create processed emotion directory
        if self.copy_to_processed:
            processed_emotion_dir = os.path.join(self.processed_dir, emotion)
            os.makedirs(processed_emotion_dir, exist_ok=True)
        
        image_files = [f for f in os.listdir(emotion_dir) 
                      if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        
        for img_file in tqdm(image_files, desc=f"Processing {emotion}", leave=False):
            img_path = os.path.join(emotion_dir, img_file)
            
            if self.filter_quality and not self._is_quality_image(img_path):
                filtered_count += 1
                continue
            
            # Process and copy image
            if self.copy_to_processed:
                processed_path = os.path.join(processed_emotion_dir, img_file)
                self._process_and_save_image(img_path, processed_path)
                self.data.append(processed_path)
            else:
                self.data.append(img_path)
            
            self.labels.append(emotion)
            loaded_count += 1
        
        return loaded_count, filtered_count
    
    def _process_and_save_image(self, src_path, dst_path):
        """Process image with quality enhancements and save"""
        try:
            # Load image
            image = Image.open(src_path)
            if image.mode != 'L':
                image = image.convert('L')
            
            # Resize if needed
            if image.size != CONFIG['image_size']:
                image = image.resize(CONFIG['image_size'], Image.LANCZOS)
            
            # Apply CLAHE for contrast enhancement
            if self.apply_clahe:
                img_array = np.array(image)
                img_array = self.clahe.apply(img_array)
                image = Image.fromarray(img_array)
            
            # Save processed image
            image.save(dst_path, 'PNG', optimize=True)
            
        except Exception as e:
            # Fallback: copy original if processing fails
            shutil.copy2(src_path, dst_path)
    
    def _is_quality_image(self, img_path):
        """Filter out low quality images"""
        try:
            image = Image.open(img_path)
            if image.mode != 'L':
                image = image.convert('L')
            
            img_array = np.array(image)
            
            # Quality criteria
            brightness = np.mean(img_array) / 255.0
            contrast = np.std(img_array) / 255.0
            
            # Filter criteria for music emotion detection
            if brightness < 0.1 or brightness > 0.95:  # Too dark or too bright
                return False
            if contrast < 0.05:  # Too low contrast
                return False
            
            return True
        except:
            return False
    
    def get_class_weights(self):
        """Calculate balanced class weights"""
        class_weights = compute_class_weight(
            'balanced',
            classes=np.unique(self.labels),
            y=self.labels
        )
        return torch.FloatTensor(class_weights)
    
    def get_weighted_sampler(self):
        """Create weighted sampler for balanced training"""
        class_counts = Counter(self.labels)
        weights = [1.0 / class_counts[label] for label in self.labels]
        return WeightedRandomSampler(weights, len(weights))
    
    def get_statistics(self):
        """Get dataset statistics"""
        class_counts = Counter(self.labels)
        stats = {
            'total_samples': len(self.labels),
            'num_classes': len(self.label_encoder.classes_),
            'class_distribution': {
                self.label_encoder.classes_[i]: class_counts[i] 
                for i in range(len(self.label_encoder.classes_))
            },
            'emotions': list(self.label_encoder.classes_)
        }
        return stats
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path = self.data[idx]
        label = self.labels[idx]
        
        # Load image
        image = Image.open(img_path)
        if image.mode != 'L':
            image = image.convert('L')
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        return image, label


In [None]:
# ENHANCED TRANSFORMS WITH ADVANCED AUGMENTATION STRATEGIES

class AdvancedTransforms:
    """Advanced data augmentation strategies optimized for facial emotion recognition"""
    
    @staticmethod
    def calculate_dataset_statistics(dataset):
        """Calculate dataset statistics for normalization"""
        
        pixel_values = []
        sample_size = min(1000, len(dataset))  # Sample for efficiency
        indices = np.random.choice(len(dataset), sample_size, replace=False)
        
        for idx in indices:
            img_path = dataset.data[idx]
            image = Image.open(img_path)
            if image.mode != 'L':
                image = image.convert('L')
            if image.size != CONFIG['image_size']:
                image = image.resize(CONFIG['image_size'], Image.LANCZOS)
            
            img_array = np.array(image, dtype=np.float32) / 255.0
            pixel_values.extend(img_array.flatten())
        
        mean = np.mean(pixel_values)
        std = np.std(pixel_values)
        
        return mean, std
    
    @staticmethod
    def get_train_transforms(mean=None, std=None):
        """Enhanced training transforms with comprehensive augmentation"""
        if mean is None or std is None:
            mean, std = 0.5, 0.25  # Default values for grayscale
        
        return transforms.Compose([
            # Geometric augmentations
            transforms.RandomRotation(
                degrees=CONFIG['augmentation']['rotation_degrees'], 
                fill=128  # Gray fill for rotations
            ),
            transforms.RandomHorizontalFlip(p=CONFIG['augmentation']['horizontal_flip_prob']),
            transforms.RandomAffine(
                degrees=0,
                translate=CONFIG['augmentation']['translate'],
                scale=CONFIG['augmentation']['scale'],
                fill=128
            ),
            
            # Photometric augmentations
            transforms.ColorJitter(
                brightness=CONFIG['augmentation']['brightness_factor'],
                contrast=CONFIG['augmentation']['contrast_factor']
            ),
            
            # Random erasing for robustness
            transforms.ToTensor(),
            transforms.Normalize(mean=[mean], std=[std]),
            transforms.RandomErasing(p=0.1, scale=(0.02, 0.1), ratio=(0.3, 3.3)),
        ])
    
    @staticmethod
    def get_val_transforms(mean=None, std=None):
        """Validation/test transforms - only normalization"""
        if mean is None or std is None:
            mean, std = 0.5, 0.25
        
        return transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[mean], std=[std])
        ])
    
    @staticmethod
    def get_test_time_augmentation_transforms(mean=None, std=None):
        """Test-time augmentation for improved inference"""
        if mean is None or std is None:
            mean, std = 0.5, 0.25
        
        return [
            # Original
            transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[mean], std=[std])
            ]),
            # Slight rotation
            transforms.Compose([
                transforms.RandomRotation(degrees=5, fill=128),
                transforms.ToTensor(),
                transforms.Normalize(mean=[mean], std=[std])
            ]),
            # Horizontal flip
            transforms.Compose([
                transforms.RandomHorizontalFlip(p=1.0),
                transforms.ToTensor(),
                transforms.Normalize(mean=[mean], std=[std])
            ])
        ]

class TransformedDataset(Dataset):
    """Wrapper for applying transforms to dataset subsets"""
    
    def __init__(self, dataset, indices, transform=None):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        actual_idx = self.indices[idx]
        img_path = self.dataset.data[actual_idx]
        label = self.dataset.labels[actual_idx]
        
        # Load image
        image = Image.open(img_path)
        if image.mode != 'L':
            image = image.convert('L')
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        return image, label

In [None]:
def calculate_dataset_statistics(dataset):
    """Calculate mean and std of the dataset"""
    return AdvancedTransforms.calculate_dataset_statistics(dataset)

In [None]:
def get_transforms(mean=None, std=None):
    """Get enhanced transforms with proper normalization"""
    train_transforms = AdvancedTransforms.get_train_transforms(mean, std)
    val_transforms = AdvancedTransforms.get_val_transforms(mean, std)
    return train_transforms, val_transforms

In [None]:
# ENHANCED RESNET ARCHITECTURE FOR MUSIC EMOTION RECOGNITION

class AttentionModule(nn.Module):
    """Spatial attention mechanism for focusing on facial features"""
    
    def __init__(self, in_channels):
        super(AttentionModule, self).__init__()
        self.conv = nn.Conv2d(in_channels, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        attention = self.conv(x)
        attention = self.sigmoid(attention)
        return x * attention

class MusicEmotionResNet(nn.Module):
    """
    ResNet-18 architecture optimized for music emotion recognition.
    """
    
    def __init__(self, num_classes=5, pretrained=True, dropout_rate=0.4):
        super(MusicEmotionResNet, self).__init__()
        
        # Load pretrained ResNet-18
        self.resnet = models.resnet18(pretrained=pretrained)
        
        # Modify first layer for grayscale input (1 channel instead of 3)
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
        # Add attention mechanism after the first few layers
        self.attention = AttentionModule(64)
        
        # Enhanced classifier head
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Dropout(dropout_rate),
            nn.Linear(num_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate * 0.7),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate * 0.5),
            nn.Linear(256, num_classes)
        )
        
        # Initialize the new conv1 layer properly
        nn.init.kaiming_normal_(self.resnet.conv1.weight, mode='fan_out', nonlinearity='relu')
        
        # Initialize classifier layers
        self._initialize_classifier()
    
    def _initialize_classifier(self):
        """Initialize classifier layers with proper weights"""
        for module in self.resnet.fc.modules():
            if isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.BatchNorm1d):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)
    
    def forward(self, x):
        # Extract features through ResNet backbone
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        
        # Apply attention mechanism early in the network
        x = self.attention(x)
        
        x = self.resnet.maxpool(x)
        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)
        
        # Final classification
        x = self.resnet.fc(x)
        
        return x
    
    def get_attention_weights(self, x):
        """Extract attention weights for visualization"""
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        
        attention_weights = self.attention.conv(x)
        attention_weights = self.attention.sigmoid(attention_weights)
        
        return attention_weights

class FocalLoss(nn.Module):
    """Focal Loss for handling class imbalance"""
    
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [None]:
# ENHANCED VISUALIZATION AND ANALYSIS FUNCTIONS

def visualize_sample_images(dataset, num_samples=20):
    """Enhanced visualization of sample images with emotion-music mapping info"""
    emotions = dataset.label_encoder.classes_
    n_emotions = len(emotions)
    samples_per_emotion = num_samples // n_emotions
    
    fig, axes = plt.subplots(n_emotions, samples_per_emotion, 
                            figsize=(samples_per_emotion * 2, n_emotions * 2))
    
    if n_emotions == 1:
        axes = axes.reshape(1, -1)
    elif samples_per_emotion == 1:
        axes = axes.reshape(-1, 1)
    
    # Color mapping for emotions
    colors = {'angry': 'red', 'happy': 'gold', 'neutral': 'gray', 
              'sad': 'blue', 'surprise': 'orange'}
    
    print("Sample Images with Music Genre Mapping:")
    print("-" * 50)
    
    for i, emotion in enumerate(emotions):
        # Get indices for this emotion
        emotion_indices = [idx for idx, label in enumerate(dataset.labels) 
                          if dataset.label_encoder.classes_[label] == emotion]
        
        # Sample random images
        sample_indices = np.random.choice(emotion_indices, 
                                        min(samples_per_emotion, len(emotion_indices)), 
                                        replace=False)
        
        # Display info
        music_genres = ', '.join(CONFIG['emotion_to_music'][emotion][:3])
        print(f"{emotion:>8}: {music_genres}")
        
        for j, idx in enumerate(sample_indices):
            img_path = dataset.data[idx]
            try:
                image = Image.open(img_path)
                if image.mode != 'L':
                    image = image.convert('L')
                
                if samples_per_emotion == 1:
                    ax = axes[i]
                else:
                    ax = axes[i, j]
                
                ax.imshow(image, cmap='gray')
                ax.axis('off')
                
                if j == 0:
                    ax.set_ylabel(emotion.upper(), fontsize=12, fontweight='bold',
                                color=colors.get(emotion, 'black'))
                
                # Add music genre info as title for first image
                if j == 0:
                    genre_text = ' | '.join(CONFIG['emotion_to_music'][emotion][:2])
                    ax.set_title(genre_text, fontsize=8, style='italic')
                
            except Exception as e:
                if samples_per_emotion == 1:
                    ax = axes[i]
                else:
                    ax = axes[i, j]
                ax.text(0.5, 0.5, 'Error', ha='center', va='center')
                ax.axis('off')
    
    plt.suptitle('Music-Relevant Emotion Samples with Genre Mapping', 
                fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(os.path.join(CONFIG['output_dir'], 'music_emotion_samples.png'), 
                dpi=300, bbox_inches='tight')
    plt.show()

def plot_comprehensive_analysis(dataset):
    """Create comprehensive analysis plots"""
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # 1. Class distribution
    class_counts = Counter(dataset.labels)
    emotions = [dataset.label_encoder.classes_[i] for i in range(len(dataset.label_encoder.classes_))]
    counts = [class_counts[i] for i in range(len(dataset.label_encoder.classes_))]
    colors = ['#ff4444', '#44ff44', '#888888', '#4444ff', '#ffaa44']
    
    bars = axes[0,0].bar(emotions, counts, color=colors, alpha=0.8)
    axes[0,0].set_title('Music-Relevant Emotion Distribution', fontweight='bold')
    axes[0,0].set_ylabel('Number of Images')
    axes[0,0].tick_params(axis='x', rotation=45)
    
    # Add value labels and percentages
    total = sum(counts)
    for bar, count in zip(bars, counts):
        height = bar.get_height()
        percentage = (count / total) * 100
        axes[0,0].text(bar.get_x() + bar.get_width()/2, height + 50,
                      f'{count}\\n({percentage:.1f}%)', 
                      ha='center', va='bottom', fontweight='bold')
    
    # 2. Class imbalance visualization
    max_count = max(counts)
    min_count = min(counts)
    imbalance_ratio = max_count / min_count
    
    normalized_counts = [c / max_count for c in counts]
    axes[0,1].bar(emotions, normalized_counts, color=colors, alpha=0.8)
    axes[0,1].set_title(f'Class Balance (Ratio: {imbalance_ratio:.2f})', fontweight='bold')
    axes[0,1].set_ylabel('Normalized Count')
    axes[0,1].tick_params(axis='x', rotation=45)
    axes[0,1].axhline(y=0.5, color='red', linestyle='--', alpha=0.7, label='50% line')
    axes[0,1].legend()
    
    # 3. Music genre mapping pie chart
    genre_counts = defaultdict(int)
    for emotion in emotions:
        for genre in CONFIG['emotion_to_music'][emotion]:
            genre_counts[genre] += class_counts[dataset.label_encoder.transform([emotion])[0]]
    
    top_genres = sorted(genre_counts.items(), key=lambda x: x[1], reverse=True)[:8]
    genre_names = [g[0] for g in top_genres]
    genre_values = [g[1] for g in top_genres]
    
    axes[0,2].pie(genre_values, labels=genre_names, autopct='%1.1f%%', startangle=90)
    axes[0,2].set_title('Music Genre Distribution', fontweight='bold')
    
    # 4. Emotion-Music mapping network
    axes[1,0].axis('off')
    mapping_text = "Emotion -> Music Genre Mapping:\\n\\n"
    for emotion in emotions:
        genres = ', '.join(CONFIG['emotion_to_music'][emotion])
        mapping_text += f"{emotion.upper():>8}: {genres}\\n"
    
    axes[1,0].text(0.1, 0.9, mapping_text, transform=axes[1,0].transAxes,
                  fontsize=11, verticalalignment='top', fontfamily='monospace',
                  bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
    
    # 5. Training recommendations
    axes[1,1].axis('off')
    recommendations = f"""
Training Recommendations:

Class Imbalance:
- Ratio: {imbalance_ratio:.2f}
- Use weighted loss function
- Apply focal loss for hard examples
- Implement balanced sampling

Data Augmentation:
- Rotation: ±{CONFIG['augmentation']['rotation_degrees']}°
- Horizontal flip: {CONFIG['augmentation']['horizontal_flip_prob']*100}%
- Brightness/contrast: ±{CONFIG['augmentation']['brightness_factor']*100}%
- Random erasing for robustness

    """
    
    axes[1,1].text(0.05, 0.95, recommendations, transform=axes[1,1].transAxes,
                  fontsize=9, verticalalignment='top',
                  bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.8))
    
    # 6. Model architecture diagram
    axes[1,2].axis('off')
    
    axes[1,2].text(0.05, 0.95, transform=axes[1,2].transAxes,
                  fontsize=9, verticalalignment='top', fontfamily='monospace',
                  bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))
    
    plt.tight_layout()
    plt.savefig(os.path.join(CONFIG['output_dir'], 'comprehensive_analysis.png'), 
                dpi=300, bbox_inches='tight')
    plt.show()
    
    return imbalance_ratio


In [None]:
def create_enhanced_data_loaders():
    """Create optimized data loaders with advanced sampling strategies"""
    
    print("Creating enhanced data loaders for music emotion recognition...")
    
    # Load dataset with music-focused preprocessing
    full_dataset = MusicEmotionDataset(
        CONFIG['dataset_root'], 
        apply_clahe=True, 
        filter_quality=True
    )
    
    # Calculate dataset statistics for normalization
    mean, std = calculate_dataset_statistics(full_dataset)
    
    # Create stratified data splits to maintain class balance
    sss = StratifiedShuffleSplit(
        n_splits=1, 
        test_size=CONFIG['validation_split'] + CONFIG['test_split'],
        random_state=42
    )
    
    train_idx, temp_idx = next(sss.split(
        np.zeros(len(full_dataset)), 
        full_dataset.labels
    ))
    
    # Further split temp into validation and test
    sss_val = StratifiedShuffleSplit(
        n_splits=1,
        test_size=CONFIG['test_split'] / (CONFIG['validation_split'] + CONFIG['test_split']),
        random_state=42
    )
    
    val_idx, test_idx = next(sss_val.split(
        np.zeros(len(temp_idx)),
        [full_dataset.labels[i] for i in temp_idx]
    ))
    
    # Convert to absolute indices
    val_idx = temp_idx[val_idx]
    test_idx = temp_idx[test_idx]
    
    print(f"Data splits:")
    print(f"  Training: {len(train_idx)} samples")
    print(f"  Validation: {len(val_idx)} samples") 
    print(f"  Test: {len(test_idx)} samples")
    
    # Get enhanced transforms
    train_transforms, val_transforms = get_transforms(mean, std)
    
    # Create transformed datasets
    train_dataset = TransformedDataset(full_dataset, train_idx, train_transforms)
    val_dataset = TransformedDataset(full_dataset, val_idx, val_transforms)
    test_dataset = TransformedDataset(full_dataset, test_idx, val_transforms)
    
    # Create weighted sampler for balanced training
    weighted_sampler = full_dataset.get_weighted_sampler()
    train_sampler_indices = [i for i, idx in enumerate(range(len(full_dataset))) if idx in train_idx]
    train_weights = [weighted_sampler.weights[i] for i in train_idx]
    train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=CONFIG['batch_size'],
        sampler=train_sampler,  # Use weighted sampler instead of shuffle
        num_workers=4,
        pin_memory=True,
        drop_last=True  # For stable batch norm
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    print("Enhanced data loaders created successfully")
    return train_loader, val_loader, test_loader, full_dataset, mean, std

def setup_enhanced_model_and_training():
    """Setup enhanced model with training components"""
    
    # Device setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Initialize enhanced model
    model = MusicEmotionResNet(
        num_classes=len(CONFIG['music_emotions']),
        pretrained=True,
        dropout_rate=0.4
    )
    model = model.to(device)
    
    return model, device

def save_enhanced_preprocessing_info(dataset, train_size, val_size, test_size, 
                                   mean, std, class_weights, imbalance_ratio):
    """Save comprehensive preprocessing information"""
    
    preprocessing_info = {
        'dataset_info': {
            'total_size': len(dataset),
            'train_size': train_size,
            'val_size': val_size,
            'test_size': test_size,
            'num_classes': len(CONFIG['music_emotions']),
            'original_emotions': CONFIG['original_emotions'],
            'music_emotions': CONFIG['music_emotions'],
            'excluded_emotions': list(set(CONFIG['original_emotions']) - set(CONFIG['music_emotions'])),
            'class_imbalance_ratio': float(imbalance_ratio)
        },
        'preprocessing_settings': {
            'image_size': CONFIG['image_size'],
            'batch_size': CONFIG['batch_size'],
            'quality_filtering': True,
            'clahe_enhancement': True,
            'advanced_augmentation': True
        },
        'dataset_statistics': {
            'pixel_mean': float(mean),
            'pixel_std': float(std),
            'normalization_applied': True
        },
        'class_distribution': {
            emotion: int(Counter(dataset.labels)[i]) 
            for i, emotion in enumerate(dataset.label_encoder.classes_)
        },
        'class_weights': class_weights.tolist(),
        'emotion_music_mapping': CONFIG['emotion_to_music'],
        'augmentation_config': CONFIG['augmentation'],
        'model_config': {
            'architecture': 'Enhanced ResNet-18',
            'attention_mechanism': True,
            'dropout_rate': 0.4,
            'loss_function': 'Focal Loss',
            'optimizer': 'AdamW'
        },
        'training_config': {
            'num_epochs': CONFIG['num_epochs'],
            'learning_rate': CONFIG['learning_rate'],
            'weight_decay': CONFIG['weight_decay'],
            'lr_scheduler': 'ReduceLROnPlateau',
            'early_stopping': True,
            'patience': CONFIG['patience']
        },
        'timestamp': pd.Timestamp.now().isoformat(),
        'version': '2.0_music_focused'
    }
    
    # Save to JSON
    output_file = os.path.join(CONFIG['output_dir'], 'enhanced_preprocessing_info.json')
    with open(output_file, 'w') as f:
        json.dump(preprocessing_info, f, indent=2)
    
    print(f"Enhanced preprocessing info saved to: {output_file}")
    
    # Create detailed summary report
    summary_report = f"""
ENHANCED EMOTION DETECTION PREPROCESSING REPORT
==============================================

DATASET OVERVIEW:
- Total Images: {len(dataset):,}
- Music-Relevant Emotions: {len(CONFIG['music_emotions'])}
- Excluded Emotions: {len(set(CONFIG['original_emotions']) - set(CONFIG['music_emotions']))}
- Class Imbalance Ratio: {imbalance_ratio:.2f}

DATA SPLITS:
- Training: {train_size:,} samples ({train_size/len(dataset)*100:.1f}%)
- Validation: {val_size:,} samples ({val_size/len(dataset)*100:.1f}%)
- Test: {test_size:,} samples ({test_size/len(dataset)*100:.1f}%)

DATASET STATISTICS:
- Pixel Mean: {mean:.4f}
- Pixel Std: {std:.4f}
- Image Size: {CONFIG['image_size'][0]}x{CONFIG['image_size'][1]}
- Batch Size: {CONFIG['batch_size']}

MODEL ARCHITECTURE:
- Base: Enhanced ResNet-18
- Input: Grayscale (1 channel)
- Attention: Spatial attention module
- Classifier: Multi-layer with dropout
- Output: {len(CONFIG['music_emotions'])} emotion classes

MUSIC EMOTION MAPPING:
{chr(10).join([f"- {emotion}: {', '.join(genres)}" for emotion, genres in CONFIG['emotion_to_music'].items()])}

TRAINING OPTIMIZATIONS:
- Loss Function: Focal Loss (handles imbalance)
- Optimizer: AdamW with weight decay
- Learning Rate: {CONFIG['learning_rate']} with scheduler
- Early Stopping: {CONFIG['patience']} epochs patience
- Regularization: Dropout + BatchNorm

OUTPUT LOCATION: {CONFIG['output_dir']}
MODEL SAVE LOCATION: {CONFIG['model_save_dir']}
    """
    
    summary_file = os.path.join(CONFIG['output_dir'], 'enhanced_preprocessing_summary.txt')
    with open(summary_file, 'w') as f:
        f.write(summary_report)
    
    print(f"Detailed summary saved to: {summary_file}")
    return preprocessing_info

In [None]:
def create_processed_dataset():
    """Create processed dataset and save to processed directory"""
    
    print("Creating processed dataset for music emotion recognition...")
    
    # Create enhanced dataset with processing and copying
    dataset = MusicEmotionDataset(
        CONFIG['dataset_root'], 
        apply_clahe=True, 
        filter_quality=True,
        copy_to_processed=True,
        processed_dir=CONFIG['output_dir']
    )
    
    # Calculate dataset statistics
    mean, std = calculate_dataset_statistics(dataset)
    
    # Get class weights and statistics
    class_weights = dataset.get_class_weights()
    stats = dataset.get_statistics()
    
    # Save dataset information
    dataset_info = {
        'statistics': stats,
        'class_weights': class_weights.tolist(),
        'normalization': {'mean': float(mean), 'std': float(std)},
        'config': CONFIG,
        'processed_location': CONFIG['output_dir']
    }
    
    # Save dataset info
    info_file = os.path.join(CONFIG['output_dir'], 'dataset_info.json')
    with open(info_file, 'w') as f:
        json.dump(dataset_info, f, indent=2)
    
    print(f"Dataset processed and saved to: {CONFIG['output_dir']}")
    print(f"Total samples: {stats['total_samples']}")
    print(f"Classes: {stats['emotions']}")
    print(f"Class distribution: {stats['class_distribution']}")
    
    return dataset, mean, std, class_weights

def create_data_loaders_from_processed():
    """Create data loaders from processed dataset"""
    
    print("Creating data loaders from processed dataset...")
    
    # Load from processed directory
    processed_dataset = MusicEmotionDataset(
        CONFIG['output_dir'],  # Use processed directory
        apply_clahe=False,     # Already processed
        filter_quality=False,  # Already filtered
        copy_to_processed=False
    )
    
    # Calculate or load statistics
    mean, std = calculate_dataset_statistics(processed_dataset)
    
    # Create stratified splits
    sss = StratifiedShuffleSplit(
        n_splits=1, 
        test_size=CONFIG['validation_split'] + CONFIG['test_split'],
        random_state=42
    )
    
    train_idx, temp_idx = next(sss.split(
        np.zeros(len(processed_dataset)), 
        processed_dataset.labels
    ))
    
    # Further split for validation and test
    sss_val = StratifiedShuffleSplit(
        n_splits=1,
        test_size=CONFIG['test_split'] / (CONFIG['validation_split'] + CONFIG['test_split']),
        random_state=42
    )
    
    val_idx, test_idx = next(sss_val.split(
        np.zeros(len(temp_idx)),
        [processed_dataset.labels[i] for i in temp_idx]
    ))
    
    val_idx = temp_idx[val_idx]
    test_idx = temp_idx[test_idx]
    
    # Get transforms
    train_transforms, val_transforms = get_transforms(mean, std)
    
    # Create datasets
    train_dataset = TransformedDataset(processed_dataset, train_idx, train_transforms)
    val_dataset = TransformedDataset(processed_dataset, val_idx, val_transforms)
    test_dataset = TransformedDataset(processed_dataset, test_idx, val_transforms)
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
    
    print(f"Data loaders created:")
    print(f"  Training: {len(train_dataset)} samples")
    print(f"  Validation: {len(val_dataset)} samples")
    print(f"  Test: {len(test_dataset)} samples")
    
    return train_loader, val_loader, test_loader, processed_dataset

def main():
    """Main preprocessing pipeline"""
    
    print("Music Emotion Detection - Data Preprocessing")
    print("=" * 50)
    
    # Step 1: Create processed dataset
    dataset, mean, std, class_weights = create_processed_dataset()
    
    # Step 2: Create data loaders
    train_loader, val_loader, test_loader, processed_dataset = create_data_loaders_from_processed()
    
    # Step 3: Setup model
    model, device = setup_enhanced_model_and_training()
    
    print("\nPreprocessing completed successfully!")
    print(f"Processed data saved to: {CONFIG['output_dir']}")
    
    return {
        'train_loader': train_loader,
        'val_loader': val_loader,
        'test_loader': test_loader,
        'model': model,
        'device': device,
        'class_weights': class_weights,
        'dataset_stats': {'mean': mean, 'std': std},
        'focal_loss': FocalLoss(alpha=1, gamma=2),
        'processed_dataset': processed_dataset
    }


In [None]:
if __name__ == "__main__":
    # Execute preprocessing pipeline
    results = main()
    
    print("\nPreprocessing completed! Ready for model training.")
    print(f"Processed dataset location: {CONFIG['output_dir']}")
    
    # Save training info for model training notebook
    training_info = {
        'processed_data_dir': CONFIG['output_dir'],
        'dataset_stats': {
            'mean': float(results['dataset_stats']['mean']),
            'std': float(results['dataset_stats']['std'])
        },
        'class_weights': [float(w) for w in results['class_weights'].tolist()],
        'config': CONFIG
    }
    
    info_file = os.path.join(CONFIG['output_dir'], 'training_ready_info.json')
    with open(info_file, 'w') as f:
        json.dump(training_info, f, indent=2)
    
    print(f"Training info saved to: {info_file}")