In [None]:
!pip install -q tqdm pandas pillow

In [None]:
from google.colab import drive
import os
import subprocess
from pathlib import Path

In [None]:
drive.mount('/content/drive')

In [None]:
def setup_dataset():
    base_dir = Path("/content/spatial-relation-benchmark")

    if not base_dir.exists():
        subprocess.run(["git", "clone", "https://github.com/AlvinWen428/spatial-relation-benchmark.git"],
                      cwd="/content", capture_output=True)

        os.chdir(base_dir)

        subprocess.run(["pip", "install", "-q", "-r", "requirements.txt"], capture_output=True)

        data_dir = base_dir / "data" / "spatialsense"
        data_dir.mkdir(parents=True, exist_ok=True)

        if not (data_dir / "spatialsense.zip").exists():
            subprocess.run([
                "wget", "-q",
                "https://zenodo.org/api/records/8104370/files-archive",
                "-O", str(data_dir / "spatialsense.zip")
            ])

        if (data_dir / "spatialsense.zip").exists() and not (data_dir / "annotations.json").exists():
            subprocess.run(["unzip", "-q", str(data_dir / "spatialsense.zip"), "-d", str(data_dir)])

        images_dir = data_dir / "images"
        images_dir.mkdir(exist_ok=True)

        images_tar = data_dir / "images.tar.gz"
        if images_tar.exists() and not (images_dir / "images").exists():
            subprocess.run(["tar", "-zxf", str(images_tar), "-C", str(images_dir)], capture_output=True)

        spatialsense_plus = data_dir / "annots_spatialsenseplus.json"
        if not spatialsense_plus.exists():
            try:
                import gdown
                gdown.download(
                    "https://drive.google.com/uc?id=1vIOozqk3OlxkxZgL356pD1EAGt06ZwM4",
                    str(spatialsense_plus),
                    quiet=True
                )
            except:
                subprocess.run(["pip", "install", "-q", "gdown"], capture_output=True)
                import gdown
                gdown.download(
                    "https://drive.google.com/uc?id=1vIOozqk3OlxkxZgL356pD1EAGt06ZwM4",
                    str(spatialsense_plus),
                    quiet=True
                )

    return base_dir

In [None]:
base_dir = setup_dataset()
os.chdir(base_dir)

In [None]:
data_dir = Path("data/spatialsense")
zip_file = data_dir / "spatialsense.zip"
if zip_file.exists() and not (data_dir / "annotations.json").exists():
    subprocess.run(["unzip", "-q", str(zip_file), "-d", str(data_dir)], capture_output=True)

In [None]:
import json
import os
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
import numpy as np
from PIL import Image
from tqdm import tqdm
import pandas as pd
from collections import Counter, defaultdict

In [None]:
@dataclass
class BoundingBox:
    y1: float
    y2: float
    x1: float
    x2: float

    @property
    def width(self) -> float:
        return self.x2 - self.x1

    @property
    def height(self) -> float:
        return self.y2 - self.y1

    @property
    def center(self) -> Tuple[float, float]:
        return ((self.x1 + self.x2) / 2, (self.y1 + self.y2) / 2)

@dataclass
class SpatialObject:
    name: str
    bbox: BoundingBox
    x: float
    y: float

@dataclass
class SpatialAnnotation:
    subject: SpatialObject
    object: SpatialObject
    predicate: str
    label: bool

@dataclass
class ImageData:
    url: str
    width: int
    height: int
    split: str
    annotations: List[SpatialAnnotation]

class SpatialSenseLoader:

    def __init__(self, data_dir: str = "data/spatialsense"):
        self.data_dir = Path(data_dir)
        self.images_dir = self.data_dir / "images" / "images"
        self.annotations_file = self.data_dir / "annotations.json"
        self.spatialsense_plus_file = self.data_dir / "annots_spatialsenseplus.json"

        self.images_data: List[ImageData] = []
        self.stats: Dict = {}

    def _parse_bbox(self, bbox_list: List[float]) -> BoundingBox:
        return BoundingBox(y1=bbox_list[0], y2=bbox_list[1],
                          x1=bbox_list[2], x2=bbox_list[3])

    def _parse_object(self, obj_data: Dict) -> SpatialObject:
        bbox = self._parse_bbox(obj_data['bbox'])
        return SpatialObject(
            name=obj_data['name'],
            bbox=bbox,
            x=obj_data['x'],
            y=obj_data['y']
        )

    def _parse_annotation(self, ann_data: Dict) -> SpatialAnnotation:
        subject = self._parse_object(ann_data['subject'])
        obj = self._parse_object(ann_data['object'])

        return SpatialAnnotation(
            subject=subject,
            object=obj,
            predicate=ann_data['predicate'],
            label=ann_data['label']
        )

    def _find_image_path(self, image_url: str) -> Optional[Path]:
        filename = Path(image_url).name

        flickr_path = self.images_dir / "flickr" / filename
        if flickr_path.exists():
            return flickr_path

        nyu_path = self.images_dir / "nyu" / filename
        if nyu_path.exists():
            return nyu_path

        return None

    def load_annotations(self) -> None:
        if not self.annotations_file.exists():
            return

        with open(self.annotations_file, 'r') as f:
            raw_data = json.load(f)

        self.images_data = []

        for img_data in tqdm(raw_data, desc="Loading annotations"):
            annotations = []
            for ann_data in img_data['annotations']:
                annotations.append(self._parse_annotation(ann_data))

            if annotations:
                image = ImageData(
                    url=img_data['url'],
                    width=img_data['width'],
                    height=img_data['height'],
                    split=img_data['split'],
                    annotations=annotations
                )
                self.images_data.append(image)

    def compute_statistics(self) -> Dict:
        if not self.images_data:
            self.load_annotations()

        total_images = len(self.images_data)
        total_relations = sum(len(img.annotations) for img in self.images_data)

        split_counts = Counter(img.split for img in self.images_data)

        predicate_counts = Counter()
        for img in self.images_data:
            for ann in img.annotations:
                predicate_counts[ann.predicate] += 1

        object_counts = Counter()
        for img in self.images_data:
            for ann in img.annotations:
                object_counts[ann.subject.name] += 1
                object_counts[ann.object.name] += 1

        widths = [img.width for img in self.images_data]
        heights = [img.height for img in self.images_data]

        relations_per_image = [len(img.annotations) for img in self.images_data]

        self.stats = {
            'total_images': total_images,
            'total_relations': total_relations,
            'avg_relations_per_image': np.mean(relations_per_image),
            'split_distribution': dict(split_counts),
            'predicate_distribution': dict(predicate_counts),
            'top_objects': dict(object_counts.most_common(20)),
            'image_dimensions': {
                'width': {'min': min(widths), 'max': max(widths), 'mean': np.mean(widths)},
                'height': {'min': min(heights), 'max': max(heights), 'mean': np.mean(heights)}
            },
            'relations_per_image': {
                'min': min(relations_per_image),
                'max': max(relations_per_image),
                'mean': np.mean(relations_per_image),
                'std': np.std(relations_per_image)
            }
        }

        return self.stats

    def get_images_by_split(self, split: str) -> List[ImageData]:
        return [img for img in self.images_data if img.split == split]

    def get_images_by_predicate(self, predicate: str) -> List[ImageData]:
        return [img for img in self.images_data
                if any(ann.predicate == predicate for ann in img.annotations)]

    def verify_image_files(self) -> Dict[str, int]:
        found = 0
        missing = 0

        for img in tqdm(self.images_data, desc="Verifying images"):
            if self._find_image_path(img.url):
                found += 1
            else:
                missing += 1

        return {'found': found, 'missing': missing}

    def get_summary(self) -> pd.DataFrame:
        if not self.stats:
            self.compute_statistics()

        summary_data = []
        for split, count in self.stats['split_distribution'].items():
            summary_data.append({
                'Split': split,
                'Images': count,
                'Percentage': f"{count/self.stats['total_images']*100:.1f}%"
            })

        return pd.DataFrame(summary_data)

loader = SpatialSenseLoader()
loader.load_annotations()
stats = loader.compute_statistics()

subjects = []
objects = []
labels = []

for image_data in loader.images_data:
    for annotation in image_data.annotations:
        if annotation.label:  # On garde seulement les relations valides (True)
            subjects.append(annotation.subject.name)
            objects.append(annotation.object.name)
            labels.append(annotation.predicate)

if stats:
    print("=== SpatialSense Dataset Statistics ===")
    print(f"Total Images: {stats['total_images']:,}")
    print(f"Total Relations: {stats['total_relations']:,}")
    print(f"Average Relations per Image: {stats['avg_relations_per_image']:.2f}")
    print(f"Number of Unique Predicates: {len(stats['predicate_distribution'])}")
    print(f"Number of Unique Objects: {len(stats['top_objects'])}")

    print("\n=== Split Distribution ===")
    print(loader.get_summary().to_string(index=False))

    print("\n=== Top 10 Spatial Predicates ===")
    top_predicates = sorted(stats['predicate_distribution'].items(),
                           key=lambda x: x[1], reverse=True)[:10]
    for predicate, count in top_predicates:
        print(f"{predicate}: {count:,}")

    print("\n=== Top 10 Objects ===")
    top_objects = list(stats['top_objects'].items())[:10]
    for obj_name, count in top_objects:
        print(f"{obj_name}: {count:,}")


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from typing import Tuple, List, Optional
from dataclasses import dataclass
from tqdm import tqdm
import time
from collections import Counter
from skimage.draw import line
from scipy.spatial import ConvexHull

@dataclass
class ReferencePoint:
    x: float
    y: float
    method: str

