In [None]:
!pip install torch torchvision pandas numpy matplotlib

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!git clone https://github.com/AlvinWen428/spatial-relation-benchmark.git
%cd spatial-relation-benchmark

In [None]:
!pip install -r requirements.txt

In [None]:
%cd /content/spatial-relation-benchmark
!mkdir -p data/spatialsense
!wget https://zenodo.org/api/records/8104370/files-archive -O spatialsense.zip
!unzip spatialsense.zip -d data/spatialsense/

In [None]:
!mkdir -p data/spatialsense/images
!tar -zxvf data/spatialsense/images.tar.gz -C data/spatialsense/images


In [None]:
!pip install gdown
import gdown

# ID du fichier Google Drive pour les annotations SpatialSense+
file_id = "1vIOozqk3OlxkxZgL356pD1EAGt06ZwM4"
output_path = "data/spatialsense/annots_spatialsenseplus.json"

# Télécharger les annotations
url = f"https://drive.google.com/uc?id={file_id}"
gdown.download(url, output_path, quiet=False)

In [None]:
!ls -l data/spatialsense/
!ls -l data/spatialsense/images/


In [None]:
import json
import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter

class SpatialSensePredicateAnalyzer:
    def __init__(self, data_dir="data/spatialsense"):
        self.data_dir = data_dir
        self.annotations = self._load_annotations()
        self.annotations_plus = self._load_annotations_plus()

        # Relations SpatialSense+ officielles (selon le tableau)
        self.official_spatialsense_relations = [
            'above', 'behind', 'in', 'in front of', 'next to',
            'on', 'to the left of', 'to the right of', 'under'
        ]

    def _load_annotations(self):
        """Charge et parse les annotations originales"""
        try:
            with open(os.path.join(self.data_dir, "annotations.json"), 'r') as f:
                data = json.load(f)
                print(f"   Annotations originales chargées: {len(data) if isinstance(data, list) else 'structure complexe'}")
                return data
        except Exception as e:
            print(f"    Erreur chargement annotations.json: {e}")
            return []

    def _load_annotations_plus(self):
        """Charge et parse les annotations SpatialSense+"""
        try:
            with open(os.path.join(self.data_dir, "annots_spatialsenseplus.json"), 'r') as f:
                data = json.load(f)
                print(f"   Annotations SpatialSense+ chargées")
                return data
        except Exception as e:
            print(f"       Pas d'annotations SpatialSense+ trouvées: {e}")
            return {}

    def extract_all_predicates(self):
        """Extrait TOUS les prédicats des deux sources"""
        predicates_original = []
        predicates_plus = []

        print("\n   Extraction des prédicats...")

        # 1. Prédicats des annotations originales
        if isinstance(self.annotations, list):
            for img_data in self.annotations:
                if 'annotations' in img_data:
                    for ann in img_data['annotations']:
                        if 'predicate' in ann and ann.get('label', False):
                            predicates_original.append(ann['predicate'])

        # 2. Prédicats des annotations SpatialSense+
        if isinstance(self.annotations_plus, dict):
            # Adapter selon la structure réelle
            for key, value in self.annotations_plus.items():
                if isinstance(value, list):
                    for item in value:
                        if isinstance(item, dict) and 'predicate' in item:
                            predicates_plus.append(item['predicate'])
                        elif isinstance(item, dict) and 'annotations' in item:
                            for ann in item['annotations']:
                                if 'predicate' in ann and ann.get('label', False):
                                    predicates_plus.append(ann['predicate'])
        elif isinstance(self.annotations_plus, list):
            for img_data in self.annotations_plus:
                if 'annotations' in img_data:
                    for ann in img_data['annotations']:
                        if 'predicate' in ann and ann.get('label', False):
                            predicates_plus.append(ann['predicate'])

        return predicates_original, predicates_plus

    def analyze_predicates(self):
        """Analyse complète de tous les prédicats"""
        predicates_original, predicates_plus = self.extract_all_predicates()

        # Comptage des prédicats
        count_original = Counter(predicates_original)
        count_plus = Counter(predicates_plus) if predicates_plus else Counter()

        # Combinaison de tous les prédicats uniques
        all_predicates = set(predicates_original + predicates_plus)

        print(f"\n   ANALYSE COMPLETE DES PREDICATS")
        print("="*60)
        print(f"Total prédicats annotations originales: {len(predicates_original)}")
        print(f"Total prédicats SpatialSense+: {len(predicates_plus)}")
        print(f"Prédicats uniques (total): {len(all_predicates)}")

        return {
            'predicates_original': count_original,
            'predicates_plus': count_plus,
            'all_unique_predicates': sorted(all_predicates),
            'total_original': len(predicates_original),
            'total_plus': len(predicates_plus),
            'unique_count': len(all_predicates)
        }

    def display_all_predicates(self):
        """Affiche TOUS les prédicats trouvés"""
        results = self.analyze_predicates()

        print(f"\n   TOUS LES PREDICATS UNIQUES TROUVES ({len(results['all_unique_predicates'])}):")
        print("="*80)

        for i, predicate in enumerate(results['all_unique_predicates'], 1):
            # Compter occurrences dans chaque source
            count_orig = results['predicates_original'].get(predicate, 0)
            count_plus = results['predicates_plus'].get(predicate, 0)
            total_count = count_orig + count_plus

            # Vérifier si c'est une relation SpatialSense+ officielle
            is_official = predicate in self.official_spatialsense_relations
            status = "  " if is_official else "❓"

            print(f"{i:2d}. {status} '{predicate}' → Total: {total_count} "
                  f"(Orig: {count_orig}, Plus: {count_plus})")

    def analyze_mapping_needs(self):
        """Analyse quels prédicats nécessitent un mapping vers SpatialSense+"""
        results = self.analyze_predicates()

        print(f"\n  ANALYSE DU MAPPING NECESSAIRE")
        print("="*60)

        # Classer les prédicats
        official_predicates = []
        need_mapping = []
        unclear = []

        for predicate in results['all_unique_predicates']:
            count_orig = results['predicates_original'].get(predicate, 0)
            count_plus = results['predicates_plus'].get(predicate, 0)
            total_count = count_orig + count_plus

            if predicate in self.official_spatialsense_relations:
                official_predicates.append((predicate, total_count))
            else:
                # Essayer de déterminer le mapping
                mapping = self._suggest_mapping(predicate)
                if mapping:
                    need_mapping.append((predicate, mapping, total_count))
                else:
                    unclear.append((predicate, total_count))

        # Affichage des résultats
        print(f"\n   PREDICATS OFFICIELS SPATIALSENSE+ ({len(official_predicates)}):")
        for pred, count in sorted(official_predicates, key=lambda x: x[1], reverse=True):
            print(f"   '{pred}': {count} occurrences")

        print(f"\n   PREDICATS NECESSITANT UN MAPPING ({len(need_mapping)}):")
        for pred, mapping, count in sorted(need_mapping, key=lambda x: x[2], reverse=True):
            print(f"   '{pred}' → '{mapping}': {count} occurrences")

        print(f"\n❓ PREDICATS AMBIGUS ({len(unclear)}):")
        for pred, count in sorted(unclear, key=lambda x: x[1], reverse=True):
            print(f"   '{pred}': {count} occurrences")

        return {
            'official': official_predicates,
            'need_mapping': need_mapping,
            'unclear': unclear
        }

    def _suggest_mapping(self, predicate):
        """Suggère un mapping vers une relation SpatialSense+ officielle"""
        predicate_lower = predicate.lower().strip()

        # Mapping suggéré basé sur la sémantique
        mapping_suggestions = {
            'over': 'above',
            'on top of': 'above',
            'below': 'under',
            'beneath': 'under',
            'underneath': 'under',
            'upon': 'on',
            'front': 'in front of',
            'in front': 'in front of',
            'beside': 'next to',
            'adjacent': 'next to',
            'adjacent to': 'next to',
            'left of': 'to the left of',
            'right of': 'to the right of',
            'left': 'to the left of',
            'right': 'to the right of',
            'to the left': 'to the left of',
            'to the right': 'to the right of',
            'inside': 'in',
            'within': 'in',
            'near': 'next to',
            'close to': 'next to',
            'nearby': 'next to',
            'far': 'next to',
            'outside': 'next to',
            'surrounding': 'next to',
            'between': 'next to',
            'touching': 'on',
            'against': 'on'
        }

        return mapping_suggestions.get(predicate_lower, None)

    def visualize_predicate_distribution(self):
        """Visualise la distribution des prédicats"""
        results = self.analyze_predicates()

        # Préparer les données pour la visualisation
        all_counts = {}
        for pred in results['all_unique_predicates']:
            count_orig = results['predicates_original'].get(pred, 0)
            count_plus = results['predicates_plus'].get(pred, 0)
            all_counts[pred] = count_orig + count_plus

        # Sélectionner les top 20 prédicats
        top_predicates = sorted(all_counts.items(), key=lambda x: x[1], reverse=True)[:20]

        # Créer la visualisation
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

        # 1. Distribution des top prédicats
        predicates, counts = zip(*top_predicates)
        colors = ['green' if pred in self.official_spatialsense_relations else 'orange'
                 for pred in predicates]

        bars = ax1.bar(range(len(predicates)), counts, color=colors, alpha=0.7)
        ax1.set_xlabel('Prédicats')
        ax1.set_ylabel('Nombre d\'occurrences')
        ax1.set_title('Top 20 Prédicats par Fréquence')
        ax1.set_xticks(range(len(predicates)))
        ax1.set_xticklabels(predicates, rotation=45, ha='right')

        # Ajouter les valeurs sur les barres
        for bar, count in zip(bars, counts):
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height + 5,
                    f'{count}', ha='center', va='bottom')

        # Légende
        ax1.legend(['SpatialSense+ Officiel', 'Nécessite Mapping'],
                  loc='upper right')

        # 2. Camembert des types de prédicats
        mapping_analysis = self.analyze_mapping_needs()

        sizes = [
            len(mapping_analysis['official']),
            len(mapping_analysis['need_mapping']),
            len(mapping_analysis['unclear'])
        ]
        labels = ['Officiels SpatialSense+', 'Nécessitent Mapping', 'Ambigus']
        colors = ['green', 'orange', 'red']

        ax2.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90)
        ax2.set_title('Répartition des Types de Prédicats')

        plt.tight_layout()
        plt.savefig('predicates_analysis.png', dpi=300, bbox_inches='tight')
        plt.show()

    def generate_mapping_code(self):
        """Génère le code de mapping automatiquement"""
        mapping_analysis = self.analyze_mapping_needs()

        print(f"\n   CODE DE MAPPING GENERE AUTOMATIQUEMENT:")
        print("="*70)
        print("spatialsense_exact_mapping = {")

        # Relations officielles (identité)
        print("    # Relations officielles SpatialSense+")
        for pred, _ in mapping_analysis['official']:
            print(f"    '{pred}': '{pred}',")

        print("\n    # Variantes nécessitant un mapping")
        for pred, mapping, _ in mapping_analysis['need_mapping']:
            print(f"    '{pred}': '{mapping}',")

        print("\n    # Relations ambigües (mapping par défaut)")
        for pred, _ in mapping_analysis['unclear']:
            print(f"    '{pred}': 'next to',  # TODO: vérifier ce mapping")

        print("}")

    def print_sample_data(self):
        """Affiche un échantillon des données pour comprendre leur structure"""
        print("\n=== ECHANTILLON DES DONNEES ===")

        print("\n   Structure annotations.json:")
        if self.annotations and len(self.annotations) > 0:
            print("Premier élément:")
            sample = self.annotations[0]
            print(f"  Clés: {list(sample.keys()) if isinstance(sample, dict) else 'Non-dict'}")
            if 'annotations' in sample and sample['annotations']:
                ann_sample = sample['annotations'][0]
                print(f"  Première annotation - Clés: {list(ann_sample.keys())}")
                if 'predicate' in ann_sample:
                    print(f"  Exemple prédicat: '{ann_sample['predicate']}'")

        print("\n   Structure annots_spatialsenseplus.json:")
        if self.annotations_plus:
            if isinstance(self.annotations_plus, dict):
                print(f"  Type: Dictionnaire avec clés: {list(self.annotations_plus.keys())}")
            elif isinstance(self.annotations_plus, list):
                print(f"  Type: Liste avec {len(self.annotations_plus)} éléments")
            else:
                print(f"  Type: {type(self.annotations_plus)}")

def main():
    """Fonction principale d'analyse"""
    print("   ANALYSEUR COMPLET DES PREDICATS SPATIALSENSE")
    print("="*80)

    # Initialiser l'analyseur
    analyzer = SpatialSensePredicateAnalyzer()

    # 1. Afficher la structure des données
    analyzer.print_sample_data()

    # 2. Afficher TOUS les prédicats
    analyzer.display_all_predicates()

    # 3. Analyser les besoins de mapping
    analyzer.analyze_mapping_needs()

    # 4. Générer le code de mapping
    analyzer.generate_mapping_code()

    # 5. Créer les visualisations
    try:
        analyzer.visualize_predicate_distribution()
    except Exception as e:
        print(f"       Erreur lors de la visualisation: {e}")

    print(f"\n   Analyse terminée!")
    print(f"   Fichier généré: predicates_analysis.png")

if __name__ == "__main__":
    main()

In [None]:
import json
import os
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.patches as patches
import numpy as np

class SpatialVisualizer:
    def __init__(self, data_dir="data/spatialsense"):
        self.data_dir = data_dir
        self.annotations = self._load_annotations()

    def _load_annotations(self):
        with open(os.path.join(self.data_dir, "annotations.json"), 'r') as f:
            return json.load(f)

    def _find_image_path(self, image_url):
        base_dir = os.path.join(self.data_dir, "images/images")
        filename = os.path.basename(image_url)

        if "staticflickr" in image_url or len(filename.split('_')) == 2:
            return os.path.join(base_dir, "flickr", filename)
        else:
            return os.path.join(base_dir, "nyu", filename)

    def _draw_bbox(self, ax, bbox, color, name, obj_type):
        """
        Dessine une bounding box avec son label
        bbox format: [y1, y2, x1, x2]
        """
        try:
            y1, y2, x1, x2 = bbox  # Correct order for the coordinates
            width = x2 - x1
            height = y2 - y1

            # Créer le rectangle en utilisant x1,y1 comme point de départ
            rect = patches.Rectangle(
                (x1, y1),  # Point de départ (coin supérieur gauche)
                width,     # Largeur
                height,    # Hauteur
                linewidth=2,
                edgecolor=color,
                facecolor='none'
            )
            ax.add_patch(rect)

            # Ajouter le label au-dessus de la box
            plt.text(
                x1, y1 - 5,
                f"{obj_type}: {name}",
                color=color,
                fontsize=10,
                bbox=dict(facecolor='white', alpha=0.7, edgecolor='none')
            )
        except Exception as e:
            print(f"Erreur lors du dessin de la bbox: {e}")
            print(f"bbox: {bbox}")

    def visualize_sample(self, index=0):
        image_data = self.annotations[index]
        print(f"Image URL: {image_data['url']}")
        print(f"Dimensions: {image_data['width']}x{image_data['height']}")
        print(f"Split: {image_data['split']}")

        img_path = self._find_image_path(image_data['url'])
        print(f"\nChemin de l'image: {img_path}")

        try:
            img = Image.open(img_path)
            img_array = np.array(img)

            valid_annotations = [ann for ann in image_data['annotations'] if ann['label']]
            print(f"Nombre de relations valides: {len(valid_annotations)}")

            for i, ann in enumerate(valid_annotations):
                plt.figure(figsize=(12, 8))
                plt.imshow(img_array)

                # Afficher les coordonnées pour le débogage
                print(f"\nRelation {i+1}:")
                print(f"Subject bbox: {ann['subject']['bbox']}")
                print(f"Object bbox: {ann['object']['bbox']}")

                # Traiter le sujet (rouge)
                subject = ann['subject']
                self._draw_bbox(plt.gca(), subject['bbox'], 'red', subject['name'], 'subject')

                # Traiter l'objet (bleu)
                object_ = ann['object']
                self._draw_bbox(plt.gca(), object_['bbox'], 'blue', object_['name'], 'object')

                plt.title(f"Relation {i+1}: {subject['name']} → {ann['predicate']} → {object_['name']}")
                plt.axis('off')
                plt.tight_layout()
                plt.show()

                print(f"Prédicat: {ann['predicate']}")
                print(f"Sujet: {subject['name']} à ({subject['x']}, {subject['y']})")
                print(f"Objet: {object_['name']} à ({object_['x']}, {object_['y']})")
                print("---")

        except Exception as e:
            print(f"Erreur lors du chargement de l'image: {e}")

# Utilisation
visualizer = SpatialVisualizer()
print("=== Visualisation de l'exemple ===")
visualizer.visualize_sample(0)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import json
import os
from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import cv2
from collections import Counter, defaultdict
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# CONFIGURATION GLOBALE
# =============================================================================

# Configuration GPU
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device utilisé: {DEVICE}")

# Hyperparamètres (basés sur l'article Haldekar et al. 2017)
BATCH_SIZE = 10  # Comme dans l'article
LEARNING_RATE = 0.001  # AdamOptimizer avec lr=0.001
EPOCHS = 10
K_FOLDS = 5
IMG_SIZE = 224  # Taille standard pour VGGNet (comme l'article)
DROPOUT_RATE = 0.5  # Comme dans l'article

# Relations spatiales EXACTES de SpatialSense+ selon le tableau fourni
SPATIAL_RELATIONS = [
    'above',        # Position plus haute dans la direction de la gravité
    'behind',       # Profondeur du sujet plus grande que l'objet
    'in',          # Sujet à l'intérieur de l'objet (semi-enclos)
    'in front of', # Profondeur du sujet plus petite que l'objet
    'next to',     # Sujet proche de l'objet, pas d'obstacle entre eux
    'on',          # Sujet sur le dessus de l'objet avec contact
    'to the left of',  # Sujet à gauche de l'objet (vue du labeller)
    'to the right of', # Sujet à droite de l'objet (vue du labeller)
    'under'        # Position plus basse dans la direction de la gravité
]

print(f"Relations SpatialSense+ configurées: {len(SPATIAL_RELATIONS)}")
print("Relations supportées:")
for i, rel in enumerate(SPATIAL_RELATIONS):
    print(f"  {i}: {rel}")

# =============================================================================
# DATASET SPATIAL SENSE+ ADAPTÉ
# =============================================================================

class SpatialSenseDataset(Dataset):
    """
    Dataset SpatialSense+ avec relations exactes du tableau
    Implémentation fidèle à Haldekar et al. 2017
    """

    def __init__(self, data_dir, split='train', transform=None):
        """
        Initialise le dataset SpatialSense+

        Args:
            data_dir: Répertoire racine contenant les données
            split: 'train', 'val', ou 'test'
            transform: Transformations PyTorch à appliquer
        """
        self.data_dir = data_dir
        self.split = split
        self.transform = transform

        # Mapping exact selon SpatialSense+ (défini AVANT le chargement)
        self.relation_to_idx = {rel: idx for idx, rel in enumerate(SPATIAL_RELATIONS)}
        self.idx_to_relation = {idx: rel for rel, idx in self.relation_to_idx.items()}

        # Chargement des données
        self.annotations = self._load_annotations()
        self.data_samples = self._prepare_samples()

        print(f"Dataset SpatialSense+ {split} initialisé avec {len(self.data_samples)} échantillons")
        if len(self.data_samples) > 0:
            self._print_statistics()

    def _load_annotations(self):
        """Charge les annotations SpatialSense+"""
        annotations_path = os.path.join(self.data_dir, 'annotations.json')
        try:
            with open(annotations_path, 'r') as f:
                return json.load(f)
        except FileNotFoundError:
            print(f"Erreur: Fichier {annotations_path} non trouvé!")
            print("Vérifiez que le dataset SpatialSense+ est dans le bon répertoire")
            return []
        except json.JSONDecodeError as e:
            print(f"Erreur lors du décodage JSON: {e}")
            return []

    def _find_image_path(self, image_url):
        """Trouve le chemin local d'une image SpatialSense+"""
        base_dir = os.path.join(self.data_dir, "images", "images")
        filename = os.path.basename(image_url)

        # Organisation SpatialSense+ : deux dossiers flickr et nyu
        if "staticflickr" in image_url or len(filename.split('_')) == 2:
            return os.path.join(base_dir, "flickr", filename)
        else:
            return os.path.join(base_dir, "nyu", filename)

    def _should_skip_sample_spatialsense(self, annotation, relation):
        """
        Applique les critères de filtrage SpatialSense+ selon le tableau:
        - behind/in front of: skip si objets dans directions très différentes
        - to the left of/to the right of: skip si différence depth/height trop grande

        Note: Implémentation basique car nécessiterait les bounding boxes détaillées
        """
        # Pour l'instant, on garde tous les échantillons
        # Dans une implémentation complète, on vérifierait:
        # 1. Pour 'behind'/'in front of': angle entre objets
        # 2. Pour 'to the left of'/'to the right of': différence de profondeur/hauteur
        return False

    def _prepare_samples(self):
        """
        Prépare les échantillons selon SpatialSense+
        Filtre selon les critères du tableau
        """
        samples = []
        images_not_found = 0
        relations_mapped = Counter()
        relations_filtered = 0

        for img_data in self.annotations:
            # Filtrer selon le split
            if img_data['split'] != self.split:
                continue

            img_path = self._find_image_path(img_data['url'])

            # Vérifier que l'image existe
            if not os.path.exists(img_path):
                images_not_found += 1
                continue

            # Pour chaque annotation positive dans l'image
            for ann in img_data['annotations']:
                if ann['label']:  # Seulement les annotations positives
                    original_relation = ann['predicate']
                    relation = original_relation

                    relations_mapped[f"{original_relation} → {relation}"] += 1

                    # Appliquer les critères de filtrage SpatialSense+
                    if self._should_skip_sample_spatialsense(ann, relation):
                        relations_filtered += 1
                        continue

                    sample = {
                        'image_path': img_path,
                        'subject': ann['subject']['name'],
                        'object': ann['object']['name'],
                        'relation': relation,
                        'original_relation': original_relation,
                        'subject_bbox': ann['subject'].get('bbox', None),
                        'object_bbox': ann['object'].get('bbox', None)
                    }
                    samples.append(sample)

        # Statistiques de preprocessing
        if images_not_found > 0:
            print(f"       Images non trouvées: {images_not_found}")
        if relations_filtered > 0:
            print(f"       Relations filtrées (critères SpatialSense+): {relations_filtered}")

        print(f"  Mapping des relations appliqué (top 5):")
        for mapping, count in relations_mapped.most_common(5):
            print(f"   {mapping}: {count}")

        return samples

    def _print_statistics(self):
        """Affiche des statistiques SpatialSense+"""
        relation_counts = Counter([s['relation'] for s in self.data_samples])
        print(f"\n   Distribution SpatialSense+ dans {self.split}:")
        total = len(self.data_samples)

        # Afficher dans l'ordre défini par SpatialSense+
        for relation in SPATIAL_RELATIONS:
            count = relation_counts.get(relation, 0)
            percentage = count / total * 100 if total > 0 else 0
            print(f"   {relation}: {count} ({percentage:.1f}%)")

        # Vérifier s'il y a des relations non prévues
        unexpected = set(relation_counts.keys()) - set(SPATIAL_RELATIONS)
        if unexpected:
            print(f"\n       Relations inattendues détectées: {unexpected}")

    def __len__(self):
        return len(self.data_samples)

    def __getitem__(self, idx):
        """
        Retourne un échantillon SpatialSense+
        Format: image complète (224x224) + label relation + métadonnées
        """
        sample = self.data_samples[idx]

        try:
            # Chargement image (comme l'article: 224x224)
            image = Image.open(sample['image_path']).convert('RGB')

            # Transformations (preprocessing VGG)
            if self.transform:
                image = self.transform(image)

            # Label de la relation SpatialSense+
            label = self.relation_to_idx[sample['relation']]

            # Métadonnées pour analyse
            metadata = {
                'subject': sample['subject'],
                'object': sample['object'],
                'relation': sample['relation'],
                'original_relation': sample['original_relation']
            }

            return image, label, metadata

        except Exception as e:
            print(f"    Erreur chargement {sample['image_path']}: {e}")
            # Image par défaut en cas d'erreur
            dummy_image = Image.new('RGB', (224, 224), color='gray')
            if self.transform:
                dummy_image = self.transform(dummy_image)
            else:
                dummy_image = torch.zeros(3, 224, 224)

            return dummy_image, 0, {
                'subject': 'error', 'object': 'error',
                'relation': 'next to', 'original_relation': 'error'
            }

# =============================================================================
# ARCHITECTURE HALDEKAR ET AL. 2017 - EXACTE
# =============================================================================

class VGGFeatureExtractor(nn.Module):
    """
    Extracteur de features basé sur VGG16 pré-entraîné
    EXACTEMENT comme dans l'article Haldekar et al. 2017
    Extrait les features de la couche FC-7 (4096 dimensions)
    """

    def __init__(self):
        super(VGGFeatureExtractor, self).__init__()

        # Chargement de VGG16 pré-entraîné (comme l'article)
        vgg16 = models.vgg16(pretrained=True)

        print("   Initialisation VGG Feature Extractor (Haldekar-style):")
        print("   - Modèle: VGG16 pré-entraîné sur ImageNet")
        print("   - Features: Couches convolutionnelles + pooling")
        print("   - Output: FC-7 (4096 dimensions)")

        # Extraction des couches convolutionnelles
        self.features = vgg16.features
        self.avgpool = vgg16.avgpool

        # Classifier jusqu'à FC-7 (deuxième couche fully connected)
        # VGG16 classifier: FC1(4096) -> ReLU -> Dropout -> FC2(4096) -> ReLU -> Dropout -> FC3(1000)
        # Nous prenons jusqu'à FC2 + ReLU + Dropout (FC-7 dans l'article)
        classifier_layers = list(vgg16.classifier.children())[:6]  # Jusqu'à FC-7
        self.fc7 = nn.Sequential(*classifier_layers)

        # Gel des poids pour l'extraction de features (comme l'article)
        for param in self.parameters():
            param.requires_grad = False

        print("   - Paramètres gelés:    (feature extraction seulement)")

    def forward(self, x):
        """Forward pass pour extraire features FC-7"""
        x = self.features(x)      # Convolutions + pooling
        x = self.avgpool(x)       # Average pooling
        x = torch.flatten(x, 1)   # Aplatissement
        x = self.fc7(x)          # FC-7 (4096 dimensions)
        return x

class SpatialRelationMLP(nn.Module):
    """
    MLP EXACT de l'article Haldekar et al. 2017
    Architecture: 4096 -> 512 -> 256 -> num_relations
    ReLU activation + Dropout 0.5
    """

    def __init__(self, input_dim=4096, hidden1_dim=512, hidden2_dim=256,
                 num_relations=len(SPATIAL_RELATIONS), dropout_rate=0.5):
        super(SpatialRelationMLP, self).__init__()

        print(f"   Initialisation MLP Haldekar (architecture exacte):")
        print(f"   - Input: {input_dim} (FC-7 VGG)")
        print(f"   - Hidden 1 (FC-0): {hidden1_dim}")
        print(f"   - Hidden 2 (FC-1): {hidden2_dim}")
        print(f"   - Output: {num_relations} relations SpatialSense+")
        print(f"   - Dropout: {dropout_rate}")
        print(f"   - Activation: ReLU")

        # Première couche cachée (FC-0 dans l'article)
        self.fc1 = nn.Linear(input_dim, hidden1_dim)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)

        # Deuxième couche cachée (FC-1 dans l'article)
        self.fc2 = nn.Linear(hidden1_dim, hidden2_dim)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_rate)

        # Couche de sortie (classification)
        self.fc3 = nn.Linear(hidden2_dim, num_relations)

    def forward(self, x):
        # FC-0: 4096 -> 512
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        # FC-1: 512 -> 256
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        # Output: 256 -> num_relations
        x = self.fc3(x)
        return x

class SpatialRelationModel(nn.Module):
    """
    Modèle complet Haldekar et al. 2017 adapté à SpatialSense+
    VGG Feature Extractor (gelé) + MLP (entraînable)
    """

    def __init__(self, num_relations=len(SPATIAL_RELATIONS)):
        super(SpatialRelationModel, self).__init__()

        print(f"    Modèle Haldekar et al. 2017 pour SpatialSense+:")
        print(f"   - Dataset: SpatialSense+ ({num_relations} relations)")
        print(f"   - Architecture: VGG16 + MLP")
        print(f"   - Comparaison: Article original sur SUN09 (3 relations)")

        # Extracteur de features VGG (gelé, comme l'article)
        self.feature_extractor = VGGFeatureExtractor()

        # Classifieur MLP (entraînable, architecture exacte)
        self.classifier = SpatialRelationMLP(num_relations=num_relations)

    def forward(self, x):
        # Extraction des features VGG FC-7 (sans gradient comme l'article)
        with torch.no_grad():
            features = self.feature_extractor(x)

        # Classification MLP (avec gradient)
        output = self.classifier(features)
        return output

    def get_features(self, x):
        """Méthode utilitaire pour obtenir les features FC-7"""
        return self.feature_extractor(x)

# =============================================================================
# TRANSFORMATIONS (identiques à l'article)
# =============================================================================

def create_transforms():
    """
    Créé les transformations pour VGG (comme l'article)
    Normalisation ImageNet standard
    """
    # Normalisation ImageNet (utilisée par VGG pré-entraîné)
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet mean
        std=[0.229, 0.224, 0.225]    # ImageNet std
    )

    # Transformations d'entraînement
    train_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),  # 224x224 comme l'article
        transforms.RandomHorizontalFlip(p=0.5),  # Augmentation simple
        transforms.ToTensor(),
        normalize
    ])

    # Transformations de validation/test (pas d'augmentation)
    val_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        normalize
    ])

    return train_transform, val_transform

