In [None]:
cd '../'

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import numpy as np

from src.utils_freq import rgb2gray, dct, dct2, idct, idct2, batch_dct, getDCTmatrix

from models import linear_model

import matplotlib
import matplotlib.pyplot as plt

# hyper params initalization
_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_printoptions(linewidth=200, edgeitems=5, precision=4)

In [None]:
args = {'d':3, "itr": 200001} 

w_tilde_star = torch.zeros(args['d'], 1)
w_tilde_star[0] = 5
w_tilde_star[1] = 10
w_tilde_star[2] = 0

sigma_tilde = torch.zeros(args['d'], 1)
sigma_tilde[0] = 0.1
sigma_tilde[1] = 0.05
sigma_tilde[2] = 0.

w_star = idct(w_tilde_star)
model = linear_model(args["d"]).to(_device)
w_init = torch.randn_like(model.linear.weight)
w_init[0,0] = 0.01
w_init[0,1] = -0.01
w_init[0,2] = 0.02
w_init = idct(w_init.t()).t()

In [None]:
w_log = torch.zeros(args["d"], args["itr"], 4)
w_log[:,:,0] = torch.load('./ckpt/synthetic-gd.pt')
w_log[:,:,1] = torch.load('./ckpt/synthetic-adam.pt')
w_log[:,:,2] = torch.load('./ckpt/synthetic-rmsprop.pt')
w_log[:,:,3] = torch.load('./ckpt/synthetic-signgd.pt')

In [None]:
def plot_std_adv_risk(log, sigma_tilde, eps = 1, n=3, robust_w = None, save=None):
    
    fix, axs = plt.subplots(ncols = 2, nrows=1, figsize=(20, 3))

    w_log = log.clone().detach()
    w_tilde = torch.zeros_like(w_log)
    
    for i in range(4):
        w_tilde[:,:,i] = batch_dct(w_log[:,:,i].t(), getDCTmatrix(3)).t()
        
    robust_w_tilde = robust_w[1]
    e_tilde = w_tilde - robust_w_tilde.view(3,1,1).repeat(1,w_log.shape[1],4)
    
    
    label = ['GD', 'Adam','RMSProp','SignGD']
    
    for i, method in enumerate(label):

        std_risk = moving_average(e_tilde[0,:,i]**2*sigma_tilde[0] + e_tilde[1,:,i]**2*sigma_tilde[1], n=n).numpy()
        x = np.arange(1,len(std_risk)+1)
        axs[0].plot(x, std_risk, color = "C"+str(i), linewidth=3.0, marker = "", label=method, alpha = 0.8)

        sum_term = sigma_tilde[0]**2 * e_tilde[0,:,i]**2 + sigma_tilde[1]**2 * e_tilde[1,:,i]**2
        adv_risk = moving_average(0.5*(sum_term) + eps*np.sqrt(2/np.pi*sum_term)*np.sqrt((w_tilde[:,:,i]**2).sum(dim=0)) + eps**2/2*(w_tilde[:,:,i]**2).sum(dim=0), n=n).numpy()
        axs[1].plot(x, adv_risk, color = "C"+str(i), linewidth=3.0, marker = "", alpha = 0.8)
        print(adv_risk[-1])
        
        if method == 'GD':
            x = np.arange(-10,len(std_risk)+10,1)
            gd_adv_risk = np.ones_like(x)*adv_risk[-1]
            axs[1].plot(x, gd_adv_risk, color = "C"+str(i), linewidth=2.0, marker = "", linestyle = (0, (5, 5)), alpha = 0.8)
        elif method == 'SignGD':
            x = np.arange(-10,len(std_risk)+10,1)
            signGD_adv_risk = np.ones_like(x)*adv_risk[-1]
            axs[1].plot(x, signGD_adv_risk, color = "C"+str(i), linewidth=2.0, marker = "", linestyle = (0, (5, 5)), alpha = 0.8)
    
    for i, loss in enumerate(["s","a"]):
        axs[i].tick_params(axis="both", labelsize=15)
        axs[i].set_ylabel(r"$\mathcal{R}_{"+loss+"} (t)$", rotation=0, labelpad=20, fontsize=15)
        axs[i].grid()
        axs[i].set_xlabel("Training iteration (t)",fontsize=15)
        axs[i].set_xscale('log')
        
    axs[1].set_yticks([0, 50, 100, 123, 141, 150])

    my_colors = ['k', 'k', 'k', '#1f77b4', '#d62728', 'k']
    my_size = [15, 15, 15, 13, 13, 15]

    for ticklabel, tickcolor, ticksize in zip(axs[1].get_yticklabels(), my_colors, my_size):
        ticklabel.set_color(tickcolor)
        ticklabel.set_size(ticksize)
    
    axs[0].legend(fontsize=15)
    
