In [1]:
from __future__ import print_function

import argparse
import os
import random
import shutil
from datetime import datetime

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as prune
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from NeRV.model_nerv import CustomDataSet, Generator
from NeRV.utils import *



In [2]:
args = argparse.Namespace()
args.vid = [None]
args.scale = 1
args.frame_gap = 1
args.augment = 0
args.dataset = 'bunny'
args.test_gap = 1
args.embed = '1.25_40'
args.stem_dim_num = '512_1'
args.fc_hw_dim = '9_16_26'
args.expansion = 1
args.reduction = 2
args.strides = [5,2,2,2,2]
args.num_blocks = 1
args.norm = 'none'
args.act = 'swish'
args.lower_width = 96
args.single_res = True
args.conv_type = 'conv'
args.workers = 4
args.batchSize = 1
args.not_resume_epoch = True
args.epochs = 300
args.cycles = 1
args.warmup = 0.2
args.lr = 0.0005
args.lr_type = 'cosine'
args.lr_steps = []
args.beta = 0.5
args.loss_type = 'Fusion6'
args.lw = 1.0
args.sigmoid = True
args.eval_only = False
args.eval_freq = 50
args.quant_bit = -1
args.quant_axis = 0
args.dump_images = False
args.eval_fps = False
args.prune_steps = [0.,]
args.prune_ratio = 1.0
args.manualSeed = 1
args.init_method = 'tcp://127.0.0.1:9888'
args.distributed = False
args.debug = True
args.print_freq = 50
args.weight = 'None'
args.overwrite = False
args.outf = 'NeRV/bunny_ab'
args.suffix = ''
args.ngpus_per_node = 1


cudnn.benchmark = True
torch.manual_seed(args.manualSeed)
random.seed(args.manualSeed)

if not os.path.isdir(args.outf):
        os.makedirs(args.outf)

device = "mps" if torch.backends.mps.is_available() else "cpu"

local_rank = 0


In [3]:
train_best_psnr, train_best_msssim, val_best_psnr, val_best_msssim = [torch.tensor(0) for _ in range(4)]
is_train_best, is_val_best = False, False

In [4]:
PE = PositionalEncoding(args.embed)
args.embed_length = PE.embed_length
model = Generator(embed_length=args.embed_length, stem_dim_num=args.stem_dim_num, fc_hw_dim=args.fc_hw_dim, expansion=args.expansion, 
        num_blocks=args.num_blocks, norm=args.norm, act=args.act, bias = True, reduction=args.reduction, conv_type=args.conv_type,
        stride_list=args.strides,  sin_res=args.single_res,  lower_width=args.lower_width, sigmoid=args.sigmoid)

In [5]:
## Prune Model
prune_net = args.prune_ratio < 1
if prune_net:
    raise NotImplementedError("Not Implemented Yet")

In [6]:
## Get Model params and flops
total_params = sum([p.data.nelement() for p in model.parameters()]) / 1e6
params = sum([p.data.nelement() for p in model.parameters()]) / 1e6

print(f'{args}\n {model}\n Model Params: {params}M')
with open ('{}/rank0.txt'.format(args.outf), 'a') as f:
    f.write(str(model) + '\n' + f'Params: {params}M\n')
writer = SummaryWriter(os.path.join(args.outf, f'param_{total_params}M', 'tensorboard'))

