# __init__

In [None]:
#@title imports
# Load the Drive helper and mount
from google.colab import drive
drive.mount('/content/drive')

import numpy as np

import torch
import matplotlib.pyplot as plt

import plotly
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

path='/content/drive/My Drive/Colab Notebooks/PADAM/'


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
#@title VGG
'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn
from torch.autograd import Variable


cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name, num_classes = 10):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

# net = VGG('VGG11')
# x = torch.randn(2,3,32,32)
# print(net(Variable(x)).size())

In [None]:
#@title plot
def plot_re(methods, dataset, net='vggnet'):
    checkpoints={}
    for i in methods:
      try:
          checkpoints[i]=torch.load(f'{path}/checkpoint/{net}/{dataset}/{i}',map_location=torch.device('cpu'))
      except Exception as e:
        print(e)
    fig, ax = plt.subplots(2,figsize=(15, 15))

    for i in checkpoints.keys():
      ax[0].plot(checkpoints[i]['train_losses'], label=i)
    ax[0].set_title('train_losses')
    ax[0].legend()
    ax[0].set_xlabel('Epochs')
    ax[0].set_ylim([0, 1])

    for i in checkpoints.keys():
      ax[1].plot(checkpoints[i]['test_errs'], label=i)
    ax[1].set_title('test_errs')
    ax[1].legend()
    ax[1].set_xlabel('Epochs')
    if dataset == 'CIFAR10':
        ax[1].set_ylim([0, 0.2])
    elif dataset == 'CIFAR100':
        ax[1].set_ylim([0.2, 0.6])

    fig.suptitle(f'{net} {dataset}')
    plt.show()

def plot_lrs(methods, dataset, net='vggnet'):
    checkpoints={}
    for i in checkpoints.keys():
      try:
          checkpoints[i]=torch.load(f'{path}/checkpoint/{net}/{dataset}/{i}',map_location=torch.device('cpu'))
      except Exception as e:
        print(e)
    fig, ax = plt.subplots(1,figsize=(15, 15))

    for i in checkpoints.keys():
      ax.plot(checkpoints[i]['lrs'], label=i)
    ax.set_title('lrs')
    ax.legend()
    ax.set_xlabel('Epochs')

    fig.suptitle(net)
    plt.show()

def plotly_plot(methods, dataset, net='vggnet',name='plot',benchmark_acc = 93.74):
    checkpoints={}
    for i in methods:
      try:
          temp = torch.load(f'{path}/checkpoint/{net}/{dataset}/{i}',map_location=torch.device('cpu'))
          best_acc = (100-np.array(temp['test_errs']).min()*100)
          print(f"{i} best acc: {best_acc.round(3)}")
          if best_acc >= benchmark_acc or i in ['sgdm','adam']:
            checkpoints[i]=temp
          else:
            continue
      except Exception as e:
        continue
    fig = make_subplots(rows=2, cols=1,subplot_titles=("train_losses", "test_errs"),        
                        column_widths=[0.6], row_heights=[20, 20], shared_xaxes=True,vertical_spacing=0.1)

    for i, m in enumerate(checkpoints.keys()):
        fig.add_trace(go.Scatter(y=checkpoints[m]['train_losses'], x=list(range(1,200)),
                      mode='lines+markers',legendgroup=f'group{m}',line=dict(color=px.colors.qualitative.Alphabet[i]),
                      name=m),row=1, col=1)
        fig.update_traces(marker={'size': 3})

    for i, m in enumerate(checkpoints.keys()):
        fig.add_trace(go.Scatter(y=checkpoints[m]['test_errs'], x=list(range(1,200)),
                      mode='lines+markers',showlegend = False, legendgroup=f'group{m}',
                      line=dict(color=px.colors.qualitative.Alphabet[i]),
                      name=m),row=2, col=1)
        fig.update_traces(marker={'size': 3})
        
    fig.update_layout(
        title_text=f"{dataset} {net}",
        autosize=False,
        width=1200,
        height=800,
        margin=dict(
            l=10,
            r=10,
            b=10,
            t=50,
            pad=1
        ),
        plot_bgcolor="white",
    )
    fig.update_layout(legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="right",
        x=0.99,
        font=dict(
            size=20
    )
    ))
    fig.update_xaxes(showline=True, linewidth=2, linecolor='black')
    fig.update_yaxes(showline=True, linewidth=2, linecolor='black')
    fig.write_html(f'{path}/checkpoint/{net}/{dataset}/plots/{name}.html')
    fig.show()

