In [12]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image

In [14]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm  # Barre de progression

# Hugging Face & LoRA
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
from peft import LoraConfig, get_peft_model

# ==========================================
# 1. CONFIGURATION (MODE S√âCURIT√â)
# ==========================================
# METS "True" POUR TON PC (Pour √©viter le crash RTX 5070 / Driver)
# METS "False" QUAND TU LANCES SUR LE CLUSTER DE L'ECOLE
FORCE_CPU = True

MODEL_ID = "depth-anything/Depth-Anything-V2-Small-hf"
OUTPUT_DIR = "./resultats_projet"
BATCH_SIZE = 4
LR = 1e-4
EPOCHS = 10  # Suffisant pour avoir des r√©sultats visibles

# Choix du processeur
if FORCE_CPU:
    DEVICE = "cpu"
    print("‚ö†Ô∏è MODE CPU FORC√â (Lent mais stable pour g√©n√©rer le rapport)")
else:
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"üöÄ D√©marrage sur : {DEVICE}")
    if DEVICE == "cuda":
        print(f"Carte : {torch.cuda.get_device_name(0)}")

os.makedirs(OUTPUT_DIR, exist_ok=True)

# ==========================================
# 2. DATASET (Zivid + Correction Unit√©s)
# ==========================================
class ZividDataset(Dataset):
    def __init__(self, root_dir, processor):
        self.img_dir = os.path.join(root_dir, "images")
        self.depth_dir = os.path.join(root_dir, "depth")
        self.processor = processor
        
        valid_ext = ('.png', '.jpg', '.jpeg')
        if not os.path.exists(self.img_dir):
             raise FileNotFoundError(f"‚ùå Dossier introuvable : {self.img_dir}")

        self.images = sorted([f for f in os.listdir(self.img_dir) if f.lower().endswith(valid_ext)])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        # On remplace _color.jpg par _rawDepth.npy
        base_name = os.path.splitext(img_name)[0].replace("_color", "")
        npy_name = base_name + "_rawDepth.npy"
        
        img_path = os.path.join(self.img_dir, img_name)
        npy_path = os.path.join(self.depth_dir, npy_name)

        # 1. Image
        image = Image.open(img_path).convert("RGB")
        
        # 2. Profondeur (V√©rit√© Terrain)
        try:
            point_cloud = np.load(npy_path)
            depth_Z = point_cloud[:, :, 2] # On garde Z
            
            # --- IMPORTANT : CONVERSION MM -> METRES ---
            # Si max > 100, c'est du mm. On divise par 1000 pour aider le mod√®le.
            if np.nanmax(depth_Z) > 100:
                depth_Z = depth_Z / 1000.0
            # -------------------------------------------
            
        except Exception:
            return None # Skip si fichier illisible

        # 3. Pr√©paration Inputs
        inputs = self.processor(images=image, return_tensors="pt")
        target_h, target_w = inputs["pixel_values"].shape[-2:]
        
        depth_tensor = torch.from_numpy(depth_Z).float().unsqueeze(0).unsqueeze(0)
        
        # 4. Masque (Ignorer les NaNs et valeurs <= 0)
        mask = ~torch.isnan(depth_tensor) & ~torch.isinf(depth_tensor) & (depth_tensor > 0)
        depth_tensor = torch.nan_to_num(depth_tensor, nan=0.0)
        
        # 5. Redimensionnement (Nearest pour ne pas inventer de donn√©es)
        depth_resized = torch.nn.functional.interpolate(depth_tensor, size=(target_h, target_w), mode='nearest')
        mask_resized = torch.nn.functional.interpolate(mask.float(), size=(target_h, target_w), mode='nearest')
        
        return {
            "pixel_values": inputs["pixel_values"].squeeze(0),
            "labels": depth_resized.squeeze(0),
            "mask": mask_resized.squeeze(0)
        }

