In [1]:
import numpy as np
from sdhelper import SD
from PIL import Image
import torch
import torch.nn.functional as F
from tqdm.autonotebook import tqdm, trange
import matplotlib.pyplot as plt
from datasets import load_dataset

torch.set_float32_matmul_precision('high')  # for better performance (got a warning without this during torch compile)


In [None]:
data = load_dataset("0jl/NYUv2", trust_remote_code=True, split="train")
data[0].keys()

In [None]:
sd = SD()

In [None]:
blocks = ['up_blocks[1]']
# blocks = ['up_blocks[1]', 'up_blocks[2].resnets[2]']
repr_raw = sd.img2repr([x['image'] for x in data], extract_positions=blocks, step=50, seed=42)
# repr_raw = sd.img2repr([x['image'].resize(tuple(np.array(x['image'].size)*2)) for x in data], extract_positions=blocks, step=50, seed=42, batch_size=10)

In [5]:
# save vram
del sd

In [None]:
repr_torch = torch.stack([x.concat() for x in repr_raw]).to(dtype=torch.float32).permute(0, 2, 3, 1)

n, w, h, features = repr_torch.shape
w_orig, h_orig = np.array(data[0]['depth']).shape

if w_orig % w != 0 or h_orig % h != 0:
    print('fixing width/height...')
    w_tmp, h_tmp = w_orig, h_orig
    while w_tmp > w or h_tmp > h:
        w_tmp //= 2
        h_tmp //= 2
    print(f'{w}x{h} -> {w_tmp}x{h_tmp}')
    w, h = w_tmp, h_tmp
    repr_torch = repr_torch[:, :w, :h, :]

n_train = int(n * 0.8)
n_val = n - n_train

repr_train = repr_torch[:n_train]
repr_val = repr_torch[n_train:]

print(repr_train.shape)

In [11]:
depths_full = torch.tensor([x['depth'] for x in data], dtype=torch.float32)
depths_scaled = depths_full.reshape(n, w, w_orig//w, h, h_orig//h).mean(dim=(2, 4))
depths_train = depths_scaled[:n_train]
depths_val = depths_scaled[n_train:]


In [8]:
# define SiLog loss (https://arxiv.org/abs/1406.2283)

def silog_loss(pred, target, lambd=0.5):
    diff = torch.log(pred) - torch.log(target)
    return ((diff**2).mean() - lambd * diff.mean()**2)**0.5

## Simple linear model

In [None]:
model = torch.nn.Linear(features, 1).to('cuda')

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

losses = []
for i in trange(500):
    idx = torch.randint(0, n_train, (128,))
    repr = repr_train[idx].to('cuda')
    depths = depths_train[idx].to('cuda')

    pred = model(repr).squeeze(-1)
    loss = F.mse_loss(pred, depths)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f'{i}: {loss.item()}')
    losses.append(loss.item())

In [None]:
plt.plot(losses)
plt.yscale('log')
plt.show()

In [None]:
# test
with torch.no_grad():
    pred_test = model(repr_val.to('cuda')).squeeze(-1)
    loss_test = F.mse_loss(pred_test, depths_val.to('cuda'))
    print(f'val: {loss_test.item()}')

In [None]:
# plot example
idx = torch.randint(0, n_val, (1,)).item()
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].imshow(data[n_train + idx]['image'].rotate(-90, expand=True))
axs[0].axis('off')
axs[0].set_title('input')
axs[1].imshow(np.rot90(depths_val[idx].squeeze().cpu().numpy(), k=-1), cmap='gray_r')
axs[1].axis('off')
axs[1].set_title('target')
axs[2].imshow(np.rot90(pred_test[idx].squeeze().cpu().numpy(), k=-1), cmap='gray_r')
axs[2].axis('off')
axs[2].set_title('pred')
plt.show()


## Classification model

In [None]:
num_classes = 20
model = torch.nn.Linear(features, num_classes).to('cuda')

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

