<a href="https://colab.research.google.com/github/RobbieEarle/al_da/blob/master/replicating_maxout_sgd.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.utils.data

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import numpy as np
from matplotlib import pyplot as plt

import json
from google.colab import drive
from datetime import datetime
from pytz import timezone  

import math
import time
import pprint

In [0]:
##@title PyTorch NN Implementation - MaxOut

class MaxNet(nn.Module):

  def __init__(self, 
               hidden_u1=32,
               hidden_u2=16,
               k=2,
               l0_dropout_prob=0.5,
               l1_dropout_prob=0.5,
               l2_dropout_prob=0.5
               ):
    super(MaxNet, self).__init__()

    self.k = k
    self.fc1 = nn.Linear(784, hidden_u1)
    self.fc2 = nn.Linear(int(math.ceil(hidden_u1 / k)), hidden_u2)
    self.fc3 = nn.Linear(int(math.ceil(hidden_u2 / k)), 10)
    self.l0_dropout = nn.Dropout(p=l0_dropout_prob)
    self.l1_dropout = nn.Dropout(p=l1_dropout_prob)
    self.l2_dropout = nn.Dropout(p=l2_dropout_prob)

  def forward(self, x):
    x = self.l0_dropout(x.view(-1, 28*28))
    x = self.l1_dropout(self.maxout(self.fc1(x)))
    x = self.l2_dropout(self.maxout(self.fc2(x)))
    x = self.fc3(x)
    
    return x

  def maxout(self, x):
    batch_size = x.size()[0]
    num_hidden_nodes = x.size()[1]
    num_groups = math.floor(num_hidden_nodes / self.k)
    remainder = num_hidden_nodes % self.k

    if remainder != 0:
      y = x[:, num_hidden_nodes-remainder:]
      x = x[:, :num_hidden_nodes-remainder]
      y = torch.max(y, dim=1).values
      y = y.view(y.size()[0],1)

    x = x.view(batch_size,num_groups,self.k)
    x = torch.max(x, dim=2).values

    if remainder != 0:
      x = torch.cat((x, y), dim=1)

    return x

In [0]:
# Initializes weights within given range
irange = 0.005
def weights_init(m):
  if type(m) == nn.Linear:
    m.weight.data.uniform_(-1 * irange, irange)
    m.bias.data.fill_(0)

# Applies max norm regularization
def max_norm(model, max_val=1, eps=1e-8):
  for name, param in model.named_parameters():
    if 'bias' not in name:
      norm = param.norm(2, dim=1, keepdim=True)
      desired = torch.clamp(norm, 0, max_val)
      param = param * (desired / (eps + norm))

def train_model(model, 
                train_loader,
                test_loader,
                goal_loss=0,
                curr_best_acc=0,
                save_path=None,
                curr_iteration=0,
                curr_round=0,
                init_lr=0.1,
                init_momentum=0.5):

  # ---- Initialization
  if curr_round == 1:
    model.apply(weights_init)
    max_norm(model, max_val=1.9365)
  optimizer = optim.SGD(model.parameters(), lr=init_lr, momentum=init_momentum)
  criterion = nn.CrossEntropyLoss()
  scheduler = StepLR(optimizer, step_size=1, gamma=1/1.000004)
  epoch = 1
  curr_best_misclasses = 0
  train_loss = 0
  logs = {
      "iteration":curr_iteration,
      "round":curr_round,
      "curr_best_acc":float(curr_best_acc),
      "data":[]
      }

  while ((curr_round == 1 and curr_best_misclasses > 100) or (curr_round == 2 and train_loss > goal_loss) or epoch == 1) and epoch <= 250:

    # ---- Training
    model.train()
    for batch_idx, (x, target) in enumerate(train_loader):
      if torch.cuda.is_available():
        x, target = x.cuda(), target.cuda()
      optimizer.zero_grad()
      out = model(x)
      train_loss = criterion(out, target)
      train_loss.backward()
      optimizer.step()
      scheduler.step()
      max_norm(model, max_val=1.9365)
    for i in range(len(optimizer.param_groups)):
      optimizer.param_groups[i]['momentum'] = min(0.7, optimizer.param_groups[i]['momentum'] + 0.0008)

    # ---- Testing
    num_correct = 0
    num_total = 0
    misclasses = 0
    model.eval()
    with torch.no_grad():
      for batch_idx2, (x, target) in enumerate(test_loader):
        if torch.cuda.is_available():
          x, target = x.cuda(), target.cuda()
        out = model(x)
        test_loss = criterion(out, target)
        _, prediction = torch.max(out.data, 1)
        num_correct += torch.sum(prediction == target.data)
        num_total += len(prediction)
      misclasses += num_total - num_correct
      accuracy = num_correct * 1.0 / num_total

      if accuracy > curr_best_acc:
        curr_best_acc = accuracy
        path = save_path + "best_acc_weights.pt"
        torch.save(model.state_dict(), path)

    if misclasses < curr_best_misclasses or epoch == 1:
      curr_best_misclasses = int(misclasses)

    # ---- Logging
    print (
        "  Epoch {} : lr = {:1.6f}  |  momentum = {:1.6f}  |  num_misclass = {}  |  loss = {:1.6f}  |  Accuracy = {:1.5f}"
        .format(epoch, optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['momentum'], misclasses, train_loss, accuracy)
    )
    log = {
        "epoch":epoch,
        "lr":optimizer.param_groups[0]['lr'],
        "momentum":optimizer.param_groups[0]['momentum'],
        "cross_entropy":float(train_loss),
        "accuracy":float(accuracy),
        "num_misclasses":float(misclasses)
    }
    logs["data"].append(log)

    if epoch % 50 == 0 or curr_best_misclasses <= 100:
      log_file_path = save_path + "log.json"
      with open(log_file_path) as f:
        log_file = json.load(f)
      log_file["entries"].append(logs)
      with open(log_file_path, 'w') as f:
          json.dump(log_file, f)

    epoch += 1


  return train_loss, curr_best_acc, model, optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['momentum']

