In [6]:
# EmotionCLIP Fine-tuning for ArtEmis Emotion Classification - FIXED VERSION
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import pandas as pd
import numpy as np
import os
import os.path as osp
from PIL import Image
import unicodedata
from ast import literal_eval
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import logging

# Import EmotionCLIP-V2
from EmotionCLIP_V2 import EmotionCLIP

# Import artemis modules
from artemis.emotions import ARTEMIS_EMOTIONS, IDX_TO_EMOTION, positive_negative_else
from artemis.in_out.basics import create_dir

# Configuration
num_labels = len(ARTEMIS_EMOTIONS)
model_name = 'EmotionCLIP-V2'
load_best_model = False
do_training = True
max_train_epochs = 5  # Reduced for testing
subsample_data = True  # Enable for faster testing

# Fine-tuning method selection
use_layer_norm_tuning = False  # Disabled for now
use_prefix_tuning = True
use_prompt_tuning = False  # Disabled to simplify

# Paths (matching your structure)
my_out_dir = r'finetuned_emotionclip_results/emotionclip_based'
artemis_preprocessed_dir = r'data/artemis/old_preprocessed_data'
image_hists_file = r'data/artemis/artemis/data/image-emotion-histogram.csv'

create_dir(my_out_dir)
best_model_dir = osp.join(my_out_dir, 'best_model')
create_dir(best_model_dir)

# Training parameters - CONSERVATIVE SETTINGS
batch_size = 4  # Very small batch size
learning_rate = 1e-6  # Very small learning rate
weight_decay = 0.01
max_grad_norm = 0.5  # Aggressive gradient clipping
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Fine-tuning parameters - REDUCED
prefix_length = 2  # Very small
prefix_dim = 128  # Reduced

# Emotion label mapping
emotion_mapping = {emotion: idx for idx, emotion in enumerate(ARTEMIS_EMOTIONS)}

class ArtemisEmotionDataset(Dataset):
    """Dataset class for ArtEmis emotion data"""
    
    def __init__(self, df, preprocess, emotion_mapping):
        self.df = df
        self.preprocess = preprocess
        self.emotion_mapping = emotion_mapping
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        try:
            # Load and preprocess image
            img_path = row['image_file']
            
            if not os.path.exists(img_path):
                return None
            
            image = Image.open(img_path).convert('RGB')
            image_tensor = self.preprocess(image)
            
            # Get emotion distribution
            emotion_dist = row['emotion_distribution']
            
            if isinstance(emotion_dist, str):
                emotion_dist = literal_eval(emotion_dist)
            
            emotion_dist = np.array(emotion_dist, dtype=np.float32)
            
            # Ensure valid distribution
            if emotion_dist.sum() <= 0:
                emotion_dist = np.ones(len(self.emotion_mapping)) / len(self.emotion_mapping)
            else:
                emotion_dist = emotion_dist / emotion_dist.sum()
            
            emotion_dist = torch.tensor(emotion_dist, dtype=torch.float32)
            
            # Get dominant emotion
            dominant_emotion_idx = torch.argmax(emotion_dist).item()
            
            return {
                'image': image_tensor,
                'emotion_distribution': emotion_dist,
                'emotion_label': dominant_emotion_idx,
                'image_path': img_path
            }
        except Exception as e:
            print(f"Error loading {row.get('image_file', 'unknown')}: {e}")
            return None

def collate_fn(batch):
    """Custom collate function"""
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return None
    
    images = torch.stack([item['image'] for item in batch])
    emotion_distributions = torch.stack([item['emotion_distribution'] for item in batch])
    emotion_labels = torch.tensor([item['emotion_label'] for item in batch])
    
    return {
        'images': images,
        'emotion_distributions': emotion_distributions,
        'emotion_labels': emotion_labels,
        'image_paths': [item['image_path'] for item in batch]
    }

