# Diffusion vs Gradient Diffusion

In [None]:
from tqdm import tqdm
import numpy as np
from loss_functions import TV_loss, TVV_loss, diffusion_loss, grad_diffusion_loss, robust_diffusion_loss, robust_grad_diffusion_loss
from matplotlib import pyplot as plt
import matplotlib
from mpl_toolkits.axes_grid1 import make_axes_locatable
cmap = plt.cm.RdYlBu

In [None]:
momentum=0.99
image = torch.ones(1,1,3,100)
lr = 0.15
rand = torch.randn(1,1,3,100)
case1 = torch.sin(torch.linspace(0,6,100)).view(1,1,1,100).expand(1,1,3,100)
case2 = -1 + 2*torch.abs(torch.linspace(-1,1,100))
iter_per_line = 10
lines = 100

def optimize(tensor, loss_function, weight, momentum, lr, kappa, loss_name, function_name, ax, lines=lines, iter_per_line=iter_per_line):
    pcf = ax.pcolormesh([[0,1],[1,0]],cmap=cmap,vmin=0,vmax=lines*iter_per_line)
    ax.clear()
    ax.plot(tensor[0,0,0].detach().numpy(),color=cmap(0), label='intial value')
    param = torch.nn.Parameter(tensor)
    optimizer = torch.optim.SGD([param], lr=lr, momentum=momentum)
    for i in tqdm(range(lines)):
        for j in range(iter_per_line):
            optimizer.zero_grad()
            loss = weight*loss_function(param, image, kappa)
            loss = loss.mean()
            loss.backward()
            optimizer.step()
        if loss.item() != loss.item():
            break
        ax.plot(tensor[0,0,1].detach().numpy(), color=cmap(i/lines))
    ax.set_title("{} loss on {}\nlr = {}, momentum = {}".format(loss_name, function_name, lr, momentum))
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(pcf,cax=cax, ticks = [0,lines*iter_per_line])
    cbar.set_ticklabels(['0','1e+{}'.format(int(np.log10(lines*iter_per_line)))])
    cbar.ax.tick_params(axis='y', direction='in')
    cbar.set_label('# of iterations', rotation=270)


fig,axes = plt.subplots(4,2,figsize=(15,20),dpi=200,
                        sharex=True,sharey=True)


optimize(case1 + 0.2*rand, TV_loss, 1, 0.9, lr, 50, "TV", "a sinusoid", axes[0,0])
axes[0,0].legend()
optimize(case1 + 0.2*rand, diffusion_loss, 50, 0.9, lr, 50, "Diffusion", "Sinusoid", axes[1,0])
optimize(case1 + 0.2*rand, TVV_loss, 1, 0.9, lr, 50, "TVV", "a sinusoid", axes[2,0], iter_per_line = 100)
optimize(case1 + 0.2*rand, grad_diffusion_loss, 100, 0.99, lr, 50, "Gradient diffusion", "a sinusoid", axes[3,0], iter_per_line=100)


optimize(case2 + 0.2*rand, TV_loss, 1, 0.9, lr, 50, "TV", "Abs function", axes[0,1])
optimize(case2 + 0.2*rand, diffusion_loss, 50, 0.9, lr, 50, "Diffusion", "Abs function", axes[1,1])
optimize(case2 + 0.2*rand, TVV_loss, 4, 0.9, lr, 50, "TVV", "Abs function", axes[2,1], iter_per_line=100)
optimize(case2 + 0.2*rand, grad_diffusion_loss, 100, 0.99, lr, 50, "Gradient diffusion", "Abs function", axes[3,1], iter_per_line=100)

plt.show()

In [None]:
momentum=0.9
lr = 0.1
rand = torch.randn(100)
case1 = 1.5 + (torch.sin(torch.linspace(0,6,100))).view(1,1,1,100).expand(1,1,3,100)
case2 = 2.5*(torch.abs(torch.linspace(-1,1,100))).view(1,1,1,100).expand(1,1,3,100)
case3 = 2.5*(torch.linspace(1,0,100)).view(1,1,1,100).expand(1,1,3,100)
#tensor[:,:,:,20:80] += 0.2
image = torch.ones(1,1,3,100)
image[:,:,:,:19] = 0
image[:,:,:,80:] = 0
iter_per_line = 10
lines = 100

fig,axes = plt.subplots(3,2,figsize=(15,15),dpi=200,
                        sharex=True,sharey=True)

optimize(case1 + 0*rand, diffusion_loss, 100, 0.9, lr, 0.1, "Anisotropic diffusion", "a sinusoid", axes[0,0], lines=100)
optimize(case2 + 0*rand, diffusion_loss, 100, 0.9, lr, 0.1, "Anisotropic diffusion", "Abs function", axes[1,0], lines=100)
optimize(case3 + 0*rand, diffusion_loss, 100, 0.9, lr, 0.1, "Anisotropic diffusion", "a linear function", axes[2,0], lines=100)
optimize(case1 + 0*rand, grad_diffusion_loss, 100, 0.99, lr, 0.1, "Anisotropic gradient diffusion", "a sinusoid", axes[0,1], lines=100, iter_per_line=100)
optimize(case2 + 0*rand, grad_diffusion_loss, 100, 0.99, lr, 0.1, "Anisotropic gradient diffusion", "Abs function", axes[1,1], lines=100, iter_per_line=100)
optimize(case3 + 0*rand, grad_diffusion_loss, 100, 0.99, lr, 0.1, "Anisotropic gradient diffusion", "a linear function", axes[2,1], lines=100, iter_per_line=100)
for ax in axes:
    for a in ax:
        a.plot(image[0,0,0].numpy()*0.5, '--', label='image value')
