In [None]:
import torch
import plotly.express as px
import numpy as np
import pandas as pd
from torch.nn.functional import pad, relu
from collections import defaultdict
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Functions

In [None]:
# Util functions 
def picture_to_coordinates(B):
    radB = len(B)//2
    x = torch.tensor([i-radB for i in range(len(B))], dtype=torch.float32)
    return x

def get_picture_radians(B):
    return len(B)//2

def loss1(B, z, x, F, relu_coeff=1.):
    xrad = get_picture_radians(B)
    l1 = ((B-F)**2).sum()
    l2 = relu_coeff*(relu(-z-xrad)+relu(z-xrad)).sum()
    return l1, l2

##Models

In [None]:

# Models
def model1(z, x, k=1.):
    X, Z = torch.meshgrid(x,z)
    E = torch.exp(-k*(X-Z).pow(2))
    F = E.max(-1)[0]
    return F

In [None]:
def run(B, z, optim, n_iter:int=10,
        model_fn=model1, model_kwargs=None, 
        loss_fn=loss1, loss_kwargs=None):
  # Setup 
  x = picture_to_coordinates(B)
  xrad = get_picture_radians(B)
  tracker = defaultdict(list)
  if loss_kwargs is None:
    loss_kwargs = dict()
  if  model_kwargs is None:
    model_kwargs = dict()
  
  for i in range(1, n_iter+1):
    tracker['iter'].append(i)
    for zi in range(len(z)):
      tracker[f'z_{zi}'].append(z[zi].item())

    F = model_fn(z, x, **model_kwargs)
    losses = loss_fn(B, z, x, F, **loss_kwargs)
    l = sum(losses)
    optim.zero_grad()
    l.backward()
    optim.step()

    tracker['loss'].append(l.item())
    for li in range(len(losses)):
      tracker[f'loss_l{li}'].append(losses[li].item())
    for zi in range(len(z)):
      tracker[f'zgrad_{zi}'].append(z.grad[zi].item())
  return pd.DataFrame(tracker)
      

In [None]:
# Plotting functions
def loss_over_iter(df, color='rgb(34,100,192)'):
    loss_go = go.Scatter(
        x=df['iter'],
        y=df['loss'],
        name='loss',
        marker=dict(
            color= color
            )
    )
    return loss_go

def losses_over_iter(df, colors=None):
    loss_df = df.filter(regex='^loss_')
    if colors is None:
        colors = list()
        for i in range(len(loss_df.columns)):
            r,g,b = np.random.randint(0,255, 3)
            colors.append(f'rgb({r},{g},{b})')
        
    gos = list()
    for li, col in enumerate(loss_df):
        loss_go = go.Scatter(
            x=df['iter'],
            y=df[col],
            name=col,
            marker=dict(
                color= colors[li]
                )
        )
        gos.append(loss_go)
    return gos


def loss_over_space(df, color='rgb(34,100,192)'):
    loss_go = go.Scatter(
        x=df['x'],
        y=df['loss'],
        name='loss',
        marker=dict(
            color= color
            )
    )
    return loss_go

In [None]:
def trace_z(run_df):
    traces = list()
    for z_label in run_df.filter(regex='^z_'):
        z_go = go.Scatter(
            x=run_df[z_label],
            y=run_df['loss'],
            name='z_label'
        )
        traces.append(z_go)
    return traces


## Losses

In [None]:
def loss_landscape(z, B, n_samples:int=100, 
        model_fn=model1, model_kwargs=None, 
        loss_fn=loss1, loss_kwargs=None):
    # Setup 
    x = picture_to_coordinates(B)
    xrad = get_picture_radians(B)
    tracker = defaultdict(list)
    if loss_kwargs is None:
        loss_kwargs = dict()
    if  model_kwargs is None:
        model_kwargs = dict()
    xrange = np.linspace(-1.2*xrad, 1.2*xrad, n_samples)
    for i, xcoord in enumerate(xrange):
        z = torch.tensor([xcoord], dtype=torch.float32)
        F = model_fn(z, x, **model_kwargs)
        losses = loss_fn(B, z, x, F, **loss_kwargs)
        l = sum(losses)
        tracker['x'].append(xcoord)
        tracker['loss'].append(l.item())
        for li in range(len(losses)):
            tracker[f'loss_l{li}'].append(losses[li].item())
    return pd.DataFrame(tracker)

# Setup

In [None]:
# Image
B = torch.tensor([0, 1, 0, 1, 0])
# Initial coordinates
z = torch.tensor([-2.4, 2.1], requires_grad=True)
# Optimizer
optim = torch.optim.Adam([z], lr=1e-1)

# Run

In [None]:
rundf = run(B, z, optim, n_iter=100)
landscapedf = loss_landscape(z, B, n_samples=100)

In [None]:
rundf

## Plotting

In [None]:
iterfig = make_subplots()
iterfig.add_trace(loss_over_iter(rundf, color='rgb(0,200,0)'))
for g in losses_over_iter(rundf):
    iterfig.add_trace(g)
iterfig


In [None]:
landscapefig = make_subplots()
landscapefig.add_trace(loss_over_space(landscapedf))
landscapefig.add_vline(x=-get_picture_radians(B))
landscapefig.add_vline(x=get_picture_radians(B))
for g in trace_z(rundf):
    landscapefig.add_trace(g)

In [None]:
landscapefig

In [None]:
# Include app2 
x = torch.tensor([1,-1.])
w = torch.arange(10)-5
l  = (w.T-x)

In [None]:
# Include app2 
x = torch.tensor([1,-1.])
w = torch.arange(10)-5
l  = (w.T-x)
y = w.T-np.linalg.eig(w@x.T)