# 1. Configuration Initiale et Chargement des Modèles
Objectif : Cette cellule configure l'environnement et charge les modèles nécessaires (CLIP et spaCy) pour l'analyse multimodale.
Description : Configure les variables d'environnement pour la reproductibilité (graines aléatoires, déterminisme CUDA). Initialise le modèle CLIP pré-entraîné pour l'extraction de caractéristiques visuelles et textuelles, ainsi que le modèle spaCy pour le traitement du texte. Vérifie également la disponibilité et l'utilisation du GPU.

In [None]:
#!pip install pillow torchvision transformers scikit-learn
import os
import torch
import torchvision.transforms as transforms
from PIL import Image, ImageFile, ImageEnhance
import numpy as np
import pandas as pd
from transformers import CLIPModel, CLIPTokenizer, CLIPProcessor
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize, LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score, balanced_accuracy_score
from sklearn.metrics import adjusted_rand_score, confusion_matrix, precision_score, recall_score, roc_auc_score, balanced_accuracy_score
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_validate, cross_val_predict
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import make_pipeline
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
import re
import spacy
from collections import Counter
import random
from scipy.interpolate import griddata
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler

# Configuration for large images
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None

# Reproducibility configuration
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True, warn_only=True)

# GPU/CPU configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize CLIP model
MODEL_NAME = 'openai/clip-vit-base-patch32'  # Changed to smaller model
try:
    model = CLIPModel.from_pretrained(MODEL_NAME).to(device)
    tokenizer = CLIPTokenizer.from_pretrained(MODEL_NAME)
    processor = CLIPProcessor.from_pretrained(MODEL_NAME)
    print("✅ CLIP model loaded successfully")
except Exception as e:
    print(f"❌ Failed to load CLIP model: {str(e)}")
    raise

# Load spaCy model
try:
    nlp = spacy.load("en_core_web_trf")
    print("✅ spaCy model loaded successfully")
except:
    print("⏳ Downloading spaCy model...")
    os.system("python -m spacy download en_core_web_trf")
    nlp = spacy.load("en_core_web_trf")

# Ensure spaCy uses GPU if available
try:
    spacy.require_gpu()
    print("✅ spaCy using GPU")
except:
    print("⚠️ spaCy not using GPU (CUDA not found or configured)")

# 2. Chargement et Nettoyage des Données
Objectif : Cette cellule charge le jeu de données des produits, nettoie les informations textuelles et prépare les chemins d'accès aux images pour un traitement ultérieur.
Description : Lit les données à partir d'un fichier CSV. Effectue un nettoyage approfondi des champs textuels (nom du produit, description, spécifications), y compris la gestion des valeurs manquantes et l'application de règles de nettoyage spécifiques. Vérifie l'existence et la validité des fichiers image associés et filtre les entrées problématiques. Extrait également la catégorie principale de chaque produit.

In [None]:
import re
import spacy
from collections import Counter
import pandas as pd
from PIL import Image
import os

def clean_text(text):
    """Clean text by replacing specific patterns and removing unwanted symbols, numbers, and punctuation.
       Apply patterns twice to ensure complete replacement."""
    if not isinstance(text, str):
        return ""
    text = text.lower()
    all_patterns = [
        # Transformation des motifs comme iphone4s en iphone s
        (r'([a-zA-Z]+)(\d+)([a-zA-Z])', r'\1 \3'),
        # Abréviations d'indice solaire
        (r'\bpa\+{1,3}\b', 'sun protection factor'),
        # Symboles indésirables
        (r'[@*/±&%#]', ' '),  # Supprime @, *, /, ±, &, %, #
        # Codes alphanumériques non pertinents (ex. ms004pktbl, r&m0179)
        (r'\b[A-Z0-9]+[-_][A-Z0-9]+\b', ' '),
        # Nombres seuls
        (r'\b\d+\b', ' '),
        # Ponctuation spécifique
        (r'\(', ' ( '),
        (r'\)', ' ) '),
        (r'\.', ' . '),
        (r'\!', ' ! '),
        (r'\?', ' ? '),
        (r'\:', ' : '),
        (r'\,', ', '),
        # Motifs spécifiques du domaine
        (r'\b(\d+)\s*[-~to]?\s*(\d+)\s*(m|mth|mths|month|months?)\b', 'month'),
        (r'\bnewborn\s*[-~to]?\s*(\d+)\s*(m|mth|months?)\b', 'month'),
        (r'\b(nb|newborn|baby|bb|bby|babie|babies)\b', 'baby'),
        (r'\b(diaper|diapr|nappy)\b', 'diaper'),
        (r'\b(stroller|pram|buggy)\b', 'stroller'),
        (r'\b(bpa\s*free|non\s*bpa)\b', 'bisphenol a free'),
        (r'\b(\d+)\s*(oz|ounce)\b', 'ounce'),
        (r'\b(rtx\s*\d+)\b', 'ray tracing graphics'),
        (r'\b(gtx\s*\d+)\b', 'geforce graphics'),
        (r'\bnvidia\b', 'nvidia'),
        (r'\b(amd\s*radeon\s*rx\s*\d+)\b', 'amd radeon graphics'),
        (r'\b(intel\s*(core|xeon)\s*[i\d-]+)\b', 'intel processor'),
        (r'\b(amd\s*ryzen\s*[\d]+)\b', 'amd ryzen processor'),
        (r'\bssd\b', 'solid state drive'),
        (r'\bhdd\b', 'hard disk drive'),
        (r'\bwifi\s*([0-9])\b', 'wi-fi standard'),
        (r'\bbluetooth\s*(\d\.\d)\b', 'bluetooth version'),
        (r'\bethernet\b', 'ethernet'),
        (r'\bfhd\b', 'full high definition'),
        (r'\buhd\b', 'ultra high definition'),
        (r'\bqhd\b', 'quad high definition'),
        (r'\boled\b', 'organic light emitting diode'),
        (r'\bips\b', 'in-plane switching'),
        (r'\bram\b', 'random access memory'),
        (r'\bcpu\b', 'central processing unit'),
        (r'\bgpu\b', 'graphics processing unit'),
        (r'\bhdmi\b', 'high definition multimedia interface'),
        (r'\busb\s*([a-z0-9]*)\b', 'universal serial bus'),
        (r'\brgb\b', 'red green blue'),
        (r'\bfridge\b', 'refrigerator'),
        (r'\bwashing\s*machine\b', 'clothes washer'),
        (r'\bdishwasher\b', 'dish washing machine'),
        (r'\boven\b', 'cooking oven'),
        (r'\bmicrowave\b', 'microwave oven'),
        (r'\bhoover\b', 'vacuum cleaner'),
        (r'\btumble\s*dryer\b', 'clothes dryer'),
        (r'\b(a\+\++)\b', 'energy efficiency class'),
        (r'\b(\d+)\s*btu\b', 'british thermal unit'),
        (r'\bpoly\b', 'polyester'),
        (r'\bacrylic\b', 'acrylic fiber'),
        (r'\bnylon\b', 'nylon fiber'),
        (r'\bspandex\b', 'spandex fiber'),
        (r'\blycra\b', 'lycra fiber'),
        (r'\bpvc\b', 'polyvinyl chloride'),
        (r'\bvinyl\b', 'vinyl material'),
        (r'\bstainless\s*steel\b', 'stainless steel'),
        (r'\baluminum\b', 'aluminum metal'),
        (r'\bplexiglass\b', 'acrylic glass'),
        (r'\bpu\s*leather\b', 'polyurethane leather'),
        (r'\bsynthetic\s*leather\b', 'synthetic leather'),
        (r'\bfaux\s*leather\b', 'faux leather'),
        (r'\bwaterproof\b', 'water resistant'),
        (r'\bbreathable\b', 'air permeable'),
        (r'\bwrinkle-free\b', 'wrinkle resistant'),
        (r'\bSPF\b', 'sun protection factor'),
        (r'\bUV\b', 'ultraviolet'),
        (r'\bBB\s*cream\b', 'blemish balm cream'),
        (r'\bCC\s*cream\b', 'color correcting cream'),
        (r'\bHA\b', 'hyaluronic acid'),
        (r'\bAHA\b', 'alpha hydroxy acid'),
        (r'\bBHA\b', 'beta hydroxy acid'),
        (r'\bPHA\b', 'polyhydroxy acid'),
        (r'\bNMF\b', 'natural moisturizing factor'),
        (r'\bEGF\b', 'epidermal growth factor'),
        (r'\bVit\s*C\b', 'vitamin c'),
        (r'\bVit\s*E\b', 'vitamin e'),
        (r'\bVit\s*B3\b', 'niacinamide vitamin b3'),
        (r'\bVit\s*B5\b', 'panthenol vitamin b5'),
        (r'\bSOD\b', 'superoxide dismutase'),
        (r'\bQ10\b', 'coenzyme q10'),
        (r'\bFoam\s*cl\b', 'foam cleanser'),
        (r'\bMic\s*H2O\b', 'micellar water'),
        (r'\bToner\b', 'skin toner'),
        (r'\bEssence\b', 'skin essence'),
        (r'\bAmpoule\b', 'concentrated serum'),
        (r'\bCF\b', 'cruelty free'),
        (r'\bPF\b', 'paraben free'),
        (r'\bSF\b', 'sulfate free'),
        (r'\bGF\b', 'gluten free'),
        (r'\bHF\b', 'hypoallergenic formula'),
        (r'\bNT\b', 'non-comedogenic tested'),
        (r'\bAM\b', 'morning'),
        (r'\bPM\b', 'night'),
        (r'\bBID\b', 'twice daily'),
        (r'\bQD\b', 'once daily'),
        (r'\bAIR\b', 'airless pump bottle'),
        (r'\bD-C\b', 'dropper container'),
        (r'\bT-C\b', 'tube container'),
        (r'\bPDO\b', 'polydioxanone'),
        (r'\bPCL\b', 'polycaprolactone'),
        (r'\bPLLA\b', 'poly-l-lactic acid'),
        (r'\bHIFU\b', 'high-intensity focused ultrasound'),
        (r'\b(\d+)\s*fl\s*oz\b', 'fluid ounce'),
        (r'\bpH\s*bal\b', 'ph balanced'),
        (r'\b(\d+)\s*(gb|tb|mb|go|to|mo)\b', 'byte'),
        (r'\boctet\b', 'byte'),
        (r'\b(\d+)\s*y\b', 'year'),
        (r'\b(\d+)\s*mth\b', 'month'),
        (r'\b(\d+)\s*d\b', 'day'),
        (r'\b(\d+)\s*h\b', 'hour'),
        (r'\b(\d+)\s*min\b', 'minute'),
        (r'\b(\d+)\s*rpm\b', 'revolution per minute'),
        (r'\b(\d+)\s*(mw|cw|kw)\b', 'watt'),
        (r'\b(\d+)\s*(ma|ca|ka)\b', 'ampere'),
        (r'\b(\d+)\s*(mv|cv|kv)\b', 'volt'),
        (r'\b(\d+)\s*(mm|cm|m|km)\b', 'meter'),
        (r'\binch\b', 'meter'),
        (r'\b(\d+)\s*(ml|cl|dl|l|oz|gal)\b', 'liter'),
        (r'\b(gallon|ounce)\b', 'liter'),
        (r'\b(\d+)\s*(mg|cg|dg|g|kg|lb)\b', 'gram'),
        (r'\bpound\b', 'gram'),
        (r'\b(\d+)\s*(°c|°f)\b', 'celsius'),
        (r'\bfahrenheit\b', 'celsius'),
        (r'\bflipkart\.com\b', ''),
        (r'\bapprox\.?\b', 'approximately'),
        (r'\bw/o\b', 'without'),
        (r'\bw/\b', 'with'),
        (r'\bant-\b', 'anti'),
        (r'\byes\b', ''),
        (r'\bno\b', ''),
        (r'\bna\b', ''),
        (r'\brs\.?\b', ''),
        # Normaliser les espaces
        (r'\s+', ' '),
    ]
    # Apply patterns twice to ensure complete replacement
    for _ in range(2):
        for pattern, replacement in all_patterns:
            text = re.sub(pattern, replacement, text, flags=re.IGNORECASE)
    return text.strip()

def extract_keywords(text, nlp, top_n=15):
    """Extract keywords from text using lemmatization and stopword removal,
       filtering out potential product references, specific codes, and short words."""
    if not text:
        return []
    doc = nlp(text)
    keywords = []
    for token in doc:
        lemma = token.lemma_.lower().strip()
        # Skip short words, punctuation, stopwords, empty lemmas, and unwanted patterns
        if (len(lemma) < 2 or
            token.is_punct or
            not lemma or
            token.is_stop or
            re.match(r'.*[@*/±&%#].*', lemma)):  # Exclure les mots avec symboles indésirables
            continue
        keywords.append(lemma)
    keyword_counts = Counter(keywords)
    return [word for word, count in keyword_counts.most_common(top_n)]

def process_descriptions_to_keywords(df, uniq_id=None):
    """Convert processed_text to comma-separated keywords and generate keyword frequencies CSV.
       If uniq_id is provided, generate CSV for that specific product only."""
    print("⏳ Loading spaCy model...")
    try:
        nlp = spacy.load("en_core_web_trf")
        # Ensure spaCy uses GPU if available
        try:
            spacy.require_gpu()
            print("✅ spaCy using GPU")
        except:
            print("⚠️ spaCy not using GPU (CUDA not found or configured)")
    except Exception as e:
        print(f"❌ Failed to load spaCy model: {str(e)}")
        print("⏳ Downloading spaCy model...")
        os.system("python -m spacy download en_core_web_trf")
        nlp = spacy.load("en_core_web_trf")
    print("🔍 Extracting keywords from descriptions...")

    # Extract keywords for all products or a specific product
    if uniq_id is not None:
        df_subset = df[df['uniq_id'] == uniq_id].copy()
        if df_subset.empty:
            print(f"❌ No product found with uniq_id: {uniq_id}")
            return df
        df_subset['keywords'] = df_subset['processed_text'].apply(lambda x: ", ".join(extract_keywords(x, nlp)))
        df_subset['keywords'] = df_subset['keywords'].replace('', 'no_keywords_found')
        all_keywords = []
        for kws in df_subset['keywords']:
            if kws != 'no_keywords_found':
                all_keywords.extend(kws.split(", "))
        output_csv = f'keyword_frequencies_{uniq_id}.csv'
        print(f"✅ Keywords extracted for product {uniq_id}")
    else:
        df['keywords'] = df['processed_text'].apply(lambda x: ", ".join(extract_keywords(x, nlp)))
        df['keywords'] = df['keywords'].replace('', 'no_keywords_found')
        all_keywords = []
        for kws in df['keywords']:
            if kws != 'no_keywords_found':
                all_keywords.extend(kws.split(", "))
        output_csv = 'keyword_frequencies.csv'
        print(f"✅ Keywords extracted for {len(df)} products")

    # Generate keyword frequencies
    keyword_freq = Counter(all_keywords)
    keyword_freq_df = pd.DataFrame(list(keyword_freq.items()), columns=['Mot Clé', 'Fréquence'])
    keyword_freq_df = keyword_freq_df.sort_values(by='Fréquence', ascending=False)

    # Save to CSV
    keyword_freq_df.to_csv(output_csv, index=False, encoding='utf-8')
    print(f"✅ Keyword frequencies saved to {output_csv}")

    return df

