In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import os
import pix2pix
import cv2
import matplotlib.pyplot as plt
from pix2pix import models
import tqdm
import av
from torch2trt import torch2trt
from pix2pix import config

In [None]:
gt_path = 'videos/input_video_gt.mp4'
noisy_path = 'videos/input_video_noisy.mp4'
pred_save_path = 'videos/predicted_video_v5.mp4'

device = "cuda" if torch.cuda.is_available() else "cpu"
bs = 256
print (device)

In [None]:
# state_dict = torch.load('../../models/pix2pix_32_11112021_051302/checkpoints/gen_epoch_1_pix2pix_32_11112021_051302.pt')
# state_dict
state_dict = torch.load('../../models/pix2pix_32_13112021_183033/gen_pix2pix_32_13112021_183033.pt')

In [None]:
model = models.unet(norm_layer=torch.nn.InstanceNorm2d, **config.model.GEN_ARGS)
model = torch.nn.DataParallel(model)
model.load_state_dict(state_dict)
model = model.module
model.eval();
# model_trt = model
model_trt = torch2trt(model, [torch.ones(1, 3, 256, 256).to(device)], max_batch_size=bs)

In [None]:
# x = torch.ones((32, 3, 256, 256)).cuda()

# y = model(x)
# y_trt = model_trt(x)

# print(torch.max(torch.abs(y - y_trt)))

del model

In [None]:
import gc

gc.collect()

In [None]:
class img_dataset(torch.utils.data.Dataset):
    def __init__(self, imgs):
        self.imgs = imgs

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.to_list()
        
        inp = torch.Tensor(self.imgs[idx])
        inp = (inp / (255 / 2)) - 1
        inp = inp.permute((2, 0, 1))
        
        return inp    

In [None]:
%%time

noisy_cap = cv2.VideoCapture(noisy_path)
fps = noisy_cap.get(cv2.CAP_PROP_FPS)
width = int(noisy_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(noisy_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Be sure to use lower case
    
frame_list = []

while noisy_cap.isOpened():
    ret, frame = noisy_cap.read()
    
    if ret == False:
        break
        
    frame_list.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    
noisy_cap.release()
img_dataloader = torch.utils.data.DataLoader(img_dataset(frame_list), 
                                             batch_size = 128,
                                             shuffle = False,
                                            num_workers = 1)

container = av.open(pred_save_path, mode = 'w')
stream = container.add_stream('h264_nvenc', rate = fps)
stream.width = width
stream.height = height

model_trt.eval()

preds_list = []

for ip in tqdm.tqdm(img_dataloader, total = len(img_dataloader)):    
    with torch.no_grad():
        ip = ip.to(device)
        preds = model_trt(ip).detach()
            
    preds_list.extend(list(preds))



for pred in preds_list:
    pred = ((pred + 1) / 2).permute(1, 2, 0).cpu()
    pred = (pred.numpy() * 255).astype('uint8')
    
    frame = av.VideoFrame.from_ndarray(pred, format='rgb24')
    
    for packet in stream.encode(frame):
        container.mux(packet)


for packet in stream.encode():
    container.mux(packet)

del ip
del pred    
del preds
del img_dataloader
del preds_list
torch.cuda.empty_cache()
container.close()
noisy_cap.release()