In [7]:
# Importation des librairies

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import cv2
import os
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import mean_absolute_error, mean_squared_error
from depth_anything_v2.dpt import DepthAnythingV2

xFormers not available
xFormers not available


In [18]:
# 1. Chargement des données

# Transformation des images RGB et des nuages de points
def transform_image_and_point_cloud(image, point_cloud, target_height=1200, target_width=1944):
    """
    Redimensionne l'image et le nuage de points pour qu'ils aient des dimensions multiples de 14
    tout en maintenant le ratio d'aspect.
    """
    # Redimensionner l'image RGB
    image_resized = resize_to_multiple_of_14(image, target_height, target_width)
    
    # Redimensionner le nuage de points de manière similaire
    point_cloud_resized = cv2.resize(point_cloud, (image_resized.shape[1], image_resized.shape[0]), interpolation=cv2.INTER_LINEAR)
    
    return image_resized, point_cloud_resized

class DepthDataset(Dataset):
    def __init__(self, rgb_dir, point_cloud_dir, transform=None):
        self.rgb_dir = rgb_dir
        self.point_cloud_dir = point_cloud_dir
        self.rgb_files = sorted(os.listdir(rgb_dir))  # Liste des fichiers PNG
        self.point_cloud_files = sorted(os.listdir(point_cloud_dir))  # Liste des fichiers NPY
        self.transform = transform

    def __len__(self):
        return len(self.rgb_files)

    def __getitem__(self, idx):
        # Chargement de l'image RGB
        rgb_image = cv2.imread(os.path.join(self.rgb_dir, self.rgb_files[idx]))
        rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)  # Conversion BGR -> RGB

        # Chargement du nuage de points
        point_cloud = np.load(os.path.join(self.point_cloud_dir, self.point_cloud_files[idx]))

        # Redimensionner les images et les nuages de points
        rgb_image, point_cloud = resize_image_and_point_cloud(rgb_image, point_cloud)

        # Optionnel : Transformer les images et nuages de points (normalisation ou autres)
        if self.transform:
            rgb_image = self.transform(rgb_image)
            point_cloud = self.transform(point_cloud)

        return torch.tensor(rgb_image, dtype=torch.float32), torch.tensor(point_cloud, dtype=torch.float32)

In [19]:
# 2. Définir LoRA

class LoRA(nn.Module):
    def __init__(self, rank, input_dim, output_dim):
        super(LoRA, self).__init__()
        self.rank = rank
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.A = nn.Parameter(torch.randn(input_dim, rank))
        self.B = nn.Parameter(torch.randn(rank, output_dim))

    def forward(self, x):
        return x + torch.matmul(torch.matmul(x, self.A), self.B)

In [20]:
# 3. Modèle Depth Anything + Intégration de LoRA

class DepthAnythingWithLoRA(nn.Module):
    def __init__(self, base_model, rank=8):
        super(DepthAnythingWithLoRA, self).__init__()
        self.base_model = base_model  # Modèle pré-entraîné Depth Anything (tu devras charger ce modèle)
        self.lora_layers = nn.ModuleList([
            LoRA(rank=rank, input_dim=512, output_dim=512),  # Exemple de taille
            # Ajoute d'autres couches LoRA si nécessaire pour les couches du modèle
        ])

    def forward(self, x):
        x = self.base_model(x)  # Passage à travers le modèle Depth Anything
        for layer in self.lora_layers:
            x = layer(x)  # Application de LoRA
        return x

In [21]:
# 4. Entraînement et fine-tuning

def train(model, dataloader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images_rgb, point_clouds in dataloader:
            images_rgb = images_rgb.permute(0, 3, 1, 2)  # Convertir en (batch, 3, H, W)
            optimizer.zero_grad()

            # Prédiction de la profondeur
            predicted_depth = model(images_rgb)

            # Calcul de la perte
            loss = criterion(predicted_depth, point_clouds)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader)}")

In [22]:
# 5. Évaluation du modèle

def evaluate(model, dataloader):
    model.eval()
    all_predicted = []
    all_true = []
    with torch.no_grad():
        for images_rgb, point_clouds in dataloader:
            images_rgb = images_rgb.permute(0, 3, 1, 2)
            predicted_depth = model(images_rgb)

            all_predicted.append(predicted_depth.cpu().numpy())
            all_true.append(point_clouds.cpu().numpy())

    all_predicted = np.concatenate(all_predicted, axis=0)
    all_true = np.concatenate(all_true, axis=0)

    # Calcul de l'erreur MAE et RMSE
    mae = mean_absolute_error(all_true, all_predicted)
    rmse = np.sqrt(mean_squared_error(all_true, all_predicted))
    print(f"MAE: {mae}, RMSE: {rmse}")

In [23]:
# Code principal

# Paramètres
rgb_dir = 'dataset/images'
point_cloud_dir = 'dataset/depth'
batch_size = 8
learning_rate = 1e-4
num_epochs = 10
rank = 8  # Taille du rang pour LoRA

# 1. Préparation des données
dataset = DepthDataset(rgb_dir, point_cloud_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 2. Charger le modèle pré-entraîné Depth Anything
# Remplace cette partie par le chargement réel du modèle Depth Anything
#base_model = torch.hub.load('depth-anything-v2.github.io', 'depth_anything_pretrained', pretrained=True)
base_model = DepthAnythingV2()

# 3. Ajouter LoRA au modèle
model = DepthAnythingWithLoRA(base_model, rank=rank)

# 4. Optimiseur et critère
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()  # Perte pour prédiction de la profondeur

# 5. Fine-tuning
train(model, dataloader, criterion, optimizer, num_epochs=num_epochs)

# 6. Sauvegarde du modèle fine-tuné
#torch.save(model.state_dict(), 'fine_tuned_depth_anything.pth')

# 7. Évaluation du modèle
evaluate(model, dataloader)

AssertionError: Input image height 1196 is not a multiple of patch height 14