def load_data(filepath, image_folder, uniq_id=None):
    """
    Load data and prepare image paths with advanced text cleaning and handling of problematic images.
    Allow description and product_specifications to be NaN or null. Resize large images to fit pixel limit.
    Suppress image size logging. Optionally process a single product by uniq_id.
    Save the updated DataFrame with keywords back to produits_original.csv.
    """
    # Load data
    df = pd.read_csv(filepath)
    original_count = len(df)
    print(f"Initial product count: {original_count}")

    # If uniq_id is provided, filter to that product
    if uniq_id is not None:
        df = df[df['uniq_id'] == uniq_id]
        if df.empty:
            raise ValueError(f"No product found with uniq_id: {uniq_id}")
        print(f"Processing single product with uniq_id: {uniq_id}")

    # Step 1: Drop rows with NaN in required fields (product_name, product_category_tree, image)
    required_columns = ['product_name', 'product_category_tree', 'image']
    df = df.dropna(subset=required_columns)
    print(f"After dropping NaN in required columns ({required_columns}): {len(df)} rows remain")
    dropped_nan = df[df[required_columns].isna().any(axis=1)]
    if not dropped_nan.empty:
        print("Dropped due to NaN in required columns:", dropped_nan['uniq_id'].tolist())

    # Step 2: Filter empty strings in required fields
    df = df[(df['product_name'] != '') & (df['product_category_tree'] != '') & (df['image'] != '')]
    print(f"After filtering empty strings in required columns: {len(df)} rows remain")
    dropped_empty = df[(df['product_name'] == '') | (df['product_category_tree'] == '') | (df['image'] == '')]
    if not dropped_empty.empty:
        print("Dropped due to empty strings in required columns:", dropped_empty['uniq_id'].tolist())

    # Step 3: Extract and clean main category
    df['main_category'] = df['product_category_tree'].str.split(' >> ').str[0]
    df['main_category'] = df['main_category'].str.replace(r'[\[\]\"\']', '', regex=True)
    df = df[df['main_category'] != '']
    print(f"After filtering empty categories: {len(df)} rows remain")
    dropped_category = df[df['main_category'] == '']
    if not dropped_category.empty:
        print("Dropped due to empty categories:", dropped_category['uniq_id'].tolist())

    # Step 4: Check image existence and validity
    df['image_path'] = df['uniq_id'].apply(lambda x: os.path.join(image_folder, f"{x}.jpg"))
    df['image_exists'] = df['image_path'].apply(lambda x: any(
        os.path.exists(os.path.join(image_folder, f"{x}.{ext}"))
        for ext in ['jpg', 'JPG', 'jpeg', 'JPEG']
    ))
    def is_valid_image(path):
        try:
            with Image.open(path) as img:
                img.verify()  # Verify image integrity
                img = Image.open(path)  # Re-open after verify
                pixel_count = img.size[0] * img.size[1]
                max_pixels = 89478485
                if pixel_count > max_pixels:
                    # Calculate scaling factor to fit within max_pixels while preserving aspect ratio
                    scale = (max_pixels / pixel_count) ** 0.5
                    new_size = (int(img.size[0] * scale), int(img.size[1] * scale))
                    img = img.resize(new_size, Image.LANCZOS)
                    # Save resized image to a temporary path to avoid modifying original
                    temp_path = path.replace('.jpg', '_resized.jpg')
                    img.save(temp_path, 'JPEG', quality=95)
                    return temp_path
                return path
        except Exception as e:
            print(f"Invalid image {path}: {str(e)}")
            return False
    df['image_valid_path'] = df['image_path'].apply(is_valid_image)
    dropped_images = df[df['image_valid_path'] == False]
    if not dropped_images.empty:
        print("Dropped due to missing or invalid images:", dropped_images['uniq_id'].tolist())
    df = df[df['image_valid_path'] != False].copy()
    df['image_path'] = df['image_valid_path']  # Update image_path with resized path if applicable
    df = df.drop(columns=['image_exists', 'image_valid_path'])
    print(f"After filtering invalid images: {len(df)} rows remain")

    # Step 5: Process text, allowing NaN for description and product_specifications
    def process_specs(spec_string):
        if not isinstance(spec_string, str):
            return ""
        matches = re.findall(r'\{"key"=>"(.*?)", "value"=>"(.*?)"\}', spec_string)
        return ". ".join(f"{k.strip().lower()} {v.strip().lower()}" for k, v in matches if k.strip() and v.strip())

    # Replace NaN with empty strings for description and product_specifications
    df['description'] = df['description'].fillna('')
    df['product_specifications'] = df['product_specifications'].fillna('')
    df['cleaned_specs'] = df['product_specifications'].apply(process_specs)
    df['combined_text'] = (
        df['product_name'].str.lower() + '. ' +
        df['brand'].fillna('').str.lower() + '. ' +
        df['cleaned_specs'].str.lower() + '. ' +
        df['description'].str.lower()
    )
    df['processed_text'] = df['combined_text'].apply(clean_text)
    dropped_text = df[(df['processed_text'].str.strip() == '') | (df['processed_text'].str.split().str.len() <= 3)]
    if not dropped_text.empty:
        print("Dropped due to empty or short text:", dropped_text['uniq_id'].tolist())
    df = df[(df['processed_text'].str.strip() != '') & (df['processed_text'].str.split().str.len() > 3)]
    print(f"After filtering short text: {len(df)} rows remain")

    # Step 6: Extract keywords
    df = process_descriptions_to_keywords(df, uniq_id=uniq_id)

    # Step 7: Save the updated DataFrame to produits_original.csv (only if processing all products)
    if uniq_id is None:
        # Keep only the original columns plus 'keywords'
        original_columns = pd.read_csv(filepath).columns.tolist()
        save_columns = original_columns + ['keywords']
        df[save_columns].to_csv(filepath, index=False, encoding='utf-8')
        print(f"✅ Updated DataFrame with keywords saved to {filepath}")

    if df.empty:
        raise ValueError("DataFrame vide après nettoyage. Vérifiez les données sources.")
    return df.reset_index(drop=True)

# 3. Fine-Tuning du Modèle CLIP
Objectif : Cette cellule adapte le modèle CLIP à la tâche spécifique de classification des produits en le fine-tunant sur le jeu de données préparé.
Description : Définit une classe personnalisée `CLIPForClassification` qui ajoute une couche de classification sur le modèle CLIP. Crée un `ProductDataset` personnalisé pour gérer le chargement des images et du texte, y compris le redimensionnement des images et la tokenisation du texte. Configure l'entraînement du modèle avec un optimiseur et un scaler pour l'autocasting GPU. Effectue le fine-tuning sur plusieurs époques et sauvegarde le modèle fine-tuné.

In [None]:
from torch.amp import autocast, GradScaler

class CLIPForClassification(CLIPModel):
    def __init__(self, config, num_labels):
        super().__init__(config)
        self.clip = CLIPModel.from_pretrained(MODEL_NAME)
        self.classifier = nn.Linear(config.projection_dim * 2, num_labels)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, pixel_values, input_ids, attention_mask, labels=None):
        outputs = self.clip(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = torch.cat((outputs.image_embeds, outputs.text_embeds), dim=-1)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss = self.loss_fn(logits, labels)

        return type('Output', (), {
            'loss': loss,
            'logits': logits,
            'image_embeds': outputs.image_embeds,
            'text_embeds': outputs.text_embeds
        })()

class ProductDataset(Dataset):
    def __init__(self, df, processor, tokenizer, max_size=128, max_length=77):  # Reduced max_size
        self.df = df
        self.processor = processor
        self.tokenizer = tokenizer
        self.max_size = max_size
        self.max_length = max_length
        self.labels = pd.factorize(df['main_category'])[0]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row['image_path']
        text = row['keywords']
        label = self.labels[idx]

        try:
            with Image.open(img_path) as img:
                if img.mode != 'RGB':
                    img = img.convert('RGB')
                if max(img.size) > self.max_size:
                    ratio = self.max_size / max(img.size)
                    new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))
                    img = img.resize(new_size, Image.LANCZOS)
                image_inputs = self.processor(images=img, return_tensors="pt", padding=True).pixel_values.squeeze(0)
        except Exception as e:
            print(f"⚠️ Skipping image {img_path}: {str(e)}")
            return None

        try:
            text_inputs = self.tokenizer(
                text,
                return_tensors="pt",
                padding='max_length',
                truncation=True,
                max_length=self.max_length
            )
            return {
                'pixel_values': image_inputs,
                'input_ids': text_inputs['input_ids'].squeeze(0),
                'attention_mask': text_inputs['attention_mask'].squeeze(0),
                'labels': torch.tensor(label, dtype=torch.long)
            }
        except Exception as e:
            print(f"⚠️ Skipping text for index {idx}: {str(e)}")
            return None

def fine_tune_clip(df, processor, tokenizer, epochs=5, batch_size=4, accum_steps=4, save_path="finetuned_clip"):
    """Fine-tune the CLIP model with a classification head and save it."""
    num_labels = len(df['main_category'].unique())
    try:
        config = CLIPModel.from_pretrained(MODEL_NAME).config
        model = CLIPForClassification(config, num_labels=num_labels).to(device)
    except Exception as e:
        print(f"❌ Failed to initialize CLIPForClassification: {str(e)}")
        raise

    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-6)
    scaler = GradScaler('cuda')
    dataset = ProductDataset(df, processor, tokenizer, max_size=128, max_length=77)

    def collate_fn(batch):
        batch = [item for item in batch if item is not None]
        if not batch:
            return None
        return {
            'pixel_values': torch.stack([item['pixel_values'] for item in batch]),
            'input_ids': torch.stack([item['input_ids'] for item in batch]),
            'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
            'labels': torch.stack([item['labels'] for item in batch])
        }

    # Split dataset into train and validation
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size], generator=torch.Generator().manual_seed(SEED)
    )

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        collate_fn=collate_fn,
        drop_last=True
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        collate_fn=collate_fn,
        drop_last=False
    )

    effective_batch_size = batch_size * accum_steps
    print(f"Using batch_size={batch_size}, accum_steps={accum_steps}, effective_batch_size={effective_batch_size}")

    # Store training history
    history = {
        'epoch': [],
        'train_loss': [],
        'val_loss': [],
        'train_accuracy': [],
        'val_accuracy': [],
        'train_f1': [],
        'val_f1': [],
        'train_precision': [],
        'val_precision': [],
        'train_recall': [],
        'val_recall': [],
        'train_balanced_accuracy': [],
        'val_balanced_accuracy': [],
        'train_roc_auc': [],
        'val_roc_auc': []
    }

    for epoch in range(epochs):
        # Training phase
        model.train()
        total_train_loss = 0
        train_preds = []
        train_labels = []
        train_probs = []

        step = 0
        optimizer.zero_grad()

        for batch in train_dataloader:
            if batch is None:
                continue
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
            labels = batch['labels'].to(device)

            with autocast('cuda'):
                outputs = model(**inputs, labels=labels)
                loss = outputs.loss

            if loss is not None:
                loss = loss / accum_steps
                scaler.scale(loss).backward()
                step += 1

                if step % accum_steps == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()

                total_train_loss += loss.item() * accum_steps

            # Store predictions for metrics
            with torch.no_grad():
                logits = outputs.logits
                preds = torch.argmax(logits, dim=1)
                train_preds.extend(preds.cpu().numpy())
                train_labels.extend(labels.cpu().numpy())
                train_probs.extend(torch.softmax(logits, dim=1).cpu().numpy())

        # Calculate training metrics
        train_accuracy = accuracy_score(train_labels, train_preds)
        train_f1 = f1_score(train_labels, train_preds, average='weighted')
        train_precision = precision_score(train_labels, train_preds, average='weighted')
        train_recall = recall_score(train_labels, train_preds, average='weighted')
        train_balanced_accuracy = balanced_accuracy_score(train_labels, train_preds)

        # Calculate ROC AUC if possible
        try:
            train_roc_auc = roc_auc_score(train_labels, train_probs, multi_class='ovr', average='weighted')
        except:
            train_roc_auc = np.nan

        # Validation phase
        model.eval()
        total_val_loss = 0
        val_preds = []
        val_labels = []
        val_probs = []

        with torch.no_grad():
            for batch in val_dataloader:
                if batch is None:
                    continue
                inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
                labels = batch['labels'].to(device)

                with autocast('cuda'):
                    outputs = model(**inputs, labels=labels)
                    loss = outputs.loss

                if loss is not None:
                    total_val_loss += loss.item()

                logits = outputs.logits
                preds = torch.argmax(logits, dim=1)
                val_preds.extend(preds.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())
                val_probs.extend(torch.softmax(logits, dim=1).cpu().numpy())

        # Calculate validation metrics
        val_accuracy = accuracy_score(val_labels, val_preds)
        val_f1 = f1_score(val_labels, val_preds, average='weighted')
        val_precision = precision_score(val_labels, val_preds, average='weighted')
        val_recall = recall_score(val_labels, val_preds, average='weighted')
        val_balanced_accuracy = balanced_accuracy_score(val_labels, val_preds)

        try:
            val_roc_auc = roc_auc_score(val_labels, val_probs, multi_class='ovr', average='weighted')
        except:
            val_roc_auc = np.nan

        # Store metrics
        history['epoch'].append(epoch + 1)
        history['train_loss'].append(total_train_loss / len(train_dataloader))
        history['val_loss'].append(total_val_loss / len(val_dataloader))
        history['train_accuracy'].append(train_accuracy)
        history['val_accuracy'].append(val_accuracy)
        history['train_f1'].append(train_f1)
        history['val_f1'].append(val_f1)
        history['train_precision'].append(train_precision)
        history['val_precision'].append(val_precision)
        history['train_recall'].append(train_recall)
        history['val_recall'].append(val_recall)
        history['train_balanced_accuracy'].append(train_balanced_accuracy)
        history['val_balanced_accuracy'].append(val_balanced_accuracy)
        history['train_roc_auc'].append(train_roc_auc)
        history['val_roc_auc'].append(val_roc_auc)

        print(f"Epoch {epoch+1}/{epochs}")
        print(f"  Train Loss: {history['train_loss'][-1]:.4f}, Val Loss: {history['val_loss'][-1]:.4f}")
        print(f"  Train Acc: {train_accuracy:.4f}, Val Acc: {val_accuracy:.4f}")
        print(f"  Train F1: {train_f1:.4f}, Val F1: {val_f1:.4f}")
        print(f"  Train-Validation Gap: {train_accuracy - val_accuracy:.4f}")

    # Save model in the Hugging Face format
    os.makedirs(save_path, exist_ok=True)
    model.save_pretrained(save_path)
    processor.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)
    print(f"✅ Fine-tuned model saved to {save_path}")

    # Save model state_dict as a single .pth file
    pth_save_path = "new_clip_product_classifier.pth"
    torch.save(model.state_dict(), pth_save_path)
    print(f"✅ Model state_dict saved to {pth_save_path}")

    # Save training history
    history_df = pd.DataFrame(history)
    history_df.to_csv(os.path.join(save_path, 'training_history.csv'), index=False)

    print(f"✅ Training history saved to {os.path.join(save_path, 'training_history.csv')}")

    return model, history

