In [None]:
import numpy as np
from sdhelper import SD
from PIL import Image
import torch
import torch.nn.functional as F
from tqdm.autonotebook import tqdm, trange
import matplotlib.pyplot as plt
from datasets import load_dataset
import pickle

torch.set_float32_matmul_precision('high')  # for better performance (got a warning without this during torch compile)


In [None]:
data = load_dataset("0jl/NYUv2", trust_remote_code=True, split="train")
data[0].keys()

In [None]:
sd = SD()


In [4]:
all_blocks_separated = [[
        'conv_in',
    ],[
        'down_blocks[0].resnets[0]',
        'down_blocks[0].attentions[0]',
        'down_blocks[0].resnets[1]',
        'down_blocks[0].attentions[1]',
        'down_blocks[0].downsamplers[0]',
    ],[
        'down_blocks[1].resnets[0]',
        'down_blocks[1].attentions[0]',
        'down_blocks[1].resnets[1]',
        'down_blocks[1].attentions[1]',
        'down_blocks[1].downsamplers[0]',
    ],[
        'down_blocks[2].resnets[0]',
        'down_blocks[2].attentions[0]',
        'down_blocks[2].resnets[1]',
        'down_blocks[2].attentions[1]',
        'down_blocks[2].downsamplers[0]',
    ],[
        'down_blocks[3].resnets[0]',
        'down_blocks[3].resnets[1]',
    ],[
        'mid_block.resnets[0]',
        'mid_block.attentions[0]',
        'mid_block.resnets[1]',
    ],[
        'up_blocks[0].resnets[0]',
        'up_blocks[0].resnets[1]',
        'up_blocks[0].upsamplers[0]',
    ],[
        'up_blocks[1].resnets[0]',
        'up_blocks[1].attentions[0]',
        'up_blocks[1].resnets[1]',
        'up_blocks[1].attentions[1]',
        'up_blocks[1].resnets[2]',
        'up_blocks[1].attentions[2]',
        'up_blocks[1].upsamplers[0]',
    ],[
        'up_blocks[2].resnets[0]',
        'up_blocks[2].attentions[0]',
        'up_blocks[2].resnets[1]',
        'up_blocks[2].attentions[1]',
        'up_blocks[2].resnets[2]',
        'up_blocks[2].attentions[2]',
        'up_blocks[2].upsamplers[0]',
    ],[
        'up_blocks[3].resnets[0]',
        'up_blocks[3].attentions[0]',
        'up_blocks[3].resnets[1]',
        'up_blocks[3].attentions[1]',
        'up_blocks[3].resnets[2]',
        'up_blocks[3].attentions[2]',
    ],[
        'conv_out',
    ]
]
all_blocks = [b for blocks_list in all_blocks_separated for b in blocks_list]

In [5]:
# config
n_steps = 10000
batch_size = 64
noise_step = 50
seed = 42
model_lr = 1e-3
limit_data_to = 100000

In [None]:
data_limited = data.select(range(limit_data_to)) if len(data) > limit_data_to else data
depths_full = torch.tensor([x['depth'] for x in tqdm(data_limited, desc='loading depths')], dtype=torch.float32, device='cuda')
n, w_orig, h_orig = depths_full.shape
n_train = int(n * 0.8)
n_val = n - n_train
depths_train = depths_full[:n_train]
depths_val = depths_full[n_train:]
print(f'{n = }, {w_orig = }, {h_orig = }')
images = [x['image'] for x in tqdm(data_limited, desc='loading images')]

In [7]:
up1_anomalies = np.load('../data/data_labeler/high_norm_anomalies_nyuv2_step50_seed42.npy')
up1_scale_factor = 16

In [None]:

final_losses = {}
models = {}
for blocks in all_blocks_separated:
    # empty cache and move pipeline to cuda for representation extraction
    torch.cuda.empty_cache()
    # sd.pipeline = sd.pipeline.to('cuda')

    # extract representations only for the layers of the current block to save memory
    repr_raw = sd.img2repr(images, extract_positions=blocks, step=noise_step, seed=seed)

    # move pipeline back to cpu to save vram
    # sd.pipeline = sd.pipeline.to('cpu')
    # torch.cuda.empty_cache()

    for block in blocks:
        print('-'*100)
        print(f'{block = }')

        # convert representation to torch tensor
        repr_torch = torch.stack([x[block].squeeze(0) for x in repr_raw]).permute(0, 2, 3, 1).to(dtype=torch.float32, device='cuda')
        print(f'{repr_torch.shape = }')
        repr_train = repr_torch[:n_train]
        repr_val = repr_torch[n_train:]
        _, h, w, features = repr_torch.shape

        # setup model
        model = torch.nn.Linear(features, 1).to('cuda')
        optimizer = torch.optim.AdamW(model.parameters(), lr=model_lr)
        losses = []

        # train
        for i in (tr:=trange(n_steps, desc=f'training')):
            idx = torch.randint(0, n_train, (batch_size,))
            repr = repr_train[idx]
            depths = depths_train[idx]

            pred = model(repr).squeeze(-1).unsqueeze(1)
            pred_full = F.interpolate(pred, (w_orig, h_orig), mode='bilinear').squeeze(1)
            loss = F.huber_loss(pred_full, depths)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            tr.set_postfix(loss=loss.item())
            losses.append(loss.item())

        # plot loss
        plt.plot(losses)
        plt.yscale('log')
        plt.title(f'{block} train loss')
        plt.xlabel('step')
        plt.ylabel('huber loss')
        plt.show()

        # test
        with torch.no_grad():
            pred_all = torch.cat([model(repr_torch[i:i+batch_size]).squeeze(-1) for i in range(0, n, batch_size)]).unsqueeze(1)
            pred_full = F.interpolate(pred_all, (w_orig, h_orig), mode='bilinear').squeeze(1)

            rmse_train = F.mse_loss(pred_full[:n_train], depths_train).item()**.5
            print(f'rmse train: {rmse_train}')
            huber_loss_train = F.huber_loss(pred_full[:n_train], depths_train).item()
            print(f'huber train: {huber_loss_train}')

            rmse_test = F.mse_loss(pred_full[n_train:], depths_val).item()**.5
            print(f'rmse val: {rmse_test}')
            huber_loss_test = F.huber_loss(pred_full[n_train:], depths_val).item()
            print(f'huber val: {huber_loss_test}')

        # compute up1 anomaly metrics
        rmse_train_anomaly = []
        huber_train_anomaly = []
        rmse_test_anomaly = []
        huber_test_anomaly = []
        for j, tmp in enumerate(tqdm(up1_anomalies)):
            img_idx, w_idx, h_idx = tmp.tolist()
            if img_idx >= len(images): continue
            pred_anomaly = pred_full[img_idx, h_idx:h_idx+up1_scale_factor, w_idx:w_idx+up1_scale_factor]
            depth_anomaly = depths_full[img_idx, h_idx:h_idx+up1_scale_factor, w_idx:w_idx+up1_scale_factor]
            if img_idx < n_train:
                rmse_train_anomaly += F.mse_loss(pred_anomaly, depth_anomaly).item()**.5,
                huber_train_anomaly += F.huber_loss(pred_anomaly, depth_anomaly).item(),
            else:
                rmse_test_anomaly += F.mse_loss(pred_anomaly, depth_anomaly).item()**.5,
                huber_test_anomaly += F.huber_loss(pred_anomaly, depth_anomaly).item(),
        
        # compute corner metrics
        rmse_train_corner = sum(F.mse_loss(pred_full[:n_train, -i, -j], depths_train[:, -i, -j]).item()**.5 for i in range(2) for j in range(2)) / 4
        huber_train_corner = sum(F.huber_loss(pred_full[:n_train, -i, -j], depths_train[:, -i, -j]).item() for i in range(2) for j in range(2)) / 4
        rmse_test_corner = sum(F.mse_loss(pred_full[n_train:, -i, -j], depths_val[:, -i, -j]).item()**.5 for i in range(2) for j in range(2)) / 4
        huber_test_corner = sum(F.huber_loss(pred_full[n_train:, -i, -j], depths_val[:, -i, -j]).item() for i in range(2) for j in range(2)) / 4

        # compute border metrics
        rmse_train_border = F.mse_loss(
            torch.cat([pred_full[:n_train, :, -1], pred_full[:n_train, -1, :], pred_full[:n_train, :, 0], pred_full[:n_train, 0, :]], dim=1),
            torch.cat([depths_train[:, :, -1], depths_train[:, -1, :], depths_train[:, :, 0], depths_train[:, 0, :]], dim=1),
        ).item()**.5
        huber_train_border = F.huber_loss(
            torch.cat([pred_full[:n_train, :, -1], pred_full[:n_train, -1, :], pred_full[:n_train, :, 0], pred_full[:n_train, 0, :]], dim=1),
            torch.cat([depths_train[:, :, -1], depths_train[:, -1, :], depths_train[:, :, 0], depths_train[:, 0, :]], dim=1),
        ).item()
        rmse_test_border = F.mse_loss(
            torch.cat([pred_full[n_train:, :, -1], pred_full[n_train:, -1, :], pred_full[n_train:, :, 0], pred_full[n_train:, 0, :]], dim=1),
            torch.cat([depths_val[:, :, -1], depths_val[:, -1, :], depths_val[:, :, 0], depths_val[:, 0, :]], dim=1),
        ).item()**.5
        huber_test_border = F.huber_loss(
            torch.cat([pred_full[n_train:, :, -1], pred_full[n_train:, -1, :], pred_full[n_train:, :, 0], pred_full[n_train:, 0, :]], dim=1),
            torch.cat([depths_val[:, :, -1], depths_val[:, -1, :], depths_val[:, :, 0], depths_val[:, 0, :]], dim=1),
        ).item()

        # store loss / performance
        final_losses[block] = [
            rmse_train,
            huber_loss_train,
            rmse_test,
            huber_loss_test,
            np.mean(rmse_train_anomaly).item(),
            np.mean(huber_train_anomaly).item(),
            np.mean(rmse_test_anomaly).item(),
            np.mean(huber_test_anomaly).item(),
            rmse_train_corner,
            huber_train_corner,
            rmse_test_corner,
            huber_test_corner,
            rmse_train_border,
            huber_train_border,
            rmse_test_border,
            huber_test_border,
        ]
        models[block] = model.to('cpu')

        del repr_torch, repr_train, repr_val
    del repr_raw