class HyperFastExtendedRLM:

    def __init__(self, num_directions: int = 120):  
        self.theta_step = 2 * np.pi / num_directions
        self.directions = np.linspace(0, 2*np.pi - self.theta_step, num_directions)

        self.cos_dirs = np.cos(self.directions)
        self.sin_dirs = np.sin(self.directions)
        self.direction_vectors = np.column_stack([self.cos_dirs, self.sin_dirs])

    def _compute_detailed_convex_hull(self, bbox: 'BoundingBox', num_points: int = 12) -> np.ndarray:
        # Générer plus de points sur le périmètre de la bbox
        x1, y1, x2, y2 = bbox.x1, bbox.y1, bbox.x2, bbox.y2

        points = []

        # Points sur les 4 côtés
        points_per_side = num_points // 4

        # Côté bas (y=y1)
        for i in range(points_per_side):
            x = x1 + (x2 - x1) * i / (points_per_side - 1) if points_per_side > 1 else x1
            points.append([x, y1])

        # Côté droit (x=x2)
        for i in range(1, points_per_side):  # Skip corner
            y = y1 + (y2 - y1) * i / (points_per_side - 1)
            points.append([x2, y])

        # Côté haut (y=y2)
        for i in range(1, points_per_side):  # Skip corner
            x = x2 - (x2 - x1) * i / (points_per_side - 1)
            points.append([x, y2])

        # Côté gauche (x=x1)
        for i in range(1, points_per_side):  # Skip corner
            y = y2 - (y2 - y1) * i / (points_per_side - 1)
            points.append([x1, y])

        points = np.array(points)

        try:
            hull = ConvexHull(points)
            return points[hull.vertices]
        except:
            # Fallback vers points de base
            return np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]])

    def _line_intersection_fast(self, p1, p2, p3, p4):
        """Intersection rapide entre deux lignes définies par (p1,p2) et (p3,p4)"""
        x1, y1 = p1
        x2, y2 = p2
        x3, y3 = p3
        x4, y4 = p4

        denom = (x1-x2)*(y3-y4) - (y1-y2)*(x3-x4)
        if abs(denom) < 1e-10:
            return None

        t = ((x1-x3)*(y3-y4) - (y1-y3)*(x3-x4)) / denom
        u = -((x1-x2)*(y1-y3) - (y1-y2)*(x1-x3)) / denom

        if 0 <= t <= 1 and 0 <= u <= 1:
            px = x1 + t*(x2-x1)
            py = y1 + t*(y2-y1)
            return np.array([px, py])
        return None

    def _compute_reference_point_article_compliant(self, subject_bbox: 'BoundingBox',
                                                  object_bbox: 'BoundingBox') -> ReferencePoint:
        """ Calcul EXACT du point de référence selon l'article (équation 4)"""
        # Centres des objets
        CA = np.array([subject_bbox.center[0], subject_bbox.center[1]])
        CB = np.array([object_bbox.center[0], object_bbox.center[1]])

        # Si les centres coïncident
        if np.linalg.norm(CA - CB) < 1e-6:
            return ReferencePoint(CA[0], CA[1], "coincident_centers")

        #  VRAIS Convex hulls avec scipy
        hull_A = self._compute_detailed_convex_hull(subject_bbox, num_points=16)
        hull_B = self._compute_detailed_convex_hull(object_bbox, num_points=16)

        # Ligne entre les centres
        direction = CB - CA
        direction_norm = direction / np.linalg.norm(direction)

        # Trouver les intersections avec les convex hulls
        IA = CA.copy()
        IB = CB.copy()

        #  Intersection avec hull_A (plus précise)
        max_distance_A = 0
        for i in range(len(hull_A)):
            p1 = hull_A[i]
            p2 = hull_A[(i+1) % len(hull_A)]
            intersection = self._line_intersection_fast(CA, CB, p1, p2)
            if intersection is not None:
                # Vérifier que c'est dans la bonne direction
                to_intersection = intersection - CA
                if np.dot(to_intersection, direction_norm) > 0:
                    distance = np.linalg.norm(to_intersection)
                    if distance > max_distance_A:
                        max_distance_A = distance
                        IA = intersection

        #  Intersection avec hull_B (plus précise)
        max_distance_B = 0
        for i in range(len(hull_B)):
            p1 = hull_B[i]
            p2 = hull_B[(i+1) % len(hull_B)]
            intersection = self._line_intersection_fast(CB, CA, p1, p2)  # Direction inversée
            if intersection is not None:
                # Vérifier que c'est dans la bonne direction
                to_intersection = intersection - CB
                if np.dot(to_intersection, -direction_norm) > 0:
                    distance = np.linalg.norm(to_intersection)
                    if distance > max_distance_B:
                        max_distance_B = distance
                        IB = intersection

        # Point de référence = milieu de [IA, IB] (équation 4 de l'article)
        Rp = (IA + IB) / 2
        return ReferencePoint(Rp[0], Rp[1], "article_compliant_convex_hull")

    def _create_binary_mask_higher_res(self, bbox: 'BoundingBox', img_width: int, img_height: int) -> Tuple[np.ndarray, np.ndarray]:
        """ Masque avec résolution plus élevée pour plus de précision"""
        resolution = 200  # Augmenté pour plus de précision
        mask = np.zeros((resolution, resolution), dtype=bool)

        # Coordonnées normalisées avec plus de précision
        x1_norm = max(0, min(resolution-1, int(bbox.x1 / img_width * resolution)))
        x2_norm = max(0, min(resolution-1, int(bbox.x2 / img_width * resolution)))
        y1_norm = max(0, min(resolution-1, int(bbox.y1 / img_height * resolution)))
        y2_norm = max(0, min(resolution-1, int(bbox.y2 / img_height * resolution)))

        # S'assurer que la bbox a une taille minimale
        if x2_norm <= x1_norm:
            x2_norm = x1_norm + 1
        if y2_norm <= y1_norm:
            y2_norm = y1_norm + 1

        mask[y1_norm:y2_norm+1, x1_norm:x2_norm+1] = True

        scale_info = np.array([resolution / img_width, resolution / img_height, resolution, resolution])
        return mask, scale_info

    def _compute_rlm_histogram_article_accurate(self, bbox_center: np.ndarray, bbox_dims: np.ndarray,
                                               reference_point: ReferencePoint, img_width: int, img_height: int,
                                               bbox: 'BoundingBox') -> np.ndarray:
        """ RLM selon l'article : M(Rp,θi) = |line(Rp,θi)∩X| avec 120 directions"""
        histogram = np.zeros(self.num_directions)

        # Masque de l'objet avec résolution plus élevée
        object_mask, scale_info = self._create_binary_mask_higher_res(bbox, img_width, img_height)
        mask_height, mask_width = object_mask.shape
        scale_x, scale_y = scale_info[0], scale_info[1]

        # Point de référence dans le masque
        rp_x = int(reference_point.x * scale_x)
        rp_y = int(reference_point.y * scale_y)

        #  Pour chaque direction (maintenant 120)
        for i, theta in enumerate(self.directions):
            # Ligne radiale jusqu'au bord de l'image
            max_length = max(mask_width, mask_height) * 2
            end_x = rp_x + max_length * np.cos(theta)
            end_y = rp_y + max_length * np.sin(theta)

            # Points de la ligne avec plus de précision
            if 0 <= rp_x < mask_width and 0 <= rp_y < mask_height:
                try:
                    rr, cc = line(rp_y, rp_x, int(end_y), int(end_x))
                    # Filtrer les points dans les limites
                    valid_mask = (rr >= 0) & (rr < mask_height) & (cc >= 0) & (cc < mask_width)
                    rr, cc = rr[valid_mask], cc[valid_mask]

                    #  Intersection avec l'objet (longueur de ligne)
                    if len(rr) > 0:
                        intersection_points = object_mask[rr, cc]
                        # Compter les segments continus (plus fidèle à l'article)
                        intersection_length = np.sum(intersection_points)
                        histogram[i] = intersection_length
                except Exception as e:
                    histogram[i] = 0

        #  Normalisation selon l'article
        max_val = np.max(histogram)
        if max_val > 0:
            histogram = histogram / max_val
        else:
            # Si pas d'intersection, utiliser une distribution uniforme faible
            histogram = np.ones(self.num_directions) * 0.01

        return histogram

    def _compute_force_histogram_120_dirs(self, subject_center: np.ndarray, object_center: np.ndarray,
                                         reference_point: ReferencePoint, force_type: int = 2) -> np.ndarray:
        """ Forces avec 120 directions selon l'article (équations 6-8)"""
        histogram = np.zeros(self.num_directions)

        # Vecteur de force
        force_vector = object_center - subject_center
        distance = np.linalg.norm(force_vector)

        if distance < 1e-6:
            return histogram

        #  Pour chaque direction (120)
        for i, theta in enumerate(self.directions):
            direction_vector = np.array([np.cos(theta), np.sin(theta)])

            # Projection de la force sur la direction
            projection = np.dot(force_vector, direction_vector)

            # Force selon le type (équations de l'article)
            if force_type == 0:  # f0 : force constante
                force_magnitude = abs(projection) / distance
            else:  # f2 : force gravitationnelle (équation 8)
                force_magnitude = abs(projection) / (distance ** 2 + 1e-6)

            histogram[i] = force_magnitude

        # Normalisation
        max_val = np.max(histogram)
        if max_val > 0:
            histogram = histogram / max_val

        return histogram

    def extract_features(self, subject_bbox: 'BoundingBox', object_bbox: 'BoundingBox',
                        img_width: int, img_height: int) -> dict:
        """ Extraction complète selon l'article avec 120 directions et vrai ConvexHull"""

        #  Point de référence EXACT selon l'article avec vrai ConvexHull
        ref_point = self._compute_reference_point_article_compliant(subject_bbox, object_bbox)

        subject_center = np.array([subject_bbox.center[0], subject_bbox.center[1]])
        object_center = np.array([object_bbox.center[0], object_bbox.center[1]])
        subject_dims = np.array([subject_bbox.width, subject_bbox.height])
        object_dims = np.array([object_bbox.width, object_bbox.height])

        #  RLM histogrammes avec 120 directions et vrai masking
        subject_rlm = self._compute_rlm_histogram_article_accurate(
            subject_center, subject_dims, ref_point, img_width, img_height, subject_bbox
        )
        object_rlm = self._compute_rlm_histogram_article_accurate(
            object_center, object_dims, ref_point, img_width, img_height, object_bbox
        )

        #  Forces histogrammes avec 120 directions
        force_f0 = self._compute_force_histogram_120_dirs(subject_center, object_center, ref_point, force_type=0)
        force_f2 = self._compute_force_histogram_120_dirs(subject_center, object_center, ref_point, force_type=2)

        return {
            'reference_point': ref_point,
            'subject_rlm': subject_rlm,        # 120 features
            'object_rlm': object_rlm,          # 120 features
            'force_f0': force_f0,              
            'force_f2': force_f2,              # 120 features
            'combined_rlm_f2': np.concatenate([subject_rlm, object_rlm, force_f2]),  # 360 features
            'directions': self.directions       # 120 directions
        }


class HyperFastPHIDescriptor:

    def __init__(self):
        # 13 relations d'Allen
        self.allen_relations = [
            'before', 'meets', 'overlaps', 'finished_by', 'contains',
            'starts', 'equals', 'started_by', 'during', 'finishes',
            'overlapped_by', 'met_by', 'after'
        ]
        self.n_relations = len(self.allen_relations)
        self.relation_to_idx = {rel: i for i, rel in enumerate(self.allen_relations)}

    def _classify_interval_relation_instant(self, interval1: Tuple[float, float],
                                           interval2: Tuple[float, float]) -> int:
        """Classification précise selon les 13 relations d'Allen"""
        a1, a2 = interval1
        b1, b2 = interval2

        eps = 1e-8  # Plus précis

        if a2 < b1 - eps: return 0
        elif abs(a2 - b1) <= eps: return 1
        elif a1 < b1 - eps and a2 > b1 + eps and a2 < b2 - eps: return 2
        elif abs(a1 - b1) <= eps and a2 < b2 - eps: return 5
        elif a1 > b1 + eps and a2 < b2 - eps: return 8
        elif a1 > b1 + eps and abs(a2 - b2) <= eps: return 9
        elif abs(a1 - b1) <= eps and abs(a2 - b2) <= eps: return 6
        elif abs(a1 - b1) <= eps and a2 > b2 + eps: return 7
        elif a1 < b1 - eps and a2 > b2 + eps: return 4
        elif a1 < b1 - eps and abs(a2 - b2) <= eps: return 3
        elif a1 < b1 - eps and a2 > b1 + eps and a2 > b2 + eps: return 10
        elif abs(a1 - b2) <= eps: return 11
        else: return 12

    def extract_features(self, subject_bbox: 'BoundingBox', object_bbox: 'BoundingBox') -> dict:
        # Intervalles horizontal et vertical
        subj_h_interval = (subject_bbox.x1, subject_bbox.x2)
        obj_h_interval = (object_bbox.x1, object_bbox.x2)
        subj_v_interval = (subject_bbox.y1, subject_bbox.y2)
        obj_v_interval = (object_bbox.y1, object_bbox.y2)

        # Classification des relations
        h_idx = self._classify_interval_relation_instant(subj_h_interval, obj_h_interval)
        v_idx = self._classify_interval_relation_instant(subj_v_interval, obj_v_interval)

        # Vecteurs features
        h_features = np.zeros(self.n_relations, dtype=np.float32)
        v_features = np.zeros(self.n_relations, dtype=np.float32)
        h_features[h_idx] = 1.0
        v_features[v_idx] = 1.0

        return {
            'horizontal_relation': self.allen_relations[h_idx],
            'vertical_relation': self.allen_relations[v_idx],
            'horizontal_features': h_features,
            'vertical_features': v_features,
            'combined_features': np.concatenate([h_features, v_features])
        }