In [0]:
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
mnist_train_p2 = datasets.MNIST(root='./data', train=True, download=True, transform=trans)
mnist_test_p2 = datasets.MNIST(root='./data', train=False, download=True, transform=trans)

train_set_indices = np.arange(0,50000)
mnist_train_p1 = torch.utils.data.Subset(mnist_train_p2, train_set_indices)
test_set_indices = np.arange(50000, 60000)
mnist_test_p1 = torch.utils.data.Subset(mnist_train_p2, test_set_indices)

batch_size = 100

train_loader_p1 = torch.utils.data.DataLoader(dataset=mnist_train_p1, batch_size=batch_size, shuffle=True)
test_loader_p1 = torch.utils.data.DataLoader(dataset=mnist_test_p1, batch_size=batch_size, shuffle=True)
train_loader_p2 = torch.utils.data.DataLoader(dataset=mnist_train_p2, batch_size=batch_size, shuffle=True)
test_loader_p2 = torch.utils.data.DataLoader(dataset=mnist_test_p2, batch_size=batch_size, shuffle=True)

curr_best_acc = 0
iteration = 1
drive.mount('/content/drive')
path = F"/content/drive/My Drive/combinact_outputs/maxout_sgd_training/training_iterations/"
log_file_path = path + "log.json"
if not os.path.exists(log_file_path):
  log_file = {"entries":[]}
  with open(log_file_path, 'w') as f:
      json.dump(log_file, f)

while curr_best_acc < 0.99:
  
  model = MaxNet(240, 240, 5, 0, 0.2, 0)
  if torch.cuda.is_available():
    model.cuda()
  print("------> Iteration " + str(iteration) + "  |  Round 1")
  cross_entropy, best_acc, model, round1_lr, round1_momentum = train_model(model,
                                                                           train_loader=train_loader_p1, 
                                                                           test_loader=test_loader_p1, 
                                                                           curr_best_acc = curr_best_acc, 
                                                                           save_path=path, 
                                                                           curr_iteration=iteration,
                                                                           curr_round=1)
  
  if best_acc > curr_best_acc:
    curr_best_acc = best_acc

  print("Best Accuracy: " + str(curr_best_acc))
  print("------> Iteration " + str(iteration) + "  |  Round 2")
  _, best_acc, _, _, _ = train_model(model, 
                                     train_loader=train_loader_p2, 
                                     test_loader=test_loader_p2, 
                                     goal_loss=cross_entropy,
                                     curr_best_acc = curr_best_acc, 
                                     save_path=path, 
                                     curr_iteration=iteration, 
                                     curr_round = 2, 
                                     init_lr=round1_lr, 
                                     init_momentum=round1_momentum)
  
  if best_acc > curr_best_acc:
    curr_best_acc = best_acc

  print("Best Accuracy: " + str(curr_best_acc))
  iteration += 1

drive.flush_and_unmount()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
------> Iteration 1  |  Round 1
  Epoch 1 : lr = 0.099800  |  momentum = 0.500800  |  num_misclass = 7945  |  loss = 2.192814  |  Accuracy = 0.20550
  Epoch 2 : lr = 0.099601  |  momentum = 0.501600  |  num_misclass = 1230  |  loss = 0.440342  |  Accuracy = 0.87700
  Epoch 3 : lr = 0.099402  |  momentum = 0.502400  |  num_misclass = 587  |  loss = 0.187316  |  Accuracy = 0.94130
  Epoch 4 : lr = 0.099203  |  momentum = 0.503200  |  num_misclass = 486  |  loss = 0.256267  |  Accuracy = 0.95140
  Epoch 5 : lr = 0.099005  |  momentum = 0.504000  |  num_misclass = 420  |  loss = 0.319946  |  Accuracy = 0.95800
  Epoch 6 : lr = 0.098807  |  momentum = 0.504800  |  num_misclass = 340  |  loss = 0.283217  |  Accuracy = 0.96600
  Epoch 7 : lr = 0.098610  |  momentum = 0.505600  |  num_misclass = 409  |  loss = 0.340215  |  Accuracy = 0.95910
  Epoch 8 : lr = 0.098413