In [None]:
def plot_training_curves(history, save_path):
    """Plot training and validation curves to detect overfitting."""
    os.makedirs(save_path, exist_ok=True)

    fig, axes = plt.subplots(2, 3, figsize=(18, 12))

    # Loss curve
    axes[0, 0].plot(history['epoch'], history['train_loss'], label='Train Loss', marker='o')
    axes[0, 0].plot(history['epoch'], history['val_loss'], label='Validation Loss', marker='o')
    axes[0, 0].set_title('Training and Validation Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Accuracy curve
    axes[0, 1].plot(history['epoch'], history['train_accuracy'], label='Train Accuracy', marker='o')
    axes[0, 1].plot(history['epoch'], history['val_accuracy'], label='Validation Accuracy', marker='o')
    axes[0, 1].set_title('Training and Validation Accuracy')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # F1 score curve
    axes[0, 2].plot(history['epoch'], history['train_f1'], label='Train F1', marker='o')
    axes[0, 2].plot(history['epoch'], history['val_f1'], label='Validation F1 Score', marker='o')
    axes[0, 2].set_title('Training and Validation F1 Score')
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].set_ylabel('F1 Score')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)

    # Precision curve
    axes[1, 0].plot(history['epoch'], history['train_precision'], label='Train Precision', marker='o')
    axes[1, 0].plot(history['epoch'], history['val_precision'], label='Validation Precision', marker='o')
    axes[1, 0].set_title('Training and Validation Precision')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Precision')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

    # Recall curve
    axes[1, 1].plot(history['epoch'], history['train_recall'], label='Train Recall', marker='o')
    axes[1, 1].plot(history['epoch'], history['val_recall'], label='Validation Recall', marker='o')
    axes[1, 1].set_title('Training and Validation Recall')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Recall')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)

    # Accuracy gap (overfitting indicator)
    accuracy_gap = [train - val for train, val in zip(history['train_accuracy'], history['val_accuracy'])]
    axes[1, 2].plot(history['epoch'], accuracy_gap, label='Accuracy Gap (Train - Val)', marker='o', color='red')
    axes[1, 2].axhline(y=0, color='gray', linestyle='--', alpha=0.7)
    axes[1, 2].set_title('Accuracy Gap (Indicator of Overfitting)')
    axes[1, 2].set_xlabel('Epoch')
    axes[1, 2].set_ylabel('Accuracy Gap')
    axes[1, 2].legend()
    axes[1, 2].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'training_curves.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # Create summary table
    final_metrics = {
        'Metric': ['Loss', 'Accuracy', 'F1 Score', 'Precision', 'Recall', 'Balanced Accuracy', 'ROC AUC'],
        'Train_Final': [
            history['train_loss'][-1],
            history['train_accuracy'][-1],
            history['train_f1'][-1],
            history['train_precision'][-1],
            history['train_recall'][-1],
            history['train_balanced_accuracy'][-1],
            history['train_roc_auc'][-1] if not np.isnan(history['train_roc_auc'][-1]) else None
        ],
        'Validation_Final': [
            history['val_loss'][-1],
            history['val_accuracy'][-1],
            history['val_f1'][-1],
            history['val_precision'][-1],
            history['val_recall'][-1],
            history['val_balanced_accuracy'][-1],
            history['val_roc_auc'][-1] if not np.isnan(history['val_roc_auc'][-1]) else None
        ],
        'Gap': [
            history['train_loss'][-1] - history['val_loss'][-1],
            history['train_accuracy'][-1] - history['val_accuracy'][-1],
            history['train_f1'][-1] - history['val_f1'][-1],
            history['train_precision'][-1] - history['val_precision'][-1],
            history['train_recall'][-1] - history['val_recall'][-1],
            history['train_balanced_accuracy'][-1] - history['val_balanced_accuracy'][-1],
            (history['train_roc_auc'][-1] - history['val_roc_auc'][-1]) if not np.isnan(history['train_roc_auc'][-1]) and not np.isnan(history['val_roc_auc'][-1]) else None
        ]
    }

    summary_df = pd.DataFrame(final_metrics)
    summary_df.to_csv(os.path.join(save_path, 'final_metrics_summary.csv'), index=False)

    print(f"✅ Training curves saved to {os.path.join(save_path, 'training_curves.png')}")
    print(f"✅ Final metrics summary saved to {os.path.join(save_path, 'final_metrics_summary.csv')}")

    return summary_df

# 4. Extraction des Caractéristiques Textuelles et Visuelles
Objectif : Cette cellule utilise le modèle CLIP fine-tuné pour extraire des représentations numériques (caractéristiques) distinctes pour les modalités textuelle et visuelle de chaque produit, puis les combine.
Description : Définit des fonctions pour extraire les caractéristiques textuelles à partir des mots-clés en utilisant l'encodeur de texte de CLIP, et les caractéristiques visuelles à partir des images en utilisant l'encodeur d'image de CLIP. Normalise ces caractéristiques. Propose une méthode pour combiner ces deux types de caractéristiques en utilisant une pondération alpha, créant ainsi une représentation multimodale.

In [None]:
def extract_text_features(df, model, tokenizer):
    """Extract text features using the fine-tuned CLIP model."""
    texts = df['keywords'].tolist()
    if not texts:
        return np.array([])
    with torch.no_grad():
        inputs = tokenizer(texts, padding=True, truncation=True, max_length=77, return_tensors="pt").to(device)
        text_features = model.clip.get_text_features(**inputs)
    return text_features.cpu().numpy()

def extract_image_features(df, model, processor, max_size=128):  # Reduced max_size
    """Extract image features using the fine-tuned CLIP model."""
    features = []
    valid_indices = []
    for idx, img_path in enumerate(df['image_path']):
        try:
            with Image.open(img_path) as img:
                if img.mode != 'RGB':
                    img = img.convert('RGB')
                if max(img.size) > max_size:
                    ratio = max_size / max(img.size)
                    new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))
                    img = img.resize(new_size, Image.LANCZOS)
                with torch.no_grad():
                    inputs = processor(images=img, return_tensors="pt").to(device)
                    features.append(model.clip.get_image_features(**inputs).cpu().numpy())
                    valid_indices.append(idx)
        except Exception as e:
            print(f"⚠️ Skipping image {img_path}: {str(e)}")
            continue
    if not features:
        raise ValueError("No valid images processed.")
    return np.vstack(features), df.iloc[valid_indices].copy()

def combine_features(text_features, image_features, alpha=0.6):
    """Combine text and image features with a weighting factor."""
    text_features = normalize(text_features, norm='l2')
    image_features = normalize(image_features, norm='l2')
    min_samples = min(text_features.shape[0], image_features.shape[0])
    return normalize(alpha * text_features[:min_samples] + (1 - alpha) * image_features[:min_samples])

# 5. Évaluation et Comparaison des Modalités
Objectif : Cette cellule évalue les performances de classification en utilisant les caractéristiques textuelles, visuelles et combinées pour déterminer l'efficacité de chaque modalité et de leur combinaison.
Description : Implémente une fonction `evaluate_classification` qui utilise la validation croisée Stratified K-Fold avec un pipeline comprenant une PCA pour la réduction de dimensionnalité et un classificateur RandomForest. Calcule et affiche plusieurs métriques de performance courantes (Accuracy, F1-score, Precision, Recall, Balanced Accuracy, ARI, ROC AUC). La fonction `compare_modalities` appelle l'évaluation pour chaque ensemble de caractéristiques et sauvegarde les résultats dans un fichier CSV et génère une visualisation comparative des métriques.

In [None]:
def evaluate_classification(features, true_labels, method_name, n_splits=5):
    """Evaluate classification performance with cross-validation, including additional metrics."""
    pipeline = make_pipeline(
        PCA(n_components=0.95, random_state=SEED),
        RandomForestClassifier(n_estimators=100, random_state=SEED, max_features='sqrt', bootstrap=True)
    )
    cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=SEED)

    # Define scoring metrics
    scoring = {
        'accuracy': 'accuracy',
        'f1_weighted': 'f1_weighted',
        'precision_weighted': 'precision_weighted',
        'recall_weighted': 'recall_weighted',
        'balanced_accuracy': 'balanced_accuracy'
    }

    # Perform cross-validation
    scores = cross_validate(
        pipeline, features, true_labels, cv=cv, scoring=scoring, n_jobs=1, return_train_score=False
    )

    # Compute ARI and get predictions using cross_val_predict
    y_pred = cross_val_predict(pipeline, features, true_labels, cv=cv, n_jobs=1)
    ari_score = adjusted_rand_score(true_labels, y_pred)

    # Compute ROC AUC (One-vs-Rest) if multi-class
    try:
        le = LabelEncoder()
        y_true_encoded = le.fit_transform(true_labels)
        y_score = cross_val_predict(
            pipeline, features, true_labels, cv=cv, method='predict_proba', n_jobs=1
        )
        roc_auc = roc_auc_score(y_true_encoded, y_score, multi_class='ovr', average='weighted')
    except Exception as e:
        print(f"⚠️ ROC AUC calculation failed for {method_name}: {str(e)}")
        roc_auc = np.nan

    # Aggregate results
    avg_accuracy = np.mean(scores['test_accuracy'])
    avg_f1 = np.mean(scores['test_f1_weighted'])
    avg_precision = np.mean(scores['test_precision_weighted'])
    avg_recall = np.mean(scores['test_recall_weighted'])
    avg_balanced_accuracy = np.mean(scores['test_balanced_accuracy'])

    print(f"\nMéthode: {method_name}")
    print(f"Accuracy (CV): {avg_accuracy:.3f}")
    print(f"F1-score (weighted): {avg_f1:.3f}")
    print(f"Precision (weighted): {avg_precision:.3f}")
    print(f"Recall (weighted): {avg_recall:.3f}")
    print(f"Balanced Accuracy: {avg_balanced_accuracy:.3f}")
    print(f"Adjusted Rand Index: {ari_score:.3f}")
    print(f"ROC AUC (OvR, weighted): {roc_auc:.3f}")

    return {
        'method': method_name,
        'accuracy': avg_accuracy,
        'f1_weighted': avg_f1,
        'precision_weighted': avg_precision,
        'recall_weighted': avg_recall,
        'balanced_accuracy': avg_balanced_accuracy,
        'ari': ari_score,
        'roc_auc': roc_auc
    }, true_labels, y_pred # Return true and predicted labels from CV