# ==========================================
# 3. FONCTIONS M√âTRIQUES & VISUALISATION
# ==========================================
def compute_metrics(pred, target, mask):
    """Calcule AbsRel, RMSE et Delta < 1.25"""
    pred = pred[mask]
    target = target[mask]
    
    if len(target) == 0:
        return 0.0, 0.0, 0.0

    # AbsRel (Erreur relative)
    abs_rel = torch.mean(torch.abs(pred - target) / target)
    # RMSE (Erreur quadratique)
    rmse = torch.sqrt(torch.mean((pred - target) ** 2))
    # Accuracy Delta < 1.25 (Combien de pixels sont pr√©cis ?)
    max_ratio = torch.max(pred / target, target / pred)
    delta1 = (max_ratio < 1.25).float().mean()

    return abs_rel.item(), rmse.item(), delta1.item()

def save_comparison_image(pixel_values, true_depth, pred_depth, epoch):
    """G√©n√®re l'image pour le rapport"""
    # D√©normalisation image
    img = pixel_values.permute(1, 2, 0).cpu().numpy()
    img = (img - img.min()) / (img.max() - img.min())
    
    true_d = true_depth.squeeze().cpu().numpy()
    pred_d = pred_depth.squeeze().detach().cpu().numpy()
    
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1); plt.imshow(img); plt.title("Input RGB"); plt.axis('off')
    plt.subplot(1, 3, 2); plt.imshow(true_d, cmap='inferno'); plt.title("V√©rit√© (Z)"); plt.axis('off')
    plt.subplot(1, 3, 3); plt.imshow(pred_d, cmap='inferno'); plt.title(f"Pr√©diction (Ep {epoch})"); plt.axis('off')
    
    plt.savefig(os.path.join(OUTPUT_DIR, f"resultat_epoch_{epoch}.png"))
    plt.close()

# ==========================================
# 4. FONCTION MAIN (TRAIN)
# ==========================================
def run_project():
    base_dir = os.getcwd()
    dataset_dir = os.path.join(base_dir, "DATASET_DEVOIR")
    
    print("‚è≥ Chargement du mod√®le...")
    processor = AutoImageProcessor.from_pretrained(MODEL_ID)
    model = AutoModelForDepthEstimation.from_pretrained(MODEL_ID)
    
    # --- LoRA CONFIG ---
    # On cible les modules d'attention
    lora_config = LoraConfig(
        r=16, lora_alpha=16, target_modules=["query", "value"], 
        lora_dropout=0.1, bias="none"
    )
    model = get_peft_model(model, lora_config)
    model.to(DEVICE)
    model.print_trainable_parameters()
    
    # --- DATASET ---
    print(f"üìÇ Donn√©es : {dataset_dir}")
    try:
        dataset = ZividDataset(dataset_dir, processor)
        dataset = [d for d in dataset if d is not None] # Filtre erreurs
    except Exception as e:
        print(f"‚ùå Erreur : {e}"); return

    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    
    print(f"üî• C'est parti pour {EPOCHS} √©poques !")
    
    model.train()
    for epoch in range(1, EPOCHS + 1):
        total_loss = 0
        total_delta = 0
        
        loop = tqdm(loader, desc=f"Epoch {epoch}/{EPOCHS}")
        
        for batch_idx, batch in enumerate(loop):
            pixel_values = batch["pixel_values"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)
            mask = batch["mask"].to(DEVICE)
            
            optimizer.zero_grad()
            
            # Forward
            outputs = model(pixel_values=pixel_values)
            predicted_depth = outputs.predicted_depth
            
            # Interpolation
            prediction = torch.nn.functional.interpolate(
                predicted_depth.unsqueeze(1), size=labels.shape[-2:], 
                mode="bilinear", align_corners=False
            )
            
            # Loss (Sur pixels valides seulement)
            loss = torch.sum(mask * (prediction - labels)**2) / (torch.sum(mask) + 1e-6)
            
            loss.backward()
            optimizer.step()
            
            # Calcul M√©triques
            with torch.no_grad():
                abs_rel, rmse, d1 = compute_metrics(prediction, labels, mask.bool())
            
            total_loss += loss.item()
            total_delta += d1
            
            loop.set_postfix(Loss=f"{loss.item():.4f}", Delta=f"{d1:.3f}")

            # Sauvegarde image t√©moin (1√®re batch seulement)
            if batch_idx == 0:
                save_comparison_image(pixel_values[0], labels[0], prediction[0], epoch)
        
        # Bilan Fin d'√âpoque
        avg_loss = total_loss / len(loader)
        avg_delta = total_delta / len(loader)
        print(f"üèÅ Epoch {epoch} termin√© | Loss: {avg_loss:.4f} | Pr√©cision (Delta): {avg_delta:.4f}")

    # Sauvegarde finale
    save_path = os.path.join(OUTPUT_DIR, "modele_final_lora")
    model.save_pretrained(save_path)
    print(f"üéâ Termin√© ! Mod√®le sauvegard√© dans : {save_path}")

