In [1]:
import os
import sys
import time
import argparse
from PIL import Image
# from generate_video import generate_video

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR
from torchvision.models import vgg16
from perceptual import LossNetwork
import random

import config
import myutils
from myutils import test_images_2d, test_images_outp, test_metric, get_raft_args, test_images_optical
from loss import Loss
from loss import edge_conv2d
from torch.utils.data import DataLoader
from loss import L1_Charbonnier_loss
from pdb import set_trace as bp

import models
from utils import make_coord
from torch.autograd import Variable
from loss import PerceptualLoss, DCLoss

from core.raft import RAFT
from core.utils import flow_viz
from core.utils.utils import InputPadder

import cv2
from PWC_src import PWC_Net
from PWC_src import flow_to_image
import warnings

from dataset.GoPro_arbitrary_nosr import make_coord_3d


def double_forward(model, optimizer, preds, images, gt, device, i, epoch_id, bs):
    
    sampled_idx = sorted(random.sample(range(5), 3))
    h, w = preds[0].shape[2], preds[0].shape[3]
    for idx in range(3):
        optimizer.zero_grad()
        temp_coord = make_coord_3d((h, w), (idx + 2) / 8)
        temp_coord = [temp_coord.to(device)[None].repeat(bs, 1, 1)]
        if idx == 0:
            inputs = torch.cat([images[:, :, 1].unsqueeze(0), preds[1].detach().unsqueeze(0), preds[3].detach().unsqueeze(0), preds[5].detach().unsqueeze(0)], dim=0).permute(1, 2, 0, 3, 4)
        elif idx == 1:
            inputs = torch.cat([preds[0].detach().unsqueeze(0), preds[2].detach().unsqueeze(0), preds[4].detach().unsqueeze(0), preds[6].detach().unsqueeze(0)], dim=0).permute(1, 2, 0, 3, 4)
        else:
            inputs = torch.cat([preds[1].detach().unsqueeze(0), preds[3].detach().unsqueeze(0), preds[5].detach().unsqueeze(0), images[:, :, 2].unsqueeze(0)], dim=0).permute(1, 2, 0, 3, 4)
        new_pred_f, new_pred_b = model(inputs, temp_coord, True)
        loss = F.smooth_l1_loss(new_pred_f[0], gt[idx + 2])\
            + F.smooth_l1_loss(new_pred_b[0], gt[idx + 2])\
            + F.smooth_l1_loss(new_pred_b[0], new_pred_f[0])
        loss.backward()
        optimizer.step()
        print('Epoch %d, Iter %d, Loss: %.4f' % (epoch_id, i, loss.item()))
        
    return 
warnings.filterwarnings('ignore')

##### Tensorboard #####
writer = SummaryWriter('/output/logs')

##### Parameters #####
device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser()

parser.add_argument('--data_root', type=str, default='/data/nnice1216/vimeo_septuplet/DAVIS/JPEGImages/Full-Resolution/bmx-rider/')
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--epoch_num', type=int, default=30)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--if_continue', type=bool, default=False)
parser.add_argument('--TEMP', type=float, default=1)

args = parser.parse_known_args()[0]

##### Dataset ###### 
## DAVIS 
# from dataset.Davis_liif import get_loader
# args.data_root = '/data/nnice1216/vimeo_septuplet/DAVIS/JPEGImages/Full-Resolution/bmx-rider/' 
# train_loader = get_loader(args.data_root, args.batch_size, shuffle=True, num_workers=8, drop_last=True)


## VIMEO ##
# from dataset.vimeo90k_septuplet import get_loader
# args.data_root = '/data/nnice1216/vimeo_septuplet/'
# train_loader = get_loader('train', args.data_root, args.batch_size, shuffle=True, num_workers=args.num_workers)