Namespace(vid=[None], scale=1, frame_gap=1, augment=0, dataset='bunny', test_gap=1, embed='1.25_40', stem_dim_num='512_1', fc_hw_dim='9_16_26', expansion=1, reduction=2, strides=[5, 2, 2, 2, 2], num_blocks=1, norm='none', act='swish', lower_width=96, single_res=True, conv_type='conv', workers=4, batchSize=1, not_resume_epoch=True, epochs=300, cycles=1, warmup=0.2, lr=0.0005, lr_type='cosine', lr_steps=[], beta=0.5, loss_type='Fusion6', lw=1.0, sigmoid=True, eval_only=False, eval_freq=50, quant_bit=-1, quant_axis=0, dump_images=False, eval_fps=False, prune_steps=[0.0], prune_ratio=1.0, manualSeed=1, init_method='tcp://127.0.0.1:9888', distributed=False, debug=True, print_freq=50, weight='None', overwrite=False, outf='NeRV/bunny_ab', suffix='', ngpus_per_node=1, embed_length=80)
 Generator(
  (stem): Sequential(
    (0): Linear(in_features=80, out_features=512, bias=True)
    (1): SiLU(inplace=True)
    (2): Linear(in_features=512, out_features=3744, bias=True)
    (3): SiLU(inplace=True

In [None]:
# Distributed model to gpu or parallel
if args.distributed and args.ngus_per_node > 1:
    raise NotImplementedError("Distributed model not implemented")
elif args.ngpus_per_node > 1:
    raise NotImplementedError("Distributed model not implemented")
else:
    model = model.to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), betas=(args.beta, 0.999))

In [None]:
if args.weight != 'None':
    raise NotImplementedError("Resume From Weight Not Implemented")

## Resume from nodekl_latest
checkpoint = None
checkpoint_path = os.path.join(args.outf, 'model_latest.pth')
if os.path.isfile(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    if prune_net:
       raise NotImplementedError("Prune Not Implemented")
    model.load_state_dict(checkpoint['state_dict'])
    print("=> Auto resume loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else:
    print("=> No resume checkpoint found at '{}'".format(checkpoint_path))

loc = 'mps:0'
if checkpoint is not None:
    args.start_epoch = checkpoint['epoch'] 
    train_best_psnr = checkpoint['train_best_psnr'].to(torch.device(loc))
    train_best_msssim = checkpoint['train_best_msssim'].to(torch.device(loc))
    val_best_psnr = checkpoint['val_best_psnr'].to(torch.device(loc))
    val_best_msssim = checkpoint['val_best_msssim'].to(torch.device(loc))
    optimizer.load_state_dict(checkpoint['optimizer'])

In [7]:
img_transforms = transforms.ToTensor()
DataSet = CustomDataSet
train_data_dir = f'NeRV/data/{args.dataset.lower()}'
val_data_dir = f'NeRV/data/{args.dataset.lower()}'

train_dataset = DataSet(train_data_dir, img_transforms,vid_list=args.vid, frame_gap=args.frame_gap,  )
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchSize, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True, worker_init_fn=worker_init_fn)

val_dataset = DataSet(val_data_dir, img_transforms, vid_list=args.vid, frame_gap=args.test_gap,  )
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) if args.distributed else None
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batchSize,  shuffle=False,
        num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False, worker_init_fn=worker_init_fn)
data_size = len(train_dataset)

In [12]:
data_iter = iter(train_dataloader)

In [25]:
batch = next(data_iter)
data, norm_idx = batch
PE(norm_idx), norm_idx

(tensor([[ 0.9525,  0.3045,  1.0000, -0.0059,  0.9210, -0.3895,  0.6272, -0.7789,
           0.0620, -0.9981, -0.6502, -0.7597, -0.9951,  0.0993, -0.2652,  0.9642,
           0.9443,  0.3292,  0.0266, -0.9996, -0.7303,  0.6832,  0.8539, -0.5205,
          -0.4740,  0.8805, -0.8155, -0.5788, -0.3954, -0.9185, -0.9618, -0.2737,
           0.7390,  0.6737, -0.5067,  0.8621,  0.7874,  0.6165, -0.4238,  0.9058,
           0.5201, -0.8541, -0.9948,  0.1015,  0.9649,  0.2625,  0.0606,  0.9982,
          -0.0758, -0.9971,  0.7709,  0.6370,  0.4534, -0.8913,  0.1959,  0.9806,
          -0.9698,  0.2440, -0.9964, -0.0844,  0.8784,  0.4780, -0.9736, -0.2282,
           0.6291, -0.7773, -0.9979, -0.0652, -0.4568,  0.8896, -0.5584,  0.8295,
          -0.7377, -0.6752,  0.9685, -0.2489, -0.6497, -0.7602,  0.9951, -0.0990]]),
 tensor([0.4015]))

In [30]:
math.sin(0.4015 * (1.25)**3 * math.pi)

0.6272518154951441