class SimpleEmotionCLIPTuning(nn.Module):
    """
    Simplified EmotionCLIP fine-tuning - FIXED VERSION
    """
    
    def __init__(self, clip_model, emotion_mapping):
        super().__init__()
        self.clip_model = clip_model
        self.emotion_mapping = emotion_mapping
        self.num_emotions = len(emotion_mapping)
        
        # Get model dimensions safely
        self.feature_dim = self._get_feature_dimension()
        
        # Create text prompts
        self.emotion_texts = [f"This image shows {emotion}" for emotion in emotion_mapping.keys()]
        
        # ONLY use simple prefix tuning
        if use_prefix_tuning:
            self.vision_adapter = nn.Parameter(
                torch.zeros(self.feature_dim) * 0.01  # Very small initialization
            )
            print(f"Initialized vision adapter with dimension {self.feature_dim}")
        else:
            self.vision_adapter = None
        
        # Simple learnable temperature
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.07))
        
    def _get_feature_dimension(self):
        """Get feature dimension safely"""
        try:
            dummy_image = torch.randn(1, 3, 224, 224).to(device)
            with torch.no_grad():
                features = self.clip_model.encode_image(dummy_image.to(self.clip_model.dtype))
                return features.shape[-1]
        except Exception as e:
            print(f"Error getting feature dimension: {e}")
            return 512  # Default fallback
    
    def encode_text_prompts(self):
        """Encode text prompts once and cache"""
        if not hasattr(self, '_cached_text_features'):
            try:
                # Use EmotionCLIP tokenizer
                text_tokens = EmotionCLIP.tokenizer(self.emotion_texts).to(device)
                with torch.no_grad():
                    text_features = self.clip_model.encode_text(text_tokens)
                    # Normalize
                    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                    self._cached_text_features = text_features.detach()
            except Exception as e:
                print(f"Error encoding text: {e}")
                # Fallback: random features
                self._cached_text_features = torch.randn(self.num_emotions, self.feature_dim, dtype=torch.float32).to(device)
                self._cached_text_features = self._cached_text_features / self._cached_text_features.norm(dim=-1, keepdim=True)
        
        return self._cached_text_features
    
    def forward(self, images):
        """Simple forward pass"""
        try:
            # Encode images - ENSURE GRADIENTS FLOW
            if use_prefix_tuning and self.vision_adapter is not None:
                # Enable gradients for vision adapter
                image_features = self.clip_model.encode_image(images.to(self.clip_model.dtype))
                
                # Apply simple adapter
                image_features = image_features + self.vision_adapter.unsqueeze(0)
            else:
                # No fine-tuning, just use base model
                with torch.no_grad():
                    image_features = self.clip_model.encode_image(images.to(self.clip_model.dtype))
                # Convert to require grad for loss computation
                image_features = image_features.detach().requires_grad_(True)
            
            # Check for NaN
            if torch.isnan(image_features).any():
                print("NaN in image features, using random features")
                image_features = torch.randn_like(image_features).to(image_features.dtype).requires_grad_(True)      
            
            # Normalize image features
            image_features = image_features / (image_features.norm(dim=-1, keepdim=True) + 1e-8)
            
            # Get text features (cached) - FIX: Ensure same dtype
            text_features = self.encode_text_prompts().to(image_features.dtype)
            
            # Compute logits - FIX: Ensure consistent dtype
            logit_scale = torch.clamp(self.logit_scale.exp(), min=1.0, max=100.0).to(image_features.dtype)
            logits = logit_scale * image_features @ text_features.t()
            
            # Clamp logits to prevent overflow
            logits = torch.clamp(logits, min=-20, max=20)
            
            return logits
            
        except Exception as e:
            print(f"Error in forward pass: {e}")
            # Return dummy logits that require grad
            batch_size = images.size(0)
            dummy_logits = torch.randn(batch_size, self.num_emotions, requires_grad=True).to(device)
            return dummy_logits