def compare_modalities(df, text_features, image_features, combined_features, true_labels, valid_df, valid_categories, save_folder="result"):
    """Compare text, image, and combined modalities with extended metrics and save to CSV."""
    os.makedirs(save_folder, exist_ok=True)
    results = []
    # Capture true and predicted labels from CV for each modality
    text_results, text_true_labels_cv, text_pred_labels_cv = evaluate_classification(text_features, true_labels, "Texte seul")
    results.append(text_results)

    img_results, img_true_labels_cv, img_pred_labels_cv = evaluate_classification(image_features, valid_categories, "Image seule")
    results.append(img_results)

    comb_results, comb_true_labels_cv, comb_pred_labels_cv = evaluate_classification(combined_features, valid_categories, "Texte+Image")
    results.append(comb_results)

    results_df = pd.DataFrame(results)

    # Save results to CSV
    results_df.to_csv(os.path.join(save_folder, 'comparison_results.csv'), index=False)
    print(f"✅ Results saved to '{save_folder}/comparison_results.csv'")

    # Visualization
    plt.figure(figsize=(15, 10))

    # Accuracy
    plt.subplot(2, 3, 1)
    sns.barplot(x='method', y='accuracy', hue='method', data=results_df, palette="Blues_d", legend=False)
    plt.title("Accuracy moyenne (validation croisée)")
    plt.ylim(0, 1)
    plt.ylabel("Accuracy")

    # F1-score
    plt.subplot(2, 3, 2)
    sns.barplot(x='method', y='f1_weighted', hue='method', data=results_df, palette="Greens_d", legend=False)
    plt.title("F1-score moyen (pondéré)")
    plt.ylim(0, 1)
    plt.ylabel("F1-score")

    # Precision
    plt.subplot(2, 3, 3)
    sns.barplot(x='method', y='precision_weighted', hue='method', data=results_df, palette="Oranges_d", legend=False)
    plt.title("Precision moyenne (pondérée)")
    plt.ylim(0, 1)
    plt.ylabel("Precision")

    # Recall
    plt.subplot(2, 3, 4)
    sns.barplot(x='method', y='recall_weighted', hue='method', data=results_df, palette="Purples_d", legend=False)
    plt.title("Recall moyen (pondéré)")
    plt.ylim(0, 1)
    plt.ylabel("Recall")

    # Balanced Accuracy
    plt.subplot(2, 3, 5)
    sns.barplot(x='method', y='balanced_accuracy', hue='method', data=results_df, palette="Reds_d", legend=False)
    plt.title("Balanced Accuracy moyenne")
    plt.ylim(0, 1)
    plt.ylabel("Balanced Accuracy")

    # ARI
    plt.subplot(2, 3, 6)
    sns.barplot(x='method', y='ari', hue='method', data=results_df, palette="YlOrBr_d", legend=False)
    plt.title("Adjusted Rand Index")
    plt.ylim(0, 1)
    plt.ylabel("ARI")

    plt.tight_layout()
    plt.savefig(os.path.join(save_folder, 'comparison_supervised_finetuned_extended.png'), dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ Visualization saved to '{save_folder}/comparison_supervised_finetuned_extended.png'")

    return results_df, comb_true_labels_cv, comb_pred_labels_cv # Return results and CV labels for combined features

# 6. Génération de la Matrice de Confusion
Objectif : Cette cellule visualise la matrice de confusion pour comprendre où le modèle fine-tuné fait des erreurs de classification entre les différentes catégories de produits.
Description : Divise les données (caractéristiques combinées et étiquettes) en ensembles d'entraînement et de test. Entraîne un pipeline PCA + RandomForest sur l'ensemble d'entraînement et prédit les étiquettes sur l'ensemble de test. Calcule la matrice de confusion normalisée. Utilise Seaborn pour visualiser la matrice de confusion sous forme de heatmap, affichant les proportions de vrais positifs, faux positifs et faux négatifs pour chaque paire de catégories. Sauvegarde l'image de la matrice de confusion.

In [None]:
def plot_confusion_matrix(features, labels, category_names, save_path="result"):
    """Generate and plot a normalized confusion matrix."""
    os.makedirs(save_path, exist_ok=True)
    X_train, X_test, y_train, y_test = train_test_split(
        features, labels, test_size=0.2, random_state=SEED, stratify=labels)
    pipeline = make_pipeline(
        PCA(n_components=0.95, random_state=SEED),
        RandomForestClassifier(n_estimators=100, random_state=SEED, max_features='sqrt', bootstrap=True))
    pipeline.fit(X_train, y_train)
    y_pred = pipeline.predict(X_test)
    cm = confusion_matrix(y_test, y_pred, normalize='true')
    cleaned_category_names = [re.sub(r'^[\[\"\]]|[\]\"]$', '', name) for name in category_names]

    plt.figure(figsize=(15, 12))
    sns.heatmap(cm, annot=True, fmt=".2f", cmap=sns.light_palette("#3498db", as_cmap=True),
                xticklabels=cleaned_category_names, yticklabels=cleaned_category_names,
                linewidths=0.5, linecolor='lightgray')
    plt.title("Matrice de confusion normalisée (Fine-Tuned CLIP)", fontsize=14, pad=20)
    plt.xlabel('Prédictions', fontsize=12)
    plt.ylabel('Vraies classes', fontsize=12)
    plt.xticks(rotation=45, ha='right', fontsize=10)
    plt.yticks(rotation=0, fontsize=10)
    for _, spine in plt.gca().spines.items():
        spine.set_visible(True)
        spine.set_color('lightgray')
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'confusion_matrix_finetuned.png'), dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ Confusion matrix saved to '{save_path}/confusion_matrix_finetuned.png'")

# 7. Visualisation t-SNE
Objectif : Cette cellule réduit la dimensionnalité des caractéristiques combinées pour visualiser leur distribution dans un espace 2D et observer la séparation des clusters par catégorie.
Description : Applique une PCA pour réduire initialement les caractéristiques combinées, puis utilise t-SNE pour projeter les caractéristiques dans un espace bidimensionnel. Crée un DataFrame avec les coordonnées t-SNE et les étiquettes de catégorie. Utilise Seaborn pour générer un nuage de points (scatterplot) coloré par catégorie, permettant d'évaluer visuellement la qualité du clustering et la distinction entre les différentes classes de produits dans l'espace des caractéristiques apprises par le modèle fine-tuné. Sauvegarde la visualisation t-SNE.

In [None]:
def plot_tsne(features, labels, category_names, save_path="result"):
    """Generate t-SNE visualization of features."""
    os.makedirs(save_path, exist_ok=True)
    print("⏳ Computing t-SNE...")
    pca = PCA(n_components=min(50, features.shape[1]), random_state=SEED)
    features_pca = pca.fit_transform(features)
    tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, random_state=SEED, init='pca')
    tsne_features = tsne.fit_transform(features_pca)
    cleaned_category_names = [re.sub(r'^[\[\"\]]|[\]\"]$', '', name) for name in category_names]
    tsne_df = pd.DataFrame({
        'x': tsne_features[:, 0],
        'y': tsne_features[:, 1],
        'category': [cleaned_category_names[i] for i in labels]
    })

    plt.figure(figsize=(16, 12))
    sns.scatterplot(data=tsne_df, x='x', y='y', hue='category',
                    palette=sns.color_palette("husl", len(cleaned_category_names)),
                    s=70, alpha=0.8, legend='full')
    plt.title("Visualisation t-SNE (Fine-Tuned CLIP)", fontsize=16)
    plt.xlabel("t-SNE 1", fontsize=14)
    plt.ylabel("t-SNE 2", fontsize=14)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0, fontsize=10, title='Catégories', title_fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.2)
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'tsne_finetuned.png'), dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ t-SNE visualization saved ('{save_path}/tsne_finetuned.png')")

# 8. Analyse d'Attention CLIP
Objectif : Cette cellule explore les mécanismes d'attention du modèle CLIP fine-tuné pour comprendre quelles parties de l'image et quels mots-clés sont les plus influents dans la représentation apprise pour un produit donné.
Description : Définit une fonction `clip_attention_analysis` qui prend l'ID d'un produit, le DataFrame, le modèle, le processeur et le tokenizer. Effectue une décomposition de l'image en patches et calcule la similarité des caractéristiques de chaque patch avec les caractéristiques des mots-clés associés au produit. Visualise ces similarités sur une grille d'images. Calcule également une heatmap d'attention en utilisant des patchs glissants sur l'image entière et la moyenne des scores de similarité avec tous les mots-clés, superposant cette heatmap sur l'image originale pour montrer les régions les plus "attendues". Affiche et sauvegarde les visualisations ainsi que les scores de similarité des mots-clés.

