In [1]:
import torch
import plotly.express as px
import numpy as np
import pandas as pd
from torch.nn.functional import pad, relu
from collections import defaultdict
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Functions

In [2]:
# Util functions 
def picture_to_coordinates(B):
    radB = len(B)//2
    x = torch.tensor([i-radB for i in range(len(B))], dtype=torch.float32)
    return x

def get_picture_radians(B):
    return len(B)//2

def loss1(B, z, x, F, relu_coeff=1.):
    xrad = get_picture_radians(B)
    l1 = ((B-F)**2).sum()
    l2 = relu_coeff*(relu(-z-xrad)+relu(z-xrad)).sum()
    return l1, l2


In [3]:

# Models
def model1(z, x, k=1.):
    X, Z = torch.meshgrid(x,z)
    E = torch.exp(-k*(X-Z).pow(2))
    F = E.max(-1)[0]
    return F

In [4]:
def run(B, z, optim, n_iter:int=10,
        model_fn=model1, model_kwargs=None, 
        loss_fn=loss1, loss_kwargs=None):
  # Setup 
  x = picture_to_coordinates(B)
  xrad = get_picture_radians(B)
  tracker = defaultdict(list)
  if loss_kwargs is None:
    loss_kwargs = dict()
  if  model_kwargs is None:
    model_kwargs = dict()
  
  for i in range(1, n_iter+1):
    tracker['iter'].append(i)
    for zi in range(len(z)):
      tracker[f'z_{zi}'].append(z[zi].item())

    F = model_fn(z, x, **model_kwargs)
    losses = loss_fn(B, z, x, F, **loss_kwargs)
    l = sum(losses)
    optim.zero_grad()
    l.backward()
    optim.step()

    tracker['loss'].append(l.item())
    for li in range(len(losses)):
      tracker[f'loss_l{li}'].append(losses[li].item())
    for zi in range(len(z)):
      tracker[f'zgrad_{zi}'].append(z.grad[zi].item())
  return pd.DataFrame(tracker)
      

In [6]:
# Plotting functions
def loss_over_iter(df, color='rgb(34,100,192)'):
    loss_go = go.Scatter(
        x=df['iter'],
        y=df['loss'],
        name='loss',
        marker=dict(
            color= color
            )
    )
    return loss_go

def losses_over_iter(df, colors=None):
    loss_df = df.filter(regex='^loss_')
    if colors is None:
        colors = list()
        for i in range(len(loss_df.columns)):
            r,g,b = np.random.randint(0,255, 3)
            colors.append(f'rgb({r},{g},{b})')
        
    gos = list()
    for li, col in enumerate(loss_df):
        loss_go = go.Scatter(
            x=df['iter'],
            y=df[col],
            name=col,
            marker=dict(
                color= colors[li]
                )
        )
        gos.append(loss_go)
    return gos

def loss_over_space(df, color='rgb(34,100,192)'):
    loss_go = go.Scatter(
        x=df['x'],
        y=df['loss'],
        name='loss',
        marker=dict(
            color= color
            )
    )
    return loss_go

In [13]:
def trace_z(run_df):
    traces = list()
    for z_label in run_df.filter(regex='^z_'):
        z_go = go.Scatter(
            x=run_df[z_label],
            y=run_df['loss'],
            name='loss'
        )
        traces.append(z_go)
    return traces


In [7]:
def loss_landscape(z, B, n_samples:int=100, 
        model_fn=model1, model_kwargs=None, 
        loss_fn=loss1, loss_kwargs=None):
    # Setup 
    x = picture_to_coordinates(B)
    xrad = get_picture_radians(B)
    tracker = defaultdict(list)
    if loss_kwargs is None:
        loss_kwargs = dict()
    if  model_kwargs is None:
        model_kwargs = dict()
    xrange = np.linspace(-1.2*xrad, 1.2*xrad, n_samples)
    for i, xcoord in enumerate(xrange):
        z = torch.tensor([xcoord], dtype=torch.float32)
        F = model_fn(z, x, **model_kwargs)
        losses = loss_fn(B, z, x, F, **loss_kwargs)
        l = sum(losses)
        tracker['x'].append(xcoord)
        tracker['loss'].append(l.item())
        for li in range(len(losses)):
            tracker[f'loss_l{li}'].append(losses[li].item())
    return pd.DataFrame(tracker)

# Setup

In [8]:
# Image
B = torch.tensor([0, 1, 0, 1, 0])
# Initial coordinates
z = torch.tensor([-2.4, 2.1], requires_grad=True)
# Optimizer
optim = torch.optim.Adam([z], lr=1e-1)

# Run

In [9]:
rundf = run(B, z, optim, n_iter=10)
landscapedf = loss_landscape(z, B, n_samples=100)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [12]:
rundf

Unnamed: 0,iter,z_0,z_1,loss,loss_l0,loss_l1,zgrad_0,zgrad_1
0,1,-2.4,2.1,3.437147,2.937147,0.5,-0.515858,1.527493
1,2,-2.3,2.0,3.200191,2.90019,0.3,-0.78013,0.927493
2,3,-2.200975,1.904097,3.000808,2.799834,0.200975,-1.125598,1.262972
3,4,-2.102927,1.806685,2.735438,2.63251,0.102927,-1.516757,1.512182
4,5,-2.005347,1.708001,2.411409,2.406062,0.005347,-1.909531,1.641218
5,6,-1.907606,1.60837,2.135229,2.135229,0.0,-1.25732,1.63618
6,7,-1.809443,1.508076,1.839657,1.839657,0.0,-1.517065,1.506088
7,8,-1.71019,1.407633,1.540418,1.540418,0.0,-1.659519,1.27891
8,9,-1.609854,1.307942,1.258566,1.258566,0.0,-1.67334,0.992635
9,10,-1.508665,1.210329,1.011615,1.011615,0.0,-1.570801,0.683819


In [10]:
iterfig = make_subplots()
iterfig.add_trace(loss_over_iter(rundf, color='rgb(0,200,0)'))
for g in losses_over_iter(rundf):
    iterfig.add_trace(g)
iterfig


In [14]:
landscapefig = make_subplots()
landscapefig.add_trace(loss_over_space(landscapedf))
landscapefig.add_vline(x=-get_picture_radians(B))
landscapefig.add_vline(x=get_picture_radians(B))
for g in trace_z(rundf):
    landscapefig.add_trace(g)

In [15]:
landscapefig

233