axes[0,0].legend()


# Diff Loss vs Robust diff loss

In [None]:
momentum=0.99
rand = torch.randn(100).clamp(min=-0.3)
randomness = 0.1
num_lines = 10
iter_btw_lines = 1000
iter_per_line = 10
lines = 100
image = torch.ones(1,1,3,100)

#optimizer = torch.optim.Adam([param], lr=0.05)

def optimize_inverse(tensor, loss_function, weight, momentum, lr, kappa,
                     loss_name, function_name, ax1, ax2, lines=lines, iter_per_line=iter_per_line,
                     gamma=0, iterations=0, inverse=False):
    pcf = ax1.pcolormesh([[0,1],[1,0]],cmap=cmap,vmin=0,vmax=lines*iter_per_line)
    ax1.clear()
    ax1.plot(tensor[0,0,0].detach().numpy(),color=cmap(0), label='intial value')
    ax2.plot(1/tensor[0,0,0].detach().numpy(),color=cmap(0), label='intial value')
    param = torch.nn.Parameter(tensor)
    optimizer = torch.optim.SGD([param], lr=lr, momentum=momentum)
    for i in tqdm(range(lines)):
        for j in range(iter_per_line):
            optimizer.zero_grad()
            to_smooth = 1/param if inverse else param
            if gamma==0:
                loss = weight*loss_function(to_smooth, image, kappa)
            else:
                loss = weight*loss_function(to_smooth, image, kappa, gamma, iterations, loss_function='abs')
            loss = loss.mean()
            loss.backward()
            optimizer.step()
        if loss.item() != loss.item():
            break
        ax1.plot(tensor[0,0,1].detach().numpy(), color=cmap(i/lines))
        ax2.plot(1/tensor[0,0,1].detach().numpy(), color=cmap(i/lines))
    ax1, ax2 = (ax1, ax2) if inverse else (ax2, ax1)
    ax1.set_title("{} loss on \n{}, lr = {}, momentum = {}".format(loss_name, function_name, lr, momentum))
    divider = make_axes_locatable(ax1)
    divider2 = make_axes_locatable(ax2)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    _ = divider2.append_axes("right", size="5%", pad=0.05)
    _.remove()
    cbar = fig.colorbar(pcf,cax=cax, ticks = [0,lines*iter_per_line])
    cbar.set_ticklabels(['0','1e+{}'.format(int(np.log10(lines*iter_per_line)))])
    cbar.ax.tick_params(axis='y', direction='in')
    cbar.set_label('# of iterations', rotation=270, labelpad=-17)

In [None]:
lr=0.1
fig,axes = plt.subplots(2,3,figsize=(20,10),dpi=100,
                        sharex=True,sharey=False)

case = (2 + torch.sin(torch.linspace(0,6,100))).view(1,1,1,100).expand(1,1,3,100)
optimize_inverse(1/(case + randomness*rand), grad_diffusion_loss, 100, 0.99, lr, 500, "Regular gradient diffusion", "an inverse sinusoid", axes[1,0], axes[0,0], lines=100, iter_per_line=100)
optimize_inverse(case + randomness*rand, grad_diffusion_loss, 100, 0.99, lr, 500, "Regular inverse (unstable) gradient diffusion", "a sinusoid", axes[0,1], axes[1,1], lines=100, iter_per_line=100, inverse=True)
optimize_inverse(case + randomness*rand, robust_grad_diffusion_loss, 1, 0.99, lr, 500, "robust inverse gradient diffusion", "a sinusoid", axes[0,2], axes[1,2], lines=100, iter_per_line=100, inverse=True, gamma=0.3, iterations=10)
plt.show()

# Diff vs Robust Diff

In [None]:
lr=15
rand = torch.randn(100).clamp(min=-10.6)
fig,axes = plt.subplots(2,3,figsize=(20,10),dpi=100,
                        sharex=True,sharey=False)

case = (2 + torch.sin(torch.linspace(0,6,100))).view(1,1,1,100).expand(1,1,3,100)
optimize_inverse(1/(case + randomness*rand), diffusion_loss, 1, 0.9, lr, 500, "Regular gradient diffusion", "an inverse sinusoid", axes[1,0], axes[0,0], lines=100, iter_per_line=10)
optimize_inverse(case + randomness*rand, diffusion_loss, 10, 0.9, lr, 500, "Regular inverse (unstable) diffusion", "a sinusoid", axes[0,1], axes[1,1], lines=100, iter_per_line=10, inverse=True)
optimize_inverse(case + randomness*rand, robust_diffusion_loss, 0.01, 0.9, lr, 500, "robust inverse diffusion", "a sinusoid", axes[0,2], axes[1,2], lines=100, iter_per_line=10, inverse=True, gamma=0.3, iterations=10)
plt.show()