In [1]:
import numpy as np
import pickle
import sys
import os
import time
import datetime as dt
from tqdm import tqdm
from matplotlib import pyplot as plt
from itertools import chain
import imageio

import torch
import torchvision.transforms as transforms

from transform import RandomResizedCropFlip
from vit_wrapper import PretrainedViTWrapper
from single_image_dataset import SingleImageDataset

from neural_feature_field import NeuralFeatureField
from offline_denoiser import SingleImageDenoiser
from visualization_tools import visualize_offline_denoised_samples

import misc



In [2]:
torch.cuda.device_count()

2

In [3]:
import types
args = types.SimpleNamespace(
    model="vit_base_patch14_dinov2.lvd142m",
    input_size=(518, 518),
    stride_size=14,
    layer_depth_ratio=1.0,
    img_path="demo/assets/demo/cat.jpg",
    dtype="float32",
    data_root=None,#"./features_FGH",
    save_root=None,
    start_idx=0,
    num_imgs=1,
    num_views=768,
    num_iters=25_000,
    warmup_iters=100,
    n_levels=16,
    freeze_shared_artifacts_after=0.5,
    lr=0.01,
    min_lr=0.001,
    weight_decay=1e-5,
    extract_bsz=32,
    pixel_bsz=2048,
    output_dir="./work_dirs/demo",
    num_vis_samples=5,
    vis_freq=100,
    seed=0
)

In [4]:
# device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")

vit = PretrainedViTWrapper(model_identifier=args.model, stride=args.stride_size)
vit = vit.to(device).eval()

device: cuda:0


In [5]:
# 4 parameters computed by specific ViT model
layer_index = int(args.layer_depth_ratio * vit.last_layer_index)
print(f"layer_index: {layer_index}")
pos_h = (args.input_size[0] - vit.patch_size) // args.stride_size + 1
pos_w = (args.input_size[1] - vit.patch_size) // args.stride_size + 1
print(f"patch spatial size: {pos_h, pos_w}")
feat_dim = vit.n_output_dims
print(f"feat_dim: {feat_dim}")
args.layer_index = layer_index
args.feat_dim = vit.n_output_dims
args.noise_map_height = pos_h
args.noise_map_width = pos_w


if not isinstance(args.dtype, torch.dtype):
    args.dtype = torch.float32 if args.dtype == "float32" else torch.bfloat16
print(f"dtype: {args.dtype}")

# normalizer and denormalizer from the model
normalizer = vit.transformation.transforms[-1]
assert isinstance(normalizer, transforms.Normalize), "last transform must be norm"
denormalizer = transforms.Normalize(
    mean=[-m / s for m, s in zip(normalizer.mean, normalizer.std)],
    std=[1 / s for s in normalizer.std],
)

layer_index: 11
patch spatial size: (37, 37)
feat_dim: 768
dtype: torch.float32


In [6]:
# placeholder for the data
num_samples = args.num_views + 1  # + 1 for the original image
global_pixel_coords = torch.zeros(
    (num_samples, pos_h, pos_w, 2),
    dtype=args.dtype,
    device=device,
)
views = torch.zeros(
    (num_samples, 3, args.input_size[0], args.input_size[1]),
    dtype=args.dtype,
    device=device,
)
vit_features = torch.zeros(
    (num_samples, pos_h, pos_w, vit.n_output_dims),
    dtype=args.dtype,
    device=device,
)

In [7]:
dataset = SingleImageDataset(
    size=args.input_size,
    base_transform=transforms.Compose(
        [
            transforms.ToPILImage(),
            transforms.Resize(args.input_size),
            transforms.ToTensor(),
            normalizer,
        ]
    ),
    final_transform=RandomResizedCropFlip(
        size=args.input_size,
        horizontal_flip=True,
        scale=(0.1, 0.5),
        patch_size=vit.patch_size,
        stride=args.stride_size,
    ),
    num_views=args.num_views,
)

In [9]:
def make_patch_coordinates(height, width, start=-1, end=1):
    patch_y, patch_x = torch.linspace(start, end, height), torch.linspace(start, end, width)
    patch_y, patch_x = torch.meshgrid(patch_y, patch_x, indexing="ij")
    patch_coordinates = torch.stack([patch_x, patch_y], dim=-1)
    return patch_coordinates

