In [1]:
# Params
input_size = (100,100)
batch_size = 512
num_workers = 8
num_classes = 10
learning_rate = 1e-03
SGD_momentum = 0.9
epochs = 10

# Training an image classifier

In [2]:
import torch

print("torch-version:", torch.__version__)
print("Available GPU:", torch.cuda.current_device())

torch-version: 1.1.0
Available GPU: 0


## 1. Loading and normalizing CIFAR10

In [3]:
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.Resize(input_size), # Must be ahead of transforms.ToTensor()
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

print()
print(trainset.__getitem__(0)[0].size(), trainset.__len__())
print(testset.__getitem__(0)[0].size(), testset.__len__())

Files already downloaded and verified
Files already downloaded and verified

torch.Size([3, 100, 100]) 50000
torch.Size([3, 100, 100]) 10000


## 2. Define a GoogLeNet

In [4]:
from googlenet import *

net = GoogLeNet(in_channel=3, num_classes=num_classes, aux_block=True).cuda()

#for i in net.named_children():
#    print(i)

## 3. Define a Loss function and optimizer

In [5]:
from torch.optim.lr_scheduler import StepLR

#criterion = torch.nn.BCEWithLogitsLoss()
criterion = torch.nn.CrossEntropyLoss()
#criterion = torch.nn.BCELoss()
#optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=SGD_momentum)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=8, gamma=0.96)

## 4. Train the network

In [6]:
times = 1
aux_loss_weight = 0.3

for epoch in range(epochs):  # loop over the dataset multiple times
    net.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader):
        # get the inputs
        inputs, labels = data
        inputs = inputs.cuda()
        labels = labels.cuda()
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        aux_1, aux_2, outputs = net(inputs) #tuple: (aux1, aux2, input)
        aux_1 = aux_1.reshape(batch_size, num_classes)
        aux_2 = aux_2.reshape(batch_size, num_classes)
        outputs = outputs.reshape(batch_size, num_classes)
        
        loss1 = criterion(aux_1, labels)
        loss2 = criterion(aux_2, labels)
        loss3 = criterion(outputs, labels)
        
        loss = aux_loss_weight*loss1 + aux_loss_weight*loss2 + loss3
        loss.backward()
        ##scheduler.step()
        optimizer.step()
        
        running_loss += loss.item()
        
    if epoch % times == 0:
        print('[%d] loss: %.3f' %(epoch, running_loss / times))
        running_loss = 0.0
        #torch.save(model,'./data/cifar_model.pkl')

print('\nFinished Training')

[0] loss: 277.942
[1] loss: 200.853
[2] loss: 164.138
[3] loss: 135.534
[4] loss: 115.763
[5] loss: 98.941
[6] loss: 87.949
[7] loss: 76.329
[8] loss: 65.530
[9] loss: 56.735

Finished Training


## 5. Test the network on the test data

In [7]:
correct = 0
total = 0

with torch.no_grad():
    net.eval()
    for data in testloader:
        images, labels = data
        images = images.cuda() #.cpu()
        labels = labels.cuda() #.cpu()
        
        outputs = net(images).reshape(batch_size, num_classes)
        #print(outputs.size())
        
        _, predicted = torch.max(outputs, dim=1)
        
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' %(100 * correct / total))

Accuracy of the network on the 10000 test images: 80 %
