In [2]:
import torch
import plotly.express as px
import numpy as np
import pandas as pd
from torch.nn.functional import pad, relu
from collections import defaultdict

# Functions

In [14]:
# Util functions 
def picture_to_coordinates(B):
    radB = len(B)//2
    x = torch.tensor([i-radB for i in range(len(B))], dtype=torch.float32)
    return x

def get_picture_radians(B):
    return len(B)//2

def loss1(B, z, x, F, relu_coeff=1.):
    xrad = get_picture_radians(B)
    l1 = ((B-F)**2).sum()
    l2 = relu_coeff*(relu(-z-xrad)+relu(z-xrad)).sum()
    return l1, l2


In [35]:
def model1(z, x, k=1.):
    X, Z = torch.meshgrid(x,z)
    E = torch.exp(-k*(X-Z).pow(2))
    F = E.max(-1)[0]
    return F

In [37]:
def run(B, z, optim, n_iter:int=10,
        model_fn=model1, model_kwargs=None, 
        loss_fn=loss1, loss_kwargs=None):
  # Setup 
  x = picture_to_coordinates(B)
  xrad = get_picture_radians(B)
  tracker = defaultdict(list)
  if loss_kwargs is None:
    loss_kwargs = dict()
  if  model_kwargs is None:
    model_kwargs = dict()
  
  for i in range(1, n_iter+1):
    tracker['iter'].append(i)
    for zi in range(len(z)):
      tracker[f'z_{zi}'].append(z[zi].item())

    F = model_fn(z, x, **model_kwargs)
    losses = loss_fn(B, z, x, F, **loss_kwargs)
    l = sum(losses)
    optim.zero_grad()
    l.backward()
    optim.step()

    tracker['loss'].append(l.item())
    for li in range(len(losses)):
      tracker[f'loss_l{li}'].append(losses[li].item())
    for zi in range(len(z)):
      tracker[f'zgrad_{zi}'].append(z.grad[zi].item())
  return pd.DataFrame(tracker)
      

In [38]:
# Plotting functions
def loss_over_iter(df):
    return px.scatter(df, x='iter', y='loss')

In [43]:
def loss_landscape(z, B, n_samples:int=100, 
        model_fn=model1, model_kwargs=None, 
        loss_fn=loss1, loss_kwargs=None):
    # Setup 
    x = picture_to_coordinates(B)
    xrad = get_picture_radians(B)
    tracker = defaultdict(list)
    if loss_kwargs is None:
        loss_kwargs = dict()
    if  model_kwargs is None:
        model_kwargs = dict()
    xrange = np.linspace(-1.2*xrad, 1.2*xrad, n_samples)
    for i, xcoord in enumerate(xrange):
        z = torch.tensor([xcoord], dtype=torch.float32)
        F = model_fn(z, x, **model_kwargs)
        losses = loss_fn(B, z, x, F, **loss_kwargs)
        l = sum(losses)
        tracker['x'].append(xcoord)
        tracker['loss'].append(l.item())
        for li in range(len(losses)):
            tracker[f'loss_l{li}'].append(losses[li].item())
    return pd.DataFrame(tracker)

# Setup

In [22]:
# Image
B = torch.tensor([0, 1, 0, 1, 0])
# Initial coordinates
z = torch.tensor([-0.4, 2.1], requires_grad=True)
# Optimizer
optim = torch.optim.Adam([z], lr=1e-1)

# Run

In [45]:
rundf = run(B, z, optim, n_iter=10)
landscapedf = loss_landscape(z, B, n_samples=100)

In [47]:
landscapedf

Unnamed: 0,x,loss,loss_l0,loss_l1
0,-2.400000,2.864264,2.464264,0.400000
1,-2.351515,2.836533,2.485018,0.351515
2,-2.303030,2.802617,2.499586,0.303030
3,-2.254545,2.761463,2.506917,0.254545
4,-2.206061,2.712161,2.506100,0.206061
...,...,...,...,...
95,2.206061,2.712161,2.506100,0.206061
96,2.254545,2.761463,2.506917,0.254545
97,2.303030,2.802617,2.499586,0.303030
98,2.351515,2.836533,2.485018,0.351515


In [30]:
p = loss_over_iter(df)

In [32]:
p.show()