In [2]:
"""

"""

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchsummary import summary
import sys
sys.path.append('../')
from wideresnet import Wide_ResNet


'''
Function that loads the dataset and returns the data-loaders
'''
def getData(batch_size,test_batch_size,val_percentage):
    # Normalize the training set with data augmentation
    transform_train = transforms.Compose([ 
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        torchvision.transforms.RandomRotation(20),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    # Normalize the test set same as training set without augmentation
    transform_test = transforms.Compose([ 
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # Download/Load data
    full_training_data = torchvision.datasets.CIFAR100('/home/test/data',train=True,transform=transform_train,download=True)  
    test_data = torchvision.datasets.CIFAR100('/home/test/data',train=False,transform=transform_test,download=True)  

    # Create train and validation splits
    num_samples = len(full_training_data)
    training_samples = int((1-val_percentage)*num_samples+1)
    validation_samples = num_samples - training_samples
    training_data, validation_data = torch.utils.data.random_split(full_training_data, [training_samples, validation_samples])

    # Initialize dataloaders
    train_loader = torch.utils.data.DataLoader(training_data,batch_size=batch_size,shuffle=True)
    val_loader = torch.utils.data.DataLoader(validation_data,batch_size=batch_size,shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_data,batch_size=test_batch_size,shuffle=False)

    return train_loader, val_loader, test_loader

'''
Function to test that returns the loss per sample and the total accuracy
'''
def test(data_loader,net,cost_fun,device):
  
    net.eval()
    samples = 0.
    cumulative_loss = 0.
    cumulative_accuracy = 0.

    for batch_idx, (inputs,targets) in enumerate(data_loader):

        inputs = inputs.to(device)
        targets = targets.to(device)

        outputs = net(inputs)[0]
        loss = cost_fun(outputs,targets)

        # Metrics computation
        samples+=inputs.shape[0]
        cumulative_loss += loss.item()
        _, predicted = outputs.max(1)
        cumulative_accuracy += predicted.eq(targets).sum().item()

    return cumulative_loss/samples, cumulative_accuracy/samples*100

'''
Function to train the nework on the data for one epoch that returns the loss per sample and the total accuracy
'''
def train(data_loader,net,cost_fun,device,optimizer):
    
    net.train()
    samples = 0.
    cumulative_loss = 0.
    cumulative_accuracy = 0.

    for batch_idx, (inputs,targets) in enumerate(data_loader):

        inputs = inputs.to(device)
        targets = targets.to(device)

        outputs = net(inputs)[0]
        loss = cost_fun(outputs,targets)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Metrics computation
        samples+=inputs.shape[0]
        cumulative_loss += loss.item()
        _, predicted = outputs.max(1)
        cumulative_accuracy += predicted.eq(targets).sum().item()

    return cumulative_loss/samples, cumulative_accuracy/samples*100

def main(epochs, batch_size, test_batch_size,val_percentage,lr,test_freq, net_depth, net_width):
    
    # Define cost function
    cost_function = torch.nn.CrossEntropyLoss()

    # Create the network: Wide_ResNet(depth, width, dropout, num_classes)
    net = Wide_ResNet(net_depth,net_width,0,100)
    net = net.to(device)
    summary(net,input_size=(3,32,32))

    # Create the optimizer anche the learning rate scheduler
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
                    milestones=[int(epochs*0.3),int(epochs*0.6),int(epochs*0.8)], gamma=0.20)

    # Get the data
    train_loader, val_loader, test_loader = getData(batch_size,test_batch_size,val_percentage)
    
    save_filename = './CIFAR100-teacher-' + str(net_depth) + '-' + str(net_width) + '.pth'

    for e in range(epochs):
        net.train() 

        train_loss, train_accuracy = train(train_loader,net,cost_function,device,optimizer)

        val_loss, val_accuracy = test(val_loader,net,cost_function,device)
        
        scheduler.step()

        print('Epoch: {:d}:'.format(e+1))
        print('\t Training loss: \t {:.6f}, \t Training accuracy \t {:.2f}'.format(train_loss, train_accuracy))
        print('\t Validation loss: \t {:.6f},\t Validation accuracy \t {:.2f}'.format(val_loss, val_accuracy))
        
        if((e+1) % test_freq) == 0:
            test_loss, test_accuracy = test(test_loader,net,cost_function,device)
            torch.save(net.state_dict(), save_filename)
            print('Test loss: \t {:.6f}, \t \t Test accuracy \t {:.2f}'.format(test_loss, test_accuracy))

    print('After training:')
    train_loss, train_accuracy = test(train_loader,net,cost_function,device)
    val_loss, val_accuracy = test(val_loader,net,cost_function,device)
    test_loss, test_accuracy = test(test_loader,net,cost_function,device)

    print('\t Training loss: \t {:.6f}, \t Training accuracy \t {:.2f}'.format(train_loss, train_accuracy))
    print('\t Validation loss: \t {:.6f},\t Validation accuracy \t {:.2f}'.format(val_loss, val_accuracy))
    print('Test loss: \t {:.6f}, \t \t Test accuracy \t {:.2f}'.format(test_loss, test_accuracy))
    
    torch.save(net.state_dict(), save_filename)

    net2 = Wide_ResNet(net_depth,net_width,0,100)
    net2 = net.to(device)
    net2.load_state_dict(torch.load(save_filename))
    
    print('loaded net test:')
    test_loss, test_accuracy = test(test_loader,net2,cost_function,device)
    print('\t Test loss: \t {:.6f}, \t Test accuracy \t {:.2f}'.format(test_loss, test_accuracy))

    
# Parameters
epochs = 2
batch_size = 128
test_batch_size = 128
val_percentage = 0.01
lr = 0.1
test_freq = 1
device = 'cuda:0'
net_depth = 40
net_width = 2


# Call the main
main(epochs, batch_size, test_batch_size,val_percentage,lr,test_freq, net_depth, net_width)





| Wide-Resnet 40x1
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             432
       BatchNorm2d-2           [-1, 16, 32, 32]              32
            Conv2d-3           [-1, 16, 32, 32]           2,304
           Dropout-4           [-1, 16, 32, 32]               0
       BatchNorm2d-5           [-1, 16, 32, 32]              32
            Conv2d-6           [-1, 16, 32, 32]           2,304
        wide_basic-7           [-1, 16, 32, 32]               0
       BatchNorm2d-8           [-1, 16, 32, 32]              32
            Conv2d-9           [-1, 16, 32, 32]           2,304
          Dropout-10           [-1, 16, 32, 32]               0
      BatchNorm2d-11           [-1, 16, 32, 32]              32
           Conv2d-12           [-1, 16, 32, 32]           2,304
       wide_basic-13           [-1, 16, 32, 32]               0
      BatchNorm2d-14

100.0%

Extracting /home/test/data/cifar-100-python.tar.gz to /home/test/data
Files already downloaded and verified
Epoch: 1:
	 Training loss: 	 0.032007, 	 Training accuracy 	 7.52
	 Validation loss: 	 0.030731,	 Validation accuracy 	 8.42
Test loss: 	 0.029458, 	 	 Test accuracy 	 11.78
Epoch: 2:
	 Training loss: 	 0.028587, 	 Training accuracy 	 13.32
	 Validation loss: 	 0.029195,	 Validation accuracy 	 12.83
Test loss: 	 0.028319, 	 	 Test accuracy 	 13.93
After training:
	 Training loss: 	 0.028175, 	 Training accuracy 	 14.49
	 Validation loss: 	 0.029193,	 Validation accuracy 	 12.63
Test loss: 	 0.028319, 	 	 Test accuracy 	 13.93
| Wide-Resnet 40x1
loaded net test:
	 Test loss: 	 0.028319, 	 Test accuracy 	 13.93