In [None]:
@torch.no_grad()
def evaluate(model, val_dataloader, pe, local_rank, args):
    # Model Quantization
    if args.quant_bit != -1:
        raise NotImplementedError("HuffmanCodec Not Implemented")

    psnr_list = []
    msssim_list = []
    if args.dump_images:
        from torchvision.utils import save_image
        visual_dir = f'{args.outf}/visualize'
        print(f'Saving predictions to {visual_dir}')
        if not os.path.isdir(visual_dir):
            os.makedirs(visual_dir)

    time_list = []
    model.eval()
    for i, (data,  norm_idx) in enumerate(val_dataloader):
        if i > 10 and args.debug:
            break
        embed_input = pe(norm_idx)
        if local_rank not in [0, None]:
            raise NotImplementedError("Distributed Model not Implemented")
        else:
            data,  embed_input = data.to(device=device, non_blocking=True), embed_input.to(device=device, non_blocking=True)

        # compute psnr and msssim
        fwd_num = 10 if args.eval_fps else 1
        for _ in range(fwd_num):
            # embed_input = embed_input.half()
            # model = model.half()
            start_time = datetime.now()
            output_list = model(embed_input)
            # torch.cuda.current_stream().synchronize()
            time_list.append((datetime.now() - start_time).total_seconds())

        # dump predictions
        if args.dump_images:
            for batch_ind in range(args.batchSize):
                full_ind = i * args.batchSize + batch_ind
                save_image(output_list[-1][batch_ind], f'{visual_dir}/pred_{full_ind}.png')
                save_image(data[batch_ind], f'{visual_dir}/gt_{full_ind}.png')

        # compute psnr and ms-ssim
        target_list = [F.adaptive_avg_pool2d(data, x.shape[-2:]) for x in output_list]
        psnr_list.append(psnr_fn(output_list, target_list))
        msssim_list.append(msssim_fn(output_list, target_list))
        val_psnr = torch.cat(psnr_list, dim=0)              #(batchsize, num_stage)
        val_psnr = torch.mean(val_psnr, dim=0)              #(num_stage)
        val_msssim = torch.cat(msssim_list, dim=0)          #(batchsize, num_stage)
        val_msssim = torch.mean(val_msssim.float(), dim=0)  #(num_stage)        
        if i % args.print_freq == 0:
            fps = fwd_num * (i+1) * args.batchSize / sum(time_list)
            print_str = 'Rank:{}, Step [{}/{}], PSNR: {}, MSSSIM: {} FPS: {}'.format(
                local_rank, i+1, len(val_dataloader),
                RoundTensor(val_psnr, 2, False), RoundTensor(val_msssim, 4, False), round(fps, 2))
            print(print_str)
            if local_rank in [0, None]:
                with open('{}/rank0.txt'.format(args.outf), 'a') as f:
                    f.write(print_str + '\n')
    model.train()

    return val_psnr, val_msssim

