In [1]:
import torch,imageio,sys,time,ffmpeg,cv2,cmapy,os
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from omegaconf import OmegaConf

sys.path.append('..')
# from models.sparseCoding import sparseCoding 
from models.FactorFields import FactorFields 

from utils import SimpleSampler
from dataLoader import dataset_dict
from torch.utils.data import DataLoader

device = 'cuda'
torch.cuda.set_device(0)

%load_ext autoreload
%autoreload 2

In [2]:
def PSNR(a,b):
    if type(a).__module__ == np.__name__:
        mse = np.mean((a-b)**2)
    else:
        mse = torch.mean((a-b)**2).item()
    psnr = -10.0 * np.log(mse) / np.log(10.0)
    return psnr

@torch.no_grad()
def eval_img(reso, chunk=10240,target_region=[0.0,0.0,1.0,1.0]):
    y = torch.linspace(target_region[0],target_region[2],reso[0])*(H-1)
    x = torch.linspace(target_region[1],target_region[3],reso[1])*(W-1)
    # y = torch.arange(0, reso[0])
    # x = torch.arange(0, reso[1])
    yy, xx = torch.meshgrid((y, x), indexing='ij')
    res = []
    
    coordiantes = torch.stack((xx,yy),dim=-1).reshape(-1,2) + 0.5 #/(torch.FloatTensor(reso[::-1])-1)*2-1
    # if normalize:
    #     coordiantes = coordiantes/torch.tensor([W,H])*2-1
    coordiantes = torch.split(coordiantes,chunk,dim=0)
    for coordiante in coordiantes:

        feats,_ = model.get_coding(coordiante.to(model.device))
        y_recon = model.linear_mat(feats)
        
        res.append(y_recon.cpu())
    return torch.cat(res).reshape(reso[0],reso[1],-1)

def srgb_to_linear(img):
	limit = 0.04045
	return np.where(img > limit, np.power((img + 0.055) / 1.055, 2.4), img / 12.92)

def zoom_in_animation(target_region=[0.0,0.0,0.5,0.5], n_frames=150):
    shiftment_y = 0.5 - (target_region[0]+target_region[2])/2
    shiftment_x = 0.5 - (target_region[1]+target_region[3])/2
    scale_y = 1.0 / (target_region[2] - target_region[0])
    scale_x = 1.0 / (target_region[3] - target_region[1])
    return

In [3]:
base_conf = OmegaConf.load('../configs/defaults.yaml')
second_conf = OmegaConf.load('../configs/image_intro.yaml')
cfg = OmegaConf.merge(base_conf, second_conf)

# Please pick one of the follow models

In [4]:
data_mode = 'rgb' # or 'binary'

if data_mode == 'binary':
    out_dim = 1
    cfg.dataset.datadir = '../data/image/cat_occupancy.png'
else:
    out_dim = 3
    cfg.dataset.datadir = '../data/image/cat_rgb.png'

## i. Implicit Network

In [7]:
model_name = 'occ'
cfg.model.coeff_type = 'none'
cfg.model.basis_type = 'x'
cfg.model.basis_mapping = 'x'
cfg.model.num_layers = 8
cfg.model.hidden_dim = 256
cfg.model.freq_bands=[1.]
cfg.model.basis_dims=[1]
# cfg.model.basis_resos=[2160]

## ii. NeRF

In [11]:
model_name = 'nerf'
cfg.model.coeff_type = 'none'
cfg.model.basis_type = 'x'
cfg.model.basis_mapping = 'trigonometric'
cfg.model.num_layers = 8
cfg.model.hidden_dim = 256
cfg.model.freq_bands=[1.,2.,4.,8.,16.,32.,64,128,256.,512.]
cfg.model.basis_dims=[1,1,1,1,1,1,1,1,1,1]
cfg.model.basis_resos=[1024,512,256,128,64,32,16,8,4,2]

## iii. Dense Grid

In [5]:
model_name = 'dense-grid'
cfg.model.coeff_type = 'grid'
cfg.model.basis_type = 'none'
cfg.model.coeff_reso = 128
cfg.model.num_layers = 2
cfg.model.hidden_dim = 64
cfg.model.basis_dims = [12]
cfg.model.basis_resos=[1]
cfg.model.T_coeff = 2048000

