In [25]:
import sys
import os
import glob
import shutil

from matplotlib import pyplot as plt
import matplotlib
from tqdm import tqdm
import numpy as np
import cv2 as cv2

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
from torchvision.transforms import Compose
from torch.utils.tensorboard import SummaryWriter



sys.path.append('C:\\Users\\LuCo\\Documents\\repos\\Depth-Anything-V2')
from depth_anything_v2.dpt import DepthAnythingV2
from metric_depth.dataset.transform import Resize, NormalizeImage, PrepareForNet, Crop


In [26]:
# Global Parameters
# ------------------------------

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DATA_DIR = os.path.join(os.getcwd(), 'data', "train", "train")
CHECKPOINT_DIR = os.path.join(os.getcwd(), 'checkpoints')
BEST_CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, 'best_depth_anything_v2.pth')
BACKUP_CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, 'depth_anything_v2.pth')
ENCODER = 'vitl'
NUM_EPOCHS = 10
LOG_DIR = os.path.join(os.getcwd(), 'logs')
RESULT_DIR = os.path.join(os.getcwd(), 'results')

In [None]:
class DepthDataset(Dataset):
    def __init__(self, root, mode='train', size = (518, 518), stack_size = 10):
        self.root = root
        self.mode = mode
        self.size = size
        self.stack_size = stack_size

        self.rgb_paths = sorted(glob.glob(os.path.join(root, '*_rgb.png')))
        self.depth_paths = [p.replace('_rgb.png', '_depth.npy') for p in self.rgb_paths]
        self.focal_stack_paths = [p.replace('_rgb.png', '_focal_stack_*.png') for p in self.rgb_paths]


        net_w, net_h = size
        self.transform = Compose([
            Resize(width = net_w, height = net_h, 
                   resize_target = True if mode == 'train' else False, 
                   keep_aspect_ratio = True, 
                   ensure_multiple_of = 14, 
                   resize_method = "lower_bound",
                   image_interpolation_method = cv2.INTER_CUBIC),
            NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            PrepareForNet(),
        ] + ([Crop(size[0])] if mode == 'train' else []))

    def __len__(self):
        return len(self.rgb_paths)
    
    def __getitem__(self, idx):
        rgb = cv2.imread(self.rgb_paths[idx])
        rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) / 255.0

        depth = np.load(self.depth_paths[idx])


        # focal stack
        sample = self.transform({'image': rgb, 'depth': depth})

        sample['image'] = torch.from_numpy(sample['image'])
        sample['depth'] = torch.from_numpy(sample['depth'])

        sample['valid_mask'] = (torch.isnan(sample['depth']) == 0)
        sample['depth'][sample['valid_mask'] == 0] = 0

        return sample

In [28]:
class ScaleInvariantRMSELoss(nn.Module):
    def forward(self, pred, target):
        pred = torch.clamp(pred, min=1e-6)
        target = torch.clamp(target, min=1e-6)
        diff = torch.log(pred) - torch.log(target)
        alpha = torch.mean(diff)
        return torch.sqrt(torch.mean((diff - alpha)**2))

In [29]:
if not os.path.exists(LOG_DIR):
    os.makedirs(LOG_DIR)
writer = SummaryWriter(log_dir=LOG_DIR)

In [None]:
depthDataset = DepthDataset(DATA_DIR, mode='train', size=(518, 518), stack_size=10)
print(depthDataset.focal_stack_paths)

