In [2]:
from typing import Optional
from multiprocessing import Pool

import numpy as np
import cv2
import torch
import matplotlib.pyplot as plt
import math
from torch._C import dtype
from typing import Dict
from models.siren_model import Siren
from models.very_tiny_nerf_model import VeryTinyNerfModel
import deepCABAC
from torchinfo import summary



In [3]:
DTYPE_BIT_SIZE: Dict[dtype, int] = {
    torch.float32: 32,
    torch.float: 32,
    torch.float64: 64,
    torch.double: 64,
    torch.float16: 16,
    torch.half: 16,
    torch.bfloat16: 16,
    torch.complex32: 32,
    torch.complex64: 64,
    torch.complex128: 128,
    torch.cdouble: 128,
    torch.uint8: 8,
    torch.int8: 8,
    torch.int16: 16,
    torch.short: 16,
    torch.int32: 32,
    torch.int: 32,
    torch.int64: 64,
    torch.long: 64,
    torch.bool: 1
}

def model_size_in_bits(model):
    """Calculate total number of bits to store `model` parameters and buffers."""
    return sum(sum(t.nelement() * DTYPE_BIT_SIZE[t.dtype] for t in tensors)
               for tensors in (model.parameters(), model.buffers()))

In [4]:
def clamp_image(img):
    """Clamp image values to like in [0, 1] and convert to unsigned int.
    Args:
        img (torch.Tensor):
    """
    # Values may lie outside [0, 1], so clamp input
    img_ = torch.clamp(img, 0., 1.)
    # Pixel values lie in {0, ..., 255}, so round float tensor
    return torch.round(img_ * 255) / 255.
    
def psnr(img1, img2):
    """Calculates PSNR between two images.
    Args:
        img1 (torch.Tensor):
        img2 (torch.Tensor):
    """
    return 20. * np.log10(1.) - 10. * (img1 - img2).detach().pow(2).mean().log10().to('cpu').item()

In [5]:
def get_minibatches(inputs: torch.Tensor, chunksize: Optional[int] = 1024 * 8):
  r"""
  Each element of the list (except possibly the last) has dimension `0` of length
  `chunksize`.
  """
  return [inputs[i:i + chunksize] for i in range(0, inputs.shape[0], chunksize)]

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('CUDA version:', torch.version.cuda)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3, 1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3, 1), 'GB')

torch.set_default_tensor_type('torch.cuda.FloatTensor' if torch.cuda.is_available() else 'torch.FloatTensor')

Using device: cuda
CUDA version: 11.2

Tesla T4
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [7]:
# cap = cv2.VideoCapture("../datasets/00003.mp4")
cap = cv2.VideoCapture("../datasets/fireworks_128.mp4")
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# total_frames = 10


def complex_vid(width=100, height=100, frames=100):
  """Peak complexity video. Random pixel values for all coords at each frame."""
  return torch.rand(frames, width, height, 3)

def get_frame(idx):
  r""" Get the RGB tensor of a specific frame in the video.
  """
  cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
  success, img = cap.read()
  if not success:
    print("Failed to load frame at index " + str(idx))
  return torch.from_numpy(np.float32(img) / 255)

frames = torch.stack([get_frame(i) for i in range(total_frames)])
frames = frames.to(device)
# frames = complex_vid(frames=100)
# total_frames = 100
# width = 100
# height = 100
# fps = 24

In [8]:
def xs_and_ys(width, height, frame_ind):
    r""" Construct (x, y, f) tuples.
    """

    coordinates = torch.ones([height,width]).nonzero(as_tuple=False).float()


    #Assuming the image is square is necessary for this nice vector operation
    #Change if we use non-square image
    coordinates = coordinates/ (height - 1) - 0.5

    fill_val =  frame_ind/(total_frames -1) - 0.5
    frame_indicies = torch.full((coordinates.shape[0], 1), fill_val)
    coordinates = torch.cat([coordinates, frame_indicies], -1)
 
    coordinates *= 2
    return coordinates



