In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as T
import sal_ssim
import skvideo.io
from pathlib import Path
import os
import csv
from time import time
from threading import BoundedSemaphore, Thread
import traceback
import piq
import math
import pathlib
import json
from copy import deepcopy
import pickle as pkl
import logging


plt.rcParams["figure.figsize"] = (10, 10)

In [2]:
DATASET = Path('/home/a_mos/2019')

OUTPUT = Path('results_16_08')
RESULTS = OUTPUT / 'raw.csv'
ERRORS =  OUTPUT / 'errors.txt'

HEADERS = 'Name,Reference,Distorted,Metric,Metric_val,Mask,Time'.split(',')

In [3]:
class VideoReader:
    def __init__(self, path, shape=None, batch_size=10, verbose=False):
        if str(path).endswith('yuv'):
            self.vr = skvideo.io.vreader(path, 1080, 1920, inputdict={'-pix_fmt': 'yuvj420p'})
        else:
            self.vr = skvideo.io.vreader(path)
        self.batch_size = batch_size
        self.shape = shape or tuple(int(x) for x in tuple(torch.Tensor(skvideo.io.FFmpegReader(path).getShape())))
        self.len = self.shape[0]
        self.pos = 0
        self.time_out = 10
        self.verbose = verbose
        
    def __iter__(self):
        for batch_num in range(0, self.len, self.batch_size):
            if self.verbose and batch_num % self.time_out == 0:
                print(f'batch {batch_num // self.batch_size} / {self.len // self.batch_size}')
                
            batch_size = min(self.batch_size, self.len - self.pos)
            shape = (batch_size, *self.shape[1:])
            
            batch = np.zeros(shape)
            
            for i, frame in enumerate(self.vr):
                if i < batch_size:
                    if frame.shape != self.shape[1:]:
                        frame = cv2.resize(frame, (shape[2], shape[1]))
                    batch[i, ...] = frame / 255
                else:
                    break
            batch = torch.Tensor(batch).permute(0, 3, 1, 2)
            yield batch
            
class MyWriter:
    def __init__(self, f):
        self.filename = f
        if not os.path.exists(self.filename):
            os.open(self.filename, os.O_CREAT)

    def write_row(self, row):
        with open(self.filename, 'a') as f:
            writer = csv.writer(f)
            writer.writerow(row)

    def write_text(self, text):
        with open(self.filename, 'a') as f:
            f.write(text)
            print(text)
            
def folder2name(folder):
    name = os.path.basename(os.path.normpath(folder))
    name = "_".join(name.split("_")[:-1])
    return name

class CommandHandler:
    def __init__(self, max_processes=2):
        self.maxp = max_processes
        self.sem = BoundedSemaphore(self.maxp)
    def run(self, target, args):
        def func_with_end(args):
            target(*args)
            self.sem.release()
            # print('end')
            
        self.sem.acquire()
        thread = Thread(target=func_with_end, args=(args,))        
        thread.start()