In [None]:
start = datetime.now()
total_epochs = args.epochs * args.cycles
for epoch in range(args.start_epoch, total_epochs):
    model.train()
    epoch_start_time = datetime.now()
    psnr_list = []
    msssim_list = []

    for i, (data, norm_idx) in enumerate(train_dataloader):
        if i > 10 and args.debug:
            break
    
        embed_input = PE(norm_idx)
        data, embed_input = data.to(device), embed_input.to(device)

        output_list = model(embed_input)
        target_list = [F.adaptive_avg_pool2d(data, x.shape[-2:]) for x in output_list]
        loss_list = [loss_fn(output, target, args) for output, target in zip(output_list, target_list)]
        loss_list = [loss_list[i] * (args.lw if i < len(loss_list) - 1 else 1) for i in range(len(loss_list))]
        loss_sum = sum(loss_list)
        lr = adjust_lr(optimizer, epoch % args.epochs, i, data_size, args)
        optimizer.zero_grad()
        loss_sum.backward()
        optimizer.step()

        # compute psnr and msssim
        psnr_list.append(psnr_fn(output_list, target_list))
        msssim_list.append(msssim_fn(output_list, target_list))
        if i % args.print_freq == 0 or i == len(train_dataloader) - 1 :
            train_psnr = torch.cat(psnr_list, dim=0)
            train_psnr = torch.mean(train_psnr, dim=0)
            train_msssim = torch.cat(msssim_list, dim = 0)
            train_msssim = torch.mean(train_msssim.float(), dim = 0)
            time_now_string = datetime.now().strftime("%Y/%m/%d %H:%M:%S")
            print_str = '[{}] Rank:{}, Epoch[{}/{}], Step [{}/{}], lr:{:2e} PSNR: {}, MSSSIM: {}'.format(
                time_now_string, 0, epoch+1, args.epochs, i+1, len(train_dataloader), lr,
                RoundTensor(train_psnr, 2, False), RoundTensor(train_msssim, 4, False)
            )
            print(print_str)
            if local_rank in [0, None]:
                with open('{}/rank0.txt'.format(args.outf), 'a') as f:
                            f.write(print_str + '\n')

        if local_rank in [0, None]:
            h, w = output_list[-1].shape[-2:]
            is_train_best = train_psnr[-1] > train_best_psnr
            train_best_psnr = train_psnr[-1] if train_psnr[-1] > train_best_psnr else train_best_psnr
            train_best_msssim = train_msssim[-1] if train_msssim[-1] > train_best_msssim else train_best_msssim
            writer.add_scalar(f'Train/PSNR_{h}X{w}_gap{args.frame_gap}', train_psnr[-1].item(), epoch+1)
            writer.add_scalar(f'Train/MSSSIM_{h}X{w}_gap{args.frame_gap}', train_msssim[-1].item(), epoch+1)
            writer.add_scalar(f'Train/best_PSNR_{h}X{w}_gap{args.frame_gap}', train_best_psnr.item(), epoch+1)
            writer.add_scalar(f'Train/best_MSSSIM_{h}X{w}_gap{args.frame_gap}', train_best_msssim, epoch+1)
            print_str = '\t{}p: current: {:.2f}\t best: {:.2f}\t msssim_best: {:.4f}\t'.format(h, train_psnr[-1].item(), train_best_psnr.item(), train_best_msssim.item())
            print(print_str, flush=True)
            with open('{}/rank0.txt'.format(args.outf), 'a') as f:
                f.write(print_str + '\n')
            writer.add_scalar('Train/lr', lr, epoch+1)
            epoch_end_time = datetime.now()
            print("Time/epoch: \tCurrent:{:.2f} \tAverage:{:.2f}".format( (epoch_end_time - epoch_start_time).total_seconds(), \
                    (epoch_end_time - start).total_seconds() / (epoch + 1 - args.start_epoch) ))
        
        state_dict = model.state_dict()
        save_checkpoint = {
            'epoch': epoch+1,
            'state_dict': state_dict,
            'train_best_psnr': train_best_psnr,
            'train_best_msssim': train_best_msssim,
            'val_best_psnr': val_best_psnr,
            'val_best_msssim': val_best_msssim,
            'optimizer': optimizer.state_dict(),   
        }    

        if (epoch + 1) % args.eval_freq == 0 or epoch > total_epochs - 10:
            val_start_time = datetime.now()
            val_psnr, val_msssim = evaluate(model, val_dataloader, PE, local_rank, args)
            val_end_time = datetime.now()
        
            if local_rank in [0, None]:
                    # ADD val_PSNR TO TENSORBOARD
                    h, w = output_list[-1].shape[-2:]
                    print_str = f'Eval best_PSNR at epoch{epoch+1}:'
                    is_val_best = val_psnr[-1] > val_best_psnr
                    val_best_psnr = val_psnr[-1] if is_val_best else val_best_psnr
                    val_best_msssim = val_msssim[-1] if val_msssim[-1] > val_best_msssim else val_best_msssim
                    writer.add_scalar(f'Val/PSNR_{h}X{w}_gap{args.test_gap}', val_psnr[-1], epoch+1)
                    writer.add_scalar(f'Val/MSSSIM_{h}X{w}_gap{args.test_gap}', val_msssim[-1], epoch+1)
                    writer.add_scalar(f'Val/best_PSNR_{h}X{w}_gap{args.test_gap}', val_best_psnr, epoch+1)
                    writer.add_scalar(f'Val/best_MSSSIM_{h}X{w}_gap{args.test_gap}', val_best_msssim, epoch+1)
                    print_str += '\t{}p: current: {:.2f}\tbest: {:.2f} \tbest_msssim: {:.4f}\t Time/epoch: {:.2f}'.format(h, val_psnr[-1].item(),
                        val_best_psnr.item(), val_best_msssim.item(), (val_end_time - val_start_time).total_seconds())
                    print(print_str)
                    with open('{}/rank0.txt'.format(args.outf), 'a') as f:
                        f.write(print_str + '\n')
                    if is_val_best:
                        torch.save(save_checkpoint, '{}/model_val_best.pth'.format(args.outf))

        if local_rank in [0, None]:
            # state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
            torch.save(save_checkpoint, '{}/model_latest.pth'.format(args.outf))
            if is_train_best:
                torch.save(save_checkpoint, '{}/model_train_best.pth'.format(args.outf))

    print("Training complete in: " + str(datetime.now() - start))