## GOPRO ##
from dataset.GoPro_arbitrary_nosr import get_loader
random_seed = 0
interval = 8
train_data_root = '/data/nnice1216/high_FPS_video/GOPRO_Large_all/'
# train_loader = get_loader('train', train_data_root, args.batch_size, shuffle=True, num_workers=args.num_workers, random_seed=random_seed, seted_interval=interval)
# train_loader = get_loader('train', train_data_root, args.batch_size, shuffle=True, num_workers=args.num_workers, random_seed=random_seed)


##### Model #####
## LIIF ###
model_args = {'encoder_spec': {'name': 'edsr-baseline', 'args': {'no_upsampling': True}}, 'imnet_spec': {'name': 'mlp', 'args': {'out_dim': 3, 'hidden_list': [64, 64]}}}
model_spec = {'name': 'liif_foroptical', 'args': model_args}
model = models.make(model_spec).to(device)
model = nn.DataParallel(model, device_ids=device_ids)
# for k, v in model.named_parameters():
#     if k[:14] == 'module.encoder':
#         v.requires_grad=False
        
# model.load_state_dict(torch.load('/model/nnice1216/video/FLAVR_2x.pth')['state_dict'], strict=False)

if args.if_continue:
    name = 'final-968.pth'
#     name = 'vimeo_epoch1_iter499.pth'
    print('Load model ' + name)
    model.load_state_dict(torch.load('/model/nnice1216/video/' + name))

##### Loss & Optimizer #####
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=0.8)

##### Training #####
model.train()
loss_f = PerceptualLoss(nn.MSELoss())
index = 0
# test_metric(model, 0, 0)
perc_lambda = 0.02
dcp_lambda = 0.005
FLOW_SCALE = 20.0

