In [None]:
% load_ext autoreload
% autoreload 2

%matplotlib inline

In [None]:
from dataset import *
from model import *

In [None]:
root = os.path.join("..", "data", "test")
anno_path = os.path.join(root, "test.csv")
data_test = fashion_ai_dataset(root, anno_path, is_train=False)

In [None]:
import matplotlib.pyplot as plt
import cv2
import numpy as np

from random import randint
from torchvision.transforms import ToPILImage, Normalize

def show_data(x, y, y_pred=None):
    to_img = ToPILImage()
    to_kpts = HeatmapToKeyPoints()
    
    def kpts_img(kpts, blur=True):
        img = torch.sum(kpts, 0, keepdim=True)
        img = to_img(img)
        img = np.array(img)
        if blur:
            img = cv2.GaussianBlur(img, (5, 5), 1)
        return img
    
    num_img = 2 if y_pred is None else 3
    plt.figure(figsize=(18, 18 // num_img))
    
    # undo the normalize by resnet
    mu = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    mu = -mu / std
    std = 1 / std
    normalize = Normalize(mu, std)
    plt.subplot(1, num_img, 1)
    plt.imshow(to_img(normalize(x)))
    
    y, mask, size = y
    y *= mask
    plt.subplot(1, num_img, 2)
    plt.imshow(kpts_img(y, blur=False))

    if y_pred is not None:
        kpts = to_kpts(y_pred.unsqueeze(0), mask.unsqueeze(0), size.unsqueeze(0))
        print(kpts)
        y_pred *= mask
        plt.subplot(1, num_img, 3)
        plt.imshow(kpts_img(y_pred))

In [None]:
model = CascadePyramidNet(24)
model.load_state_dict(torch.load("w.h5"))

In [None]:
test_loader = data.DataLoader(data_test, batch_size=32, num_workers=6)

to_kpts = HeatmapToKeyPoints()
result = pd.DataFrame(columns=kpt_names)

model = model.cuda()
for x, (mask, size) in test_loader:
    x = x.cuda()
    mask = mask.cuda()
    size = size.cuda()
    with torch.no_grad():
        _, heatmap = model(x)
    bat_kpts = to_kpts(heatmap, mask, size)
    bat_kpts = bat_kpts.cpu().numpy()
    bat_kpts = list(map(lambda kpts : list(map(lambda kpt : "_".join(kpt.astype('str')), kpts)), bat_kpts))
    bat_kpts = pd.DataFrame(data=bat_kpts, columns=FashionAIDataset.kpt_names)
    result = result.append(bat_kpts,ignore_index=True)
    
df = data_test.dataset.df[["image_id", "image_category"]]
result = df.join(result)

In [None]:
result.to_csv("r.csv")

In [None]:
idx = randint(0, len(data_train) - 1)
#idx = 19629
print(data_train.dataset.df.iloc[idx])

x, y = data_train[idx]
heatmap = model(x.unsqueeze(0).cuda())[1].squeeze(0).cpu()
show_data(x, (heatmap, *y))