# =============================================================================
# FONCTIONS D'ENTRAÎNEMENT (identiques, mais commentées pour SpatialSense+)
# =============================================================================

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    """Entraîne le modèle pour une époque (Haldekar-style) avec gestion flexible"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    progress_bar = tqdm(dataloader, desc='Training SpatialSense+')
    for batch_idx, batch_data in enumerate(progress_bar):
        # Gestion flexible du format de batch
        if len(batch_data) >= 2:
            images, labels = batch_data[0], batch_data[1]
        else:
            continue

        images, labels = images.to(device), labels.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Statistiques
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Mise à jour de la barre de progression
        progress_bar.set_postfix({
            'loss': running_loss / (batch_idx + 1),
            'acc': 100. * correct / total
        })

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total

    return epoch_loss, epoch_acc

def evaluate(model, dataloader, criterion, device):
    """Évalue le modèle (SpatialSense+) avec gestion flexible des métadonnées"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_data in tqdm(dataloader, desc='Evaluating SpatialSense+'):
            # Gestion flexible du format de batch
            if len(batch_data) >= 2:
                images, labels = batch_data[0], batch_data[1]
            else:
                continue

            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_loss = running_loss / len(dataloader)
    accuracy = 100. * correct / total

    return avg_loss, accuracy

# =============================================================================
# NOUVELLE FONCTION: MATRICE DE CONFUSION ET ANALYSE D'ERREURS
# =============================================================================

def generate_confusion_matrix_and_analysis(model, dataloader, device, save_prefix="spatialsense"):
    """
    Génère une matrice de confusion détaillée et analyse les erreurs du modèle
    """
    model.eval()

    # Collecte des prédictions et métadonnées
    all_predictions = []
    all_true_labels = []
    all_metadata = []
    error_examples = defaultdict(list)  # Pour stocker des exemples d'erreurs

    print("   Génération matrice de confusion et analyse d'erreurs...")

    with torch.no_grad():
        for batch_data in tqdm(dataloader, desc='Analyzing predictions'):
            # Gestion flexible du format de batch
            if len(batch_data) == 3:
                images, labels, metadata = batch_data
            else:
                images, labels = batch_data[:2]
                metadata = [{'subject': 'unknown', 'object': 'unknown', 'relation': SPATIAL_RELATIONS[labels[i].item()]} for i in range(len(labels))]

            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            probabilities = torch.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)

            # Stocker les résultats
            for i in range(len(labels)):
                pred_idx = predicted[i].item()
                true_idx = labels[i].item()

                all_predictions.append(pred_idx)
                all_true_labels.append(true_idx)

                # Gestion des métadonnées
                if isinstance(metadata, list) and i < len(metadata):
                    current_metadata = metadata[i]
                elif isinstance(metadata, dict):
                    # Si metadata est un dict avec des listes
                    current_metadata = {
                        'subject': metadata.get('subject', ['unknown'] * len(labels))[i] if 'subject' in metadata else 'unknown',
                        'object': metadata.get('object', ['unknown'] * len(labels))[i] if 'object' in metadata else 'unknown',
                        'relation': metadata.get('relation', [SPATIAL_RELATIONS[true_idx]] * len(labels))[i] if 'relation' in metadata else SPATIAL_RELATIONS[true_idx]
                    }
                else:
                    current_metadata = {'subject': 'unknown', 'object': 'unknown', 'relation': SPATIAL_RELATIONS[true_idx]}

                all_metadata.append(current_metadata)

                # Stocker les erreurs avec leurs probabilités
                if pred_idx != true_idx:
                    true_rel = SPATIAL_RELATIONS[true_idx]
                    pred_rel = SPATIAL_RELATIONS[pred_idx]
                    confidence = probabilities[i, pred_idx].item()

                    error_info = {
                        'metadata': current_metadata,
                        'predicted_relation': pred_rel,
                        'true_relation': true_rel,
                        'confidence': confidence,
                        'probabilities': probabilities[i].cpu().numpy()
                    }
                    error_examples[(true_rel, pred_rel)].append(error_info)

    # Calcul de la matrice de confusion
    cm = confusion_matrix(all_true_labels, all_predictions)

    # Visualisation de la matrice de confusion
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(20, 16))

    # 1. Matrice de confusion avec nombres absolus
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=SPATIAL_RELATIONS, yticklabels=SPATIAL_RELATIONS,
                ax=ax1, cbar_kws={'shrink': 0.8})
    ax1.set_title('Matrice de Confusion - Nombres Absolus', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Prédictions')
    ax1.set_ylabel('Vérité Terrain')
    ax1.tick_params(axis='x', rotation=45)
    ax1.tick_params(axis='y', rotation=0)

    # 2. Matrice de confusion normalisée (par ligne)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    cm_normalized = np.nan_to_num(cm_normalized)  # Gérer les divisions par zéro

    sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Reds',
                xticklabels=SPATIAL_RELATIONS, yticklabels=SPATIAL_RELATIONS,
                ax=ax2, cbar_kws={'shrink': 0.8})
    ax2.set_title('Matrice de Confusion - Normalisée par Ligne', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Prédictions')
    ax2.set_ylabel('Vérité Terrain')
    ax2.tick_params(axis='x', rotation=45)
    ax2.tick_params(axis='y', rotation=0)

    # 3. Analyse des erreurs les plus fréquentes
    error_counts = Counter()
    for (true_rel, pred_rel), examples in error_examples.items():
        if true_rel != pred_rel:
            error_counts[(true_rel, pred_rel)] = len(examples)

    top_errors = error_counts.most_common(10)
    if top_errors:
        error_pairs, error_freqs = zip(*top_errors)
        error_labels = [f"{true_rel}\n→\n{pred_rel}" for true_rel, pred_rel in error_pairs]

        bars = ax3.bar(range(len(error_labels)), error_freqs, color='lightcoral', alpha=0.8)
        ax3.set_title('Top 10 des Erreurs les Plus Fréquentes', fontsize=14, fontweight='bold')
        ax3.set_xlabel('Type d\'Erreur (Vrai → Prédit)')
        ax3.set_ylabel('Nombre d\'Erreurs')
        ax3.set_xticks(range(len(error_labels)))
        ax3.set_xticklabels(error_labels, rotation=45, ha='right')
        ax3.grid(True, axis='y', alpha=0.3)

        # Ajouter les valeurs sur les barres
        for bar, freq in zip(bars, error_freqs):
            height = bar.get_height()
            ax3.annotate(f'{freq}',
                        xy=(bar.get_x() + bar.get_width() / 2, height),
                        xytext=(0, 3), textcoords="offset points",
                        ha='center', va='bottom', fontsize=9)

    # 4. Distribution des confiances pour les erreurs vs succès
    correct_confidences = []
    error_confidences = []

    model.eval()
    with torch.no_grad():
        for images, labels, metadata in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            probabilities = torch.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)

            for i in range(len(labels)):
                confidence = probabilities[i, predicted[i]].item()
                if predicted[i] == labels[i]:
                    correct_confidences.append(confidence)
                else:
                    error_confidences.append(confidence)

    ax4.hist(correct_confidences, bins=30, alpha=0.7, label='Prédictions Correctes',
             color='lightgreen', density=True)
    ax4.hist(error_confidences, bins=30, alpha=0.7, label='Erreurs',
             color='lightcoral', density=True)
    ax4.set_title('Distribution des Confiances', fontsize=14, fontweight='bold')
    ax4.set_xlabel('Confiance de Prédiction')
    ax4.set_ylabel('Densité')
    ax4.legend()
    ax4.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{save_prefix}_confusion_matrix_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()

    # Rapport de classification détaillé
    print("\n" + "="*80)
    print("   RAPPORT DE CLASSIFICATION DÉTAILLÉ")
    print("="*80)

    report = classification_report(all_true_labels, all_predictions,
                                 target_names=SPATIAL_RELATIONS,
                                 output_dict=True, zero_division=0)

    # Affichage formaté du rapport
    print(f"{'Relation':<15} {'Precision':<10} {'Recall':<10} {'F1-Score':<10} {'Support':<10}")
    print("-" * 65)

    for i, relation in enumerate(SPATIAL_RELATIONS):
        if relation in report:
            metrics = report[relation]
            print(f"{relation:<15} {metrics['precision']:<10.3f} {metrics['recall']:<10.3f} "
                  f"{metrics['f1-score']:<10.3f} {int(metrics['support']):<10}")

    print("-" * 65)
    print(f"{'Accuracy':<15} {'':<10} {'':<10} {report['accuracy']:<10.3f} {int(report['macro avg']['support']):<10}")
    print(f"{'Macro avg':<15} {report['macro avg']['precision']:<10.3f} {report['macro avg']['recall']:<10.3f} "
          f"{report['macro avg']['f1-score']:<10.3f} {int(report['macro avg']['support']):<10}")
    print(f"{'Weighted avg':<15} {report['weighted avg']['precision']:<10.3f} {report['weighted avg']['recall']:<10.3f} "
          f"{report['weighted avg']['f1-score']:<10.3f} {int(report['weighted avg']['support']):<10}")

    return cm, error_examples, report

def analyze_error_patterns(error_examples, save_prefix="spatialsense"):
    """
    Analyse détaillée des patterns d'erreurs
    """
    print("\n" + "="*80)
    print("   ANALYSE DÉTAILLÉE DES PATTERNS D'ERREURS")
    print("="*80)

    # 1. Analyse des confusions par catégorie sémantique
    semantic_groups = {
        'vertical': ['above', 'under', 'on'],
        'horizontal': ['to the left of', 'to the right of', 'next to'],
        'depth': ['behind', 'in front of'],
        'containment': ['in']
    }

    # Matrice de confusion inter-groupes
    group_confusion = defaultdict(lambda: defaultdict(int))

    for (true_rel, pred_rel), examples in error_examples.items():
        true_group = None
        pred_group = None

        for group, relations in semantic_groups.items():
            if true_rel in relations:
                true_group = group
            if pred_rel in relations:
                pred_group = group

        if true_group and pred_group:
            group_confusion[true_group][pred_group] += len(examples)

    print("\n   Confusions entre Groupes Sémantiques:")
    print(f"{'Vrai→Prédit':<20} {'Vertical':<10} {'Horizontal':<12} {'Depth':<8} {'Containment':<12}")
    print("-" * 70)

    for true_group in semantic_groups.keys():
        row = f"{true_group:<20}"
        for pred_group in semantic_groups.keys():
            count = group_confusion[true_group][pred_group]
            row += f"{count:<10}"
        print(row)

    # 2. Analyse des objets/sujets les plus problématiques
    print("\n  Objets/Sujets les Plus Problématiques:")

    subject_errors = defaultdict(int)
    object_errors = defaultdict(int)

    for (true_rel, pred_rel), examples in error_examples.items():
        for example in examples:
            subject_errors[example['metadata']['subject']] += 1
            object_errors[example['metadata']['object']] += 1

    print("\nTop 10 Sujets avec le Plus d'Erreurs:")
    for subject, count in Counter(subject_errors).most_common(10):
        print(f"   • {subject}: {count} erreurs")

    print("\nTop 10 Objets avec le Plus d'Erreurs:")
    for obj, count in Counter(object_errors).most_common(10):
        print(f"   • {obj}: {count} erreurs")

    # 3. Analyse des patterns de confusion spécifiques
    print("\n   Patterns de Confusion Spécifiques:")

    # Confusions directionnelles
    directional_confusions = {
        ('to the left of', 'to the right of'): 'Inversion gauche-droite',
        ('to the right of', 'to the left of'): 'Inversion droite-gauche',
        ('above', 'under'): 'Inversion haut-bas',
        ('under', 'above'): 'Inversion bas-haut',
        ('behind', 'in front of'): 'Inversion profondeur',
        ('in front of', 'behind'): 'Inversion profondeur inverse'
    }

    confusion_analysis = {}
    for (true_rel, pred_rel), description in directional_confusions.items():
        if (true_rel, pred_rel) in error_examples:
            count = len(error_examples[(true_rel, pred_rel)])
            confusion_analysis[description] = count

    print("\nConfusions Directionnelles:")
    for pattern, count in sorted(confusion_analysis.items(), key=lambda x: x[1], reverse=True):
        print(f"   • {pattern}: {count} cas")

    # 4. Analyse de la confiance dans les erreurs
    print("\n       Analyse de Confiance dans les Erreurs:")

    high_confidence_errors = []
    low_confidence_errors = []

    for (true_rel, pred_rel), examples in error_examples.items():
        for example in examples:
            if example['confidence'] > 0.7:
                high_confidence_errors.append((true_rel, pred_rel, example['confidence']))
            elif example['confidence'] < 0.3:
                low_confidence_errors.append((true_rel, pred_rel, example['confidence']))

    print(f"\nErreurs Haute Confiance (>70%): {len(high_confidence_errors)}")
    if high_confidence_errors:
        high_conf_counter = Counter([(true_rel, pred_rel) for true_rel, pred_rel, _ in high_confidence_errors])
        for (true_rel, pred_rel), count in high_conf_counter.most_common(5):
            avg_conf = np.mean([conf for t, p, conf in high_confidence_errors if t == true_rel and p == pred_rel])
            print(f"   • {true_rel} → {pred_rel}: {count} cas (conf. moy: {avg_conf:.2f})")

    print(f"\nErreurs Basse Confiance (<30%): {len(low_confidence_errors)}")
    if low_confidence_errors:
        low_conf_counter = Counter([(true_rel, pred_rel) for true_rel, pred_rel, _ in low_confidence_errors])
        for (true_rel, pred_rel), count in low_conf_counter.most_common(5):
            avg_conf = np.mean([conf for t, p, conf in low_confidence_errors if t == true_rel and p == pred_rel])
            print(f"   • {true_rel} → {pred_rel}: {count} cas (conf. moy: {avg_conf:.2f})")

    # 5. Suggestions d'amélioration
    print("\n       SUGGESTIONS D'AMÉLIORATION:")

    # Analyser les patterns pour suggérer des améliorations
    suggestions = []

    # Vérifier les confusions directionnelles
    directional_error_rate = sum(confusion_analysis.values()) / sum(len(examples) for examples in error_examples.values()) * 100
    if directional_error_rate > 20:
        suggestions.append("Augmentation de données avec flips horizontaux/verticaux")
        suggestions.append("Pré-traitement pour normaliser l'orientation des images")

    # Vérifier la distribution des erreurs par groupe
    max_group_errors = max(sum(group_confusion[group].values()) for group in semantic_groups.keys())
    if max_group_errors > 50:
        suggestions.append("Équilibrage du dataset par groupe sémantique")
        suggestions.append("Perte pondérée pour compenser le déséquilibre")

    # Vérifier les erreurs haute confiance
    if len(high_confidence_errors) > 10:
        suggestions.append("Révision des annotations pour les erreurs haute confiance")
        suggestions.append("Ensemble de modèles pour réduire la sur-confiance")

    # Vérifier la performance sur les objets spécifiques
    if len(subject_errors) > 0:
        max_subject_errors = max(subject_errors.values())
        if max_subject_errors > 5:
            suggestions.append("Features spécifiques aux objets problématiques")
            suggestions.append("Augmentation ciblée pour les objets difficiles")

    print("Recommandations basées sur l'analyse:")
    for i, suggestion in enumerate(suggestions, 1):
        print(f"   {i}. {suggestion}")

    if not suggestions:
        print("      Le modèle montre des patterns d'erreurs acceptables")

    return group_confusion, subject_errors, object_errors

