In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision
import numpy as np
import json
import time
import sys
import copy

from kWTA import models
from kWTA import activation
from kWTA import attack
from kWTA import training
from kWTA import utilities
from kWTA import densenet
from kWTA import resnet
from kWTA import wideresnet

norm_mean = 0
norm_var = 1
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((norm_mean,norm_mean,norm_mean), (norm_var, norm_var, norm_var)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((norm_mean,norm_mean,norm_mean), (norm_var, norm_var, norm_var)),
])



In [3]:
svhn_train = datasets.SVHN("./data", split='train', download=True, transform=transform_train)
svhn_test = datasets.SVHN("./data", split='test', download=True, transform=transform_test)
train_loader = DataLoader(svhn_train, batch_size = 512, shuffle=True)
#test_loader = DataLoader(svhn_test, batch_size = 50, shuffle=True)

Downloading http://ufldl.stanford.edu/housenumbers/train_32x32.mat to ./data\train_32x32.mat


100%|███████████████████████████████████████████████████████████████▉| 182034432/182040794 [07:05<00:00, 423821.79it/s]

Downloading http://ufldl.stanford.edu/housenumbers/test_32x32.mat to ./data\test_32x32.mat



0it [00:00, ?it/s]
  0%|                                                                                     | 0/64275384 [00:00<?, ?it/s]
  0%|                                                                     | 16384/64275384 [00:00<10:12, 104892.49it/s]
  0%|                                                                     | 32768/64275384 [00:00<09:54, 108133.01it/s]
  0%|                                                                     | 49152/64275384 [00:00<09:59, 107136.46it/s]
  0%|                                                                     | 65536/64275384 [00:00<09:44, 109798.14it/s]
  0%|                                                                     | 98304/64275384 [00:01<08:11, 130501.19it/s]
  0%|▏                                                                   | 122880/64275384 [00:01<07:46, 137535.65it/s]
  0%|▏                                                                   | 163840/64275384 [00:01<06:32, 163419.32it/s]
  0%|▏              

  6%|████▏                                                              | 4014080/64275384 [00:10<02:24, 417933.09it/s]
  6%|████▎                                                              | 4079616/64275384 [00:10<02:19, 431343.45it/s]
  6%|████▎                                                              | 4145152/64275384 [00:10<02:20, 427727.63it/s]
  7%|████▍                                                              | 4210688/64275384 [00:10<02:16, 438563.17it/s]
  7%|████▍                                                              | 4276224/64275384 [00:11<02:18, 432677.06it/s]
  7%|████▌                                                              | 4341760/64275384 [00:11<02:15, 442197.94it/s]
  7%|████▌                                                              | 4399104/64275384 [00:11<02:18, 431320.58it/s]
  7%|████▋                                                              | 4464640/64275384 [00:11<02:19, 427701.51it/s]
  7%|████▋                              

 13%|████████▋                                                          | 8331264/64275384 [00:20<02:10, 428490.15it/s]
 13%|████████▊                                                          | 8396800/64275384 [00:20<02:11, 425759.29it/s]
 13%|████████▊                                                          | 8462336/64275384 [00:20<02:11, 423868.89it/s]
 13%|████████▉                                                          | 8527872/64275384 [00:20<02:07, 435701.50it/s]
 13%|████████▉                                                          | 8593408/64275384 [00:21<02:09, 430738.75it/s]
 13%|█████████                                                          | 8650752/64275384 [00:21<02:11, 423620.07it/s]
 14%|█████████                                                          | 8716288/64275384 [00:21<02:11, 422376.41it/s]
 14%|█████████▏                                                         | 8781824/64275384 [00:21<02:11, 421512.90it/s]
 14%|█████████▏                         

 20%|█████████████                                                     | 12697600/64275384 [00:30<02:02, 421280.29it/s]
 20%|█████████████                                                     | 12763136/64275384 [00:30<02:02, 420728.64it/s]
 20%|█████████████▏                                                    | 12828672/64275384 [00:30<01:58, 433386.76it/s]
 20%|█████████████▏                                                    | 12894208/64275384 [00:31<01:59, 429161.58it/s]
 20%|█████████████▎                                                    | 12959744/64275384 [00:31<01:56, 439624.61it/s]
 20%|█████████████▎                                                    | 13025280/64275384 [00:31<01:58, 433371.22it/s]
 20%|█████████████▍                                                    | 13090816/64275384 [00:31<01:59, 429144.28it/s]
 20%|█████████████▌                                                    | 13148160/64275384 [00:31<02:01, 422437.76it/s]
 21%|█████████████▌                     

 27%|█████████████████▌                                                | 17072128/64275384 [00:40<01:50, 426421.92it/s]
 27%|█████████████████▌                                                | 17137664/64275384 [00:41<01:47, 437574.28it/s]
 27%|█████████████████▋                                                | 17195008/64275384 [00:41<01:53, 413743.27it/s]
 27%|█████████████████▋                                                | 17260544/64275384 [00:41<01:49, 428182.21it/s]
 27%|█████████████████▊                                                | 17326080/64275384 [00:41<01:50, 425578.44it/s]
 27%|█████████████████▊                                                | 17391616/64275384 [00:41<01:47, 436951.38it/s]
 27%|█████████████████▉                                                | 17457152/64275384 [00:41<01:48, 431573.17it/s]
 27%|█████████████████▉                                                | 17522688/64275384 [00:41<01:49, 427885.12it/s]
 27%|██████████████████                 

 33%|██████████████████████                                            | 21438464/64275384 [00:51<01:38, 433966.34it/s]
 33%|██████████████████████                                            | 21504000/64275384 [00:51<01:39, 429541.93it/s]
 34%|██████████████████████▏                                           | 21569536/64275384 [00:51<01:37, 439902.24it/s]
 34%|██████████████████████▏                                           | 21626880/64275384 [00:51<01:42, 415193.69it/s]
 34%|██████████████████████▎                                           | 21692416/64275384 [00:51<01:42, 416432.03it/s]
 34%|██████████████████████▎                                           | 21757952/64275384 [00:51<01:38, 430247.12it/s]
 34%|██████████████████████▍                                           | 21823488/64275384 [00:51<01:39, 426973.22it/s]
 34%|██████████████████████▍                                           | 21889024/64275384 [00:52<01:36, 437999.13it/s]
 34%|██████████████████████▌            

 40%|██████████████████████████▍                                       | 25804800/64275384 [01:01<01:32, 416831.71it/s]
 40%|██████████████████████████▌                                       | 25870336/64275384 [01:01<01:29, 430492.69it/s]
 40%|██████████████████████████▋                                       | 25935872/64275384 [01:01<01:29, 427136.11it/s]
 40%|██████████████████████████▋                                       | 26001408/64275384 [01:01<01:27, 438139.42it/s]
 41%|██████████████████████████▊                                       | 26066944/64275384 [01:01<01:28, 432383.16it/s]
 41%|██████████████████████████▊                                       | 26124288/64275384 [01:02<01:32, 410476.81it/s]
 41%|██████████████████████████▉                                       | 26189824/64275384 [01:02<01:29, 425727.38it/s]
 41%|██████████████████████████▉                                       | 26255360/64275384 [01:02<01:29, 423844.60it/s]
 41%|███████████████████████████        

 47%|██████████████████████████████▉                                   | 30171136/64275384 [01:11<01:20, 423609.43it/s]
 47%|███████████████████████████████                                   | 30236672/64275384 [01:11<01:20, 422382.44it/s]
 47%|███████████████████████████████                                   | 30302208/64275384 [01:11<01:20, 420459.82it/s]
 47%|███████████████████████████████▏                                  | 30367744/64275384 [01:11<01:18, 429998.39it/s]
 47%|███████████████████████████████▏                                  | 30433280/64275384 [01:12<01:19, 426801.32it/s]
 47%|███████████████████████████████▎                                  | 30498816/64275384 [01:12<01:17, 437853.66it/s]
 48%|███████████████████████████████▍                                  | 30556160/64275384 [01:12<01:21, 413912.95it/s]
 48%|███████████████████████████████▍                                  | 30621696/64275384 [01:12<01:18, 428309.61it/s]
 48%|███████████████████████████████▌   

 54%|███████████████████████████████████▍                              | 34545664/64275384 [01:21<01:08, 432294.77it/s]
 54%|███████████████████████████████████▌                              | 34611200/64275384 [01:21<01:09, 428412.24it/s]
 54%|███████████████████████████████████▌                              | 34668544/64275384 [01:21<01:10, 422001.82it/s]
 54%|███████████████████████████████████▋                              | 34734080/64275384 [01:22<01:10, 421254.59it/s]
 54%|███████████████████████████████████▋                              | 34799616/64275384 [01:22<01:07, 433785.49it/s]
 54%|███████████████████████████████████▊                              | 34865152/64275384 [01:22<01:08, 429404.20it/s]
 54%|███████████████████████████████████▊                              | 34930688/64275384 [01:22<01:06, 439780.21it/s]
 54%|███████████████████████████████████▉                              | 34996224/64275384 [01:22<01:07, 433518.49it/s]
 55%|███████████████████████████████████

 61%|███████████████████████████████████████▉                          | 38912000/64275384 [01:31<00:59, 428885.81it/s]
 61%|████████████████████████████████████████                          | 38977536/64275384 [01:32<00:59, 426047.51it/s]
 61%|████████████████████████████████████████                          | 39043072/64275384 [01:32<00:57, 437315.98it/s]
 61%|████████████████████████████████████████▏                         | 39100416/64275384 [01:32<01:00, 413577.32it/s]
 61%|████████████████████████████████████████▏                         | 39165952/64275384 [01:32<00:58, 428057.18it/s]
 61%|████████████████████████████████████████▎                         | 39231488/64275384 [01:32<00:58, 425455.82it/s]
 61%|████████████████████████████████████████▎                         | 39297024/64275384 [01:32<00:57, 436899.01it/s]
 61%|████████████████████████████████████████▍                         | 39362560/64275384 [01:32<00:57, 431515.45it/s]
 61%|███████████████████████████████████

 67%|████████████████████████████████████████████▍                     | 43278336/64275384 [01:42<00:48, 431693.14it/s]
 67%|████████████████████████████████████████████▌                     | 43343872/64275384 [01:42<00:48, 427972.46it/s]
 68%|████████████████████████████████████████████▌                     | 43409408/64275384 [01:42<00:47, 438748.99it/s]
 68%|████████████████████████████████████████████▋                     | 43474944/64275384 [01:42<00:48, 432797.52it/s]
 68%|████████████████████████████████████████████▋                     | 43540480/64275384 [01:42<00:48, 428728.25it/s]
 68%|████████████████████████████████████████████▊                     | 43597824/64275384 [01:42<00:48, 422249.54it/s]
 68%|████████████████████████████████████████████▊                     | 43663360/64275384 [01:42<00:48, 421428.32it/s]
 68%|████████████████████████████████████████████▉                     | 43728896/64275384 [01:43<00:47, 433913.40it/s]
 68%|███████████████████████████████████

 74%|████████████████████████████████████████████████▉                 | 47644672/64275384 [01:52<00:39, 421371.74it/s]
 74%|████████████████████████████████████████████████▉                 | 47710208/64275384 [01:52<00:39, 420814.70it/s]
 74%|█████████████████████████████████████████████████                 | 47775744/64275384 [01:52<00:38, 433458.06it/s]
 74%|█████████████████████████████████████████████████                 | 47841280/64275384 [01:52<00:38, 429181.57it/s]
 75%|█████████████████████████████████████████████████▏                | 47906816/64275384 [01:52<00:38, 426237.69it/s]
 75%|█████████████████████████████████████████████████▎                | 47972352/64275384 [01:53<00:37, 437471.26it/s]
 75%|█████████████████████████████████████████████████▎                | 48029696/64275384 [01:53<00:39, 413674.10it/s]
 75%|█████████████████████████████████████████████████▍                | 48095232/64275384 [01:53<00:37, 428129.75it/s]
 75%|███████████████████████████████████

 81%|█████████████████████████████████████████████████████▍            | 52019200/64275384 [02:02<00:28, 432278.99it/s]
 81%|█████████████████████████████████████████████████████▍            | 52076544/64275384 [02:02<00:29, 410443.86it/s]
 81%|█████████████████████████████████████████████████████▌            | 52142080/64275384 [02:02<00:28, 425701.98it/s]
 81%|█████████████████████████████████████████████████████▌            | 52207616/64275384 [02:02<00:28, 423828.73it/s]
 81%|█████████████████████████████████████████████████████▋            | 52273152/64275384 [02:03<00:27, 435691.69it/s]
 81%|█████████████████████████████████████████████████████▋            | 52338688/64275384 [02:03<00:27, 430712.80it/s]
 82%|█████████████████████████████████████████████████████▊            | 52404224/64275384 [02:03<00:27, 427293.55it/s]
 82%|█████████████████████████████████████████████████████▉            | 52469760/64275384 [02:03<00:26, 438243.84it/s]
 82%|███████████████████████████████████

 88%|█████████████████████████████████████████████████████████▉        | 56385536/64275384 [02:12<00:18, 428701.03it/s]
 88%|█████████████████████████████████████████████████████████▉        | 56451072/64275384 [02:12<00:17, 439284.95it/s]
 88%|██████████████████████████████████████████████████████████        | 56516608/64275384 [02:12<00:17, 433134.39it/s]
 88%|██████████████████████████████████████████████████████████        | 56573952/64275384 [02:13<00:18, 410980.58it/s]
 88%|██████████████████████████████████████████████████████████▏       | 56639488/64275384 [02:13<00:17, 426105.19it/s]
 88%|██████████████████████████████████████████████████████████▏       | 56705024/64275384 [02:13<00:17, 424108.82it/s]
 88%|██████████████████████████████████████████████████████████▎       | 56770560/64275384 [02:13<00:17, 422720.11it/s]
 88%|██████████████████████████████████████████████████████████▎       | 56836096/64275384 [02:13<00:17, 434873.27it/s]
 89%|███████████████████████████████████

 95%|██████████████████████████████████████████████████████████████▍   | 60751872/64275384 [02:22<00:08, 423770.97it/s]
 95%|██████████████████████████████████████████████████████████████▍   | 60817408/64275384 [02:23<00:08, 422474.39it/s]
 95%|██████████████████████████████████████████████████████████████▌   | 60882944/64275384 [02:23<00:08, 421594.20it/s]
 95%|██████████████████████████████████████████████████████████████▌   | 60948480/64275384 [02:23<00:07, 434040.19it/s]
 95%|██████████████████████████████████████████████████████████████▋   | 61014016/64275384 [02:23<00:07, 429579.68it/s]
 95%|██████████████████████████████████████████████████████████████▋   | 61079552/64275384 [02:23<00:07, 439931.43it/s]
 95%|██████████████████████████████████████████████████████████████▊   | 61145088/64275384 [02:23<00:07, 433603.07it/s]
 95%|██████████████████████████████████████████████████████████████▊   | 61202432/64275384 [02:23<00:07, 411244.20it/s]
 95%|███████████████████████████████████