#  Mise à jour de la fonction d'extraction principale
def extract_features_from_spatialsense_fixed(loader, max_samples=None, use_splits=True,
                                           batch_size=100, show_progress=True):
    """ Version corrigée avec 120 directions et vrai ConvexHull"""

    #  120 directions au lieu de 72
    rlm = HyperFastExtendedRLM(num_directions=120)
    phi = HyperFastPHIDescriptor()

    samples_data = []
    for img_data in tqdm(loader.images_data, desc="Collecting samples", disable=not show_progress):
        for annotation in img_data.annotations:
            if annotation.label:
                samples_data.append({
                    'img_data': img_data,
                    'annotation': annotation,
                    'split': img_data.split
                })

    if max_samples and len(samples_data) > max_samples:
        np.random.seed(42)
        samples_data = np.random.choice(samples_data, max_samples, replace=False).tolist()

    n_samples = len(samples_data)
    n_rlm_features = 120 * 3  # subject_rlm + object_rlm + force_f2
    n_phi_features = 26

    rlm_features_array = np.zeros((n_samples, n_rlm_features), dtype=np.float32)
    phi_features_array = np.zeros((n_samples, n_phi_features), dtype=np.float32)
    labels = []
    splits = []

    start_time = time.time()
    processed = 0
    failed = 0

    progress_bar = tqdm(total=n_samples, desc=" Fixed RLM/PHI (120 dirs + ConvexHull)", disable=not show_progress)

    for i, sample in enumerate(samples_data):
        img_data = sample['img_data']
        annotation = sample['annotation']

        try:
            #  Extraction avec les corrections
            rlm_features = rlm.extract_features(
                annotation.subject.bbox,
                annotation.object.bbox,
                img_data.width,
                img_data.height
            )

            phi_features = phi.extract_features(
                annotation.subject.bbox,
                annotation.object.bbox
            )

            rlm_features_array[i] = rlm_features['combined_rlm_f2']  # Maintenant 360D
            phi_features_array[i] = phi_features['combined_features']  # 26D
            labels.append(annotation.predicate)
            splits.append(sample['split'])
            processed += 1

        except Exception as e:
            print(f"Error processing sample {i}: {e}")
            rlm_features_array[i] = 0
            phi_features_array[i] = 0
            labels.append('unknown')
            splits.append(sample['split'])
            failed += 1

        progress_bar.update(1)

        if (i + 1) % 100 == 0:
            elapsed = time.time() - start_time
            speed = (i + 1) / elapsed
            eta = (n_samples - i - 1) / speed if speed > 0 else 0
            progress_bar.set_postfix({
                'speed': f'{speed:.1f} samples/s',
                'ETA': f'{eta/60:.1f}min',
                'failed': failed
            })

    progress_bar.close()

    total_time = time.time() - start_time
    print(f"\n Feature extraction completed:")
    print(f"   Total time: {total_time/60:.1f} minutes")
    print(f"   Processed: {processed}/{n_samples} samples")
    print(f"   Failed: {failed} samples")
    print(f"   RLM features shape: {rlm_features_array.shape} (120 directions × 3)")
    print(f"   PHI features shape: {phi_features_array.shape}")

    if failed > 0:
        valid_mask = np.array([label != 'unknown' for label in labels])
        rlm_features_array = rlm_features_array[valid_mask]
        phi_features_array = phi_features_array[valid_mask]
        labels = [label for label in labels if label != 'unknown']
        splits = [split for i, split in enumerate(splits) if valid_mask[i]]

    labels_array = np.array(labels)
    splits_array = np.array(splits)

    if use_splits:
        return rlm_features_array, phi_features_array, labels_array, splits_array
    else:
        return rlm_features_array, phi_features_array, labels_array

#  UTILISATION:
print(" Re-extracting features with corrections...")
rlm_features, phi_features, labels, splits = extract_features_from_spatialsense_fixed(
    loader, max_samples=None, use_splits=True
)

In [None]:
class HyperFastVisualizer:

    def __init__(self):
        self.colors = {
            'subject': '#FF6B6B',
            'object': '#4ECDC4',
            'reference': '#FFD93D',
            'radial_lines': '#95E1D3',
            'rlm_subject': '#FF4444',
            'rlm_object': '#44DDDD',
            'force_f0': 'blue',
            'force_f2': 'red'
        }

    def visualize_complete_analysis(self, subject_bbox, object_bbox, rlm_features, phi_features,
                                   img_width: int, img_height: int, title_suffix: str = ""):

        fig, axes = plt.subplots(2, 2, figsize=(16, 12))

        ax1 = axes[0, 0]
        self._draw_objects_and_radial_lines(ax1, subject_bbox, object_bbox,
                                           rlm_features['reference_point'],
                                           rlm_features['directions'], img_width, img_height)
        ax1.set_title(f'Objects and Reference Point\nHYPER-FAST RLM: {title_suffix}')

        ax2 = axes[0, 1]
        self._draw_radial_lines_with_hatching(ax2, subject_bbox, object_bbox,
                                             rlm_features['reference_point'],
                                             rlm_features['directions'], img_width, img_height)
        ax2.set_title('Radial Lines (sample)')

        ax3 = axes[1, 0]
        angles_deg = np.degrees(rlm_features['directions'])
        ax3.plot(angles_deg, rlm_features['subject_rlm'],
                color=self.colors['rlm_subject'], linewidth=2, label='Subject RLM')
        ax3.plot(angles_deg, rlm_features['object_rlm'],
                color=self.colors['rlm_object'], linewidth=2, label='Object RLM')
        ax3.set_xlabel('Angle (degrees)')
        ax3.set_ylabel('Normalized RLM Value')
        ax3.set_title('RLM Histograms')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        ax3.set_xlim(0, 360)
        ax3.set_ylim(0, 1.1)

        ax4 = axes[1, 1]
        ax4.plot(angles_deg, rlm_features['force_f0'],
                color=self.colors['force_f0'], linewidth=2, label='Force f0 (constant)')
        ax4.plot(angles_deg, rlm_features['force_f2'],
                color=self.colors['force_f2'], linewidth=2, label='Force f2 (gravitational)')
        ax4.set_xlabel('Angle (degrees)')
        ax4.set_ylabel('Normalized Force Value')
        ax4.set_title('Force Histograms')
        ax4.legend()
        ax4.grid(True, alpha=0.3)
        ax4.set_xlim(0, 360)
        ax4.set_ylim(0, 1.1)

        plt.tight_layout()
        return fig

    def visualize_phi_analysis(self, subject_bbox, object_bbox, phi_features, title_suffix: str = ""):

        fig, axes = plt.subplots(1, 3, figsize=(18, 6))

        ax1 = axes[0]
        max_x = max(subject_bbox.x2, object_bbox.x2) + 50
        max_y = max(subject_bbox.y2, object_bbox.y2) + 50
        self._draw_objects_only(ax1, subject_bbox, object_bbox, max_x, max_y)
        ax1.set_title('Spatial Configuration')

        ax2 = axes[1]
        y_pos = np.arange(len(self.allen_relations))
        bars = ax2.barh(y_pos, phi_features['horizontal_features'], alpha=0.7, color='skyblue')
        ax2.set_yticks(y_pos)
        ax2.set_yticklabels(self.allen_relations)
        ax2.set_xlabel('Feature Value')
        ax2.set_title(f'Horizontal Relations\n({phi_features["horizontal_relation"]})')
        ax2.set_xlim(0, 1.1)

        ax3 = axes[2]
        bars = ax3.barh(y_pos, phi_features['vertical_features'], alpha=0.7, color='orange')
        ax3.set_yticks(y_pos)
        ax3.set_yticklabels(self.allen_relations)
        ax3.set_xlabel('Feature Value')
        ax3.set_title(f'Vertical Relations\n({phi_features["vertical_relation"]})')
        ax3.set_xlim(0, 1.1)

        fig.suptitle(f'HYPER-FAST PHI: {title_suffix}', fontsize=14, fontweight='bold')
        plt.tight_layout()
        return fig

    def _draw_objects_and_radial_lines(self, ax, subject_bbox, object_bbox, ref_point, directions,
                                      img_width, img_height):
        ax.set_xlim(0, img_width)
        ax.set_ylim(img_height, 0)

        subj_rect = patches.Rectangle(
            (subject_bbox.x1, subject_bbox.y1), subject_bbox.width, subject_bbox.height,
            linewidth=2, edgecolor='black', facecolor='lightgray', alpha=0.7
        )
        ax.add_patch(subj_rect)

        obj_rect = patches.Rectangle(
            (object_bbox.x1, object_bbox.y1), object_bbox.width, object_bbox.height,
            linewidth=2, edgecolor='black', facecolor='lightcyan', alpha=0.7
        )
        ax.add_patch(obj_rect)

        ax.plot(ref_point.x, ref_point.y, 'o', color='yellow', markersize=8,
               markeredgecolor='black', markeredgewidth=2)

        ax.set_aspect('equal')

    def _draw_radial_lines_with_hatching(self, ax, subject_bbox, object_bbox, ref_point,
                                        directions, img_width, img_height):
        self._draw_objects_and_radial_lines(ax, subject_bbox, object_bbox, ref_point,
                                           directions, img_width, img_height)

        max_length = max(img_width, img_height) * 1.5

        for i, theta in enumerate(directions[::3]):
            end_x = ref_point.x + max_length * np.cos(theta)
            end_y = ref_point.y + max_length * np.sin(theta)

            ax.plot([ref_point.x, end_x], [ref_point.y, end_y],
                   color='lightgray', alpha=0.5, linewidth=0.8)

    def _draw_objects_only(self, ax, subject_bbox, object_bbox, img_width, img_height):
        ax.set_xlim(0, img_width)
        ax.set_ylim(img_height, 0)

        subj_rect = patches.Rectangle(
            (subject_bbox.x1, subject_bbox.y1), subject_bbox.width, subject_bbox.height,
            linewidth=2, edgecolor='black', facecolor='lightcoral', alpha=0.7
        )
        ax.add_patch(subj_rect)

        obj_rect = patches.Rectangle(
            (object_bbox.x1, object_bbox.y1), object_bbox.width, object_bbox.height,
            linewidth=2, edgecolor='black', facecolor='lightcyan', alpha=0.7
        )
        ax.add_patch(obj_rect)

        ax.set_aspect('equal')

    @property
    def allen_relations(self):
        return [
            'before', 'meets', 'overlaps', 'finished_by', 'contains',
            'starts', 'equals', 'started_by', 'during', 'finishes',
            'overlapped_by', 'met_by', 'after'
        ]