In [9]:
def positional_encoding(
    tensor, num_encoding_functions = [6,6,6], std_dev = 1.4, include_input=True, log_sampling=True
) -> torch.Tensor:
    r"""Apply positional encoding to the input.

    Args:
    tensor (torch.Tensor): Input tensor to be positionally encoded.
    num_encoding_functions (optional, int): Number of encoding functions used to
        compute a positional encoding (default: 6).
    std_dev (optional, int): Scale parameter/standard deviation, replaces the default two used in Nerf
    include_input (optional, bool): Whether or not to include the input in the
        computed positional encoding (default: True).
    log_sampling (optional, bool): Sample logarithmically in frequency space, as
        opposed to linearly (default: True).

    Returns:
    (torch.Tensor): Positional encoding of the input tensor.
    """
    # TESTED
    # Trivially, the input tensor is added to the positional encoding.

    encoding = [tensor] if include_input else []

    # Now, encode the input using a set of high-frequency functions and append the
    # resulting values to the encoding.
    frequency_bands = None
    if log_sampling:
        frequency_bands =[ std_dev ** torch.linspace(
            0.0,
            n - 1,
            n,
            dtype=tensor.dtype,
            device=tensor.device,
        ) for n in num_encoding_functions]

    else:
        frequency_bands = [ torch.linspace(
            std_dev ** 0.0,
            std_dev ** (n - 1),
            n,
            dtype=tensor.dtype,
            device=tensor.device,
        ) for n in num_encoding_functions ]
        
    max_enc_len = np.max(num_encoding_functions)    
    for i in range(0, max_enc_len):
            for func in [torch.sin, torch.cos]:
                for j in range(0, len(num_encoding_functions)):
                    freq_band = frequency_bands[j]
                    if i < len(freq_band):
                        encoding.append(func(tensor[:, j:j+1] * freq_band[i]))
                   

    # Special case, for no positional encoding
    if len(encoding) == 1:
        return encoding[0]
    else:
        return torch.cat(encoding, dim=-1)

In [10]:
def gaussian_encoding(tensor, num_encoding_functions=[6, 6, 6], std_dev=1.4, include_input=True):

    encoding = [tensor] if include_input else []

    frequency_bands = [torch.linspace(
            0.0,
            n - 1,
            n,
            dtype=tensor.dtype,
            device=tensor.device,
        ) for n in num_encoding_functions]

    max_enc_len = np.max(num_encoding_functions)
    for i in range(0, max_enc_len):
            for j in range(0, len(num_encoding_functions)):
                freq_band = frequency_bands[j]
                if i < len(freq_band):
                    encoding.append(torch.exp(-torch.pow((freq_band[i] - tensor[:, j:j+1]), 2) / (2 * std_dev**2)))

    # Special case, for no positional encoding
    if len(encoding) == 1:
        return encoding[0]
    else:
        return torch.cat(encoding, dim=-1)


In [11]:
encoded = {}
def one_iter_npc(width, height, model, frame_ind, encoding_fn, get_minibatches_fn):


  if not frame_ind in encoded:
    pts = xs_and_ys(width, height, frame_ind)
    encoded_pts = encoding_fn(pts)
    encoded[frame_ind] = encoded_pts
  else:
    encoded_pts = encoded[frame_ind]
  
  rgb_flat = model(encoded_pts)
  rgb = torch.reshape(rgb_flat, [height, width, 3])
  return rgb

