In [8]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data.sampler import SubsetRandomSampler
from torch import nn
from torch.nn import functional as F
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#dataset
trtransform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Normalize the test set same as training set without augmentation
tetransform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=trtransform)
test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=tetransform)

num = len(train)
indices = list(range(num))
split = int(np.floor(0.1 * num))
train_idx, valid_idx = indices[split:], indices[:split]
np.random.seed(10)
np.random.shuffle(indices)
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

trainloader = torch.utils.data.DataLoader(train, batch_size=64, sampler=train_sampler)
validloader = torch.utils.data.DataLoader(train, batch_size=64, sampler=valid_sampler)
testloader = torch.utils.data.DataLoader(test, batch_size=64, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv1_bn = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv2_bn = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.conv3_bn = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.dropout1 = nn.Dropout2d(p=0.05)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 4 * 4, 120)
        self.fc1_bn = nn.BatchNorm1d(120)
        self.fc2 = nn.Linear(120, 60)
        self.fc2_bn = nn.BatchNorm1d(60)
        self.fc3 = nn.Linear(60, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1_bn(self.conv1(x))))
        x = self.pool(F.relu(self.conv2_bn(self.conv2(x))))
        x = self.pool(F.relu(self.conv3_bn(self.conv3(x))))
        x = F.relu(self.conv4(x))
        x = x.view(-1, 128 *4 * 4)
        x = F.relu(self.fc1_bn(self.fc1(x)))
        x = F.relu(self.fc2_bn(self.fc2(x)))
        x = self.dropout1(x)
        x = self.fc3(x)
        return x

model = Net().to(device)

#hyperparameter
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.05)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)
min_val_loss = 1
count=0

#Train
for epoch in range(1,101):
  model.train()
  train_loss = 0.0
  val_loss=0.0
  correct = 0
  for image, label in trainloader:
    image, label = image.to(device), label.to(device)
    
    optimizer.zero_grad()
    output = model(image)
    loss = criterion(output, label)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()
  
  model.eval()
  with torch.no_grad():
    for image, label in validloader:
      image, label = image.to(device), label.to(device)
      output = model(image)
      loss = criterion(output, label)
      val_loss += loss.item() / 5000
      _, pred = torch.max(output.data, 1)
      correct += (pred == label).sum().item()
  print('Epoch:{}/{}, train_loss:{}, validation Accuracy: {}'.format(epoch, 100, train_loss / 45000, 100 * correct / 5000))
  scheduler.step()
  if val_loss < min_val_loss:
    count=0
    min_val_loss = val_loss
  else:
    count+=1
  if count == 7:
    print("Early Stopped!")
    break

#Test acc
correct = 0
tot = 0
with torch.no_grad():
  for image, label in testloader:
    image, label = image.to(device), label.to(device)
    output = model(image)
    _, pred = torch.max(output.data, 1)
    correct += (pred == label).sum().item()

print('Test Accuracy: {}'.format(100 * correct / len(testloader.dataset)))

Files already downloaded and verified
Files already downloaded and verified
Epoch:1/100, train_loss:0.026910161736276413, validation Accuracy: 41.6
Epoch:2/100, train_loss:0.021499770487679374, validation Accuracy: 50.76
Epoch:3/100, train_loss:0.018492825645870632, validation Accuracy: 59.8
Epoch:4/100, train_loss:0.01671929239961836, validation Accuracy: 62.48
Epoch:5/100, train_loss:0.015444826420148213, validation Accuracy: 69.16
Epoch:6/100, train_loss:0.01447360364596049, validation Accuracy: 68.4
Epoch:7/100, train_loss:0.013588668402036031, validation Accuracy: 70.66
Epoch:8/100, train_loss:0.012979773398902682, validation Accuracy: 72.06
Epoch:9/100, train_loss:0.012333800266186397, validation Accuracy: 74.0
Epoch:10/100, train_loss:0.011831282095776665, validation Accuracy: 73.44
Epoch:11/100, train_loss:0.011219521139065424, validation Accuracy: 75.48
Epoch:12/100, train_loss:0.010872043123510148, validation Accuracy: 77.7
Epoch:13/100, train_loss:0.010553104105922912, valid