def create_example_configurations():
    examples = []

    examples.append({
        'subject': BoundingBox(x1=235, y1=220, x2=408, y2=400),
        'object': BoundingBox(x1=50, y1=200, x2=450, y2=500),
        'img_width': 500,
        'img_height': 500,
        'title': '"cat" on "ground"'
    })

    examples.append({
        'subject': BoundingBox(x1=150, y1=100, x2=200, y2=200),
        'object': BoundingBox(x1=300, y1=120, x2=400, y2=180),
        'img_width': 500,
        'img_height': 300,
        'title': '"fork" next to "plate"'
    })

    examples.append({
        'subject': BoundingBox(x1=150, y1=50, x2=250, y2=100),
        'object': BoundingBox(x1=120, y1=200, x2=280, y2=280),
        'img_width': 400,
        'img_height': 350,
        'title': '"object" above "object"'
    })

    examples.append({
        'subject': BoundingBox(x1=180, y1=140, x2=220, y2=180),
        'object': BoundingBox(x1=100, y1=100, x2=300, y2=220),
        'img_width': 400,
        'img_height': 320,
        'title': '"object" inside "object"'
    })

    return examples

def analyze_hyper_fast_examples():

    rlm = HyperFastExtendedRLM(num_directions=120)
    phi = HyperFastPHIDescriptor()
    visualizer = HyperFastVisualizer()

    examples = create_example_configurations()

    for i, example in enumerate(examples):
        start_time = time.time()
        rlm_features = rlm.extract_features(
            example['subject'], example['object'],
            example['img_width'], example['img_height']
        )
        phi_features = phi.extract_features(example['subject'], example['object'])
        extraction_time = time.time() - start_time

        fig1 = visualizer.visualize_complete_analysis(
            example['subject'], example['object'], rlm_features, phi_features,
            example['img_width'], example['img_height'], example['title']
        )

        fig2 = visualizer.visualize_phi_analysis(
            example['subject'], example['object'], phi_features, example['title']
        )

        plt.show()

def analyze_real_spatialsense_examples(loader, num_examples=5):

    rlm = HyperFastExtendedRLM(num_directions=120)
    phi = HyperFastPHIDescriptor()
    visualizer = HyperFastVisualizer()

    examples = []
    seen_predicates = set()

    for img_data in loader.images_data:
        if len(examples) >= num_examples:
            break

        for annotation in img_data.annotations:
            if (annotation.label and annotation.predicate not in seen_predicates
                and len(examples) < num_examples):
                examples.append((img_data, annotation))
                seen_predicates.add(annotation.predicate)

    total_time = 0

    for i, (img_data, annotation) in enumerate(examples):
        title = f'"{annotation.subject.name}" {annotation.predicate} "{annotation.object.name}"'

        start_time = time.time()
        rlm_features = rlm.extract_features(
            annotation.subject.bbox, annotation.object.bbox,
            img_data.width, img_data.height
        )
        phi_features = phi.extract_features(
            annotation.subject.bbox, annotation.object.bbox
        )
        extraction_time = time.time() - start_time
        total_time += extraction_time

        fig1 = visualizer.visualize_complete_analysis(
            annotation.subject.bbox, annotation.object.bbox, rlm_features, phi_features,
            img_data.width, img_data.height, title
        )

        fig2 = visualizer.visualize_phi_analysis(
            annotation.subject.bbox, annotation.object.bbox, phi_features, title
        )

        plt.show()

analyze_hyper_fast_examples()
analyze_real_spatialsense_examples(loader, num_examples=3)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import (classification_report, confusion_matrix,
                           accuracy_score, f1_score, log_loss)
from sklearn.model_selection import train_test_split, GridSearchCV
from collections import Counter, defaultdict
import warnings
import time
from itertools import product
warnings.filterwarnings('ignore')

class AdvancedSpatialClassifier:

    def __init__(self, random_state=42):
        self.random_state = random_state
        self.scalers = {}
        self.label_encoder = LabelEncoder()
        self.models = {}
        self.training_history = {}
        self.results = {}

    def create_train_val_test_split(self, X, y, train_size=0.7, val_size=0.15, test_size=0.15):
        assert abs(train_size + val_size + test_size - 1.0) < 1e-6

        X_train, X_temp, y_train, y_temp = train_test_split(
            X, y, test_size=(val_size + test_size),
            random_state=self.random_state, stratify=y
        )

        relative_test_size = test_size / (val_size + test_size)
        X_val, X_test, y_val, y_test = train_test_split(
            X_temp, y_temp, test_size=relative_test_size,
            random_state=self.random_state, stratify=y_temp
        )

        return X_train, X_val, X_test, y_train, y_val, y_test

    def prepare_features(self, rlm_features, phi_features, feature_mode='combined'):
        if feature_mode == 'rlm':
            return rlm_features
        elif feature_mode == 'phi':
            return phi_features
        elif feature_mode == 'combined':
            return np.concatenate([rlm_features, phi_features], axis=1)
        else:
            raise ValueError("feature_mode must be 'rlm', 'phi', or 'combined'")

    def train_improved_mlp(self, X_train, y_train, X_val, y_val, X_test, y_test,
                          feature_mode='combined'):

        # Architecture plus simple et robuste
        mlp = MLPClassifier(
            hidden_layer_sizes=(128, 64) ,
            activation='relu',
            solver='adam',
            alpha=0.1,  # Augmenté pour régularisation
            batch_size=128,  # Augmenté
            learning_rate='adaptive',
            learning_rate_init=0.0001,  # Réduit
            max_iter=1,  # Une seule époque à la fois pour capturer l'évolution
            random_state=self.random_state,
            warm_start=True,  # Continuer l'entraînement
            tol=1e-4,
            beta_1=0.9,
            beta_2=0.999,
            verbose=False  # Désactiver le verbose pour éviter le spam
        )

        # Initialiser l'historique
        history = {
            'train_loss': [],
            'val_loss': [],
            'test_loss': []
        }

        max_epochs = 500
        patience = 5
        best_val_loss = float('inf')
        patience_counter = 0

        print(f"Training MLP {feature_mode} with epoch-by-epoch monitoring...")

        for epoch in range(max_epochs):
            # Entraîner une époque
            mlp.fit(X_train, y_train)

            # Calculer les loss pour cette époque
            try:
                # Training loss
                train_pred_proba = mlp.predict_proba(X_train)
                train_loss = log_loss(y_train, train_pred_proba)
                history['train_loss'].append(train_loss)

                # Validation loss
                val_pred_proba = mlp.predict_proba(X_val)
                val_loss = log_loss(y_val, val_pred_proba)
                history['val_loss'].append(val_loss)

                # Test loss
                test_pred_proba = mlp.predict_proba(X_test)
                test_loss = log_loss(y_test, test_pred_proba)
                history['test_loss'].append(test_loss)

                # Affichage périodique
                if (epoch + 1) % 10 == 0 or epoch < 5:
                    print(f"Epoch {epoch+1:3d}: Train={train_loss:.4f}, Val={val_loss:.4f}, Test={test_loss:.4f}")

                # Early stopping basé sur validation loss
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    patience_counter = 0
                else:
                    patience_counter += 1

                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break

            except Exception as e:
                print(f"Erreur à l'époque {epoch+1}: {e}")
                break

        # Évaluation finale
        val_pred = mlp.predict(X_val)
        val_acc = accuracy_score(y_val, val_pred)
        y_pred = mlp.predict(X_test)
        test_acc = accuracy_score(y_test, y_pred)

        model_key = f'MLP_{feature_mode}'
        self.models[model_key] = mlp
        self.training_history[model_key] = history
        self.results[model_key] = {
            'test_accuracy': test_acc,
            'val_accuracy': val_acc,
            'predictions': y_pred,
            'true_labels': y_test,
            'feature_mode': feature_mode
        }

        print(f"MLP {feature_mode} - Val: {val_acc:.3f}, Test: {test_acc:.3f}")
        return test_acc

    def train_improved_svm(self, X_train, y_train, X_val, y_val, X_test, y_test,
                          feature_mode='combined'):

        # Grille étendue avec plus d'options
        param_grid = {
            'C': [0.01, 0.1, 1, 10, 100, 1000],
            'gamma': ['scale', 'auto', 0.0001, 0.001, 0.01, 0.1, 1],
            'kernel': ['rbf', 'poly', 'sigmoid']
        }

        # Plus d'échantillons pour une meilleure validation
        subset_size = min(1000, len(X_train))
        indices = np.random.choice(len(X_train), subset_size, replace=False)
        X_subset = X_train[indices]
        y_subset = y_train[indices]

        print(f"Grid search on {subset_size} samples...")

        svm = SVC(random_state=self.random_state, probability=True)

        grid_search = GridSearchCV(
            svm, param_grid, cv=3,  # Plus de folds
            scoring='accuracy',
            n_jobs=-1,
            verbose=1,
            return_train_score=False
        )

        grid_search.fit(X_subset, y_subset)
        print(f"Best params: {grid_search.best_params_}")

        # Entraîner le meilleur modèle sur toutes les données
        print("Training final SVM...")
        best_svm = SVC(**grid_search.best_params_,
                      random_state=self.random_state,
                      probability=True)
        best_svm.fit(X_train, y_train)

        val_pred = best_svm.predict(X_val)
        val_acc = accuracy_score(y_val, val_pred)

        y_pred = best_svm.predict(X_test)
        test_acc = accuracy_score(y_test, y_pred)

        model_key = f'SVM_{feature_mode}'
        self.models[model_key] = best_svm
        self.results[model_key] = {
            'test_accuracy': test_acc,
            'val_accuracy': val_acc,
            'predictions': y_pred,
            'true_labels': y_test,
            'feature_mode': feature_mode,
            'best_params': grid_search.best_params_
        }

        print(f"SVM {feature_mode} - Val: {val_acc:.3f}, Test: {test_acc:.3f}")
        return test_acc

    def train_all_configurations(self, rlm_features, phi_features, labels):
        print("=== STARTING TRAINING ===\n")

        y = self.label_encoder.fit_transform(labels)
        feature_modes = ['rlm', 'phi', 'combined']

        results_summary = []

        for i, feature_mode in enumerate(feature_modes, 1):
            print(f"--- Feature mode {i}/3: {feature_mode.upper()} ---")

            X = self.prepare_features(rlm_features, phi_features, feature_mode)
            print(f"Feature shape: {X.shape}")

            X_train, X_val, X_test, y_train, y_val, y_test = self.create_train_val_test_split(X, y)
            print(f"Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")

            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_val_scaled = scaler.transform(X_val)
            X_test_scaled = scaler.transform(X_test)
            self.scalers[feature_mode] = scaler

            print(f"Training MLP {feature_mode}...")
            mlp_acc = self.train_improved_mlp(
                X_train_scaled, y_train, X_val_scaled, y_val, X_test_scaled, y_test, feature_mode
            )

            print(f"Training SVM {feature_mode}...")
            svm_acc = self.train_improved_svm(
                X_train_scaled, y_train, X_val_scaled, y_val, X_test_scaled, y_test, feature_mode
            )

            results_summary.extend([
                {'Model': 'MLP', 'Features': feature_mode, 'Test_Accuracy': mlp_acc},
                {'Model': 'SVM', 'Features': feature_mode, 'Test_Accuracy': svm_acc}
            ])

            print(f"✓ {feature_mode} completed\n")

        print("=== TRAINING COMPLETED ===")
        return pd.DataFrame(results_summary)

    def plot_training_curves(self, figsize=(18, 6)):
        mlp_models = [k for k in self.training_history.keys() if k.startswith('MLP')]

        if not mlp_models:
            print("Aucune courbe d'entraînement disponible")
            return

        n_models = len(mlp_models)
        fig, axes = plt.subplots(1, n_models, figsize=figsize)
        fig.suptitle('Courbes de Loss - Entraînement MLP', fontsize=16, fontweight='bold')

        if n_models == 1:
            axes = [axes]

        for i, model_key in enumerate(mlp_models):
            history = self.training_history[model_key]
            feature_mode = model_key.split('_')[1].upper()

            # Courbes de Loss
            ax_loss = axes[i]

            if history['train_loss']:
                epochs = range(len(history['train_loss']))
                ax_loss.plot(epochs, history['train_loss'], 'b-', label='Training Loss', linewidth=2)

            if history['val_loss']:
                epochs = range(len(history['val_loss']))
                ax_loss.plot(epochs, history['val_loss'], 'r-', label='Validation Loss', linewidth=2)

            # Test loss comme courbe évolutive maintenant !
            if history['test_loss']:
                epochs = range(len(history['test_loss']))
                ax_loss.plot(epochs, history['test_loss'], 'g--', label='Test Loss', linewidth=2, alpha=0.8)

            ax_loss.set_title(f'{feature_mode} Features - Loss Evolution', fontsize=12, fontweight='bold')
            ax_loss.set_xlabel('Epoch')
            ax_loss.set_ylabel('Loss')
            ax_loss.legend()
            ax_loss.grid(True, alpha=0.3)
            if history['train_loss']:
                ax_loss.set_xlim(0, len(history['train_loss']))

        plt.tight_layout()
        plt.show()

        # Afficher un résumé des résultats
        print("\n" + "="*60)
        print("RÉSUMÉ DES COURBES DE LOSS")
        print("="*60)
        for model_key in mlp_models:
            history = self.training_history[model_key]
            result = self.results[model_key]
            feature_mode = model_key.split('_')[1].upper()

            print(f"\n{feature_mode} Features:")
            print(f"  • Époques d'entraînement: {len(history['train_loss'])}")
            if history['train_loss']:
                print(f"  • Loss finale d'entraînement: {history['train_loss'][-1]:.4f}")
            if history['val_loss']:
                print(f"  • Loss finale de validation: {history['val_loss'][-1]:.4f}")
            if history['test_loss']:
                print(f"  • Loss finale de test: {history['test_loss'][-1]:.4f}")
            print(f"  • Accuracy validation: {result['val_accuracy']:.3f}")
            print(f"  • Accuracy test: {result['test_accuracy']:.3f}")

            if len(history['train_loss']) < 500:
                print(f"  • Early stopping à l'époque {len(history['train_loss'])}")
            else:
                print(f"  • Entraînement complet (500 époques)")
        print("="*60)

    def plot_confusion_matrices(self, figsize=(20, 15)):
        n_models = len(self.results)

        if n_models == 0:
            return

        cols = 3
        rows = (n_models + cols - 1) // cols

        fig, axes = plt.subplots(rows, cols, figsize=figsize)
        fig.suptitle('Confusion Matrices', fontsize=16)

        if rows == 1:
            axes = axes.reshape(1, -1)

        class_names = self.label_encoder.classes_

        for i, (model_key, result) in enumerate(self.results.items()):
            row, col = i // cols, i % cols

            cm = confusion_matrix(result['true_labels'], result['predictions'])

            cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

            sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Blues',
                       xticklabels=class_names, yticklabels=class_names,
                       ax=axes[row, col])

            model_name = model_key.replace('_', ' ')
            acc = result['test_accuracy']
            axes[row, col].set_title(f'{model_name}\\nAccuracy: {acc:.3f}')
            axes[row, col].set_xlabel('Predicted')
            axes[row, col].set_ylabel('True')

        for i in range(n_models, rows * cols):
            row, col = i // cols, i % cols
            axes[row, col].set_visible(False)

        plt.tight_layout()
        plt.show()

    def get_comprehensive_results(self):
        summary = []

        for model_key, result in self.results.items():
            model_type = model_key.split('_')[0]
            feature_mode = model_key.split('_')[1]

            y_true = result['true_labels']
            y_pred = result['predictions']

            report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)

            summary.append({
                'Model': model_type,
                'Features': feature_mode,
                'Test_Accuracy': result['test_accuracy'],
                'Val_Accuracy': result['val_accuracy'],
                'Macro_F1': report['macro avg']['f1-score'],
                'Weighted_F1': report['weighted avg']['f1-score'],
                'Macro_Precision': report['macro avg']['precision'],
                'Macro_Recall': report['macro avg']['recall']
            })

        df = pd.DataFrame(summary)

        df = df.sort_values('Test_Accuracy', ascending=False)

        return df

    def analyze_feature_importance(self):
        df = self.get_comprehensive_results()

        feature_analysis = df.groupby('Features').agg({
            'Test_Accuracy': ['mean', 'std'],
            'Macro_F1': ['mean', 'std']
        }).round(4)

        best_idx = df['Test_Accuracy'].idxmax()
        best_model = df.iloc[best_idx]

        return df