min_depth = depths_train.min()
max_depth = depths_train.max()
depths_train_class = ((depths_train - min_depth) / (max_depth - min_depth) * (num_classes - 1)).long()

losses = []
for i in trange(1000):
    idx = torch.randint(0, n_train, (64,))
    repr = repr_train[idx].to('cuda')
    depths = depths_train_class[idx].flatten().to('cuda')

    pred = model(repr)
    loss = F.cross_entropy(pred.flatten(0, 2), depths)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print(f'{i}: {loss.item()}')
    losses.append(loss.item())

In [None]:
plt.plot(losses)
plt.yscale('log')
plt.show()

In [None]:
# test
depths_val_class = ((depths_val - min_depth) / (max_depth - min_depth) * (num_classes - 1)).long()
with torch.no_grad():
    pred_test = model(repr_val.to('cuda'))
    loss_test = F.cross_entropy(pred_test.flatten(0, 2), depths_val_class.flatten().to('cuda'))
    print(f'val: {loss_test.item()}')

In [None]:
# plot example
idx = torch.randint(0, n_val, (1,)).item()
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].imshow(data[n_train + idx]['image'].rotate(-90, expand=True))
axs[0].axis('off')
axs[0].set_title('input')
axs[1].imshow(np.rot90(depths_val[idx].squeeze().cpu().numpy(), k=-1), cmap='gray_r')
axs[1].axis('off')
axs[1].set_title('target')
axs[2].imshow(np.rot90(pred_test[idx].argmax(dim=-1).squeeze().cpu().numpy(), k=-1), cmap='gray_r')
axs[2].axis('off')
axs[2].set_title('pred')
plt.show()


## CNN model

In [None]:
model = torch.nn.Conv2d(features, 1, kernel_size=3, padding=1).to('cuda')

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

losses = []
for i in trange(1000):
    idx = torch.randint(0, n_train, (64,))
    repr = repr_train[idx].permute(0, 3, 1, 2).to('cuda')
    depths = depths_train[idx].to('cuda')

    pred = model(repr).squeeze(1)
    loss = F.mse_loss(pred, depths)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print(f'{i}: {loss.item()}')
    losses.append(loss.item())

In [None]:
plt.plot(losses)
plt.yscale('log')
plt.show()

In [None]:
# test
with torch.no_grad():
    pred_test = model(repr_val.permute(0, 3, 1, 2).to('cuda')).squeeze(1)
    loss_test = F.mse_loss(pred_test, depths_val.to('cuda'))
    print(f'val: {loss_test.item()}')

In [None]:
# plot example
idx = torch.randint(0, n_val, (1,)).item()
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].imshow(data[n_train + idx]['image'].rotate(-90, expand=True))
axs[0].axis('off')
axs[0].set_title('input')
axs[1].imshow(np.rot90(depths_val[idx].squeeze().cpu().numpy(), k=-1), cmap='gray_r')
axs[1].axis('off')
axs[1].set_title('target')
axs[2].imshow(np.rot90(pred_test[idx].squeeze().cpu().numpy(), k=-1), cmap='gray_r')
axs[2].axis('off')
axs[2].set_title('pred')
plt.show()

## MLP