# learning rate
cfg.training.lr_small: 0.002
cfg.training.lr_large: 0.002

## v. Tensor Decomposition

In [22]:
model_name = 'cp'
cfg.model.coeff_type = 'vec'
cfg.model.basis_type = 'cp'
cfg.model.num_layers = 2
cfg.model.hidden_dim = 64
cfg.model.basis_dims = [320]
cfg.model.freq_bands =  [1.]
cfg.model.basis_resos =  [1024]
cfg.model.T_basis = 2048000

# learning rate
cfg.training.lr_small: 0.0002
cfg.training.lr_large: 0.002

# cfg.model.coef_init: 0.001


## iv. Hash Grid

In [8]:
model_name = 'hash-grid'
cfg.model.coeff_type = 'none'
cfg.model.basis_type = 'hash'
cfg.model.coeff_reso = 0
cfg.model.num_layers = 2
cfg.model.hidden_dim = 64
cfg.model.basis_dims = [2,2,2,2,2,2]
cfg.model.freq_bands =  [1.,2.,4.,8.,16.,32.]
cfg.model.T_basis = 2048000

# learning rate
cfg.training.lr_small: 0.0002
cfg.training.lr_large: 0.002

## vi. Dictionary Factorization

In [14]:
dataset = dataset_dict[cfg.dataset.dataset_name]
train_dataset = dataset(cfg.dataset, cfg.training.batch_size, split='train', tolinear=False, perscent=0.5)
train_dataset.image = train_dataset.image[...,:out_dim]
train_loader = DataLoader(train_dataset,
              num_workers=8,
              persistent_workers=True,
              batch_size=None,
              pin_memory=True)

cfg.model.out_dim = out_dim
batch_size = cfg.training.batch_size
n_iter = cfg.training.n_iters

H,W = train_dataset.HW
cfg.dataset.aabb = train_dataset.scene_bbox

In [15]:
model = FactorFields(cfg, device)
print(model)
print('total parameters: ',model.n_parameters())

grad_vars = model.get_optparam_groups(lr_small=cfg.training.lr_small,lr_large=cfg.training.lr_large)
optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))#

H,W = train_dataset.HW