def visualize_error_examples(model, dataset, error_examples, device, num_examples=6, save_prefix="spatialsense"):
    """
    Visualise des exemples d'erreurs spécifiques avec heatmaps
    """
    model.eval()

    # Sélectionner des erreurs intéressantes
    selected_errors = []

    # Prendre les erreurs les plus fréquentes
    error_counts = Counter()
    for (true_rel, pred_rel), examples in error_examples.items():
        error_counts[(true_rel, pred_rel)] = len(examples)

    # Sélectionner des exemples variés
    for (true_rel, pred_rel), count in error_counts.most_common():
        if len(selected_errors) >= num_examples:
            break
        if error_examples[(true_rel, pred_rel)]:
            # Prendre l'exemple avec la plus haute confiance (erreur la plus "convaincue")
            best_example = max(error_examples[(true_rel, pred_rel)], key=lambda x: x['confidence'])
            selected_errors.append(best_example)

    if not selected_errors:
        print("    Aucun exemple d'erreur à visualiser")
        return

    fig, axes = plt.subplots(len(selected_errors), 4, figsize=(16, 4*len(selected_errors)))
    if len(selected_errors) == 1:
        axes = axes.reshape(1, -1)

    print(f"   Visualisation de {len(selected_errors)} exemples d'erreurs...")

    for i, error_example in enumerate(selected_errors):
        # Rechercher l'échantillon correspondant dans le dataset
        target_metadata = error_example['metadata']
        sample_found = False

        for j in range(len(dataset)):
            image, label, metadata = dataset[j]
            if (metadata['subject'] == target_metadata['subject'] and
                metadata['object'] == target_metadata['object'] and
                metadata['relation'] == target_metadata['relation']):

                sample_found = True

                # Dénormaliser pour affichage
                img_display = image.clone()
                mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
                std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
                img_display = img_display * std + mean
                img_display = img_display.clamp(0, 1).permute(1, 2, 0)

                # Prédiction
                with torch.no_grad():
                    image_tensor = image.unsqueeze(0).to(device)
                    output = model(image_tensor)
                    probs = torch.softmax(output, dim=1)
                    pred_idx = torch.argmax(probs)

                # Heatmap pour cette erreur
                heatmap = generate_heatmap_haldekar_style(model, image, label, device)

                # 1. Image originale
                axes[i, 0].imshow(img_display)
                axes[i, 0].set_title(f"    Erreur #{i+1}\n{metadata['subject']} - {metadata['object']}")
                axes[i, 0].axis('off')

                # 2. Heatmap
                im = axes[i, 1].imshow(heatmap, cmap='hot')
                axes[i, 1].set_title('Heatmap Attention')
                axes[i, 1].axis('off')
                plt.colorbar(im, ax=axes[i, 1], shrink=0.8)

                # 3. Superposition
                axes[i, 2].imshow(img_display)
                axes[i, 2].imshow(heatmap, cmap='jet', alpha=0.5)
                axes[i, 2].set_title('Superposition')
                axes[i, 2].axis('off')

                # 4. Analyse détaillée
                axes[i, 3].axis('off')

                # Probabilités top-3
                top3_indices = torch.topk(probs[0], 3).indices
                top3_probs = torch.topk(probs[0], 3).values

                prob_text = "Top-3 Prédictions:\n"
                for k, (idx, prob) in enumerate(zip(top3_indices, top3_probs)):
                    rel_name = SPATIAL_RELATIONS[idx]
                    marker = " " if k == 0 else "  "
                    prob_text += f"{marker} {rel_name}: {prob:.1%}\n"

                info_text = f"""
Erreur d'Analyse:

Vérité: {error_example['true_relation']}
Prédiction: {error_example['predicted_relation']}
Confiance: {error_example['confidence']:.1%}

{prob_text}

Type d'Erreur:
{get_error_type(error_example['true_relation'], error_example['predicted_relation'])}

Gravité: {'🔴 Haute' if error_example['confidence'] > 0.7 else '🟡 Moyenne' if error_example['confidence'] > 0.4 else '🟢 Faible'}
                """

                color = 'lightcoral' if error_example['confidence'] > 0.7 else 'lightyellow'
                axes[i, 3].text(0.1, 0.5, info_text, transform=axes[i, 3].transAxes,
                               fontsize=9, verticalalignment='center',
                               bbox=dict(boxstyle='round', facecolor=color, alpha=0.8))
                break

        if not sample_found:
            print(f"       Échantillon d'erreur {i+1} non trouvé dans le dataset")
            for j in range(4):
                axes[i, j].axis('off')
                axes[i, j].text(0.5, 0.5, 'Échantillon\nnon trouvé',
                               ha='center', va='center', transform=axes[i, j].transAxes)

    plt.suptitle('Analyse Détaillée des Erreurs les Plus Significatives',
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(f'{save_prefix}_error_examples_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()

def get_error_type(true_relation, predicted_relation):
    """
    Détermine le type d'erreur commise
    """
    # Groupes sémantiques
    vertical_relations = ['above', 'under', 'on']
    horizontal_relations = ['to the left of', 'to the right of', 'next to']
    depth_relations = ['behind', 'in front of']
    containment_relations = ['in']

    def get_group(relation):
        if relation in vertical_relations:
            return 'vertical'
        elif relation in horizontal_relations:
            return 'horizontal'
        elif relation in depth_relations:
            return 'depth'
        elif relation in containment_relations:
            return 'containment'
        return 'unknown'

    true_group = get_group(true_relation)
    pred_group = get_group(predicted_relation)

    # Analyse du type d'erreur
    if true_group == pred_group:
        # Erreur dans le même groupe
        if true_group == 'vertical':
            if (true_relation == 'above' and predicted_relation == 'under') or \
               (true_relation == 'under' and predicted_relation == 'above'):
                return "   Inversion verticale"
            else:
                return "   Confusion verticale"
        elif true_group == 'horizontal':
            if (true_relation == 'to the left of' and predicted_relation == 'to the right of') or \
               (true_relation == 'to the right of' and predicted_relation == 'to the left of'):
                return "   Inversion horizontale"
            else:
                return "   Confusion horizontale"
        elif true_group == 'depth':
            return "   Inversion profondeur"
        else:
            return f"   Confusion {true_group}"
    else:
        # Erreur entre groupes différents
        return f"   Confusion inter-groupe\n({true_group} → {pred_group})"

# =============================================================================
# GÉNÉRATION DE HEATMAPS (méthode de l'article)
# =============================================================================

def generate_heatmap_haldekar_style(model, image, label, device, mask_size=16, stride=8):
    """
    Génère une heatmap EXACTEMENT comme dans l'article Haldekar et al. 2017
    Méthode: masking séquentiel + mesure d'influence (entropy change)
    """
    model.eval()

    # Préparer l'image
    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    image = image.to(device)

    # Prédiction sur l'image originale (baseline)
    with torch.no_grad():
        original_output = model(image)
        original_probs = torch.softmax(original_output, dim=1)
        original_prob = original_probs[0, label].item()

    # Dimensions de l'image
    _, _, h, w = image.shape

    # Matrice d'influence (comme Figure 2 de l'article)
    influence_map = np.zeros((h // stride, w // stride))

    # Application séquentielle du masque (comme l'article)
    for i in range(0, h - mask_size, stride):
        for j in range(0, w - mask_size, stride):
            # Image masquée (région grise comme l'article)
            masked_image = image.clone()
            masked_image[:, :, i:i+mask_size, j:j+mask_size] = 0.5  # Masque gris

            # Prédiction avec masque
            with torch.no_grad():
                masked_output = model(masked_image)
                masked_probs = torch.softmax(masked_output, dim=1)
                masked_prob = masked_probs[0, label].item()

            # Calcul de l'influence (comme équation 2 de l'article)
            influence = original_prob - masked_prob
            influence_map[i // stride, j // stride] = influence

    # Redimensionner à la taille originale
    heatmap = cv2.resize(influence_map, (w, h), interpolation=cv2.INTER_CUBIC)

    return heatmap

# =============================================================================
# VALIDATION K-FOLD POUR SPATIALSENSE+
# =============================================================================

def kfold_cross_validation_spatialsense(data_dir, k_folds=5, epochs=20):
    """
    Validation croisée K-fold sur SpatialSense+
    Comparaison avec les résultats Haldekar et al. 2017
    """
    print("     Validation K-fold Haldekar et al. 2017 sur SpatialSense+")
    print("="*70)
    print(f"Dataset: SpatialSense+ ({len(SPATIAL_RELATIONS)} relations)")
    print(f"Architecture: VGG16 FC-7 + MLP (4096→512→256→{len(SPATIAL_RELATIONS)})")
    print(f"Comparaison: Article original SUN09 (3 relations, 55.98% test)")
    print("="*70)

    # Transformations
    train_transform, val_transform = create_transforms()

    # Dataset complet
    full_dataset = SpatialSenseDataset(
        data_dir=data_dir,
        split='train',
        transform=train_transform
    )

    if len(full_dataset) == 0:
        print("    Dataset SpatialSense+ vide!")
        return []

    # Configuration K-fold
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)
    fold_results = []

    # Pour chaque fold
    for fold, (train_indices, val_indices) in enumerate(kfold.split(full_dataset)):
        print(f"\n{'='*50}")
        print(f"FOLD {fold + 1}/{k_folds} - SPATIALSENSE+")
        print(f"{'='*50}")

        # Sous-ensembles
        train_subset = torch.utils.data.Subset(full_dataset, train_indices)
        val_subset = torch.utils.data.Subset(full_dataset, val_indices)

        # DataLoaders
        train_loader = DataLoader(
            train_subset, batch_size=BATCH_SIZE, shuffle=True,
            num_workers=2, pin_memory=True
        )

        val_loader = DataLoader(
            val_subset, batch_size=BATCH_SIZE, shuffle=False,
            num_workers=2, pin_memory=True
        )

        # Nouveau modèle Haldekar pour ce fold
        model = SpatialRelationModel(num_relations=len(SPATIAL_RELATIONS))
        model = model.to(DEVICE)

        # Optimiseur (comme l'article: Adam sur le classifieur seulement)
        optimizer = optim.Adam(model.classifier.parameters(), lr=LEARNING_RATE)
        criterion = nn.CrossEntropyLoss()  # Cross entropy comme l'article

        # Historique d'entraînement
        train_losses, val_losses = [], []
        train_accs, val_accs = [], []
        best_val_acc = 0.0

        # Boucle d'entraînement
        for epoch in range(epochs):
            print(f"\nEpoch {epoch+1}/{epochs}")

            # Phase d'entraînement
            train_loss, train_acc = train_one_epoch(
                model, train_loader, criterion, optimizer, DEVICE
            )
            train_losses.append(train_loss)
            train_accs.append(train_acc)

            # Phase de validation
            val_loss, val_acc = evaluate(model, val_loader, criterion, DEVICE)
            val_losses.append(val_loss)
            val_accs.append(val_acc)

            print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
            print(f"Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%")

            # Sauvegarde si meilleur modèle
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), f'best_spatialsense_model_fold_{fold+1}.pth')
                print(f"  Nouveau meilleur modèle sauvé: {val_acc:.2f}%")

        # Résultats du fold
        fold_results.append({
            'fold': fold + 1,
            'train_losses': train_losses,
            'val_losses': val_losses,
            'train_accs': train_accs,
            'val_accs': val_accs,
            'best_val_acc': best_val_acc,
            'final_train_acc': train_accs[-1],
            'final_val_acc': val_accs[-1],
            'model': model
        })

        # Sauvegarde modèle final du fold
        torch.save(model.state_dict(), f'spatialsense_model_fold_{fold+1}_final.pth')

        print(f"\n   Résumé Fold {fold + 1}:")
        print(f"   - Meilleure val accuracy: {best_val_acc:.2f}%")
        print(f"   - Train accuracy finale: {train_accs[-1]:.2f}%")
        print(f"   - Val accuracy finale: {val_accs[-1]:.2f}%")

    return fold_results

# =============================================================================
# VISUALISATIONS SPATIALSENSE+ AVEC COMPARAISON ARTICLE
# =============================================================================

def visualize_spatialsense_results_with_comparison(fold_results):
    """
    Visualise les résultats SpatialSense+ avec comparaison à Haldekar et al. 2017
    """
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))

    # 1. Courbes de loss
    for result in fold_results:
        fold = result['fold']
        ax1.plot(result['train_losses'], label=f'Fold {fold} Train', linewidth=2)
        ax1.plot(result['val_losses'], '--', label=f'Fold {fold} Val', alpha=0.8)

    ax1.set_title('Loss SpatialSense+ (Architecture Haldekar)', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Époque')
    ax1.set_ylabel('Loss')
    ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax1.grid(True, alpha=0.3)

    # 2. Courbes d'accuracy
    for result in fold_results:
        fold = result['fold']
        ax2.plot(result['train_accs'], label=f'Fold {fold} Train', linewidth=2)
        ax2.plot(result['val_accs'], '--', label=f'Fold {fold} Val', alpha=0.8)

    # Ligne de référence article original
    ax2.axhline(y=55.98, color='red', linestyle='-', alpha=0.7,
               label='Haldekar et al. 2017 (55.98%)')

    ax2.set_title('Accuracy SpatialSense+ vs Article Original', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Époque')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax2.grid(True, alpha=0.3)

    # 3. Comparaison des performances finales
    folds = [r['fold'] for r in fold_results]
    train_accs = [r['final_train_acc'] for r in fold_results]
    val_accs = [r['final_val_acc'] for r in fold_results]
    best_val_accs = [r['best_val_acc'] for r in fold_results]

    x = np.arange(len(folds))
    width = 0.25

    bars1 = ax3.bar(x - width, train_accs, width, label='Train Final', alpha=0.8)
    bars2 = ax3.bar(x, val_accs, width, label='Val Final', alpha=0.8)
    bars3 = ax3.bar(x + width, best_val_accs, width, label='Val Best', alpha=0.8)

    # Ligne de référence article
    ax3.axhline(y=55.98, color='red', linestyle='--', alpha=0.7,
               label='Haldekar SUN09 (55.98%)')

    ax3.set_title('Performance SpatialSense+ par Fold vs Article', fontsize=14, fontweight='bold')
    ax3.set_xlabel('Fold')
    ax3.set_ylabel('Accuracy (%)')
    ax3.set_xticks(x)
    ax3.set_xticklabels(folds)
    ax3.legend()
    ax3.grid(True, axis='y', alpha=0.3)

    # Ajouter les valeurs sur les barres
    def autolabel(bars):
        for bar in bars:
            height = bar.get_height()
            ax3.annotate(f'{height:.1f}',
                        xy=(bar.get_x() + bar.get_width() / 2, height),
                        xytext=(0, 3), textcoords="offset points",
                        ha='center', va='bottom', fontsize=9)

    autolabel(bars1)
    autolabel(bars2)
    autolabel(bars3)

    # 4. Résumé comparatif avec l'article
    ax4.axis('off')

    mean_train = np.mean(train_accs)
    std_train = np.std(train_accs)
    mean_val = np.mean(val_accs)
    std_val = np.std(val_accs)
    mean_best = np.mean(best_val_accs)
    std_best = np.std(best_val_accs)

    best_fold = np.argmax(best_val_accs) + 1

    # Comparaison avec l'article original
    improvement = mean_best - 55.98
    improvement_sign = "      " if improvement > 0 else "  " if improvement < -2 else "➡️"

    summary_text = f"""
    COMPARAISON HALDEKAR ET AL. 2017

    Article Original (SUN09):
    • Dataset: 4468 train, 4955 test
    • Relations: 3 (above, beside, behind)
    • Architecture: VGG FC-7 + MLP
    • Performance: 71.97% train, 55.98% test

    Notre Implementation (SpatialSense+):
    • Dataset: SpatialSense+
    • Relations: {len(SPATIAL_RELATIONS)} ({', '.join(SPATIAL_RELATIONS[:3])}...)
    • Architecture: Identique (VGG FC-7 + MLP)
    • Performance: {mean_best:.2f}% ± {std_best:.2f}% (val)

    Résultat: {improvement_sign} {improvement:+.2f}% vs article

    Défis SpatialSense+:
    • +{len(SPATIAL_RELATIONS)-3} relations supplémentaires
    • Tâche plus complexe
    • Dataset différent

    Meilleur Fold: {best_fold} ({best_val_accs[best_fold-1]:.2f}%)
    """

    color = 'lightgreen' if improvement > 0 else 'lightyellow' if improvement > -5 else 'lightcoral'
    ax4.text(0.05, 0.95, summary_text, transform=ax4.transAxes,
             fontsize=11, verticalalignment='top',
             bbox=dict(boxstyle='round,pad=0.5', facecolor=color, alpha=0.8))

    plt.tight_layout()
    plt.savefig('spatialsense_haldekar_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()

def visualize_predictions_with_heatmaps_spatialsense(model, dataset, device, num_samples=6):
    """
    Visualise prédictions SpatialSense+ avec heatmaps style Haldekar
    """
    model.eval()

    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    for i in range(min(num_samples, len(dataset))):
        # Échantillon SpatialSense+
        image, label, metadata = dataset[i]

        # Dénormaliser pour affichage
        img_display = image.clone()
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        img_display = img_display * std + mean
        img_display = img_display.clamp(0, 1).permute(1, 2, 0)

        # Prédiction
        with torch.no_grad():
            image_tensor = image.unsqueeze(0).to(device)
            output = model(image_tensor)
            probs = torch.softmax(output, dim=1)
            pred_idx = torch.argmax(probs)
            pred_prob = probs[0, pred_idx].item()
            true_prob = probs[0, label].item()

        pred_relation = dataset.idx_to_relation[pred_idx.item()]
        true_relation = metadata['relation']

        # Heatmap style Haldekar
        heatmap = generate_heatmap_haldekar_style(model, image, label, device)

        # 1. Image originale
        axes[i, 0].imshow(img_display)
        axes[i, 0].set_title(f"{metadata['subject']} - {metadata['object']}")
        axes[i, 0].axis('off')

        # 2. Heatmap (comme Figure 2 de l'article)
        axes[i, 1].imshow(heatmap, cmap='hot', alpha=0.8)
        axes[i, 1].set_title('Heatmap Haldekar-style')
        axes[i, 1].axis('off')

        # 3. Superposition (comme l'article)
        axes[i, 2].imshow(img_display)
        axes[i, 2].imshow(heatmap, cmap='jet', alpha=0.5)
        axes[i, 2].set_title('Superposition')
        axes[i, 2].axis('off')

        # 4. Informations détaillées
        axes[i, 3].axis('off')
        info_text = f"""
        SpatialSense+ Sample:

        Vérité: {true_relation}
        Prédiction: {pred_relation}
        Original: {metadata.get('original_relation', 'N/A')}

        Confiance prédiction: {pred_prob:.2%}
        Confiance vérité: {true_prob:.2%}

        Résultat: {'   CORRECT' if pred_relation == true_relation else '    INCORRECT'}

        Heatmap (influence):
        Max: {heatmap.max():.3f}
        Min: {heatmap.min():.3f}
        """
        color = 'lightgreen' if pred_relation == true_relation else 'lightcoral'
        axes[i, 3].text(0.1, 0.5, info_text, transform=axes[i, 3].transAxes,
                       fontsize=10, verticalalignment='center',
                       bbox=dict(boxstyle='round', facecolor=color, alpha=0.8))

    plt.suptitle('Prédictions SpatialSense+ - Méthode Haldekar et al. 2017',
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig('spatialsense_predictions_haldekar.png', dpi=300, bbox_inches='tight')
    plt.show()

# =============================================================================
# ANALYSE DE PERFORMANCE PAR RELATION SPATIALSENSE+
# =============================================================================

def analyze_spatialsense_performance_per_relation(model, dataloader, device):
    """Analyse la performance par relation SpatialSense+"""
    model.eval()

    # Matrices pour chaque relation SpatialSense+
    relation_correct = {rel: 0 for rel in SPATIAL_RELATIONS}
    relation_total = {rel: 0 for rel in SPATIAL_RELATIONS}
    relation_predictions = {rel: [] for rel in SPATIAL_RELATIONS}

    print("   Analyse performance par relation SpatialSense+...")

    with torch.no_grad():
        for batch_data in tqdm(dataloader, desc='Analyzing relations'):
            # Gestion flexible du format de batch
            if len(batch_data) == 3:
                images, labels, metadata = batch_data
            else:
                images, labels = batch_data[:2]
                metadata = [{'relation': SPATIAL_RELATIONS[labels[i].item()]} for i in range(len(labels))]

            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs, 1)

            for i in range(len(labels)):
                # Gestion des métadonnées
                if isinstance(metadata, list) and i < len(metadata):
                    true_rel = metadata[i].get('relation', SPATIAL_RELATIONS[labels[i].item()])
                elif isinstance(metadata, dict):
                    relation_list = metadata.get('relation', [SPATIAL_RELATIONS[labels[j].item()] for j in range(len(labels))])
                    true_rel = relation_list[i] if i < len(relation_list) else SPATIAL_RELATIONS[labels[i].item()]
                else:
                    true_rel = SPATIAL_RELATIONS[labels[i].item()]

                pred_idx = predicted[i].item()
                pred_rel = SPATIAL_RELATIONS[pred_idx]

                relation_total[true_rel] += 1
                relation_predictions[true_rel].append(pred_rel)

                if pred_rel == true_rel:
                    relation_correct[true_rel] += 1

    # Calcul accuracies par relation
    relation_accuracies = {}
    for rel in SPATIAL_RELATIONS:
        if relation_total[rel] > 0:
            relation_accuracies[rel] = relation_correct[rel] / relation_total[rel] * 100
        else:
            relation_accuracies[rel] = 0.0

    # Visualisation
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

    # 1. Accuracy par relation SpatialSense+
    relations = SPATIAL_RELATIONS  # Ordre défini
    accuracies = [relation_accuracies[rel] for rel in relations]
    counts = [relation_total[rel] for rel in relations]

    bars = ax1.bar(range(len(relations)), accuracies,
                   color='skyblue', alpha=0.8)
    ax1.set_xlabel('Relations SpatialSense+')
    ax1.set_ylabel('Accuracy (%)')
    ax1.set_title('Performance par Relation SpatialSense+\n(Architecture Haldekar)')
    ax1.set_xticks(range(len(relations)))
    ax1.set_xticklabels(relations, rotation=45, ha='right')
    ax1.grid(True, axis='y', alpha=0.3)

    # Ligne de référence article (moyenne)
    ax1.axhline(y=55.98, color='red', linestyle='--', alpha=0.7,
               label='Haldekar SUN09 (55.98%)')
    ax1.legend()

    # Annotations avec nombre d'échantillons
    for i, (bar, acc, count) in enumerate(zip(bars, accuracies, counts)):
        height = bar.get_height()
        ax1.annotate(f'{acc:.1f}%\n(n={count})',
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3), textcoords="offset points",
                    ha='center', va='bottom', fontsize=8)

    # 2. Distribution des échantillons SpatialSense+
    ax2.bar(range(len(relations)), counts, color='lightcoral', alpha=0.8)
    ax2.set_xlabel('Relations SpatialSense+')
    ax2.set_ylabel('Nombre d\'échantillons')
    ax2.set_title('Distribution Échantillons SpatialSense+')
    ax2.set_xticks(range(len(relations)))
    ax2.set_xticklabels(relations, rotation=45, ha='right')
    ax2.grid(True, axis='y', alpha=0.3)

    plt.tight_layout()
    plt.savefig('spatialsense_performance_per_relation.png', dpi=300, bbox_inches='tight')
    plt.show()

    # Résumé textuel
    print("\n" + "="*70)
    print("ANALYSE PERFORMANCE PAR RELATION SPATIALSENSE+")
    print("="*70)

    sorted_relations = sorted(zip(relations, accuracies, counts),
                             key=lambda x: x[1], reverse=True)

    print(f"\n    Meilleures performances:")
    for rel, acc, count in sorted_relations[:3]:
        status = "🟢" if acc > 55.98 else "🟡" if acc > 40 else "🔴"
        print(f"   {status} {rel}: {acc:.2f}% (n={count})")

    print(f"\n   Pires performances:")
    for rel, acc, count in sorted_relations[-3:]:
        status = "🟢" if acc > 55.98 else "🟡" if acc > 40 else "🔴"
        print(f"   {status} {rel}: {acc:.2f}% (n={count})")

    overall_acc = sum(relation_correct.values()) / sum(relation_total.values()) * 100
    print(f"\n   Statistiques globales:")
    print(f"   - Accuracy globale: {overall_acc:.2f}%")
    print(f"   - Relations au-dessus Haldekar: {sum(1 for acc in accuracies if acc > 55.98)}/{len(SPATIAL_RELATIONS)}")
    print(f"   - Échantillons total: {sum(relation_total.values())}")

    return relation_accuracies, relation_total

# =============================================================================
# FONCTION PRINCIPALE SPATIALSENSE+ HALDEKAR AVEC ANALYSE COMPLÈTE
# =============================================================================

def main_spatialsense_haldekar():
    """
    Fonction principale: Implémentation exacte Haldekar et al. 2017 sur SpatialSense+
    Avec analyse complète incluant matrice de confusion et analyse d'erreurs
    """
    DATA_DIR = "data/spatialsense"  # À adapter selon votre environnement

    print("="*80)
    print("IMPLÉMENTATION EXACTE HALDEKAR ET AL. 2017 SUR SPATIALSENSE+")
    print("AVEC ANALYSE COMPLÈTE DES ERREURS")
    print("="*80)
    print("   Article original: 'Identifying Spatial Relations in Images using CNNs'")
    print("   Dataset original: SUN09 (3 relations, 55.98% test accuracy)")
    print("   Notre adaptation: SpatialSense+ (9 relations)")
    print("    Architecture: VGG16 FC-7 + MLP (identique à l'article)")
    print("   Nouveauté: Matrice de confusion + analyse détaillée des erreurs")
    print("="*80)

    print(f"\nConfiguration:")
    print(f"   - Device: {DEVICE}")
    print(f"   - Dataset: SpatialSense+ ({len(SPATIAL_RELATIONS)} relations)")
    print(f"   - Architecture: VGG16 FC-7 (4096) + MLP (512→256→{len(SPATIAL_RELATIONS)})")
    print(f"   - Hyperparamètres: batch_size={BATCH_SIZE}, lr={LEARNING_RATE}, epochs={EPOCHS}")
    print(f"   - Relations: {', '.join(SPATIAL_RELATIONS)}")

    # Vérification des données
    if not os.path.exists(DATA_DIR):
        print(f"\n    ERREUR: Le répertoire {DATA_DIR} n'existe pas!")
        print("\n   Pour utiliser ce code:")
        print("1. Téléchargez le dataset SpatialSense+")
        print("2. Extrayez-le dans le répertoire spécifié")
        print("3. Structure attendue:")
        print("   data/spatialsense/")
        print("   ├── annotations.json")
        print("   └── images/images/{flickr,nyu}/")
        return

    # Test rapide du dataset
    print("\n" + "="*50)
    print("🧪 Test du Dataset SpatialSense+")
    print("="*50)

    try:
        # Créer le dataset de test AVEC transform
        train_transform, test_transform = create_transforms()
        test_dataset = SpatialSenseDataset(DATA_DIR, split='train', transform=train_transform)
        print(f"   Dataset d'entraînement: {len(test_dataset)} échantillons")

        if len(test_dataset) > 0:
            img, label, meta = test_dataset[0]
            print(f"   Premier échantillon:")
            print(f"   - Relation: {meta['relation']} (mappée de '{meta['original_relation']}')")
            print(f"   - Sujet: {meta['subject']}")
            print(f"   - Objet: {meta['object']}")

            # Correction: vérifier le type avant d'accéder à shape
            if isinstance(img, torch.Tensor):
                print(f"   - Image shape: {img.shape}")
            else:
                print(f"   - Image size: {img.size} (PIL Image)")
        else:
            print("    Dataset vide!")
            return
    except Exception as e:
        print(f"    Erreur chargement dataset: {e}")
        return

    # Phase 1: Validation K-fold Haldekar-style
    print("\n" + "="*50)
    print("     PHASE 1: Validation K-fold (Haldekar et al. 2017)")
    print("="*50)

    fold_results = kfold_cross_validation_spatialsense(
        data_dir=DATA_DIR,
        k_folds=K_FOLDS,
        epochs=EPOCHS
    )

    if not fold_results:
        print("    Erreur pendant la validation K-fold")
        return

    # Visualisation des résultats avec comparaison
    print("\n   Visualisation résultats avec comparaison article...")
    visualize_spatialsense_results_with_comparison(fold_results)

    # Sélection du meilleur modèle
    best_fold_idx = np.argmax([r['best_val_acc'] for r in fold_results])
    best_model = fold_results[best_fold_idx]['model']
    best_acc = fold_results[best_fold_idx]['best_val_acc']

    print(f"\n    Meilleur modèle: Fold {best_fold_idx + 1}")
    print(f"   - Validation accuracy: {best_acc:.2f}%")
    print(f"   - Comparaison Haldekar: {best_acc - 55.98:+.2f}%")

    # Sauvegarde du meilleur modèle
    torch.save(best_model.state_dict(), 'best_spatialsense_haldekar_model.pth')
    print("   Modèle sauvegardé: best_spatialsense_haldekar_model.pth")

    # Phase 2: Évaluation sur le test set avec analyse complète
    print("\n" + "="*50)
    print("  PHASE 2: Évaluation Test Set avec Analyse Complète")
    print("="*50)

    try:
        # Dataset de test
        _, test_transform = create_transforms()
        test_dataset = SpatialSenseDataset(DATA_DIR, split='test', transform=test_transform)

        print(f"Dataset de test: {len(test_dataset)} échantillons")

        if len(test_dataset) == 0:
            print("       Dataset test vide, utilisation du validation set")
            test_dataset = SpatialSenseDataset(DATA_DIR, split='val', transform=test_transform)
            print(f"Dataset de validation: {len(test_dataset)} échantillons")

        if len(test_dataset) > 0:
            test_loader = DataLoader(
                test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                num_workers=2, pin_memory=True
            )

            # Évaluation finale
            criterion = nn.CrossEntropyLoss()
            test_loss, test_acc = evaluate(best_model, test_loader, criterion, DEVICE)

            print(f"\n   Résultats Test Set:")
            print(f"   - Loss: {test_loss:.4f}")
            print(f"   - Accuracy: {test_acc:.2f}%")
            print(f"   - Comparaison Haldekar: {test_acc - 55.98:+.2f}%")

            # NOUVELLE PHASE: Analyse complète des erreurs
            print("\n" + "="*50)
            print("   PHASE 3: Analyse Complète des Erreurs")
            print("="*50)

            # Génération de la matrice de confusion et analyse d'erreurs
            print("   Génération matrice de confusion...")
            cm, error_examples, classification_report = generate_confusion_matrix_and_analysis(
                best_model, test_loader, DEVICE, save_prefix="spatialsense"
            )

            # Analyse des patterns d'erreurs
            print("\n   Analyse des patterns d'erreurs...")
            group_confusion, subject_errors, object_errors = analyze_error_patterns(
                error_examples, save_prefix="spatialsense"
            )

            # Visualisation d'exemples d'erreurs
            print("\n   Visualisation d'exemples d'erreurs...")
            visualize_error_examples(
                best_model, test_dataset, error_examples, DEVICE,
                num_examples=6, save_prefix="spatialsense"
            )

            # Phase 4: Visualisations Haldekar-style
            print("\n" + "="*50)
            print("   PHASE 4: Visualisations Haldekar-style")
            print("="*50)

            print("   Génération heatmaps style article...")
            visualize_predictions_with_heatmaps_spatialsense(
                best_model, test_dataset, DEVICE, num_samples=6
            )

            # Phase 5: Analyse par relation
            print("       Analyse performance par relation...")
            relation_accs, relation_counts = analyze_spatialsense_performance_per_relation(
                best_model, test_loader, DEVICE
            )

        else:
            print("    Aucun échantillon de test disponible")
            test_acc = best_acc

    except Exception as e:
        print(f"    Erreur évaluation test: {e}")
        test_acc = best_acc

    # Résumé final avec analyse d'erreurs
    print("\n" + "="*80)
    print("   RÉSUMÉ FINAL - COMPARAISON HALDEKAR + ANALYSE D'ERREURS")
    print("="*80)

    # Statistiques K-fold
    mean_train_acc = np.mean([r['final_train_acc'] for r in fold_results])
    mean_val_acc = np.mean([r['final_val_acc'] for r in fold_results])
    std_val_acc = np.std([r['final_val_acc'] for r in fold_results])
    mean_best_acc = np.mean([r['best_val_acc'] for r in fold_results])

    print(f"\n   Résultats Validation K-fold:")
    print(f"   - Train accuracy moyenne: {mean_train_acc:.2f}%")
    print(f"   - Val accuracy moyenne: {mean_val_acc:.2f}% ± {std_val_acc:.2f}%")
    print(f"   - Best val accuracy moyenne: {mean_best_acc:.2f}%")
    print(f"   - Meilleur fold: {best_fold_idx + 1} ({best_acc:.2f}%)")

    if 'test_acc' in locals() and test_acc != best_acc:
        print(f"\n  Test final:")
        print(f"   - Test accuracy: {test_acc:.2f}%")
        final_comparison = test_acc
    else:
        final_comparison = best_acc

    print(f"\n   Comparaison avec Article Original:")
    print(f"   - Haldekar et al. 2017 (SUN09): 55.98% test")
    print(f"   - Notre implémentation (SpatialSense+): {final_comparison:.2f}%")
    print(f"   - Différence: {final_comparison - 55.98:+.2f}%")

    if final_comparison > 55.98:
        print(f"   -         AMÉLIORATION malgré +{len(SPATIAL_RELATIONS)-3} relations!")
    elif final_comparison > 50:
        print(f"   -    PERFORMANCE SOLIDE avec {len(SPATIAL_RELATIONS)} relations vs 3")
    else:
        print(f"   -        Performance à améliorer (tâche plus complexe)")

    print(f"\n   Facteurs de Complexité:")
    print(f"   - Relations: {len(SPATIAL_RELATIONS)} vs 3 (+{len(SPATIAL_RELATIONS)-3})")
    print(f"   - Dataset: SpatialSense+ vs SUN09 (différent)")
    print(f"   - Architecture: Identique (VGG16 FC-7 + MLP)")
    print(f"   - Méthode heatmap: Identique (masking séquentiel)")

    # Analyse d'erreurs résumée
    if 'error_examples' in locals():
        total_errors = sum(len(examples) for examples in error_examples.values())
        error_types = len(error_examples)
        print(f"\n   Analyse d'Erreurs:")
        print(f"   - Total erreurs analysées: {total_errors}")
        print(f"   - Types d'erreurs différents: {error_types}")

        # Top 3 erreurs
        error_counts = Counter()
        for (true_rel, pred_rel), examples in error_examples.items():
            error_counts[(true_rel, pred_rel)] = len(examples)

        top_3_errors = error_counts.most_common(3)
        print(f"   - Top 3 erreurs fréquentes:")
        for i, ((true_rel, pred_rel), count) in enumerate(top_3_errors, 1):
            print(f"     {i}. {true_rel} → {pred_rel}: {count} cas")

    return fold_results

# =============================================================================
# POINT D'ENTRÉE
# =============================================================================

if __name__ == "__main__":
    main_spatialsense_haldekar()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import ViTModel, ViTConfig
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import json
import os
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# CONFIGURATION AMÉLIORÉE
# =============================================================================

# Configuration GPU
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device utilisé: {DEVICE}")

# HYPERPARAMÈTRES OPTIMISÉS POUR AMÉLIORER LES PERFORMANCES
BATCH_SIZE = 16          # Augmenté de 8 à 16 (plus stable)
LEARNING_RATE = 2e-5     # Augmenté de 1e-6 à 2e-5 (convergence plus rapide)
EPOCHS = 15              # Augmenté de 10 à 15 (plus d'entraînement)
K_FOLDS = 5
IMG_SIZE = 224
DROPOUT_RATE = 0.3       # Réduit de 0.5 à 0.3 (moins de régularisation)

# Nouveaux hyperparamètres pour optimisation
WEIGHT_DECAY = 0.05      # Augmenté pour meilleure régularisation
WARMUP_EPOCHS = 3        # Warmup pour stabilité
MIN_LR = 1e-7           # Learning rate minimum
SCHEDULER_PATIENCE = 3   # Patience pour ReduceLROnPlateau

# Relations spatiales EXACTES de SpatialSense+
SPATIAL_RELATIONS = [
    'above', 'behind', 'in', 'in front of', 'next to',
    'on', 'to the left of', 'to the right of', 'under'
]

print(f"Relations SpatialSense+ : {len(SPATIAL_RELATIONS)} relations")
print(f"Hyperparamètres optimisés :")
print(f"  - Batch size: {BATCH_SIZE}")
print(f"  - Learning rate: {LEARNING_RATE}")
print(f"  - Epochs: {EPOCHS}")
print(f"  - Dropout: {DROPOUT_RATE}")
print(f"  - Weight decay: {WEIGHT_DECAY}")

# =============================================================================
# ARCHITECTURE AMÉLIORÉE
# =============================================================================

class ImprovedSpatialRelationMLP(nn.Module):
    """
    MLP amélioré avec techniques modernes pour meilleures performances
    """
    def __init__(self, input_dim=4096, hidden1_dim=512, hidden2_dim=256,
                 num_relations=len(SPATIAL_RELATIONS), dropout_rate=0.3):
        super(ImprovedSpatialRelationMLP, self).__init__()

        # Première couche avec BatchNorm
        self.fc1 = nn.Linear(input_dim, hidden1_dim)
        self.bn1 = nn.BatchNorm1d(hidden1_dim)  # Ajout BatchNorm
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)

        # Deuxième couche avec BatchNorm
        self.fc2 = nn.Linear(hidden1_dim, hidden2_dim)
        self.bn2 = nn.BatchNorm1d(hidden2_dim)  # Ajout BatchNorm
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_rate)

        # Couche de sortie
        self.fc3 = nn.Linear(hidden2_dim, num_relations)

        # Initialisation Xavier/Glorot pour meilleure convergence
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.xavier_uniform_(self.fc3.weight)

        print(f"MLP Amélioré avec BatchNorm:")
        print(f"  - Input: {input_dim}")
        print(f"  - Hidden 1: {hidden1_dim} + BatchNorm")
        print(f"  - Hidden 2: {hidden2_dim} + BatchNorm")
        print(f"  - Output: {num_relations}")
        print(f"  - Dropout: {dropout_rate}")

    def forward(self, x):
        # FC-0 layer avec BatchNorm
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        # FC-1 layer avec BatchNorm
        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        # Output layer
        x = self.fc3(x)
        return x

class ImprovedViTFeatureExtractor(nn.Module):
    """
    Extracteur ViT amélioré avec fine-tuning partiel
    """
    def __init__(self, model_name='google/vit-base-patch16-224',
                 freeze_features=False, fine_tune_layers=4):
        super(ImprovedViTFeatureExtractor, self).__init__()

        self.vit = ViTModel.from_pretrained(model_name)
        self.feature_dim = self.vit.config.hidden_size

        # Projection améliorée avec résiduel
        self.feature_projection = nn.Sequential(
            nn.Linear(self.feature_dim, 2048),
            nn.LayerNorm(2048),  # LayerNorm au lieu de BatchNorm
            nn.GELU(),           # GELU au lieu de ReLU
            nn.Dropout(0.1),
            nn.Linear(2048, 4096),
            nn.LayerNorm(4096),
            nn.GELU(),
            nn.Dropout(0.1)
        )

        # Fine-tuning partiel des dernières couches
        if not freeze_features:
            # Dégel des dernières couches pour fine-tuning
            for param in self.vit.parameters():
                param.requires_grad = False

            # Dégel des dernières couches seulement
            for layer in self.vit.encoder.layer[-fine_tune_layers:]:
                for param in layer.parameters():
                    param.requires_grad = True

            print(f"Fine-tuning des {fine_tune_layers} dernières couches ViT")
        else:
            for param in self.vit.parameters():
                param.requires_grad = False
            print("Toutes les features ViT gelées")

    def forward(self, x):
        outputs = self.vit(pixel_values=x)
        cls_token = outputs.last_hidden_state[:, 0]
        features = self.feature_projection(cls_token)
        return features

    def get_attention_weights(self, x):
        outputs = self.vit(pixel_values=x, output_attentions=True)
        attention_weights = outputs.attentions[-1][:, -1, 0, 1:]
        return attention_weights

class ImprovedViTSpatialRelationModel(nn.Module):
    """
    Modèle complet amélioré
    """
    def __init__(self, num_relations=len(SPATIAL_RELATIONS),
                 vit_model='google/vit-base-patch16-224',
                 freeze_vit=False, fine_tune_layers=4):
        super(ImprovedViTSpatialRelationModel, self).__init__()

        self.feature_extractor = ImprovedViTFeatureExtractor(
            model_name=vit_model,
            freeze_features=freeze_vit,
            fine_tune_layers=fine_tune_layers
        )

        self.classifier = ImprovedSpatialRelationMLP(
            input_dim=4096,
            hidden1_dim=512,
            hidden2_dim=256,
            num_relations=num_relations,
            dropout_rate=DROPOUT_RATE
        )

    def forward(self, x):
        features = self.feature_extractor(x)
        output = self.classifier(features)
        return output

    def get_attention_visualization(self, x):
        return self.feature_extractor.get_attention_weights(x)

# =============================================================================
# DATASET ADAPTÉ POUR SPATIALSENSE+
# =============================================================================

class SpatialSenseDataset(Dataset):
    """Dataset SpatialSense+ avec les relations exactes du tableau"""

    def __init__(self, data_dir, split='train', transform=None):
        self.data_dir = data_dir
        self.split = split
        self.transform = transform

        # Mapping exact selon le tableau SpatialSense+
        self.relation_to_idx = {rel: idx for idx, rel in enumerate(SPATIAL_RELATIONS)}
        self.idx_to_relation = {idx: rel for rel, idx in self.relation_to_idx.items()}

        self.annotations = self._load_annotations()
        self.data_samples = self._prepare_samples()

        print(f"Dataset SpatialSense+ {split} initialisé avec {len(self.data_samples)} échantillons")
        if len(self.data_samples) > 0:
            self._print_statistics()

    def _load_annotations(self):
        """Charge les annotations SpatialSense+"""
        annotations_path = os.path.join(self.data_dir, 'annotations.json')
        try:
            with open(annotations_path, 'r') as f:
                return json.load(f)
        except FileNotFoundError:
            print(f"Erreur: Fichier {annotations_path} non trouvé!")
            return []
        except json.JSONDecodeError as e:
            print(f"Erreur lors du décodage JSON: {e}")
            return []

    def _find_image_path(self, image_url):
        """Trouve le chemin local d'une image SpatialSense+"""
        base_dir = os.path.join(self.data_dir, "images", "images")
        filename = os.path.basename(image_url)

        if "staticflickr" in image_url or len(filename.split('_')) == 2:
            return os.path.join(base_dir, "flickr", filename)
        else:
            return os.path.join(base_dir, "nyu", filename)

    def _prepare_samples(self):
        """Prépare les échantillons selon SpatialSense+"""
        samples = []
        images_not_found = 0
        relations_filtered = 0

        for img_data in self.annotations:
            if img_data['split'] != self.split:
                continue

            img_path = self._find_image_path(img_data['url'])

            if not os.path.exists(img_path):
                images_not_found += 1
                continue

            for ann in img_data['annotations']:
                if ann['label']:  # Seulement les annotations positives
                    original_relation = ann['predicate']

                    # Utilisation directe de la relation sans mapping
                    if original_relation.lower().strip() in [rel.lower() for rel in SPATIAL_RELATIONS]:
                        # Trouver la relation exacte (gestion de la casse)
                        relation = None
                        for rel in SPATIAL_RELATIONS:
                            if rel.lower() == original_relation.lower().strip():
                                relation = rel
                                break

                        if relation:
                            sample = {
                                'image_path': img_path,
                                'subject': ann['subject']['name'],
                                'object': ann['object']['name'],
                                'relation': relation,
                                'original_relation': original_relation,
                                'subject_bbox': ann['subject'].get('bbox', None),
                                'object_bbox': ann['object'].get('bbox', None)
                            }
                            samples.append(sample)
                        else:
                            relations_filtered += 1
                    else:
                        relations_filtered += 1

        if images_not_found > 0:
            print(f"Images non trouvées: {images_not_found}")
        if relations_filtered > 0:
            print(f"Relations filtrées: {relations_filtered}")

        return samples

    def _print_statistics(self):
        """Affiche les statistiques SpatialSense+"""
        relation_counts = Counter([s['relation'] for s in self.data_samples])
        print(f"\nDistribution SpatialSense+ dans {self.split}:")
        total = len(self.data_samples)
        for relation in SPATIAL_RELATIONS:  # Ordre défini
            count = relation_counts.get(relation, 0)
            percentage = count / total * 100 if total > 0 else 0
            print(f"  {relation}: {count} ({percentage:.1f}%)")

    def __len__(self):
        return len(self.data_samples)

    def __getitem__(self, idx):
        """Retourne échantillon SpatialSense+"""
        sample = self.data_samples[idx]

        try:
            # Chargement image (224x224 comme l'article)
            image = Image.open(sample['image_path']).convert('RGB')

            if self.transform:
                image = self.transform(image)

            label = self.relation_to_idx[sample['relation']]

            metadata = {
                'subject': sample['subject'],
                'object': sample['object'],
                'relation': sample['relation'],
                'original_relation': sample['original_relation']
            }

            return image, label, metadata

        except Exception as e:
            print(f"Erreur chargement {sample['image_path']}: {e}")
            # Image par défaut en cas d'erreur
            dummy_image = Image.new('RGB', (224, 224), color='gray')
            if self.transform:
                dummy_image = self.transform(dummy_image)
            else:
                dummy_image = torch.zeros(3, 224, 224)

            return dummy_image, 0, {
                'subject': 'error', 'object': 'error',
                'relation': 'next to', 'original_relation': 'error'
            }