args1 = get_raft_args()
raft = RAFT(args1)
raft = nn.DataParallel(raft, device_ids=device_ids)
raft.load_state_dict(torch.load(args1.model))
raft.eval()
# new_state_dict = OrderedDict()
# pre_trained_dict = torch.load(args1.model)
# for k, v in pre_trained_dict.items():
#     name = k[7:]
#     new_state_dict[name] = v
# self.raft.load_state_dict(new_state_dict)
# self.raft.eval()
 
        
for epoch_id in range(args.epoch_num):
    
    print('Epoch {} Begin'.format(epoch_id))

    out_dir = '/output/models/'
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    out_dir2 = '/output/tempimgs22/'
    if not os.path.exists(out_dir2):
        os.makedirs(out_dir2)
    train_loader = get_loader('train', train_data_root, args.batch_size, shuffle=True, num_workers=args.num_workers, random_seed=random_seed, seted_interval=interval)
    for i, data in enumerate(train_loader):
        
        ## VIMEO PRE_PROCESS FOR LIIF_3D ##
        images, gt_image, coords, cells, times = data
        images = [img_.to(device) for img_ in images]
        images = torch.stack(images, dim=2)
        gt = [g_.to(device) for g_ in gt_image]
        # gt = [g_.view(args.batch_size, 3, -1).permute(0, 2, 1).to(device) for g_ in gt_image]
        # gt = torch.cat(gt).view(args.batch_size, 3, -1).permute(0, 2, 1)
        coords = [c_.to(device) for c_ in coords]
        # coord, cell = coord.to(device), cell.to(device)
 
        # Forward
        optimizer.zero_grad()
        # pred_inter, pred_0, pred_1 = model(img, coord, cell)
        # bp()
        with torch.no_grad():
            padder = InputPadder(images[:, :, 1].shape)
            padder = InputPadder(images[:, :, 1].shape)
            image1, image2 = padder.pad(images[:, :, 1].clamp(0, 1) * 255, images[:, :, 2].clamp(0, 1) * 255)
            _, flow_f = raft(image1, image2, iters=20, test_mode=True)
        
        preds_f = model(images, coords, True)
        
        ## VISA ##
        # pred = model(img)
        # gt = img.clone().detach()
        ##  ##  ##
        
        # loss1 = F.smooth_l1_loss(preds[0], images[1].view(args.batch_size, 3, -1).permute(0, 2, 1)) + F.mse_loss(preds[-1], images[2].view(args.batch_size, 3, -1).permute(0, 2, 1))
        # loss1 = F.smooth_l1_loss(preds[0], images[1]) + F.smooth_l1_loss(preds[-1], images[2])
        # with torch.no_grad():
            # feature1 = vgg_model(images[1])
            # feature2 = vgg_model(images[2])
        # bp()
        loss1 = 0
        loss2 = 0
        loss3 = 0
        loss4 = 0
        for idx in range(1):
            '''
            with torch.no_grad():
                time = coords[idx][0, 0, -1].item()
                flow_f = FLOW_SCALE*pwc(images[:, :, 1], images[:, :, 2])
                flow_f = (flow_f - flow_f.min()) / (flow_f.max() - flow_f.min())
                # flow_b = FLOW_SCALE*pwc(images[:, :, 2], images[:, :, 1])
                # flow2 = FLOW_SCALE*pwc(images[:, :, 1], preds_f[idx])
            '''
            # loss1 += F.mse_loss(feature1, feature2) * 0.04
            # loss1 += F.smooth_l1_loss(preds[idx], gt[idx - 1])
            loss1 += F.smooth_l1_loss(preds_f[idx], flow_f)
            # loss2 += F.smooth_l1_loss(feats[idx][0], feats[idx][1])
            # loss2 += F.smooth_l1_loss(preds_b[idx], gt[idx])
            # loss3 += F.smooth_l1_loss(preds_f[idx], preds_b[idx])
            # loss4 += F.smooth_l1_loss(flow_f * time * (1 - time) - flow_b * (time ** 2), flow2) * 1e-3
            # loss2 += loss_f.get_loss(preds_f[idx], gt[idx]) * perc_lambda
            # loss3 += DCLoss(preds_f[idx].clamp(0, 1)) * dcp_lambda
            # loss2 += F.smooth_l1_loss(preds_b[idx], gt[idx]) + loss_f.get_loss(preds_b[idx], gt[idx]) * perc_lambda
            # loss3 += F.smooth_l1_loss(preds_f[idx], preds_b[idx]) + loss_f.get_loss(preds_f[idx], preds_b[idx]) * perc_lambda
            # loss2 += 0.1 * F.mse_loss(edge_conv2d(preds[idx]), edge_conv2d(gt[idx]))
            # loss2 += 0.05 * F.mse_loss(preds[idx], preds[idx - 1])
        
        # loss = F.mse_loss(preds[0], images[0].view(args.batch_size, 3, -1).permute(0, 2, 1)) + F.mse_loss(preds[-1], images[1].view(args.batch_size, 3, -1).permute(0, 2, 1))\
        #        + F.mse_loss(preds[1], gt[0]) + F.mse_loss(preds[2], gt[1]) + F.mse_loss(preds[3], gt[2]) + F.mse_loss(preds[4], gt[3]) + F.mse_loss(preds[5], gt[4])
        
        # loss1 = F.smooth_l1_loss(preds, gt[0])
        
        # Backward & Update
        # bp()
        loss = loss1 + loss2 + loss3 + loss4
        # loss *= 0.05
        # print(loss1.item(), loss2.item(), loss3.item(), loss4.item())
        # loss = loss1
        loss.backward()
        optimizer.step()
        
        print('Epoch %d, Iter %d, Loss: %.4f' % (epoch_id, i, loss.item()))
        # print('Loss1: %.4f, Loss2: %.4f' % (loss1.item(), loss2.item()))
        # print('Loss1: %.4f, Loss2: %.4f, Loss3: %.4f' % (loss1.item(), loss2.item(), loss3.item()))
        # print('Epoch %d, Iter %d, Loss1: %.4f, Loss2: %.4f, Loss: %.4f' % (epoch_id, i, loss1.item(), loss2.item(), loss.item()))
        writer.add_scalar('Training Loss', loss.item(), index)
        index += 1
        # index += 1
        
        if i % 50 == 0:
            Image.fromarray((gt_image[0][0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)).save(os.path.join(out_dir2, 'Epoch{}_iter{}_GT.jpg'.format(epoch_id, i)))
            flow = preds_f[0].data.cpu()
            flow = flow[0].numpy().transpose((1,2,0))
            flow_im = flow_to_image(flow)
            Image.fromarray(flow_im).save(os.path.join(out_dir2, 'Epoch{}_iter{}_PRED.jpg'.format(epoch_id, i)))
        # Image.fromarray((preds[0][0].clamp(0, 1).view(192, 192, 3).detach().cpu().numpy() * 255).astype(np.uint8)).save(os.path.join(out_dir2, 'Epoch{}_iter{}.jpg'.format(epoch_id, i)))
        if i % 100 == 0:
            test_images_optical(model, device, 1 / 64, epoch_id, i)
            # test_images_3d(model, device, 0.85, epoch_id, i)
        if i % 100 == 20:
            test_images_optical(model, device, 8 / 64, epoch_id, i)
            # test_images_3d(model, device, 0.9, epoch_id, i)
        elif i % 100 == 40:
            test_images_optical(model, device, 32 / 64, epoch_id, i)
            # test_images_3d(model, device, 0.95, epoch_id, i)
        elif i % 100 == 60:
            test_images_optical(model, device, 48 / 64, epoch_id, i)
        elif i % 100 == 80:
            test_images_optical(model, device, 63 / 64, epoch_id, i)  
        model.train()
        # double_forward(model, optimizer, preds_f, images, gt, device, i, epoch_id, args.batch_size)
        
        if (i + 1) % 100 == 0:
            torch.save(model.state_dict(), '/output/models/vimeo_epoch{}_iter{}.pth'.format(epoch_id, i))
        if (i + 1) % 1000 == 0:
            test_metric(model, epoch_id, i)
    if epoch_id % 7 == 6:
        test_metric(model, epoch_id, i)
    print("Epoch {} Done. Index={}".format(epoch_id, index))
    # if epoch_id == 29:
    #     generate_video(model, epoch_id, device)
    
    # scheduler.step()


BATCHNORM:  False
Epoch 0 Begin
1891
1891
1891


RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu

In [1]:
import torchvision.transforms as tfs
from PIL import Image
import os
import myutils
from torch.utils.data import Dataset, DataLoader
from dataset.temp_test import TEMP

losses, psnrs, ssims = myutils.init_meters('1*L1')

test_loader = DataLoader(TEMP(), batch_size=8, shuffle=False, num_workers=8, drop_last=False)

for i, data in enumerate(test_loader):
    print(i, "DONE!")
    gt, pred = data
    gt, pred = gt.cuda(), pred.cuda()
    
    myutils.eval_metrics(gt, pred, psnrs, ssims)

    print('Val_PSNR:{0:.4f}, Val_SSIM:{1:.4f}'
          .format(psnrs.avg, ssims.avg))

['000019.jpg', '000020.jpg', '000021.jpg', '000022.jpg', '000023.jpg', '000024.jpg', '000025.jpg', '000026.jpg', '000027.jpg', '000028.jpg', '000029.jpg', '000030.jpg', '000031.jpg', '000032.jpg', '000033.jpg', '000034.jpg', '000035.jpg', '000036.jpg', '000037.jpg', '000038.jpg', '000039.jpg', '000040.jpg', '000041.jpg', '000042.jpg', '000043.jpg', '000044.jpg', '000045.jpg', '000046.jpg', '000047.jpg', '000048.jpg', '000049.jpg', '000050.jpg', '000051.jpg', '000052.jpg', '000053.jpg', '000054.jpg', '000055.jpg', '000056.jpg', '000057.jpg', '000058.jpg', '000059.jpg', '000060.jpg', '000061.jpg', '000062.jpg', '000063.jpg', '000064.jpg', '000065.jpg', '000066.jpg', '000067.jpg', '000068.jpg', '000069.jpg', '000070.jpg', '000071.jpg', '000072.jpg', '000073.jpg', '000074.jpg', '000075.jpg', '000076.jpg', '000077.jpg', '000078.jpg', '000079.jpg', '000080.jpg', '000081.jpg', '000082.jpg', '000083.jpg', '000084.jpg', '000085.jpg', '000086.jpg', '000087.jpg', '000088.jpg', '000089.jpg', '0000

In [1]:
import os
import sys
import time
import copy
import shutil
import random
import pdb

import torch
import numpy as np
from tqdm import tqdm

import config
import myutils
import models

from PIL import Image
import torch.nn as nn
from torch.utils.data import DataLoader
from pdb import set_trace as bp


##### Parse CmdLine Arguments #####
# os.environ["CUDA_VISIBLE_DEVICES"]='7'
args, unparsed = config.get_args()
cwd = os.getcwd()

device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device('cuda' if args.cuda else 'cpu')

torch.manual_seed(args.random_seed)
if args.cuda:
    torch.cuda.manual_seed(args.random_seed)

##### FLAVR #####

model_args = {'encoder_spec': {'name': 'edsr-baseline', 'args': {'no_upsampling': True}}, 'imnet_spec': {'name': 'mlp', 'args': {'out_dim': 3, 'hidden_list': [64, 64]}}}
model_spec = {'name': 'liif_bidi', 'args': model_args}
model = models.make(model_spec).to(device)
model = nn.DataParallel(model, device_ids=device_ids)
name = '633-best-21-22.pth'
model.load_state_dict(torch.load('/model/nnice1216/video/' + name))


def get_name(index):
    if index >= 0 and index <= 9:
        text = '00000' + str(index) + '.jpg'
    elif index >= 10 and index <= 99:
        text = '0000' + str(index) + '.jpg'
    elif index >= 100 and index <= 999:
        text = '000' + str(index) + '.jpg'
    else:
        text = '00' + str(index) + '.jpg'
    return text

def save_image(img, index):
    if index < 1600:
        img.save(os.path.join(out_path1, get_name(index)))
    elif index < 3200:
        img.save(os.path.join(out_path2, get_name(index)))
    elif index < 4800:
        img.save(os.path.join(out_path3, get_name(index)))
    elif index < 6400:
        img.save(os.path.join(out_path4, get_name(index)))
    elif index < 8000:
        img.save(os.path.join(out_path5, get_name(index)))
    elif index < 9600:
        img.save(os.path.join(out_path6, get_name(index)))
    else:
        img.save(os.path.join(out_path7, get_name(index)))
        

def test(model, epoch_id, iter_id):
    
    out_path1 = '/output/Imagestry11/'
    out_path2 = '/output/Imagestry2/'
    out_path3 = '/output/Imagestry3/'
    out_path4 = '/output/Imagestry4/'
    out_path5 = '/output/Imagestry5/'
    out_path6 = '/output/Imagestry6/'
    out_path7 = '/output/Imagestry7/'
    if not os.path.exists(out_path1):
        os.makedirs(out_path1)
    if not os.path.exists(out_path2):
        os.makedirs(out_path2)
    if not os.path.exists(out_path3):
        os.makedirs(out_path3)
    if not os.path.exists(out_path4):
        os.makedirs(out_path4)
    if not os.path.exists(out_path5):
        os.makedirs(out_path5)
    if not os.path.exists(out_path6):
        os.makedirs(out_path6)
    if not os.path.exists(out_path7):
        os.makedirs(out_path7)
        
    torch.cuda.empty_cache()
    from dataset.GoPro_singlevideo import get_loader as gl
    test_loader = gl('train', '/data/nnice1216/high_FPS_video/GOPRO_Large_all/GOPR0384_11_00/', 1, shuffle=False, num_workers=8, interFrames=15)    
    # test_loader = gl('train', '/output/liif_result/', 1, shuffle=False, num_workers=8, interFrames=15)
    
    time_taken = []
    losses, psnrs, ssims = myutils.init_meters('1*L1')
    model.eval()
    save_index = 17
    psnr_list = []
    outpath = '/output/liif_result/'
    if not os.path.exists(outpath):
        os.makedirs(outpath)
        
    with torch.no_grad():
        for i, (images, coords, cells, times, names) in enumerate(test_loader):
            print(names[1][0].split('/')[-1][:6], "SHOULD EQUAL TO {}, DONE".format(save_index))

            images = [img_.cuda() for img_ in images]
            coords = [coords[0].cuda()]
            # coords = [c_.cuda() for c_ in coords]
            # cells = [c_.cuda() for c_ in cells]
            
            Image.fromarray((images[1][0].permute(1, 2, 0).detach().cpu() * 255).numpy().astype(np.uint8)).save(os.path.join(out_path1, get_name(save_index)))
            print(save_index, "DONE!")
            save_index += 1
            
            torch.cuda.synchronize()
            start_time = time.time()
            # out = model(images)
            
            out = model(images, coords, 8, save_index)
            save_index += 14
            torch.cuda.synchronize()
            time_taken.append(time.time() - start_time)
            print(os.path.join(out_path1, get_name(save_index)))
            Image.fromarray((out[0].clamp(0, 1).permute(1, 2, 0).detach().cpu() * 255).numpy().astype(np.uint8)).save(os.path.join(out_path1, get_name(save_index)))
            print(save_index, "DONE!")
            save_index += 1
            
            del images
            torch.cuda.empty_cache()        
    # print("PSNR: %f, SSIM: %f" %
    #       (psnrs.avg, ssims.avg))
    # print("Time: " , sum(time_taken)/len(time_taken))
    torch.cuda.empty_cache()
    return


""" Entry Point """
def main(args):
    
    assert args.load_from is not None

    test(model, 1, 1)


if __name__ == "__main__":
    main(args)

Unparsed args: ['-f', '/root/.local/share/jupyter/runtime/kernel-264af218-60f9-4bb3-ba8b-ebed003b01fe.json']
BATCHNORM:  False
66
000017 SHOULD EQUAL TO 17, DONE
17 DONE!
INDEX 18, DONE!
INDEX 19, DONE!
INDEX 20, DONE!
INDEX 21, DONE!
INDEX 22, DONE!
INDEX 23, DONE!
INDEX 24, DONE!
INDEX 25, DONE!
INDEX 26, DONE!
INDEX 27, DONE!
INDEX 28, DONE!
INDEX 29, DONE!
INDEX 30, DONE!
INDEX 31, DONE!
/output/Imagestry11/000032.jpg
32 DONE!
000033 SHOULD EQUAL TO 33, DONE
33 DONE!
INDEX 34, DONE!
INDEX 35, DONE!
INDEX 36, DONE!
INDEX 37, DONE!
INDEX 38, DONE!
INDEX 39, DONE!
INDEX 40, DONE!
INDEX 41, DONE!
INDEX 42, DONE!
INDEX 43, DONE!
INDEX 44, DONE!
INDEX 45, DONE!
INDEX 46, DONE!
INDEX 47, DONE!
/output/Imagestry11/000048.jpg
48 DONE!
000049 SHOULD EQUAL TO 49, DONE
49 DONE!
INDEX 50, DONE!
INDEX 51, DONE!
INDEX 52, DONE!
INDEX 53, DONE!
INDEX 54, DONE!
INDEX 55, DONE!
INDEX 56, DONE!
INDEX 57, DONE!
INDEX 58, DONE!
INDEX 59, DONE!
INDEX 60, DONE!
INDEX 61, DONE!
INDEX 62, DONE!
INDEX 63, D