In [1]:
import torch
import torchvision
import os
from unet import UNet
from data import Dataset
import cv2
from PIL import Image
from torchvision.transforms import v2
import wandb

LR = 0.001
EPOCHS = 10
BATCH_SIZE = 32

In [2]:
model = UNet(
    dim=64,                 # dimension that will get multiplied by dim_mults
    with_time_emb=False     # use time embedding (only relevant for diffusion, not for i2i)
)

In [3]:
depth = r'D:\scalable_analytics_institute\semantic_depth_transformer\data\virtual_kitti\vkitti_2.0.3_depth'
rgb = r'D:\scalable_analytics_institute\semantic_depth_transformer\data\virtual_kitti\vkitti_2.0.3_rgb'
semantic = r'D:\scalable_analytics_institute\semantic_depth_transformer\data\virtual_kitti\vkitti_2.0.3_classSegmentation'

rgb_depth_dataset = Dataset(rgb, depth)
dataloader = torch.utils.data.DataLoader(rgb_depth_dataset, batch_size=BATCH_SIZE) 

In [4]:
# move the hyperparameters to their own cell

In [None]:
def save(model, optimizer, path):
    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }, path + 'model.pt')

def load(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer


In [5]:
loss_fn = torch.nn.L1Loss()
optimizer = torch.optim.AdamW(model.parameters(), lr = LR)
SAVE_EVERY = 5  
SAVE_PATH = './checkpoints/'

In [None]:
# training
device = 'cuda'
model = model.to(device)
model.train()

wandb.init(project='unet_depth', name='training_run')

for epoch in range(1, EPOCHS+1):
    for i, (rgb, depth) in enumerate(dataloader):
        rgb, depth = rgb.to(device), depth.to(device)
        pred_depth = model(rgb.float())
        loss = loss_fn(depth, pred_depth)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    if epoch % SAVE_EVERY == 0:
        save(model, optimizer, SAVE_PATH)

wandb.finish()


In [None]:
# validation
model.eval()
model.to(device)

val_dataloader = torch.utils.data.DataLoader(
    Dataset(), # using validation data
    batch_size = BATCH_SIZE,
    shuffle = True
)

wandb.init(project='unet_depth', name='validation_run')

with torch.no_grad():
    for index, batch in enumerate(dataloader):
        video, target = batch
        video, target = video.to(device), target.to(device)
        prediction = model(video)
        loss = loss_fn(prediction, target)
        wandb.log({
            'loss': loss.item()
        }, commit = True)

wandb.finish()