# =============================================================================
# TRANSFORMATIONS AMÉLIORÉES
# =============================================================================

def create_improved_vit_transforms():
    """Transformations améliorées avec augmentations plus fortes"""
    normalize = transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5]
    )

    # Augmentations plus fortes pour l'entraînement
    train_transform = transforms.Compose([
        transforms.Resize((256, 256)),  # Resize plus grand
        transforms.RandomCrop((224, 224)),  # Random crop
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),  # Rotation ajoutée
        transforms.ColorJitter(brightness=0.2, contrast=0.2,
                              saturation=0.2, hue=0.1),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # Translation
        transforms.ToTensor(),
        normalize
    ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        normalize
    ])

    return train_transform, val_transform

# =============================================================================
# FONCTIONS D'ENTRAÎNEMENT AMÉLIORÉES
# =============================================================================

def get_improved_optimizer_and_scheduler(model, train_loader_len):
    """Optimiseur et scheduler améliorés"""

    # Paramètres différentiés pour ViT et MLP
    vit_params = []
    mlp_params = []

    for name, param in model.named_parameters():
        if param.requires_grad:
            if 'vit' in name:
                vit_params.append(param)
            else:
                mlp_params.append(param)

    # Learning rates différents
    optimizer = optim.AdamW([
        {'params': vit_params, 'lr': LEARNING_RATE * 0.1},  # LR plus petit pour ViT
        {'params': mlp_params, 'lr': LEARNING_RATE}         # LR normal pour MLP
    ], weight_decay=WEIGHT_DECAY)

    # Scheduler avec warmup
    total_steps = train_loader_len * EPOCHS
    warmup_steps = train_loader_len * WARMUP_EPOCHS

    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        else:
            return max(MIN_LR / LEARNING_RATE,
                      0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps))))

    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    return optimizer, scheduler

def train_one_epoch_improved(model, dataloader, criterion, optimizer, scheduler, device, epoch):
    """Entraînement amélioré avec mixed precision et gradient clipping"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    # Scaler pour mixed precision (si supporté)
    scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None

    progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}')
    for batch_idx, (images, labels, _) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        # Mixed precision si disponible
        if scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

        scheduler.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        current_lr = optimizer.param_groups[0]['lr']
        progress_bar.set_postfix({
            'loss': running_loss / (batch_idx + 1),
            'acc': 100. * correct / total,
            'lr': f'{current_lr:.2e}'
        })

    return running_loss / len(dataloader), 100. * correct / total

def evaluate(model, dataloader, criterion, device):
    """Évalue le modèle"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels, _ in tqdm(dataloader, desc='Evaluating'):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return running_loss / len(dataloader), 100. * correct / total

# =============================================================================
# CORRECTION ANALYSE PERFORMANCE PAR RELATION
# =============================================================================

def analyze_performance_per_relation_fixed(model, dataloader, device):
    """Version corrigée de l'analyse par relation"""
    model.eval()

    relation_correct = {rel: 0 for rel in SPATIAL_RELATIONS}
    relation_total = {rel: 0 for rel in SPATIAL_RELATIONS}
    relation_predictions = {rel: [] for rel in SPATIAL_RELATIONS}

    with torch.no_grad():
        for images, labels, metadata_batch in tqdm(dataloader, desc='Analyzing per relation'):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs, 1)

            # CORRECTION: Gestion correcte des métadonnées en batch
            batch_size = len(labels)

            for i in range(batch_size):
                # Extraction correcte des métadonnées selon la structure
                if isinstance(metadata_batch, dict):
                    # Cas où metadata_batch est un dict avec des listes
                    true_rel = metadata_batch['relation'][i] if 'relation' in metadata_batch else None
                elif isinstance(metadata_batch, (list, tuple)):
                    # Cas où metadata_batch est une liste/tuple
                    true_rel = metadata_batch[i]['relation'] if i < len(metadata_batch) else None
                else:
                    # Cas imprévu
                    print(f"Structure métadonnées inattendue: {type(metadata_batch)}")
                    continue

                if true_rel is None:
                    continue

                pred_idx = predicted[i].item()
                pred_rel = SPATIAL_RELATIONS[pred_idx]

                relation_total[true_rel] += 1
                relation_predictions[true_rel].append(pred_rel)

                if pred_rel == true_rel:
                    relation_correct[true_rel] += 1

    # Calcul accuracies
    relation_accuracies = {}
    for rel in SPATIAL_RELATIONS:
        if relation_total[rel] > 0:
            relation_accuracies[rel] = relation_correct[rel] / relation_total[rel] * 100
        else:
            relation_accuracies[rel] = 0.0

    # Visualisation
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

    # 1. Accuracy par relation
    relations = list(relation_accuracies.keys())
    accuracies = list(relation_accuracies.values())
    counts = [relation_total[rel] for rel in relations]

    bars = ax1.bar(range(len(relations)), accuracies,
                   color='skyblue', alpha=0.8)
    ax1.set_xlabel('Relations SpatialSense+')
    ax1.set_ylabel('Accuracy (%)')
    ax1.set_title('Performance par Relation SpatialSense+')
    ax1.set_xticks(range(len(relations)))
    ax1.set_xticklabels(relations, rotation=45, ha='right')
    ax1.grid(True, axis='y', alpha=0.3)

    # Annotations avec nombre d'échantillons
    for i, (bar, acc, count) in enumerate(zip(bars, accuracies, counts)):
        height = bar.get_height()
        ax1.annotate(f'{acc:.1f}%\n(n={count})',
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3), textcoords="offset points",
                    ha='center', va='bottom', fontsize=8)

    # 2. Distribution des échantillons
    ax2.bar(range(len(relations)), counts, color='lightcoral', alpha=0.8)
    ax2.set_xlabel('Relations SpatialSense+')
    ax2.set_ylabel('Nombre d\'échantillons')
    ax2.set_title('Distribution des Échantillons par Relation')
    ax2.set_xticks(range(len(relations)))
    ax2.set_xticklabels(relations, rotation=45, ha='right')
    ax2.grid(True, axis='y', alpha=0.3)

    plt.tight_layout()
    plt.savefig('spatialsense_performance_per_relation.png', dpi=300, bbox_inches='tight')
    plt.show()

    return relation_accuracies, relation_total

# =============================================================================
# HEATMAP GÉNÉRATION (inspirée de l'article)
# =============================================================================

def generate_attention_heatmap_like_article(model, image, target_class, device):
    """
    Génère une heatmap d'attention style Haldekar et al. 2017
    Utilise l'attention ViT comme équivalent du masking de l'article
    """
    model.eval()

    with torch.no_grad():
        image_tensor = image.unsqueeze(0).to(device)

        # Prédiction normale
        output = model(image_tensor)

        # Attention ViT (équivalent heatmap article)
        attention_weights = model.get_attention_visualization(image_tensor)

        # Reshape attention pour visualisation (14x14 patches pour ViT-Base)
        attention_map = attention_weights[0].cpu().numpy()
        patch_size = int(np.sqrt(len(attention_map)))
        attention_2d = attention_map.reshape(patch_size, patch_size)

        # Redimensionner à 224x224 (taille image article)
        attention_resized = np.kron(attention_2d, np.ones((224//patch_size, 224//patch_size)))

        return attention_resized, output

def visualize_spatialsense_predictions(model, dataset, device, num_samples=6):
    """
    Visualise les prédictions SpatialSense+ avec heatmaps (style article Haldekar)
    """
    model.eval()

    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    for i in range(min(num_samples, len(dataset))):
        # Échantillon SpatialSense+
        image, label, metadata = dataset[i]

        # Dénormaliser pour affichage
        img_display = image.clone()
        img_display = (img_display + 1) / 2  # ViT dénormalization
        img_display = img_display.clamp(0, 1).permute(1, 2, 0)

        # Prédiction avec heatmap
        attention_map, output = generate_attention_heatmap_like_article(
            model, image, label, device
        )

        probs = torch.softmax(output, dim=1)
        pred_idx = torch.argmax(probs)
        pred_prob = probs[0, pred_idx].item()

        pred_relation = dataset.idx_to_relation[pred_idx.item()]
        true_relation = metadata['relation']

        # 1. Image originale
        axes[i, 0].imshow(img_display)
        axes[i, 0].set_title(f"{metadata['subject']} - {metadata['object']}")
        axes[i, 0].axis('off')

        # 2. Heatmap d'attention (équivalent masking article)
        axes[i, 1].imshow(attention_map, cmap='hot', alpha=0.8)
        axes[i, 1].set_title('Attention Heatmap\n(équivalent article)')
        axes[i, 1].axis('off')

        # 3. Superposition (comme Figure 2 de l'article)
        axes[i, 2].imshow(img_display)
        axes[i, 2].imshow(attention_map, cmap='jet', alpha=0.4)
        axes[i, 2].set_title('Superposition\n(style Figure 2)')
        axes[i, 2].axis('off')

        # 4. Informations prédiction
        axes[i, 3].axis('off')
        info_text = f"""
        SpatialSense+ Sample:

        Vérité: {true_relation}
        Prédiction: {pred_relation}
        Original: {metadata.get('original_relation', 'N/A')}

        Confiance: {pred_prob:.2%}

        Résultat: {'✓ CORRECT' if pred_relation == true_relation else '✗ INCORRECT'}

        Attention:
        Max: {attention_map.max():.3f}
        Min: {attention_map.min():.3f}
        """
        color = 'lightgreen' if pred_relation == true_relation else 'lightcoral'
        axes[i, 3].text(0.1, 0.5, info_text, transform=axes[i, 3].transAxes,
                       fontsize=10, verticalalignment='center',
                       bbox=dict(boxstyle='round', facecolor=color, alpha=0.8))

    plt.suptitle('Analyse SpatialSense+ - Style Haldekar et al. 2017', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig('spatialsense_predictions_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()

# =============================================================================
# VISUALISATIONS RÉSULTATS K-FOLD
# =============================================================================

def visualize_spatialsense_results(fold_results):
    """Visualise les résultats SpatialSense+"""

    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))

    # 1. Courbes de loss
    for result in fold_results:
        fold = result['fold']
        epochs_range = range(1, len(result['train_losses']) + 1)

        ax1.plot(epochs_range, result['train_losses'],
                label=f'Fold {fold} Train', linewidth=2)
        ax1.plot(epochs_range, result['val_losses'], '--',
                label=f'Fold {fold} Val', alpha=0.8)

    ax1.set_title('Loss SpatialSense+', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Époque')
    ax1.set_ylabel('Loss')
    ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax1.grid(True, alpha=0.3)

    # 2. Courbes d'accuracy
    for result in fold_results:
        fold = result['fold']
        epochs_range = range(1, len(result['train_accs']) + 1)

        ax2.plot(epochs_range, result['train_accs'],
                label=f'Fold {fold} Train', linewidth=2)
        ax2.plot(epochs_range, result['val_accs'], '--',
                label=f'Fold {fold} Val', alpha=0.8)

    ax2.set_title('Accuracy SpatialSense+', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Époque')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax2.grid(True, alpha=0.3)

    # 3. Performance par fold
    folds = [r['fold'] for r in fold_results]
    best_val_accs = [r['best_val_acc'] for r in fold_results]

    x = np.arange(len(folds))
    bars = ax3.bar(x, best_val_accs, color='lightblue', alpha=0.8)

    ax3.set_title('Performance SpatialSense+ par Fold', fontsize=14, fontweight='bold')
    ax3.set_xlabel('Fold')
    ax3.set_ylabel('Best Validation Accuracy (%)')
    ax3.set_xticks(x)
    ax3.set_xticklabels(folds)
    ax3.grid(True, axis='y', alpha=0.3)

    # Annotations
    for i, (bar, acc) in enumerate(zip(bars, best_val_accs)):
        height = bar.get_height()
        ax3.annotate(f'{acc:.1f}%',
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3), textcoords="offset points",
                    ha='center', va='bottom', fontsize=9)

    # 4. Résumé statistique SpatialSense+
    ax4.axis('off')

    mean_best = np.mean(best_val_accs)
    std_best = np.std(best_val_accs)

    summary_text = f"""
    RÉSULTATS SPATIALSENSE+

    Dataset:
    • Relations: {len(SPATIAL_RELATIONS)} classes
    • Architecture: ViT + MLP Haldekar-style
    • Modèle: ViT-Base-224 + MLP(4096→512→256→{len(SPATIAL_RELATIONS)})

    Performance:
    • Accuracy moyenne: {mean_best:.2f}% ± {std_best:.2f}%
    • Min: {min(best_val_accs):.2f}%, Max: {max(best_val_accs):.2f}%

    Paramètres:
    • Epochs: {EPOCHS}
    • K-folds: {len(fold_results)}
    • Learning rate: {LEARNING_RATE}

    Comparaison article Haldekar:
    • Article (SUN09): 55.98% test accuracy
    • Notre approche: {mean_best:.2f}% (ViT vs VGG)
    """

    color = 'lightgreen' if mean_best > 55.98 else 'lightyellow'
    ax4.text(0.05, 0.95, summary_text, transform=ax4.transAxes,
             fontsize=11, verticalalignment='top',
             bbox=dict(boxstyle='round,pad=0.5', facecolor=color, alpha=0.8))

    plt.tight_layout()
    plt.savefig('spatialsense_results_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()

# =============================================================================
# FONCTION D'ENTRAÎNEMENT PRINCIPALE AMÉLIORÉE
# =============================================================================

def train_improved_spatialsense_model(data_dir, epochs=EPOCHS, k_folds=K_FOLDS):
    """Entraînement amélioré avec tous les optimisations"""

    print("="*60)
    print("ENTRAÎNEMENT AMÉLIORÉ HALDEKAR-STYLE + ViT")
    print("="*60)

    # Transformations améliorées
    train_transform, val_transform = create_improved_vit_transforms()

    # Dataset
    full_dataset = SpatialSenseDataset(
        data_dir=data_dir,
        split='train',
        transform=train_transform
    )

    if len(full_dataset) == 0:
        print("Erreur: Dataset vide!")
        return []

    # K-fold validation
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_indices, val_indices) in enumerate(kfold.split(full_dataset)):
        print(f"\n{'='*40}")
        print(f"FOLD {fold + 1}/{k_folds} - AMÉLIORÉ")
        print(f"{'='*40}")

        # Sous-ensembles
        train_subset = torch.utils.data.Subset(full_dataset, train_indices)
        val_subset = torch.utils.data.Subset(full_dataset, val_indices)

        # DataLoaders
        train_loader = DataLoader(
            train_subset, batch_size=BATCH_SIZE, shuffle=True,
            num_workers=4, pin_memory=True, persistent_workers=True
        )

        val_loader = DataLoader(
            val_subset, batch_size=BATCH_SIZE, shuffle=False,
            num_workers=4, pin_memory=True, persistent_workers=True
        )

        # Modèle amélioré avec fine-tuning partiel
        model = ImprovedViTSpatialRelationModel(
            num_relations=len(SPATIAL_RELATIONS),
            vit_model='google/vit-base-patch16-224',
            freeze_vit=False,  # Fine-tuning activé
            fine_tune_layers=4
        )
        model = model.to(DEVICE)

        # Optimiseur et scheduler améliorés
        optimizer, scheduler = get_improved_optimizer_and_scheduler(
            model, len(train_loader)
        )

        # Loss avec smoothing
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

        # Historique
        train_losses, val_losses = [], []
        train_accs, val_accs = [], []
        best_val_acc = 0.0

        # Boucle d'entraînement améliorée
        for epoch in range(epochs):
            print(f"\nEpoch {epoch+1}/{epochs}")

            # Entraînement avec améliorations
            train_loss, train_acc = train_one_epoch_improved(
                model, train_loader, criterion, optimizer, scheduler, DEVICE, epoch
            )
            train_losses.append(train_loss)
            train_accs.append(train_acc)

            # Validation
            val_loss, val_acc = evaluate(model, val_loader, criterion, DEVICE)
            val_losses.append(val_loss)
            val_accs.append(val_acc)

            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
            print(f"LR: {optimizer.param_groups[0]['lr']:.2e}")

            # Sauvegarde si meilleur
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(),
                          f'improved_spatialsense_model_fold_{fold+1}.pth')

        # Résultats du fold
        fold_result = {
            'fold': fold + 1,
            'train_losses': train_losses,
            'val_losses': val_losses,
            'train_accs': train_accs,
            'val_accs': val_accs,
            'best_val_acc': best_val_acc,
            'final_train_acc': train_accs[-1],
            'final_val_acc': val_accs[-1],
            'total_epochs': len(train_losses),
            'model': model
        }

        fold_results.append(fold_result)
        print(f"\nRésumé Fold {fold + 1}: Meilleure acc: {best_val_acc:.2f}%")

    return fold_results

# =============================================================================
# FONCTION PRINCIPALE AMÉLIORÉE
# =============================================================================

def main_improved_spatialsense_experiment():
    """Expérience principale avec toutes les améliorations"""

    DATA_DIR = "data/spatialsense"


    if not os.path.exists(DATA_DIR):
        print(f"\nERREUR: Le répertoire {DATA_DIR} n'existe pas!")
        return

    # Entraînement amélioré
    print("\n     Entraînement avec améliorations")
    results = train_improved_spatialsense_model(DATA_DIR)

    if not results:
        print("Erreur pendant l'entraînement")
        return

    # Analyse des résultats
    print("\n   Analyse des résultats améliorés")
    visualize_spatialsense_results(results)

    # Test avec meilleur modèle
    print("\n   Test du meilleur modèle")
    best_fold_idx = np.argmax([r['best_val_acc'] for r in results])
    best_model = results[best_fold_idx]['model']

    # Dataset de test
    _, test_transform = create_improved_vit_transforms()
    test_dataset = SpatialSenseDataset(DATA_DIR, split='test', transform=test_transform)

    if len(test_dataset) == 0:
        test_dataset = SpatialSenseDataset(DATA_DIR, split='val', transform=test_transform)

    if len(test_dataset) > 0:
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                               shuffle=False, num_workers=4, pin_memory=True)

        # Visualisation des prédictions
        visualize_spatialsense_predictions(best_model, test_dataset, DEVICE, num_samples=8)

        # Analyse par relation avec correction
        relation_accs, relation_counts = analyze_performance_per_relation_fixed(
            best_model, test_loader, DEVICE
        )

        # Évaluation finale
        criterion = nn.CrossEntropyLoss()
        test_loss, test_acc = evaluate(best_model, test_loader, criterion, DEVICE)

        print(f"\n  Résultats finaux améliorés:")
        print(f"  - Test Loss: {test_loss:.4f}")
        print(f"  - Test Accuracy: {test_acc:.2f}%")
        print(f"  - Amélioration vs article: {test_acc - 55.98:.2f}%")

        # Résumé par relation
        print("\n       Performance par relation:")
        sorted_relations = sorted(relation_accs.items(), key=lambda x: x[1], reverse=True)

        print("\nMeilleures performances:")
        for rel, acc in sorted_relations[:3]:
            count = relation_counts[rel]
            print(f"  {rel}: {acc:.2f}% (n={count})")

        print("\nPires performances:")
        for rel, acc in sorted_relations[-3:]:
            count = relation_counts[rel]
            print(f"  {rel}: {acc:.2f}% (n={count})")

        overall_acc = sum([relation_counts[rel] * relation_accs[rel] for rel in SPATIAL_RELATIONS]) / sum(relation_counts.values())
        print(f"\nAccuracy globale pondérée: {overall_acc:.2f}%")

    else:
        print("Aucun dataset de test disponible")

    # Sauvegarde
    torch.save(best_model.state_dict(), 'improved_spatialsense_final_model.pth')

    # Comparaison finale
    mean_acc = np.mean([r['best_val_acc'] for r in results])
    std_acc = np.std([r['best_val_acc'] for r in results])

    print("\n" + "="*80)
    print("COMPARAISON FINALE AVEC HALDEKAR ET AL. 2017")
    print("="*80)

    print(f"\nArticle original:")
    print(f"  - Dataset: SUN09 (3 relations)")
    print(f"  - Architecture: VGGNet FC-7 + MLP")
    print(f"  - Performance: 55.98% test accuracy")

    print(f"\nNotre implémentation améliorée:")
    print(f"  - Dataset: SpatialSense+ ({len(SPATIAL_RELATIONS)} relations)")
    print(f"  - Architecture: ViT-Base + MLP amélioré")
    print(f"  - Performance: {mean_acc:.2f}% ± {std_acc:.2f}%")

    if 'test_acc' in locals():
        print(f"  - Test accuracy: {test_acc:.2f}%")
        if test_acc > 55.98:
            print(f"          Amélioration: +{test_acc - 55.98:.2f}%")
        else:
            print(f"        À améliorer: {test_acc - 55.98:.2f}% (tâche plus difficile)")

    print(f"\nTechniques modernes ajoutées:")
    print(f"     Vision Transformer (2021) vs VGGNet (2014)")
    print(f"     Fine-tuning partiel")
    print(f"     BatchNormalization")
    print(f"     Learning rate scheduling avec warmup")
    print(f"     Mixed precision training")
    print(f"     Label smoothing")
    print(f"     Data augmentation renforcée")
    print(f"     Gradient clipping")

    print("\n   Expérience améliorée terminée avec succès!")
    print(f"       {len(SPATIAL_RELATIONS)} relations vs 3 dans l'article original")
    print(f"     Architecture moderne adaptée de Haldekar et al. 2017")

    return results

# =============================================================================
# POINT D'ENTRÉE
# =============================================================================

if __name__ == "__main__":
    main_improved_spatialsense_experiment()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image, ImageDraw
import numpy as np
import json
import os
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# CONFIGURATION GLOBALE
# =============================================================================

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device utilisé: {DEVICE}")

# Hyperparamètres
BATCH_SIZE = 8  # Réduit car deux images en entrée
LEARNING_RATE = 0.001
EPOCHS = 15
K_FOLDS = 5
IMG_SIZE = 224
DROPOUT_RATE = 0.4

# Relations spatiales SpatialSense+
SPATIAL_RELATIONS = [
    'above', 'behind', 'in', 'in front of', 'next to',
    'on', 'to the left of', 'to the right of', 'under'
]

print(f"Architecture DUALE avec masquage bounding box:")
print(f"  - Relations: {len(SPATIAL_RELATIONS)}")
print(f"  - Entrée 1: Image originale")
print(f"  - Entrée 2: Image masquée (seules les bounding boxes visibles)")

# =============================================================================
# DATASET DUAL INPUT AVEC MASQUAGE BOUNDING BOX
# =============================================================================

class DualInputSpatialSenseDataset(Dataset):
    """
    Dataset SpatialSense+ avec architecture duale:
    - Image originale
    - Image masquée où seuls les contenus des bounding boxes sont visibles
    """

    def __init__(self, data_dir, split='train', transform=None, mask_background_color=(128, 128, 128)):
        self.data_dir = data_dir
        self.split = split
        self.transform = transform
        self.mask_background_color = mask_background_color  # Couleur de masquage

        self.relation_to_idx = {rel: idx for idx, rel in enumerate(SPATIAL_RELATIONS)}
        self.idx_to_relation = {idx: rel for rel, idx in self.relation_to_idx.items()}

        self.annotations = self._load_annotations()
        self.data_samples = self._prepare_samples()

        print(f"Dataset DUAL SpatialSense+ {split}: {len(self.data_samples)} échantillons")
        if len(self.data_samples) > 0:
            self._print_statistics()

    def _load_annotations(self):
        annotations_path = os.path.join(self.data_dir, 'annotations.json')
        try:
            with open(annotations_path, 'r') as f:
                return json.load(f)
        except FileNotFoundError:
            print(f"Erreur: {annotations_path} non trouvé!")
            return []

    def _find_image_path(self, image_url):
        base_dir = os.path.join(self.data_dir, "images", "images")
        filename = os.path.basename(image_url)

        if "staticflickr" in image_url or len(filename.split('_')) == 2:
            return os.path.join(base_dir, "flickr", filename)
        else:
            return os.path.join(base_dir, "nyu", filename)

    def _prepare_samples(self):
        samples = []
        images_not_found = 0

        for img_data in self.annotations:
            if img_data['split'] != self.split:
                continue

            img_path = self._find_image_path(img_data['url'])

            if not os.path.exists(img_path):
                images_not_found += 1
                continue

            for ann in img_data['annotations']:
                if ann['label'] and ann['predicate'].lower().strip() in [rel.lower() for rel in SPATIAL_RELATIONS]:
                    # Trouver la relation correspondante
                    relation = None
                    for rel in SPATIAL_RELATIONS:
                        if rel.lower() == ann['predicate'].lower().strip():
                            relation = rel
                            break

                    if relation and 'bbox' in ann['subject'] and 'bbox' in ann['object']:
                        sample = {
                            'image_path': img_path,
                            'subject': ann['subject']['name'],
                            'object': ann['object']['name'],
                            'relation': relation,
                            'original_relation': ann['predicate'],
                            'subject_bbox': ann['subject']['bbox'],  # [y1, y2, x1, x2]
                            'object_bbox': ann['object']['bbox'],   # [y1, y2, x1, x2]
                            'image_width': img_data['width'],
                            'image_height': img_data['height']
                        }
                        samples.append(sample)

        if images_not_found > 0:
            print(f"Images non trouvées: {images_not_found}")

        return samples

    def _create_masked_image(self, image, subject_bbox, object_bbox):
        """
        Crée une image masquée où seuls les contenus des bounding boxes sont visibles
        bbox format: [y1, y2, x1, x2]
        """
        # Convertir en array numpy pour manipulation
        img_array = np.array(image)

        # Créer un masque de la même taille, rempli avec la couleur de fond
        masked_array = np.full_like(img_array, self.mask_background_color)

        # Fonction pour appliquer une bounding box
        def apply_bbox(bbox):
            try:
                y1, y2, x1, x2 = bbox
                # Clamp les coordonnées pour éviter les débordements
                y1 = max(0, min(int(y1), img_array.shape[0]))
                y2 = max(0, min(int(y2), img_array.shape[0]))
                x1 = max(0, min(int(x1), img_array.shape[1]))
                x2 = max(0, min(int(x2), img_array.shape[1]))

                # Copier le contenu de la bounding box de l'image originale
                if y2 > y1 and x2 > x1:
                    masked_array[y1:y2, x1:x2] = img_array[y1:y2, x1:x2]
            except Exception as e:
                print(f"Erreur application bbox {bbox}: {e}")

        # Appliquer les deux bounding boxes
        apply_bbox(subject_bbox)
        apply_bbox(object_bbox)

        # Reconvertir en PIL Image
        masked_image = Image.fromarray(masked_array)
        return masked_image

    def _print_statistics(self):
        relation_counts = Counter([s['relation'] for s in self.data_samples])
        print(f"\nDistribution DUAL SpatialSense+ dans {self.split}:")
        total = len(self.data_samples)

        for relation in SPATIAL_RELATIONS:
            count = relation_counts.get(relation, 0)
            percentage = count / total * 100 if total > 0 else 0
            print(f"   {relation}: {count} ({percentage:.1f}%)")

    def __len__(self):
        return len(self.data_samples)

    def __getitem__(self, idx):
        sample = self.data_samples[idx]

        try:
            # Charger l'image originale
            original_image = Image.open(sample['image_path']).convert('RGB')

            # Créer l'image masquée avec les bounding boxes
            masked_image = self._create_masked_image(
                original_image,
                sample['subject_bbox'],
                sample['object_bbox']
            )

            # Appliquer les transformations
            if self.transform:
                original_image = self.transform(original_image)
                masked_image = self.transform(masked_image)

            label = self.relation_to_idx[sample['relation']]

            metadata = {
                'subject': sample['subject'],
                'object': sample['object'],
                'relation': sample['relation'],
                'original_relation': sample['original_relation'],
                'subject_bbox': sample['subject_bbox'],
                'object_bbox': sample['object_bbox']
            }

            return original_image, masked_image, label, metadata

        except Exception as e:
            print(f"Erreur chargement {sample['image_path']}: {e}")
            # Images par défaut en cas d'erreur
            dummy_original = Image.new('RGB', (224, 224), color='gray')
            dummy_masked = Image.new('RGB', (224, 224), color=self.mask_background_color)

            if self.transform:
                dummy_original = self.transform(dummy_original)
                dummy_masked = self.transform(dummy_masked)
            else:
                dummy_original = torch.zeros(3, 224, 224)
                dummy_masked = torch.zeros(3, 224, 224)

            return dummy_original, dummy_masked, 0, {
                'subject': 'error', 'object': 'error',
                'relation': 'next to', 'original_relation': 'error',
                'subject_bbox': [0, 0, 0, 0], 'object_bbox': [0, 0, 0, 0]
            }

# =============================================================================
# ARCHITECTURE DUALE VGG
# =============================================================================

class DualVGGFeatureExtractor(nn.Module):
    """
    Extracteur de features VGG dual:
    - Branche 1: Image originale
    - Branche 2: Image masquée (bounding boxes seulement)
    """

    def __init__(self, fusion_method='concat'):
        super(DualVGGFeatureExtractor, self).__init__()

        self.fusion_method = fusion_method  # 'concat', 'add', 'attention'

        # VGG16 pré-entraîné partagé pour les deux branches
        vgg16 = models.vgg16(pretrained=True)

        # Extraction des couches
        self.features = vgg16.features
        self.avgpool = vgg16.avgpool

        # FC-7 partagé
        classifier_layers = list(vgg16.classifier.children())[:6]
        self.fc7 = nn.Sequential(*classifier_layers)

        # Gel des poids VGG
        for param in self.parameters():
            param.requires_grad = False

        # Couche de fusion des features
        if fusion_method == 'concat':
            self.fusion_dim = 4096 * 2  # Concaténation
        elif fusion_method == 'add':
            self.fusion_dim = 4096      # Addition
        elif fusion_method == 'attention':
            self.fusion_dim = 4096
            # Mécanisme d'attention simple
            self.attention = nn.Sequential(
                nn.Linear(4096 * 2, 4096),
                nn.ReLU(),
                nn.Linear(4096, 2),
                nn.Softmax(dim=1)
            )

        print(f"Dual VGG Feature Extractor:")
        print(f"  - Méthode fusion: {fusion_method}")
        print(f"  - Dimension sortie: {self.fusion_dim}")
        print(f"  - Branche 1: Image originale")
        print(f"  - Branche 2: Image masquée (bounding boxes)")

    def forward(self, original_img, masked_img):
        # Extraction features pour image originale
        features_orig = self._extract_features(original_img)

        # Extraction features pour image masquée
        features_masked = self._extract_features(masked_img)

        # Fusion des features
        if self.fusion_method == 'concat':
            # Concaténation simple
            fused_features = torch.cat([features_orig, features_masked], dim=1)
        elif self.fusion_method == 'add':
            # Addition pondérée
            fused_features = features_orig + features_masked
        elif self.fusion_method == 'attention':
            # Mécanisme d'attention
            combined = torch.cat([features_orig, features_masked], dim=1)
            attention_weights = self.attention(combined)  # [batch, 2]

            # Application des poids d'attention
            weighted_orig = features_orig * attention_weights[:, 0:1]
            weighted_masked = features_masked * attention_weights[:, 1:2]
            fused_features = weighted_orig + weighted_masked

        return fused_features

    def _extract_features(self, x):
        """Extraction des features VGG FC-7"""
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc7(x)
        return x

class DualSpatialRelationMLP(nn.Module):
    """
    MLP adapté pour les features duales
    """

    def __init__(self, input_dim, hidden1_dim=512, hidden2_dim=256,
                 num_relations=len(SPATIAL_RELATIONS), dropout_rate=0.4):
        super(DualSpatialRelationMLP, self).__init__()

        print(f"Dual MLP Architecture:")
        print(f"  - Input: {input_dim} (features fusionnées)")
        print(f"  - Hidden 1: {hidden1_dim}")
        print(f"  - Hidden 2: {hidden2_dim}")
        print(f"  - Output: {num_relations}")
        print(f"  - Dropout: {dropout_rate}")

        # Couches MLP
        self.fc1 = nn.Linear(input_dim, hidden1_dim)
        self.bn1 = nn.BatchNorm1d(hidden1_dim)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)

        self.fc2 = nn.Linear(hidden1_dim, hidden2_dim)
        self.bn2 = nn.BatchNorm1d(hidden2_dim)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_rate)

        self.fc3 = nn.Linear(hidden2_dim, num_relations)

        # Initialisation
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.xavier_uniform_(self.fc3.weight)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        x = self.fc3(x)
        return x