def load_data():
    """Load and prepare the ArtEmis dataset"""
    print("Loading ArtEmis data...")
    
    # Load preprocessed data
    artemis_data = pd.read_csv(osp.join(artemis_preprocessed_dir, 'artemis_preprocessed.csv'))
    artemis_data['image_file'] = artemis_data['image_file'].apply(lambda x: unicodedata.normalize('NFC', x))
    artemis_data['painting'] = artemis_data['painting'].apply(lambda x: unicodedata.normalize('NFC', x))
    
    # Keep each image once
    artemis_data = artemis_data.drop_duplicates(subset=['art_style', 'painting'])
    artemis_data.reset_index(inplace=True, drop=True)
    
    # Load emotion histograms
    try:
        image_hists = pd.read_csv(image_hists_file)
        image_hists.emotion_histogram = image_hists.emotion_histogram.apply(literal_eval)
        image_hists.emotion_histogram = image_hists.emotion_histogram.apply(
            lambda x: (np.array(x) / max(sum(x), 1e-8)).astype('float32')
        )
        
        artemis_data = artemis_data[['art_style', 'painting', 'split', 'image_file']]
        artemis_data = artemis_data.merge(image_hists, on=['art_style', 'painting'], how='inner')
        artemis_data = artemis_data.rename(columns={'emotion_histogram': 'emotion_distribution'})
        
    except Exception as e:
        print(f"Error loading emotion histograms: {e}")
        return {}
    
    # Validate data
    print(f"Original dataset size: {len(artemis_data)}")
    artemis_data = artemis_data.dropna(subset=['emotion_distribution', 'image_file'])
    print(f"After removing NaN: {len(artemis_data)}")
    
    # Check image files exist
    valid_images = artemis_data['image_file'].apply(os.path.exists)
    artemis_data = artemis_data[valid_images]
    print(f"After checking image files: {len(artemis_data)}")
    
    # Create data splits
    data_splits = {}
    for split in ['train', 'val', 'test']:
        mask = (artemis_data['split'] == split)
        sub_df = artemis_data[mask].copy()
        sub_df.reset_index(drop=True, inplace=True)
        
        if len(sub_df) == 0:
            print(f"Warning: No data found for split '{split}'")
            continue
        
        if subsample_data:
            sample_size = min(50, len(sub_df))  # Very small for testing
            sub_df = sub_df.sample(sample_size)
            sub_df.reset_index(drop=True, inplace=True)
        
        data_splits[split] = sub_df
        print(f"{split.capitalize()}: {len(sub_df)} samples")
    
    return data_splits

def create_data_loaders(data_splits, preprocess):
    """Create PyTorch DataLoaders"""
    datasets = {}
    data_loaders = {}
    
    for split_name, split_data in data_splits.items():
        dataset = ArtemisEmotionDataset(
            split_data, 
            preprocess, 
            emotion_mapping
        )
        datasets[split_name] = dataset
        
        shuffle = (split_name == 'train')
        data_loaders[split_name] = DataLoader(
            dataset, 
            batch_size=batch_size, 
            shuffle=shuffle,
            num_workers=0,  # No multiprocessing
            collate_fn=collate_fn,
            drop_last=(split_name == 'train')
        )
    
    return data_loaders, datasets

def train_epoch(model, data_loader, optimizer, criterion, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_predictions = 0
    valid_batches = 0
    
    progress_bar = tqdm(data_loader, desc="Training")
    
    for batch_idx, batch in enumerate(progress_bar):
        if batch is None:
            continue
        
        try:
            images = batch['images'].to(device)
            emotion_distributions = batch['emotion_distributions'].to(device)
            emotion_labels = batch['emotion_labels'].to(device)
            
            # Validate inputs
            if torch.isnan(images).any() or torch.isnan(emotion_distributions).any():
                print(f"NaN in batch {batch_idx}, skipping")
                continue
            
            optimizer.zero_grad()
            
            # Forward pass
            logits = model(images)
            
            # Check outputs
            if torch.isnan(logits).any() or torch.isinf(logits).any():
                print(f"NaN/Inf in logits batch {batch_idx}, skipping")
                continue
            
            # Compute loss
            log_probs = F.log_softmax(logits, dim=-1)
            loss = criterion(log_probs, emotion_distributions)
            
            if torch.isnan(loss) or torch.isinf(loss) or loss.item() > 50:
                print(f"Invalid loss {loss.item()} in batch {batch_idx}, skipping")
                continue
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(
                [p for p in model.parameters() if p.requires_grad], 
                max_grad_norm
            )
            
            optimizer.step()
            
            total_loss += loss.item()
            valid_batches += 1
            
            # Calculate accuracy
            predictions = torch.argmax(logits, dim=-1)
            correct_predictions += (predictions == emotion_labels).sum().item()
            total_predictions += emotion_labels.size(0)
            
            if total_predictions > 0:
                accuracy = correct_predictions / total_predictions
                avg_loss = total_loss / valid_batches
                
                progress_bar.set_postfix({
                    'loss': avg_loss,
                    'accuracy': accuracy,
                    'valid_batches': valid_batches
                })
            
        except Exception as e:
            print(f"Error in training batch {batch_idx}: {e}")
            continue
    
    avg_loss = total_loss / valid_batches if valid_batches > 0 else float('inf')
    accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
    
    return avg_loss, accuracy

def evaluate(model, data_loader, criterion, device):
    """Evaluate the model"""
    model.eval()
    total_loss = 0
    all_predictions = []
    all_labels = []
    valid_batches = 0
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(data_loader, desc="Evaluating")):
            if batch is None:
                continue
            
            try:
                images = batch['images'].to(device)
                emotion_distributions = batch['emotion_distributions'].to(device)
                emotion_labels = batch['emotion_labels'].to(device)
                
                if torch.isnan(images).any() or torch.isnan(emotion_distributions).any():
                    continue
                
                logits = model(images)
                
                if torch.isnan(logits).any() or torch.isinf(logits).any():
                    continue
                
                log_probs = F.log_softmax(logits, dim=-1)
                loss = criterion(log_probs, emotion_distributions)
                
                if torch.isnan(loss) or torch.isinf(loss):
                    continue
                
                total_loss += loss.item()
                valid_batches += 1
                
                predictions = torch.argmax(logits, dim=-1)
                all_predictions.extend(predictions.cpu().numpy())
                all_labels.extend(emotion_labels.cpu().numpy())
                
            except Exception as e:
                print(f"Error in evaluation batch {batch_idx}: {e}")
                continue
    
    avg_loss = total_loss / valid_batches if valid_batches > 0 else float('inf')
    accuracy = accuracy_score(all_labels, all_predictions) if all_labels else 0
    
    return avg_loss, accuracy, np.array(all_predictions), np.array(all_labels)