def run_comprehensive_spatial_classification(rlm_features, phi_features, labels):
    start_time = time.time()

    classifier = AdvancedSpatialClassifier(random_state=42)

    summary_df = classifier.train_all_configurations(rlm_features, phi_features, labels)

    detailed_df = classifier.get_comprehensive_results()

    # Afficher les courbes d'entraînement
    classifier.plot_training_curves()

    # Afficher les matrices de confusion
    classifier.plot_confusion_matrices()

    classifier.analyze_feature_importance()

    elapsed = time.time() - start_time
    print(f"\nTemps total d'entraînement: {elapsed/60:.1f} minutes")

    return classifier, detailed_df

# Exécution
Classifier, results = run_comprehensive_spatial_classification(rlm_features, phi_features, labels)

In [None]:
# ============================================================================
# ANALYSE TOP 5 ERREURS AVEC VISUALISATION D'IMAGES - VERSION CORRIGÉE
# ============================================================================

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from sklearn.metrics import confusion_matrix
from collections import Counter, defaultdict
import random
from PIL import Image

class SpatialErrorAnalyzer:
    """Analyseur d'erreurs pour les relations spatiales avec visualisation d'images"""

    def __init__(self, classifier, loader, original_labels=None):
        self.classifier = classifier
        self.loader = loader
        self.original_labels = original_labels
        self.colors = {
            'subject': '#FF6B6B',
            'object': '#4ECDC4',
            'correct': '#2ECC71',
            'error': '#E74C3C',
            'text': '#2C3E50'
        }

        # Créer un mapping direct des données
        self.create_data_mapping()

    def create_data_mapping(self):
        """Crée un mapping direct entre toutes les données disponibles"""
        print(" Création du mapping des données...")

        # Collecter toutes les annotations avec labels positifs
        self.all_annotations = []
        for img_idx, img_data in enumerate(self.loader.images_data):
            for ann_idx, annotation in enumerate(img_data.annotations):
                if annotation.label:  # Seulement les labels positifs
                    self.all_annotations.append({
                        'img_data': img_data,
                        'annotation': annotation,
                        'img_idx': img_idx,
                        'ann_idx': ann_idx,
                        'global_idx': len(self.all_annotations)
                    })

        print(f" {len(self.all_annotations)} annotations trouvées")

        # Si on a les labels originaux, essayer de les mapper
        if self.original_labels is not None:
            print(f" Tentative de mapping avec {len(self.original_labels)} labels")

            # Vérifier la cohérence
            if len(self.all_annotations) >= len(self.original_labels):
                self.mapping_successful = True
                print(" Mapping réussi")
            else:
                self.mapping_successful = False
                print(f" Problème de mapping: {len(self.all_annotations)} annotations vs {len(self.original_labels)} labels")
        else:
            self.mapping_successful = False
            print(" Pas de labels originaux fournis")

    def get_best_model_results(self):
        """Récupère les résultats du meilleur modèle"""
        best_results = self.classifier.get_comprehensive_results().iloc[0]
        model_key = f"{best_results['Model']}_{best_results['Features']}"
        return self.classifier.results[model_key], best_results

    def analyze_top_errors(self, top_k=5):
        """Analyse les top K erreurs les plus fréquentes"""
        model_results, best_model_info = self.get_best_model_results()

        y_true = model_results['true_labels']
        y_pred = model_results['predictions']

        # Matrice de confusion
        cm = confusion_matrix(y_true, y_pred)
        class_names = self.classifier.label_encoder.classes_

        # Trouver les erreurs (hors diagonale)
        errors = []
        for i in range(len(class_names)):
            for j in range(len(class_names)):
                if i != j and cm[i, j] > 0:  # Erreur
                    errors.append({
                        'true_class': class_names[i],
                        'pred_class': class_names[j],
                        'count': cm[i, j],
                        'true_idx': i,
                        'pred_idx': j
                    })

        # Trier par fréquence d'erreur
        errors = sorted(errors, key=lambda x: x['count'], reverse=True)
        top_errors = errors[:top_k]

        print(f"=== TOP {top_k} ERREURS DE CLASSIFICATION ===")
        print(f"Modèle: {best_model_info['Model']} {best_model_info['Features']} - Accuracy: {best_model_info['Test_Accuracy']:.3f}")
        print(f"Total erreurs analysées: {len(errors)}\n")

        for i, error in enumerate(top_errors, 1):
            total_true = np.sum(cm[error['true_idx'], :])
            error_rate = error['count'] / total_true * 100 if total_true > 0 else 0
            print(f"{i}. '{error['true_class']}' prédit comme '{error['pred_class']}'")
            print(f"   Occurrences: {error['count']} ({error_rate:.1f}% des '{error['true_class']}')\n")

        return top_errors, model_results

    def find_error_examples_direct(self, error_info, model_results, max_examples=3):
        """Trouve des exemples d'erreurs en utilisant une approche directe"""
        y_true = model_results['true_labels']
        y_pred = model_results['predictions']

        # Trouver tous les indices d'erreurs pour ce type
        error_indices = []
        for i, (true_label, pred_label) in enumerate(zip(y_true, y_pred)):
            if (true_label == error_info['true_idx'] and
                pred_label == error_info['pred_idx']):
                error_indices.append(i)


        # Si on a un mapping réussi, utiliser les vraies annotations
        if self.mapping_successful and len(error_indices) > 0:
            examples = []
            for idx in error_indices[:max_examples]:
                if idx < len(self.all_annotations):
                    examples.append(self.all_annotations[idx])
                else:
                    print(f" Index {idx} hors limites")
            return examples

        # Sinon, créer des exemples synthétiques basés sur les erreurs les plus communes
        return self.create_synthetic_examples(error_info, max_examples)

    def create_synthetic_examples(self, error_info, max_examples=3):
        """Crée des exemples synthétiques pour illustrer les erreurs"""

        # Chercher des annotations qui correspondent approximativement
        examples = []
        seen_combinations = set()

        for ann_data in self.all_annotations:
            if len(examples) >= max_examples:
                break

            annotation = ann_data['annotation']

            # Éviter les doublons
            key = (annotation.subject.name, annotation.object.name, annotation.predicate)
            if key in seen_combinations:
                continue
            seen_combinations.add(key)

            # Chercher des annotations qui pourraient être confondues
            if (self.could_be_confused(annotation.predicate, error_info['true_class'], error_info['pred_class']) or
                len(examples) == 0):  # Au moins un exemple même s'il n'est pas parfait
                examples.append(ann_data)

        # Si on n'a toujours rien, prendre les premiers exemples disponibles
        if not examples and self.all_annotations:
            examples = self.all_annotations[:max_examples]
            print(f" Utilisation d'exemples génériques")

        return examples

    def could_be_confused(self, predicate, true_class, pred_class):
        """Détermine si un prédicat pourrait être confondu avec les classes d'erreur"""
        confusion_groups = {
            'on': ['above', 'on top of', 'over'],
            'in': ['inside', 'within', 'contained in'],
            'behind': ['back of', 'rear of'],
            'in front of': ['front of', 'ahead of'],
            'above': ['over', 'on top of', 'on'],
            'below': ['under', 'beneath'],
            'next to': ['beside', 'adjacent to', 'near'],
            'left': ['left of', 'to the left'],
            'right': ['right of', 'to the right']
        }

        true_group = confusion_groups.get(true_class, [true_class])
        pred_group = confusion_groups.get(pred_class, [pred_class])

        return predicate in true_group or predicate in pred_group

    def load_original_image(self, img_data):
        """Charge l'image originale depuis le fichier"""
        try:
            # Essayer de trouver l'image dans les dossiers
            image_path = self.loader._find_image_path(img_data.url)
            if image_path and image_path.exists():
                image = Image.open(image_path)
                return np.array(image)
            else:
                print(f" Image non trouvée: {img_data.url}")
                return None
        except Exception as e:
            print(f" Erreur chargement image: {e}")
            return None

    def visualize_error_example(self, ann_data, true_class, pred_class, example_num, ax_bbox, ax_image):
        """Visualise un exemple d'erreur avec bounding boxes et image originale"""
        img_data = ann_data['img_data']
        annotation = ann_data['annotation']

        # Dimensions de l'image
        img_width, img_height = img_data.width, img_data.height

        # === SUBPLOT 1: BOUNDING BOXES ===
        ax_bbox.set_xlim(0, img_width)
        ax_bbox.set_ylim(img_height, 0)  # Inverser Y pour correspondre aux coordonnées image
        ax_bbox.set_aspect('equal')

        # Dessiner les bounding boxes
        subject_bbox = annotation.subject.bbox
        object_bbox = annotation.object.bbox

        # Sujet (rouge)
        subject_rect = patches.Rectangle(
            (subject_bbox.x1, subject_bbox.y1),
            subject_bbox.width, subject_bbox.height,
            linewidth=3, edgecolor=self.colors['subject'],
            facecolor=self.colors['subject'], alpha=0.3
        )
        ax_bbox.add_patch(subject_rect)

        # Objet (bleu)
        object_rect = patches.Rectangle(
            (object_bbox.x1, object_bbox.y1),
            object_bbox.width, object_bbox.height,
            linewidth=3, edgecolor=self.colors['object'],
            facecolor=self.colors['object'], alpha=0.3
        )
        ax_bbox.add_patch(object_rect)

        # Labels sur les boxes
        ax_bbox.text(subject_bbox.x1, subject_bbox.y1-10, f'S: {annotation.subject.name}',
               fontsize=10, color=self.colors['subject'], fontweight='bold',
               bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.8))

        ax_bbox.text(object_bbox.x1, object_bbox.y1-10, f'O: {annotation.object.name}',
               fontsize=10, color=self.colors['object'], fontweight='bold',
               bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.8))

        # Titre avec l'erreur (simplifié)
        real_predicate = annotation.predicate
        ax_bbox.set_title(f'Exemple {example_num} - Annotations\n' +
                    f'Vraie relation: "{real_predicate}"\n' +
                    f'Prédiction incorrecte: {pred_class}',
                    fontsize=10, fontweight='bold',
                    color=self.colors['error'])

        # Informations supplémentaires
        ax_bbox.text(0.02, 0.98, f'Image: {img_width}×{img_height}',
               transform=ax_bbox.transAxes, fontsize=8,
               verticalalignment='top',
               bbox=dict(boxstyle="round,pad=0.3", facecolor='lightgray', alpha=0.7))

        # Désactiver les ticks pour une meilleure lisibilité
        ax_bbox.set_xticks([])
        ax_bbox.set_yticks([])

        # === SUBPLOT 2: IMAGE ORIGINALE ===
        original_image = self.load_original_image(img_data)

        if original_image is not None:
            ax_image.imshow(original_image)
            ax_image.set_title(f'Image Originale {example_num}', fontsize=10, fontweight='bold')

            # Ajouter les bounding boxes sur l'image originale aussi
            # Ajuster les coordonnées si nécessaire
            img_h, img_w = original_image.shape[:2]
            scale_x = img_w / img_width if img_width > 0 else 1
            scale_y = img_h / img_height if img_height > 0 else 1

            # Sujet (rouge) sur image originale
            subject_rect_img = patches.Rectangle(
                (subject_bbox.x1 * scale_x, subject_bbox.y1 * scale_y),
                subject_bbox.width * scale_x, subject_bbox.height * scale_y,
                linewidth=2, edgecolor=self.colors['subject'],
                facecolor='none', alpha=0.8
            )
            ax_image.add_patch(subject_rect_img)

            # Objet (bleu) sur image originale
            object_rect_img = patches.Rectangle(
                (object_bbox.x1 * scale_x, object_bbox.y1 * scale_y),
                object_bbox.width * scale_x, object_bbox.height * scale_y,
                linewidth=2, edgecolor=self.colors['object'],
                facecolor='none', alpha=0.8
            )
            ax_image.add_patch(object_rect_img)

        else:
            # Si pas d'image, créer un placeholder visuel avec les informations
            ax_image.set_xlim(0, 10)
            ax_image.set_ylim(0, 10)

            # Fond coloré selon le type d'erreur
            ax_image.add_patch(patches.Rectangle((0, 0), 10, 10,
                             facecolor='lightgray', alpha=0.3))

            # Texte informatif
            info_text = f'Image: {img_data.url.split("/")[-1]}\n'
            info_text += f'Dimensions: {img_width}×{img_height}\n'
            info_text += f'Relation: {annotation.predicate}\n'
            info_text += f'Sujet: {annotation.subject.name}\n'
            info_text += f'Objet: {annotation.object.name}'

            ax_image.text(5, 5, info_text,
                         ha='center', va='center', fontsize=9,
                         bbox=dict(boxstyle="round,pad=0.5", facecolor='white', alpha=0.8))
            ax_image.set_title(f'Info Image {example_num}\n(Prédiction: {pred_class})', fontsize=10)

        ax_image.set_xticks([])
        ax_image.set_yticks([])

    def visualize_top_errors_with_examples(self, top_k=5, examples_per_error=3):
        """Visualise les top erreurs avec des exemples d'images et images originales"""
        print(" Analyse des erreurs en cours...")

        # Analyser les erreurs
        top_errors, model_results = self.analyze_top_errors(top_k)

        if not top_errors:
            print(" Aucune erreur trouvée!")
            return top_errors, model_results

        # Créer la figure avec plus d'espace (2 lignes par erreur : bboxes + images originales)
        fig = plt.figure(figsize=(20, 8 * len(top_errors)))

        for error_idx, error_info in enumerate(top_errors):
            print(f"\n Recherche d'exemples pour: '{error_info['true_class']}' prédit comme '{error_info['pred_class']}'")

            # Trouver des exemples pour cette erreur
            examples = self.find_error_examples_direct(error_info, model_results, examples_per_error)

            if not examples:
                print(f" Aucun exemple trouvé pour cette erreur")
                continue

            # Afficher les exemples - BOUNDING BOXES (ligne du haut)
            for example_idx, example in enumerate(examples):
                # Position pour bounding boxes
                bbox_pos = (error_idx * 2) * examples_per_error + example_idx + 1
                ax_bbox = plt.subplot(len(top_errors) * 2, examples_per_error, bbox_pos)

                # Position pour image originale (ligne du bas)
                img_pos = (error_idx * 2 + 1) * examples_per_error + example_idx + 1
                ax_image = plt.subplot(len(top_errors) * 2, examples_per_error, img_pos)

                try:
                    self.visualize_error_example(
                        example,
                        error_info['true_class'],
                        error_info['pred_class'],
                        example_idx + 1,
                        ax_bbox,  # Axes pour bounding boxes
                        ax_image  # Axes pour image originale
                    )
                    print(f" Exemple {example_idx + 1} visualisé (bbox + image)")
                except Exception as e:
                    print(f" Erreur lors de la visualisation: {e}")
                    # Placeholder en cas d'erreur
                    ax_bbox.text(0.5, 0.5, f'Erreur bbox\n{str(e)[:30]}',
                               ha='center', va='center', transform=ax_bbox.transAxes)
                    ax_image.text(0.5, 0.5, f'Erreur image\n{str(e)[:30]}',
                                ha='center', va='center', transform=ax_image.transAxes)

        plt.suptitle(f'TOP {len(top_errors)} ERREURS DE CLASSIFICATION\nExemples avec Images Originales',
                    fontsize=16, fontweight='bold', y=0.98)
        plt.tight_layout()
        plt.subplots_adjust(top=0.95, hspace=0.3, wspace=0.2)
        plt.show()

        return top_errors, model_results

    def generate_error_summary_report(self, top_errors, model_results):
        """Génère un rapport de synthèse des erreurs"""
        print("\n" + "="*80)
        print(" RAPPORT DE SYNTHÈSE DES ERREURS")
        print("="*80)

        y_true = model_results['true_labels']
        y_pred = model_results['predictions']
        total_samples = len(y_true)
        total_errors = np.sum(y_true != y_pred)

        print(f" Statistiques générales:")
        print(f"   • Total échantillons: {total_samples:,}")
        print(f"   • Total erreurs: {total_errors:,} ({total_errors/total_samples*100:.1f}%)")
        print(f"   • Accuracy: {1 - total_errors/total_samples:.3f}")

        print(f"\n🎯 Top {len(top_errors)} types d'erreurs:")
        total_top_errors = sum(error['count'] for error in top_errors)

        for i, error in enumerate(top_errors, 1):
            pct_of_total_errors = error['count'] / total_errors * 100 if total_errors > 0 else 0
            pct_of_samples = error['count'] / total_samples * 100

            print(f"   {i}. '{error['true_class']}' prédit comme '{error['pred_class']}'")
            print(f"      └ {error['count']} erreurs ({pct_of_total_errors:.1f}% des erreurs, {pct_of_samples:.2f}% du dataset)")

        if total_errors > 0:
            print(f"\n Ces {len(top_errors)} types représentent {total_top_errors/total_errors*100:.1f}% de toutes les erreurs")
        print("="*80)