if __name__ == "__main__":
    run_project()

‚ö†Ô∏è MODE CPU FORC√â (Lent mais stable pour g√©n√©rer le rapport)
‚è≥ Chargement du mod√®le...
trainable params: 294,912 || all params: 25,080,001 || trainable%: 1.1759
üìÇ Donn√©es : c:\Users\simon\Documents\git\Transformer-projet\DATASET_DEVOIR
üî• C'est parti pour 10 √©poques !


Epoch 1/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15/15 [01:21<00:00,  5.46s/it, Delta=0.110, Loss=1.6556]


üèÅ Epoch 1 termin√© | Loss: 2.0889 | Pr√©cision (Delta): 0.1665


Epoch 2/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15/15 [01:18<00:00,  5.20s/it, Delta=0.142, Loss=1.2178]


üèÅ Epoch 2 termin√© | Loss: 1.0025 | Pr√©cision (Delta): 0.2290


Epoch 3/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15/15 [01:17<00:00,  5.18s/it, Delta=0.183, Loss=0.8615]


üèÅ Epoch 3 termin√© | Loss: 0.6454 | Pr√©cision (Delta): 0.3133


Epoch 4/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15/15 [01:18<00:00,  5.26s/it, Delta=0.479, Loss=0.3158]


üèÅ Epoch 4 termin√© | Loss: 0.3881 | Pr√©cision (Delta): 0.3962


Epoch 5/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15/15 [01:21<00:00,  5.44s/it, Delta=0.489, Loss=0.1750]


üèÅ Epoch 5 termin√© | Loss: 0.2432 | Pr√©cision (Delta): 0.5268


Epoch 6/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15/15 [01:17<00:00,  5.15s/it, Delta=0.724, Loss=0.1319]


üèÅ Epoch 6 termin√© | Loss: 0.1795 | Pr√©cision (Delta): 0.6078


Epoch 7/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15/15 [01:24<00:00,  5.63s/it, Delta=0.648, Loss=0.1218]


üèÅ Epoch 7 termin√© | Loss: 0.1377 | Pr√©cision (Delta): 0.6703


Epoch 8/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15/15 [01:25<00:00,  5.68s/it, Delta=0.629, Loss=0.1410]


üèÅ Epoch 8 termin√© | Loss: 0.1176 | Pr√©cision (Delta): 0.7266


Epoch 9/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15/15 [01:30<00:00,  6.05s/it, Delta=0.878, Loss=0.0576]


üèÅ Epoch 9 termin√© | Loss: 0.0972 | Pr√©cision (Delta): 0.7754


Epoch 10/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15/15 [01:34<00:00,  6.33s/it, Delta=0.821, Loss=0.1045]


üèÅ Epoch 10 termin√© | Loss: 0.0898 | Pr√©cision (Delta): 0.7977
üéâ Termin√© ! Mod√®le sauvegard√© dans : ./resultats_projet\modele_final_lora