In [None]:
def clip_attention_analysis(uniq_id, df, model, processor, tokenizer, patch_size=128, resolution=50, category_folder="Home Furnishing"):
    """Generate CLIP attention interpretability visualizations."""
    try:
        product = df[df['uniq_id'] == uniq_id].iloc[0]
        img = Image.open(product['image_path'])
        img = img.convert('RGB')  # Ensure RGB format
        img_width, img_height = img.size
        img_bw = img.convert('L')
        enhancer = ImageEnhance.Contrast(img_bw)
        img_bw = enhancer.enhance(1.5)
        img_bw = np.array(img_bw)
        keywords = list(set(kw.strip() for kw in product['keywords'].split(',') if kw.strip()))
        print(f"🔍 Analyse du produit: {product['product_name'][:50]}...")
        print("🔠 Mots-clés uniques:", ", ".join(keywords))

        # Create the category folder
        os.makedirs(category_folder, exist_ok=True)

        # Decomposition
        patches = []
        positions = []
        step = patch_size
        for y in range(0, img_height, step):
            for x in range(0, img_width, step):
                patch = img.crop((x, y, min(x+patch_size, img_width), min(y+patch_size, img_height)))
                if patch.size[0] > 0 and patch.size[1] > 0:
                    patch = patch.convert('RGB')  # Ensure patch is RGB
                    patch = patch.resize((224, 224), Image.LANCZOS)
                    patches.append(patch)
                    positions.append((x, y, min(x+patch_size, img_width), min(y+patch_size, img_height)))

        if not patches:
            raise ValueError("No valid patches extracted for decomposition.")

        with torch.no_grad():
            text_inputs = tokenizer(keywords[:5], return_tensors="pt", padding=True, truncation=True, max_length=77).to(device)
            # Access text features directly from the model instance
            text_features = model.get_text_features(**text_inputs)
            patch_features = []
            for p in patches:
                inputs = processor(images=p, return_tensors="pt", padding=True).pixel_values.to(device).float()
                if inputs.shape[1] != 3:  # Check for 3 channels (RGB)
                    print(f"⚠️ Patch non-RGB détecté, saut du patch")
                    continue
                # Access image features directly from the model instance
                features = model.get_image_features(pixel_values=inputs)
                patch_features.append(features)
                torch.cuda.empty_cache()
            if not patch_features:
                raise ValueError("No valid patch features extracted.")
            patch_features = torch.cat(patch_features)
            similarities = (patch_features @ text_features.T).softmax(dim=-1).cpu().numpy()

        n = int(np.ceil(len(patches)**0.5))
        fig, axes = plt.subplots(n, n, figsize=(15, 15))
        axes = axes.flatten()
        for idx, (patch, ax) in enumerate(zip(patches, axes)):
            ax.imshow(patch)
            ax.axis('off')
            if idx < len(similarities):
                top_concept = keywords[similarities[idx].argmax()]
                ax.set_title(f"{top_concept}\n{similarities[idx].max():.2f}", fontsize=8)
        for ax in axes[len(patches):]:
            ax.axis('off')
        plt.suptitle(f"CLIP Decomposition (Fine-Tuned) - {product['product_name'][:50]}...", y=0.92)
        plt.tight_layout()
        decomposition_path = os.path.join(category_folder, f'clip_decomposition_finetuned_{uniq_id}.png')
        plt.savefig(decomposition_path, dpi=300, bbox_inches='tight')
        plt.close()

        # Keyword Similarity
        with torch.no_grad():
            text_inputs = tokenizer(keywords, return_tensors="pt", padding=True, truncation=True, max_length=77).to(device)
            image_inputs = processor(images=img, return_tensors="pt").pixel_values.to(device).float()
            # Access text features directly from the model instance
            text_features = model.get_text_features(**text_inputs)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            # Access image features directly from the model instance
            image_features = model.get_image_features(pixel_values=image_inputs)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            logits_per_image = (image_features @ text_features.T) / 0.07
            probs = logits_per_image.softmax(dim=-1).cpu().numpy()[0]
        results = dict(zip(keywords, probs))
        sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)

        print("\n📊 Scores de similarité:")
        for kw, prob in sorted_results:
            print(f"- {kw}: {prob:.4f}")

        # Smooth Heatmap with batched processing
        torch.cuda.empty_cache()
        x = np.linspace(0, img_width, resolution, dtype=int)
        y = np.linspace(0, img_height, resolution, dtype=int)
        xx, yy = np.meshgrid(x, y)
        positions = []
        batch_size = 10
        patch_features = []
        size = min(img_width, img_height) // 10
        for i in range(0, resolution * resolution, batch_size):
            batch_patches = []
            batch_positions = []
            for j in range(i, min(i + batch_size, resolution * resolution)):
                x_idx = j // resolution
                y_idx = j % resolution
                x_pos = xx[x_idx, y_idx]
                y_pos = yy[x_idx, y_idx]
                patch = img.crop((max(0, x_pos - size//2), max(0, y_pos - size//2),
                                  min(img_width, x_pos + size//2), min(img_height, y_pos + size//2)))
                if patch.size[0] > 0 and patch.size[1] > 0:
                    patch = patch.convert('RGB')  # Ensure patch is RGB
                    patch = patch.resize((224, 224), Image.LANCZOS)
                    batch_patches.append(patch)
                    batch_positions.append((x_pos, y_pos))
            if batch_patches:
                with torch.no_grad():
                    inputs = processor(images=batch_patches, return_tensors="pt").pixel_values.to(device).float()
                    if inputs.shape[1] != 3:  # Check for 3 channels (RGB)
                        print(f"⚠️ Batch non-RGB détecté, saut du batch")
                        continue
                    # Access image features directly from the model instance
                    features = model.get_image_features(pixel_values=inputs)
                    patch_features.append(features)
                positions.extend(batch_positions)
                torch.cuda.empty_cache()
        if not patch_features:
            raise ValueError("No valid patches extracted for heatmap.")
        patch_features = torch.cat(patch_features)
        patch_features = patch_features / patch_features.norm(dim=-1, keepdim=True)
        with torch.no_grad():
            # Access text features directly from the model instance
            text_features = model.get_text_features(**text_inputs)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            attention_scores = (patch_features @ text_features.T).cpu().numpy()
        points = np.array(positions)
        grid_x, grid_y = np.mgrid[0:img_width:complex(0, img_width), 0:img_height:complex(0, img_height)]
        smooth_heatmap = griddata(points, attention_scores.mean(axis=1), (grid_x, grid_y), method='cubic', fill_value=0)
        smooth_heatmap = (smooth_heatmap - smooth_heatmap.min()) / (smooth_heatmap.max() - smooth_heatmap.min())

        plt.figure(figsize=(16, 10))
        plt.imshow(img_bw, cmap='gray', vmin=0, vmax=255)
        heatmap_layer = plt.imshow(smooth_heatmap.T, cmap='inferno', alpha=0.55, # Transpose heatmap
                                  extent=[0, img_width, img_height, 0], interpolation='bicubic')
        top_keywords = sorted(zip(keywords, attention_scores.mean(axis=0)), key=lambda x: x[1], reverse=True)[:3]
        for kw, score in top_keywords:
            kw_idx = keywords.index(kw)
            max_pos_idx = np.argmax(attention_scores[:, kw_idx])
            max_pos = positions[max_pos_idx]
            plt.scatter(max_pos[0], max_pos[1], s=300, edgecolors='white', linewidths=2, facecolors='none')
            plt.text(max_pos[0], max_pos[1]+img_height*0.03, f"{kw}\n({score:.2f})",
                     color='white', ha='center', va='top', fontsize=11,
                     bbox=dict(facecolor='black', alpha=0.7, boxstyle='round,pad=0.5', edgecolor='white', linewidth=1))
        cbar = plt.colorbar(heatmap_layer, fraction=0.03, pad=0.01)
        cbar.set_label('Intensité d\'attention', rotation=270, labelpad=15)
        plt.title(f"Heatmap d'attention CLIP (Fine-Tuné) - {product['product_name'][:50]}...\nProduit: {uniq_id}", pad=20, fontsize=12)
        plt.axis('off')
        plt.tight_layout()
        heatmap_path = os.path.join(category_folder, f'smooth_attention_finetuned_{uniq_id}.png')
        plt.savefig(heatmap_path, dpi=300, bbox_inches='tight', facecolor='black')
        plt.close()
        print(f"✅ Heatmap saved as {heatmap_path}")

        return {
            'decomposition': similarities,
            'keyword_similarities': dict(sorted_results),
            'heatmap': smooth_heatmap,
            'top_keywords': top_keywords
        }

    except Exception as e:
        print(f"❌ Error: {str(e)}")
        return None

# 9. Analyse des Produits Représentatifs et des Erreurs de Prédiction par Catégorie
Objectif : Cette cellule analyse en détail les produits qui sont les plus "typiques" de chaque catégorie (représentatifs) et ceux qui sont le plus souvent mal classés, en utilisant l'analyse d'attention pour comprendre les raisons.
Description : Définit la fonction `find_closest_to_centers` pour identifier les produits dont les caractéristiques combinées sont les plus proches du centre (moyenne) de leur cluster (catégorie), les considérant comme des représentants typiques. La fonction `analyze_classification_errors` utilise la validation croisée pour identifier les paires de catégories où les erreurs de classification sont les plus fréquentes, en se basant sur la matrice de confusion. Pour les erreurs les plus courantes, elle sélectionne les produits mal classés et applique l'analyse d'attention CLIP (définie dans la cellule précédente) pour visualiser ce qui a pu conduire à la mauvaise prédiction. Sauvegarde les visualisations et génère un rapport CSV des erreurs.

In [None]:
def find_closest_to_centers(features, labels, df, n_examples=3):
    """Find n products closest to each cluster center."""
    from sklearn.metrics.pairwise import euclidean_distances
    unique_labels = np.unique(labels)
    closest_indices = []

    for label in unique_labels:
        cluster_points = features[labels == label]
        center = np.mean(cluster_points, axis=0)
        distances = euclidean_distances(cluster_points, [center])
        closest_idx = np.argsort(distances.flatten())[:n_examples]
        original_indices = np.where(labels == label)[0][closest_idx]
        closest_indices.extend(original_indices)

    results = df.iloc[closest_indices].copy()
    results['cluster'] = labels[closest_indices]
    results['distance_to_center'] = euclidean_distances(
        features[closest_indices],
        [np.mean(features[labels == l], axis=0) for l in labels[closest_indices]]
    ).diagonal()
    return results.sort_values(['cluster', 'distance_to_center'])

def analyze_classification_errors(features, true_labels, df, category_names, model, processor, tokenizer, top_n_errors=5, n_splits=5):
    """Analyze and visualize classification errors from confusion matrix."""
    # Create error directory
    os.makedirs('error', exist_ok=True)
    print("\n⏳ Analyzing classification errors...")

    # Generate predictions using cross-validation
    pipeline = make_pipeline(
        PCA(n_components=0.95, random_state=SEED),
        RandomForestClassifier(n_estimators=100, random_state=SEED, max_features='sqrt', bootstrap=True)
    )
    cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=SEED)
    pred_labels = cross_val_predict(pipeline, features, true_labels, cv=cv, n_jobs=1)

    # Get confusion matrix
    cm = confusion_matrix(true_labels, pred_labels)

    # Find most common error pairs
    error_pairs = []
    for i in range(len(category_names)):
        for j in range(len(category_names)):
            if i != j and cm[i,j] > 0:
                error_pairs.append((i, j, cm[i,j]))

    # Sort by error count
    error_pairs.sort(key=lambda x: x[2], reverse=True)

    # Create error analysis dataframe
    error_df = pd.DataFrame(columns=['true_category', 'predicted_category', 'uniq_id',
                                   'product_name', 'keywords', 'error_count'])

    # Process top error pairs
    for true_idx, pred_idx, error_count in error_pairs[:top_n_errors]:
        true_cat = category_names[true_idx]
        pred_cat = category_names[pred_idx]

        print(f"\n🔴 Erreur fréquente: '{true_cat}' classé comme '{pred_cat}' ({error_count} erreurs)")

        # Get misclassified products
        misclassified_indices = np.where((true_labels == true_idx) & (pred_labels == pred_idx))[0]

        if len(misclassified_indices) == 0:
            print("⚠️ Aucun produit trouvé pour cette paire d'erreur")
            continue

        # Get the actual products from valid_df
        misclassified_products = valid_df.iloc[misclassified_indices]

        # Process each misclassified product
        for _, row in misclassified_products.iterrows():
            try:
                # Create subfolder for this error type
                error_folder = os.path.join('error',
                                          f"{true_cat.replace('/', '_')}_as_{pred_cat.replace('/', '_')}")
                os.makedirs(error_folder, exist_ok=True)

                print(f"   🔍 Traitement du produit: {row['product_name'][:50]}...")

                # Generate CLIP attention analysis
                analysis_results = clip_attention_analysis(
                    uniq_id=row['uniq_id'],
                    df=valid_df,
                    model=model,
                    processor=processor,
                    tokenizer=tokenizer,
                    category_folder=error_folder
                )

                if analysis_results:
                    # ✅ MÊME CODE QUE POUR LE DOSSIER 'CATEGORY'
                    plt.figure(figsize=(10, 6))
                    keywords_list = list(analysis_results['keyword_similarities'].keys())
                    scores_list = list(analysis_results['keyword_similarities'].values())

                    ax = sns.barplot(x=scores_list, y=keywords_list, hue=keywords_list, palette="Blues_d", legend=False)
                    plt.title(f"Scores de similarité des mots-clés - {row['product_name'][:50]}...")
                    plt.xlabel("Score de similarité")
                    plt.ylabel("Mots-clés")

                    # Ajouter les valeurs au bout des barres
                    for i, score in enumerate(scores_list):
                        ax.text(score + 0.002, i, f'{score:.4f}', va='center', ha='left', fontsize=10, color='black')

                    plt.tight_layout()
                    barchart_path = os.path.join(error_folder, f'keyword_similarity_barchart_{row["uniq_id"]}.png')
                    plt.savefig(barchart_path, dpi=300, bbox_inches='tight')
                    plt.close()
                    print(f"   ✅ Diagramme en barres sauvegardé: {barchart_path}")

                    # Add to error dataframe
                    error_df = pd.concat([error_df, pd.DataFrame([{
                        'true_category': true_cat,
                        'predicted_category': pred_cat,
                        'uniq_id': row['uniq_id'],
                        'product_name': row['product_name'],
                        'keywords': row['keywords'],
                        'error_count': error_count
                    }])], ignore_index=True)

            except Exception as e:
                print(f"❌ Erreur lors du traitement du produit {row['uniq_id']}: {str(e)}")

    # Save error analysis report
    if not error_df.empty:
        error_df.to_csv('error/classification_errors_report.csv', index=False)
        print("\n✅ Rapport d'erreurs sauvegardé: 'error/classification_errors_report.csv'")

        # Afficher un résumé
        print("\n📊 RÉSUMÉ DES ERREURS:")
        for _, row in error_df.iterrows():
            print(f"   - {row['true_category']} → {row['predicted_category']}: {row['product_name'][:30]}...")
    else:
        print("\n⚠️ Aucune erreur de classification trouvée")

    return error_df

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import os
from PIL import Image
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import StratifiedKFold, cross_val_predict

def find_closest_to_centers(features, labels, df, n_examples=3):
    """Find n products closest to each cluster center."""
    from sklearn.metrics.pairwise import euclidean_distances
    unique_labels = np.unique(labels)
    closest_indices = []

    for label in unique_labels:
        cluster_points = features[labels == label]
        if cluster_points.shape[0] == 0:
            print(f"⚠️ No points found for label {label}")
            continue
        center = np.mean(cluster_points, axis=0)
        distances = euclidean_distances(cluster_points, [center])
        closest_idx = np.argsort(distances.flatten())[:n_examples]
        # Ensure original_indices correspond to the original dataframe df
        original_indices = df[labels == label].iloc[closest_idx].index.tolist()
        closest_indices.extend(original_indices)


    results = df.loc[closest_indices].copy()
    # Map original indices back to labels
    results['cluster'] = labels[results.index]
    # Recalculate distance to center using the features of the selected products
    results['distance_to_center'] = euclidean_distances(
        features[results.index],
        [np.mean(features[labels == l], axis=0) for l in results['cluster']]
    ).diagonal()
    return results.sort_values(['cluster', 'distance_to_center'])

def analyze_classification_errors(features, true_labels, df, category_names, model, processor, tokenizer, top_n_errors=5, n_splits=5):
    """Analyze and visualize classification errors from confusion matrix."""
    # Create error directory
    os.makedirs('error', exist_ok=True)
    print("\n⏳ Analyzing classification errors...")

    # Generate predictions using cross-validation
    pipeline = make_pipeline(
        PCA(n_components=0.95, random_state=SEED),
        RandomForestClassifier(n_estimators=100, random_state=SEED, max_features='sqrt', bootstrap=True)
    )
    cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=SEED)
    pred_labels = cross_val_predict(pipeline, features, true_labels, cv=cv, n_jobs=1)

    # Get confusion matrix
    cm = confusion_matrix(true_labels, pred_labels)

    # Find most common error pairs
    error_pairs = []
    for i in range(len(category_names)):
        for j in range(len(category_names)):
            if i != j and cm[i,j] > 0:
                error_pairs.append((i, j, cm[i,j]))

    # Sort by error count
    error_pairs.sort(key=lambda x: x[2], reverse=True)

    # Create error analysis dataframe
    error_df = pd.DataFrame(columns=['true_category', 'predicted_category', 'uniq_id',
                                   'product_name', 'keywords', 'error_count'])

    # Process top error pairs
    for true_idx, pred_idx, error_count in error_pairs[:top_n_errors]:
        true_cat = category_names[true_idx]
        pred_cat = category_names[pred_idx]

        print(f"\n🔴 Erreur fréquente: '{true_cat}' classé comme '{pred_cat}' ({error_count} erreurs)")

        # Get misclassified products
        misclassified_indices = np.where((true_labels == true_idx) & (pred_labels == pred_idx))[0]

        if len(misclassified_indices) == 0:
            print("⚠️ Aucun produit trouvé pour cette paire d'erreur")
            continue

        # Get the actual products from valid_df
        misclassified_products = df.iloc[misclassified_indices] # Use the input df which should be the valid_df from the main pipeline

        # Process each misclassified product
        for _, row in misclassified_products.iterrows():
            try:
                # Create subfolder for this error type
                error_folder = os.path.join('error',
                                          f"{true_cat.replace('/', '_')}_as_{pred_cat.replace('/', '_')}")
                os.makedirs(error_folder, exist_ok=True)

                print(f"   🔍 Traitement du produit: {row['product_name'][:50]}...")

                # Generate CLIP attention analysis
                analysis_results = clip_attention_analysis(
                    uniq_id=row['uniq_id'],
                    df=df, # Pass the input df which should be the valid_df
                    model=model,
                    processor=processor,
                    tokenizer=tokenizer,
                    category_folder=error_folder
                )

                if analysis_results:
                    # ✅ MÊME CODE QUE POUR LE DOSSIER 'CATEGORY'
                    plt.figure(figsize=(10, 6))
                    keywords_list = list(analysis_results['keyword_similarities'].keys())
                    scores_list = list(analysis_results['keyword_similarities'].values())

                    ax = sns.barplot(x=scores_list, y=keywords_list, hue=keywords_list, palette="Blues_d", legend=False)
                    plt.title(f"Scores de similarité des mots-clés - {row['product_name'][:50]}...")
                    plt.xlabel("Score de similarité")
                    plt.ylabel("Mots-clés")

                    # Ajouter les valeurs au bout des barres
                    for i, score in enumerate(scores_list):
                        ax.text(score + 0.002, i, f'{score:.4f}', va='center', ha='left', fontsize=10, color='black')

                    plt.tight_layout()
                    barchart_path = os.path.join(error_folder, f'keyword_similarity_barchart_{row["uniq_id"]}.png')
                    plt.savefig(barchart_path, dpi=300, bbox_inches='tight')
                    plt.close()
                    print(f"   ✅ Diagramme en barres sauvegardé: {barchart_path}")

                    # Add to error dataframe
                    error_df = pd.concat([error_df, pd.DataFrame([{
                        'true_category': true_cat,
                        'predicted_category': pred_cat,
                        'uniq_id': row['uniq_id'],
                        'product_name': row['product_name'],
                        'keywords': row['keywords'],
                        'error_count': error_count
                    }])], ignore_index=True)

            except Exception as e:
                print(f"❌ Erreur lors du traitement du produit {row['uniq_id']}: {str(e)}")

    # Save error analysis report
    if not error_df.empty:
        error_df.to_csv('error/classification_errors_report.csv', index=False)
        print("\n✅ Rapport d'erreurs sauvegardé: 'error/classification_errors_report.csv'")

        # Afficher un résumé
        print("\n📊 RÉSUMÉ DES ERREURS:")
        for _, row in error_df.iterrows():
            print(f"   - {row['true_category']} → {row['predicted_category']}: {row['product_name'][:30]}...")
    else:
        print("\n⚠️ Aucune erreur de classification trouvée")

    return error_df

# 10. Exécution du Pipeline Complet
Objectif : Cette cellule exécute l'ensemble du pipeline d'analyse multimodale, du chargement des données à l'analyse détaillée des résultats, y compris l'évaluation, la visualisation et l'interprétabilité des erreurs.
Description : Appelle séquentiellement les fonctions définies dans les cellules précédentes : chargement et nettoyage des données (`load_data`, `process_descriptions_to_keywords`), fine-tuning du modèle CLIP (`fine_tune_clip`), extraction des caractéristiques (`extract_text_features`, `extract_image_features`, `combine_features`), évaluation comparative (`compare_modalities`), génération des visualisations (matrice de confusion et t-SNE), et analyse unifiée des produits représentatifs et des erreurs (`unified_analysis_pipeline`). La fonction `unified_analysis_pipeline` est redéfinie ici pour s'assurer qu'elle est disponible dans ce bloc d'exécution, intégrant les appels aux fonctions d'analyse d'attention et de recherche des représentants. Gère également la gestion de la mémoire GPU et les erreurs potentielles.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import os
from PIL import Image
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import StratifiedKFold, cross_val_predict
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.preprocessing import LabelEncoder # Import LabelEncoder


# Definition de la fonction unified_analysis_pipeline included here for assurance
def find_closest_to_centers(features, labels, df, n_examples=3):
    """Find n products closest to each cluster center."""
    from sklearn.metrics.pairwise import euclidean_distances
    unique_labels = np.unique(labels)
    closest_indices = []

    for label in unique_labels:
        cluster_points = features[labels == label]
        if cluster_points.shape[0] == 0:
            print(f"⚠️ No points found for label {label}")
            continue
        center = np.mean(cluster_points, axis=0)
        distances = euclidean_distances(cluster_points, [center])
        closest_idx = np.argsort(distances.flatten())[:n_examples]
        # Ensure original_indices correspond to the original dataframe df
        original_indices = df[labels == label].iloc[closest_idx].index.tolist()
        closest_indices.extend(original_indices)


    results = df.loc[closest_indices].copy()
    # Map original indices back to labels
    results['cluster'] = labels[results.index]
    # Recalculate distance to center using the features of the selected products
    results['distance_to_center'] = euclidean_distances(
        features[results.index],
        [np.mean(features[labels == l], axis=0) for l in results['cluster']]
    ).diagonal()
    return results.sort_values(['cluster', 'distance_to_center'])

def analyze_classification_errors(features, true_labels, df, category_names, model, processor, tokenizer, top_n_errors=5, n_splits=5, save_folder="error"):
    """Analyze and visualize classification errors from confusion matrix."""
    # Create error directory
    os.makedirs(save_folder, exist_ok=True)
    print("\n⏳ Analyzing classification errors...")

    # Generate predictions using cross-validation
    pipeline = make_pipeline(
        PCA(n_components=0.95, random_state=SEED),
        RandomForestClassifier(n_estimators=100, random_state=SEED, max_features='sqrt', bootstrap=True)
    )
    cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=SEED)
    pred_labels = cross_val_predict(pipeline, features, true_labels, cv=cv, n_jobs=1)

    # Get confusion matrix
    cm = confusion_matrix(true_labels, pred_labels)

    # Find most common error pairs
    error_pairs = []
    for i in range(len(category_names)):
        for j in range(len(category_names)):
            if i != j and cm[i,j] > 0:
                error_pairs.append((i, j, cm[i,j]))

    # Sort by error count
    error_pairs.sort(key=lambda x: x[2], reverse=True)

    # Create error analysis dataframe
    error_df = pd.DataFrame(columns=['true_category', 'predicted_category', 'uniq_id',
                                   'product_name', 'keywords', 'error_count'])

    # Process top error pairs
    for true_idx, pred_idx, error_count in error_pairs[:top_n_errors]:
        true_cat = category_names[true_idx]
        pred_cat = category_names[pred_idx]

        print(f"\n🔴 Erreur fréquente: '{true_cat}' classé comme '{pred_cat}' ({error_count} erreurs)")

        # Get misclassified products
        misclassified_indices = np.where((true_labels == true_idx) & (pred_labels == pred_idx))[0]

        if len(misclassified_indices) == 0:
            print("⚠️ Aucun produit trouvé pour cette paire d'erreur")
            continue

        # Get the actual products from valid_df
        misclassified_products = df.iloc[misclassified_indices] # Use the input df which should be the valid_df from the main pipeline

        # Process each misclassified product
        for _, row in misclassified_products.iterrows():
            try:
                # Create subfolder for this error type
                error_folder = os.path.join(save_folder,
                                          f"{true_cat.replace('/', '_')}_as_{pred_cat.replace('/', '_')}")
                os.makedirs(error_folder, exist_ok=True)

                print(f"   🔍 Traitement du produit: {row['product_name'][:50]}...")

                # Generate CLIP attention analysis
                analysis_results = clip_attention_analysis(
                    uniq_id=row['uniq_id'],
                    df=df, # Pass the input df which should be the valid_df
                    model=model,
                    processor=processor,
                    tokenizer=tokenizer,
                    category_folder=error_folder
                )

                if analysis_results:
                    # ✅ MÊME CODE QUE POUR LE DOSSIER 'CATEGORY'
                    plt.figure(figsize=(10, 6))
                    keywords_list = list(analysis_results['keyword_similarities'].keys())
                    scores_list = list(analysis_results['keyword_similarities'].values()) # Corrected access

                    ax = sns.barplot(x=scores_list, y=keywords_list, hue=keywords_list, palette="Blues_d", legend=False)
                    plt.title(f"Scores de similarité des mots-clés - {row['product_name'][:50]}...")
                    plt.xlabel("Score de similarité")
                    plt.ylabel("Mots-clés")

                    # Ajouter les valeurs au bout des barres
                    for i, score in enumerate(scores_list):
                        ax.text(score + 0.002, i, f'{score:.4f}', va='center', ha='left', fontsize=10, color='black')

                    plt.tight_layout()
                    barchart_path = os.path.join(error_folder, f'keyword_similarity_barchart_{row["uniq_id"]}.png')
                    plt.savefig(barchart_path, dpi=300, bbox_inches='tight')
                    plt.close()
                    print(f"   ✅ Diagramme en barres sauvegardé: {barchart_path}")

                    # Add to error dataframe
                    error_df = pd.concat([error_df, pd.DataFrame([{
                        'true_category': true_cat,
                        'predicted_category': pred_cat,
                        'uniq_id': row['uniq_id'],
                        'product_name': row['product_name'],
                        'keywords': row['keywords'],
                        'error_count': error_count
                    }])], ignore_index=True)

            except Exception as e:
                print(f"❌ Erreur lors du traitement du produit {row['uniq_id']}: {str(e)}")

    # Save error analysis report
    if not error_df.empty:
        error_df.to_csv(os.path.join(save_folder, 'classification_errors_report.csv'), index=False)
        print(f"\n✅ Rapport d'erreurs sauvegardé: '{save_folder}/classification_errors_report.csv'")

        # Afficher un résumé
        print("\n📊 RÉSUMÉ DES ERREURS:")
        for _, row in error_df.iterrows():
            print(f"   - {row['true_category']} → {row['predicted_category']}: "
                      f"{row['product_name'][:30]}...")
        else:
            print("\n✅ Aucune erreur de classification trouvée")

    return error_df

def unified_analysis_pipeline(features, true_labels, df, category_names, model, processor, tokenizer,
                             analysis_type="both", n_representatives=3, top_n_errors=5):
    """
    Pipeline unifié pour l'analyse des produits représentatifs et des erreurs de classification.

    Args:
        features: Caractéristiques combinées
        true_labels: Étiquettes vraies
        df: DataFrame original (should be the valid_df from the main pipeline)
        category_names: Noms des catégories
        model: Modèle CLIP fine-tuné
        processor: Processeur CLIP
        tokenizer: Tokenizer CLIP
        analysis_type: Type d'analyse ("representatives", "errors", ou "both")
        n_representatives: Nombre de produits représentatifs par catégorie
        top_n_errors: Nombre d'erreurs principales à analyser
    """

    # Créer les dossiers nécessaires
    category_save_folder = 'category'
    error_save_folder = 'error'
    os.makedirs(category_save_folder, exist_ok=True)
    os.makedirs(error_save_folder, exist_ok=True)


    if analysis_type in ["representatives", "both"]:
        print("\n" + "="*80)
        print("ANALYSE DES PRODUITS REPRÉSENTATIFS PAR CATÉGORIE")
        print("="*80)

        # Trouver les produits les plus proches des centres de clusters
        print("\n⏳ Recherche des produits les plus proches des centres de clusters...")
        closest_products = find_closest_to_centers(features, true_labels, df, n_examples=n_representatives)


        # Traiter chaque catégorie
        unique_labels = np.unique(true_labels)
        for cluster_id in unique_labels:
            category_name = category_names[cluster_id]
            print(f"\n🏠 Top {n_representatives} produits pour la catégorie '{category_name}':")

            # Créer un dossier pour la catégorie
            category_folder = os.path.join(category_save_folder, category_name.replace('/', '_').replace(' ', '_'))
            os.makedirs(category_folder, exist_ok=True)

            # Sélectionner les produits pour cette catégorie
            category_products = closest_products[closest_products['cluster'] == cluster_id].head(n_representatives)

            if category_products.empty:
                print(f"⚠️ Aucun produit trouvé pour la catégorie '{category_name}'")
                continue

            # Traiter chaque produit représentatif
            for idx, row in category_products.iterrows():
                print(f"\n🔹 Produit: {row['product_name'][:50]}...")
                print(f"   📏 Distance au centre: {row['distance_to_center']:.4f}")

                try:
                    # Générer l'analyse d'attention CLIP
                    analysis_results = clip_attention_analysis(
                        uniq_id=row['uniq_id'],
                        df=df, # Pass the input df which should be the valid_df
                        model=model,
                        processor=processor,
                        tokenizer=tokenizer,
                        category_folder=category_folder # Pass the category-specific folder
                    )

                    if analysis_results:
                        # Générer le diagramme en barres des similarités
                        plt.figure(figsize=(10, 6))
                        keywords_list = list(analysis_results['keyword_similarities'].keys())
                        scores_list = list(analysis_results['keyword_similarities'].values())

                        ax = sns.barplot(x=scores_list, y=keywords_list, hue=keywords_list,
                                        palette="Blues_d", legend=False)
                        plt.title(f"Scores de similarité - {row['product_name'][:50]}...")
                        plt.xlabel("Score de similarité")
                        plt.ylabel("Mots-clés")

                        # Ajouter les valeurs aux barres
                        for i, score in enumerate(scores_list):
                            ax.text(score + 0.002, i, f'{score:.4f}', va='center',
                                   ha='left', fontsize=10, color='black')

                        plt.tight_layout()
                        barchart_path = os.path.join(category_folder,
                                                   f'keyword_similarity_barchart_{row["uniq_id"]}.png')
                        plt.savefig(barchart_path, dpi=300, bbox_inches='tight')
                        plt.close()

                        print(f"✅ Visualisations sauvegardées dans: {category_folder}")

                except Exception as e:
                    print(f"❌ Erreur lors de l'analyse du produit {row['uniq_id']}: {str(e)}")

    if analysis_type in ["errors", "both"]:
        print("\n" + "="*80)
        print("ANALYSE DES ERREURS DE CLASSIFICATION")
        print("="*80)

        # Générer les prédictions par validation croisée
        pipeline = make_pipeline(
            PCA(n_components=0.95, random_state=SEED),
            RandomForestClassifier(n_estimators=100, random_state=SEED,
                                 max_features='sqrt', bootstrap=True)
        )
        cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
        pred_labels = cross_val_predict(pipeline, features, true_labels, cv=cv, n_jobs=1)

        # Obtenir la matrice de confusion
        cm = confusion_matrix(true_labels, pred_labels)

        # Trouver les paires d'erreurs les plus fréquentes
        error_pairs = []
        for i in range(len(category_names)):
            for j in range(len(category_names)):
                if i != j and cm[i, j] > 0:
                    error_pairs.append((i, j, cm[i, j]))

        # Trier par nombre d'erreurs
        error_pairs.sort(key=lambda x: x[2], reverse=True)

        # Créer le dataframe d'analyse d'erreurs
        error_df = pd.DataFrame(columns=['true_category', 'predicted_category', 'uniq_id',
                                       'product_name', 'keywords', 'error_count'])

        # Traiter les principales erreurs
        for true_idx, pred_idx, error_count in error_pairs[:top_n_errors]:
            true_cat = category_names[true_idx]
            pred_cat = category_names[pred_idx]

            print(f"\n🔴 Erreur: '{true_cat}' → '{pred_cat}' ({error_count} erreurs)")

            # Obtenir les indices des produits mal classés
            misclassified_indices = np.where((true_labels == true_idx) & (pred_labels == pred_idx))[0]

            if len(misclassified_indices) == 0:
                print("⚠️ Aucun produit trouvé pour cette paire d'erreur")
                continue

            # Créer un sous-dossier pour ce type d'erreur
            error_folder = os.path.join(error_save_folder,
                                      f"{true_cat.replace('/', '_')}_as_{pred_cat.replace('/', '_')}")
            os.makedirs(error_folder, exist_ok=True)

            # Traiter chaque produit mal classé
            for idx in misclassified_indices:
                row = df.iloc[idx] # Use the input df which should be the valid_df from the main pipeline

                try:
                    print(f"   🔍 Traitement: {row['product_name'][:50]}...")

                    # Générer l'analyse d'attention
                    analysis_results = clip_attention_analysis(
                        uniq_id=row['uniq_id'],
                        df=df, # Pass the input df which should be the valid_df
                        model=model,
                        processor=processor,
                        tokenizer=tokenizer,
                        category_folder=error_folder # Pass the error-specific folder
                    )

                    if analysis_results:
                        # Générer le diagramme en barres
                        plt.figure(figsize=(10, 6))
                        keywords_list = list(analysis_results['keyword_similarities'].keys())
                        # CORRECTED: Use 'keyword_similarities' (plural)
                        scores_list = list(analysis_results['keyword_similarities'].values())

                        ax = sns.barplot(x=scores_list, y=keywords_list, hue=keywords_list,
                                        palette="Reds_d", legend=False)
                        plt.title(f"Erreur: {true_cat} → {pred_cat} - {row['product_name'][:30]}...")
                        plt.xlabel("Score de similarité")
                        plt.ylabel("Mots-clés")

                        for i, score in enumerate(scores_list):
                            ax.text(score + 0.002, i, f'{score:.4f}', va='center',
                                   ha='left', fontsize=10, color='black')

                        plt.tight_layout()
                        barchart_path = os.path.join(error_folder,
                                                   f'keyword_similarity_barchart_{row["uniq_id"]}.png')
                        plt.savefig(barchart_path, dpi=300, bbox_inches='tight')
                        plt.close()

                        # Ajouter au rapport d'erreurs
                        error_df = pd.concat([error_df, pd.DataFrame([{
                            'true_category': true_cat,
                            'predicted_category': pred_cat,
                            'uniq_id': row['uniq_id'],
                            'product_name': row['product_name'],
                            'keywords': row['keywords'],
                            'error_count': error_count
                        }])], ignore_index=True)

                except Exception as e:
                    print(f"❌ Erreur lors du traitement du produit {row['uniq_id']}: {str(e)}")

        # Sauvegarder le rapport d'erreurs
        if not error_df.empty:
            error_df.to_csv(os.path.join(error_save_folder, 'classification_errors_report.csv'), index=False)
            print(f"\n✅ Rapport d'erreurs sauvegardé: '{error_save_folder}/classification_errors_report.csv'")

            # Afficher le résumé
            print("\n📊 RÉSUMÉ DES ERREURS:")
            for _, row in error_df.iterrows():
                print(f"   - {row['true_category']} → {row['predicted_category']}: "
                      f"{row['product_name'][:30]}...")
        else:
            print("\n✅ Aucune erreur de classification trouvée")

    print("\n" + "="*80)
    print("ANALYSE TERMINÉE AVEC SUCCÈS")
    print("="*80)


# Créer les dossier au début de l'exécution du pipeline
os.makedirs('result', exist_ok=True)
print("✅ Created 'result' folder.")
os.makedirs('category', exist_ok=True)
print("✅ Created 'category' folder.")
os.makedirs('error', exist_ok=True)
print("✅ Created 'error' folder.")
os.makedirs('training_analysis', exist_ok=True)
print("✅ Created 'training_analysis' folder.")


try:
    print("\n⏳ Loading data...")
    df = load_data('produits_original.csv', 'images_original')
    df = process_descriptions_to_keywords(df)
    print(f"✅ {len(df)} products loaded")

    print("\n⏳ Clearing GPU memory...")
    torch.cuda.empty_cache()
    import gc
    gc.collect()
    print("✅ GPU memory cleared")

    print("\n⏳ Fine-tuning CLIP model...")
    # Modifier l'appel pour récupérer l'historique
    model, training_history = fine_tune_clip(df, processor, tokenizer, epochs=5, batch_size=4, accum_steps=4, save_path="finetuned_clip")
    model.eval()
    print("✅ Loaded fine-tuned CLIP model")

except Exception as e:
    print(f"\n❌ Error: {str(e)}")
    import traceback
    traceback.print_exc()
finally:
    torch.cuda.empty_cache()
    import gc
    gc.collect()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import os
from PIL import Image
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import StratifiedKFold, cross_val_predict
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.preprocessing import LabelEncoder # Import LabelEncoder


# Definition de la fonction unified_analysis_pipeline included here for assurance
def find_closest_to_centers(features, labels, df, n_examples=3):
    """Find n products closest to each cluster center."""
    from sklearn.metrics.pairwise import euclidean_distances
    unique_labels = np.unique(labels)
    closest_indices = []

    for label in unique_labels:
        cluster_points = features[labels == label]
        if cluster_points.shape[0] == 0:
            print(f"⚠️ No points found for label {label}")
            continue
        center = np.mean(cluster_points, axis=0)
        distances = euclidean_distances(cluster_points, [center])
        closest_idx = np.argsort(distances.flatten())[:n_examples]
        # Ensure original_indices correspond to the original dataframe df
        original_indices = df[labels == label].iloc[closest_idx].index.tolist()
        closest_indices.extend(original_indices)


    results = df.loc[closest_indices].copy()
    # Map original indices back to labels
    results['cluster'] = labels[results.index]
    # Recalculate distance to center using the features of the selected products
    results['distance_to_center'] = euclidean_distances(
        features[results.index],
        [np.mean(features[labels == l], axis=0) for l in results['cluster']]
    ).diagonal()
    return results.sort_values(['cluster', 'distance_to_center'])

def analyze_classification_errors(features, true_labels, df, category_names, model, processor, tokenizer, top_n_errors=5, n_splits=5, save_folder="error"):
    """Analyze and visualize classification errors from confusion matrix."""
    # Create error directory
    os.makedirs(save_folder, exist_ok=True)
    print("\n⏳ Analyzing classification errors...")

    # Generate predictions using cross-validation
    pipeline = make_pipeline(
        PCA(n_components=0.95, random_state=SEED),
        RandomForestClassifier(n_estimators=100, random_state=SEED, max_features='sqrt', bootstrap=True)
    )
    cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=SEED)
    pred_labels = cross_val_predict(pipeline, features, true_labels, cv=cv, n_jobs=1)

    # Get confusion matrix
    cm = confusion_matrix(true_labels, pred_labels)

    # Find most common error pairs
    error_pairs = []
    for i in range(len(category_names)):
        for j in range(len(category_names)):
            if i != j and cm[i,j] > 0:
                error_pairs.append((i, j, cm[i,j]))

    # Sort by error count
    error_pairs.sort(key=lambda x: x[2], reverse=True)

    # Create error analysis dataframe
    error_df = pd.DataFrame(columns=['true_category', 'predicted_category', 'uniq_id',
                                   'product_name', 'keywords', 'error_count'])

    # Process top error pairs
    for true_idx, pred_idx, error_count in error_pairs[:top_n_errors]:
        true_cat = category_names[true_idx]
        pred_cat = category_names[pred_idx]

        print(f"\n🔴 Erreur fréquente: '{true_cat}' classé comme '{pred_cat}' ({error_count} erreurs)")

        # Get misclassified products
        misclassified_indices = np.where((true_labels == true_idx) & (pred_labels == pred_idx))[0]

        if len(misclassified_indices) == 0:
            print("⚠️ Aucun produit trouvé pour cette paire d'erreur")
            continue

        # Get the actual products from valid_df
        misclassified_products = df.iloc[misclassified_indices] # Use the input df which should be the valid_df from the main pipeline

        # Process each misclassified product
        for _, row in misclassified_products.iterrows():
            try:
                # Create subfolder for this error type
                error_folder = os.path.join(save_folder,
                                          f"{true_cat.replace('/', '_')}_as_{pred_cat.replace('/', '_')}")
                os.makedirs(error_folder, exist_ok=True)

                print(f"   🔍 Traitement du produit: {row['product_name'][:50]}...")

                # Generate CLIP attention analysis
                analysis_results = clip_attention_analysis(
                    uniq_id=row['uniq_id'],
                    df=df, # Pass the input df which should be the valid_df
                    model=model,
                    processor=processor,
                    tokenizer=tokenizer,
                    category_folder=error_folder
                )

                if analysis_results:
                    # ✅ MÊME CODE QUE POUR LE DOSSIER 'CATEGORY'
                    plt.figure(figsize=(10, 6))
                    keywords_list = list(analysis_results['keyword_similarities'].keys())
                    scores_list = list(analysis_results['keyword_similarities'].values()) # Corrected access

                    ax = sns.barplot(x=scores_list, y=keywords_list, hue=keywords_list, palette="Blues_d", legend=False)
                    plt.title(f"Scores de similarité des mots-clés - {row['product_name'][:50]}...")
                    plt.xlabel("Score de similarité")
                    plt.ylabel("Mots-clés")

                    # Ajouter les valeurs au bout des barres
                    for i, score in enumerate(scores_list):
                        ax.text(score + 0.002, i, f'{score:.4f}', va='center', ha='left', fontsize=10, color='black')

                    plt.tight_layout()
                    barchart_path = os.path.join(error_folder, f'keyword_similarity_barchart_{row["uniq_id"]}.png')
                    plt.savefig(barchart_path, dpi=300, bbox_inches='tight')
                    plt.close()
                    print(f"   ✅ Diagramme en barres sauvegardé: {barchart_path}")

                    # Add to error dataframe
                    error_df = pd.concat([error_df, pd.DataFrame([{
                        'true_category': true_cat,
                        'predicted_category': pred_cat,
                        'uniq_id': row['uniq_id'],
                        'product_name': row['product_name'],
                        'keywords': row['keywords'],
                        'error_count': error_count
                    }])], ignore_index=True)

            except Exception as e:
                print(f"❌ Erreur lors du traitement du produit {row['uniq_id']}: {str(e)}")

    # Save error analysis report
    if not error_df.empty:
        error_df.to_csv(os.path.join(save_folder, 'classification_errors_report.csv'), index=False)
        print(f"\n✅ Rapport d'erreurs sauvegardé: '{save_folder}/classification_errors_report.csv'")

        # Afficher un résumé
        print("\n📊 RÉSUMÉ DES ERREURS:")
        for _, row in error_df.iterrows():
            print(f"   - {row['true_category']} → {row['predicted_category']}: "
                      f"{row['product_name'][:30]}...")
        else:
            print("\n✅ Aucune erreur de classification trouvée")

    return error_df

