In [1]:
import torch
from torchvision import transforms, datasets
import models

In [2]:
"""
Configuration and Hyperparameters
"""
torch.set_default_tensor_type(torch.cuda.FloatTensor)  # default all in GPU

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # can omit
    transforms.RandomHorizontalFlip(),  # can omit
    transforms.ToTensor(),
    transforms.Normalize(
        (0.4914, 0.4822, 0.4465),
        (0.2023, 0.1994, 0.2010)
    )
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        (0.4914, 0.4822, 0.4465),
        (0.2023, 0.1994, 0.2010)
    )
])

batch_size = 128
step_size = 0.1
random_seed = 0
epochs = 100
L2_decay = 1e-4
gauss_vicinal_std = 0.25

torch.manual_seed(random_seed)

<torch._C.Generator at 0x1c041f104d0>

In [3]:
"""
Data
"""
train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
model = models.__dict__['ResNet18']()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=step_size, momentum=0.9, weight_decay=L2_decay)

In [5]:
def gauss_vicinal(inputs, gauss_vicinal_std):
    inputs_gauss = torch.normal(inputs, gauss_vicinal_std)
    return inputs_gauss

In [6]:
"""
Training
"""
model.train()
for epoch in range(epochs):
    epoch_loss = 0.
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        inputs_gauss = gauss_vicinal(inputs, gauss_vicinal_std)
        optimizer.zero_grad()
        outputs = model(inputs_gauss)
        loss = criterion(outputs, labels)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
    print('{}: {}'.format(epoch, epoch_loss))

0: 627.465212225914
1: 444.6420153975487
2: 355.1243146657944
3: 297.7835254371166
4: 260.7003476023674
5: 231.77259382605553
6: 214.85467541217804
7: 197.6988164782524
8: 186.06376150250435
9: 172.3618279993534
10: 165.77817930281162
11: 156.5470260977745
12: 149.48277381062508
13: 142.67068864405155
14: 137.3132664859295
15: 130.78425869345665
16: 128.45121905207634
17: 123.56266961991787
18: 118.02041666209698
19: 114.61374358832836
20: 111.10151681303978
21: 108.44731068611145
22: 105.71356572210789
23: 104.81287098675966
24: 101.08970533311367
25: 97.83019200712442
26: 97.67572692781687
27: 94.27546255290508
28: 92.01549385488033
29: 92.44790637493134
30: 89.16870843619108
31: 85.66438357532024
32: 85.36193803697824
33: 84.76138945668936
34: 83.02898990362883
35: 79.43821177631617
36: 81.97182720154524
37: 80.86381334066391
38: 78.70818359404802
39: 77.01568362116814
40: 76.82536789774895
41: 73.53598108887672
42: 73.92624707520008
43: 73.23817817121744
44: 71.91749671846628
45: 7

In [7]:
torch.save(model.state_dict(), './gauss_model_pytorch_cifar10')
model = models.__dict__['ResNet18']()
model.load_state_dict(torch.load('./gauss_model_pytorch_cifar10'))

<All keys matched successfully>

In [8]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.8917


In [9]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in train_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.96418