In [None]:
def train():
    dataset = DepthDataset(DATA_DIR, mode='train')
    val_size = int(len(dataset) * 0.1)
    train_size = len(dataset) - val_size

    print(dataset.root)
    print(len(dataset), train_size, val_size)
    train_set, val_set = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_set, batch_size=5, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_set, batch_size=5, shuffle=False, num_workers=0)

    model_configs = {
    'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
    'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
    'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
    'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
    }

    # Load the model with the specified encoder
    model = DepthAnythingV2(**model_configs[ENCODER])

    for param in model.pretrained.parameters():
        param.requires_grad = False

    for param in model.depth_head.parameters():
        param.requires_grad = True


    model.load_state_dict(torch.load(BEST_CHECKPOINT_PATH, map_location=DEVICE), strict=False)
    
    model = model.to(DEVICE)
    loss_fn = ScaleInvariantRMSELoss()
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-6, weight_decay=1e-4)

    best_val_loss = 0.113

    
    patience = 5
    patience_counter = 0

    for epoch in range(NUM_EPOCHS):
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
        model.train()
        train_loss = 0.0

        for i, sample in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]")):

            
            rgb = sample['image']
            depth = sample['depth']

            rgb = rgb.to(DEVICE)
            depth = depth.to(DEVICE)

            # print(f"Model device: {next(model.parameters()).device}")
            # print(f"RGB device: {rgb.device}, Depth device: {depth.device}")

            optimizer.zero_grad()
            pred = model(rgb)
            loss = loss_fn(pred, depth)
            loss.backward()
            optimizer.step()

            # logging the loss
            if i % 50 == 0:
                writer.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + i)
            train_loss += loss.item()

        train_loss /= len(train_loader)
        print(f"Train Loss: {train_loss:.4f}")
        writer.add_scalar('Loss/Train_Epoch', train_loss, epoch)

        #Validation
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for sample in tqdm(val_loader, desc=f"Epoch {epoch+1} [Validation]"):
                rgb = sample['image']
                depth = sample['depth']
                rgb = rgb.to(DEVICE)
                depth = depth.to(DEVICE)

                pred = model(rgb)
                loss = loss_fn(pred, depth)
                val_loss += loss.item()
        val_loss /= len(val_loader)
        print(f"Validation Loss: {val_loss:.4f}")
        writer.add_scalar('Loss/Validation_Epoch', val_loss, epoch)

        # Save the model if validation loss improves
        if not os.path.exists(CHECKPOINT_DIR):
            os.makedirs(CHECKPOINT_DIR)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            # Move the previous best checkpoint to a backup location
            if os.path.exists(BEST_CHECKPOINT_PATH):
                shutil.move(BEST_CHECKPOINT_PATH, BACKUP_CHECKPOINT_PATH)

            torch.save(model.state_dict(), BEST_CHECKPOINT_PATH)
            print(f"Model saved to {BEST_CHECKPOINT_PATH}")
            print(f"Previous best model moved to {BACKUP_CHECKPOINT_PATH}")
        else:
            patience_counter += 1
            print(f"Patience counter: {patience_counter}/{patience}")
        
        # Save a sample image and depth map for visualization

        # Load a sample from the training set
        rgb = cv2.imread(os.path.join(DATA_DIR, "sample_000000_rgb.png"))
        gt = np.load(os.path.join(DATA_DIR, "sample_000000_depth.npy"))

        # set the model to evaluation mode
        with torch.no_grad():
            pred = model.infer_image(rgb)
        d_min = np.min(pred)
        d_max = np.max(pred)

        depth_vis = (pred - d_min) / (d_max - d_min + 1e-6)
        
        cmap = matplotlib.colormaps.get_cmap('plasma')
        
        plt.figure(figsize=(15, 5))

        plt.subplot(1, 3, 1)
        plt.title('Input Image')
        plt.axis('off')
        plt.imshow(rgb)

        plt.subplot(1, 3, 2)
        plt.title('Predicted Depth Map')
        plt.axis('off')



        plt.imshow(depth_vis, cmap=cmap)

        plt.subplot(1, 3, 3)
        plt.title('Ground Truth')
        plt.axis('off')
        plt.imshow(gt, cmap=cmap)
        plt.savefig(os.path.join(RESULT_DIR, f"epoch_{epoch+1}_sample.png"))
        plt.close()

        if patience_counter >= patience:
            print("Early stopping triggered.")
            break



c:\Users\LuCo\Documents\repos\DeepAnything\DepthAnything\data\train\train
23971 21574 2397
Epoch 1/10


Epoch 1 [Train]: 100%|██████████| 4315/4315 [1:02:01<00:00,  1.16it/s]


Train Loss: 0.1126


Epoch 1 [Validation]: 100%|██████████| 480/480 [05:42<00:00,  1.40it/s]


Validation Loss: 0.1113
Model saved to c:\Users\LuCo\Documents\repos\DeepAnything\DepthAnything\checkpoints\best_depth_anything_v2.pth
Previous best model moved to c:\Users\LuCo\Documents\repos\DeepAnything\DepthAnything\checkpoints\depth_anything_v2.pth
Epoch 2/10


Epoch 2 [Train]: 100%|██████████| 4315/4315 [1:01:01<00:00,  1.18it/s]


Train Loss: 0.1112


Epoch 2 [Validation]: 100%|██████████| 480/480 [05:33<00:00,  1.44it/s]


Validation Loss: 0.1113
Patience counter: 1/5
Epoch 3/10


Epoch 3 [Train]: 100%|██████████| 4315/4315 [1:00:41<00:00,  1.19it/s]


Train Loss: 0.1101


Epoch 3 [Validation]: 100%|██████████| 480/480 [05:34<00:00,  1.44it/s]


Validation Loss: 0.1112
Model saved to c:\Users\LuCo\Documents\repos\DeepAnything\DepthAnything\checkpoints\best_depth_anything_v2.pth
Previous best model moved to c:\Users\LuCo\Documents\repos\DeepAnything\DepthAnything\checkpoints\depth_anything_v2.pth
Epoch 4/10


Epoch 4 [Train]:  10%|▉         | 431/4315 [05:59<53:56,  1.20it/s]  


KeyboardInterrupt: 