class DualInputSpatialRelationModel(nn.Module):
    """
    Modèle complet avec architecture duale
    """

    def __init__(self, num_relations=len(SPATIAL_RELATIONS), fusion_method='concat'):
        super(DualInputSpatialRelationModel, self).__init__()

        print(f"Modèle Dual Input Spatial Relation:")
        print(f"  - Relations: {num_relations}")
        print(f"  - Fusion: {fusion_method}")

        # Extracteur dual
        self.feature_extractor = DualVGGFeatureExtractor(fusion_method=fusion_method)

        # Classifieur MLP
        self.classifier = DualSpatialRelationMLP(
            input_dim=self.feature_extractor.fusion_dim,
            num_relations=num_relations
        )

    def forward(self, original_img, masked_img):
        # Extraction des features duales
        fused_features = self.feature_extractor(original_img, masked_img)

        # Classification
        output = self.classifier(fused_features)
        return output

# =============================================================================
# TRANSFORMATIONS
# =============================================================================

def create_dual_transforms():
    """Transformations pour les images duales"""
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )

    train_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
        transforms.ToTensor(),
        normalize
    ])

    val_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        normalize
    ])

    return train_transform, val_transform

# =============================================================================
# FONCTIONS D'ENTRAÎNEMENT DUALES
# =============================================================================

def train_dual_epoch(model, dataloader, criterion, optimizer, device):
    """Entraîne une époque avec architecture duale"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    progress_bar = tqdm(dataloader, desc='Training Dual')
    for batch_idx, (original_imgs, masked_imgs, labels, _) in enumerate(progress_bar):
        original_imgs = original_imgs.to(device)
        masked_imgs = masked_imgs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(original_imgs, masked_imgs)
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        progress_bar.set_postfix({
            'loss': running_loss / (batch_idx + 1),
            'acc': 100. * correct / total
        })

    return running_loss / len(dataloader), 100. * correct / total

def evaluate_dual(model, dataloader, criterion, device):
    """Évalue le modèle dual"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for original_imgs, masked_imgs, labels, _ in tqdm(dataloader, desc='Evaluating Dual'):
            original_imgs = original_imgs.to(device)
            masked_imgs = masked_imgs.to(device)
            labels = labels.to(device)

            outputs = model(original_imgs, masked_imgs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return running_loss / len(dataloader), 100. * correct / total

# =============================================================================
# VISUALISATION DUAL INPUT
# =============================================================================

def visualize_dual_input_samples(dataset, num_samples=4):
    """Visualise des échantillons avec les deux types d'images"""
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 4*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    for i in range(min(num_samples, len(dataset))):
        original_img, masked_img, label, metadata = dataset[i]

        # Dénormaliser pour affichage
        def denormalize(tensor):
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            return (tensor * std + mean).clamp(0, 1).permute(1, 2, 0)

        original_display = denormalize(original_img)
        masked_display = denormalize(masked_img)

        # Image originale
        axes[i, 0].imshow(original_display)
        axes[i, 0].set_title(f"Original\n{metadata['subject']} - {metadata['object']}")
        axes[i, 0].axis('off')

        # Image masquée
        axes[i, 1].imshow(masked_display)
        axes[i, 1].set_title(f"Masquée (BBox seulement)\nRelation: {metadata['relation']}")
        axes[i, 1].axis('off')

        # Informations
        axes[i, 2].axis('off')
        info_text = f"""
        Échantillon {i+1}:

        Sujet: {metadata['subject']}
        Objet: {metadata['object']}
        Relation: {metadata['relation']}
        Original: {metadata['original_relation']}

        Subject BBox: {metadata['subject_bbox']}
        Object BBox: {metadata['object_bbox']}

        Architecture:
        • Image complète
        • Image masquée (objets seulement)
        • Fusion des features VGG
        """
        axes[i, 2].text(0.1, 0.5, info_text, transform=axes[i, 2].transAxes,
                       fontsize=10, verticalalignment='center',
                       bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))

    plt.suptitle('Architecture Duale: Image Originale + Image Masquée (BBox)',
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# =============================================================================
# ENTRAÎNEMENT DUAL K-FOLD
# =============================================================================

def train_dual_kfold(data_dir, k_folds=5, epochs=15, fusion_method='concat'):
    """Entraînement K-fold avec architecture duale"""

    print("="*70)
    print("ENTRAÎNEMENT DUAL INPUT SPATIALSENSE+")
    print("="*70)
    print(f"Architecture: VGG Dual + MLP")
    print(f"Fusion: {fusion_method}")
    print(f"Entrée 1: Image originale")
    print(f"Entrée 2: Image masquée (bounding boxes)")

    # Transformations
    train_transform, val_transform = create_dual_transforms()

    # Dataset dual
    full_dataset = DualInputSpatialSenseDataset(
        data_dir=data_dir,
        split='train',
        transform=train_transform,
        mask_background_color=(128, 128, 128)  # Gris
    )

    if len(full_dataset) == 0:
        print("Dataset vide!")
        return []

    # Visualisation d'échantillons
    print("\nVisu échantillons dual input:")
    visualize_dual_input_samples(full_dataset, num_samples=3)

    # K-fold
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_indices, val_indices) in enumerate(kfold.split(full_dataset)):
        print(f"\n{'='*50}")
        print(f"FOLD {fold + 1}/{k_folds} - DUAL INPUT")
        print(f"{'='*50}")

        # Sous-ensembles
        train_subset = torch.utils.data.Subset(full_dataset, train_indices)
        val_subset = torch.utils.data.Subset(full_dataset, val_indices)

        # DataLoaders
        train_loader = DataLoader(
            train_subset, batch_size=BATCH_SIZE, shuffle=True,
            num_workers=2, pin_memory=True
        )

        val_loader = DataLoader(
            val_subset, batch_size=BATCH_SIZE, shuffle=False,
            num_workers=2, pin_memory=True
        )

        # Modèle dual
        model = DualInputSpatialRelationModel(
            num_relations=len(SPATIAL_RELATIONS),
            fusion_method=fusion_method
        )
        model = model.to(DEVICE)

        # Optimiseur (seulement MLP entraînable)
        optimizer = optim.AdamW(model.classifier.parameters(),
                               lr=LEARNING_RATE, weight_decay=0.01)
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

        # Scheduler
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

        # Entraînement
        train_losses, val_losses = [], []
        train_accs, val_accs = [], []
        best_val_acc = 0.0

        for epoch in range(epochs):
            print(f"\nEpoch {epoch+1}/{epochs}")

            # Training
            train_loss, train_acc = train_dual_epoch(
                model, train_loader, criterion, optimizer, DEVICE
            )
            train_losses.append(train_loss)
            train_accs.append(train_acc)

            # Validation
            val_loss, val_acc = evaluate_dual(model, val_loader, criterion, DEVICE)
            val_losses.append(val_loss)
            val_accs.append(val_acc)

            scheduler.step()

            print(f"Train: {train_loss:.4f} / {train_acc:.2f}%")
            print(f"Val: {val_loss:.4f} / {val_acc:.2f}%")

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(),
                          f'best_dual_model_fold_{fold+1}_{fusion_method}.pth')

        # Résultats fold
        fold_results.append({
            'fold': fold + 1,
            'train_losses': train_losses,
            'val_losses': val_losses,
            'train_accs': train_accs,
            'val_accs': val_accs,
            'best_val_acc': best_val_acc,
            'model': model
        })

        print(f"Fold {fold + 1} - Meilleur: {best_val_acc:.2f}%")

    return fold_results

# =============================================================================
# FONCTION PRINCIPALE DUAL
# =============================================================================

def main_dual_experiment():
    """Expérience principale avec architecture duale"""

    DATA_DIR = "data/spatialsense"

    print("="*80)
    print("ARCHITECTURE DUALE - HALDEKAR + BOUNDING BOX MASKING")
    print("="*80)
    print("Innovation:")
    print("     Entrée 1: Image complète (comme Haldekar)")
    print("     Entrée 2: Image masquée (seules les bounding boxes)")
    print("     Fusion des features VGG")
    print("     Focus sur les objets pertinents")

    if not os.path.exists(DATA_DIR):
        print(f"\nERREUR: {DATA_DIR} n'existe pas!")
        return

    # Test différentes méthodes de fusion
    fusion_methods = ['concat', 'add', 'attention']

    for fusion_method in fusion_methods:
        print(f"\n{'='*60}")
        print(f"TEST FUSION: {fusion_method.upper()}")
        print(f"{'='*60}")

        results = train_dual_kfold(
            data_dir=DATA_DIR,
            k_folds=3,  # Réduit pour tester rapidement
            epochs=10,
            fusion_method=fusion_method
        )

        if results:
            mean_acc = np.mean([r['best_val_acc'] for r in results])
            print(f"\n  {fusion_method}: {mean_acc:.2f}% (moyenne)")

    print("\n   Expérience architecture duale terminée!")

if __name__ == "__main__":
    main_dual_experiment()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image, ImageDraw
import numpy as np
import json
import os
from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import cv2
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# CONFIGURATION GLOBALE
# =============================================================================

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device utilisé: {DEVICE}")

# Hyperparamètres
BATCH_SIZE = 8  # Réduit car deux images en entrée
LEARNING_RATE = 0.001
EPOCHS = 15
K_FOLDS = 5
IMG_SIZE = 224
DROPOUT_RATE = 0.4

# Relations spatiales SpatialSense+
SPATIAL_RELATIONS = [
    'above', 'behind', 'in', 'in front of', 'next to',
    'on', 'to the left of', 'to the right of', 'under'
]

print(f"Architecture DUALE avec masque binaire bounding box (CONCAT SEULEMENT):")
print(f"  - Relations: {len(SPATIAL_RELATIONS)}")
print(f"  - Entrée 1: Image originale")
print(f"  - Entrée 2: Masque binaire des bounding boxes")
print(f"  - Fusion: Concaténation uniquement")

# =============================================================================
# DATASET DUAL INPUT AVEC MASQUE BINAIRE
# =============================================================================

class DualInputBinaryMaskDataset(Dataset):
    """
    Dataset SpatialSense+ avec architecture duale:
    - Image originale
    - Masque binaire des bounding boxes (blanc=objet, noir=fond)
    """

    def __init__(self, data_dir, split='train', transform=None):
        self.data_dir = data_dir
        self.split = split
        self.transform = transform

        self.relation_to_idx = {rel: idx for idx, rel in enumerate(SPATIAL_RELATIONS)}
        self.idx_to_relation = {idx: rel for rel, idx in self.relation_to_idx.items()}

        self.annotations = self._load_annotations()
        self.data_samples = self._prepare_samples()

        print(f"Dataset DUAL Binary Mask SpatialSense+ {split}: {len(self.data_samples)} échantillons")
        if len(self.data_samples) > 0:
            self._print_statistics()

    def _load_annotations(self):
        annotations_path = os.path.join(self.data_dir, 'annotations.json')
        try:
            with open(annotations_path, 'r') as f:
                return json.load(f)
        except FileNotFoundError:
            print(f"Erreur: {annotations_path} non trouvé!")
            return []

    def _find_image_path(self, image_url):
        base_dir = os.path.join(self.data_dir, "images", "images")
        filename = os.path.basename(image_url)

        if "staticflickr" in image_url or len(filename.split('_')) == 2:
            return os.path.join(base_dir, "flickr", filename)
        else:
            return os.path.join(base_dir, "nyu", filename)

    def _prepare_samples(self):
        samples = []
        images_not_found = 0

        for img_data in self.annotations:
            if img_data['split'] != self.split:
                continue

            img_path = self._find_image_path(img_data['url'])

            if not os.path.exists(img_path):
                images_not_found += 1
                continue

            for ann in img_data['annotations']:
                if ann['label'] and ann['predicate'].lower().strip() in [rel.lower() for rel in SPATIAL_RELATIONS]:
                    # Trouver la relation correspondante
                    relation = None
                    for rel in SPATIAL_RELATIONS:
                        if rel.lower() == ann['predicate'].lower().strip():
                            relation = rel
                            break

                    if relation and 'bbox' in ann['subject'] and 'bbox' in ann['object']:
                        sample = {
                            'image_path': img_path,
                            'subject': ann['subject']['name'],
                            'object': ann['object']['name'],
                            'relation': relation,
                            'original_relation': ann['predicate'],
                            'subject_bbox': ann['subject']['bbox'],  # [y1, y2, x1, x2]
                            'object_bbox': ann['object']['bbox'],   # [y1, y2, x1, x2]
                            'image_width': img_data['width'],
                            'image_height': img_data['height']
                        }
                        samples.append(sample)

        if images_not_found > 0:
            print(f"Images non trouvées: {images_not_found}")

        return samples

    def _create_binary_mask(self, image_size, subject_bbox, object_bbox):
        """
        Crée un masque binaire où les bounding boxes sont en blanc (255) et le fond en noir (0)
        bbox format: [y1, y2, x1, x2]
        """
        # Créer un masque noir
        mask = np.zeros((image_size[1], image_size[0]), dtype=np.uint8)  # (height, width)

        # Fonction pour appliquer une bounding box
        def apply_bbox(bbox):
            try:
                y1, y2, x1, x2 = bbox
                # Clamp les coordonnées pour éviter les débordements
                y1 = max(0, min(int(y1), mask.shape[0]))
                y2 = max(0, min(int(y2), mask.shape[0]))
                x1 = max(0, min(int(x1), mask.shape[1]))
                x2 = max(0, min(int(x2), mask.shape[1]))

                # Remplir la bounding box en blanc (255)
                if y2 > y1 and x2 > x1:
                    mask[y1:y2, x1:x2] = 255
            except Exception as e:
                print(f"Erreur application bbox {bbox}: {e}")

        # Appliquer les deux bounding boxes
        apply_bbox(subject_bbox)
        apply_bbox(object_bbox)

        # Convertir en PIL Image et dupliquer pour avoir 3 canaux (RGB)
        mask_pil = Image.fromarray(mask, mode='L')
        mask_rgb = Image.merge('RGB', (mask_pil, mask_pil, mask_pil))

        return mask_rgb

    def _print_statistics(self):
        relation_counts = Counter([s['relation'] for s in self.data_samples])
        print(f"\nDistribution DUAL Binary Mask dans {self.split}:")
        total = len(self.data_samples)

        for relation in SPATIAL_RELATIONS:
            count = relation_counts.get(relation, 0)
            percentage = count / total * 100 if total > 0 else 0
            print(f"   {relation}: {count} ({percentage:.1f}%)")

    def __len__(self):
        return len(self.data_samples)

    def __getitem__(self, idx):
        sample = self.data_samples[idx]

        try:
            # Charger l'image originale
            original_image = Image.open(sample['image_path']).convert('RGB')

            # Créer le masque binaire des bounding boxes
            binary_mask = self._create_binary_mask(
                original_image.size,  # (width, height)
                sample['subject_bbox'],
                sample['object_bbox']
            )

            # Appliquer les transformations
            if self.transform:
                original_image = self.transform(original_image)
                binary_mask = self.transform(binary_mask)

            label = self.relation_to_idx[sample['relation']]

            metadata = {
                'subject': sample['subject'],
                'object': sample['object'],
                'relation': sample['relation'],
                'original_relation': sample['original_relation'],
                'subject_bbox': sample['subject_bbox'],
                'object_bbox': sample['object_bbox']
            }

            return original_image, binary_mask, label, metadata

        except Exception as e:
            print(f"Erreur chargement {sample['image_path']}: {e}")
            # Images par défaut en cas d'erreur
            dummy_original = Image.new('RGB', (224, 224), color='gray')
            dummy_mask = Image.new('RGB', (224, 224), color='black')

            if self.transform:
                dummy_original = self.transform(dummy_original)
                dummy_mask = self.transform(dummy_mask)
            else:
                dummy_original = torch.zeros(3, 224, 224)
                dummy_mask = torch.zeros(3, 224, 224)

            return dummy_original, dummy_mask, 0, {
                'subject': 'error', 'object': 'error',
                'relation': 'next to', 'original_relation': 'error',
                'subject_bbox': [0, 0, 0, 0], 'object_bbox': [0, 0, 0, 0]
            }

# =============================================================================
# ARCHITECTURE DUALE VGG (CONCAT SEULEMENT)
# =============================================================================

class DualVGGFeatureExtractor(nn.Module):
    """
    Extracteur de features VGG dual avec concaténation uniquement:
    - Branche 1: Image originale
    - Branche 2: Masque binaire des bounding boxes
    """

    def __init__(self):
        super(DualVGGFeatureExtractor, self).__init__()

        # VGG16 pré-entraîné partagé pour les deux branches
        vgg16 = models.vgg16(pretrained=True)

        # Extraction des couches
        self.features = vgg16.features
        self.avgpool = vgg16.avgpool

        # FC-7 partagé
        classifier_layers = list(vgg16.classifier.children())[:6]
        self.fc7 = nn.Sequential(*classifier_layers)

        # Gel des poids VGG
        for param in self.parameters():
            param.requires_grad = False

        # Dimension de sortie après concaténation
        self.fusion_dim = 4096 * 2  # Concaténation des deux branches

        print(f"Dual VGG Feature Extractor (CONCAT uniquement):")
        print(f"  - Dimension sortie: {self.fusion_dim}")
        print(f"  - Branche 1: Image originale")
        print(f"  - Branche 2: Masque binaire (bounding boxes)")

    def forward(self, original_img, binary_mask):
        # Extraction features pour image originale
        features_orig = self._extract_features(original_img)

        # Extraction features pour masque binaire
        features_mask = self._extract_features(binary_mask)

        # Concaténation simple
        fused_features = torch.cat([features_orig, features_mask], dim=1)

        return fused_features

    def _extract_features(self, x):
        """Extraction des features VGG FC-7"""
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc7(x)
        return x

class DualSpatialRelationMLP(nn.Module):
    """
    MLP adapté pour les features duales concaténées
    """

    def __init__(self, input_dim, hidden1_dim=512, hidden2_dim=256,
                 num_relations=len(SPATIAL_RELATIONS), dropout_rate=0.4):
        super(DualSpatialRelationMLP, self).__init__()

        print(f"Dual MLP Architecture:")
        print(f"  - Input: {input_dim} (features concaténées)")
        print(f"  - Hidden 1: {hidden1_dim}")
        print(f"  - Hidden 2: {hidden2_dim}")
        print(f"  - Output: {num_relations}")
        print(f"  - Dropout: {dropout_rate}")

        # Couches MLP
        self.fc1 = nn.Linear(input_dim, hidden1_dim)
        self.bn1 = nn.BatchNorm1d(hidden1_dim)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)

        self.fc2 = nn.Linear(hidden1_dim, hidden2_dim)
        self.bn2 = nn.BatchNorm1d(hidden2_dim)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_rate)

        self.fc3 = nn.Linear(hidden2_dim, num_relations)

        # Initialisation
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.xavier_uniform_(self.fc3.weight)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        x = self.fc3(x)
        return x

class DualInputSpatialRelationModel(nn.Module):
    """
    Modèle complet avec architecture duale (concaténation)
    """

    def __init__(self, num_relations=len(SPATIAL_RELATIONS)):
        super(DualInputSpatialRelationModel, self).__init__()

        print(f"Modèle Dual Input Spatial Relation (CONCAT):")
        print(f"  - Relations: {num_relations}")
        print(f"  - Fusion: Concaténation")

        # Extracteur dual
        self.feature_extractor = DualVGGFeatureExtractor()

        # Classifieur MLP
        self.classifier = DualSpatialRelationMLP(
            input_dim=self.feature_extractor.fusion_dim,
            num_relations=num_relations
        )

    def forward(self, original_img, binary_mask):
        # Extraction des features duales
        fused_features = self.feature_extractor(original_img, binary_mask)

        # Classification
        output = self.classifier(fused_features)
        return output

# =============================================================================
# FONCTIONS D'ÉVALUATION AVEC MATRICE DE CONFUSION
# =============================================================================

def evaluate_dual_with_confusion_matrix(model, dataloader, criterion, device, class_names=None):
    """Évalue le modèle dual et retourne les métriques + matrice de confusion"""
    model.eval()
    running_loss = 0.0
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for original_imgs, binary_masks, labels, _ in tqdm(dataloader, desc='Evaluating Dual'):
            original_imgs = original_imgs.to(device)
            binary_masks = binary_masks.to(device)
            labels = labels.to(device)

            outputs = model(original_imgs, binary_masks)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculs des métriques
    accuracy = 100. * np.mean(np.array(all_predictions) == np.array(all_labels))
    avg_loss = running_loss / len(dataloader)

    # Matrice de confusion
    cm = confusion_matrix(all_labels, all_predictions)

    # Rapport de classification
    if class_names is None:
        class_names = SPATIAL_RELATIONS

    report = classification_report(
        all_labels, all_predictions,
        target_names=class_names,
        output_dict=True,
        zero_division=0
    )

    return avg_loss, accuracy, cm, report, all_predictions, all_labels

def plot_confusion_matrix(cm, class_names, title="Matrice de Confusion", normalize=False):
    """Affiche la matrice de confusion avec style amélioré"""
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        fmt = '.2f'
        title += " (Normalisée)"
    else:
        fmt = 'd'

    plt.figure(figsize=(12, 10))
    sns.heatmap(cm,
                annot=True,
                fmt=fmt,
                cmap='Blues',
                xticklabels=class_names,
                yticklabels=class_names,
                cbar_kws={'label': 'Proportion' if normalize else 'Nombre de prédictions'})

    plt.title(title, fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Prédictions', fontsize=14, fontweight='bold')
    plt.ylabel('Vraies étiquettes', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

def display_classification_metrics(report, title="Métriques de Classification"):
    """Affiche les métriques de classification sous forme de tableau"""
    print(f"\n{'='*60}")
    print(f"{title}")
    print(f"{'='*60}")

    # Métriques par classe
    print(f"{'Classe':<15} {'Précision':<10} {'Rappel':<10} {'F1-Score':<10} {'Support':<10}")
    print("-" * 60)

    for class_name in SPATIAL_RELATIONS:
        if class_name in report:
            metrics = report[class_name]
            print(f"{class_name:<15} {metrics['precision']:<10.3f} {metrics['recall']:<10.3f} "
                  f"{metrics['f1-score']:<10.3f} {metrics['support']:<10.0f}")

    # Métriques globales
    print("-" * 60)
    print(f"{'Accuracy':<15} {'':<10} {'':<10} {report['accuracy']:<10.3f} {report['macro avg']['support']:<10.0f}")
    print(f"{'Macro avg':<15} {report['macro avg']['precision']:<10.3f} {report['macro avg']['recall']:<10.3f} "
          f"{report['macro avg']['f1-score']:<10.3f} {report['macro avg']['support']:<10.0f}")
    print(f"{'Weighted avg':<15} {report['weighted avg']['precision']:<10.3f} {report['weighted avg']['recall']:<10.3f} "
          f"{report['weighted avg']['f1-score']:<10.3f} {report['weighted avg']['support']:<10.0f}")

# =============================================================================
# VISUALISATION DUAL INPUT AVEC MASQUE BINAIRE
# =============================================================================

def visualize_dual_binary_samples(dataset, num_samples=4):
    """Visualise des échantillons avec image originale et masque binaire"""
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 4*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    for i in range(min(num_samples, len(dataset))):
        original_img, binary_mask, label, metadata = dataset[i]

        # Dénormaliser pour affichage
        def denormalize(tensor):
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            return (tensor * std + mean).clamp(0, 1).permute(1, 2, 0)

        original_display = denormalize(original_img)
        mask_display = denormalize(binary_mask)

        # Image originale
        axes[i, 0].imshow(original_display)
        axes[i, 0].set_title(f"Original\n{metadata['subject']} - {metadata['object']}")
        axes[i, 0].axis('off')

        # Masque binaire
        axes[i, 1].imshow(mask_display, cmap='gray')
        axes[i, 1].set_title(f"Masque Binaire\nRelation: {metadata['relation']}")
        axes[i, 1].axis('off')

        # Informations
        axes[i, 2].axis('off')
        info_text = f"""
        Échantillon {i+1}:

        Sujet: {metadata['subject']}
        Objet: {metadata['object']}
        Relation: {metadata['relation']}
        Original: {metadata['original_relation']}

        Subject BBox: {metadata['subject_bbox']}
        Object BBox: {metadata['object_bbox']}

        Architecture:
        • Image complète (RGB)
        • Masque binaire (Blanc=objets, Noir=fond)
        • Concaténation des features VGG
        """
        axes[i, 2].text(0.1, 0.5, info_text, transform=axes[i, 2].transAxes,
                       fontsize=10, verticalalignment='center',
                       bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))

    plt.suptitle('Architecture Duale: Image Originale + Masque Binaire (CONCAT)',
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# =============================================================================
# TRANSFORMATIONS
# =============================================================================

def create_dual_transforms():
    """Transformations pour les images duales"""
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )

    train_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
        transforms.ToTensor(),
        normalize
    ])

    val_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        normalize
    ])

    return train_transform, val_transform

# =============================================================================
# FONCTIONS D'ENTRAÎNEMENT DUALES
# =============================================================================

