In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from pathlib import Path
data_path = Path('ReDWeb_V1/')

In [None]:
os.listdir(data_path)

## Creating the dataset

In [None]:
from torch.utils.data import Dataset
import cv2

class ReDWebDataset(Dataset):
    def __init__(self, rgb_paths, depth_paths, transform=None):
        self.rgb_paths = rgb_paths
        self.depth_paths = depth_paths
        self.transform = transform  

    def __len__(self):
        return len(self.rgb_paths)

    def __getitem__(self, idx):
        img = cv2.imread(self.rgb_paths[idx])
        depth = cv2.imread(self.depth_paths[idx], cv2.IMREAD_UNCHANGED)

        if self.transform:
            img, depth = self.transform(img, depth)
        return img, depth


In [None]:
filenames = os.listdir(data_path / 'Imgs')

images_filenames = [data_path / 'Imgs' / filename for filename in filenames]
depths_filenames = [data_path / 'RDs' / f"{filename.split('.')[0]}.png" for filename in filenames]

In [None]:
full_dataset = ReDWebDataset(images_filenames, depths_filenames)

## Test/Train/Validation split

- Train: 70% (2520 images)
- Test: 15% (540 images)
- Validation: 15% (540 images)

In [None]:
from sklearn.model_selection import train_test_split

indices = list(range(len(full_dataset)))

train_indices, temp_indices = train_test_split(
    indices, test_size=0.30, random_state=42, shuffle=True
)

val_indices, test_indices = train_test_split(
    temp_indices, test_size=0.50, random_state=42, shuffle=True
)


In [None]:
train_imgs  = [images_filenames[i] for i in train_indices]
train_depth = [depths_filenames[i] for i in train_indices]

val_imgs  = [images_filenames[i] for i in val_indices]
val_depth = [depths_filenames[i] for i in val_indices]

test_imgs  = [images_filenames[i] for i in test_indices]
test_depth = [depths_filenames[i] for i in test_indices]

## Augmenting function


In [None]:
import random
import cv2
import numpy as np
import torch

def train_transform(img, depth):
    if random.random() < 0.5:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
        img[:,:,2] = img[:,:,2] * random.uniform(0.9, 1.1)
        img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)

    # 50% chance of flip
    if random.random() < 0.5:
        img = cv2.flip(img, 1)
        depth = cv2.flip(depth, 1)

    # Random resized crop
    h, w = img.shape[:2]
    scale = random.uniform(0.8, 1.0)  # keep 80–100% of image
    new_h, new_w = int(scale*h), int(scale*w)
    y = random.randint(0, h-new_h)
    x = random.randint(0, w-new_w)

    img = img[y:y+new_h, x:x+new_w]
    depth = depth[y:y+new_h, x:x+new_w]

    # Resize into 384x384
    img = cv2.resize(img, (384, 384))
    depth = cv2.resize(depth, (384, 384))

    img = torch.from_numpy(img).permute(2,0,1).float() / 255.0
    depth = torch.from_numpy(depth).float() / 255.0

    return img, depth


def val_transform(img, depth):
    img = cv2.resize(img, (384, 384))
    depth = cv2.resize(depth, (384, 384))
    return torch.from_numpy(img).permute(2,0,1)/255.0, torch.from_numpy(depth)/255.0

In [None]:
train_dataset = ReDWebDataset(train_imgs, train_depth, transform=train_transform)
val_dataset   = ReDWebDataset(val_imgs, val_depth, transform=val_transform)
test_dataset  = ReDWebDataset(test_imgs, test_depth, transform=val_transform)

In [None]:
import matplotlib.pyplot as plt

def visualize_samples(dataset, num_samples=4):
    fig, axes = plt.subplots(2, num_samples, figsize=(num_samples*3, 6))

    for i in range(num_samples):
        idx = random.randint(0, len(dataset)-1)
        img, depth = dataset[idx]  # get sample

        # Convert tensor → numpy for plotting
        if torch.is_tensor(img):
            img_np = img.permute(1,2,0).cpu().numpy()
        else:
            img_np = img[:,:,::-1]  # BGR→RGB if still numpy

        if torch.is_tensor(depth):
            depth_np = depth.squeeze().cpu().numpy()
        else:
            depth_np = depth

        # Show RGB
        axes[0, i].imshow(img_np, cmap=None)
        axes[0, i].set_title("RGB image")
        axes[0, i].axis('off')

        # Show depth
        axes[1, i].imshow(depth_np, cmap="gray")
        axes[1, i].set_title("Relative Depth")
        axes[1, i].axis('off')

    plt.tight_layout()
    plt.show()

visualize_samples(train_dataset, num_samples=4)

## Training

In [None]:
import depth_perception_model

model = depth_perception_model.DepthEstimationModel()

In [None]:
from torchsummary import summary
summary(model, (3, 384,384))

## Data loaders

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False
)


In [None]:
import loss_function

model = depth_perception_model.DepthEstimationModel()

criterion = loss_function.ranking_loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
from tqdm import tqdm

num_epochs = 15

for epoch in range(num_epochs):
    model.train()
    train_loss = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for img, depth in pbar:

        img = img
        depth = depth

        optimizer.zero_grad()
        pred = model(img)             # forward pass
        loss = criterion(pred, depth) # compute loss
        loss.backward()               # backprop
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)

    # --- VALIDATION ---
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for img, depth in val_loader:
            img = img
            depth = depth
            pred = model(img)
            val_loss += criterion(pred, depth).item()

    val_loss /= len(val_loader)
    print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")


In [None]:
# To save model with optimizer
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': criterion,
}, "checkpoint.pth")

In [None]:
# To laod the model
checkpoint = torch.load("checkpoint.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

epoch = checkpoint['epoch']
loss = checkpoint['loss']

In [None]:
for img, depth in test_loader:      # your DataLoader
    img = img.cuda()

    with torch.no_grad():
        pred = model(img).cpu()     # (B,1,H,W)
        # print(pred.shape)
        # ranking_loss(pred, depth)

    # visualize first 4 samples in batch
    for i in range(4):
        image_np = img[i].permute(1,2,0).cpu().numpy()        # (H,W,3)
        depth_np = depth[i].squeeze().cpu().numpy()           # (H,W)
        pred_np  = pred[i].squeeze().numpy()                  # (H,W)

        plt.figure(figsize=(10,3))

        plt.subplot(1,3,1)
        plt.title("Image")
        plt.imshow(image_np)
        plt.axis("off")

        plt.subplot(1,3,2)
        plt.title("Ground Truth Depth")
        plt.imshow(depth_np, cmap='inferno')
        plt.axis("off")

        plt.subplot(1,3,3)
        plt.title("Prediction")
        plt.imshow(pred_np, cmap='inferno')
        plt.axis("off")

        plt.show()

    break