## Regular training

### ReLU Resnet18

In [None]:
device = torch.device('cuda:0')
model = resnet.ResNet18().to(device)
opt = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
eps = 0.047
for ep in range(80):
    if ep == 50:
        for param_group in opt.param_groups:
                param_group['lr'] = 0.01
    train_err, train_loss = training.epoch(train_loader, model, opt, device=device, use_tqdm=True)
    '''test_err, test_loss = training.epoch(test_loader, model, device=device, use_tqdm=True)
    adv_err, adv_loss = training.epoch_adversarial(test_loader,
        model, attack=attack.pgd_linf_untargeted, device=device, num_iter=20, 
        use_tqdm=True, epsilon=eps, randomize=True, alpha=0.004, n_test=1000)
    print('epoch', ep, 'train err', train_err, 'test err', test_err, 'adv_err', adv_err)'''
    torch.save(model.state_dict(), 'models/resnet18_svhn.pth')

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

### kWTA-0.1 Resnet18

In [None]:
device = torch.device('cuda:0')
model = resnet.SparseResNet18(sparsities=[0.1,0.1,0.1,0.1], sparse_func='vol').to(device)
opt = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
eps = 0.047
for ep in range(80):
    if ep == 50:
        for param_group in opt.param_groups:
                param_group['lr'] = 0.01
    train_err, train_loss = training.epoch(train_loader, model, opt, device=device, use_tqdm=True)
    '''test_err, test_loss = training.epoch(test_loader, model, device=device, use_tqdm=True)
    adv_err, adv_loss = training.epoch_adversarial(test_loader,
        model, attack=attack.pgd_linf_untargeted, device=device, num_iter=20, 
        use_tqdm=True, epsilon=eps, randomize=True, alpha=0.004, n_test=1000)
    print('epoch', ep, 'train err', train_err, 'test err', test_err, 'adv_err', adv_err)'''
    torch.save(model.state_dict(), 'models/spresnet18_0.1_svhn.pth')