def main():
    """Main function"""
    # Setup
    logging.basicConfig(level=logging.INFO)
    torch.manual_seed(42)
    np.random.seed(42)
    
    print(f"Using device: {device}")
    
    # Load EmotionCLIP model
    print("Loading EmotionCLIP-V2...")
    try:
        clip_model = EmotionCLIP.model.to(device)
        preprocess = EmotionCLIP.preprocess
        print("EmotionCLIP-V2 loaded successfully")
    except Exception as e:
        print(f"Error loading EmotionCLIP: {e}")
        return
    
    # Freeze base model
    for param in clip_model.parameters():
        param.requires_grad = False
    
    print("Base EmotionCLIP model frozen")
    
    # Load data
    print("Loading data...")
    data_splits = load_data()
    
    if len(data_splits) == 0:
        print("Error: No data splits found!")
        return
    
    # Initialize model
    print("Initializing fine-tuned model...")
    model = SimpleEmotionCLIPTuning(clip_model, emotion_mapping)
    model.to(device)
    
    # Create data loaders
    print("Creating data loaders...")
    data_loaders, datasets = create_data_loaders(data_splits, preprocess)
    
    # Setup training
    trainable_params = []
    if use_prefix_tuning and model.vision_adapter is not None:
        trainable_params.append(model.vision_adapter)
    trainable_params.append(model.logit_scale)
    
    print(f"Trainable parameters: {sum(p.numel() for p in trainable_params)}")
    
    if do_training and len(trainable_params) > 0:
        print("Starting training...")
        
        optimizer = optim.Adam(trainable_params, lr=learning_rate, weight_decay=weight_decay)
        criterion = nn.KLDivLoss(reduction='batchmean')
        
        best_val_accuracy = 0
        
        for epoch in range(max_train_epochs):
            print(f"\nEpoch {epoch + 1}/{max_train_epochs}")
            
            # Training
            train_loss, train_accuracy = train_epoch(
                model, data_loaders['train'], optimizer, criterion, device
            )
            
            print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
            
            if np.isnan(train_loss) or np.isinf(train_loss):
                print("Training failed due to NaN/Inf loss. Stopping...")
                break
            
            # Validation
            if 'val' in data_loaders:
                val_loss, val_accuracy, _, _ = evaluate(model, data_loaders['val'], criterion, device)
                print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
                
                # Save best model
                if val_accuracy > best_val_accuracy:
                    best_val_accuracy = val_accuracy
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'val_accuracy': val_accuracy,
                    }, osp.join(best_model_dir, 'model.pt'))
                    print(f"New best model saved: {val_accuracy:.4f}")
    
    # Final evaluation
    if 'test' in data_loaders:
        print("Evaluating on test set...")
        criterion = nn.KLDivLoss(reduction='batchmean')
        test_loss, test_accuracy, predictions, labels = evaluate(
            model, data_loaders['test'], criterion, device
        )
        
        print(f"\nTest Results:")
        print(f"Test Loss: {test_loss:.4f}")
        print(f"Test Accuracy: {test_accuracy:.4f}")
        
        if len(predictions) > 0 and len(labels) > 0:
            # Ternary analysis
            gt = pd.Series(labels)
            predictions_series = pd.Series(predictions)
            
            gt_pne = gt.apply(lambda x: positive_negative_else(IDX_TO_EMOTION[x]))
            predictions_pne = predictions_series.apply(lambda x: positive_negative_else(IDX_TO_EMOTION[x]))
            
            ternary_accuracy = (gt_pne == predictions_pne).mean()
            print(f'Ternary accuracy (pos/neg/else): {ternary_accuracy:.4f}')
            
            # Save results
            results = {
                'test_accuracy': test_accuracy,
                'test_loss': test_loss,
                'ternary_accuracy': ternary_accuracy,
            }
            
            results_df = pd.DataFrame([results])
            results_df.to_csv(osp.join(my_out_dir, 'test_results.csv'), index=False)
            print(f"Results saved to {my_out_dir}")