def check_best_acc(methods,dataset,net):
    for i in methods:
        try:
            temp=torch.load(f'{path}/checkpoint/{net}/{dataset}/{i}',map_location=torch.device('cpu'))
            print(f"{i} best acc: {(100-np.array(temp['test_errs']).min()*100).round(3)}")
        except Exception as e:
            continue
          # print(e)

In [None]:
LRS = [0.001,0.002,0.003,0.004,0.005,0.006,0.007,0.008,0.009,
       0.01,0.015,0.02,0.025,0.03,0.035,0.04,0.045,0.05,0.055,
       0.06,0.065,0.07,0.075,0.08,0.085,0.09,0.095,0.1,0.11,
       0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]

# SGDM and Adaptive Gradient methods

In [None]:

dataset = 'CIFAR10'

methods=['sgdm','adam','amsgrad','adamw','adabound','padam']

plotly_plot(methods,dataset, net='vggnet',name='compare',benchmark_acc=0)


sgdm best acc: 93.74
adam best acc: 92.47
amsgrad best acc: 93.05
adamw best acc: 93.15
adabound best acc: 92.77
padam best acc: 94.06


# Adam Switch SGDM

In [None]:
#@title Switch = 15

switch = 15
lr = 0.001
dataset = 'CIFAR10'

name = f'adam_switch_sgdm_switch={switch}_lr={lr}'
methods=['sgdm','adam']
for lr_sgdm in LRS:
  methods.append(f'adam_switch_sgdm_switch={switch}_lr={lr}_{lr_sgdm}')

plotly_plot(methods,dataset, net='vggnet',name=name,benchmark_acc=93.74)


sgdm best acc: 93.74
adam best acc: 92.47
adam_switch_sgdm_switch=15_lr=0.001_0.001 best acc: 91.35
adam_switch_sgdm_switch=15_lr=0.001_0.003 best acc: 92.14
adam_switch_sgdm_switch=15_lr=0.001_0.005 best acc: 92.79
adam_switch_sgdm_switch=15_lr=0.001_0.007 best acc: 93.1
adam_switch_sgdm_switch=15_lr=0.001_0.009 best acc: 92.95
adam_switch_sgdm_switch=15_lr=0.001_0.01 best acc: 93.05
adam_switch_sgdm_switch=15_lr=0.001_0.03 best acc: 93.5
adam_switch_sgdm_switch=15_lr=0.001_0.05 best acc: 93.79
adam_switch_sgdm_switch=15_lr=0.001_0.07 best acc: 93.74
adam_switch_sgdm_switch=15_lr=0.001_0.1 best acc: 93.74
adam_switch_sgdm_switch=15_lr=0.001_0.11 best acc: 93.64


In [None]:
#@title Switch = 25
switch = 25
lr = 0.001
dataset = 'CIFAR10'

name = f'adam_switch_sgdm_switch={switch}_lr={lr}'
methods=['sgdm','adam']
for lr_sgdm in LRS:
  methods.append(f'adam_switch_sgdm_switch={switch}_lr={lr}_{lr_sgdm}')

plotly_plot(methods,dataset, net='vggnet',name=name,benchmark_acc=93.7)