imgs = []

        
psnrs,times = [],[0.0]
loss_scale = 1.0
lr_factor = 0.1 ** (1 / n_iter)
pbar = tqdm(range(n_iter))
start = time.time()
for (iteration, sample) in zip(pbar,train_loader):
    iteration_start = time.time()
    loss_scale *= lr_factor

    coordiantes, pixel_rgb = sample['xy'], sample['rgb']
    feats,coeff = model.get_coding(coordiantes.to(device))
    
    y_recon = model.linear_mat(feats)
    
    loss = torch.mean((y_recon.squeeze()-pixel_rgb.squeeze().to(device))**2) 
    
    
    psnr = -10.0 * np.log(loss.item()) / np.log(10.0)
    psnrs.append(psnr)
    times.append(time.time()-start)
    
    if iteration%10==0:
        pbar.set_description(
                    f'Iteration {iteration:05d}:'
                    + f' loss_dist = {loss.item():.8f}'
                    + f' psnr = {psnr:.3f}'
                )
    
    loss = loss * loss_scale
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    iteration_end = time.time()
    times.append(times[-1] + iteration_end-iteration_start)
    
    if iteration%(n_iter//150) == 0 or iteration==n_iter-1:
        imgs.append((eval_img(train_dataset.HW).clamp(0,1.)*255).to(torch.uint8))
        
iteration_time = time.time()-start  
    

# img = eval_img(train_dataset.HW).clamp(0,1.)
# print(PSNR(img,train_dataset.image.view(img.shape)),iteration_time)
# plt.figure(figsize=(10, 10))
# plt.imshow(img)


imageio.mimwrite('/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/temp.mp4', imgs, fps=30, quality=10)
# np.savetxt(f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/cat_{model_name}_psnr_{data_mode}.txt',psnrs)
# np.savetxt(f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/cat_{model_name}_time_{data_mode}.txt',times)

stream = ffmpeg.input('/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/temp.mp4')
stream = ffmpeg.output(stream, f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/cat_sparse_{model_name}_{data_mode}.mp4')
ffmpeg.run(stream,overwrite_output=True)

48 2048000 2048000
=====> total parameters:  2047852
FactorFields(
  (coeffs): ParameterList(  (0): Parameter containing: [torch.float32 of size 1x12x413x413 (GPU 0)])
  (linear_mat): MLPMixer(
    (backbone): ModuleList(
      (0): Linear(in_features=12, out_features=64, bias=True)
      (1): Linear(in_features=64, out_features=3, bias=False)
    )
  )
)
total parameters:  2047852


Iteration 09990: loss_dist = 0.00055990 psnr = 32.519: 100%|█| 10000/1
ffmpeg version 4.3 Copyright (c) 2000-2020 the FFmpeg developers
  built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)
  configuration: --prefix=/mnt/lustre/geiger/zyu30/.conda/envs/sdfstudio --cc=/opt/conda/conda-bld/ffmpeg_1597178665428/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --disable-openssl --enable-avresample --enable-gnutls --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame
  libavutil      56. 51.100 / 56. 51.100
  libavcodec     58. 91.100 / 58. 91.100
  libavformat    58. 45.100 / 58. 45.100
  libavdevice    58. 10.100 / 58. 10.100
  libavfilter     7. 85.100 /  7. 85.100
  libavresample   4.  0.  0 /  4.  0.  0
  libswscale      5.  7.100 /  5.  7.100
  libswresample   3.  7.100 /  3.  7.100
Input #0, mov,mp4,m4a,3gp,3g2,mj2, from '/mnt/qb/home/geiger/zy

(None, None)

# vis

In [7]:
model = FactorFields(cfg, device)
print(model)
print('total parameters: ',model.n_parameters())

grad_vars = model.get_optparam_groups(lr_small=cfg.training.lr_small,lr_large=cfg.training.lr_large)
optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))#

H,W = train_dataset.HW

imgs = []
if model_name == 'dense-grid':
    for c in range(model.coeffs[0].shape[1]):
        feat = model.coeffs[0][0,c].cpu().detach().numpy()
        feat = (feat-np.min(feat))/(np.max(feat) - np.min(feat))
        img = cmapy.colorize((feat * 255).astype('uint8'), 'coolwarm')
        imgs.append(img)