# ============================================================================
# EXÉCUTION DE L'ANALYSE
# ============================================================================

def run_error_analysis_with_visualization():
    """Lance l'analyse complète des erreurs avec visualisation"""

    print(" Démarrage de l'analyse d'erreurs avec visualisation...")

    # Vérifier les variables nécessaires
    required_vars = ['Classifier', 'loader']
    missing_vars = [var for var in required_vars if var not in globals()]

    if missing_vars:
        print(f" Variables manquantes: {missing_vars}")
        print("Assurez-vous d'avoir exécuté le pipeline complet d'abord")
        return None

    # Essayer de récupérer les labels originaux
    original_labels = None
    if 'labels' in globals():
        original_labels = labels
        print(" Labels originaux trouvés")
    else:
        print(" Labels originaux non trouvés")

    # Créer l'analyseur avec une approche robuste
    analyzer = SpatialErrorAnalyzer(Classifier, loader, original_labels)

    # Lancer l'analyse avec visualisation
    print("\n" + "="*60)
    top_errors, model_results = analyzer.visualize_top_errors_with_examples(
        top_k=5,
        examples_per_error=3
    )

    # Générer le rapport de synthèse
    analyzer.generate_error_summary_report(top_errors, model_results)

    print("\n Analyse d'erreurs terminée!")
    return analyzer, top_errors, model_results

