In [11]:
import math
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import time as time
import numpy as np
from gradient_descent_the_ultimate_optimizer import gdtuo
from gradient_descent_the_ultimate_optimizer.gdtuo import Optimizable
import os
import matplotlib.pyplot as plt
import imageio
from IPython.display import Video, Image
from poly_fit_relu import train_poly_fit_relu as pfr
from poly_fit_relu import plot_poly_fit_relu as ppfr

torch.cuda.empty_cache()
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

class MNIST_CNN(nn.Module):
    def __init__(self, poly_act):
        super(MNIST_CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(12544, 128)  # Adjusted input dimensions
        self.fc2 = nn.Linear(128, 10)
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm1d(128)

        self.poly_act = poly_act
        self.dict_stats = {}
        self.gather_stats = False

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        if self.gather_stats:
            self.dict_stats['conv1_mean'] = x.clone().detach().cpu().numpy().mean()
            self.dict_stats['conv1_std'] = x.clone().detach().cpu().numpy().std()
        x = self.poly_act(x)

        x = self.conv2(x)
        x = self.bn2(x)
        if self.gather_stats:
            self.dict_stats['conv2_mean'] = x.clone().detach().cpu().numpy().mean()
            self.dict_stats['conv2_std'] = x.clone().detach().cpu().numpy().std()
        x = self.poly_act(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)

        x = self.fc1(x)
        x = self.bn3(x)
        if self.gather_stats:
            self.dict_stats['fc1_mean'] = x.clone().detach().cpu().numpy().mean()
            self.dict_stats['fc1_std'] = x.clone().detach().cpu().numpy().std()
        x = self.poly_act(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

BATCH_SIZE = 256
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

mnist_train = torchvision.datasets.MNIST('./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST('./data', train=False, download=True, transform=torchvision.transforms.ToTensor())
dl_train = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=True)
dl_test = torch.utils.data.DataLoader(mnist_test, batch_size=256, shuffle=False)

cuda


In [12]:
rank = 8

class PolyAct(Optimizable):
    def __init__(self, optimizer, ranks = rank, coefs_list = None):
        self.n = ranks
        self.coefs = torch.randn(np.min([self.n, 4]))
        if self.n > 4:
            self.coefs = torch.cat((self.coefs, torch.zeros(self.n - 4)))
        self.coefs = nn.Parameter(self.coefs)
        self.parameters = {'coefs': self.coefs}
        if coefs_list is not None:
            self.n = len(coefs_list)
            self.coefs = nn.Parameter(torch.tensor(coefs_list))
            self.parameters = {'coefs': self.coefs}
                
        self.optimizer = optimizer
        self.all_params_with_gradients = [self.parameters['coefs']]
        super().__init__(self.parameters, self.optimizer)

    def __call__(self, x):
        out = 0

        for i in range(self.n):
            out += self.parameters['coefs'][i] * x ** i
        return out
    
    def step(self):
        self.optimizer.step(self.parameters)

poly_act = PolyAct(gdtuo.Adam(0.0001), ranks = rank)
poly_act_init = PolyAct(gdtuo.Adam(0.001), ranks = rank)
poly_act.initialize()

model = MNIST_CNN(poly_act).to(DEVICE)
optim = torch.optim.SGD(model.parameters(), lr=0.045)

In [13]:
init_time = time.time()
EPOCHS = 10
coefs_list = [poly_act.parameters['coefs'].detach().cpu().numpy()]
for i in range(1, EPOCHS+1):
    running_acc = 0.0
    running_loss = 0.0
    for j, (features_, labels_) in enumerate(dl_train):

        if j % 100 == 0:
            print('coefs so far', poly_act.parameters['coefs'].detach())
            coefs_list.append(poly_act.parameters['coefs'].detach().cpu().numpy())
        poly_act.begin()
        optim.zero_grad()
        poly_act.zero_grad()
        features, labels = features_.to(DEVICE), labels_.to(DEVICE)
        pred = model.forward(features)
        loss = F.nll_loss(pred, labels)
        loss.backward(create_graph=True)

        optim.step()
        poly_act.step()
        running_loss += loss.item() * features_.size(0)
        running_acc += (torch.argmax(pred, dim=1) == labels).sum().item()
    train_loss = running_loss / len(dl_train.dataset)
    train_acc = running_acc / len(dl_train.dataset)
    print("EPOCH: {}, TRAIN LOSS: {}, ACC: {}".format(i, train_loss, train_acc))
    print(model.dict_stats)

print("Time taken: {}".format(time.time() - init_time))

coefs so far tensor([-0.2352,  0.5826,  0.9822, -0.2571,  0.0000,  0.0000,  0.0000,  0.0000])


coefs so far tensor([-2.4279e-01,  5.7846e-01,  9.8638e-01, -2.5771e-01, -3.6745e-04,
         1.1149e-04, -9.5481e-04,  2.5590e-05])
coefs so far tensor([-2.4939e-01,  5.7664e-01,  9.9008e-01, -2.5804e-01,  2.5385e-04,
         2.3853e-04, -9.4190e-04,  6.9518e-05])
EPOCH: 1, TRAIN LOSS: 0.310711980565389, ACC: 0.91845
{}
coefs so far tensor([-2.5187e-01,  5.7634e-01,  9.9110e-01, -2.5826e-01,  4.5258e-04,
         2.2792e-04, -9.1979e-04,  6.1382e-05])
coefs so far tensor([-2.5796e-01,  5.7594e-01,  9.9424e-01, -2.5871e-01,  1.0286e-03,
         2.3197e-04, -8.9454e-04,  6.4794e-05])
coefs so far tensor([-2.6465e-01,  5.7636e-01,  9.9692e-01, -2.5895e-01,  1.4801e-03,
         2.6491e-04, -8.7818e-04,  7.3570e-05])
EPOCH: 2, TRAIN LOSS: 0.11630143987337749, ACC: 0.9711333333333333
{}
coefs so far tensor([-2.6708e-01,  5.7710e-01,  9.9767e-01, -2.5889e-01,  1.5508e-03,
         2.9472e-04, -8.9600e-04,  8.0034e-05])
coefs so far tensor([-2.7357e-01,  5.7808e-01,  1.0004e+00, -2.5921e-

KeyboardInterrupt: 

In [None]:
x = np.linspace(-4, 4, 1000)

for i in range(len(coefs_list)):
    curr_poly_act = PolyAct(gdtuo.Adam(0.001), coefs_list = torch.tensor(coefs_list[i]))
    y = curr_poly_act(torch.tensor(x)).detach()

    fig, ax = plt.subplots()
    ax.plot(x, y)
    ax.plot(x, np.maximum(x, 0))
    # set small cross at 0.0
    ax.plot([0.0], [0.0], 'x', color='red')

    ax.set_xlim([-4, 4])
    ax.set_ylim([-10, 10])
    ax.set_yscale('linear')
    os.makedirs('plots', exist_ok=True)
    plt.savefig('plots/{}.png'.format(i))
    plt.close()

video = './polyact.mp4'
imageio.mimsave(video, [imageio.imread('plots/{}.png'.format(i)) for i in range(len(coefs_list))], fps = 3)
#play it here
Video(video)

  self.coefs = nn.Parameter(torch.tensor(coefs_list))
  self.coefs = nn.Parameter(torch.tensor(coefs_list))
  self.coefs = nn.Parameter(torch.tensor(coefs_list))
  self.coefs = nn.Parameter(torch.tensor(coefs_list))
  self.coefs = nn.Parameter(torch.tensor(coefs_list))
  self.coefs = nn.Parameter(torch.tensor(coefs_list))
  self.coefs = nn.Parameter(torch.tensor(coefs_list))
  self.coefs = nn.Parameter(torch.tensor(coefs_list))
  self.coefs = nn.Parameter(torch.tensor(coefs_list))
  self.coefs = nn.Parameter(torch.tensor(coefs_list))
  self.coefs = nn.Parameter(torch.tensor(coefs_list))
  self.coefs = nn.Parameter(torch.tensor(coefs_list))
  self.coefs = nn.Parameter(torch.tensor(coefs_list))
  self.coefs = nn.Parameter(torch.tensor(coefs_list))
  self.coefs = nn.Parameter(torch.tensor(coefs_list))
  self.coefs = nn.Parameter(torch.tensor(coefs_list))
  self.coefs = nn.Parameter(torch.tensor(coefs_list))
  self.coefs = nn.Parameter(torch.tensor(coefs_list))
  self.coefs = nn.Parameter(

In [None]:
mean_weight_CNN = 0
std_weight_CNN = 0

for name, param in model.named_parameters():
    if 'weight' in name:
        mean_weight_CNN += param.data.mean()
        
print(mean_weight_CNN) 


tensor(3.0436, device='cuda:0')
