In [None]:
%matplotlib inline

from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from datasets.davis import DAVISPairDataset

import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--root", help='path to DAVIS-like folder')
parser.add_argument("--anno", default="Annotations",
                    help='path to Annotations subfolder (of ROOT)')
parser.add_argument("--jpeg", default="JPEGImages",
                    help='path to JPEGImages subfolder (of ROOT)')
parser.add_argument("--res", default="480p",
                    help='path to Resolution subfolder (of ANNO and JPEG)')
parser.add_argument("--imgset", default="ImageSets",
                    help='path to ImageSet subfolder (of ROOT)')
parser.add_argument("--year", default="2017",
                    help='path to Year subfolder (of IMGSET)')
parser.add_argument("--phase", default="train",
                    help='path to phase txt file (of IMGSET/YEAR)')
parser.add_argument("--mode", default=0, type=int,
                    help='frame pair selector mode')
args = parser.parse_args('--root data/DAVIS-trainval --res 480p_split --mode 1'.split())

dataset = DAVISPairDataset(root_path=args.root,
                           annotation_folder=args.anno,
                           jpeg_folder=args.jpeg,
                           resolution=args.res,
                           imageset_folder=args.imgset,
                           year=args.year,
                           phase=args.phase,
                           mode=args.mode)

dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

for idx, batch in enumerate(dataloader):
    for support_img, support_anno, query_img, query_anno in zip(*batch[0], batch[1]):
        fig, ax = plt.subplots(2, 2)
        ax[0, 0].imshow(support_img.permute(1, 2, 0))
        ax[0, 1].imshow(support_anno.squeeze())
        ax[1, 0].imshow(query_img.permute(1, 2, 0))
        ax[1, 1].imshow(query_anno.squeeze())

        fig.tight_layout()
        plt.show()
        plt.close()

In [None]:
%matplotlib inline

import torch
from torch.utils.data import DataLoader
import torchvision
import numpy as np
import matplotlib.pyplot as plt

from utils.getter import get_instance, get_data
from utils.device import move_to

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--weight')
parser.add_argument('--gpus', default=None)
args = parser.parse_args('--weight runs/UNet_Single-2020_05_03-04_23_34/best_loss.pth'.split())

dev_id = 'cuda:{}'.format(args.gpus) \
    if torch.cuda.is_available() and args.gpus is not None \
    else 'cpu'
device = torch.device(dev_id)

config = torch.load(args.weight, map_location=dev_id)

model = get_instance(config['config']['model']).to(device)
model.load_state_dict(config['model_state_dict'])

_, dataloader = get_data(config['config']['dataset'],
                         config['config']['seed'])

def visualize_attention(m, i, o):
    _, _, H, W = i[0].shape
    attns = m.attn_score.mean(dim=1).reshape(-1, H, W).cpu()
    for attn in attns:
        plt.imshow(attn, vmin=0.0)
model.middle_conv.register_forward_hook(visualize_attention)

model.eval()
with torch.no_grad():
    for idx, batch in enumerate(dataloader):
        output = model(move_to(batch[0], device))
        preds = torch.argmax(output, dim=1).cpu()
        for support_img, support_anno, query_img, query_anno, pred in zip(*batch[0], batch[1], preds):
            print('=' * 60)

            fig, ax = plt.subplots(2, 2)
            ax[0, 0].imshow(support_img.permute(1, 2, 0))
            ax[0, 0].imshow(support_anno.squeeze(0), alpha=0.2)
            ax[0, 0].set_title('Reference frame + GT')
            
            ax[1, 0].imshow(query_img.permute(1, 2, 0))
            ax[1, 0].imshow(pred.squeeze(0), alpha=0.2)
            ax[1, 0].set_title('Query frame + Pred')
            
            ax[0, 1].imshow((query_anno & ~pred).squeeze(0))
            ax[0, 1].set_title('False Negative')
            ax[1, 1].imshow((~query_anno & pred).squeeze(0))
            ax[1, 1].set_title('False Positive')

            fig.tight_layout()
            plt.show()
            plt.close()

In [None]:
%matplotlib inline

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (30,10)

from utils.getter import get_instance, get_data
from utils.device import move_to

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--weight')
parser.add_argument('--gpus', default=None)
args = parser.parse_args('--weight backup/STMWithSTN-2020_05_11-22_43_49/best_loss.pth --gpus 0'.split())

dev_id = 'cuda:{}'.format(args.gpus) \
    if torch.cuda.is_available() and args.gpus is not None \
    else 'cpu'
device = torch.device(dev_id)

config = torch.load(args.weight, map_location=dev_id)

model = get_instance(config['config']['model']).to(device)
model.load_state_dict(config['model_state_dict'])

config['config']['dataset']['val']['loader']['args']['shuffle'] = True
# config['config']['dataset']['val']['args']['shuffle'] = True
_, dataloader = get_data(config['config']['dataset'],
                         config['config']['seed'])

# model.eval()
# with torch.no_grad():
#     for idx, batch in enumerate(dataloader):
#         a_im, a_seg, b_im, c_im, nobjs = batch[0]
#         c_seg = batch[1]
        
#         a_seg = F.one_hot(a_seg, 11).permute(0, 3, 1, 2)
#         a_im_stn = model.stn(a_im, a_seg)
        
#         fig, ax = plt.subplots(3, 1)
#         ax[0].imshow(a_im[0].permute(1, 2, 0))
#         ax[1].imshow(a_im_stn[0].permute(1, 2, 0))
#         ax[2].imshow(torch.abs(a_im[0] - a_im_stn[0]).mean(0))
#         plt.show()
#         plt.close()
# #         output = model.stn(move_to(batch[0], device))
# #         break

model.eval()
with torch.no_grad():
    for idx, batch in enumerate(dataloader):
        output = model(move_to(batch[0], device))
        preds = torch.argmax(output, dim=1).cpu()
        for support_img, support_anno, _, query_img, nobjs, query_anno, pred in zip(*batch[0], batch[1], preds):
            print('=' * 60)

            fig, ax = plt.subplots(2, 2)
            ax[0, 0].imshow(support_img.permute(1, 2, 0))
            ax[0, 0].imshow(support_anno.squeeze(0), alpha=0.2)
            ax[0, 0].set_title('Reference frame + GT')
            
            ax[1, 0].imshow(query_img.permute(1, 2, 0))
            ax[1, 0].imshow(pred.squeeze(0), alpha=0.2)
            ax[1, 0].set_title('Query frame + Pred')
            
            ax[0, 1].imshow((query_anno & ~pred).squeeze(0))
            ax[0, 1].set_title('False Negative')
            ax[1, 1].imshow((~query_anno & pred).squeeze(0))
            ax[1, 1].set_title('False Positive')

            fig.tight_layout()
            plt.show()
            plt.close()

In [172]:
from IPython.display import Video
video_name = 'deer'

In [174]:
Video(f'viz/STM_DAVIS_17test-dev_/{video_name}.mp4')

In [173]:
Video(f'viz/STM_DAVIS_17test-dev/{video_name}.mp4')