Note: Running this notebook requires a local installation of LaTeX

Imports

In [None]:
import sys
sys.path.insert(0, '..')
import re
import math
from typing import Callable
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.backends.backend_pgf import FigureCanvasPgf
matplotlib.backend_bases.register_backend('pdf', FigureCanvasPgf)

import multiprocessing.dummy as multiprocessing

import nn_util

device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

matplotlib.rcParams.update({
    "pgf.texsystem": "pdflatex",
    'pgf.rcfonts': False,
    'figure.figsize': (3.15,3.15)
})


Heatmap generation functions

In [None]:
def get_loss(loss_func:Callable, e:tuple[float,float], c:tuple[float,float], C:list[tuple[float,float]]) -> float:
    predicted_embedding = torch.Tensor([e]).to(device)
    target_labels: list[int] = [C.index(c)]
    class_embeddings = torch.Tensor(C).to(device)

    return loss_func(predicted_embedding, class_embeddings, target_labels, device)[0].item()

# 𝑒 = The embedding outputted by the neural network
# 𝑐  = The class embedding for the class of 𝑒
# 𝐶  = The set of class embeddings
def make_heatmap_moving_target_class(
        x_min:int, x_max:int, y_min:int, y_max:int, resolution_upscale:float,
        loss_func:Callable, target:tuple[float,float], all_classes:list[tuple[float,float]]
        ) -> tuple[np.ndarray, tuple[int,int,int,int]]:
    
    heatmap = np.ndarray(shape=(
        int(np.floor((x_max-x_min)*resolution_upscale)),
        int(np.floor((y_max-y_min)*resolution_upscale)))
    )

    for x_i, y_i in np.ndindex(heatmap.shape):
        x = x_i / resolution_upscale + x_min
        y = y_i / resolution_upscale + y_min
        heatmap[x_i, y_i] = get_loss(loss_func, target, (x,y), [(x,y)] + all_classes, )

    return heatmap, (x_min, x_max, y_min, y_max)

def make_heatmap_moving_embedding(
        x_min:int, x_max:int, y_min:int, y_max:int, resolution_upscale:float,
        loss_func:Callable, target:tuple[float,float], all_classes:list[tuple[float,float]]
        ) -> tuple[np.ndarray, tuple[int,int,int,int]]:
    
    heatmap = np.ndarray(shape=(
        int(np.floor((x_max-x_min)*resolution_upscale)),
        int(np.floor((y_max-y_min)*resolution_upscale)))
    )

    for x_i, y_i in np.ndindex(heatmap.shape):
        x = x_i / resolution_upscale + x_min
        y = y_i / resolution_upscale + y_min
        heatmap[x_i, y_i] = get_loss(loss_func, (x,y), target, [target] + all_classes, )

    return heatmap, (x_min, x_max, y_min, y_max)

def get_lowest_loss(heatmap:np.ndarray) -> tuple[int, int]:
    argmin = heatmap.argmin()
    width = heatmap[0].size
    x = int(math.floor(argmin / width))
    y = int(math.floor(argmin % width))

    return x, y


Configuration

In [None]:
x_min = -15
x_max = 15
y_min = -15
y_max = 15

print_plots = False # True: Show plots in notebook, False, save them
resolution_upscale = 5

e = (-2,2)
c = (-2,2)
all_classes = [(-1, 8),(4,6)]

# e = (-7,-3)
# c = (-7,-3)
# all_classes = [(-6, 3),(-1,1),(4,7),(6,0),(4,-6)]

# text_color = "#F80"
text_color = "#3E3"

x_ticks = [i for i in range(x_min, x_max+1, 5)]
y_ticks = [i for i in range(y_min, y_max+1, 5)]


def standardized_heatmap_moving_target(loss_func): 
    return make_heatmap_moving_target_class(x_min, x_max, y_min, y_max, resolution_upscale,
                              loss_func, e, all_classes)

def standardized_heatmap_moving_embedding(loss_func): 
    return make_heatmap_moving_embedding(x_min, x_max, y_min, y_max, resolution_upscale,
                              loss_func, c, all_classes)


def pgf_fix_png_path(filename, latex_rel_path):
    with open(filename, "r") as fid:
        lines = fid.readlines()

    with open(filename, "w") as fid:
        for line in lines:
            fid.write(re.sub(r"(\\includegraphics\[.*?\]\{)(.*?\})", r"\1"+ latex_rel_path +r"\2", line))

def save_plot(name):
    if print_plots:
        plt.show()
        
    else:
        save_path = "loss_plots/" + name + ".pdf"
        plt.savefig(save_path, bbox_inches="tight")