### kWTA-0.2 Resnet18

In [None]:
device = torch.device('cuda:0')
model = resnet.SparseResNet18(sparsities=[0.2,0.2,0.2,0.2], sparse_func='vol').to(device)
opt = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
eps = 0.047
for ep in range(80):
    if ep == 50:
        for param_group in opt.param_groups:
                param_group['lr'] = 0.01
    train_err, train_loss = training.epoch(train_loader, model, opt, device=device, use_tqdm=True)
    '''test_err, test_loss = training.epoch(test_loader, model, device=device, use_tqdm=True)
    adv_err, adv_loss = training.epoch_adversarial(test_loader,
        model, attack=attack.pgd_linf_untargeted, device=device, num_iter=20, 
        use_tqdm=True, epsilon=eps, randomize=True, alpha=0.004, n_test=1000)
    print('epoch', ep, 'train err', train_err, 'test err', test_err, 'adv_err', adv_err)'''
    torch.save(model.state_dict(), 'models/spresnet18_0.2_svhn.pth')

## Adv Training

### ReLU Resnet18

In [None]:
device = torch.device('cuda:0')
model = resnet.ResNet18().to(device)
opt = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
eps = 0.047
for ep in range(80):
    if ep == 50:
        for param_group in opt.param_groups:
                param_group['lr'] = 0.01
    train_err, train_loss = training.epoch_adversarial(train_loader, model, opt=opt,
            attack=attack.pgd_linf_untargeted, device=device, num_iter=10, 
        use_tqdm=True, epsilon=eps, randomize=True, alpha=0.007)
    '''test_err, test_loss = training.epoch(test_loader, model, device=device, use_tqdm=True)
    adv_err, adv_loss = training.epoch_adversarial(test_loader,
        model, attack=attack.pgd_linf_untargeted, device=device, num_iter=20, 
        use_tqdm=True, epsilon=eps, randomize=True, alpha=0.004, n_test=1000)
    print('epoch', ep, 'train err', train_err, 'test err', test_err, 'adv_err', adv_err)'''
    torch.save(model.state_dict(), 'models/resnet18_svhn_adv.pth')

