In [3]:
import numpy as np
from sdhelper import SD
from PIL import Image
import torch
import torch.nn.functional as F
from tqdm.autonotebook import tqdm, trange
import matplotlib.pyplot as plt
from datasets import load_dataset
import torch

torch.set_float32_matmul_precision('high')  # for better performance (got a warning without this during torch compile)


In [4]:
blocks = [
    *[f'down_blocks[{i}]' for i in range(4)],
    'mid_block',
    *[f'up_blocks[{i}]' for i in range(4)],
]

In [None]:
data = load_dataset("0jl/NYUv2", trust_remote_code=True, split="train")
data[0].keys()

In [None]:
sd = SD()


In [None]:
repr_raw = sd.img2repr([x['image'] for x in data], extract_positions=blocks, step=50, seed=42)


In [None]:
block = 'up_blocks[1]'

depths_full = torch.tensor([x['depth'] for x in data], dtype=torch.float32)
n, w_orig, h_orig = depths_full.shape
n_train = int(n * 0.8)
n_val = n - n_train
depths_train = depths_full[:n_train]
depths_val = depths_full[n_train:]
print(n, w_orig, h_orig)

repr_torch = torch.stack([x[block].squeeze(0) for x in repr_raw]).to(dtype=torch.float32).permute(0, 2, 3, 1)
print(repr_torch.shape)

repr_train = repr_torch[:n_train]
repr_val = repr_torch[n_train:]
features = repr_torch.shape[-1]


In [None]:
mean_anomaly = torch.load("../data/data_labeler/imagenet_subset_high_norm_anomalies_step50_seed42_heavy_only_reprs_of_patches_mean.pt", weights_only=True).to(dtype=torch.float32).to('cuda')
mean_anomaly.shape

In [None]:

model = torch.nn.Linear(features, 1).to('cuda')

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
losses = []

for i in trange(1000):
    idx = torch.randint(0, n_train, (64,))
    repr = repr_train[idx].to('cuda')
    depths = depths_full[idx].to('cuda')

    pred = model(repr).squeeze(-1).unsqueeze(1)
    pred_full = F.interpolate(pred, (w_orig, h_orig), mode='bilinear').squeeze(1)
    loss = F.huber_loss(pred_full, depths)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print(f'{i}: {loss.item()}')
    losses.append(loss.item())

plt.plot(losses)
plt.yscale('log')
plt.show()

In [None]:
depths.shape

In [None]:
# test performance on val set
with torch.no_grad():
    repr = repr_val.to('cuda')
    depths = depths_val.to('cuda')
    
    pred = model(repr).squeeze(-1).unsqueeze(1)
    pred_full = F.interpolate(pred, (w_orig, h_orig), mode='bilinear').squeeze(1)
    val_loss = F.huber_loss(pred_full, depths)
    print(f'Validation loss: {val_loss.item()}')

# test performance on borders
with torch.no_grad():
    repr = repr_val.to('cuda')
    depths = depths_val.to('cuda')
    depths_borders = torch.cat([depths[:, :16, :], depths[:, :, :16].transpose(-1, -2), depths[:, :, -16:].transpose(-1, -2), depths[:, -16:, :]], dim=-1)
    
    pred = model(repr).squeeze(-1).unsqueeze(1)
    pred_full = F.interpolate(pred, (w_orig, h_orig), mode='bilinear').squeeze(1)
    pred_borders = torch.cat([pred_full[:, :16, :], pred_full[:, :, :16].transpose(-1, -2), pred_full[:, :, -16:].transpose(-1, -2), pred_full[:, -16:, :]], dim=-1)
    val_loss = F.huber_loss(pred_borders, depths_borders)
    print(f'Validation loss on borders: {val_loss.item():.4f} ({depths_borders.numel()} - {depths_borders.numel() / depths.numel():.2%} of total)')

# test performance on corners
with torch.no_grad():
    repr = repr_val.to('cuda')
    depths = depths_val.to('cuda')
    depths_corners = torch.stack([depths[:, :16, :16], depths[:, :16, -16:], depths[:, -16:, :16], depths[:, -16:, -16:]])

    pred = model(repr).squeeze(-1).unsqueeze(1)
    pred_full = F.interpolate(pred, (w_orig, h_orig), mode='bilinear').squeeze(1)
    pred_corners = torch.stack([pred_full[:, :16, :16], pred_full[:, :16, -16:], pred_full[:, -16:, :16], pred_full[:, -16:, -16:]])
    val_loss = F.huber_loss(pred_corners, depths_corners)
    print(f'Validation loss on corners: {val_loss.item():.4f} ({depths_corners.numel()} - {depths_corners.numel() / depths.numel():.2%} of total)')

# test performance on anomalies
with torch.no_grad():
    repr = repr_val.to('cuda')
    depths = depths_val.to('cuda')
    
    similarities = torch.nn.functional.cosine_similarity(repr, mean_anomaly, dim=-1)
    similarities_map_upscaled = ((similarities > 0.8) & (repr.norm(dim=-1) > (0.0 * repr.norm(dim=-1).max())))[:,:,None,:,None].expand((-1, -1, 2**4, -1, 2**4)).reshape(depths.shape)
    depths_anomalies = depths[similarities_map_upscaled]
    pred = model(repr).squeeze(-1).unsqueeze(1)
    pred_full = F.interpolate(pred, (w_orig, h_orig), mode='bilinear').squeeze(1)
    pred_anomalies = pred_full[similarities_map_upscaled]
    val_loss = F.huber_loss(pred_anomalies, depths_anomalies)
    print(f'Validation loss on anomalies: {val_loss.item():.4f} ({depths_anomalies.numel()} - {depths_anomalies.numel() / depths.numel():.2%} of total)')


In [None]:
plt.hist(similarities.cpu().numpy().flatten(), bins=100)
plt.hist(repr.norm(dim=-1).cpu().numpy().flatten() / repr.norm(dim=-1).cpu().numpy().flatten().max(), bins=100, alpha=0.5)
plt.yscale('log')
plt.show()