sgdm best acc: 93.74
adam best acc: 92.47
adam_switch_sgdm_switch=25_lr=0.001_0.001 best acc: 92.04
adam_switch_sgdm_switch=25_lr=0.001_0.003 best acc: 92.91
adam_switch_sgdm_switch=25_lr=0.001_0.005 best acc: 92.75
adam_switch_sgdm_switch=25_lr=0.001_0.007 best acc: 92.82
adam_switch_sgdm_switch=25_lr=0.001_0.009 best acc: 93.08
adam_switch_sgdm_switch=25_lr=0.001_0.01 best acc: 92.96
adam_switch_sgdm_switch=25_lr=0.001_0.03 best acc: 93.37
adam_switch_sgdm_switch=25_lr=0.001_0.04 best acc: 93.2
adam_switch_sgdm_switch=25_lr=0.001_0.05 best acc: 93.7
adam_switch_sgdm_switch=25_lr=0.001_0.07 best acc: 93.38
adam_switch_sgdm_switch=25_lr=0.001_0.09 best acc: 93.71
adam_switch_sgdm_switch=25_lr=0.001_0.1 best acc: 93.34
adam_switch_sgdm_switch=25_lr=0.001_0.11 best acc: 93.19
adam_switch_sgdm_switch=25_lr=0.001_0.2 best acc: 92.66


In [None]:
#@title Switch = 50
switch = 50
lr = 0.001
dataset = 'CIFAR10'

name = f'adam_switch_sgdm_switch={switch}_lr={lr}'
methods=['sgdm','adam']
for lr_sgdm in LRS:
  methods.append(f'adam_switch_sgdm_switch={switch}_lr={lr}_{lr_sgdm}')

plotly_plot(methods,dataset, net='vggnet',name=name,benchmark_acc=93.74)

sgdm best acc: 93.74
adam best acc: 92.47
adam_switch_sgdm_switch=50_lr=0.001_0.001 best acc: 92.29
adam_switch_sgdm_switch=50_lr=0.001_0.003 best acc: 92.4
adam_switch_sgdm_switch=50_lr=0.001_0.005 best acc: 92.69
adam_switch_sgdm_switch=50_lr=0.001_0.007 best acc: 92.51
adam_switch_sgdm_switch=50_lr=0.001_0.009 best acc: 92.93
adam_switch_sgdm_switch=50_lr=0.001_0.01 best acc: 92.9
adam_switch_sgdm_switch=50_lr=0.001_0.02 best acc: 93.06
adam_switch_sgdm_switch=50_lr=0.001_0.03 best acc: 93.29
adam_switch_sgdm_switch=50_lr=0.001_0.04 best acc: 92.88
adam_switch_sgdm_switch=50_lr=0.001_0.05 best acc: 92.97
adam_switch_sgdm_switch=50_lr=0.001_0.07 best acc: 93.05
adam_switch_sgdm_switch=50_lr=0.001_0.09 best acc: 92.66
adam_switch_sgdm_switch=50_lr=0.001_0.1 best acc: 92.98
adam_switch_sgdm_switch=50_lr=0.001_0.11 best acc: 92.89


In [None]:
#@title Switch = 100
switch = 100
lr = 0.001
dataset = 'CIFAR10'

name = f'adam_switch_sgdm_switch={switch}_lr={lr}'
methods=['sgdm','adam']
for lr_sgdm in LRS:
  methods.append(f'adam_switch_sgdm_switch={switch}_lr={lr}_{lr_sgdm}')

plotly_plot(methods,dataset, net='vggnet',name=name,benchmark_acc=93.74)

sgdm best acc: 93.74
adam best acc: 92.47
adam_switch_sgdm_switch=100_lr=0.001_0.001 best acc: 92.31
adam_switch_sgdm_switch=100_lr=0.001_0.003 best acc: 92.68
adam_switch_sgdm_switch=100_lr=0.001_0.005 best acc: 93.06
adam_switch_sgdm_switch=100_lr=0.001_0.007 best acc: 92.97
adam_switch_sgdm_switch=100_lr=0.001_0.009 best acc: 92.91
adam_switch_sgdm_switch=100_lr=0.001_0.01 best acc: 92.81
adam_switch_sgdm_switch=100_lr=0.001_0.03 best acc: 92.84
adam_switch_sgdm_switch=100_lr=0.001_0.05 best acc: 92.69
adam_switch_sgdm_switch=100_lr=0.001_0.07 best acc: 92.37
adam_switch_sgdm_switch=100_lr=0.001_0.09 best acc: 91.87
adam_switch_sgdm_switch=100_lr=0.001_0.1 best acc: 91.75
adam_switch_sgdm_switch=100_lr=0.001_0.11 best acc: 91.47