def denoise_an_image(
    args,
    all_raw_features: torch.Tensor,
    all_pixel_coords: torch.Tensor,
    all_transformed_views: torch.Tensor,
    device: torch.device,
    denormalizer: transforms.Normalize = None,
    identifier: str = '',
    # img_pth: str = None,
    should_save_vis: bool = False,
):
    # ---- build the models and optimizer ---- #
    denoiser = SingleImageDenoiser(
        noise_map_height=args.noise_map_height,
        noise_map_width=args.noise_map_width,
        feat_dim=args.feat_dim,
        layer_index=args.layer_index,
    ).to(device)
    # ---- build a neural field ----#
    neural_field = NeuralFeatureField(feat_dim=args.feat_dim, n_levels=args.n_levels)
    neural_field = neural_field.to(device)
    optimizer = torch.optim.Adam(
        chain(denoiser.parameters(), neural_field.parameters()),
        lr=args.lr,
        eps=1e-15,
        weight_decay=args.weight_decay,
        betas=(0.9, 0.99),
    )
    grad_scaler = torch.amp.GradScaler("cuda", 2**10)

    # ----- shared artifact (G) coordinates ----- #
    shared_artifact_coords = make_patch_coordinates(args.noise_map_height, args.noise_map_width)
    shared_artifact_coords = shared_artifact_coords.to(device)
    num_views = all_raw_features.shape[0]
    batched_shared_artifact_coords = shared_artifact_coords.unsqueeze(0).repeat(num_views, 1, 1, 1)
    batched_shared_artifact_coords = batched_shared_artifact_coords.reshape(-1, 2)

    batched_raw_features = all_raw_features.reshape(-1, all_raw_features.shape[-1])
    batched_pixel_coordinates = all_pixel_coords.reshape(-1, 2)

    
    
    for step in range(args.num_iters):
        step_display = step + 1
        denoiser.train()
        neural_field.train()
        if step > int(args.freeze_shared_artifacts_after * args.num_iters):
            denoiser.stop_shared_artifacts_grad()
            # if denoiser.residual_predictor_start == False:
            #     save_FGH_full_snapshot(neural_field=neural_field,
            #            denoiser=denoiser,
            #            img_name=identifier,
            #            output_dir="FGH1210_beforeH",
            #            args=args)
            denoiser.start_residual_predictor()
        random_pixel_indices = np.random.randint(0, batched_raw_features.shape[0], args.pixel_bsz)
        raw_features = batched_raw_features[random_pixel_indices]
        shared_artifact_coords = batched_shared_artifact_coords[random_pixel_indices]
        pixel_coordinates = batched_pixel_coordinates[random_pixel_indices]
        misc.adjust_learning_rate(optimizer, step, args)
        with torch.autocast(device, dtype=args.dtype, enabled=args.dtype != torch.float32):
            output = denoiser(
                raw_vit_outputs=raw_features,
                global_pixel_coords=pixel_coordinates,
                neural_field=neural_field,
                shared_artifact_coords=shared_artifact_coords,
                return_visualization=False,
            )
            loss = output["loss"]
        optimizer.zero_grad()
        grad_scaler.scale(loss).backward()
        optimizer.step()

        if step_display % 1000 == 0:# or step == args.num_iters - 1:
            print(
                f"Step {step}/{args.num_iters}: "
                f"Loss = {loss.item():.4f}, "
                f"Patch Loss = {output['patch_l2_loss'].item():.4f}, "
                f"CosSim Loss = {output['cosine_similarity_loss'].item():.4f}, "
                f"Residual Loss = {output['residual_loss'].item() if 'residual_loss' in output else 0:.4f}, "
                f"Residual Sparsity Loss = {output['residual_sparsity_loss'].item() if 'residual_sparsity_loss' in output else 0:.4f}, "
                f"LR = {optimizer.param_groups[0]['lr']:.4f}"
            )
    if should_save_vis:
        vis_indices = np.random.randint(0, num_views, args.num_vis_samples)
        vis_indices = np.concatenate([vis_indices, [-1]])
        train_pca_samples, pred_full_denoised_features = visualize_offline_denoised_samples(
            denoiser=denoiser,
            neural_field=neural_field,
            raw_features=all_raw_features[vis_indices],
            coord=all_pixel_coords[vis_indices],
            patch_images=all_transformed_views[vis_indices],
            device=device,
            denormalizer=denormalizer,
            dtype=args.dtype,
        )
        os.makedirs(f"{args.output_dir}/visualization", exist_ok=True)
        # img_name = os.path.basename(img_pth)
        imageio.imsave(f"{args.output_dir}/visualization/{identifier}.jpg", train_pca_samples)
        
        print(f"Saved visualization to {args.output_dir}/visualization/{identifier}")
    else:
        pred_full_denoised_features = None
    pred_full_denoised_features = None
    # if args.data_root is not None:
    #     if pred_full_denoised_features is None:
    #         with torch.no_grad():
    #             output = denoiser(
    #                 raw_vit_outputs=all_raw_features[-1:],
    #                 global_pixel_coords=all_pixel_coords[-1:],
    #                 neural_field=neural_field,
    #                 return_visualization=True,
    #             )
    #         pred_full_denoised_features = output["denoised_feats"].float().detach().cpu().numpy()
    #     raw_feat_dir = f"{args.save_root}/raw_features/{args.model}/"
    #     denoised_feat_dir = f"{args.save_root}/denoised_features/{args.model}/"
    #     # img_extention = os.path.splitext(img_pth)[1]
    #     # raw_feat_save_path = img_pth.replace(args.data_root, raw_feat_dir).replace(
    #     #     img_extention, ".npy"
    #     # )
    #     # denoised_feat_save_path = img_pth.replace(args.data_root, denoised_feat_dir).replace(
    #     #     img_extention, ".npy"
    #     # )
    #     raw_feat_save_path = os.path.join(args.data_root, "rwa_feat", identifier+".npy")
    #     denoised_feat_save_path = os.path.join(args.data_root, "denoised_feat", identifier+".npy")
    #     os.makedirs(os.path.dirname(raw_feat_save_path), exist_ok=True)
    #     os.makedirs(os.path.dirname(denoised_feat_save_path), exist_ok=True)
    #     np.save(raw_feat_save_path, all_raw_features[-1].float().detach().cpu().numpy())
    #     np.save(denoised_feat_save_path, pred_full_denoised_features)
    #     print(
    #         f"Saved denoised features to {denoised_feat_save_path} and raw features to {raw_feat_save_path}"
    #     )
    # save_FGH_full_snapshot(neural_field=neural_field,
    #                        denoiser=denoiser,
    #                        img_name=identifier,
    #                        output_dir="FGH1210",
    #                        args=args)
    
    # del denoiser, neural_field, optimizer
    # torch.cuda.empty_cache()

    return {"denoiser": denoiser, 
            "neural_field": neural_field
           }

