Imports

In [None]:
import sys
sys.path.insert(0, '..')
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import torch
from typing import Callable
import os

import nn_util
# from nn_util import dist_and_proximity_loss, simple_dist_loss
from main import determine_device

device: torch.device = determine_device(1)


# For saving to LaTeX
# matplotlib.use("pgf")
# matplotlib.rcParams.update({
#     "pgf.texsystem": "pdflatex",
#     'font.family': 'serif',
#     'text.usetex': True,
#     'pgf.rcfonts': False,
# })


Heatmap generation functions

In [None]:
def get_loss(loss_func:Callable, e:tuple[float,float], c:tuple[float,float], C:list[tuple[float,float]]) -> float:
    predicted_embedding = torch.Tensor([e]).to(device)
    target_labels: list[int] = [C.index(c)]
    class_embeddings = torch.Tensor(C).to(device)

    return loss_func(predicted_embedding, class_embeddings, target_labels, device)[0].item()

# 𝑒 = The embedding outputted by the neural network
# 𝑐  = The class embedding for the class of 𝑒
# 𝐶  = The set of class embeddings
def make_heatmap_moving_target_class(
        x_min:int, x_max:int, y_min:int, y_max:int, resolution_upscale:float,
        loss_func:Callable, target:tuple[float,float], all_classes:list[tuple[float,float]]
        ) -> tuple[np.ndarray, tuple[int,int,int,int]]:
    
    heatmap = np.ndarray(shape=(
        int(np.floor((x_max-x_min)*resolution_upscale)),
        int(np.floor((y_max-y_min)*resolution_upscale)))
    )

    for x_i, y_i in np.ndindex(heatmap.shape):
        x = x_i / resolution_upscale + x_min
        y = y_i / resolution_upscale + y_min
        heatmap[x_i, y_i] = get_loss(loss_func, target, (x,y), [(x,y)] + all_classes, )

    return heatmap, (x_min, x_max, y_min, y_max)

def make_heatmap_moving_embedding(
        x_min:int, x_max:int, y_min:int, y_max:int, resolution_upscale:float,
        loss_func:Callable, target:tuple[float,float], all_classes:list[tuple[float,float]]
        ) -> tuple[np.ndarray, tuple[int,int,int,int]]:
    
    heatmap = np.ndarray(shape=(
        int(np.floor((x_max-x_min)*resolution_upscale)),
        int(np.floor((y_max-y_min)*resolution_upscale)))
    )

    for x_i, y_i in np.ndindex(heatmap.shape):
        x = x_i / resolution_upscale + x_min
        y = y_i / resolution_upscale + y_min
        heatmap[x_i, y_i] = get_loss(loss_func, (x,y), target, [target] + all_classes, )

    return heatmap, (x_min, x_max, y_min, y_max)

Configuration

In [None]:
x_min = -30
x_max = 20
y_min = -30
y_max = 20

resolution_upscale = 3

e = (-7,-3)
c = (-7,-3)
all_classes = [(-6, 3),(-1,1),(4,7),(6,0),(4,-6)]

def standardized_heatmap_moving_target(loss_func): 
    return make_heatmap_moving_target_class(x_min, x_max, y_min, y_max, resolution_upscale,
                              loss_func, e, all_classes)

def standardized_heatmap_moving_embedding(loss_func): 
    return make_heatmap_moving_embedding(x_min, x_max, y_min, y_max, resolution_upscale,
                              loss_func, c, all_classes)

def loss_heatmap_moving_target(loss_func, title=None): 
    heatmap, edges = standardized_heatmap_moving_target(loss_func)
    
    plt.clf()
    if title:
        plt.title(title)
        
    plt.imshow(heatmap.T, extent=edges, origin='lower')
    plt.scatter([e[0]], [e[1]], marker="$e$")

    cx = [c[0] for c in all_classes]
    cy = [c[1] for c in all_classes]
    plt.scatter(cx,cy, marker="$C$")
    # plt.show()
    plt.savefig("loss_plots/" + title)
    # plt.savefig("loss_plots/" + title + ".pgf") # Requires a local installation of LaTeX

def loss_heatmap_moving_embedding(loss_func, title=None): 
    heatmap, edges = standardized_heatmap_moving_embedding(loss_func)
    
    plt.clf()
    if title:
        plt.title(title)
        
    plt.imshow(heatmap.T, extent=edges, origin='lower')
    plt.scatter([c[0]], [c[1]], marker="$c$")

    cx = [c[0] for c in all_classes]
    cy = [c[1] for c in all_classes]
    plt.scatter(cx,cy, marker="$C$")
    # plt.show()
    plt.savefig("loss_plots/" + title)
    # plt.savefig("loss_plots/" + title + ".pgf") # Requires a local installation of LaTeX

    

Plotting

In [None]:
loss_heatmap_moving_target(nn_util.simple_dist_loss, title="Simple Loss - Given e")
loss_heatmap_moving_embedding(nn_util.simple_dist_loss, title="Simple Loss - Given c")

In [None]:
r = 1000
loss_heatmap_moving_target(nn_util.dist_and_proximity_loss(r), title="Proximity Loss (r="+str(r)+") - Given e")
loss_heatmap_moving_embedding(nn_util.dist_and_proximity_loss(r), title="Proximity Loss (r="+str(r)+") - Given c")

In [None]:
loss_heatmap_moving_target(nn_util.cone_loss_hyperparam(), title="Cone Loss - Given e")
loss_heatmap_moving_embedding(nn_util.cone_loss_hyperparam(), title="Cone Loss - Given c")

In [None]:
loss_heatmap_moving_target(nn_util.comparison_dist_loss, title="Comparison Loss - Given e")
loss_heatmap_moving_embedding(nn_util.comparison_dist_loss, title="Comparison Loss - Given c")

In [None]:
loss_heatmap_moving_target(nn_util.cosine_loss, title="Cosine Loss - Given e")
loss_heatmap_moving_embedding(nn_util.cosine_loss, title="Cosine Loss - Given c")