In [None]:
import os
import random
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

from skimage import io
from skimage.metrics import structural_similarity as ssim

In [None]:
# Helper functions
def read_img(img_path):
    return io.imread(img_path)[:,:,0]

def show_img(img, ax):
    plt.imshow(img, cmap='gray')
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    _ = [line.set_marker('None') for line in ax.get_xticklines()] 
    _ = [line.set_marker('None') for line in ax.get_yticklines()] 

In [None]:
root_dir = '../results/week_0101/cond_gif_2_one'
save_dir = '../results/week_0101/imgs_for_final_model'
os.makedirs(save_dir)

In [None]:
num_val_imgs = 20
topk = 5  # top-k to keep (via SSIM)
epoch = 1
iter = 10
for model_idx in tqdm(range(len(os.listdir(root_dir)))):
    subplot_idx = 1
    for val_idx in sorted(os.listdir(os.path.join(root_dir, str(model_idx)))):
        val_img_path = os.path.join(root_dir, str(model_idx), val_idx)
        ref_img = read_img(os.path.join(val_img_path, '0_cond_img.png'))
        ax = plt.subplot(num_val_imgs, topk+1, subplot_idx)
        # if val_idx == '0':
        #     ax.set_title('Input', fontsize=8)
        show_img(ref_img, ax)
        subplot_idx += 1
        # Get scores
        comp_imgs = [read_img(os.path.join(val_img_path, img_f)) for img_f in
            os.listdir(val_img_path) if 'cond' not in img_f]
        scores = [ssim(ref_img, comp_img) for comp_img in comp_imgs]
        # Sort by score and display images
        idxs = np.argsort(scores)
        for idx in list(reversed(idxs))[:topk]:
            ax = plt.subplot(num_val_imgs, topk+1, subplot_idx)
            show_img(comp_imgs[idx], ax)
            subplot_idx += 1
    fig = plt.gcf()
    fig.set_size_inches(topk*0.8, num_val_imgs*0.8)
    # fig.suptitle("Epoch %i, Iteration %iK" % (epoch, iter), fontsize=10)
    if iter == 30:
        epoch += 1
    iter += 10
    if iter > 30:
        iter = 10
    plt.savefig(os.path.join(save_dir, '%02i.png' % model_idx), dpi=200)
    plt.close()