def unified_analysis_pipeline(features, true_labels, df, category_names, model, processor, tokenizer,
                             analysis_type="both", n_representatives=3, top_n_errors=5):
    """
    Pipeline unifié pour l'analyse des produits représentatifs et des erreurs de classification.

    Args:
        features: Caractéristiques combinées
        true_labels: Étiquettes vraies
        df: DataFrame original (should be the valid_df from the main pipeline)
        category_names: Noms des catégories
        model: Modèle CLIP fine-tuné
        processor: Processeur CLIP
        tokenizer: Tokenizer CLIP
        analysis_type: Type d'analyse ("representatives", "errors", ou "both")
        n_representatives: Nombre de produits représentatifs par catégorie
        top_n_errors: Nombre d'erreurs principales à analyser
    """

    # Créer les dossiers nécessaires
    category_save_folder = 'category'
    error_save_folder = 'error'
    os.makedirs(category_save_folder, exist_ok=True)
    os.makedirs(error_save_folder, exist_ok=True)


    if analysis_type in ["representatives", "both"]:
        print("\n" + "="*80)
        print("ANALYSE DES PRODUITS REPRÉSENTATIFS PAR CATÉGORIE")
        print("="*80)

        # Trouver les produits les plus proches des centres de clusters
        print("\n⏳ Recherche des produits les plus proches des centres de clusters...")
        closest_products = find_closest_to_centers(features, true_labels, df, n_examples=n_representatives)


        # Traiter chaque catégorie
        unique_labels = np.unique(true_labels)
        for cluster_id in unique_labels:
            category_name = category_names[cluster_id]
            print(f"\n🏠 Top {n_representatives} produits pour la catégorie '{category_name}':")

            # Créer un dossier pour la catégorie
            category_folder = os.path.join(category_save_folder, category_name.replace('/', '_').replace(' ', '_'))
            os.makedirs(category_folder, exist_ok=True)

            # Sélectionner les produits pour cette catégorie
            category_products = closest_products[closest_products['cluster'] == cluster_id].head(n_representatives)

            if category_products.empty:
                print(f"⚠️ Aucun produit trouvé pour la catégorie '{category_name}'")
                continue

            # Traiter chaque produit représentatif
            for idx, row in category_products.iterrows():
                print(f"\n🔹 Produit: {row['product_name'][:50]}...")
                print(f"   📏 Distance au centre: {row['distance_to_center']:.4f}")

                try:
                    # Générer l'analyse d'attention CLIP
                    analysis_results = clip_attention_analysis(
                        uniq_id=row['uniq_id'],
                        df=df, # Pass the input df which should be the valid_df
                        model=model,
                        processor=processor,
                        tokenizer=tokenizer,
                        category_folder=category_folder # Pass the category-specific folder
                    )

                    if analysis_results:
                        # Générer le diagramme en barres des similarités
                        plt.figure(figsize=(10, 6))
                        keywords_list = list(analysis_results['keyword_similarities'].keys())
                        scores_list = list(analysis_results['keyword_similarities'].values())

                        ax = sns.barplot(x=scores_list, y=keywords_list, hue=keywords_list,
                                        palette="Blues_d", legend=False)
                        plt.title(f"Scores de similarité - {row['product_name'][:50]}...")
                        plt.xlabel("Score de similarité")
                        plt.ylabel("Mots-clés")

                        # Ajouter les valeurs aux barres
                        for i, score in enumerate(scores_list):
                            ax.text(score + 0.002, i, f'{score:.4f}', va='center',
                                   ha='left', fontsize=10, color='black')

                        plt.tight_layout()
                        barchart_path = os.path.join(category_folder,
                                                   f'keyword_similarity_barchart_{row["uniq_id"]}.png')
                        plt.savefig(barchart_path, dpi=300, bbox_inches='tight')
                        plt.close()

                        print(f"✅ Visualisations sauvegardées dans: {category_folder}")

                except Exception as e:
                    print(f"❌ Erreur lors de l'analyse du produit {row['uniq_id']}: {str(e)}")

    if analysis_type in ["errors", "both"]:
        print("\n" + "="*80)
        print("ANALYSE DES ERREURS DE CLASSIFICATION")
        print("="*80)

        # Générer les prédictions par validation croisée
        pipeline = make_pipeline(
            PCA(n_components=0.95, random_state=SEED),
            RandomForestClassifier(n_estimators=100, random_state=SEED,
                                 max_features='sqrt', bootstrap=True)
        )
        cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
        pred_labels = cross_val_predict(pipeline, features, true_labels, cv=cv, n_jobs=1)

        # Obtenir la matrice de confusion
        cm = confusion_matrix(true_labels, pred_labels)

        # Trouver les paires d'erreurs les plus fréquentes
        error_pairs = []
        for i in range(len(category_names)):
            for j in range(len(category_names)):
                if i != j and cm[i, j] > 0:
                    error_pairs.append((i, j, cm[i, j]))

        # Trier par nombre d'erreurs
        error_pairs.sort(key=lambda x: x[2], reverse=True)

        # Créer le dataframe d'analyse d'erreurs
        error_df = pd.DataFrame(columns=['true_category', 'predicted_category', 'uniq_id',
                                       'product_name', 'keywords', 'error_count'])

        # Traiter les principales erreurs
        for true_idx, pred_idx, error_count in error_pairs[:top_n_errors]:
            true_cat = category_names[true_idx]
            pred_cat = category_names[pred_idx]

            print(f"\n🔴 Erreur: '{true_cat}' → '{pred_cat}' ({error_count} erreurs)")

            # Obtenir les indices des produits mal classés
            misclassified_indices = np.where((true_labels == true_idx) & (pred_labels == pred_idx))[0]

            if len(misclassified_indices) == 0:
                print("⚠️ Aucun produit trouvé pour cette paire d'erreur")
                continue

            # Créer un sous-dossier pour ce type d'erreur
            error_folder = os.path.join(error_save_folder,
                                      f"{true_cat.replace('/', '_')}_as_{pred_cat.replace('/', '_')}")
            os.makedirs(error_folder, exist_ok=True)

            # Traiter chaque produit mal classé
            for idx in misclassified_indices:
                row = df.iloc[idx] # Use the input df which should be the valid_df from the main pipeline

                try:
                    print(f"   🔍 Traitement: {row['product_name'][:50]}...")

                    # Générer l'analyse d'attention
                    analysis_results = clip_attention_analysis(
                        uniq_id=row['uniq_id'],
                        df=df, # Pass the input df which should be the valid_df
                        model=model,
                        processor=processor,
                        tokenizer=tokenizer,
                        category_folder=error_folder # Pass the error-specific folder
                    )

                    if analysis_results:
                        # Générer le diagramme en barres
                        plt.figure(figsize=(10, 6))
                        keywords_list = list(analysis_results['keyword_similarities'].keys())
                        # CORRECTED: Use 'keyword_similarities' (plural)
                        scores_list = list(analysis_results['keyword_similarities'].values())

                        ax = sns.barplot(x=scores_list, y=keywords_list, hue=keywords_list,
                                        palette="Reds_d", legend=False)
                        plt.title(f"Erreur: {true_cat} → {pred_cat} - {row['product_name'][:30]}...")
                        plt.xlabel("Score de similarité")
                        plt.ylabel("Mots-clés")

                        for i, score in enumerate(scores_list):
                            ax.text(score + 0.002, i, f'{score:.4f}', va='center',
                                   ha='left', fontsize=10, color='black')

                        plt.tight_layout()
                        barchart_path = os.path.join(error_folder,
                                                   f'keyword_similarity_barchart_{row["uniq_id"]}.png')
                        plt.savefig(barchart_path, dpi=300, bbox_inches='tight')
                        plt.close()

                        # Ajouter au rapport d'erreurs
                        error_df = pd.concat([error_df, pd.DataFrame([{
                            'true_category': true_cat,
                            'predicted_category': pred_cat,
                            'uniq_id': row['uniq_id'],
                            'product_name': row['product_name'],
                            'keywords': row['keywords'],
                            'error_count': error_count
                        }])], ignore_index=True)

                except Exception as e:
                    print(f"❌ Erreur lors du traitement du produit {row['uniq_id']}: {str(e)}")

        # Sauvegarder le rapport d'erreurs
        if not error_df.empty:
            error_df.to_csv(os.path.join(error_save_folder, 'classification_errors_report.csv'), index=False)
            print(f"\n✅ Rapport d'erreurs sauvegardé: '{error_save_folder}/classification_errors_report.csv'")

            # Afficher le résumé
            print("\n📊 RÉSUMÉ DES ERREURS:")
            for _, row in error_df.iterrows():
                print(f"   - {row['true_category']} → {row['predicted_category']}: "
                      f"{row['product_name'][:30]}...")
        else:
            print("\n✅ Aucune erreur de classification trouvée")

    print("\n" + "="*80)
    print("ANALYSE TERMINÉE AVEC SUCCÈS")
    print("="*80)


