In [1]:
import torch
from torchvision import transforms, datasets, models

In [2]:
"""
Configuration and Hyperparameters
"""
torch.set_default_tensor_type(torch.cuda.FloatTensor)  # default all in GPU

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

batch_size = 128
step_size = 0.01
random_seed = 0
epochs = 80
L2_decay = 1e-4
gauss_vicinal_std = 0.25

torch.manual_seed(random_seed)

<torch._C.Generator at 0x1d1354514d0>

In [3]:
"""
Data
"""
train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
test_set = datasets.MNIST('./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=8)

In [4]:
model = models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = True
model.conv1 = torch.nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
model.fc = torch.nn.Linear(512, 10)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=step_size, momentum=0.9, weight_decay=L2_decay)

In [5]:
def gauss_vicinal(inputs, gauss_vicinal_std):
    inputs_gauss = torch.normal(inputs, gauss_vicinal_std)
    return inputs_gauss

In [6]:
"""
Training
"""
model.train()
for epoch in range(epochs):
    epoch_loss = 0.
    epoch_gauss_loss = 0.
    epoch_org_loss = 0.
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        inputs_gauss = gauss_vicinal(inputs, gauss_vicinal_std)
        optimizer.zero_grad()
        outputs = model(inputs_gauss)
        
        ##
        gauss_loss = criterion(outputs, labels)
        
        outputs_org = model(inputs)
        loss_org = criterion(outputs_org, labels)
        total_loss = gauss_loss + loss_org
        
        epoch_gauss_loss += gauss_loss.item()
        epoch_org_loss += loss_org.item()
        
        epoch_loss += total_loss.item()
        total_loss.backward()
        ##
        
        optimizer.step()
    print('{}: {} {} {}'.format(epoch, epoch_gauss_loss, epoch_org_loss, epoch_loss))

0: 79.05954274348915 76.25249055260792 155.31203251332045
1: 23.715662434580736 21.932152105029672 45.64781444054097
2: 15.561147149885073 14.43421734625008 29.99536444316618
3: 12.56916669983184 11.293473368830746 23.86264003108954
4: 9.80826793028973 8.594138295622543 18.40240619733231
5: 8.273872400313849 7.019435923837591 15.293308278953191
6: 7.573212740273448 6.3866553793195635 13.959868192439899
7: 6.064281513710739 5.30011637568532 11.364397897501476
8: 4.884694334730739 4.055035146368027 8.93972944834968
9: 4.192387229788437 3.1204996980159194 7.3128869218198815
10: 3.406780550030817 2.5840504018051433 5.990830956128775
11: 2.773184236295492 1.9674943714617257 4.740678614602075
12: 3.5567236715833133 2.7312019145429076 6.28792560337024
13: 3.2624796344098286 2.4496154176340497 5.712095046961622
14: 2.554334854316039 1.7586563883069175 4.3129912359436275
15: 3.274953437412478 2.420976756791788 5.695930189860519
16: 2.24120836532893 1.808267082549719 4.049475440275273
17: 1.9408

In [7]:
torch.save(model.state_dict(), './gauss_model_pytorch_mnist')
model = models.resnet18(pretrained=False)
model.conv1 = torch.nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
model.fc = torch.nn.Linear(512, 10)
model.load_state_dict(torch.load('./gauss_model_pytorch_mnist'))

<All keys matched successfully>

In [8]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9929


In [9]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in train_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9999666666666667