def train_dual_epoch(model, dataloader, criterion, optimizer, device):
    """Entraîne une époque avec architecture duale"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    progress_bar = tqdm(dataloader, desc='Training Dual')
    for batch_idx, (original_imgs, binary_masks, labels, _) in enumerate(progress_bar):
        original_imgs = original_imgs.to(device)
        binary_masks = binary_masks.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(original_imgs, binary_masks)
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        progress_bar.set_postfix({
            'loss': running_loss / (batch_idx + 1),
            'acc': 100. * correct / total
        })

    return running_loss / len(dataloader), 100. * correct / total

# =============================================================================
# ENTRAÎNEMENT DUAL K-FOLD AVEC MATRICE DE CONFUSION (CONCAT SEULEMENT)
# =============================================================================

def train_dual_kfold_with_confusion_matrix(data_dir, k_folds=5, epochs=15):
    """Entraînement K-fold avec architecture duale et concaténation uniquement"""

    print("="*70)
    print("ENTRAÎNEMENT DUAL INPUT AVEC MASQUE BINAIRE (CONCAT UNIQUEMENT)")
    print("="*70)
    print(f"Architecture: VGG Dual + MLP")
    print(f"Fusion: Concaténation des features")
    print(f"Entrée 1: Image originale")
    print(f"Entrée 2: Masque binaire des bounding boxes")

    # Transformations
    train_transform, val_transform = create_dual_transforms()

    # Dataset dual avec masque binaire
    full_dataset = DualInputBinaryMaskDataset(
        data_dir=data_dir,
        split='train',
        transform=train_transform
    )

    if len(full_dataset) == 0:
        print("Dataset vide!")
        return []

    # Visualisation d'échantillons
    print("\nVisu échantillons dual input avec masque binaire:")
    visualize_dual_binary_samples(full_dataset, num_samples=3)

    # K-fold
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_indices, val_indices) in enumerate(kfold.split(full_dataset)):
        print(f"\n{'='*50}")
        print(f"FOLD {fold + 1}/{k_folds} - DUAL INPUT BINARY MASK (CONCAT)")
        print(f"{'='*50}")

        # Sous-ensembles
        train_subset = torch.utils.data.Subset(full_dataset, train_indices)
        val_subset = torch.utils.data.Subset(full_dataset, val_indices)

        # DataLoaders
        train_loader = DataLoader(
            train_subset, batch_size=BATCH_SIZE, shuffle=True,
            num_workers=2, pin_memory=True
        )

        val_loader = DataLoader(
            val_subset, batch_size=BATCH_SIZE, shuffle=False,
            num_workers=2, pin_memory=True
        )

        # Modèle dual avec concaténation
        model = DualInputSpatialRelationModel(num_relations=len(SPATIAL_RELATIONS))
        model = model.to(DEVICE)

        # Optimiseur (seulement MLP entraînable)
        optimizer = optim.AdamW(model.classifier.parameters(),
                               lr=LEARNING_RATE, weight_decay=0.01)
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

        # Scheduler
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

        # Entraînement
        train_losses, val_losses = [], []
        train_accs, val_accs = [], []
        best_val_acc = 0.0

        for epoch in range(epochs):
            print(f"\nEpoch {epoch+1}/{epochs}")

            # Training
            train_loss, train_acc = train_dual_epoch(
                model, train_loader, criterion, optimizer, DEVICE
            )
            train_losses.append(train_loss)
            train_accs.append(train_acc)

            # Validation avec matrice de confusion
            val_loss, val_acc, cm, report, predictions, true_labels = evaluate_dual_with_confusion_matrix(
                model, val_loader, criterion, DEVICE, SPATIAL_RELATIONS
            )
            val_losses.append(val_loss)
            val_accs.append(val_acc)

            scheduler.step()

            print(f"Train: {train_loss:.4f} / {train_acc:.2f}%")
            print(f"Val: {val_loss:.4f} / {val_acc:.2f}%")

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(),
                          f'best_dual_binary_mask_model_fold_{fold+1}_concat.pth')

        # Évaluation finale avec matrice de confusion
        print(f"\n{'='*40}")
        print(f"ÉVALUATION FINALE FOLD {fold + 1}")
        print(f"{'='*40}")

        final_val_loss, final_val_acc, final_cm, final_report, final_predictions, final_true_labels = evaluate_dual_with_confusion_matrix(
            model, val_loader, criterion, DEVICE, SPATIAL_RELATIONS
        )

        # Affichage des métriques
        display_classification_metrics(final_report, f"Métriques Fold {fold + 1}")

        # Affichage matrice de confusion
        plot_confusion_matrix(final_cm, SPATIAL_RELATIONS,
                            f"Matrice de Confusion - Fold {fold + 1}", normalize=False)
        plot_confusion_matrix(final_cm, SPATIAL_RELATIONS,
                            f"Matrice de Confusion Normalisée - Fold {fold + 1}", normalize=True)

        # Résultats fold
        fold_results.append({
            'fold': fold + 1,
            'train_losses': train_losses,
            'val_losses': val_losses,
            'train_accs': train_accs,
            'val_accs': val_accs,
            'best_val_acc': best_val_acc,
            'final_confusion_matrix': final_cm,
            'final_report': final_report,
            'final_predictions': final_predictions,
            'final_true_labels': final_true_labels,
            'model': model
        })

        print(f"Fold {fold + 1} - Meilleur: {best_val_acc:.2f}%")

    # Analyse globale des résultats
    print(f"\n{'='*70}")
    print("ANALYSE GLOBALE DES RÉSULTATS")
    print(f"{'='*70}")

    # Moyennes des métriques
    mean_acc = np.mean([r['best_val_acc'] for r in fold_results])
    std_acc = np.std([r['best_val_acc'] for r in fold_results])
    print(f"Accuracy moyenne: {mean_acc:.2f}% ± {std_acc:.2f}%")

    # Matrice de confusion globale
    global_cm = np.sum([r['final_confusion_matrix'] for r in fold_results], axis=0)
    global_predictions = np.concatenate([r['final_predictions'] for r in fold_results])
    global_true_labels = np.concatenate([r['final_true_labels'] for r in fold_results])

    # Rapport global
    global_report = classification_report(
        global_true_labels, global_predictions,
        target_names=SPATIAL_RELATIONS,
        output_dict=True,
        zero_division=0
    )

    # Affichage final
    display_classification_metrics(global_report, "Métriques Globales (tous les folds)")
    plot_confusion_matrix(global_cm, SPATIAL_RELATIONS,
                        "Matrice de Confusion Globale", normalize=False)
    plot_confusion_matrix(global_cm, SPATIAL_RELATIONS,
                        "Matrice de Confusion Globale Normalisée", normalize=True)

    return fold_results

# =============================================================================
# ANALYSE DES ERREURS
# =============================================================================

def analyze_prediction_errors(fold_results, top_k=5):
    """Analyse des erreurs de prédiction les plus fréquentes"""
    print(f"\n{'='*60}")
    print("ANALYSE DES ERREURS DE PRÉDICTION")
    print(f"{'='*60}")

    # Combiner toutes les prédictions
    all_predictions = np.concatenate([r['final_predictions'] for r in fold_results])
    all_true_labels = np.concatenate([r['final_true_labels'] for r in fold_results])

    # Identifier les erreurs
    errors = []
    for true_idx, pred_idx in zip(all_true_labels, all_predictions):
        if true_idx != pred_idx:
            true_relation = SPATIAL_RELATIONS[true_idx]
            pred_relation = SPATIAL_RELATIONS[pred_idx]
            errors.append((true_relation, pred_relation))

    # Compter les erreurs les plus fréquentes
    from collections import Counter
    error_counts = Counter(errors)

    print(f"Nombre total d'erreurs: {len(errors)}")
    print(f"Accuracy globale: {100 * (1 - len(errors) / len(all_predictions)):.2f}%")
    print(f"\nTop {top_k} erreurs les plus fréquentes:")
    print("-" * 60)

    for i, ((true_rel, pred_rel), count) in enumerate(error_counts.most_common(top_k)):
        percentage = 100 * count / len(errors)
        print(f"{i+1:2d}. {true_rel:>15} → {pred_rel:<15} : {count:3d} ({percentage:5.1f}%)")

def plot_learning_curves(fold_results):
    """Affiche les courbes d'apprentissage pour tous les folds"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Loss curves
    axes[0, 0].set_title('Courbes de Loss - Training', fontweight='bold')
    axes[0, 1].set_title('Courbes de Loss - Validation', fontweight='bold')
    axes[1, 0].set_title('Courbes d\'Accuracy - Training', fontweight='bold')
    axes[1, 1].set_title('Courbes d\'Accuracy - Validation', fontweight='bold')

    colors = ['blue', 'red', 'green', 'orange', 'purple']

    for i, result in enumerate(fold_results):
        epochs = range(1, len(result['train_losses']) + 1)
        color = colors[i % len(colors)]

        # Training loss
        axes[0, 0].plot(epochs, result['train_losses'],
                       color=color, label=f'Fold {result["fold"]}', alpha=0.7)

        # Validation loss
        axes[0, 1].plot(epochs, result['val_losses'],
                       color=color, label=f'Fold {result["fold"]}', alpha=0.7)

        # Training accuracy
        axes[1, 0].plot(epochs, result['train_accs'],
                       color=color, label=f'Fold {result["fold"]}', alpha=0.7)

        # Validation accuracy
        axes[1, 1].plot(epochs, result['val_accs'],
                       color=color, label=f'Fold {result["fold"]}', alpha=0.7)

    # Configuration des axes
    for ax in axes.flat:
        ax.set_xlabel('Époque')
        ax.legend()
        ax.grid(True, alpha=0.3)

    axes[0, 0].set_ylabel('Loss')
    axes[0, 1].set_ylabel('Loss')
    axes[1, 0].set_ylabel('Accuracy (%)')
    axes[1, 1].set_ylabel('Accuracy (%)')

    plt.suptitle('Courbes d\'Apprentissage - Architecture Duale avec Concaténation',
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# =============================================================================
# FONCTION PRINCIPALE DUAL AVEC CONCAT SEULEMENT
# =============================================================================

def main_dual_binary_mask_concat_only():
    """Expérience principale avec architecture duale et concaténation uniquement"""

    DATA_DIR = "data/spatialsense"

    print("="*80)
    print("ARCHITECTURE DUALE - MASQUE BINAIRE (CONCATÉNATION UNIQUEMENT)")
    print("="*80)
    print("Configuration:")
    print("     Entrée 1: Image complète (RGB)")
    print("     Entrée 2: Masque binaire des bounding boxes (Blanc=objets, Noir=fond)")
    print("     Fusion: Concaténation des features VGG (4096 + 4096 = 8192)")
    print("     Matrice de confusion détaillée")
    print("     Analyse des erreurs de prédiction")
    print("     Courbes d'apprentissage")

    if not os.path.exists(DATA_DIR):
        print(f"\nERREUR: {DATA_DIR} n'existe pas!")
        return

    # Entraînement avec concaténation uniquement
    print(f"\n{'='*60}")
    print("ENTRAÎNEMENT AVEC CONCATÉNATION")
    print(f"{'='*60}")

    results = train_dual_kfold_with_confusion_matrix(
        data_dir=DATA_DIR,
        k_folds=K_FOLDS,
        epochs=EPOCHS
    )

    if results:
        mean_acc = np.mean([r['best_val_acc'] for r in results])
        std_acc = np.std([r['best_val_acc'] for r in results])

        print(f"\n  Résultats finaux avec concaténation:")
        print(f"   Accuracy moyenne: {mean_acc:.2f}% ± {std_acc:.2f}%")
        print(f"   Meilleurs résultats par fold:")

        for r in results:
            print(f"     Fold {r['fold']}: {r['best_val_acc']:.2f}%")

        # Analyse des erreurs
        analyze_prediction_errors(results)

        # Courbes d'apprentissage
        plot_learning_curves(results)

        # Statistiques détaillées par relation
        print(f"\n{'='*60}")
        print("PERFORMANCE PAR RELATION SPATIALE")
        print(f"{'='*60}")

        # Calculer les métriques globales
        global_predictions = np.concatenate([r['final_predictions'] for r in results])
        global_true_labels = np.concatenate([r['final_true_labels'] for r in results])

        # Précision par classe
        for i, relation in enumerate(SPATIAL_RELATIONS):
            mask = global_true_labels == i
            if np.sum(mask) > 0:
                class_acc = 100 * np.mean(global_predictions[mask] == global_true_labels[mask])
                support = np.sum(mask)
                print(f"  {relation:<15}: {class_acc:6.2f}% ({support:3d} échantillons)")

        print(f"\n   Expérience architecture duale avec concaténation terminée!")
        print(f"    Performance globale: {mean_acc:.2f}% ± {std_acc:.2f}%")

        return results
    else:
        print("    Aucun résultat obtenu!")
        return None

# =============================================================================
# FONCTION DE TEST RAPIDE
# =============================================================================

def quick_test_concat():
    """Test rapide avec moins d'époques et de folds pour validation"""

    DATA_DIR = "data/spatialsense"

    print("="*60)
    print("TEST RAPIDE - ARCHITECTURE DUALE CONCAT")
    print("="*60)

    if not os.path.exists(DATA_DIR):
        print(f"ERREUR: {DATA_DIR} n'existe pas!")
        return

    # Test avec paramètres réduits
    results = train_dual_kfold_with_confusion_matrix(
        data_dir=DATA_DIR,
        k_folds=3,  # Moins de folds
        epochs=5    # Moins d'époques
    )

    if results:
        mean_acc = np.mean([r['best_val_acc'] for r in results])
        print(f"\n     Test rapide terminé!")
        print(f"   Accuracy moyenne: {mean_acc:.2f}%")

        return results
    else:
        print("    Échec du test rapide!")
        return None

if __name__ == "__main__":
    # Choix entre test rapide ou expérience complète
    import sys

    if len(sys.argv) > 1 and sys.argv[1] == "quick":
        print("Mode test rapide activé...")
        quick_test_concat()
    else:
        print("Mode expérience complète activé...")
        main_dual_binary_mask_concat_only()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import json
import os
from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# CONFIGURATION GLOBALE
# =============================================================================

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device utilisé: {DEVICE}")

# Hyperparamètres
BATCH_SIZE = 16  # Augmenté car une seule image en entrée
LEARNING_RATE = 0.001
EPOCHS = 15
K_FOLDS = 5
IMG_SIZE = 224
DROPOUT_RATE = 0.4

# Relations spatiales SpatialSense+
SPATIAL_RELATIONS = [
    'above', 'behind', 'in', 'in front of', 'next to',
    'on', 'to the left of', 'to the right of', 'under'
]

print(f"Architecture IMAGE + BOUNDING BOX COORDINATES:")
print(f"  - Relations: {len(SPATIAL_RELATIONS)}")
print(f"  - Entrée: Image originale + coordonnées des bounding boxes")
print(f"  - Features VGG + features géométriques")

# =============================================================================
# DATASET IMAGE + BOUNDING BOX COORDINATES
# =============================================================================

class ImageBBoxDataset(Dataset):
    """
    Dataset SpatialSense+ avec:
    - Image originale uniquement
    - Coordonnées des bounding boxes comme features supplémentaires
    """

    def __init__(self, data_dir, split='train', transform=None):
        self.data_dir = data_dir
        self.split = split
        self.transform = transform

        self.relation_to_idx = {rel: idx for idx, rel in enumerate(SPATIAL_RELATIONS)}
        self.idx_to_relation = {idx: rel for rel, idx in self.relation_to_idx.items()}

        self.annotations = self._load_annotations()
        self.data_samples = self._prepare_samples()

        print(f"Dataset Image + BBox Coordinates SpatialSense+ {split}: {len(self.data_samples)} échantillons")
        if len(self.data_samples) > 0:
            self._print_statistics()

    def _load_annotations(self):
        annotations_path = os.path.join(self.data_dir, 'annotations.json')
        try:
            with open(annotations_path, 'r') as f:
                return json.load(f)
        except FileNotFoundError:
            print(f"Erreur: {annotations_path} non trouvé!")
            return []

    def _find_image_path(self, image_url):
        base_dir = os.path.join(self.data_dir, "images", "images")
        filename = os.path.basename(image_url)

        if "staticflickr" in image_url or len(filename.split('_')) == 2:
            return os.path.join(base_dir, "flickr", filename)
        else:
            return os.path.join(base_dir, "nyu", filename)

    def _prepare_samples(self):
        samples = []
        images_not_found = 0

        for img_data in self.annotations:
            if img_data['split'] != self.split:
                continue

            img_path = self._find_image_path(img_data['url'])

            if not os.path.exists(img_path):
                images_not_found += 1
                continue

            for ann in img_data['annotations']:
                if ann['label'] and ann['predicate'].lower().strip() in [rel.lower() for rel in SPATIAL_RELATIONS]:
                    # Trouver la relation correspondante
                    relation = None
                    for rel in SPATIAL_RELATIONS:
                        if rel.lower() == ann['predicate'].lower().strip():
                            relation = rel
                            break

                    if relation and 'bbox' in ann['subject'] and 'bbox' in ann['object']:
                        sample = {
                            'image_path': img_path,
                            'subject': ann['subject']['name'],
                            'object': ann['object']['name'],
                            'relation': relation,
                            'original_relation': ann['predicate'],
                            'subject_bbox': ann['subject']['bbox'],  # [y1, y2, x1, x2]
                            'object_bbox': ann['object']['bbox'],   # [y1, y2, x1, x2]
                            'image_width': img_data['width'],
                            'image_height': img_data['height']
                        }
                        samples.append(sample)

        if images_not_found > 0:
            print(f"Images non trouvées: {images_not_found}")

        return samples

    def _normalize_bbox(self, bbox, img_width, img_height):
        """
        Normalise les coordonnées de bounding box entre 0 et 1
        bbox format: [y1, y2, x1, x2]
        """
        y1, y2, x1, x2 = bbox

        # Normalisation
        norm_y1 = y1 / img_height
        norm_y2 = y2 / img_height
        norm_x1 = x1 / img_width
        norm_x2 = x2 / img_width

        # Calcul des features géométriques
        width = abs(norm_x2 - norm_x1)
        height = abs(norm_y2 - norm_y1)
        area = width * height
        center_x = (norm_x1 + norm_x2) / 2
        center_y = (norm_y1 + norm_y2) / 2
        aspect_ratio = width / (height + 1e-8)  # Éviter division par zéro

        return [norm_x1, norm_y1, norm_x2, norm_y2, width, height, area, center_x, center_y, aspect_ratio]

    def _compute_spatial_features(self, subject_bbox, object_bbox, img_width, img_height):
        """
        Calcule des features spatiales entre les deux bounding boxes
        """
        # Normaliser les bounding boxes
        subj_features = self._normalize_bbox(subject_bbox, img_width, img_height)
        obj_features = self._normalize_bbox(object_bbox, img_width, img_height)

        # Features individuelles (20 features: 10 + 10)
        individual_features = subj_features + obj_features

        # Features relationnelles
        subj_center_x, subj_center_y = subj_features[7], subj_features[8]
        obj_center_x, obj_center_y = obj_features[7], obj_features[8]

        # Distance entre centres
        distance = np.sqrt((subj_center_x - obj_center_x)**2 + (subj_center_y - obj_center_y)**2)

        # Direction relative (angle)
        angle = np.arctan2(obj_center_y - subj_center_y, obj_center_x - subj_center_x)

        # Différences de taille
        area_ratio = (subj_features[6] + 1e-8) / (obj_features[6] + 1e-8)

        # Chevauchement (IoU approximatif)
        subj_x1, subj_y1, subj_x2, subj_y2 = subj_features[:4]
        obj_x1, obj_y1, obj_x2, obj_y2 = obj_features[:4]

        # Intersection
        inter_x1 = max(subj_x1, obj_x1)
        inter_y1 = max(subj_y1, obj_y1)
        inter_x2 = min(subj_x2, obj_x2)
        inter_y2 = min(subj_y2, obj_y2)

        if inter_x2 > inter_x1 and inter_y2 > inter_y1:
            intersection = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
            union = subj_features[6] + obj_features[6] - intersection
            iou = intersection / (union + 1e-8)
        else:
            iou = 0.0
            intersection = 0.0

        # Position relative
        relative_x = obj_center_x - subj_center_x
        relative_y = obj_center_y - subj_center_y

        # Features relationnelles (8 features)
        relational_features = [
            distance, angle, area_ratio, iou,
            intersection, relative_x, relative_y,
            1.0 if distance < 0.1 else 0.0  # Très proche
        ]

        # Total: 28 features (20 individuelles + 8 relationnelles)
        return individual_features + relational_features

    def _print_statistics(self):
        relation_counts = Counter([s['relation'] for s in self.data_samples])
        print(f"\nDistribution Image + BBox dans {self.split}:")
        total = len(self.data_samples)

        for relation in SPATIAL_RELATIONS:
            count = relation_counts.get(relation, 0)
            percentage = count / total * 100 if total > 0 else 0
            print(f"   {relation}: {count} ({percentage:.1f}%)")

    def __len__(self):
        return len(self.data_samples)

    def __getitem__(self, idx):
        sample = self.data_samples[idx]

        try:
            # Charger l'image originale uniquement
            image = Image.open(sample['image_path']).convert('RGB')

            # Appliquer les transformations
            if self.transform:
                image = self.transform(image)

            # Calculer les features spatiales des bounding boxes
            spatial_features = self._compute_spatial_features(
                sample['subject_bbox'],
                sample['object_bbox'],
                sample['image_width'],
                sample['image_height']
            )

            # Convertir en tensor
            spatial_features = torch.tensor(spatial_features, dtype=torch.float32)

            label = self.relation_to_idx[sample['relation']]

            metadata = {
                'subject': sample['subject'],
                'object': sample['object'],
                'relation': sample['relation'],
                'original_relation': sample['original_relation'],
                'subject_bbox': sample['subject_bbox'],
                'object_bbox': sample['object_bbox'],
                'spatial_features_count': len(spatial_features)
            }

            return image, spatial_features, label, metadata

        except Exception as e:
            print(f"Erreur chargement {sample['image_path']}: {e}")
            # Image et features par défaut en cas d'erreur
            dummy_image = Image.new('RGB', (224, 224), color='gray')

            if self.transform:
                dummy_image = self.transform(dummy_image)
            else:
                dummy_image = torch.zeros(3, 224, 224)

            dummy_features = torch.zeros(28, dtype=torch.float32)  # 28 features spatiales

            return dummy_image, dummy_features, 0, {
                'subject': 'error', 'object': 'error',
                'relation': 'next to', 'original_relation': 'error',
                'subject_bbox': [0, 0, 0, 0], 'object_bbox': [0, 0, 0, 0],
                'spatial_features_count': 28
            }

# =============================================================================
# ARCHITECTURE VGG + BOUNDING BOX FEATURES
# =============================================================================

class VGGBBoxFeatureExtractor(nn.Module):
    """
    Extracteur de features combinant:
    - Features VGG de l'image
    - Features géométriques des bounding boxes
    """

    def __init__(self, spatial_features_dim=28):
        super(VGGBBoxFeatureExtractor, self).__init__()

        self.spatial_features_dim = spatial_features_dim

        # VGG16 pré-entraîné pour l'image
        vgg16 = models.vgg16(pretrained=True)

        # Extraction des couches
        self.features = vgg16.features
        self.avgpool = vgg16.avgpool

        # FC-7 de VGG
        classifier_layers = list(vgg16.classifier.children())[:6]
        self.fc7 = nn.Sequential(*classifier_layers)

        # Gel des poids VGG
        for param in self.parameters():
            param.requires_grad = False

        # Projection des features spatiales
        self.spatial_projection = nn.Sequential(
            nn.Linear(spatial_features_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        # Dimension de sortie après concaténation
        self.fusion_dim = 4096 + 512  # VGG FC-7 + projected spatial features

        print(f"VGG + BBox Feature Extractor:")
        print(f"  - VGG FC-7: 4096 features")
        print(f"  - Spatial features: {spatial_features_dim} → 512 (projetées)")
        print(f"  - Dimension sortie: {self.fusion_dim}")

    def forward(self, image, spatial_features):
        # Extraction features VGG de l'image
        vgg_features = self._extract_vgg_features(image)

        # Projection des features spatiales
        projected_spatial = self.spatial_projection(spatial_features)

        # Concaténation
        fused_features = torch.cat([vgg_features, projected_spatial], dim=1)

        return fused_features

    def _extract_vgg_features(self, x):
        """Extraction des features VGG FC-7"""
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc7(x)
        return x

class SpatialRelationMLP(nn.Module):
    """
    MLP pour la classification des relations spatiales
    """

    def __init__(self, input_dim, hidden1_dim=512, hidden2_dim=256,
                 num_relations=len(SPATIAL_RELATIONS), dropout_rate=0.4):
        super(SpatialRelationMLP, self).__init__()

        print(f"Spatial Relation MLP:")
        print(f"  - Input: {input_dim}")
        print(f"  - Hidden 1: {hidden1_dim}")
        print(f"  - Hidden 2: {hidden2_dim}")
        print(f"  - Output: {num_relations}")
        print(f"  - Dropout: {dropout_rate}")

        # Couches MLP
        self.fc1 = nn.Linear(input_dim, hidden1_dim)
        self.bn1 = nn.BatchNorm1d(hidden1_dim)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)

        self.fc2 = nn.Linear(hidden1_dim, hidden2_dim)
        self.bn2 = nn.BatchNorm1d(hidden2_dim)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_rate)

        self.fc3 = nn.Linear(hidden2_dim, num_relations)

        # Initialisation
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.xavier_uniform_(self.fc3.weight)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        x = self.fc3(x)
        return x

class ImageBBoxSpatialRelationModel(nn.Module):
    """
    Modèle complet avec image + coordonnées bounding boxes
    """

    def __init__(self, num_relations=len(SPATIAL_RELATIONS), spatial_features_dim=28):
        super(ImageBBoxSpatialRelationModel, self).__init__()

        print(f"Modèle Image + BBox Spatial Relation:")
        print(f"  - Relations: {num_relations}")
        print(f"  - Features spatiales: {spatial_features_dim}")

        # Extracteur de features
        self.feature_extractor = VGGBBoxFeatureExtractor(spatial_features_dim=spatial_features_dim)

        # Classifieur MLP
        self.classifier = SpatialRelationMLP(
            input_dim=self.feature_extractor.fusion_dim,
            num_relations=num_relations
        )

    def forward(self, image, spatial_features):
        # Extraction et fusion des features
        fused_features = self.feature_extractor(image, spatial_features)

        # Classification
        output = self.classifier(fused_features)
        return output

# =============================================================================
# EARLY STOPPING CLASS
# =============================================================================

class EarlyStopping:
    """Early Stopping pour éviter l'overfitting"""

    def __init__(self, patience=5, min_delta=0.001, restore_best_weights=True, verbose=True):
        """
        Args:
            patience (int): Nombre d'époques à attendre sans amélioration (défaut: 5)
            min_delta (float): Amélioration minimale considérée comme significative
            restore_best_weights (bool): Restaurer les meilleurs poids à la fin
            verbose (bool): Afficher les messages
        """
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.verbose = verbose

        self.best_score = None
        self.counter = 0
        self.best_weights = None
        self.early_stop = False

    def __call__(self, val_score, model):
        """
        Vérifie si l'entraînement doit s'arrêter

        Args:
            val_score (float): Score de validation (accuracy %)
            model: Modèle PyTorch
        """
        if self.best_score is None:
            self.best_score = val_score
            self.save_checkpoint(model)
            if self.verbose:
                print(f"   Early Stopping: Score initial = {val_score:.3f}%")

        elif val_score < self.best_score + self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f"     Early Stopping: {self.counter}/{self.patience} (Best: {self.best_score:.3f}%)")

            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print(f"🛑 Early Stopping déclenché! Restauration du meilleur modèle (Accuracy: {self.best_score:.3f}%)")

        else:
            improvement = val_score - self.best_score
            if self.verbose:
                print(f"    Amélioration: {improvement:.3f}% (Nouveau best: {val_score:.3f}%)")

            self.best_score = val_score
            self.save_checkpoint(model)
            self.counter = 0

    def save_checkpoint(self, model):
        """Sauvegarde les poids du modèle"""
        if self.restore_best_weights:
            self.best_weights = {key: value.cpu().clone() for key, value in model.state_dict().items()}

    def restore_best_weights_to_model(self, model):
        """Restaure les meilleurs poids dans le modèle"""
        if self.best_weights is not None:
            # Restaurer les poids sur le bon device
            device = next(model.parameters()).device
            best_weights_on_device = {key: value.to(device) for key, value in self.best_weights.items()}
            model.load_state_dict(best_weights_on_device)

# =============================================================================
# FONCTIONS D'ÉVALUATION AVEC MATRICE DE CONFUSION
# =============================================================================

def evaluate_with_confusion_matrix(model, dataloader, criterion, device, class_names=None):
    """Évalue le modèle et retourne les métriques + matrice de confusion"""
    model.eval()
    running_loss = 0.0
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for images, spatial_features, labels, _ in tqdm(dataloader, desc='Evaluating'):
            images = images.to(device)
            spatial_features = spatial_features.to(device)
            labels = labels.to(device)

            outputs = model(images, spatial_features)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculs des métriques
    accuracy = 100. * np.mean(np.array(all_predictions) == np.array(all_labels))
    avg_loss = running_loss / len(dataloader)

    # Matrice de confusion
    cm = confusion_matrix(all_labels, all_predictions)

    # Rapport de classification
    if class_names is None:
        class_names = SPATIAL_RELATIONS

    report = classification_report(
        all_labels, all_predictions,
        target_names=class_names,
        output_dict=True,
        zero_division=0
    )

    return avg_loss, accuracy, cm, report, all_predictions, all_labels

def plot_confusion_matrix(cm, class_names, title="Matrice de Confusion", normalize=False):
    """Affiche la matrice de confusion avec style amélioré"""
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        fmt = '.2f'
        title += " (Normalisée)"
    else:
        fmt = 'd'

    plt.figure(figsize=(12, 10))
    sns.heatmap(cm,
                annot=True,
                fmt=fmt,
                cmap='Blues',
                xticklabels=class_names,
                yticklabels=class_names,
                cbar_kws={'label': 'Proportion' if normalize else 'Nombre de prédictions'})

    plt.title(title, fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Prédictions', fontsize=14, fontweight='bold')
    plt.ylabel('Vraies étiquettes', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

def display_classification_metrics(report, title="Métriques de Classification"):
    """Affiche les métriques de classification sous forme de tableau"""
    print(f"\n{'='*60}")
    print(f"{title}")
    print(f"{'='*60}")

    # Métriques par classe
    print(f"{'Classe':<15} {'Précision':<10} {'Rappel':<10} {'F1-Score':<10} {'Support':<10}")
    print("-" * 60)

    for class_name in SPATIAL_RELATIONS:
        if class_name in report:
            metrics = report[class_name]
            print(f"{class_name:<15} {metrics['precision']:<10.3f} {metrics['recall']:<10.3f} "
                  f"{metrics['f1-score']:<10.3f} {metrics['support']:<10.0f}")

    # Métriques globales
    print("-" * 60)
    print(f"{'Accuracy':<15} {'':<10} {'':<10} {report['accuracy']:<10.3f} {report['macro avg']['support']:<10.0f}")
    print(f"{'Macro avg':<15} {report['macro avg']['precision']:<10.3f} {report['macro avg']['recall']:<10.3f} "
          f"{report['macro avg']['f1-score']:<10.3f} {report['macro avg']['support']:<10.0f}")
    print(f"{'Weighted avg':<15} {report['weighted avg']['precision']:<10.3f} {report['weighted avg']['recall']:<10.3f} "
          f"{report['weighted avg']['f1-score']:<10.3f} {report['weighted avg']['support']:<10.0f}")

# =============================================================================
# VISUALISATION AVEC BOUNDING BOXES
# =============================================================================

def visualize_image_bbox_samples(dataset, num_samples=4):
    """Visualise des échantillons avec image et informations des bounding boxes"""
    fig, axes = plt.subplots(num_samples, 2, figsize=(12, 4*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    for i in range(min(num_samples, len(dataset))):
        image, spatial_features, label, metadata = dataset[i]

        # Dénormaliser pour affichage
        def denormalize(tensor):
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            return (tensor * std + mean).clamp(0, 1).permute(1, 2, 0)

        image_display = denormalize(image)

        # Image originale
        axes[i, 0].imshow(image_display)
        axes[i, 0].set_title(f"Image Originale\n{metadata['subject']} {metadata['relation']} {metadata['object']}")
        axes[i, 0].axis('off')

        # Informations détaillées
        axes[i, 1].axis('off')
        info_text = f"""
        Échantillon {i+1}:

        Sujet: {metadata['subject']}
        Objet: {metadata['object']}
        Relation: {metadata['relation']}
        Original: {metadata['original_relation']}

        Subject BBox: {metadata['subject_bbox']}
        Object BBox: {metadata['object_bbox']}

        Features Spatiales ({metadata['spatial_features_count']}):
        • Coordinates normalisées (8)
        • Dimensions et aires (6)
        • Centres et ratios (6)
        • Features relationnelles (8)

        Architecture:
        • VGG features (4096)
        • Spatial features projetées (512)
        • Total: 4608 features
        """
        axes[i, 1].text(0.1, 0.5, info_text, transform=axes[i, 1].transAxes,
                       fontsize=9, verticalalignment='center',
                       bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.8))

    plt.suptitle('Architecture: Image + Coordonnées Bounding Boxes',
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# =============================================================================
# TRANSFORMATIONS
# =============================================================================

def create_transforms():
    """Transformations pour les images"""
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )

    train_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
        transforms.ToTensor(),
        normalize
    ])

    val_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        normalize
    ])

    return train_transform, val_transform