main()

Using device: cuda
Loading EmotionCLIP-V2...
EmotionCLIP-V2 loaded successfully
Base EmotionCLIP model frozen
Loading data...
Loading ArtEmis data...
Original dataset size: 79388
After removing NaN: 79388
After checking image files: 79388
Train: 50 samples
Val: 50 samples
Test: 50 samples
Initializing fine-tuned model...
Initialized vision adapter with dimension 512
Creating data loaders...
Trainable parameters: 513
Starting training...

Epoch 1/5


Training:  25%|██▌       | 3/12 [00:00<00:01,  7.28it/s]

NaN in image features, using random features
NaN/Inf in logits batch 0, skipping
NaN in image features, using random features
NaN/Inf in logits batch 1, skipping
NaN in image features, using random features
NaN/Inf in logits batch 2, skipping


Training:  42%|████▏     | 5/12 [00:00<00:00,  8.88it/s]

NaN in image features, using random features
NaN/Inf in logits batch 3, skipping
NaN in image features, using random features
NaN/Inf in logits batch 4, skipping


Training:  67%|██████▋   | 8/12 [00:00<00:00, 11.17it/s]

NaN in image features, using random features
NaN/Inf in logits batch 5, skipping
NaN in image features, using random features
NaN/Inf in logits batch 6, skipping
NaN in image features, using random features
NaN/Inf in logits batch 7, skipping
NaN in image features, using random features
NaN/Inf in logits batch 8, skipping


Training:  83%|████████▎ | 10/12 [00:01<00:00, 12.20it/s]

NaN in image features, using random features
NaN/Inf in logits batch 9, skipping


Training: 100%|██████████| 12/12 [00:01<00:00,  8.94it/s]


NaN in image features, using random features
NaN/Inf in logits batch 10, skipping
NaN in image features, using random features
NaN/Inf in logits batch 11, skipping
Train Loss: inf, Train Accuracy: 0.0000
Training failed due to NaN/Inf loss. Stopping...
Evaluating on test set...


Evaluating:   8%|▊         | 1/13 [00:00<00:05,  2.36it/s]

NaN in image features, using random features


Evaluating:  31%|███       | 4/13 [00:00<00:01,  6.76it/s]

NaN in image features, using random features
NaN in image features, using random features
NaN in image features, using random features
NaN in image features, using random features


Evaluating:  54%|█████▍    | 7/13 [00:00<00:00,  9.63it/s]

NaN in image features, using random features
NaN in image features, using random features
NaN in image features, using random features


Evaluating:  77%|███████▋  | 10/13 [00:01<00:00, 12.39it/s]

NaN in image features, using random features
NaN in image features, using random features
NaN in image features, using random features


Evaluating: 100%|██████████| 13/13 [00:02<00:00,  6.44it/s]

NaN in image features, using random features
NaN in image features, using random features

Test Results:
Test Loss: inf
Test Accuracy: 0.0000