psnrs,times = [],[0.0]
loss_scale = 1.0
lr_factor = 0.1 ** (1 / n_iter)
pbar = tqdm(range(n_iter))
start = time.time()
for (iteration, sample) in zip(pbar,train_loader):
    iteration_start = time.time()
    loss_scale *= lr_factor

    coordiantes, pixel_rgb = sample['xy'], sample['rgb']
    feats,coeff = model.get_coding(coordiantes.to(device))
    
    y_recon = model.linear_mat(feats)
    
    loss = torch.mean((y_recon.squeeze()-pixel_rgb.squeeze().to(device))**2) 
    
    
    psnr = -10.0 * np.log(loss.item()) / np.log(10.0)
    psnrs.append(psnr)
    times.append(time.time()-start)
    
    if iteration%10==0:
        pbar.set_description(
                    f'Iteration {iteration:05d}:'
                    + f' loss_dist = {loss.item():.8f}'
                    + f' psnr = {psnr:.3f}'
                )
    
    loss = loss * loss_scale
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    iteration_end = time.time()
    times.append(times[-1] + iteration_end-iteration_start)
    
    if iteration%(n_iter//150) == 0 or iteration==n_iter-1:
        if model_name == 'dense-grid':
            for c in range(model.coeffs[0].shape[1]):
                feat = model.coeffs[0][0,c].cpu().detach().numpy()
                feat = (feat-np.min(feat))/(np.max(feat) - np.min(feat))
                img = cmapy.colorize((feat * 255).astype('uint8'), 'coolwarm')
                imgs.append(img)
        elif model_name == 'cp':
             for c in range(model.coeffs[0].shape[1]):
                for item in [model.coeffs[0],model.basises[0]]:
                    feat = item[0,c].cpu().detach().numpy()
                    feat = (feat-np.min(feat))/(np.max(feat) - np.min(feat))
                    feat = cv2.resize(feat,(64,feat.shape[0]))
                    img = cmapy.colorize((feat * 255).astype('uint8'), 'coolwarm')
                    imgs.append(img)
        
        
iteration_time = time.time()-start  
    
if model_name == 'dense-grid':
    imgs = np.stack(imgs).reshape(-1,model.coeffs[0].shape[1],*imgs[0].shape)
    for i in range(model.coeffs[0].shape[1]):
        imageio.mimwrite('/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/temp.mp4', imgs[:,i], fps=30, quality=10)

        os.makedirs(f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/vis_sparse_cat_{model_name}_{data_mode}',exist_ok=True)
        stream = ffmpeg.input('/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/temp.mp4')
        stream = ffmpeg.output(stream, f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/vis_sparse_cat_{model_name}_{data_mode}/{i:02d}.mp4')
        ffmpeg.run(stream,overwrite_output=True)
elif model_name == 'cp':
    vis_coef, vis_basis = imgs[0::2], imgs[1::2]
    vis_coef = np.stack(vis_coef).reshape(-1,model.coeffs[0].shape[1],*vis_coef[0].shape)
    vis_basis = np.stack(vis_basis).reshape(-1,model.coeffs[0].shape[1],*vis_basis[0].shape)
    for (name,item) in zip(['coef','basis'],[vis_coef,vis_basis]):
        for i in range(item.shape[1]):
            imageio.mimwrite('/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/temp.mp4', item[:,i], fps=30, quality=10)

            os.makedirs(f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/vis_cat_{model_name}_{data_mode}',exist_ok=True)
            stream = ffmpeg.input('/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/temp.mp4')
            stream = ffmpeg.output(stream, f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/vis_cat_{model_name}_{data_mode}/{name}-{i:02d}.mp4')
            ffmpeg.run(stream,overwrite_output=True)
        

48 2048000 2048000
=====> total parameters:  2047852
FactorFields(
  (coeffs): ParameterList(  (0): Parameter containing: [torch.float32 of size 1x12x413x413 (GPU 0)])
  (linear_mat): MLPMixer(
    (backbone): ModuleList(
      (0): Linear(in_features=12, out_features=64, bias=True)
      (1): Linear(in_features=64, out_features=3, bias=False)
    )
  )
)
total parameters:  2047852


  feat = (feat-np.min(feat))/(np.max(feat) - np.min(feat))
Iteration 09990: loss_dist = 0.00056620 psnr = 32.470: 100%|█| 10000/1


In [9]:
import os

In [10]:
        os.makedirs(f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/vis_sparse_cat_{model_name}_{data_mode}',exist_ok=True)
        stream = ffmpeg.input('/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/temp.mp4')
        stream = ffmpeg.output(stream, f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/vis_sparse_cat_{model_name}_{data_mode}/{i:02d}.mp4')
        ffmpeg.run(stream,overwrite_output=True)

ffmpeg version 4.3 Copyright (c) 2000-2020 the FFmpeg developers
  built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)
  configuration: --prefix=/mnt/lustre/geiger/zyu30/.conda/envs/sdfstudio --cc=/opt/conda/conda-bld/ffmpeg_1597178665428/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --disable-openssl --enable-avresample --enable-gnutls --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame
  libavutil      56. 51.100 / 56. 51.100
  libavcodec     58. 91.100 / 58. 91.100
  libavformat    58. 45.100 / 58. 45.100
  libavdevice    58. 10.100 / 58. 10.100
  libavfilter     7. 85.100 /  7. 85.100
  libavresample   4.  0.  0 /  4.  0.  0
  libswscale      5.  7.100 /  5.  7.100
  libswresample   3.  7.100 /  3.  7.100
Input #0, mov,mp4,m4a,3gp,3g2,mj2, from '/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/temp.mp4':
  Metadata:
    

(None, None)