# LCSA-CW

In [None]:
# Switch = 200 i.e. no switch
switch = 200
name=f'LinearCombineSGDMAdam_switch={switch}'
methods=['sgdm','adam']

for a in [0.1,0.3,0.5,0.7,0.9]:
  for lr in LRS:
    methods.append(f'LinearCombineSGDMAdam_a={a}_lr={lr}_switch=200_lr_sgdm=NA')

plotly_plot(methods,'CIFAR10', net='vggnet',name=name,benchmark_acc=93.2)

sgdm best acc: 93.74
adam best acc: 92.47
LinearCombineSGDMAdam_a=0.1_lr=0.001_switch=200_lr_sgdm=NA best acc: 93.18
LinearCombineSGDMAdam_a=0.1_lr=0.003_switch=200_lr_sgdm=NA best acc: 90.92
LinearCombineSGDMAdam_a=0.1_lr=0.005_switch=200_lr_sgdm=NA best acc: 78.95
LinearCombineSGDMAdam_a=0.1_lr=0.007_switch=200_lr_sgdm=NA best acc: 74.64
LinearCombineSGDMAdam_a=0.1_lr=0.009_switch=200_lr_sgdm=NA best acc: 73.96
LinearCombineSGDMAdam_a=0.3_lr=0.001_switch=200_lr_sgdm=NA best acc: 93.42
LinearCombineSGDMAdam_a=0.3_lr=0.003_switch=200_lr_sgdm=NA best acc: 91.93
LinearCombineSGDMAdam_a=0.3_lr=0.005_switch=200_lr_sgdm=NA best acc: 90.73
LinearCombineSGDMAdam_a=0.3_lr=0.007_switch=200_lr_sgdm=NA best acc: 78.74
LinearCombineSGDMAdam_a=0.3_lr=0.009_switch=200_lr_sgdm=NA best acc: 75.25
LinearCombineSGDMAdam_a=0.3_lr=0.01_switch=200_lr_sgdm=NA best acc: 87.27
LinearCombineSGDMAdam_a=0.3_lr=0.05_switch=200_lr_sgdm=NA best acc: 30.93
LinearCombineSGDMAdam_a=0.5_lr=0.001_switch=200_lr_sgdm=NA b

# LCSA-DDW

In [None]:
#@title Switch = 0
switch = 0
name=f'LinearCombineSGDMAdam_a=0.7_switch={switch}'
methods=['sgdm','adam']

for lr in LRS:
  for lr_sgdm in LRS:
    methods.append(f'LinearCombineSGDMAdam_a=0.7_lr={lr}_switch={switch}_lr_sgdm={lr_sgdm}')

plotly_plot(methods,'CIFAR10', net='vggnet',name=name,benchmark_acc=0)

sgdm best acc: 93.74
adam best acc: 92.47
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=0_lr_sgdm=0.05 best acc: 93.75


In [None]:
#@title Switch = 5
switch = 5
name=f'LinearCombineSGDMAdam_a=0.7_switch={switch}'
methods=['sgdm','adam']

for lr in LRS:
  for lr_sgdm in LRS:
    methods.append(f'LinearCombineSGDMAdam_a=0.7_lr={lr}_switch={switch}_lr_sgdm={lr_sgdm}')

plotly_plot(methods,'CIFAR10', net='vggnet',name=name,benchmark_acc=93.74)

sgdm best acc: 93.74
adam best acc: 92.47
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=5_lr_sgdm=0.05 best acc: 93.75