In [9]:
pickle.dump(models, open('depth_estimation_nyu_like_beyond_surface_statistics_models.pkl', 'wb'))

In [10]:
final_losses_np = np.array(list(final_losses.values()))
np.save('depth_estimation_nyu_like_beyond_surface_statistics_final_losses.npy', final_losses_np)

In [11]:
y_train_rmse, y_train_huber, y_test_rmse, y_test_huber, y_train_rmse_anomaly, y_train_huber_anomaly, y_test_rmse_anomaly, y_test_huber_anomaly, y_train_rmse_corner, y_train_huber_corner, y_test_rmse_corner, y_test_huber_corner, y_train_rmse_border, y_train_huber_border, y_test_rmse_border, y_test_huber_border = final_losses_np.T

In [None]:

fig, ax1 = plt.subplots(figsize=(10, 4))

x = np.arange(len(final_losses))
lines = []
lines += ax1.plot(x, y_train_rmse, label='train', color='tab:green', linestyle='-')
lines += ax1.plot(x, y_test_rmse, label='test', color='tab:green', linestyle='--')

lines += ax1.plot(x, y_train_rmse_anomaly, label='train anomaly', color='tab:red', linestyle='-')
lines += ax1.plot(x, y_test_rmse_anomaly, label='test anomaly', color='tab:red', linestyle='--')

# plot x ticks
ticks = ['attn' if 'attentions' in block else 'res' if 'resnets' in block else 'down' if 'downsamplers' in block else 'up' if 'upsamplers' in block else 'conv' if 'conv' in block else '?' for block in all_blocks]
ax1.set_xticks(x)
ax1.set_xticklabels(ticks, rotation=90)

# compute main blocks names and positions
main_blocks = []
main_block_positions = []
tmp = 0
for block_list in all_blocks_separated:
    if 'mid' in block_list[0]:
        name = 'mid'
    elif 'conv' in block_list[0]:
        name = block_list[0][5:]
    else:
        a, b, *_ = block_list[0].split('[')
        name = a.replace('_blocks','') + b.split(']')[0]
    main_blocks.append(name)
    main_block_positions.append(tmp)
    tmp += len(block_list)

# lines between main blocks
for p in main_block_positions[1:]:
    ax1.axvline(x=p-0.5, color='black', linestyle='--', c='lightgray')
ax_x3 = ax1.secondary_xaxis(location=0)
ax_x3.set_xticks([p-0.5 for p in main_block_positions[1:]], labels=[])
ax_x3.tick_params(axis='x', length=34, width=1.5, color='lightgray')

ax_x2 = ax1.secondary_xaxis(location=0)
ax_x2.set_xticks([p+len(bl)/2-0.5 for p, bl in zip(main_block_positions, all_blocks_separated)], labels=[f'\n\n\n{b}' for b in main_blocks], ha='center')
ax_x2.tick_params(length=0)


ax1.set_ylabel('rmse')
ax1.set_yscale('log')
ax1.set_yticks([2**i for i in range(-1,2)])
ax1.set_yticklabels([f'{y:.1f}' for y in ax1.get_yticks()])
ax1.yaxis.set_minor_formatter(plt.NullFormatter())

# Combine legends from both axes
labs = [l.get_label() for l in lines]
ax1.legend(lines, labs)

plt.show()


In [None]:
fig, ax1 = plt.subplots(figsize=(10, 7))

colors = ['tab:blue', 'tab:orange', 'tab:green']
lines = []
for i, block_type in enumerate(['attention', 'resnet', 'samplers']):
    x = [i for i, b in enumerate(all_blocks) if block_type in b]
    y_tmp = [y_train_huber[i] for i in x]
    lines.extend(ax1.plot(x, y_tmp, label=f'{block_type} train', color=colors[i], linestyle='-'))
    y_tmp = [y_test_huber[i] for i in x]
    lines.extend(ax1.plot(x, y_tmp, label=f'{block_type} test', color=colors[i], linestyle='--'))

ax1.set_xticks(np.arange(len(all_blocks)))
ax1.set_xticklabels(all_blocks, rotation=90)
ax1.set_ylabel('huber loss')

# Combine legends from both axes
ax1.legend(lines, [l.get_label() for l in lines])

ax1.set_title('huber loss')
fig.tight_layout()
plt.show()


In [14]:
# loss_data = np.array([y_train_huber, y_test_huber, y_test_rmse])
# np.save('depth_estimation_nyu_like_beyond_surface_statistics_all_loss_data.npy', loss_data)