### kWTA-0.1 Resnet18

In [None]:
device = torch.device('cuda:0')
model = resnet.SparseResNet18(sparsities=[0.1,0.1,0.1,0.1], sparse_func='vol').to(device)
opt = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
eps = 0.047
for ep in range(80):
    if ep == 50:
        for param_group in opt.param_groups:
                param_group['lr'] = 0.01
    train_err, train_loss = training.epoch_adversarial(train_loader, model, opt=opt,
            attack=attack.pgd_linf_untargeted, device=device, num_iter=10, 
        use_tqdm=True, epsilon=eps, randomize=True, alpha=0.007)
    '''test_err, test_loss = training.epoch(test_loader, model, device=device, use_tqdm=True)
    adv_err, adv_loss = training.epoch_adversarial(test_loader,
        model, attack=attack.pgd_linf_untargeted, device=device, num_iter=20, 
        use_tqdm=True, epsilon=eps, randomize=True, alpha=0.004, n_test=1000)
    print('epoch', ep, 'train err', train_err, 'test err', test_err, 'adv_err', adv_err)'''
    torch.save(model.state_dict(), 'models/spresnet18_0.1_svhn_adv.pth')

### kWTA-0.2 Resnet18

In [None]:
device = torch.device('cuda:0')
model = resnet.SparseResNet18(sparsities=[0.2,0.2,0.2,0.2], sparse_func='vol').to(device)
opt = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
eps = 0.047
for ep in range(80):
    if ep == 50:
        for param_group in opt.param_groups:
                param_group['lr'] = 0.01
    train_err, train_loss = training.epoch_adversarial(train_loader, model, opt=opt,
            attack=attack.pgd_linf_untargeted, device=device, num_iter=10, 
        use_tqdm=True, epsilon=eps, randomize=True, alpha=0.007)
    '''test_err, test_loss = training.epoch(test_loader, model, device=device, use_tqdm=True)
    adv_err, adv_loss = training.epoch_adversarial(test_loader,
        model, attack=attack.pgd_linf_untargeted, device=device, num_iter=20, 
        use_tqdm=True, epsilon=eps, randomize=True, alpha=0.004, n_test=1000)
    print('epoch', ep, 'train err', train_err, 'test err', test_err, 'adv_err', adv_err)'''
    torch.save(model.state_dict(), 'models/spresnet18_0.2_svhn_adv.pth')