In [None]:
#@title Switch = 10

switch = 10
name=f'LinearCombineSGDMAdam_a=0.7_switch={switch}'
methods=['sgdm','adam']

for lr in LRS:
  for lr_sgdm in LRS:
    methods.append(f'LinearCombineSGDMAdam_a=0.7_lr={lr}_switch={switch}_lr_sgdm={lr_sgdm}')

plotly_plot(methods,'CIFAR10', net='vggnet',name=name,benchmark_acc=93.74)

sgdm best acc: 93.74
adam best acc: 92.47
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=10_lr_sgdm=0.05 best acc: 94.06
LinearCombineSGDMAdam_a=0.7_lr=0.003_switch=10_lr_sgdm=0.05 best acc: 93.74


In [None]:
#@title Switch = 25

switch = 25
name=f'LinearCombineSGDMAdam_a=0.7_switch={switch}'
methods=['sgdm','adam']

for a in [0.1,0.3,0.5,0.7,0.9]:
  for lr in LRS:
    for lr_sgdm in LRS:
      methods.append(f'LinearCombineSGDMAdam_a={a}_lr={lr}_switch={switch}_lr_sgdm={lr_sgdm}')

plotly_plot(methods,'CIFAR10', net='vggnet',name=name,benchmark_acc=94)

sgdm best acc: 93.74
adam best acc: 92.47
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=25_lr_sgdm=0.01 best acc: 93.39
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=25_lr_sgdm=0.03 best acc: 93.99
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=25_lr_sgdm=0.05 best acc: 94.18
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=25_lr_sgdm=0.07 best acc: 93.94
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=25_lr_sgdm=0.09 best acc: 93.86
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=25_lr_sgdm=0.1 best acc: 93.7
LinearCombineSGDMAdam_a=0.7_lr=0.003_switch=25_lr_sgdm=0.01 best acc: 93.45
LinearCombineSGDMAdam_a=0.7_lr=0.003_switch=25_lr_sgdm=0.03 best acc: 94.02
LinearCombineSGDMAdam_a=0.7_lr=0.003_switch=25_lr_sgdm=0.05 best acc: 94.05
LinearCombineSGDMAdam_a=0.7_lr=0.003_switch=25_lr_sgdm=0.07 best acc: 93.99
LinearCombineSGDMAdam_a=0.7_lr=0.003_switch=25_lr_sgdm=0.09 best acc: 93.68
LinearCombineSGDMAdam_a=0.7_lr=0.003_switch=25_lr_sgdm=0.1 best acc: 93.68
LinearCombineSGDMAdam_a=0.7_lr=0.005_switch=25_lr

In [None]:
#@title Switch = 50
switch = 50
name=f'LinearCombineSGDMAdam_a=0.7_switch={switch}'
methods=['sgdm','adam']

for a in [0.1,0.3,0.5,0.7,0.9]:
  for lr in LRS:
    for lr_sgdm in LRS:
      methods.append(f'LinearCombineSGDMAdam_a={a}_lr={lr}_switch={switch}_lr_sgdm={lr_sgdm}')

plotly_plot(methods,'CIFAR10', net='vggnet',name=name,benchmark_acc=93.74)

sgdm best acc: 93.74
adam best acc: 92.47
LinearCombineSGDMAdam_a=0.5_lr=0.001_switch=50_lr_sgdm=0.001 best acc: 92.88
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=50_lr_sgdm=0.001 best acc: 93.12
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=50_lr_sgdm=0.003 best acc: 92.74
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=50_lr_sgdm=0.005 best acc: 93.28
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=50_lr_sgdm=0.007 best acc: 92.89
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=50_lr_sgdm=0.009 best acc: 90.99
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=50_lr_sgdm=0.01 best acc: 93.3
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=50_lr_sgdm=0.015 best acc: 93.49
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=50_lr_sgdm=0.02 best acc: 93.32
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=50_lr_sgdm=0.025 best acc: 93.43
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=50_lr_sgdm=0.03 best acc: 93.67
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=50_lr_sgdm=0.04 best acc: 93.7
LinearCombineSGDMAdam_a=0.7_lr=0.001_swi