In [12]:
def output_video(fps, width: int, height: int, encode, get_minibatches):
  r"""Build the final video from the trained model."""
  out = cv2.VideoWriter('output_video.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))

  def output_v(start_frame, end_frame, model):
    for f in range(start_frame, end_frame):
        rgb_predicted = clamp_image(one_iter_npc(width, height, model,
                                     f, encode,
                                     get_minibatches))

        rgb_out = cv2.normalize(src=rgb_predicted.detach().cpu().numpy(), dst=None, alpha=0, beta=255,
                                norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
        out.write(rgb_out)
    return lambda: out.release()

  return output_v

In [13]:
def get_new_chunk():
  # model = VeryTinyNerfModel(num_encoding_functions=num_encoding_functions,
  #                           num_hidden_layers=num_hidden_layers, filter_size=filter_size)

  model = Siren(
        dim_in=3,
        dim_hidden=filter_size,
        dim_out=3,
        num_layers=num_hidden_layers,
        final_activation=torch.nn.Identity(),
        w0_initial=30.0,
        w0=30.0,
        num_encoding_functions = num_encoding_functions,
    )

  l1_w = 1e-6


  print("Model size: {:2f} MB".format(1e-6 * model_size_in_bits(model) / 8))
  print(f"Model size in bits: {model_size_in_bits(model)}")
  model.to(device)

  optimizer = torch.optim.Adam(model.parameters(), lr=lr)

  # Lists to log metrics etc.
  psnrs = []
  iternums = []

  all_psnrs = {}


  def training_loop(start_frame, end_frame, num_iters = 500, mean_psnr_cutoff = 30):
    nonlocal model 
    nonlocal optimizer
    nonlocal iternums
    nonlocal all_psnrs
    nonlocal l1_w

    chunk_frames = end_frame - start_frame
    test_frame = end_frame - 1

    print("Start frame: " + str(start_frame) + ", end frame: " + str(end_frame))
    print("Total frames: " + str(chunk_frames))

    for i in range(num_iters + 1):
      with torch.cuda.amp.autocast():
        # Randomly pick a frame as the target
        target_frame_idx = np.random.randint(start_frame, end_frame)

        target_img = frames[target_frame_idx] # get_frame(target_frame_idx)
        
        rgb_predicted = one_iter_npc(width, height, model,
                                    target_frame_idx, encode,
                                    get_minibatches)

        # Compute mean-squared error between the predicted and target images. Backprop!
        l1_reg = l1_w * torch.abs(torch.cat([p.view(-1) for p in model.parameters()])).sum()

        loss = torch.nn.functional.mse_loss(rgb_predicted, target_img) + l1_reg
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        clamped_rgp_predicted = clamp_image(rgb_predicted)

        img_psnr = psnr(target_img, clamped_rgp_predicted)
      
        all_psnrs[target_frame_idx] = img_psnr.item()

        # Display images/plots/stats
        if (i > 0) and (i % display_every == 0):
          # Render the held-out view
          
          test_frame_img = frames[test_frame] # get_frame(test_frame)
          rgb_predicted = clamp_image(one_iter_npc(width, height, model,
                                      test_frame, encode,
                                      get_minibatches))

          held_out_frame_psnr = psnr(test_frame_img, rgb_predicted)
          mean_psnr = np.mean(list(all_psnrs.values()))
          median_psnr = np.median(list(all_psnrs.values()))
          print("Loss:", loss.item(),"Frame PSNR:", held_out_frame_psnr,"Mean PSNR:", mean_psnr, "Median PSNR:", median_psnr)

          iternums.append(i)
          plt.figure(figsize=(10, 4))
          plt.subplot(121)
          plt.imshow(rgb_predicted.detach().cpu().numpy().astype(float))
          plt.title(f"Iteration {i}")
          plt.subplot(122)
          plt.imshow(test_frame_img.detach().cpu().numpy().astype(float))
          # plt.plot(iternums, psnrs)
          plt.title("Original")
          plt.show()
        
    return model

  return training_loop

In [14]:
def process_chunk(tup):
    return get_new_chunk()(start_frame = tup[0], end_frame = tup[1], num_iters = tup[2], mean_psnr_cutoff = tup[3])


def chunk_processing(total_frames, outputter, frames_per_chunk, max_iters, psnr_cutoff, train_loop= None):

    start_index = 0 

    params = []
    models = []
    for i in range(0, total_frames, frames_per_chunk):
        models.append(get_new_chunk()(i, min(total_frames, i+frames_per_chunk), max_iters, psnr_cutoff) )

    for i in range(0, len(models)):
        start_index = i*frames_per_chunk
        release_out = outputter(start_index, min(total_frames, start_index+frames_per_chunk), models[i])
    
    return release_out, models



In [15]:
"""
Parameters for NPC training
"""

num_hidden_layers = 4
filter_size = 128
std_dev = 1.4
# Number of functions used in the positional encoding (Be sure to update the
# model if this number changes).
num_encoding_functions = [64, 64, 32]


# Specify encoding function.
def encode(x): return positional_encoding(
    x, num_encoding_functions=num_encoding_functions, std_dev = std_dev, include_input=True, log_sampling=True)
# def encode(x): return gaussian_encoding(
#     x, num_encoding_functions=num_encoding_functions, std_dev=0.0003, include_input=True)
# def encode(x): return x


# Optimizer parameters
lr = 2e-5

seed_iters = 500
baseline_iters = 250
num_iters = 25000

#Misc parameters
display_every = 5000  # Number of iters after which stats are displayed

"""
Train-Eval-Repeat!
"""

# Seed RNG, for repeatability
seed = 1337
torch.manual_seed(seed)
np.random.seed(seed)

outputter = output_video(fps, width, height, encode, get_minibatches)

print("Video dims: " + str(width) + "x" + str(height))
print("Framerate: " + str(fps))

initial_chunk_size = 2
chunk_increment = 2


# Horrible idea, absolute standards make the algorithm only compress two frames at a time
# For most of the sections
# mean_psnr_one_fram, train_loop = get_baseline_psnr(total_frames, initial_chunk_size, seed_iters)
# print(mean_psnr_one_fram)

frames_per_chunk = 15 # 24
release_out, models = chunk_processing(total_frames=total_frames, outputter=outputter, frames_per_chunk=frames_per_chunk, max_iters=num_iters, psnr_cutoff = 40.0)
# release_out = chunk_processing(total_frames, outputter, chunk_increment = chunk_increment , seed_iters = seed_iters, baseline_iters = baseline_iters, iters_per_frame = iters_per_frame, initial_chunk_size = initial_chunk_size, baseline_mean_psnr_percentage = 0.9)


print('Done training. Storing output...')
release_out()
print("Output complete")

Video dims: 128x128
Framerate: 25.0
Dim_in: 323
Model size: 0.365580 MB
Model size in bits: 2924640
Start frame: 0, end frame: 15
Total frames: 15


In [None]:

interv = 0.1
stepsize = 3**(-0.5*13)
stepsize_other = 3**(-0.5*17)
_lambda = 0.

for i in range(len(models)):
    encoder = deepCABAC.Encoder()

    for name, param in models[i].state_dict().items():
        if '.num_batches_tracked' in name:
            continue
        param = param.cpu().numpy()
        if '.weight' in name:
            encoder.encodeWeightsRD(param, interv, stepsize, _lambda)
        else:
            encoder.encodeWeightsRD(param, interv, stepsize_other, _lambda)

    stream = encoder.finish().tobytes()
    print("Compressed size: {:2f} MB".format(1e-6 * len(stream)))
    with open(f"weights_{i}.bin", 'wb') as f:
        f.write(stream)

decoded_models = []
for i in range(len(models)):
        
    # decoding
    model = Siren(
        dim_in=3,
        dim_hidden=filter_size,
        dim_out=3,
        num_layers=num_hidden_layers,
        final_activation=torch.nn.Identity(),
        w0_initial=30.0,
        w0=30.0,
        num_encoding_functions = num_encoding_functions,
    )

    decoder = deepCABAC.Decoder()

    with open(f"weights_{i}.bin", 'rb') as f:
        stream = f.read()

    decoder.getStream(np.frombuffer(stream, dtype=np.uint8))
    state_dict = model.state_dict()

    for name in state_dict.keys():
        if '.num_batches_tracked' in name:
            continue
        param = decoder.decodeWeights()
        state_dict[name] = torch.tensor(param)
    decoder.finish()
    model.load_state_dict(state_dict)
    decoded_models.append(model)

out = cv2.VideoWriter('output_video_compressed.mp4', cv2.VideoWriter_fourcc(
    *'mp4v'), fps, (width, height))

all_psnrs = []
for i in range(total_frames):
  rgb_predicted = clamp_image(torch.reshape(
      decoded_models[int(i/frames_per_chunk)](encoded[i]), [height, width, 3]))

  img_psnr = psnr(frames[i], rgb_predicted)
  all_psnrs.append(img_psnr.item())

  mean_psnr = np.mean(all_psnrs)
  median_psnr = np.median(all_psnrs)
  print("Mean PSNR:", mean_psnr, "Median PSNR:", median_psnr)

  if i == 10:
      all_psnrs = []

  rgb_out = cv2.normalize(src=rgb_predicted.detach().cpu().numpy(), dst=None, alpha=0, beta=255,
                          norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
  out.write(rgb_out)

out.release()

mean_psnr = np.mean(all_psnrs)
median_psnr = np.median(all_psnrs)
print("Mean PSNR:", mean_psnr, "Median PSNR:", median_psnr)


Compressed size: 0.048472 MB
Compressed size: 0.045162 MB
Dim_in: 323
Dim_in: 323
Mean PSNR: 27.68512725830078 Median PSNR: 27.68512725830078
Mean PSNR: 28.491586446762085 Median PSNR: 28.491586446762085
Mean PSNR: 28.96071990331014 Median PSNR: 29.29804563522339
Mean PSNR: 29.202473759651184 Median PSNR: 29.59851622581482
Mean PSNR: 29.317337036132812 Median PSNR: 29.776790142059326
Mean PSNR: 29.29075002670288 Median PSNR: 29.537417888641357
Mean PSNR: 29.321529184068954 Median PSNR: 29.50620412826538
Mean PSNR: 29.15383070707321 Median PSNR: 29.402124881744385
Mean PSNR: 29.189258416493733 Median PSNR: 29.47268009185791
Mean PSNR: 29.19486355781555 Median PSNR: 29.38536286354065
Mean PSNR: 29.215681769631125 Median PSNR: 29.423863887786865
Mean PSNR: 29.13794994354248 Median PSNR: 29.13794994354248
Mean PSNR: 28.922892808914185 Median PSNR: 28.922892808914185
Mean PSNR: 28.991137345631916 Median PSNR: 29.127626419067383
Mean PSNR: 28.858569860458374 Median PSNR: 28.917731046676636
M