# =============================================================================
# FONCTIONS D'ENTRAÎNEMENT
# =============================================================================

def train_epoch(model, dataloader, criterion, optimizer, device):
    """Entraîne une époque"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    progress_bar = tqdm(dataloader, desc='Training')
    for batch_idx, (images, spatial_features, labels, _) in enumerate(progress_bar):
        images = images.to(device)
        spatial_features = spatial_features.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images, spatial_features)
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        progress_bar.set_postfix({
            'loss': running_loss / (batch_idx + 1),
            'acc': 100. * correct / total
        })

    return running_loss / len(dataloader), 100. * correct / total

# =============================================================================
# ENTRAÎNEMENT K-FOLD AVEC MATRICE DE CONFUSION ET EARLY STOPPING
# =============================================================================

def train_kfold_with_confusion_matrix(data_dir, k_folds=5, epochs=15,
                                     early_stopping_patience=5, min_delta=0.001):
    """Entraînement K-fold avec image + bounding box coordinates et Early Stopping"""

    print("="*70)
    print("ENTRAÎNEMENT IMAGE + BOUNDING BOX COORDINATES + EARLY STOPPING")
    print("="*70)
    print(f"Architecture: VGG + Spatial Features MLP")
    print(f"Entrée: Image originale + 28 features spatiales des bounding boxes")
    print(f"Early Stopping: Patience={early_stopping_patience}, Min Delta={min_delta}")

    # Transformations
    train_transform, val_transform = create_transforms()

    # Dataset
    full_dataset = ImageBBoxDataset(
        data_dir=data_dir,
        split='train',
        transform=train_transform
    )

    if len(full_dataset) == 0:
        print("Dataset vide!")
        return []

    # Visualisation d'échantillons
    print("\nVisu échantillons image + bounding box coordinates:")
    visualize_image_bbox_samples(full_dataset, num_samples=3)

    # K-fold
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_indices, val_indices) in enumerate(kfold.split(full_dataset)):
        print(f"\n{'='*50}")
        print(f"FOLD {fold + 1}/{k_folds} - IMAGE + BBOX COORDINATES + EARLY STOPPING")
        print(f"{'='*50}")

        # Sous-ensembles
        train_subset = torch.utils.data.Subset(full_dataset, train_indices)
        val_subset = torch.utils.data.Subset(full_dataset, val_indices)

        # DataLoaders
        train_loader = DataLoader(
            train_subset, batch_size=BATCH_SIZE, shuffle=True,
            num_workers=2, pin_memory=True
        )

        val_loader = DataLoader(
            val_subset, batch_size=BATCH_SIZE, shuffle=False,
            num_workers=2, pin_memory=True
        )

        # Modèle
        model = ImageBBoxSpatialRelationModel(
            num_relations=len(SPATIAL_RELATIONS),
            spatial_features_dim=28
        )
        model = model.to(DEVICE)

        # Optimiseur (entraîner seulement les couches non-VGG)
        trainable_params = []
        trainable_params.extend(model.feature_extractor.spatial_projection.parameters())
        trainable_params.extend(model.classifier.parameters())

        optimizer = optim.AdamW(trainable_params, lr=LEARNING_RATE, weight_decay=0.01)
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

        # Scheduler
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

        # Early Stopping
        early_stopping = EarlyStopping(
            patience=early_stopping_patience,
            min_delta=min_delta,
            restore_best_weights=True,
            verbose=True
        )

        # Entraînement
        train_losses, val_losses = [], []
        train_accs, val_accs = [], []
        best_val_acc = 0.0
        actual_epochs = 0

        for epoch in range(epochs):
            print(f"\n   Epoch {epoch+1}/{epochs}")

            # Training
            train_loss, train_acc = train_epoch(
                model, train_loader, criterion, optimizer, DEVICE
            )
            train_losses.append(train_loss)
            train_accs.append(train_acc)

            # Validation avec matrice de confusion
            val_loss, val_acc, cm, report, predictions, true_labels = evaluate_with_confusion_matrix(
                model, val_loader, criterion, DEVICE, SPATIAL_RELATIONS
            )
            val_losses.append(val_loss)
            val_accs.append(val_acc)

            scheduler.step()
            actual_epochs = epoch + 1

            print(f"       Train: {train_loss:.4f} / {train_acc:.2f}%")
            print(f"   Val: {val_loss:.4f} / {val_acc:.2f}%")

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(),
                          f'best_image_bbox_model_fold_{fold+1}.pth')

            # Vérification Early Stopping
            early_stopping(val_acc, model)

            if early_stopping.early_stop:
                print(f"🛑 Arrêt anticipé à l'époque {epoch+1}")
                # Restaurer les meilleurs poids
                early_stopping.restore_best_weights_to_model(model)
                break

        # Si pas d'arrêt anticipé, restaurer quand même les meilleurs poids
        if not early_stopping.early_stop:
            early_stopping.restore_best_weights_to_model(model)
            print(f"   Entraînement terminé ({epochs} époques). Meilleurs poids restaurés.")

        # Évaluation finale avec matrice de confusion
        print(f"\n{'='*40}")
        print(f"ÉVALUATION FINALE FOLD {fold + 1}")
        print(f"{'='*40}")

        final_val_loss, final_val_acc, final_cm, final_report, final_predictions, final_true_labels = evaluate_with_confusion_matrix(
            model, val_loader, criterion, DEVICE, SPATIAL_RELATIONS
        )

        # Affichage des métriques
        display_classification_metrics(final_report, f"Métriques Fold {fold + 1}")

        # Affichage matrice de confusion
        plot_confusion_matrix(final_cm, SPATIAL_RELATIONS,
                            f"Matrice de Confusion - Fold {fold + 1}", normalize=False)
        plot_confusion_matrix(final_cm, SPATIAL_RELATIONS,
                            f"Matrice de Confusion Normalisée - Fold {fold + 1}", normalize=True)

        # Résultats fold
        fold_results.append({
            'fold': fold + 1,
            'train_losses': train_losses,
            'val_losses': val_losses,
            'train_accs': train_accs,
            'val_accs': val_accs,
            'best_val_acc': early_stopping.best_score,  # Utiliser le meilleur score d'early stopping
            'final_confusion_matrix': final_cm,
            'final_report': final_report,
            'final_predictions': final_predictions,
            'final_true_labels': final_true_labels,
            'model': model,
            'actual_epochs': actual_epochs,
            'early_stopped': early_stopping.early_stop
        })

        print(f"   Fold {fold + 1} - Meilleur: {early_stopping.best_score:.2f}% (après {actual_epochs} époques)")
        if early_stopping.early_stop:
            print(f"   🛑 Arrêt anticipé activé")
        else:
            print(f"      Entraînement complet")

    # Analyse globale des résultats
    print(f"\n{'='*70}")
    print("ANALYSE GLOBALE DES RÉSULTATS AVEC EARLY STOPPING")
    print(f"{'='*70}")

    # Moyennes des métriques
    mean_acc = np.mean([r['best_val_acc'] for r in fold_results])
    std_acc = np.std([r['best_val_acc'] for r in fold_results])
    print(f"Accuracy moyenne: {mean_acc:.2f}% ± {std_acc:.2f}%")

    # Statistiques Early Stopping
    early_stopped_folds = [r for r in fold_results if r['early_stopped']]
    mean_epochs = np.mean([r['actual_epochs'] for r in fold_results])

    print(f"   Statistiques Early Stopping:")
    print(f"   Folds avec arrêt anticipé: {len(early_stopped_folds)}/{k_folds}")
    print(f"   Époques moyennes: {mean_epochs:.1f}/{epochs}")

    for r in fold_results:
        status = "🛑 Arrêté" if r['early_stopped'] else "   Complet"
        print(f"   Fold {r['fold']}: {r['actual_epochs']:2d} époques - {status}")

    # Matrice de confusion globale
    global_cm = np.sum([r['final_confusion_matrix'] for r in fold_results], axis=0)
    global_predictions = np.concatenate([r['final_predictions'] for r in fold_results])
    global_true_labels = np.concatenate([r['final_true_labels'] for r in fold_results])

    # Rapport global
    global_report = classification_report(
        global_true_labels, global_predictions,
        target_names=SPATIAL_RELATIONS,
        output_dict=True,
        zero_division=0
    )

    # Affichage final
    display_classification_metrics(global_report, "Métriques Globales (tous les folds)")
    plot_confusion_matrix(global_cm, SPATIAL_RELATIONS,
                        "Matrice de Confusion Globale", normalize=False)
    plot_confusion_matrix(global_cm, SPATIAL_RELATIONS,
                        "Matrice de Confusion Globale Normalisée", normalize=True)

    return fold_results

# =============================================================================
# ANALYSE DES ERREURS
# =============================================================================

def analyze_prediction_errors(fold_results, top_k=5):
    """Analyse des erreurs de prédiction les plus fréquentes"""
    print(f"\n{'='*60}")
    print("ANALYSE DES ERREURS DE PRÉDICTION")
    print(f"{'='*60}")

    # Combiner toutes les prédictions
    all_predictions = np.concatenate([r['final_predictions'] for r in fold_results])
    all_true_labels = np.concatenate([r['final_true_labels'] for r in fold_results])

    # Identifier les erreurs
    errors = []
    for true_idx, pred_idx in zip(all_true_labels, all_predictions):
        if true_idx != pred_idx:
            true_relation = SPATIAL_RELATIONS[true_idx]
            pred_relation = SPATIAL_RELATIONS[pred_idx]
            errors.append((true_relation, pred_relation))

    # Compter les erreurs les plus fréquentes
    error_counts = Counter(errors)

    print(f"Nombre total d'erreurs: {len(errors)}")
    print(f"Accuracy globale: {100 * (1 - len(errors) / len(all_predictions)):.2f}%")
    print(f"\nTop {top_k} erreurs les plus fréquentes:")
    print("-" * 60)

    for i, ((true_rel, pred_rel), count) in enumerate(error_counts.most_common(top_k)):
        percentage = 100 * count / len(errors)
        print(f"{i+1:2d}. {true_rel:>15} → {pred_rel:<15} : {count:3d} ({percentage:5.1f}%)")

def plot_learning_curves_with_early_stopping(fold_results):
    """Affiche les courbes d'apprentissage avec indication de l'Early Stopping"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Loss curves
    axes[0, 0].set_title('Courbes de Loss - Training', fontweight='bold')
    axes[0, 1].set_title('Courbes de Loss - Validation', fontweight='bold')
    axes[1, 0].set_title('Courbes d\'Accuracy - Training', fontweight='bold')
    axes[1, 1].set_title('Courbes d\'Accuracy - Validation', fontweight='bold')

    colors = ['blue', 'red', 'green', 'orange', 'purple']

    for i, result in enumerate(fold_results):
        epochs = range(1, len(result['train_losses']) + 1)
        color = colors[i % len(colors)]

        # Style de ligne selon early stopping
        linestyle = '--' if result['early_stopped'] else '-'
        alpha = 0.8 if result['early_stopped'] else 0.7

        label = f'Fold {result["fold"]}'
        if result['early_stopped']:
            label += f' (ES@{result["actual_epochs"]})'

        # Training loss
        axes[0, 0].plot(epochs, result['train_losses'],
                       color=color, label=label, alpha=alpha, linestyle=linestyle)

        # Validation loss
        axes[0, 1].plot(epochs, result['val_losses'],
                       color=color, label=label, alpha=alpha, linestyle=linestyle)

        # Training accuracy
        axes[1, 0].plot(epochs, result['train_accs'],
                       color=color, label=label, alpha=alpha, linestyle=linestyle)

        # Validation accuracy
        axes[1, 1].plot(epochs, result['val_accs'],
                       color=color, label=label, alpha=alpha, linestyle=linestyle)

        # Marquer le point d'arrêt si early stopping
        if result['early_stopped']:
            stop_epoch = result['actual_epochs']
            # Marquer sur validation accuracy
            axes[1, 1].scatter(stop_epoch, result['val_accs'][stop_epoch-1],
                             color=color, s=100, marker='X', zorder=5)

    # Configuration des axes
    for ax in axes.flat:
        ax.set_xlabel('Époque')
        ax.legend()
        ax.grid(True, alpha=0.3)

    axes[0, 0].set_ylabel('Loss')
    axes[0, 1].set_ylabel('Loss')
    axes[1, 0].set_ylabel('Accuracy (%)')
    axes[1, 1].set_ylabel('Accuracy (%)')

    plt.suptitle('Courbes d\'Apprentissage avec Early Stopping\n(X = Arrêt anticipé)',
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

def analyze_spatial_features_importance(fold_results):
    """Analyse l'importance des features spatiales"""
    print(f"\n{'='*60}")
    print("ANALYSE DES FEATURES SPATIALES")
    print(f"{'='*60}")

    feature_names = [
        # Features Subject (10)
        'subj_x1', 'subj_y1', 'subj_x2', 'subj_y2', 'subj_width', 'subj_height',
        'subj_area', 'subj_center_x', 'subj_center_y', 'subj_aspect_ratio',

        # Features Object (10)
        'obj_x1', 'obj_y1', 'obj_x2', 'obj_y2', 'obj_width', 'obj_height',
        'obj_area', 'obj_center_x', 'obj_center_y', 'obj_aspect_ratio',

        # Features Relationnelles (8)
        'distance', 'angle', 'area_ratio', 'iou', 'intersection',
        'relative_x', 'relative_y', 'very_close'
    ]

    print("Features Spatiales Calculées (28 au total):")
    print("-" * 60)

    categories = {
        'Sujet (10)': feature_names[:10],
        'Objet (10)': feature_names[10:20],
        'Relationnelles (8)': feature_names[20:]
    }

    for category, features in categories.items():
        print(f"\n{category}:")
        for i, feature in enumerate(features):
            print(f"  {i+1:2d}. {feature}")

    print(f"\nCes features sont projetées de 28 → 512 dimensions avant fusion avec VGG.")

# =============================================================================
# FONCTION PRINCIPALE
# =============================================================================

def main_image_bbox_experiment():
    """Expérience principale avec image + coordonnées bounding boxes et Early Stopping"""

    DATA_DIR = "data/spatialsense"


    if not os.path.exists(DATA_DIR):
        print(f"\nERREUR: {DATA_DIR} n'existe pas!")
        return

    # Entraînement avec Early Stopping
    print(f"\n{'='*60}")
    print("ENTRAÎNEMENT AVEC EARLY STOPPING")
    print(f"{'='*60}")

    results = train_kfold_with_confusion_matrix(
        data_dir=DATA_DIR,
        k_folds=K_FOLDS,
        epochs=EPOCHS,
        early_stopping_patience=5,  # Patience pour early stopping (réduite à 5)
        min_delta=0.001             # Amélioration minimale
    )

    if results:
        mean_acc = np.mean([r['best_val_acc'] for r in results])
        std_acc = np.std([r['best_val_acc'] for r in results])

        print(f"\n  Résultats finaux avec Early Stopping:")
        print(f"   Accuracy moyenne: {mean_acc:.2f}% ± {std_acc:.2f}%")
        print(f"   Détails par fold:")

        total_epochs_saved = 0
        for r in results:
            epochs_saved = EPOCHS - r['actual_epochs']
            total_epochs_saved += epochs_saved
            status = "🛑" if r['early_stopped'] else "  "
            print(f"     Fold {r['fold']}: {r['best_val_acc']:.2f}% ({r['actual_epochs']}/{EPOCHS} époques) {status}")

        efficiency = 100 * total_epochs_saved / (K_FOLDS * EPOCHS)
        print(f"\n     Efficacité Early Stopping:")
        print(f"   Époques économisées: {total_epochs_saved}/{K_FOLDS * EPOCHS} ({efficiency:.1f}%)")

        # Analyse des erreurs
        analyze_prediction_errors(results)

        # Courbes d'apprentissage (avec Early Stopping visible)
        plot_learning_curves_with_early_stopping(results)

        # Analyse des features spatiales
        analyze_spatial_features_importance(results)

        # Statistiques détaillées par relation
        print(f"\n{'='*60}")
        print("PERFORMANCE PAR RELATION SPATIALE")
        print(f"{'='*60}")

        # Calculer les métriques globales
        global_predictions = np.concatenate([r['final_predictions'] for r in results])
        global_true_labels = np.concatenate([r['final_true_labels'] for r in results])

        # Précision par classe
        for i, relation in enumerate(SPATIAL_RELATIONS):
            mask = global_true_labels == i
            if np.sum(mask) > 0:
                class_acc = 100 * np.mean(global_predictions[mask] == global_true_labels[mask])
                support = np.sum(mask)
                print(f"  {relation:<15}: {class_acc:6.2f}% ({support:3d} échantillons)")

        print(f"\n   Expérience avec Early Stopping terminée!")
        print(f"    Performance globale: {mean_acc:.2f}% ± {std_acc:.2f}%")
        print(f"     Efficacité: {efficiency:.1f}% d'époques économisées")

        # Avantages de cette approche
        print(f"\n{'='*60}")
        print("AVANTAGES DE L'EARLY STOPPING")
        print(f"{'='*60}")
        print("   Prévention de l'overfitting")
        print("   Réduction du temps d'entraînement")
        print("   Automatisation de l'arrêt optimal")
        print("   Restauration des meilleurs poids")
        print("   Meilleure généralisation")

        return results
    else:
        print("    Aucun résultat obtenu!")
        return None

# =============================================================================
# FONCTION DE TEST RAPIDE
# =============================================================================

def quick_test_image_bbox():
    """Test rapide avec Early Stopping pour validation"""

    DATA_DIR = "data/spatialsense"

    print("="*60)
    print("TEST RAPIDE - IMAGE + BBOX COORDINATES + EARLY STOPPING")
    print("="*60)

    if not os.path.exists(DATA_DIR):
        print(f"ERREUR: {DATA_DIR} n'existe pas!")
        return

    # Test avec paramètres réduits
    results = train_kfold_with_confusion_matrix(
        data_dir=DATA_DIR,
        k_folds=3,               # Moins de folds
        epochs=15,               # Époques pour tester l'early stopping
        early_stopping_patience=3,  # Patience très réduite pour test rapide
        min_delta=0.001
    )

    if results:
        mean_acc = np.mean([r['best_val_acc'] for r in results])

        print(f"\n     Test rapide terminé!")
        print(f"   Accuracy moyenne: {mean_acc:.2f}%")

        # Statistiques Early Stopping
        early_stopped_count = sum(1 for r in results if r['early_stopped'])
        mean_epochs = np.mean([r['actual_epochs'] for r in results])

        print(f"\n   Efficacité Early Stopping:")
        print(f"   Folds avec arrêt anticipé: {early_stopped_count}/{len(results)}")
        print(f"   Époques moyennes: {mean_epochs:.1f}/15")

        for r in results:
            status = "🛑 Arrêté" if r['early_stopped'] else "   Complet"
            print(f"   Fold {r['fold']}: {r['actual_epochs']:2d} époques - {status}")

        # Vérification des features spatiales
        print(f"\n   Vérification des features spatiales:")
        sample_dataset = ImageBBoxDataset(
            data_dir=DATA_DIR,
            split='train',
            transform=None
        )

        if len(sample_dataset) > 0:
            _, spatial_features, _, metadata = sample_dataset[0]
            print(f"   Nombre de features: {len(spatial_features)}")
            print(f"   Features min: {spatial_features.min():.3f}")
            print(f"   Features max: {spatial_features.max():.3f}")
            print(f"   Features moyennes: {spatial_features.mean():.3f}")

        return results
    else:
        print("    Échec du test rapide!")
        return None

if __name__ == "__main__":
    # Choix entre test rapide ou expérience complète
    import sys

    if len(sys.argv) > 1 and sys.argv[1] == "quick":
        print("Mode test rapide avec Early Stopping activé...")
        quick_test_image_bbox()
    else:
        print("Mode expérience complète avec Early Stopping activé...")
        main_image_bbox_experiment()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import json
import os
from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# BERT imports
from transformers import BertTokenizer, BertModel

# =============================================================================
# CONFIGURATION GLOBALE
# =============================================================================

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device utilisé: {DEVICE}")

# Hyperparamètres
BATCH_SIZE = 8  # Réduit à cause de BERT
LEARNING_RATE = 0.001
EPOCHS = 15
K_FOLDS = 5
IMG_SIZE = 224
DROPOUT_RATE = 0.4

# Relations spatiales SpatialSense+
SPATIAL_RELATIONS = [
    'above', 'behind', 'in', 'in front of', 'next to',
    'on', 'to the left of', 'to the right of', 'under'
]

print(f"Architecture MULTIMODALE: IMAGE + BBOX + BERT TEXT:")
print(f"  - Relations: {len(SPATIAL_RELATIONS)}")
print(f"  - Modalité 1: Image originale → VGG16 FC-7 (4096)")
print(f"  - Modalité 2: Coordonnées BBox → MLP (28 → 512)")
print(f"  - Modalité 3: Texte 'subject object' → BERT (768 → 512)")
print(f"  - Fusion: Concaténation triple (5120 features)")

# =============================================================================
# BERT TEXT ENCODER
# =============================================================================

class BERTTextEncoder(nn.Module):
    """Encodeur BERT pour les textes 'subject object'"""

    def __init__(self, bert_model_name='bert-base-uncased', freeze_bert=True):
        super(BERTTextEncoder, self).__init__()

        # Charger BERT pré-entraîné
        self.tokenizer = BertTokenizer.from_pretrained(bert_model_name)
        self.bert = BertModel.from_pretrained(bert_model_name)

        # Gel des paramètres BERT si demandé
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False

        self.bert_dim = self.bert.config.hidden_size  # 768 pour bert-base

        # Projection des features BERT
        self.text_projection = nn.Sequential(
            nn.Linear(self.bert_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        print(f"BERT Text Encoder:")
        print(f"  - Modèle: {bert_model_name}")
        print(f"  - BERT dim: {self.bert_dim}")
        print(f"  - Projection: {self.bert_dim} → 512")
        print(f"  - Paramètres gelés: {freeze_bert}")

    def encode_text(self, texts, max_length=64):
        """Encode une liste de textes avec BERT"""
        # Tokenisation
        encoded = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors='pt'
        )

        # Déplacer sur le bon device
        input_ids = encoded['input_ids'].to(self.bert.device)
        attention_mask = encoded['attention_mask'].to(self.bert.device)

        # Passage dans BERT
        with torch.no_grad() if not any(p.requires_grad for p in self.bert.parameters()) else torch.enable_grad():
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)

        # Utiliser le token [CLS] pour la représentation globale
        cls_embeddings = outputs.last_hidden_state[:, 0, :]  # [batch_size, bert_dim]

        return cls_embeddings

    def forward(self, texts):
        """Forward pass pour l'encodage de texte"""
        # Encoder avec BERT
        bert_features = self.encode_text(texts)

        # Projeter les features
        projected_features = self.text_projection(bert_features)

        return projected_features

# =============================================================================
# DATASET MULTIMODAL: IMAGE + BBOX + TEXT
# =============================================================================

class MultiModalDataset(Dataset):
    """Dataset SpatialSense+ multimodal avec:
    - Image originale
    - Coordonnées des bounding boxes comme features supplémentaires
    - Texte 'subject object' encodé par BERT
    """

    def __init__(self, data_dir, split='train', transform=None):
        self.data_dir = data_dir
        self.split = split
        self.transform = transform

        self.relation_to_idx = {rel: idx for idx, rel in enumerate(SPATIAL_RELATIONS)}
        self.idx_to_relation = {idx: rel for rel, idx in self.relation_to_idx.items()}

        self.annotations = self._load_annotations()
        self.data_samples = self._prepare_samples()

        print(f"Dataset MultiModal SpatialSense+ {split}: {len(self.data_samples)} échantillons")
        if len(self.data_samples) > 0:
            self._print_statistics()

    def _load_annotations(self):
        annotations_path = os.path.join(self.data_dir, 'annotations.json')
        try:
            with open(annotations_path, 'r') as f:
                return json.load(f)
        except FileNotFoundError:
            print(f"Erreur: {annotations_path} non trouvé!")
            return []

    def _find_image_path(self, image_url):
        base_dir = os.path.join(self.data_dir, "images", "images")
        filename = os.path.basename(image_url)

        if "staticflickr" in image_url or len(filename.split('_')) == 2:
            return os.path.join(base_dir, "flickr", filename)
        else:
            return os.path.join(base_dir, "nyu", filename)

    def _prepare_samples(self):
        samples = []
        images_not_found = 0

        for img_data in self.annotations:
            if img_data['split'] != self.split:
                continue

            img_path = self._find_image_path(img_data['url'])

            if not os.path.exists(img_path):
                images_not_found += 1
                continue

            for ann in img_data['annotations']:
                if ann['label'] and ann['predicate'].lower().strip() in [rel.lower() for rel in SPATIAL_RELATIONS]:
                    # Trouver la relation correspondante
                    relation = None
                    for rel in SPATIAL_RELATIONS:
                        if rel.lower() == ann['predicate'].lower().strip():
                            relation = rel
                            break

                    if relation and 'bbox' in ann['subject'] and 'bbox' in ann['object']:
                        sample = {
                            'image_path': img_path,
                            'subject': ann['subject']['name'],
                            'object': ann['object']['name'],
                            'relation': relation,
                            'original_relation': ann['predicate'],
                            'subject_bbox': ann['subject']['bbox'],  # [y1, y2, x1, x2]
                            'object_bbox': ann['object']['bbox'],   # [y1, y2, x1, x2]
                            'image_width': img_data['width'],
                            'image_height': img_data['height']
                        }
                        samples.append(sample)

        if images_not_found > 0:
            print(f"Images non trouvées: {images_not_found}")

        return samples

    def _normalize_bbox(self, bbox, img_width, img_height):
        """Normalise les coordonnées de bounding box entre 0 et 1
        bbox format: [y1, y2, x1, x2]
        """
        y1, y2, x1, x2 = bbox

        # Normalisation
        norm_y1 = y1 / img_height
        norm_y2 = y2 / img_height
        norm_x1 = x1 / img_width
        norm_x2 = x2 / img_width

        # Calcul des features géométriques
        width = abs(norm_x2 - norm_x1)
        height = abs(norm_y2 - norm_y1)
        area = width * height
        center_x = (norm_x1 + norm_x2) / 2
        center_y = (norm_y1 + norm_y2) / 2
        aspect_ratio = width / (height + 1e-8)  # Éviter division par zéro

        return [norm_x1, norm_y1, norm_x2, norm_y2, width, height, area, center_x, center_y, aspect_ratio]

    def _compute_spatial_features(self, subject_bbox, object_bbox, img_width, img_height):
        """Calcule des features spatiales entre les deux bounding boxes"""
        # Normaliser les bounding boxes
        subj_features = self._normalize_bbox(subject_bbox, img_width, img_height)
        obj_features = self._normalize_bbox(object_bbox, img_width, img_height)

        # Features individuelles (20 features: 10 + 10)
        individual_features = subj_features + obj_features

        # Features relationnelles
        subj_center_x, subj_center_y = subj_features[7], subj_features[8]
        obj_center_x, obj_center_y = obj_features[7], obj_features[8]

        # Distance entre centres
        distance = np.sqrt((subj_center_x - obj_center_x)**2 + (subj_center_y - obj_center_y)**2)

        # Direction relative (angle)
        angle = np.arctan2(obj_center_y - subj_center_y, obj_center_x - subj_center_x)

        # Différences de taille
        area_ratio = (subj_features[6] + 1e-8) / (obj_features[6] + 1e-8)

        # Chevauchement (IoU approximatif)
        subj_x1, subj_y1, subj_x2, subj_y2 = subj_features[:4]
        obj_x1, obj_y1, obj_x2, obj_y2 = obj_features[:4]

        # Intersection
        inter_x1 = max(subj_x1, obj_x1)
        inter_y1 = max(subj_y1, obj_y1)
        inter_x2 = min(subj_x2, obj_x2)
        inter_y2 = min(subj_y2, obj_y2)

        if inter_x2 > inter_x1 and inter_y2 > inter_y1:
            intersection = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
            union = subj_features[6] + obj_features[6] - intersection
            iou = intersection / (union + 1e-8)
        else:
            iou = 0.0
            intersection = 0.0

        # Position relative
        relative_x = obj_center_x - subj_center_x
        relative_y = obj_center_y - subj_center_y

        # Features relationnelles (8 features)
        relational_features = [
            distance, angle, area_ratio, iou,
            intersection, relative_x, relative_y,
            1.0 if distance < 0.1 else 0.0  # Très proche
        ]

        # Total: 28 features (20 individuelles + 8 relationnelles)
        return individual_features + relational_features

    def _print_statistics(self):
        relation_counts = Counter([s['relation'] for s in self.data_samples])
        print(f"\nDistribution MultiModal dans {self.split}:")
        total = len(self.data_samples)

        for relation in SPATIAL_RELATIONS:
            count = relation_counts.get(relation, 0)
            percentage = count / total * 100 if total > 0 else 0
            print(f"   {relation}: {count} ({percentage:.1f}%)")

    def __len__(self):
        return len(self.data_samples)

    def __getitem__(self, idx):
        sample = self.data_samples[idx]

        try:
            # Charger l'image originale
            image = Image.open(sample['image_path']).convert('RGB')

            # Appliquer les transformations
            if self.transform:
                image = self.transform(image)

            # Calculer les features spatiales des bounding boxes
            spatial_features = self._compute_spatial_features(
                sample['subject_bbox'],
                sample['object_bbox'],
                sample['image_width'],
                sample['image_height']
            )

            # Convertir en tensor
            spatial_features = torch.tensor(spatial_features, dtype=torch.float32)

            # Créer le texte 'subject object'
            text = f"{sample['subject']} {sample['object']}"

            label = self.relation_to_idx[sample['relation']]

            metadata = {
                'subject': sample['subject'],
                'object': sample['object'],
                'relation': sample['relation'],
                'original_relation': sample['original_relation'],
                'subject_bbox': sample['subject_bbox'],
                'object_bbox': sample['object_bbox'],
                'spatial_features_count': len(spatial_features),
                'text': text
            }

            return image, spatial_features, text, label, metadata

        except Exception as e:
            print(f"Erreur chargement {sample['image_path']}: {e}")
            # Données par défaut en cas d'erreur
            dummy_image = Image.new('RGB', (224, 224), color='gray')

            if self.transform:
                dummy_image = self.transform(dummy_image)
            else:
                dummy_image = torch.zeros(3, 224, 224)

            dummy_features = torch.zeros(28, dtype=torch.float32)
            dummy_text = "error error"

            return dummy_image, dummy_features, dummy_text, 0, {
                'subject': 'error', 'object': 'error',
                'relation': 'next to', 'original_relation': 'error',
                'subject_bbox': [0, 0, 0, 0], 'object_bbox': [0, 0, 0, 0],
                'spatial_features_count': 28, 'text': dummy_text
            }

# =============================================================================
# ARCHITECTURE MULTIMODALE: VGG + BBOX + BERT
# =============================================================================

class MultiModalFeatureExtractor(nn.Module):
    """Extracteur de features multimodal combinant:
    - Features VGG de l'image (4096)
    - Features géométriques des bounding boxes (28 → 512)
    - Features textuelles BERT (768 → 512)
    """

    def __init__(self, spatial_features_dim=28):
        super(MultiModalFeatureExtractor, self).__init__()

        self.spatial_features_dim = spatial_features_dim

        # 1. VGG16 pré-entraîné pour l'image
        vgg16 = models.vgg16(pretrained=True)
        self.vgg_features = vgg16.features
        self.vgg_avgpool = vgg16.avgpool
        classifier_layers = list(vgg16.classifier.children())[:6]
        self.vgg_fc7 = nn.Sequential(*classifier_layers)

        # Gel des poids VGG
        for param in self.vgg_features.parameters():
            param.requires_grad = False
        for param in self.vgg_avgpool.parameters():
            param.requires_grad = False
        for param in self.vgg_fc7.parameters():
            param.requires_grad = False

        # 2. Projection des features spatiales BBox
        self.spatial_projection = nn.Sequential(
            nn.Linear(spatial_features_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        # 3. Encodeur BERT pour le texte
        self.text_encoder = BERTTextEncoder(freeze_bert=True)

        # Dimension de sortie après fusion triple
        self.fusion_dim = 4096 + 512 + 512  # VGG + spatial + text = 5120

        print(f"MultiModal Feature Extractor:")
        print(f"  - VGG FC-7: 4096 features (gelées)")
        print(f"  - Spatial features: {spatial_features_dim} → 512 (entraînables)")
        print(f"  - BERT text features: 768 → 512 (projection entraînable)")
        print(f"  - Fusion dim: {self.fusion_dim}")

    def forward(self, image, spatial_features, texts):
        # 1. Extraction features VGG de l'image
        vgg_features = self._extract_vgg_features(image)

        # 2. Projection des features spatiales
        projected_spatial = self.spatial_projection(spatial_features)

        # 3. Encoding des features textuelles BERT
        text_features = self.text_encoder(texts)

        # 4. Fusion par concaténation
        fused_features = torch.cat([vgg_features, projected_spatial, text_features], dim=1)

        return fused_features

    def _extract_vgg_features(self, x):
        """Extraction des features VGG FC-7"""
        x = self.vgg_features(x)
        x = self.vgg_avgpool(x)
        x = torch.flatten(x, 1)
        x = self.vgg_fc7(x)
        return x

class MultiModalMLP(nn.Module):
    """MLP adapté pour la classification avec features multimodales"""

    def __init__(self, input_dim, hidden1_dim=1024, hidden2_dim=512,
                 num_relations=len(SPATIAL_RELATIONS), dropout_rate=0.4):
        super(MultiModalMLP, self).__init__()

        print(f"MultiModal Classification MLP:")
        print(f"  - Input: {input_dim}")
        print(f"  - Hidden 1: {hidden1_dim}")
        print(f"  - Hidden 2: {hidden2_dim}")
        print(f"  - Output: {num_relations}")
        print(f"  - Dropout: {dropout_rate}")

        # Architecture MLP adaptée aux features multimodales
        self.fc1 = nn.Linear(input_dim, hidden1_dim)
        self.bn1 = nn.BatchNorm1d(hidden1_dim)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)

        self.fc2 = nn.Linear(hidden1_dim, hidden2_dim)
        self.bn2 = nn.BatchNorm1d(hidden2_dim)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_rate)

        self.fc3 = nn.Linear(hidden2_dim, num_relations)

        # Initialisation Xavier
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.xavier_uniform_(self.fc3.weight)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        x = self.fc3(x)
        return x

class MultiModalSpatialRelationModel(nn.Module):
    """Modèle complet multimodal pour la classification des relations spatiales"""

    def __init__(self, num_relations=len(SPATIAL_RELATIONS), spatial_features_dim=28):
        super(MultiModalSpatialRelationModel, self).__init__()


        # Extracteur de features multimodal
        self.feature_extractor = MultiModalFeatureExtractor(spatial_features_dim=spatial_features_dim)

        # Classifieur MLP
        self.classifier = MultiModalMLP(
            input_dim=self.feature_extractor.fusion_dim,  # 5120
            num_relations=num_relations
        )

    def forward(self, image, spatial_features, texts):
        # Extraction et fusion des features multimodales
        fused_features = self.feature_extractor(image, spatial_features, texts)

        # Classification
        output = self.classifier(fused_features)
        return output

# =============================================================================
# EARLY STOPPING
# =============================================================================

class EarlyStopping:
    """Early Stopping pour éviter l'overfitting"""

    def __init__(self, patience=5, min_delta=0.001, restore_best_weights=True, verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.verbose = verbose

        self.best_score = None
        self.counter = 0
        self.best_weights = None
        self.early_stop = False

    def __call__(self, val_score, model):
        if self.best_score is None:
            self.best_score = val_score
            self.save_checkpoint(model)
            if self.verbose:
                print(f"   Early Stopping: Score initial = {val_score:.3f}%")

        elif val_score < self.best_score + self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f"     Early Stopping: {self.counter}/{self.patience} (Best: {self.best_score:.3f}%)")

            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print(f"Early Stopping déclenché! Restauration du meilleur modèle (Accuracy: {self.best_score:.3f}%)")

        else:
            improvement = val_score - self.best_score
            if self.verbose:
                print(f"    Amélioration: {improvement:.3f}% (Nouveau best: {val_score:.3f}%)")

            self.best_score = val_score
            self.save_checkpoint(model)
            self.counter = 0

    def save_checkpoint(self, model):
        if self.restore_best_weights:
            self.best_weights = {key: value.cpu().clone() for key, value in model.state_dict().items()}

    def restore_best_weights_to_model(self, model):
        if self.best_weights is not None:
            device = next(model.parameters()).device
            best_weights_on_device = {key: value.to(device) for key, value in self.best_weights.items()}
            model.load_state_dict(best_weights_on_device)

# =============================================================================
# FONCTIONS D'ÉVALUATION
# =============================================================================

def evaluate_multimodal_with_confusion_matrix(model, dataloader, criterion, device, class_names=None):
    """Évalue le modèle multimodal"""
    model.eval()
    running_loss = 0.0
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for images, spatial_features, texts, labels, _ in tqdm(dataloader, desc='Evaluating MultiModal'):
            images = images.to(device)
            spatial_features = spatial_features.to(device)
            labels = labels.to(device)

            outputs = model(images, spatial_features, texts)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculs des métriques
    accuracy = 100. * np.mean(np.array(all_predictions) == np.array(all_labels))
    avg_loss = running_loss / len(dataloader)

    # Matrice de confusion
    cm = confusion_matrix(all_labels, all_predictions)

    # Rapport de classification
    if class_names is None:
        class_names = SPATIAL_RELATIONS

    report = classification_report(
        all_labels, all_predictions,
        target_names=class_names,
        output_dict=True,
        zero_division=0
    )

    return avg_loss, accuracy, cm, report, all_predictions, all_labels

def plot_confusion_matrix(cm, class_names, title="Matrice de Confusion", normalize=False):
    """Affiche la matrice de confusion"""
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        fmt = '.2f'
        title += " (Normalisée)"
    else:
        fmt = 'd'

    plt.figure(figsize=(12, 10))
    sns.heatmap(cm,
                annot=True,
                fmt=fmt,
                cmap='Blues',
                xticklabels=class_names,
                yticklabels=class_names,
                cbar_kws={'label': 'Proportion' if normalize else 'Nombre de prédictions'})

    plt.title(title, fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Prédictions', fontsize=14, fontweight='bold')
    plt.ylabel('Vraies étiquettes', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

def display_classification_metrics(report, title="Métriques de Classification"):
    """Affiche les métriques de classification"""
    print(f"\n{'='*60}")
    print(f"{title}")
    print(f"{'='*60}")

    # Métriques par classe
    print(f"{'Classe':<15} {'Précision':<10} {'Rappel':<10} {'F1-Score':<10} {'Support':<10}")
    print("-" * 60)

    for class_name in SPATIAL_RELATIONS:
        if class_name in report:
            metrics = report[class_name]
            print(f"{class_name:<15} {metrics['precision']:<10.3f} {metrics['recall']:<10.3f} "
                  f"{metrics['f1-score']:<10.3f} {metrics['support']:<10.0f}")

    # Métriques globales
    print("-" * 60)
    print(f"{'Accuracy':<15} {'':<10} {'':<10} {report['accuracy']:<10.3f} {report['macro avg']['support']:<10.0f}")
    print(f"{'Macro avg':<15} {report['macro avg']['precision']:<10.3f} {report['macro avg']['recall']:<10.3f} "
          f"{report['macro avg']['f1-score']:<10.3f} {report['macro avg']['support']:<10.0f}")
    print(f"{'Weighted avg':<15} {report['weighted avg']['precision']:<10.3f} {report['weighted avg']['recall']:<10.3f} "
          f"{report['weighted avg']['f1-score']:<10.3f} {report['weighted avg']['support']:<10.0f}")

# =============================================================================
# VISUALISATION
# =============================================================================

def visualize_multimodal_samples(dataset, num_samples=3):
    """Visualise des échantillons multimodaux"""
    fig, axes = plt.subplots(num_samples, 2, figsize=(14, 4*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    for i in range(min(num_samples, len(dataset))):
        image, spatial_features, text, label, metadata = dataset[i]

        # Dénormaliser pour affichage
        def denormalize(tensor):
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            return (tensor * std + mean).clamp(0, 1).permute(1, 2, 0)

        image_display = denormalize(image)

        # Image + texte
        axes[i, 0].imshow(image_display)
        axes[i, 0].set_title(f"Image + Texte BERT\n'{text}'\nRelation: {metadata['relation']}")
        axes[i, 0].axis('off')

        # Informations détaillées
        axes[i, 1].axis('off')
        info_text = f"""
        Échantillon {i+1} - MULTIMODAL:

        Sujet: {metadata['subject']}
        Objet: {metadata['object']}
        Relation: {metadata['relation']}
        Texte BERT: "{text}"

        Subject BBox: {metadata['subject_bbox']}
        Object BBox: {metadata['object_bbox']}

        Features Multimodales (5120 total):
        • VGG features: 4096 (gelées)
        • Spatial features: 28 → 512 (entraînables)
        • BERT features: 768 → 512 (projection entraînable)

        Architecture:
        1. Image → VGG16 FC-7
        2. BBox coords → MLP projection
        3. "subject object" → BERT encoding
        4. Fusion → Concaténation (5120)
        5. Classification → MLP (5120→1024→512→9)
        """
        axes[i, 1].text(0.1, 0.5, info_text, transform=axes[i, 1].transAxes,
                       fontsize=9, verticalalignment='center',
                       bbox=dict(boxstyle='round', facecolor='lightcyan', alpha=0.8))

    plt.suptitle('Architecture Multimodale: Image + BBox + BERT Text',
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# =============================================================================
# TRANSFORMATIONS
# =============================================================================

def create_transforms():
    """Transformations pour les images"""
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )

    train_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
        transforms.ToTensor(),
        normalize
    ])

    val_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        normalize
    ])

    return train_transform, val_transform

# =============================================================================
# FONCTIONS D'ENTRAÎNEMENT
# =============================================================================

def train_multimodal_epoch(model, dataloader, criterion, optimizer, device):
    """Entraîne une époque avec architecture multimodale"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    progress_bar = tqdm(dataloader, desc='Training MultiModal')
    for batch_idx, (images, spatial_features, texts, labels, _) in enumerate(progress_bar):
        images = images.to(device)
        spatial_features = spatial_features.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images, spatial_features, texts)
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        progress_bar.set_postfix({
            'loss': running_loss / (batch_idx + 1),
            'acc': 100. * correct / total
        })

    return running_loss / len(dataloader), 100. * correct / total

# =============================================================================
# ENTRAÎNEMENT K-FOLD MULTIMODAL
# =============================================================================

def train_multimodal_kfold_with_confusion_matrix(data_dir, k_folds=5, epochs=15,
                                                early_stopping_patience=5, min_delta=0.001):
    """Entraînement K-fold multimodal avec Early Stopping"""


    # Transformations
    train_transform, val_transform = create_transforms()

    # Dataset multimodal
    full_dataset = MultiModalDataset(
        data_dir=data_dir,
        split='train',
        transform=train_transform
    )

    if len(full_dataset) == 0:
        print("Dataset vide!")
        return []

    # Visualisation d'échantillons
    print("\nVisualisation échantillons multimodaux:")
    visualize_multimodal_samples(full_dataset, num_samples=3)

    # K-fold cross-validation
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_indices, val_indices) in enumerate(kfold.split(full_dataset)):
        print(f"\n{'='*50}")
        print(f"FOLD {fold + 1}/{k_folds} - MULTIMODAL + EARLY STOPPING")
        print(f"{'='*50}")

        # Sous-ensembles
        train_subset = torch.utils.data.Subset(full_dataset, train_indices)
        val_subset = torch.utils.data.Subset(full_dataset, val_indices)

        # DataLoaders
        train_loader = DataLoader(
            train_subset, batch_size=BATCH_SIZE, shuffle=True,
            num_workers=2, pin_memory=True
        )

        val_loader = DataLoader(
            val_subset, batch_size=BATCH_SIZE, shuffle=False,
            num_workers=2, pin_memory=True
        )

        # Modèle multimodal
        model = MultiModalSpatialRelationModel(
            num_relations=len(SPATIAL_RELATIONS),
            spatial_features_dim=28
        )
        model = model.to(DEVICE)

        # Optimiseur (entraîner seulement les couches non-gelées)
        trainable_params = []
        # Spatial projection (entraînable)
        trainable_params.extend(model.feature_extractor.spatial_projection.parameters())
        # BERT text projection (entraînable, BERT gelé)
        trainable_params.extend(model.feature_extractor.text_encoder.text_projection.parameters())
        # Classifier final (entraînable)
        trainable_params.extend(model.classifier.parameters())

        optimizer = optim.AdamW(trainable_params, lr=LEARNING_RATE, weight_decay=0.01)
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

        # Scheduler
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

        # Early Stopping
        early_stopping = EarlyStopping(
            patience=early_stopping_patience,
            min_delta=min_delta,
            restore_best_weights=True,
            verbose=True
        )

        # Variables d'entraînement
        train_losses, val_losses = [], []
        train_accs, val_accs = [], []
        actual_epochs = 0

        for epoch in range(epochs):
            print(f"\n   Epoch {epoch+1}/{epochs}")

            # Training multimodal
            train_loss, train_acc = train_multimodal_epoch(
                model, train_loader, criterion, optimizer, DEVICE
            )
            train_losses.append(train_loss)
            train_accs.append(train_acc)

            # Validation multimodale
            val_loss, val_acc, cm, report, predictions, true_labels = evaluate_multimodal_with_confusion_matrix(
                model, val_loader, criterion, DEVICE, SPATIAL_RELATIONS
            )
            val_losses.append(val_loss)
            val_accs.append(val_acc)

            scheduler.step()
            actual_epochs = epoch + 1

            print(f"       Train: {train_loss:.4f} / {train_acc:.2f}%")
            print(f"   Val: {val_loss:.4f} / {val_acc:.2f}%")

            # Sauvegarder le meilleur modèle
            if len(val_accs) == 1 or val_acc > max(val_accs[:-1]):
                torch.save(model.state_dict(), f'best_multimodal_model_fold_{fold+1}.pth')

            # Vérification Early Stopping
            early_stopping(val_acc, model)

            if early_stopping.early_stop:
                print(f"🛑 Arrêt anticipé à l'époque {epoch+1}")
                early_stopping.restore_best_weights_to_model(model)
                break

        # Restaurer les meilleurs poids si pas d'arrêt anticipé
        if not early_stopping.early_stop:
            early_stopping.restore_best_weights_to_model(model)
            print(f"   Entraînement terminé ({epochs} époques). Meilleurs poids restaurés.")

        # Évaluation finale
        print(f"\n{'='*40}")
        print(f"ÉVALUATION FINALE FOLD {fold + 1}")
        print(f"{'='*40}")

        final_val_loss, final_val_acc, final_cm, final_report, final_predictions, final_true_labels = evaluate_multimodal_with_confusion_matrix(
            model, val_loader, criterion, DEVICE, SPATIAL_RELATIONS
        )

        # Affichage des métriques
        display_classification_metrics(final_report, f"Métriques Multimodales Fold {fold + 1}")

        # Matrices de confusion
        plot_confusion_matrix(final_cm, SPATIAL_RELATIONS,
                            f"Matrice de Confusion Multimodale - Fold {fold + 1}", normalize=False)
        plot_confusion_matrix(final_cm, SPATIAL_RELATIONS,
                            f"Matrice de Confusion Multimodale Normalisée - Fold {fold + 1}", normalize=True)

        # Stockage des résultats
        fold_results.append({
            'fold': fold + 1,
            'train_losses': train_losses,
            'val_losses': val_losses,
            'train_accs': train_accs,
            'val_accs': val_accs,
            'best_val_acc': early_stopping.best_score,
            'final_confusion_matrix': final_cm,
            'final_report': final_report,
            'final_predictions': final_predictions,
            'final_true_labels': final_true_labels,
            'model': model,
            'actual_epochs': actual_epochs,
            'early_stopped': early_stopping.early_stop
        })

        print(f"   Fold {fold + 1} - Meilleur: {early_stopping.best_score:.2f}% (après {actual_epochs} époques)")
        if early_stopping.early_stop:
            print(f"   🛑 Arrêt anticipé activé")
        else:
            print(f"      Entraînement complet")

    # Analyse globale des résultats
    print(f"\n{'='*70}")
    print("ANALYSE GLOBALE DES RÉSULTATS MULTIMODAUX")
    print(f"{'='*70}")

    # Statistiques générales
    mean_acc = np.mean([r['best_val_acc'] for r in fold_results])
    std_acc = np.std([r['best_val_acc'] for r in fold_results])
    print(f"Accuracy moyenne: {mean_acc:.2f}% ± {std_acc:.2f}%")

    # Statistiques Early Stopping
    early_stopped_folds = [r for r in fold_results if r['early_stopped']]
    mean_epochs = np.mean([r['actual_epochs'] for r in fold_results])

    print(f"\n   Statistiques Early Stopping:")
    print(f"   Folds avec arrêt anticipé: {len(early_stopped_folds)}/{k_folds}")
    print(f"   Époques moyennes: {mean_epochs:.1f}/{epochs}")

    for r in fold_results:
        status = "🛑 Arrêté" if r['early_stopped'] else "   Complet"
        print(f"   Fold {r['fold']}: {r['actual_epochs']:2d} époques - {status}")

    # Matrice de confusion globale
    global_cm = np.sum([r['final_confusion_matrix'] for r in fold_results], axis=0)
    global_predictions = np.concatenate([r['final_predictions'] for r in fold_results])
    global_true_labels = np.concatenate([r['final_true_labels'] for r in fold_results])

    # Rapport global
    global_report = classification_report(
        global_true_labels, global_predictions,
        target_names=SPATIAL_RELATIONS,
        output_dict=True,
        zero_division=0
    )

    # Affichage final
    display_classification_metrics(global_report, "Métriques Globales Multimodales")
    plot_confusion_matrix(global_cm, SPATIAL_RELATIONS,
                        "Matrice de Confusion Globale Multimodale", normalize=False)
    plot_confusion_matrix(global_cm, SPATIAL_RELATIONS,
                        "Matrice de Confusion Globale Multimodale Normalisée", normalize=True)

    return fold_results

# =============================================================================
# ANALYSE DES ERREURS
# =============================================================================

def analyze_prediction_errors(fold_results, top_k=5):
    """Analyse des erreurs de prédiction"""
    print(f"\n{'='*60}")
    print("ANALYSE DES ERREURS DE PRÉDICTION MULTIMODALES")
    print(f"{'='*60}")

    # Combiner toutes les prédictions
    all_predictions = np.concatenate([r['final_predictions'] for r in fold_results])
    all_true_labels = np.concatenate([r['final_true_labels'] for r in fold_results])

    # Identifier les erreurs
    errors = []
    for true_idx, pred_idx in zip(all_true_labels, all_predictions):
        if true_idx != pred_idx:
            true_relation = SPATIAL_RELATIONS[true_idx]
            pred_relation = SPATIAL_RELATIONS[pred_idx]
            errors.append((true_relation, pred_relation))

    # Compter les erreurs
    error_counts = Counter(errors)

    print(f"Nombre total d'erreurs: {len(errors)}")
    print(f"Accuracy globale: {100 * (1 - len(errors) / len(all_predictions)):.2f}%")
    print(f"\nTop {top_k} erreurs les plus fréquentes:")
    print("-" * 60)

    for i, ((true_rel, pred_rel), count) in enumerate(error_counts.most_common(top_k)):
        percentage = 100 * count / len(errors)
        print(f"{i+1:2d}. {true_rel:>15} → {pred_rel:<15} : {count:3d} ({percentage:5.1f}%)")

def plot_learning_curves_multimodal(fold_results):
    """Affiche les courbes d'apprentissage multimodales"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    axes[0, 0].set_title('Courbes de Loss - Training', fontweight='bold')
    axes[0, 1].set_title('Courbes de Loss - Validation', fontweight='bold')
    axes[1, 0].set_title('Courbes d\'Accuracy - Training', fontweight='bold')
    axes[1, 1].set_title('Courbes d\'Accuracy - Validation', fontweight='bold')

    colors = ['blue', 'red', 'green', 'orange', 'purple']

    for i, result in enumerate(fold_results):
        epochs = range(1, len(result['train_losses']) + 1)
        color = colors[i % len(colors)]

        # Style selon early stopping
        linestyle = '--' if result['early_stopped'] else '-'
        alpha = 0.8 if result['early_stopped'] else 0.7

        label = f'Fold {result["fold"]}'
        if result['early_stopped']:
            label += f' (ES@{result["actual_epochs"]})'

        # Plots
        axes[0, 0].plot(epochs, result['train_losses'], color=color, label=label, alpha=alpha, linestyle=linestyle)
        axes[0, 1].plot(epochs, result['val_losses'], color=color, label=label, alpha=alpha, linestyle=linestyle)
        axes[1, 0].plot(epochs, result['train_accs'], color=color, label=label, alpha=alpha, linestyle=linestyle)
        axes[1, 1].plot(epochs, result['val_accs'], color=color, label=label, alpha=alpha, linestyle=linestyle)

        # Marquer le point d'arrêt
        if result['early_stopped']:
            stop_epoch = result['actual_epochs']
            axes[1, 1].scatter(stop_epoch, result['val_accs'][stop_epoch-1],
                             color=color, s=100, marker='X', zorder=5)

    # Configuration
    for ax in axes.flat:
        ax.set_xlabel('Époque')
        ax.legend()
        ax.grid(True, alpha=0.3)

    axes[0, 0].set_ylabel('Loss')
    axes[0, 1].set_ylabel('Loss')
    axes[1, 0].set_ylabel('Accuracy (%)')
    axes[1, 1].set_ylabel('Accuracy (%)')

    plt.suptitle('Courbes d\'Apprentissage Multimodales\n(X = Arrêt anticipé)',
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()
 

# =============================================================================
# FONCTION PRINCIPALE
# =============================================================================

def main_multimodal_experiment():
    """Expérience principale multimodale"""

    DATA_DIR = "data/spatialsense"


    if not os.path.exists(DATA_DIR):
        print(f"\nERREUR: {DATA_DIR} n'existe pas!")
        return

    # Entraînement multimodal
    print(f"\n{'='*60}")
    print("LANCEMENT ENTRAÎNEMENT MULTIMODAL")
    print(f"{'='*60}")

    results = train_multimodal_kfold_with_confusion_matrix(
        data_dir=DATA_DIR,
        k_folds=K_FOLDS,
        epochs=EPOCHS,
        early_stopping_patience=5,
        min_delta=0.001
    )

    if results:
        mean_acc = np.mean([r['best_val_acc'] for r in results])
        std_acc = np.std([r['best_val_acc'] for r in results])

        print(f"\n  RÉSULTATS FINAUX MULTIMODAUX:")
        print(f"   Accuracy moyenne: {mean_acc:.2f}% ± {std_acc:.2f}%")
        print(f"   Détails par fold:")

        total_epochs_saved = 0
        for r in results:
            epochs_saved = EPOCHS - r['actual_epochs']
            total_epochs_saved += epochs_saved
            status = "🛑" if r['early_stopped'] else "  "
            print(f"     Fold {r['fold']}: {r['best_val_acc']:.2f}% ({r['actual_epochs']}/{EPOCHS} époques) {status}")

        efficiency = 100 * total_epochs_saved / (K_FOLDS * EPOCHS)
        print(f"\n     Efficacité Early Stopping: {efficiency:.1f}% d'époques économisées")

        # Analyses
        analyze_prediction_errors(results)
        plot_learning_curves_multimodal(results)

        # Performance par relation
        print(f"\n{'='*60}")
        print("PERFORMANCE PAR RELATION SPATIALE")
        print(f"{'='*60}")

        global_predictions = np.concatenate([r['final_predictions'] for r in results])
        global_true_labels = np.concatenate([r['final_true_labels'] for r in results])

        for i, relation in enumerate(SPATIAL_RELATIONS):
            mask = global_true_labels == i
            if np.sum(mask) > 0:
                class_acc = 100 * np.mean(global_predictions[mask] == global_true_labels[mask])
                support = np.sum(mask)
                print(f"  {relation:<15}: {class_acc:6.2f}% ({support:3d} échantillons)")

        print(f"\n   Expérience Multimodale terminée!")
        print(f"    Performance: {mean_acc:.2f}% ± {std_acc:.2f}%")
        print(f"\n{'='*60}")
        print("RÉCAPITULATIF ARCHITECTURE MULTIMODALE")
        print(f"{'='*60}")
        print("   3 modalités fusionnées intelligemment")
        print("   VGG: Représentation visuelle riche")
        print("   BBox: Information spatiale précise")
        print("   BERT: Compréhension sémantique")
        print("   Early Stopping: Optimisation automatique")
        print("   Évaluation complète avec matrices de confusion")

        return results
    else:
        print("    Aucun résultat obtenu!")
        return None

# =============================================================================
# TEST RAPIDE
# =============================================================================

def quick_test_multimodal():
    """Test rapide multimodal"""

    DATA_DIR = "data/spatialsense"

    print("="*60)
    print("TEST RAPIDE MULTIMODAL")
    print("="*60)

    if not os.path.exists(DATA_DIR):
        print(f"ERREUR: {DATA_DIR} n'existe pas!")
        return

    results = train_multimodal_kfold_with_confusion_matrix(
        data_dir=DATA_DIR,
        k_folds=3,
        epochs=8,
        early_stopping_patience=3,
        min_delta=0.001
    )

    if results:
        mean_acc = np.mean([r['best_val_acc'] for r in results])
        early_stopped_count = sum(1 for r in results if r['early_stopped'])
        mean_epochs = np.mean([r['actual_epochs'] for r in results])

        print(f"\nTest rapide terminé!")
        print(f"   Accuracy moyenne: {mean_acc:.2f}%")
        print(f"   Folds avec arrêt anticipé: {early_stopped_count}/3")
        print(f"   Époques moyennes: {mean_epochs:.1f}/8")

        # Vérification modalités
        sample_dataset = MultiModalDataset(data_dir=DATA_DIR, split='train', transform=None)
        if len(sample_dataset) > 0:
            _, spatial_features, text, _, metadata = sample_dataset[0]
            print(f"\nVérification modalités:")
            print(f"   Spatial features: {len(spatial_features)} dimensions")
            print(f"   Texte BERT: '{text}'")
            print(f"   Relation: {metadata['relation']}")
            print(f"   Subject: {metadata['subject']}")
            print(f"   Object: {metadata['object']}")

        return results
    else:
        print("Échec du test rapide!")
        return None

if __name__ == "__main__":
    import sys

    if len(sys.argv) > 1 and sys.argv[1] == "quick":
        print("Mode test rapide multimodal avec Early Stopping...")
        quick_test_multimodal()
    else:
        print("Mode expérience complète multimodale avec Early Stopping...")
        main_multimodal_experiment()