# Créer les dossier au début de l'exécution du pipeline
os.makedirs('result', exist_ok=True)
print("✅ Created 'result' folder.")
os.makedirs('category', exist_ok=True)
print("✅ Created 'category' folder.")
os.makedirs('error', exist_ok=True)
print("✅ Created 'error' folder.")
os.makedirs('training_analysis', exist_ok=True)
print("✅ Created 'training_analysis' folder.")


try:

    # Tracer les courbes d'apprentissage
    print("\n📊 Plotting training curves...")
    # Correct the parameter name to save_path
    summary_df = plot_training_curves(training_history, save_path="result/training_analysis")

    # Afficher le résumé des métriques finales
    print("\n📊 FINAL METRICS SUMMARY:")
    print(summary_df.to_string(index=False))

    # Analyser l'overfitting
    accuracy_gap = summary_df[summary_df['Metric'] == 'Accuracy']['Gap'].values[0]
    if accuracy_gap > 0.1:
        print(f"\n⚠️  WARNING: Potential overfitting detected! Accuracy gap: {accuracy_gap:.4f}")
    elif accuracy_gap > 0.05:
        print(f"\nℹ️  Moderate overfitting detected. Accuracy gap: {accuracy_gap:.4f}")
    else:
        print(f"\n✅ Good generalization. Accuracy gap: {accuracy_gap:.4f}")

    # Le reste du code reste inchangé...
    categories_encoded, category_names = pd.factorize(df['main_category'])
    print("\n🔍 Extracting features...")
    text_features = extract_text_features(df, model, tokenizer)
    image_features, valid_df = extract_image_features(df, model, processor, max_size=128)
    valid_categories = categories_encoded[valid_df.index]
    combined_features = combine_features(text_features[valid_df.index], image_features, alpha=0.6)

    print("\n📊 Evaluating modalities...")
    results_df, true_labels_cv, pred_labels_cv = compare_modalities(df, text_features, image_features, combined_features, categories_encoded, valid_df, valid_categories, save_folder="result")

    print("\n📊 Generating visualizations...")
    # Explicitly call with correct parameter name
    print("Calling plot_confusion_matrix with save_path='result'")
    plot_confusion_matrix(combined_features, valid_categories, category_names, save_path="result")

    print("Calling plot_tsne with save_path='result'")
    plot_tsne(combined_features, valid_categories, category_names, save_path="result")

    print("\n📊 Analyzing classification errors...")
    # Pass the true and predicted labels from cross-validation to the error analysis function
    # The error was likely here, passing the wrong arguments.
    # We need to pass the valid_categories (true labels for the processed data)
    # and the predicted labels from the cross-validation on the combined features.
    error_report = analyze_classification_errors(
        combined_features,   # Features used for classification (for indexing)
        valid_categories,    # True labels for the processed data
        valid_df,            # Original DataFrame (for product info)
        category_names,      # List of category names
        model,               # Fine-tuned CLIP model
        processor,           # CLIP processor
        tokenizer,           # CLIP tokenizer
        top_n_errors=5,
        n_splits=5,
        save_folder="error"
    )
    print("\n✅ Error analysis completed")

    print("\n📊 Analyzing representatives and errors...")
    # Utilisation de la fonction unifiée
    # The unified_analysis_pipeline also needs the correct true labels (valid_categories)
    # and the valid_df (processed dataframe)
    unified_analysis_pipeline(
        features=combined_features,
        true_labels=valid_categories,
        df=valid_df,
        category_names=category_names,
        model=model,
        processor=processor,
        tokenizer=tokenizer,
        analysis_type="both",  # "representatives", "errors", ou "both"
        n_representatives=3,
        top_n_errors=5
    )