In [9]:
model = torch.nn.Sequential(
    torch.nn.Linear(features, 512),
    torch.nn.GELU(),
    torch.nn.Dropout(p=0.2),
    torch.nn.Linear(512, 512),
    torch.nn.GELU(), 
    torch.nn.Dropout(p=0.2),
    torch.nn.Linear(512, 1),
    torch.nn.ReLU(),
).to('cuda')
model = torch.compile(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

losses = []


In [None]:
model.train()
for i in trange(1000):
    idx = torch.randint(0, n_train, (128,))
    repr = repr_train[idx].to('cuda')
    depths = depths_train[idx].to('cuda')

    pred = model(repr).squeeze(-1) + 1e-6
    loss = silog_loss(pred, depths)
    # loss = F.mse_loss(pred, depths)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f'{i}: {loss.item()}')
    losses.append(loss.item())

In [None]:
plt.plot(losses)
plt.yscale('log')
plt.show()

In [None]:
# test
with torch.no_grad():
    model.eval()
    pred_test = model(repr_val.to('cuda')).squeeze(-1)
    loss_test = F.mse_loss(pred_test, depths_val.to('cuda'))
    print(f'val (mse): {loss_test.item()}')
    loss_test = silog_loss(pred_test, depths_val.to('cuda'))
    print(f'val (silog): {loss_test.item()}')

In [None]:
# plot example
idx = torch.randint(0, n_val, (1,)).item()
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
axs[0].imshow(data[n_train + idx]['image'].rotate(-90, expand=True))
axs[0].axis('off')
axs[0].set_title('input')
axs[1].imshow(np.rot90(depths_val[idx].squeeze().cpu().numpy(), k=-1), cmap='inferno_r')
axs[1].axis('off')
axs[1].set_title('target')
axs[2].imshow(np.rot90(pred_test[idx].squeeze().cpu().numpy(), k=-1), cmap='inferno_r')
axs[2].axis('off')
axs[2].set_title('pred')
axs[3].imshow(np.rot90(repr_val[idx].norm(dim=2).cpu().numpy(), k=-1), cmap='viridis')
axs[3].axis('off')
axs[3].set_title('repr norm')
plt.show()


## Complex CNN Classifier

In [None]:
num_classes = 20
model = torch.nn.Sequential(
    torch.nn.Conv2d(features, 512, kernel_size=3, padding=1),
    torch.nn.GELU(),
    torch.nn.Conv2d(512, 128, kernel_size=3, padding=1),
    torch.nn.GELU(),
    torch.nn.Conv2d(128, num_classes, kernel_size=1),
).to('cuda')
model = torch.compile(model)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

min_depth = depths_train.min()
max_depth = depths_train.max()
depths_train_class = ((depths_train - min_depth) / (max_depth - min_depth) * (num_classes - 1)).long()

losses = []
for i in trange(1000):
    idx = torch.randint(0, n_train, (64,))
    repr = repr_train[idx].permute(0, 3, 1, 2).to('cuda')
    depths = depths_train_class[idx].flatten().to('cuda')

    pred = model(repr)
    loss = F.cross_entropy(pred.permute(0, 2, 3, 1).flatten(0, 2), depths)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print(f'{i}: {loss.item()}')
    losses.append(loss.item())

In [None]:
plt.plot(losses)
plt.yscale('log')
plt.show()

In [None]:
# test
depths_val_class = ((depths_val - min_depth) / (max_depth - min_depth) * (num_classes - 1)).long()
with torch.no_grad():
    pred_test = model(repr_val.permute(0, 3, 1, 2).to('cuda')).permute(0, 2, 3, 1)
    loss_test = F.cross_entropy(pred_test.flatten(0, 2), depths_val_class.flatten().to('cuda'))
    print(f'val: {loss_test.item()}')

In [None]:
# plot example
idx = torch.randint(0, n_val, (1,)).item()
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].imshow(data[n_train + idx]['image'].rotate(-90, expand=True))
axs[0].axis('off')
axs[0].set_title('input')
axs[1].imshow(np.rot90(depths_val[idx].squeeze().cpu().numpy(), k=-1), cmap='gray_r')
axs[1].axis('off')
axs[1].set_title('target')
axs[2].imshow(np.rot90(pred_test[idx].argmax(dim=-1).squeeze().cpu().numpy(), k=-1), cmap='gray_r')
axs[2].axis('off')
axs[2].set_title('pred')
plt.show()


## Complex CNN Model

