In [None]:
import os
from PIL import Image, ImageFilter
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
# Import Annotation3D from matplotlib for 3d scatter plot
from mpl_toolkits.mplot3d.proj3d import proj_transform

from collections import defaultdict
import json
import torch
import torchvision
from sklearn.manifold import MDS
from tqdm import tqdm

ROOT = os.path.join("images", "svg")
# https://github.com/linssen/country-flag-icons

image_paths = [os.path.join(ROOT, nm) for nm in os.listdir(ROOT)]

COUNTRIES_PATH = os.path.join("images", "png", "countries")
png_paths = [os.path.join(COUNTRIES_PATH, nm) for nm in os.listdir(COUNTRIES_PATH)]

In [None]:
with open("json/code2country.json", "r") as f:
    code2country = json.load(f)

country2code = {v: k for k, v in code2country.items()}

code2country

In [None]:
def get_mse_matrix_vectorized(png_paths, size=66):
    """construct two matrices containing all flags, one for each country.
    Then subtract the two matrices and square the result. The mean of this matrix
    is the MSE between the two countries' flags."""
    diff_matrix = get_diff_matrix_vectorized(png_paths, size)**2

    # Take the mean of the matrix (NxN)
    mse_matrix = diff_matrix.mean(axis=1).mean(axis=1).mean(axis=1)
    # Reshape to (N, N)
    mse_matrix = mse_matrix.reshape(len(png_paths), len(png_paths))
    return mse_matrix / mse_matrix.max()

def get_diff_matrix_vectorized(png_paths, size=66, hflip=False, vflip=False):
    """Get the diff_matrix to calculate the mse_matrix"""
    # Flag matrix base (N, 3, 66, 66), where N is the number of countries
    flag_matrix = np.zeros((len(png_paths), 3, size, size))
    for i in tqdm(range(len(png_paths))):
        img = np.array(Image.open(png_paths[i]))[:,:,:3] # (H, W, 3)
        pil = Image.fromarray(img)
        resized = pil.resize((size,size))
        resized = np.array(resized)
        flag_matrix[i] = resized.transpose(2,0,1)

    if flag_matrix.max() > 1:
        flag_matrix = flag_matrix / 255

    # Flag matrix 1 (NxN, C, 66, 66) where the first N 3D matrices are the same, the second N 3D matrices are the same, etc.
    flag_matrix1 = np.repeat(flag_matrix, len(png_paths), axis=0)
    # Flag matrix 2 (NxN, C, 66, 66) where the first N 3D matrices are all different, the second N 3D matrices are all different, etc.
    flag_matrix2 = np.tile(flag_matrix, (len(png_paths), 1, 1, 1))
    if hflip:
        flag_matrix2 = flag_matrix2[:,:,::-1,:]
    if vflip:
        flag_matrix2 = flag_matrix2[:,:,:,::-1]

    # Subtract the two matrices and square the result (NxN, C, 66, 66)
    diff_matrix = flag_matrix1 - flag_matrix2
    return diff_matrix

In [None]:
# We want to link flags together at their top, bottom, left, and right borders (just the outer row of pixels)
# Create link_matrix (N,N,4) where each flag has MSE values for every other flag for each of the 4 borders
def get_link_matrix_vectorized(png_paths, size=66):
    """Use the diff_matrix to get the link_matrix"""
    # Get diff_matrix (NxN, C, 66, 66)
    diff_matrix = get_diff_matrix_vectorized(png_paths, size, vflip=True)**2
    diff_matrix = diff_matrix.mean(axis=1) # (NxN, 66, 66)

    # Get diff_matrix_left (NxN, 66)
    diff_matrix_left = diff_matrix[:,:,0]
    # Get link_matrix_left (N, N)
    link_matrix_left = diff_matrix_left.mean(axis=1).reshape(len(png_paths), len(png_paths))

    # Get diff_matrix_right (NxN, 66)
    diff_matrix_right = diff_matrix[:,:,-1]
    # Get link_matrix_right (N, N)
    link_matrix_right = diff_matrix_right.mean(axis=1).reshape(len(png_paths), len(png_paths))

    # Get diff_matrix (NxN, C, 66, 66)
    diff_matrix = get_diff_matrix_vectorized(png_paths, size, hflip=True)**2
    diff_matrix = diff_matrix.mean(axis=1) # (NxN, 66, 66)

    # Get diff_matrix_top (NxN, 66)
    diff_matrix_top = diff_matrix[:,0,:]
    # Get link_matrix_top (N, N)
    link_matrix_top = diff_matrix_top.mean(axis=1).reshape(len(png_paths), len(png_paths))

    # Get diff_matrix_bottom (NxN, 66)
    diff_matrix_bottom = diff_matrix[:,-1,:]
    # Get link_matrix_bottom (N, N)
    link_matrix_bottom = diff_matrix_bottom.mean(axis=1).reshape(len(png_paths), len(png_paths))

    # Combine the four link matrices into one (N, N, 4)
    link_matrix = np.stack([link_matrix_left, link_matrix_right, link_matrix_top, link_matrix_bottom], axis=2)
    return link_matrix / link_matrix.max()

link_matrix = get_link_matrix_vectorized(png_paths, size=66)

Seems infeasible to get a decent match for most flags

In [None]:
# Make every i=j value the max so that the flag doesn't link to itself
for i in range(len(png_paths)):
    link_matrix[i,i,:] = 1

country = "Russia"
code = country2code[country]
idx = png_paths.index(os.path.join(COUNTRIES_PATH, code + ".png"))

# Get the best linking flag for each border, excluding the flag itself
best_left = np.argmin(link_matrix[idx,:,0])
best_right = np.argmin(link_matrix[idx,:,1])
best_top = np.argmin(link_matrix[idx,:,2])
best_bottom = np.argmin(link_matrix[idx,:,3])

# Plot 3x3 grid of flags (corners are left blank for now) with the main flag in the center and the best linking flags on the borders
fig, ax = plt.subplots(3,3, figsize=(10,10))
ax[0,1].imshow(Image.open(png_paths[best_top]))
ax[1,0].imshow(Image.open(png_paths[best_left]))
ax[1,1].imshow(Image.open(png_paths[idx]))
ax[1,2].imshow(Image.open(png_paths[best_right]))
ax[2,1].imshow(Image.open(png_paths[best_bottom]))

# We also want to make sure that the links at top and bottom are to flags with the same aspect ratio