def moving_average(a, n=3):
    if n==1:
        return a
    else:
        ret = np.cumsum(a, dtype=float)
        ret[n:] = ret[n:] - ret[:-n]

        original_return = ret[n - 1:] / n
        new_return = original_return
        new_return[0]=a[0]
    return new_return

In [None]:
def plot_w_tilde_LR_freq_only(log, n=3, robust_w = None):
    
    fix, axs = plt.subplots(ncols = 3, nrows=1, figsize=(20, 4))

    w_log = log.clone().detach()
    w_tilde = torch.zeros_like(w_log)
    
    for i in range(4):
        w_tilde[:,:,i] = batch_dct(w_log[:,:,i].t(), getDCTmatrix(3)).t()
        
    robust_w_tilde = robust_w[1]
    e_tilde = np.abs(w_tilde - robust_w_tilde.view(3,1,1).repeat(1,w_log.shape[1],4))

    _iteration = 650
    
    label = ['GD', 'Adam','RMSProp','SignGD']
    
    for dim in range(3):
        for i, method in enumerate(label):
            y = moving_average(e_tilde[dim,:,i], n=n).numpy()
            x = np.arange(1, len(y)+1)
            axs[dim].semilogx(x,y, color = "C"+str(i), linewidth=3.0, marker = "", label= method if dim==0 else None, alpha = 0.8)
        axs[dim].grid()
        axs[dim].set_ylabel(r"$|\tilde{e}_{"+str(dim)+"}(t)|$", rotation=0, labelpad=30, fontsize=15)
        axs[dim].tick_params(axis="both", labelsize=15)
        axs[dim].set_xlabel("Training iteration (t)",fontsize=15)

    p0_xlimit, p0_ylimit = axs[0].get_xlim(), axs[0].get_ylim()
    p1_xlimit, p1_ylimit = axs[1].get_xlim(), axs[1].get_ylim()
    p2_xlimit, p2_ylimit = axs[2].get_xlim(), axs[2].get_ylim()
    
    
    axs[0].legend(fontsize=15)
    axs[0].add_patch(matplotlib.patches.Rectangle((_iteration,-1), -1000, 8.5, alpha = 0.1, color='g'))
    axs[0].add_patch(matplotlib.patches.Rectangle((_iteration,-1), 400000, 8.5, alpha = 0.1, color='m'))
    
    axs[1].add_patch(matplotlib.patches.Rectangle((_iteration,-1), -1000, 13, alpha = 0.1, color='g'))
    axs[1].add_patch(matplotlib.patches.Rectangle((_iteration,-1), 400000, 13, alpha = 0.1, color='m'))
    
    axs[2].add_patch(matplotlib.patches.Rectangle((_iteration,-1), -1000, 8.5, alpha = 0.1, color='g',label=r'$|\tilde{e}_2|$ grows till $\tilde{e}_0$ begins'+'\n'+'oscillating around 0'))
    axs[2].add_patch(matplotlib.patches.Rectangle((_iteration,-1), 400000, 8.5, alpha = 0.1, color='m',label=r'$|\tilde{e}_2|$ cannot be corrected'))
    axs[2].legend(fontsize=14,loc=(0.40,0.2))
    
    fix.tight_layout()
    
    axs[0].set_xlim(p0_xlimit)
    axs[1].set_xlim(p1_xlimit)
    axs[2].set_xlim(p2_xlimit)
    axs[0].set_ylim(p0_ylimit)
    axs[1].set_ylim(p1_ylimit)
    axs[2].set_ylim(p2_ylimit)

In [None]:
plot_w_tilde_LR_freq_only(w_log, n=1, robust_w = [w_star, w_tilde_star])

In [None]:
plot_std_adv_risk(w_log, sigma_tilde, eps = np.sqrt(2), n=1, robust_w = [w_star, w_tilde_star])