In [None]:
#@title Switch = 100
switch = 100
name=f'LinearCombineSGDMAdam_a=0.7_switch={switch}'
methods=['sgdm','adam']

for lr in LRS:
  for lr_sgdm in LRS:
    methods.append(f'LinearCombineSGDMAdam_a=0.7_lr={lr}_switch={switch}_lr_sgdm={lr_sgdm}')

plotly_plot(methods,'CIFAR10', net='vggnet',name=name,benchmark_acc=93.74)

sgdm best acc: 93.74
adam best acc: 92.47
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=100_lr_sgdm=0.001 best acc: 93.06
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=100_lr_sgdm=0.003 best acc: 93.26
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=100_lr_sgdm=0.005 best acc: 93.21
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=100_lr_sgdm=0.007 best acc: 93.39
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=100_lr_sgdm=0.009 best acc: 93.38
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=100_lr_sgdm=0.01 best acc: 93.06
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=100_lr_sgdm=0.03 best acc: 93.67
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=100_lr_sgdm=0.05 best acc: 93.45
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=100_lr_sgdm=0.07 best acc: 93.0
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=100_lr_sgdm=0.09 best acc: 92.58
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=100_lr_sgdm=0.1 best acc: 92.68
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=100_lr_sgdm=0.11 best acc: 92.43
LinearCombineSGDMAdam_a=0.7_lr=

In [None]:
#@title CIFAR100

# LinearCombineSGDMAdam_switch(switch=50,lr=0.003,lr_sgdm=0.03,dataset='CIFAR100')
# LinearCombineSGDMAdam_switch(switch=100,lr=0.003,lr_sgdm=0.003,dataset='CIFAR100')

name=f'LinearCombineSGDMAdam_a=0.7'
methods=['sgdm','adam']
methods.append(f'LinearCombineSGDMAdam_a=0.7_lr=0.003_switch=50_lr_sgdm=0.03')
methods.append(f'LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=50_lr_sgdm=0.05')
methods.append(f'LinearCombineSGDMAdam_a=0.7_lr=0.003_switch=100_lr_sgdm=0.003')
methods.append(f'LinearCombineSGDMAdam_a=0.7_lr=0.003_switch=25_lr_sgdm=0.03')
methods.append(f'LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=25_lr_sgdm=0.05')



plotly_plot(methods,'CIFAR100', net='vggnet',name=name,benchmark_acc = 60)

sgdm best acc: 73.77
adam best acc: 68.03
LinearCombineSGDMAdam_a=0.7_lr=0.003_switch=50_lr_sgdm=0.03 best acc: 68.43
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=50_lr_sgdm=0.05 best acc: 74.2
LinearCombineSGDMAdam_a=0.7_lr=0.003_switch=100_lr_sgdm=0.003 best acc: 64.74
LinearCombineSGDMAdam_a=0.7_lr=0.003_switch=25_lr_sgdm=0.03 best acc: 71.27
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=25_lr_sgdm=0.05 best acc: 74.35


# LCSA-CDW

In [None]:
#@title target 10,15,25
# target_epoch = 25

name=f'LinearCombineSGDMAdam_dynamic_target=10_15_25'
methods=['sgdm','adam']

for target_epoch in [10,15,25]:
  for a in LRS:
    for lr in LRS:
      for lr_sgdm in LRS:
        methods.append(f'LinearCombineSGDMAdam_a={a}_lr={lr}_lr_sgdm={lr_sgdm}_target_epoch={target_epoch}')

plotly_plot(methods,'CIFAR10', net='vggnet',name=name,benchmark_acc = 93)