In [4]:
def compute(seq, ref, dis, mask, metric, mask_mode, iteration, length, verbose):
    if verbose:
        print(args)
        
    time_start_video = time()
    try:
        if not IDLE:
            ssim_vals = []
            vr_dis = VideoReader(dis, verbose=True)
            vr_ref = VideoReader(ref, shape=vr_dis.shape)

            if mask_mode == True:
                vr_mask = VideoReader(mask, shape=vr_dis.shape)
                for batch_ref, batch_dis, batch_mask in zip(vr_dis, vr_ref, vr_mask):
                    if verbose:
                        print(batch_ref.shape, batch_dis.shape, batch_mask.shape)
                        print(batch_ref.mean(), batch_dis.mean(), batch_mask.mean())
                    
                    val = sal_ssim.ssim(batch_ref, batch_dis, batch_mask)
                    
                    if verbose:
                        print(val)
                    
                    ssim_vals.append(val)
            else:
                for batch_ref, batch_dis in zip(vr_dis, vr_ref):
                    if verbose:
                        print(batch_ref.shape, batch_dis.shape)
                        print(batch_ref.mean(), batch_dis.mean())
                    
                    val = piq.ssim(batch_ref, batch_dis)
                    
                    if verbose:
                        print(val)
                        
                    ssim_vals.append(val)
            
            with open(f"{OUTPUT / 'raw_values' / f'{dis}_{metric}_{mask_mode}'}", 'wb') as f:
                pkl.dump(ssim_vals, f)
            
            ssim_vals = np.array(ssim_vals)
            ssim_vals = ssim_vals[np.isfinite(ssim_vals)]
            metric_value = float(ssim_vals.mean())
        else:
            metric_value = 0.6666666
            
        time_calc = int(time() - time_start_video)
        row = [seq, ref, dis, metric, round(metric_value, 4), mask_mode, time_calc]
        writer_results.write_row(row)
        print("%3d% [%d/%d] :: %s   %s   %-5s   %02d:%02d   %-70s " % 
              (round(iterations/length), iterations, lenght, round(metric_value, 3),   metric,   mask_mode, time_calc // 60, time_calc % 60, dis.split('/')[-1]))
    except:
        print(traceback.format_exc())
        writer_errors.write_text(f'ERROR:  seq: {seq}, ref: {ref}, dis: {dis}, mask: {mask}, metric: {metric}, mask_mode: {mask_mode}')

In [None]:
def listdir(dir_name: Path):
    return sorted([dir_name / file for file in os.listdir(dir_name) if '.ipynb' not in str(file)])

def json_representation_of_dataset(dataset_path):
    json_data = {}
    for sequence_folder in os.listdir(dataset_path / 'seq'):
        sequence_name = sequence_folder.removesuffix('_x265')
        
        reference = [str(file) for file in listdir(dataset_path / 'ref') if sequence_name in str(file)][0]
        mask = [str(file) for file in listdir(dataset_path / 'masks') if sequence_name in str(file)][0]
        distorted = [str(file) for file in listdir(dataset_path / 'seq' / sequence_folder)]
        
        json_data[sequence_name] = {
            'mask': mask,
            'ref': reference,
            'dis': distorted
        }
        
    return json_data

def make_runconfig(dataset, metrics, mask_modes):
    runconfig = deepcopy(dataset)
    for seq in dataset:
        del runconfig[seq]['dis']
        runconfig[seq]['exp'] = []
        runconfig[seq]['exp'] = [
            {
                'dis': dis,
                'metric': metric,
                'mask_mode': mask_mode
            } for metric in metrics for mask_mode in mask_modes for dis in dataset[seq]['dis']
        ]
    return runconfig

In [None]:
metrics = ['ssim']
mask_modes = [True, False]

dataset = json_representation_of_dataset(DATASET)
global_runconfig = make_runconfig(dataset, metrics, mask_modes)

In [None]:
print(json.dumps(global_runconfig, indent=4)[:1000])

In [8]:
command_handler = CommandHandler(2)
writer_results = MyWriter(RESULTS)
writer_errors = MyWriter(ERRORS)
# writer_results.write_row(HEADERS)

all_sequences = np.array(['crowd_run', 'kayak_trip', 'tractor', 'making_alcohol', 'wedding_party'])
sequences = all_sequences[[-3]]
runconfig = {seq: global_runconfig[seq] for seq in sequences}
length = sum(len(runconfig[seq]['dis'][dis]) for seq in runconfig for dis in runconfig[seq]['dis'])
print(length)

216


In [12]:
# print(json.dumps(runconfig, indent=4))

In [24]:
list(runconfig['tractor']['dis'].keys())[60]

'/home/a_mos/2019/seq/tractor_x265/enc_res_t265_mv_subjective_tractor_short_4000.mkv'

In [10]:
# IDLE = False

# iteration = 0
# for seq in runconfig:
#     for dis in runconfig[seq]['dis']:
#         for regime in runconfig[seq]['dis'][dis]:
#             ref = runconfig[seq]['ref']
#             mask = runconfig[seq]['mask']
#             metric = regime['metric']
#             mask_mode = regime['mask_mode']
            
#             args = (seq, ref, dis, mask, metric, mask_mode, iteration, length, True)
#             command_handler.run(target=compute, args=args)
            
#             iteration += 1