In [41]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
# load dataset
train_dataset = CIFAR10(root='./data', train=True, download=True, 
                        transform=ToTensor())
test_dataset = CIFAR10(root='./data', train=False, download=True, 
                        transform=ToTensor())

# create data loader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [16:45<00:00, 169494.18it/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [39]:
device = "cuda" if torch.cuda.is_available() else "cpu"


class CNN(nn.Module):
    def __init__(self, in_channels, output_size):
        super().__init__()
        # CIFAR10 ===> 32 * 32 * 3
        self.cnn1 = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=5)
        # size ==> 28 * 28 * 64
        self.maxpool = nn.MaxPool2d(kernel_size=2)
        # size ==> 14 * 14 * 64
        self.cnn2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5)
        # size ===> 10 * 10 * 128
        # max pool ==>  size ==> 5 * 5 * 128
        self.fc1 = nn.Linear(128 * 5 * 5, 300)
        self.fc2 = nn.Linear(300, 10)



    def forward(self, x):
        x = F.relu(self.cnn1(x))
        x = self.maxpool(x)
        x =  F.relu(self.cnn2(x))
        x = self.maxpool(x)
        x = x.view(x.size(0), -1) ### flatten
        
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [40]:
data, _ = next(iter(train_loader))

model = CNN(3, 10)
model(data).size()


torch.Size([32, 10])

In [42]:
def run(model, input_size, normal_init=False):
  torch.manual_seed(42)
  model.to(device=device)

  # Define hyperparameters
  learning_rate = 1e-3
  batch_size = 64
  num_epochs = 100

  def init_weight(m):
    if isinstance(m, nn.Linear):
      nn.init.kaiming_normal_(m.weight)
      nn.init.zeros_(m.bias)

  if normal_init:
    model.apply(init_weight)
  
  transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  ])

  # load dataset
  train_dataset = CIFAR10(root='./data', train=True, download=True, 
                          transform=transform)
  test_dataset = CIFAR10(root='./data', train=False, download=True, 
                          transform=transform)
  
  # create data loader
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  test_loader = DataLoader(test_dataset, batch_size=batch_size)


  # Define the loss and optimizer
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr=learning_rate) # Nadam = Adam + Nesterove

  # training loop

  # Early stoping
  best_accuracy = 0.0
  patience = 5
  epoch_num_improve = 0

  for epoch in range(num_epochs):

    model.train()
    train_loss = 0.0
    train_correct = 0

    # batch
    for images, labels in tqdm(train_loader):


      images = images.to(device=device)
      labels = labels.to(device=device)
      
      # Forward pass
      output = model(images)
      loss = criterion(output, labels)

      # TODO: every 4 batch, update?!
      # Backward and optimize
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      train_loss += loss.item()
      _, predicted = torch.max(output.data, 1)
      train_correct += (predicted == labels).sum().item()
    train_accuracy = train_correct / len(train_dataset)
    train_loss /= len(train_loader)

    # Evaluate
    model.eval()
    test_correct = 0

    with torch.no_grad():
      for images, labels in test_loader:

        images = images.to(device=device)
        labels = labels.to(device=device)

        output = model(images)

        _, predicted = torch.max(output.data, 1)
        test_correct += (predicted == labels).sum().item()
    test_accuracy = test_correct / len(test_dataset)

    print(f"Epoch: {epoch + 1}/{num_epochs} | Train loss: {train_loss:.4f} | Train Acc: {train_accuracy:.4f} | Test Acc: {test_accuracy:.4f}")

    # Early stopping + save Checkpoint
    if test_accuracy > best_accuracy:
      best_accuracy = test_accuracy
      epoch_num_improve = 0
      torch.save(model.state_dict(), "best_model.pth")
    else:
      epoch_num_improve += 1
      if epoch_num_improve > patience:
        print(f"Early stopping, best accuracy {best_accuracy}")
        break

In [None]:
input_size = 32 * 32 * 3 # CIFAR10 color image(RGB), pixels with 3 channels
hidden_size = 100
num_layers = 20
output_size = 10

model = CNN(3, 10)
run(model, input_size)