# Lancer l'analyse
print(" Analyse des erreurs de classification...")
error_analyzer, top_errors, error_results = run_error_analysis_with_visualization()

In [None]:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import time
import warnings

from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import (classification_report, confusion_matrix,
                             accuracy_score, f1_score, log_loss)
from sklearn.model_selection import train_test_split, GridSearchCV

from transformers import BertTokenizer, BertModel
import torch
import torch.nn as nn

warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# BERT EXTRACTEUR DE FEATURES
class BERTFeatureExtractor:
    def __init__(self, model_name='bert-base-uncased', output_dim=128):
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.model = BertModel.from_pretrained(model_name).to(device)
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.eval()

        self.projection = nn.Sequential(
            nn.Linear(self.model.config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, output_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        ).to(device)

    def encode(self, texts, max_length=64):
        with torch.no_grad():
            inputs = self.tokenizer(texts, return_tensors='pt', padding=True,
                                    truncation=True, max_length=max_length).to(device)
            outputs = self.model(**inputs)
            cls_embeddings = outputs.last_hidden_state[:, 0, :]
            projected = self.projection(cls_embeddings)
            return projected.cpu().numpy()

# CLASSIFICATEUR COMPLET
class AdvancedSpatialClassifier:
    def __init__(self, random_state=42):
        self.random_state = random_state
        self.scalers = {}
        self.label_encoder = LabelEncoder()
        self.models = {}
        self.training_history = {}
        self.results = {}
        self.bert_extractor = BERTFeatureExtractor(output_dim=128)

    def create_train_val_test_split(self, X, y, train_size=0.7, val_size=0.15, test_size=0.15):
        X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=(val_size + test_size),
                                                            random_state=self.random_state, stratify=y)
        X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp,
                                                        test_size=test_size / (val_size + test_size),
                                                        random_state=self.random_state, stratify=y_temp)
        return X_train, X_val, X_test, y_train, y_val, y_test

    def prepare_features(self, rlm_features, phi_features, texts=None, feature_mode='combined'):
        if feature_mode == 'rlm':
            return rlm_features
        elif feature_mode == 'phi':
            return phi_features
        elif feature_mode == 'combined':
            return np.concatenate([rlm_features, phi_features], axis=1)
        elif feature_mode == 'with_bert':
            assert texts is not None
            bert_feats = self.bert_extractor.encode(texts)
            return np.concatenate([rlm_features, phi_features, bert_feats], axis=1)
        else:
            raise ValueError("feature_mode must be 'rlm', 'phi', 'combined', or 'with_bert'")

    def train_mlp_and_svm(self, X, y, feature_mode):
        X_train, X_val, X_test, y_train, y_val, y_test = self.create_train_val_test_split(X, y)

        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_val_scaled = scaler.transform(X_val)
        X_test_scaled = scaler.transform(X_test)

        self.scalers[feature_mode] = scaler

        mlp = MLPClassifier(hidden_layer_sizes=(512,256 ,128, 64), activation='relu',
                            solver='adam', alpha=0.1, batch_size=32,
                            learning_rate='adaptive', learning_rate_init=0.0001,
                            max_iter=500, random_state=self.random_state)
        mlp.fit(X_train_scaled, y_train)
        y_pred = mlp.predict(X_test_scaled)
        acc = accuracy_score(y_test, y_pred)

        self.models[f'MLP_{feature_mode}'] = mlp
        self.results[f'MLP_{feature_mode}'] = {
            'test_accuracy': acc,
            'predictions': y_pred,
            'true_labels': y_test
        }

        return acc

    def train_all_configurations(self, rlm_features, phi_features, labels, texts):
        print("=== TRAINING ALL CONFIGURATIONS ===")

        y = self.label_encoder.fit_transform(labels)
        modes = ['rlm', 'phi', 'combined', 'with_bert']
        summary = []

        for mode in modes:
            print(f"--- {mode.upper()} ---")
            X = self.prepare_features(rlm_features, phi_features, texts, mode)
            acc = self.train_mlp_and_svm(X, y, mode)
            summary.append({'Model': 'MLP', 'Features': mode, 'Test_Accuracy': acc})

        return pd.DataFrame(summary)

    def plot_confusion_matrices(self):
        for model_key, res in self.results.items():
            cm = confusion_matrix(res['true_labels'], res['predictions'])
            cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Blues')
            plt.title(f"{model_key} - Confusion Matrix")
            plt.xlabel("Predicted")
            plt.ylabel("True")
            plt.show()

# UTILISATION
# ===========

def generate_basic_texts(subjects, objects):
    return [f"{subj} {obj}" for subj, obj in zip(subjects, objects)]

def run_classifier_with_bert(rlm_features, phi_features, labels, subjects, objects):
    texts = generate_basic_texts(subjects, objects)
    classifier = AdvancedSpatialClassifier()
    summary = classifier.train_all_configurations(rlm_features, phi_features, labels, texts)
    print(summary)
    classifier.plot_confusion_matrices()
    return classifier, summary

classifier, results = run_classifier_with_bert(rlm_features, phi_features, labels, subjects, objects)

In [None]:
# ============================================================================
# CODE COMPLET CORRIGÉ - RELATIONS DIRECTIONNELLES SANS DATA LEAKAGE
# ============================================================================

!pip install -q sentence-transformers transformers torch scikit-learn

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random
import torch
import time
import warnings
import pickle
import os
from sentence_transformers import SentenceTransformer
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import (classification_report, confusion_matrix,
                           accuracy_score, f1_score, log_loss)
from sklearn.model_selection import train_test_split, GridSearchCV
from collections import Counter, defaultdict
from typing import List, Dict
from tqdm import tqdm

warnings.filterwarnings('ignore')

# Configuration globale
GLOBAL_SEED = 42

def set_all_seeds(seed=GLOBAL_SEED):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_all_seeds(GLOBAL_SEED)

# ============================================================================
# 1. FILTRAGE RELATIONS DIRECTIONNELLES BASIQUES
# ============================================================================

class BasicDirectionalFilter:
    def __init__(self, mode='4_classes'):
        self.mode = mode

        if mode == '4_classes':
            self.relation_mapping = {
                'on': 'ABOVE',
                'above': 'ABOVE',
                'under': 'BELOW',
                'next to': 'BESIDE',
                'to the left of': 'BESIDE',
                'to the right of': 'BESIDE',
                'behind': 'DEPTH',
                'in front of': 'DEPTH',
            }
            self.target_relations = ['ABOVE', 'BELOW', 'BESIDE', 'DEPTH']
        else:
            self.relation_mapping = {
                'on': 'ABOVE',
                'above': 'ABOVE',
                'under': 'BELOW',
                'to the left of': 'LEFT',
                'to the right of': 'RIGHT',
                'next to': 'BESIDE',
                'behind': 'BEHIND',
                'in front of': 'FRONT',
            }
            self.target_relations = ['ABOVE', 'BELOW', 'LEFT', 'RIGHT', 'BESIDE', 'BEHIND', 'FRONT']

        print(f"Mode: {mode} - {len(self.target_relations)} classes")
        print("Mapping:")
        for orig, target in self.relation_mapping.items():
            print(f"  {orig} -> {target}")

    def filter_and_map_relations(self, rlm_features, phi_features, labels, splits=None):
        print(f"\nFiltrage dataset original: {len(labels)} échantillons")

        # Analyser distribution originale
        original_counts = Counter(labels)
        print("Distribution originale (top 10):")
        for rel, count in original_counts.most_common(10):
            print(f"  {rel}: {count}")

        # Filtrer les relations qu'on garde
        keep_mask = np.array([label in self.relation_mapping for label in labels])

        filtered_rlm = rlm_features[keep_mask]
        filtered_phi = phi_features[keep_mask]
        filtered_labels_original = labels[keep_mask]
        filtered_splits = splits[keep_mask] if splits is not None else None

        # Mapper vers nouvelles classes
        mapped_labels = np.array([self.relation_mapping[label] for label in filtered_labels_original])

        # Compter la distribution finale
        final_counts = Counter(mapped_labels)
        print(f"\nDataset filtré: {len(mapped_labels)} échantillons")
        total = len(mapped_labels)
        for relation in self.target_relations:
            count = final_counts.get(relation, 0)
            pct = count/total*100 if total > 0 else 0
            print(f"  {relation}: {count} ({pct:.1f}%)")

        # Équilibre des classes
        counts = list(final_counts.values())
        ratio = max(counts) / min(counts) if min(counts) > 0 else float('inf')
        print(f"Ratio déséquilibre: {ratio:.1f}")

        if filtered_splits is not None:
            return filtered_rlm, filtered_phi, mapped_labels, filtered_splits
        else:
            return filtered_rlm, filtered_phi, mapped_labels

# ============================================================================
# 2. BERT CONTEXTUEL (SANS DATA LEAKAGE)
# ============================================================================

class ContextualBERTEmbedder:
    def __init__(self, model_name='all-MiniLM-L6-v2'):
        print(f"Chargement BERT contextuel: {model_name}")
        self.model = SentenceTransformer(model_name)

    def create_contextual_descriptions(self, rlm_features, phi_features):
        """Crée des descriptions contextuelles basées sur les features, pas les labels"""
        descriptions = []

        for i in range(len(rlm_features)):
            # Analyser les features RLM (directions dominantes)
            rlm_vec = rlm_features[i]
            if hasattr(rlm_vec, '__len__') and len(rlm_vec) >= 120:
                # Pour RLM avec 120 directions
                top_directions = np.argsort(rlm_vec)[-3:]  # Top 3 directions
                dominant_dir = top_directions[-1]

                # Convertir en descriptions spatiales génériques
                if dominant_dir < 30:  # ~0-90 degrés
                    spatial_desc = "rightward spatial configuration"
                elif dominant_dir < 60:  # ~90-180 degrés
                    spatial_desc = "upward spatial configuration"
                elif dominant_dir < 90:  # ~180-270 degrés
                    spatial_desc = "leftward spatial configuration"
                else:  # ~270-360 degrés
                    spatial_desc = "downward spatial configuration"
            else:
                spatial_desc = "spatial configuration"

            # Analyser les features PHI (relations d'Allen)
            phi_vec = phi_features[i]
            if hasattr(phi_vec, '__len__') and len(phi_vec) >= 26:
                h_dominant = np.argmax(phi_vec[:13])
                v_dominant = np.argmax(phi_vec[13:])

                # Descriptions génériques basées sur les intervalles
                h_desc = "overlapping" if h_dominant in [2, 6, 10] else "adjacent"
                v_desc = "overlapping" if v_dominant in [2, 6, 10] else "separated"

                interval_desc = f"{h_desc} horizontally and {v_desc} vertically"
            else:
                interval_desc = "with interval relationships"

            # Description finale contextuelle (SANS révéler la relation cible)
            desc = f"Objects in {spatial_desc} {interval_desc} with geometric properties"
            descriptions.append(desc)

        return descriptions

    def get_contextual_embeddings(self, rlm_features, phi_features):
        """Génère des embeddings BERT contextuels sans data leakage"""
        descriptions = self.create_contextual_descriptions(rlm_features, phi_features)
        embeddings = self.model.encode(descriptions, show_progress_bar=False, convert_to_numpy=True)
        return embeddings

