# Testing the combination model using a CNN for the fusion of the pretrained depth maps

In [1]:
import sys, os
import torch, wandb
import torch.nn as nn
from torch.utils.data import DataLoader
sys.path.append(os.path.abspath(os.path.join(os.curdir, '..')))
from configs import combination_model_config as config
from models.unet_convnextv2 import Unet
from models.fusion_models import CNNFusionModel
from models.combination_model import CombinedModel
from datasets.combination_depth_dataset import CombDepthDataset
from utils.train_utils import train_model

  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()


In [2]:
torch.manual_seed(config.random_seed)

<torch._C.Generator at 0x7f815bfcacf0>

In [3]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision.transforms as T

# Initialize dataset and dataloader
dataset = CombDepthDataset(
    data_dir=os.path.join(config.dataset_path, 'train/train'),
    depths_dir=os.path.join(config.depth_maps_path, 'train'),
    list_file=os.path.join(config.dataset_path, 'train_list.txt'),
    transform=config.padded_transform,
    target_transform=config.target_transform,
    has_gt=True,
    depth_model_names=config.depth_model_names,
    uncertainty_dir=None,
    use_uncertainty=None
)

dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Initialize models
fusion_model = CNNFusionModel(input_channels=len(dataset.depth_model_names))
unet_model = Unet(features_included=True, uncertainty_included=False)
model = CombinedModel(fusion_model=fusion_model, unet_model=unet_model, use_uncertainty=False)

# Move to GPU if available
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   =>  did not fit on mine :)
device = torch.device('gpu:3')
torch.cuda.empty_cache() 
model = model.to(device)

# Dummy loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

#  One forward-backward step 
model.train()
for batch in dataloader:
    rgb, depth_stack, gt_depth, filenames, uncertainty_map = batch
    rgb = rgb.to(device)
    depth_stack = depth_stack.to(device)
    gt_depth = gt_depth.to(device)
    uncertainty_map = uncertainty_map.to(device) if uncertainty_map is not None else None

    output = model(rgb, depth_stack, uncertainty_map)
    output = config.unpad_to_original(output, config.img_size)
    loss = criterion(output, gt_depth)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Loss: {loss.item():.4f}")
    break  # just test one batch

[DEBUG] idx=21744, stacked_depths=[torch.Size([448, 576]), torch.Size([448, 576]), torch.Size([448, 576]), torch.Size([448, 576])]
[DEBUG] idx=9954, stacked_depths=[torch.Size([448, 576]), torch.Size([448, 576]), torch.Size([448, 576]), torch.Size([448, 576])]
Loss: 4.8876