# def save_FGH_full_snapshot(neural_field, denoiser, img_name, output_dir, args):
#     base_name = os.path.splitext(img_name)[0]
#     # model_dir = os.path.join(args.output_dir, "models")
#     os.makedirs(output_dir, exist_ok=True)

#     with torch.no_grad():
#         torch.save(
#             {
#                 "F_neural_field": neural_field.state_dict(),
#                 "G_shared_artifacts": denoiser.shared_artifacts.detach().cpu(),
#                 "H_residual_predictor": (
#                     denoiser.residual_predictor.state_dict()
#                     if hasattr(denoiser, "residual_predictor")
#                     and denoiser.residual_predictor is not None
#                     else None
#                 ),
#                 "args": vars(args),
#             },
#             os.path.join(output_dir, f"{base_name}.pt")
#         )


In [10]:
tik = dt.datetime.now()
autocast_ctx = torch.autocast("cuda", dtype=args.dtype, enabled=args.dtype != torch.float32)

np.random.seed(42)
voc2012_path = "" # e.g "YOUR_PATH/VOCdevkit/VOC2012/JPEGImages"
img_filenames = os.listdir(voc2012_path)
img_filenames = sorted([fnm for fnm in img_filenames if fnm[-4:]==".jpg"])
# selected_filenames = np.random.choice(
#     img_filenames,
#     size=5_000,
#     replace=False
# )
# selected_filenames = np.random.permutation(img_filenames)[::-1][:2000]
selected_filenames = np.random.permutation(img_filenames)[:1]