except Exception as e:
    print(f"\n❌ Error: {str(e)}")
    import traceback
    traceback.print_exc()
finally:
    torch.cuda.empty_cache()
    import gc
    gc.collect()

In [None]:
# Specify the product unique ID for attention analysis
product_uniq_id_to_analyze = '1120bc768623572513df956172ffefeb'

# Find the product in the DataFrame
product_row = df[df['uniq_id'] == product_uniq_id_to_analyze]

if not product_row.empty:
    print(f"✅ Found product with uniq_id: {product_uniq_id_to_analyze}")
    # Specify a folder for this specific analysis
    analysis_folder = "attention_analysis"
    os.makedirs(analysis_folder, exist_ok=True)

    # Call the attention analysis function
    attention_results = clip_attention_analysis(
        uniq_id=product_uniq_id_to_analyze,
        df=df,
        model=model,
        processor=processor,
        tokenizer=tokenizer,
        category_folder=analysis_folder # Pass the analysis folder
    )

    if attention_results:
        print(f"\n✅ Attention analysis completed for {product_uniq_id_to_analyze}. Results saved in '{analysis_folder}' folder.")

        # Generate and save the keyword similarity bar chart
        plt.figure(figsize=(10, 6))
        keywords_list = list(attention_results['keyword_similarities'].keys())
        scores_list = list(attention_results['keyword_similarities'].values())

        ax = sns.barplot(x=scores_list, y=keywords_list, hue=keywords_list, palette="Blues_d", legend=False)
        plt.title(f"Scores de similarité des mots-clés - {product_row['product_name'].iloc[0][:50]}...")
        plt.xlabel("Score de similarité")
        plt.ylabel("Mots-clés")

        # Ajouter les valeurs au bout des barres
        for i, score in enumerate(scores_list):
            ax.text(score + 0.002, i, f'{score:.4f}', va='center', ha='left', fontsize=10, color='black')

        plt.tight_layout()
        barchart_path = os.path.join(analysis_folder, f'keyword_similarity_barchart_{product_uniq_id_to_analyze}.png')
        plt.savefig(barchart_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"   ✅ Diagramme en barres sauvegardé: {barchart_path}")

    else:
        print(f"❌ Attention analysis failed for {product_uniq_id_to_analyze}.")
else:
    print(f"❌ Product with uniq_id '{product_uniq_id_to_analyze}' not found in the DataFrame.")

In [None]:
import torch
from transformers import CLIPModel, CLIPTokenizer, CLIPProcessor
import torch.nn as nn
import os

# Ensure the CLIPForClassification class is defined (copying from a previous cell for clarity)
# In a real notebook, you would just need to ensure the cell defining this class has been run.
class CLIPForClassification(CLIPModel):
    def __init__(self, config, num_labels):
        super().__init__(config)
        self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") # Use the same base model name
        self.classifier = nn.Linear(config.projection_dim * 2, num_labels)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, pixel_values, input_ids, attention_mask, labels=None):
        outputs = self.clip(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = torch.cat((outputs.image_embeds, outputs.text_embeds), dim=-1)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss = self.loss_fn(logits, labels)

        return type('Output', (), {
            'loss': loss,
            'logits': logits,
            'image_embeds': outputs.image_embeds,
            'text_embeds': outputs.text_embeds
        })()

def load_finetuned_clip_model(pth_path, num_labels, device):
    """
    Loads the fine-tuned CLIPForClassification model from a .pth state_dict file.
    Handles the mismatched keys by loading relevant parts into the model's components.
    """
    # Initialize the model architecture
    # Ensure you use the same config and num_labels as during training
    config = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").config # Use the same base model name
    model = CLIPForClassification(config, num_labels=num_labels).to(device)

    # Load the state dictionary
    state_dict = torch.load(pth_path, map_location=device)

    # Create a new state_dict that matches the CLIPForClassification structure
    # This requires knowing the keys saved in the .pth file.
    # Based on the error message, the .pth contains the full CLIPModel state_dict.
    # We need to load the 'clip' and 'classifier' parts.
    model_state_dict = model.state_dict()
    new_state_dict = {}

    # Manually copy keys for the 'clip' part
    # The keys in the loaded state_dict for the base CLIP model don't have the 'clip.' prefix.
    # We need to add it to match the keys in model.state_dict()
    for k, v in state_dict.items():
        if k in model_state_dict:
             new_state_dict[k] = v # This handles the classifier keys
        elif 'clip.' + k in model_state_dict:
             new_state_dict['clip.' + k] = v # This handles the base CLIP model keys

    # Load the modified state dictionary into the model
    # Use strict=False to ignore keys in model_state_dict that are not in new_state_dict (e.g., logit_scale)
    # and keys in new_state_dict that are not in model_state_dict (shouldn't happen if we copied correctly).
    # Report missing and unexpected keys for debugging if needed.
    load_result = model.load_state_dict(new_state_dict, strict=False)

    print(f"✅ Model loaded successfully from {pth_path}")
    print(f"Missing keys: {load_result.missing_keys}")
    print(f"Unexpected keys: {load_result.unexpected_keys}")


    return model

# Example usage (you need to define num_labels based on your training data)
# Let's assume num_labels is the number of unique categories in your training data.
# You would need to get this value from your original data loading step.
# For demonstration, let's assume you know the number of categories.
# In a real scenario, you might save the number of labels during training or reload the data.
# For now, replace 'YOUR_NUMBER_OF_LABELS' with the actual number of unique categories.
# You can get this from the 'category_names' variable after running the data loading cell.

# Assuming 'category_names' is available from previous execution
if 'category_names' in locals():
    num_labels = len(category_names)
    print(f"Detected {num_labels} labels from previous execution.")
    try:
        finetuned_model = load_finetuned_clip_model(
            pth_path="new_clip_product_classifier.pth",
            num_labels=num_labels,
            device=device # Use the device defined in the first cell
        )
        print("✅ Fine-tuned model loaded for inference.")
        # You can now use 'finetuned_model' for predictions or feature extraction

    except FileNotFoundError:
        print("❌ Error: 'new_clip_product_classifier.pth' not found. Please run the fine-tuning cell first.")
    except Exception as e:
        print(f"❌ An error occurred during model loading: {str(e)}")
else:
    print("⚠️ 'category_names' variable not found. Please run the data loading cell (Cell 2) first to define it.")
    print("You will need to manually set 'num_labels' or ensure 'category_names' is available before running this cell.")