sgdm best acc: 93.74
adam best acc: 92.47
LinearCombineSGDMAdam_a=0.7_lr=0.001_lr_sgdm=0.001_target_epoch=15 best acc: 90.31
LinearCombineSGDMAdam_a=0.7_lr=0.003_lr_sgdm=0.003_target_epoch=15 best acc: 91.26
LinearCombineSGDMAdam_a=0.9_lr=0.001_lr_sgdm=0.001_target_epoch=15 best acc: 88.1
LinearCombineSGDMAdam_a=0.9_lr=0.003_lr_sgdm=0.003_target_epoch=15 best acc: 90.4
LinearCombineSGDMAdam_a=0.7_lr=0.001_lr_sgdm=0.001_target_epoch=25 best acc: 91.37
LinearCombineSGDMAdam_a=0.7_lr=0.003_lr_sgdm=0.003_target_epoch=25 best acc: 91.79
LinearCombineSGDMAdam_a=0.9_lr=0.001_lr_sgdm=0.001_target_epoch=25 best acc: 88.61
LinearCombineSGDMAdam_a=0.9_lr=0.003_lr_sgdm=0.003_target_epoch=25 best acc: 91.34


In [None]:
#@title target = 50
target_epoch = 50

name=f'LinearCombineSGDMAdam_dynamic_target={target_epoch}'
methods=['sgdm','adam']

for a in LRS:
  for lr in LRS:
    for lr_sgdm in LRS:
      methods.append(f'LinearCombineSGDMAdam_a={a}_lr={lr}_lr_sgdm={lr_sgdm}_target_epoch={target_epoch}')

plotly_plot(methods,'CIFAR10', net='vggnet',name=name,benchmark_acc = 93)


sgdm best acc: 93.74
adam best acc: 92.47
LinearCombineSGDMAdam_a=0.1_lr=0.001_lr_sgdm=0.001_target_epoch=50 best acc: 92.4
LinearCombineSGDMAdam_a=0.1_lr=0.005_lr_sgdm=0.005_target_epoch=50 best acc: 89.78
LinearCombineSGDMAdam_a=0.1_lr=0.01_lr_sgdm=0.01_target_epoch=50 best acc: 71.08
LinearCombineSGDMAdam_a=0.1_lr=0.05_lr_sgdm=0.05_target_epoch=50 best acc: 33.21
LinearCombineSGDMAdam_a=0.1_lr=0.09_lr_sgdm=0.09_target_epoch=50 best acc: 25.76
LinearCombineSGDMAdam_a=0.3_lr=0.001_lr_sgdm=0.001_target_epoch=50 best acc: 92.53
LinearCombineSGDMAdam_a=0.3_lr=0.005_lr_sgdm=0.005_target_epoch=50 best acc: 90.68
LinearCombineSGDMAdam_a=0.3_lr=0.01_lr_sgdm=0.01_target_epoch=50 best acc: 77.07
LinearCombineSGDMAdam_a=0.3_lr=0.05_lr_sgdm=0.05_target_epoch=50 best acc: 38.35
LinearCombineSGDMAdam_a=0.3_lr=0.09_lr_sgdm=0.09_target_epoch=50 best acc: 26.41
LinearCombineSGDMAdam_a=0.5_lr=0.001_lr_sgdm=0.001_target_epoch=50 best acc: 92.56
LinearCombineSGDMAdam_a=0.5_lr=0.005_lr_sgdm=0.005_target_

In [None]:
#@title target = 100
target_epoch = 100
name=f'LinearCombineSGDMAdam_dynamic_a=0.7_target={target_epoch}'
methods=['sgdm','adam']

for lr in LRS:
  for lr_sgdm in LRS:
    methods.append(f'LinearCombineSGDMAdam_a=0.7_lr={lr}_lr_sgdm={lr_sgdm}_target_epoch={target_epoch}')

plotly_plot(methods,'CIFAR10', net='vggnet',name=name,benchmark_acc = 93.74)


sgdm best acc: 93.74
adam best acc: 92.47


In [None]:

methods=['sgdm','LinearCombineSGDMAdam_a=0.7_lr=0.001_lr_sgdm=0.05_target_epoch=50',
         'LinearCombineSGDMAdam_a=0.7_lr=0.001_lr_sgdm=0.001_target_epoch=50',
         'LinearCombineSGDMAdam_a=0.7_lr=0.001_lr_sgdm=NA_target_epoch=200',
         'LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=50_lr_sgdm=0.05',
         'LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=200_lr_sgdm=NA']
plotly_plot(methods,'CIFAR10', net='vggnet',name='test',benchmark_acc = 0)

sgdm best acc: 93.74
LinearCombineSGDMAdam_a=0.7_lr=0.001_lr_sgdm=0.05_target_epoch=50 best acc: 92.69
LinearCombineSGDMAdam_a=0.7_lr=0.001_lr_sgdm=0.001_target_epoch=50 best acc: 92.11
LinearCombineSGDMAdam_a=0.7_lr=0.001_lr_sgdm=NA_target_epoch=200 best acc: 92.89
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=50_lr_sgdm=0.05 best acc: 94.05
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=200_lr_sgdm=NA best acc: 92.88


In [None]:

methods=['sgdm','padam','adam','LinearCombineSGDMAdam_a=0.7_initial_lr=0.001_target_lr=0.05_target_epoch=25',
         'LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=25_lr_sgdm=0.05']
plotly_plot(methods,'CIFAR10', net='vggnet',name='test',benchmark_acc = 0)

sgdm best acc: 93.74
padam best acc: 94.06
adam best acc: 92.47
LinearCombineSGDMAdam_a=0.7_initial_lr=0.001_target_lr=0.05_target_epoch=25 best acc: 92.61
LinearCombineSGDMAdam_a=0.7_lr=0.001_switch=25_lr_sgdm=0.05 best acc: 94.18


In [None]:
#@title target = 200
target_epoch = 200
name=f'LinearCombineSGDMAdam_dynamic_target={target_epoch}'
methods=['sgdm','adam']

for a in LRS:
  for lr in LRS:
      methods.append(f'LinearCombineSGDMAdam_a={a}_lr={lr}_lr_sgdm=NA_target_epoch={target_epoch}')

plotly_plot(methods,'CIFAR10', net='vggnet',name=name,benchmark_acc = 93)


sgdm best acc: 93.74
adam best acc: 92.47
LinearCombineSGDMAdam_a=0.3_lr=0.001_lr_sgdm=NA_target_epoch=200 best acc: 93.1
LinearCombineSGDMAdam_a=0.3_lr=0.003_lr_sgdm=NA_target_epoch=200 best acc: 92.26
LinearCombineSGDMAdam_a=0.3_lr=0.005_lr_sgdm=NA_target_epoch=200 best acc: 90.71
LinearCombineSGDMAdam_a=0.3_lr=0.007_lr_sgdm=NA_target_epoch=200 best acc: 79.2
LinearCombineSGDMAdam_a=0.3_lr=0.009_lr_sgdm=NA_target_epoch=200 best acc: 76.91
LinearCombineSGDMAdam_a=0.3_lr=0.01_lr_sgdm=NA_target_epoch=200 best acc: 73.13
LinearCombineSGDMAdam_a=0.3_lr=0.03_lr_sgdm=NA_target_epoch=200 best acc: 48.57
LinearCombineSGDMAdam_a=0.3_lr=0.05_lr_sgdm=NA_target_epoch=200 best acc: 31.06
LinearCombineSGDMAdam_a=0.3_lr=0.07_lr_sgdm=NA_target_epoch=200 best acc: 23.27
LinearCombineSGDMAdam_a=0.3_lr=0.09_lr_sgdm=NA_target_epoch=200 best acc: 36.58
LinearCombineSGDMAdam_a=0.3_lr=0.1_lr_sgdm=NA_target_epoch=200 best acc: 24.9
LinearCombineSGDMAdam_a=0.5_lr=0.001_lr_sgdm=NA_target_epoch=200 best acc: 93