# ============================================================================
# 3. CLASSIFICATEUR CORRIGÉ
# ============================================================================

class CorrectedDirectionalClassifier:
    def __init__(self, random_state=GLOBAL_SEED, use_contextual_bert=False):
        self.random_state = random_state
        self.use_contextual_bert = use_contextual_bert
        self.label_encoder = LabelEncoder()
        self.models = {}
        self.results = {}

        if use_contextual_bert:
            self.bert_embedder = ContextualBERTEmbedder()

        set_all_seeds(random_state)

    def train_and_evaluate(self, rlm_features, phi_features, labels):
        print("\n=== ENTRAÎNEMENT CORRIGÉ (SANS DATA LEAKAGE) ===")

        y = self.label_encoder.fit_transform(labels)
        unique_labels = self.label_encoder.classes_
        n_classes = len(unique_labels)

        print(f"Classes: {n_classes}")
        for i, label in enumerate(unique_labels):
            count = np.sum(y == i)
            print(f"  {label}: {count}")

        # Split avec seed fixe
        set_all_seeds(self.random_state)
        train_idx, temp_idx = train_test_split(
            np.arange(len(y)), test_size=0.3, random_state=self.random_state, stratify=y
        )
        val_idx, test_idx = train_test_split(
            temp_idx, test_size=0.5, random_state=self.random_state, stratify=y[temp_idx]
        )

        print(f"Split: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")

        results_summary = []

        # Configurations sans data leakage
        configs = [
            ('rlm_only', 'RLM seulement'),
            ('phi_only', 'PHI seulement'),
            ('combined', 'RLM+PHI'),
        ]

        if self.use_contextual_bert:
            configs.append(('combined_bert_contextual', 'RLM+PHI+BERT contextuel'))

        for feature_mode, description in configs:
            print(f"\nConfig: {description}")

            # Préparer les features SANS data leakage
            if feature_mode == 'rlm_only':
                X = rlm_features
            elif feature_mode == 'phi_only':
                X = phi_features
            elif feature_mode == 'combined':
                X = np.concatenate([rlm_features, phi_features], axis=1)
            elif feature_mode == 'combined_bert_contextual':
                # BERT contextuel basé sur les features, pas les labels
                print("  Génération BERT contextuel...")
                bert_contextual = self.bert_embedder.get_contextual_embeddings(rlm_features, phi_features)
                X = np.concatenate([rlm_features, phi_features, bert_contextual], axis=1)

            print(f"  Features shape: {X.shape}")

            # Split des données
            X_train, X_val, X_test = X[train_idx], X[val_idx], X[test_idx]
            y_train, y_val, y_test = y[train_idx], y[val_idx], y[test_idx]

            # Normalisation
            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_val_scaled = scaler.transform(X_val)
            X_test_scaled = scaler.transform(X_test)

            # Test plusieurs modèles
            models_to_test = [
                ('MLP_small', MLPClassifier(
                    hidden_layer_sizes=(64, 32),
                    random_state=self.random_state,
                    max_iter=300,
                    alpha=0.01
                )),
                ('MLP_large', MLPClassifier(
                    hidden_layer_sizes=(128, 64, 32),
                    random_state=self.random_state,
                    max_iter=500,
                    alpha=0.001
                )),
                ('SVM_rbf', SVC(
                    C=10, gamma='scale',
                    random_state=self.random_state
                )),
                ('SVM_linear', SVC(
                    C=1, kernel='linear',
                    random_state=self.random_state
                ))
            ]

            for model_name, model in models_to_test:
                set_all_seeds(self.random_state)
                model.fit(X_train_scaled, y_train)

                val_pred = model.predict(X_val_scaled)
                val_acc = accuracy_score(y_val, val_pred)
                test_pred = model.predict(X_test_scaled)
                test_acc = accuracy_score(y_test, test_pred)

                print(f"    {model_name}: Val={val_acc:.3f}, Test={test_acc:.3f}")

                results_summary.append({
                    'Config': description,
                    'Model': model_name,
                    'Features': feature_mode,
                    'Val_Acc': val_acc,
                    'Test_Acc': test_acc,
                    'BERT': 'contextual' in feature_mode
                })

                # Stocker le modèle
                key = f"{description}_{model_name}"
                self.models[key] = {'model': model, 'scaler': scaler}
                self.results[key] = {
                    'predictions': test_pred,
                    'true_labels': y_test,
                    'accuracy': test_acc,
                    'config': description
                }

        # Résultats
        results_df = pd.DataFrame(results_summary)
        results_df = results_df.sort_values('Test_Acc', ascending=False)

        print(f"\n=== RÉSULTATS CORRIGÉS ===")
        print(results_df.to_string(index=False, float_format='%.3f'))

        best_result = results_df.iloc[0]
        print(f"\nMeilleur: {best_result['Config']} {best_result['Model']} - {best_result['Test_Acc']:.3f}")

        # Analyse
        print(f"\n=== ANALYSE ===")
        best_acc = best_result['Test_Acc']
        if best_acc > 0.95:
            print("ATTENTION: Accuracy suspecte, vérifier data leakage")
        elif best_acc > 0.8:
            print("Excellent résultat pour classification 4 classes")
        elif best_acc > 0.6:
            print("Bon résultat, features discriminantes")
        else:
            print("Résultat modéré, amélioration possible")

        # Comparaison baseline vs BERT
        baseline_results = results_df[~results_df['BERT']]
        bert_results = results_df[results_df['BERT']]

        if len(baseline_results) > 0 and len(bert_results) > 0:
            best_baseline = baseline_results.iloc[0]['Test_Acc']
            best_bert = bert_results.iloc[0]['Test_Acc']
            improvement = (best_bert - best_baseline) * 100
            print(f"Amélioration BERT contextuel: {improvement:+.1f} points")

        return results_df

    def plot_confusion_matrix(self, config_key):
        """Matrice de confusion pour une configuration"""
        if config_key not in self.results:
            available = list(self.results.keys())
            print(f"Config non trouvée. Disponibles: {available[:3]}...")
            return

        result = self.results[config_key]
        y_true = result['true_labels']
        y_pred = result['predictions']

        cm = confusion_matrix(y_true, y_pred)

        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=self.label_encoder.classes_,
                   yticklabels=self.label_encoder.classes_)

        config_name = result['config']
        acc = result['accuracy']
        plt.title(f'Matrice Confusion - {config_name}\nAccuracy: {acc:.3f}')
        plt.ylabel('True')
        plt.xlabel('Predicted')
        plt.tight_layout()
        plt.show()

# ============================================================================
# 4. PIPELINE COMPLET
# ============================================================================

def run_complete_corrected_pipeline():
    print("=== PIPELINE COMPLET CORRIGÉ ===")

    # Vérifier variables d'entrée
    if 'rlm_features' not in globals() or 'phi_features' not in globals() or 'labels' not in globals():
        print("ERREUR: Variables rlm_features, phi_features, labels manquantes")
        print("Exécutez d'abord l'extraction des features")
        return None

    print(f"Dataset original: {len(labels)} échantillons, {len(set(labels))} relations")

    # Étape 1: Filtrage vers 4 classes directionnelles
    print("\n--- ÉTAPE 1: FILTRAGE ---")
    basic_filter = BasicDirectionalFilter(mode='4_classes')

    if 'splits' in globals():
        filtered_rlm, filtered_phi, filtered_labels, _ = basic_filter.filter_and_map_relations(
            rlm_features, phi_features, labels, splits
        )
    else:
        filtered_rlm, filtered_phi, filtered_labels = basic_filter.filter_and_map_relations(
            rlm_features, phi_features, labels
        )

    # Étape 2: Entraînement sans BERT (baseline)
    print("\n--- ÉTAPE 2: BASELINE SANS BERT ---")
    classifier_baseline = CorrectedDirectionalClassifier(use_contextual_bert=False)
    results_baseline = classifier_baseline.train_and_evaluate(
        filtered_rlm, filtered_phi, filtered_labels
    )

    # Étape 3: Entraînement avec BERT contextuel
    print("\n--- ÉTAPE 3: AVEC BERT CONTEXTUEL ---")
    classifier_bert = CorrectedDirectionalClassifier(use_contextual_bert=True)
    results_bert = classifier_bert.train_and_evaluate(
        filtered_rlm, filtered_phi, filtered_labels
    )

    # Étape 4: Comparaison finale
    print("\n--- ÉTAPE 4: COMPARAISON FINALE ---")
    best_baseline = results_baseline.iloc[0]
    best_bert = results_bert.iloc[0]

    print(f"Meilleur baseline: {best_baseline['Config']} {best_baseline['Model']} - {best_baseline['Test_Acc']:.3f}")
    print(f"Meilleur BERT: {best_bert['Config']} {best_bert['Model']} - {best_bert['Test_Acc']:.3f}")

    improvement = (best_bert['Test_Acc'] - best_baseline['Test_Acc']) * 100
    print(f"Amélioration BERT: {improvement:+.1f} points")

    # Matrices de confusion
    print("\n--- MATRICES DE CONFUSION ---")

    # Meilleur baseline
    baseline_key = f"{best_baseline['Config']}_{best_baseline['Model']}"
    print(f"Baseline ({baseline_key}):")
    classifier_baseline.plot_confusion_matrix(baseline_key)

    # Meilleur BERT
    bert_key = f"{best_bert['Config']}_{best_bert['Model']}"
    print(f"BERT contextuel ({bert_key}):")
    classifier_bert.plot_confusion_matrix(bert_key)

    print("=== PIPELINE TERMINÉ ===")

    return {
        'filtered_data': (filtered_rlm, filtered_phi, filtered_labels),
        'baseline_classifier': classifier_baseline,
        'bert_classifier': classifier_bert,
        'baseline_results': results_baseline,
        'bert_results': results_bert
    }

# ============================================================================
# 5. LANCEMENT
# ============================================================================

print("Démarrage pipeline corrigé...")
print(f"Seed fixe: {GLOBAL_SEED}")

# Lancer le pipeline complet
pipeline_results = run_complete_corrected_pipeline()

if pipeline_results:
    print("\nPipeline terminé avec succès!")
    print("Variables créées: pipeline_results contient tous les résultats")
else:
    print("Erreur dans le pipeline")