In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from avcv.all import *
from models import get_stn
from utils.download import download_model, PRETRAINED_TEST_HYPERPARAMS
from utils.vis_tools.helpers import load_pil, save_image

model_class = 'ir_face'  # choose the class you want to use
resolution = 256  # resolution the input image will be resized to (can be any power of 2)
# image_path = 'my_image.jpeg'  # path to image you want to align



  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def load_pil(path, resolution=None):
    arr = mmcv.imread(path, 0)
    arr = np.stack([arr]*3, -1)
    img = Image.fromarray(arr)
    if resolution is not None:
        img = img.resize((resolution, resolution), Image.LANCZOS)
    img = torch.tensor(np.asarray(img), device='cpu', dtype=torch.float).unsqueeze_(0).permute(0, 3, 1, 2)
    img = img.div(255.0).add(-0.5).mul(2)  # [-1, 1]
    return img  

In [4]:
class DS:
    def __init__(self, paths):
        self.paths = paths
    def __getitem__(self, idx):
        return load_pil(self.paths[idx], resolution)[0]
    def __len__(self):
        return len(self.paths)

In [5]:
# df = pd.read_csv('/data/RLDD/')

In [66]:
from models import total_variation_loss

@torch.inference_mode()
def compute_flow_scores(batch, t):
    # loader = img_dataloader(args.real_data_path, resolution=args.real_size, batch_size=args.batch, shuffle=False,
    #                         distributed=args.distributed, infinite=False, drop_last=False)
    # num_total = len(loader.dataset)
    # scores = []
    # pbar = tqdm(loader) if primary() else loader
    # for batch in pbar:
    batch = batch.to('cuda')
    # batch, _, _ = determine_flips(args, t, None, batch)
    _, flows = t(batch, return_flow=True, iters=1, padding_mode='border')
    smoothness = total_variation_loss(flows, reduce_batch=False)
    return smoothness
    #     scores.append(smoothness)
    # scores = -torch.cat(scores, 0)  # lower (more negative) scores indicate worse images
    # synchronize()
    # scores = all_gather(scores, cat=False)
    # scores = scores.permute(1, 0).reshape(-1)[:num_total]
    # if primary():
    #     score_path = f'{args.real_data_path}/flow_scores.pt'
    #     torch.save(scores.cpu(), score_path)
    #     print(f'num_scores = {scores.size(0)}')
    #     print(f'Flow scores saved at {score_path}')
    # return scores

In [6]:
def unnorm(img):
    img = (img-img.min())/(img.max()-img.min())
    img = img*255
    img = img.permute([1,2,0]).cpu().numpy().astype('uint8')
    return img
    # return img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()

In [7]:
ckpt = download_model(model_class)  # download model weights
stn = get_stn(['similarity', 'flow'], flow_size=128, supersize=resolution).to('cuda')  # instantiate STN
stn.load_state_dict(ckpt['t_ema'])  # load weights
test_kwargs = PRETRAINED_TEST_HYPERPARAMS[model_class]  # load test-time hyperparameters


In [88]:
import torch
paths = glob('/data/DMS_Drowsiness/all_video_symlink/4999d3cbd30f082fedae237fd867814b/croped_faces/*')
print(len(paths))
paths = list(sorted(paths))
l = len(paths)//100
s = l*60
paths = paths[s:s+l]
ds = DS(paths)
dl = torch.utils.data.DataLoader(ds, 10, num_workers=10)

90209


In [89]:
len(dl)

91

In [90]:
outs = []
scores = []
with torch.no_grad():
    pbar = tqdm(dl, total=len(dl))
    for input_img in pbar:

        aligned_img = stn.forward(input_img.cuda(), output_resolution=resolution, **test_kwargs)  # forward pass through the STN
        # save_image(aligned_img, 'output.png', normalize=True, range=(-1, 1))  # save to disk
        aligned_img = [unnorm(_) for _ in aligned_img]
        outs.extend(aligned_img)
        scores.append(compute_flow_scores(input_img, stn))
scores = -torch.cat(scores)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:09<00:00,  9.69it/s]


In [91]:
norm_scores = (scores-scores.min())/(scores.max()-scores.min())

In [96]:
cat_ims = []
for imwarp, score, path in zip(outs, norm_scores, paths):
    img = mmcv.imread(path, channel_order='bgr')
    img = mmcv.imresize_like(img, imwarp)
    cat_im = np.concatenate([img, imwarp], 1)
    cat_im = put_text(cat_im, (10, 40), f'{score*100:0.2f}', (0, int(255*score), 255-int(255*score)))
    cat_ims.append(cat_im)

In [97]:
images_to_video(cat_ims, 'vis.mp4', output_size=(320,160))

2022-07-07 04:57:08.970 | INFO     | avcv.utils:images_to_video:283 - Write video, output_size: (320, 160)


[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 902/902, 794.5 task/s, elapsed: 1s, ETA:     0s

2022-07-07 04:57:10.107 | INFO     | avcv.utils:images_to_video:293 - -> /home/anhvth8/gitprojects/gangealing/vis.mp4