In [52]:
model = torch.nn.Sequential(
    torch.nn.Conv2d(features, 512, kernel_size=3, padding=1),
    torch.nn.SiLU(),
    torch.nn.Conv2d(512, 512, kernel_size=3, padding=1),
    torch.nn.SiLU(),
    torch.nn.Conv2d(512, 1, kernel_size=1),
    torch.nn.ELU(),
).to('cuda')
model = torch.compile(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
losses = []


In [None]:
for i in trange(1000):
    idx = torch.randint(0, n_train, (64,))
    repr = repr_train[idx].permute(0, 3, 1, 2).to('cuda')
    depths = depths_train[idx].to('cuda')

    pred = model(repr).squeeze(1) + 1
    loss = F.mse_loss(pred, depths)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print(f'{i}: {loss.item()}')
    losses.append(loss.item())

In [None]:
plt.plot(losses)
plt.yscale('log')
plt.show()

In [None]:
# test
with torch.no_grad():
    pred_test = model(repr_val.permute(0, 3, 1, 2).to('cuda')).squeeze(1)
    loss_test = F.mse_loss(pred_test, depths_val.to('cuda'))
    print(f'val: {loss_test.item()}')

In [None]:
# plot example
idx = torch.randint(0, n_val, (1,)).item()
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
axs[0].imshow(data[n_train + idx]['image'].rotate(-90, expand=True))
axs[0].axis('off')
axs[0].set_title('input')
axs[1].imshow(np.rot90(depths_val[idx].squeeze().cpu().numpy(), k=-1), cmap='inferno_r')
axs[1].axis('off')
axs[1].set_title('target')
axs[2].imshow(np.rot90(pred_test[idx].squeeze().cpu().numpy(), k=-1), cmap='inferno_r')
axs[2].axis('off')
axs[2].set_title('pred')
axs[3].imshow(np.rot90(repr_val[idx].norm(dim=2).cpu().numpy(), k=-1), cmap='viridis')
axs[3].axis('off')
axs[3].set_title('repr norm')
plt.show()


## Like "Beyond Surface Statistics"

In [27]:
model = torch.nn.Linear(features, 1).to('cuda')
# model = torch.compile(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
losses = []


In [None]:
for i in trange(1000):
    idx = torch.randint(0, n_train, (128,))
    repr = repr_train[idx].to('cuda')
    depths = depths_full[idx].to('cuda')

    pred = model(repr).squeeze(-1).unsqueeze(1)
    pred_full = F.interpolate(pred, (w_orig, h_orig), mode='bilinear').squeeze(1)
    loss = F.huber_loss(pred_full, depths)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print(f'{i}: {loss.item()}')
    losses.append(loss.item())

In [None]:
plt.plot(losses)
plt.yscale('log')
plt.show()

In [None]:
# test
with torch.no_grad():
    pred_test = model(repr_val.to('cuda')).squeeze(-1).unsqueeze(1)
    pred_test_full = F.interpolate(pred_test, (w_orig, h_orig), mode='bilinear').squeeze(1)
    mse_loss_test = F.mse_loss(pred_test_full, depths_full[n_train:].to('cuda'))
    print(f'mse val: {mse_loss_test.item()}')
    huber_loss_test = F.huber_loss(pred_test_full, depths_full[n_train:].to('cuda'))
    print(f'huber val: {huber_loss_test.item()}')


In [None]:
# plot example
idx = torch.randint(0, n_val, (1,)).item()
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
axs[0].imshow(data[n_train + idx]['image'].rotate(-90, expand=True))
axs[0].axis('off')
axs[0].set_title('input')
axs[1].imshow(np.rot90(depths_full[n_train + idx].squeeze().cpu().numpy(), k=-1), cmap='inferno_r')
axs[1].axis('off')
axs[1].set_title('target')
axs[2].imshow(np.rot90(pred_test_full[idx].squeeze().cpu().numpy(), k=-1), cmap='inferno_r')
axs[2].axis('off')
axs[2].set_title('pred')
axs[3].imshow(np.rot90(repr_val[idx].norm(dim=2).cpu().numpy(), k=-1), cmap='viridis')
axs[3].axis('off')
axs[3].set_title('repr norm')
plt.show()