img_paths = [os.path.join(voc2012_path, fnm) for fnm in selected_filenames]

for img_path in img_paths:
    # img_path = img_paths[1]
    dataset.set_image(img_path)
    
    
    
    collect_loader = torch.utils.data.DataLoader(dataset, args.extract_bsz, num_workers=8)
    torch.cuda.empty_cache()
    
    pbar = tqdm(collect_loader, total=len(collect_loader), desc="Collecting features")
    for i, data in enumerate(pbar):
        with torch.no_grad(), autocast_ctx:
            batch_vit_features = vit.get_intermediate_layers(
                data["transformed_view"].to(device),
                n=[layer_index],
                reshape=True,
            )[-1]
            # (B, C, H, W) -> (B, H, W, C)
            batch_vit_features = batch_vit_features.permute(0, 2, 3, 1)
            batch_pixel_coords = data["pixel_coords"].to(device)
            batch_views = data["transformed_view"].to(device)
            slicer = slice(i * args.extract_bsz, i * args.extract_bsz + batch_views.shape[0])
            global_pixel_coords[slicer] = batch_pixel_coords
            views[slicer] = batch_views
            vit_features[slicer] = batch_vit_features
    with torch.no_grad(), autocast_ctx:
        original_vit_features = vit.get_intermediate_layers(
            data["full_image"][:1].to(device), n=[layer_index], reshape=True
        )[-1]
        # (B, C, H, W) -> (B, H, W, C)
        original_vit_features = original_vit_features.permute(0, 2, 3, 1)
    global_pixel_coords[-1] = make_patch_coordinates(pos_h, pos_w, start=0, end=1)
    views[-1] = data["full_image"][0].to(device)
    vit_features[-1] = original_vit_features[0]
    
    
    denoised_results = denoise_an_image(
        args,
        all_raw_features=vit_features,
        all_pixel_coords=global_pixel_coords,
        all_transformed_views=views,
        device=device,
        # img_pth=filename,
        identifier = img_path.split("/")[-1].split(".")[0],
        denormalizer=denormalizer,
        should_save_vis=True#idx % args.vis_freq == 0,
    )
    tok = dt.datetime.now()
    print(f"Training time elapse = {tok-tik}")

Collecting features: 100%|██████████████████████| 24/24 [00:15<00:00,  1.54it/s]


Step 999/25000: Loss = 0.8172, Patch Loss = 0.6908, CosSim Loss = 0.1264, Residual Loss = 0.0000, Residual Sparsity Loss = 0.0000, LR = 0.0100
Step 1999/25000: Loss = 0.7469, Patch Loss = 0.6329, CosSim Loss = 0.1141, Residual Loss = 0.0000, Residual Sparsity Loss = 0.0000, LR = 0.0099
Step 2999/25000: Loss = 0.7094, Patch Loss = 0.6009, CosSim Loss = 0.1086, Residual Loss = 0.0000, Residual Sparsity Loss = 0.0000, LR = 0.0097
Step 3999/25000: Loss = 0.7351, Patch Loss = 0.6223, CosSim Loss = 0.1128, Residual Loss = 0.0000, Residual Sparsity Loss = 0.0000, LR = 0.0095
Step 4999/25000: Loss = 0.7084, Patch Loss = 0.6010, CosSim Loss = 0.1074, Residual Loss = 0.0000, Residual Sparsity Loss = 0.0000, LR = 0.0092
Step 5999/25000: Loss = 0.6716, Patch Loss = 0.5696, CosSim Loss = 0.1020, Residual Loss = 0.0000, Residual Sparsity Loss = 0.0000, LR = 0.0088
Step 6999/25000: Loss = 0.6755, Patch Loss = 0.5730, CosSim Loss = 0.1025, Residual Loss = 0.0000, Residual Sparsity Loss = 0.0000, LR = 

  warn("standard k-means should use a non-inverted distance measure.")
  result[selector] = overlay


Saved visualization to ./work_dirs/demo/visualization/2012_000521
Training time elapse = 0:02:19.633496
