In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from fedlern.models.resnet18 import ResNet
from fedlern.models.resnet import ResNet18
from torchvision.utils import make_grid
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

In [21]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
num_epochs = 80
batch_size = 100
learning_rate = 0.01
learning_rate = 0.001
stats = (0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)
output_name = 'resnet18_cifar10.pt'
# Data augmentation and normalization for training
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])

# Normalization for testing
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])

# CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform_train, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform_test)

# Data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


Files already downloaded and verified


In [22]:
#resnet = ResNet(in_channels=16, num_classes=10)
resnet = ResNet18()
resnet.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [23]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet.parameters(), lr=learning_rate)


In [24]:
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = resnet(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

# Save the model checkpoint
torch.save(resnet.state_dict(), output_name)


Epoch [1/80], Step [100/500], Loss: 1.5893
Epoch [1/80], Step [200/500], Loss: 1.6888
Epoch [1/80], Step [300/500], Loss: 1.4042
Epoch [1/80], Step [400/500], Loss: 1.1451
Epoch [1/80], Step [500/500], Loss: 1.2718
Epoch [2/80], Step [100/500], Loss: 1.0475
Epoch [2/80], Step [200/500], Loss: 1.1194
Epoch [2/80], Step [300/500], Loss: 0.9110
Epoch [2/80], Step [400/500], Loss: 0.9690
Epoch [2/80], Step [500/500], Loss: 0.9167
Epoch [3/80], Step [100/500], Loss: 0.8683
Epoch [3/80], Step [200/500], Loss: 0.4993
Epoch [3/80], Step [300/500], Loss: 0.6908
Epoch [3/80], Step [400/500], Loss: 0.6737
Epoch [3/80], Step [500/500], Loss: 0.5372
Epoch [4/80], Step [100/500], Loss: 0.5285
Epoch [4/80], Step [200/500], Loss: 0.6403
Epoch [4/80], Step [300/500], Loss: 0.5730
Epoch [4/80], Step [400/500], Loss: 0.6314
Epoch [4/80], Step [500/500], Loss: 0.6161
Epoch [5/80], Step [100/500], Loss: 0.6180
Epoch [5/80], Step [200/500], Loss: 0.4865
Epoch [5/80], Step [300/500], Loss: 0.4524
Epoch [5/80

In [25]:
# Test the mode
model_dict = torch.load(output_name)
model = ResNet18()
model.load_state_dict(model_dict)
model.to(device)
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))

Accuracy of the model on the test images: 92.25 %