def plot_base(title, heatmap, edges):
    lowest_loss_x, lowest_loss_y = get_lowest_loss(heatmap)
    print("Lowest loss: (" + str(lowest_loss_x/resolution_upscale+x_min) + ", " + str(lowest_loss_y/resolution_upscale+y_min) + ") = " + str(heatmap[lowest_loss_x, lowest_loss_y]))

    # Title
    if title:
        plt.title(title)

    # Color map
    plt.set_cmap(matplotlib.colormaps["magma"])
    
    # Force tick locations
    plt.xticks(x_ticks)
    plt.yticks(y_ticks)

    # Make colorbar same height as main body
    # plt.figure()
    axes = plt.gca()

    # Display heatmap
    heatmap_image = axes.imshow(heatmap.T, extent=edges, origin='lower')

    # Colorbar
    divider = make_axes_locatable(axes)
    cax = divider.append_axes("right", size="5%", pad=0.05)

    heatmap_min = np.nanmin(heatmap)
    heatmap_max = np.nanmax(heatmap)
    heatmap_range = heatmap_max - heatmap_min
    plt.colorbar(heatmap_image, ticks=[i/7 * heatmap_range + heatmap_min for i in range(7+1)], cax=cax)
    # plt.colorbar()

    # C markers
    for i in range(len(all_classes)):
        x = all_classes[i][0]
        y = all_classes[i][1]
        axes.text(x, y, "\\textbf{$C_"+str(i)+"$}", fontsize=12, color=text_color)

    # Hide grid
    axes.grid(False)

    return axes

def loss_heatmap_moving_target(loss_func, title=None, save_name=None): 
    heatmap, edges = standardized_heatmap_moving_target(loss_func)
    
    plt.clf()

    if save_name == None:
        save_name = title
        
    plot = plot_base(title, heatmap, edges)

    plot.text(e[0], e[1], "e", fontsize=12, color=text_color)

    save_plot(save_name)


def loss_heatmap_moving_embedding(loss_func, title=None, save_name=None): 
    heatmap, edges = standardized_heatmap_moving_embedding(loss_func)

    plt.clf()
    if save_name == None:
        save_name = title
        
    plot = plot_base(title, heatmap, edges)

    plot.text(c[0], c[1], "$c$", fontsize=12, color=text_color)
    
    save_plot(save_name)

    

Plotting

In [None]:
loss_heatmap_moving_target(nn_util.simple_dist_loss, title="Euclidean Loss - c=(x,y)")
loss_heatmap_moving_embedding(nn_util.simple_dist_loss, title="Euclidean Loss - e=(x,y)")

In [None]:
r = 500
loss_heatmap_moving_target(nn_util.dist_and_proximity_loss(r), title="Proximity Loss (r="+str(r)+") - c=(x,y)", save_name="Proximity Loss - c=(x,y)", )
loss_heatmap_moving_embedding(nn_util.dist_and_proximity_loss(r), title="Proximity Loss (r="+str(r)+") - e=(x,y)", save_name="Proximity Loss - e=(x,y)", )

In [None]:
loss_heatmap_moving_target(nn_util.cone_loss_hyperparam(), title="Cone Loss - c=(x,y)")
loss_heatmap_moving_embedding(nn_util.cone_loss_hyperparam(), title="Cone Loss - e=(x,y)")

In [None]:
loss_heatmap_moving_target(nn_util.comparison_dist_loss, title="Comparison Loss - c=(x,y)")
loss_heatmap_moving_embedding(nn_util.comparison_dist_loss, title="Comparison Loss - e=(x,y)")

In [None]:
loss_heatmap_moving_target(nn_util.cosine_loss, title="Cosine Loss - c=(x,y)")
loss_heatmap_moving_embedding(nn_util.cosine_loss, title="Cosine Loss - e=(x,y)")

In [None]:
q = 0.7

loss_heatmap_moving_target(nn_util.pnp_hyperparam(q), title="Push \\& Pull Loss (q="+str(q)+") - c=(x,y)", save_name="Pnp Loss - c=(x,y)")
loss_heatmap_moving_embedding(nn_util.pnp_hyperparam(q), title="Push \\& Pull Loss (q="+str(q)+") - c=(x,y)", save_name="Pnp Loss - e=(x,y)")

# startIndex = 0
# endIndex = 150
# for i in range (startIndex, endIndex + 1):
#     q = round(i/100, 4)
#     pnp_loss = nn_util.pnp_hyperparam(q)

#     q_str = "{:.2f}".format(q)

#     # loss_heatmap_moving_target(nn_util.simple_dist_loss, title="Simple Loss - c=(x,y)")
#     # loss_heatmap_moving_target(pnp_loss, title="pnp Loss (r="+str(r)+") - c=(x,y)")
#     # loss_heatmap_moving_embedding(pnp_loss, title=str(i-startIndex)+"_pnp Loss (q="+str(q)+") - e=(x,y)")
#     loss_heatmap_moving_embedding(pnp_loss, title="pnp Loss (q="+q_str+") - e=(x,y)", save_name = "pnp Loss - e=(x,y)_"+str(i-startIndex))