In [1]:
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision import datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
from ray.train import Checkpoint, get_checkpoint
from ray.tune.schedulers import ASHAScheduler
import ray.cloudpickle as pickle
from ray import tune
import tempfile
from functools import partial
from ray import train
from pathlib import Path
import os
%matplotlib inline

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
def load_data(data_dir="./data"):
    trainset = datasets.MNIST(
        root=data_dir, train=True, transform=ToTensor(), download=True)

    testset = datasets.MNIST(

        root=data_dir, train=False, transform=ToTensor(), download=True)
    
    return trainset, testset


trainset, testset = load_data()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:16<00:00, 600198.83it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

In [None]:
fig = plt.figure()
cols, rows = 5, 2

for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(trainset), size=(1, 1)).item()
    img, label = trainset[sample_idx]
    fig.add_subplot(rows, cols, i)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap='gray')

plt.show()

Image size after convolution operation: 
$$
\frac{(w/h - k + 2p)}{s} + 1
$$

In [None]:
class CNN(nn.Module):
    def __init__(self, l1=100, l2=25) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=1, out_channels=10, kernel_size=(3, 3), stride=1, padding=1) # 26 * 26
        self.conv2 = nn.Conv2d(
            in_channels=10, out_channels=20, kernel_size=(3, 3), stride=1, padding=1) # 24 * 24
        self.conv3 = nn.Conv2d(
            in_channels=20, out_channels=30, kernel_size=(3, 3), stride=1, padding=1) # 22 * 22
        
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(30 * 3 * 3, l1)
        self.fc2 = nn.Linear(l1, l2)
        self.fc3 = nn.Linear(l2, 10)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x))) # 14 * 14
        x = self.pool(self.relu(self.conv2(x))) # 7 * 7
        x = self.pool(self.relu(self.conv3(x))) # 3 * 3
        x = torch.flatten(x, start_dim=1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
def train_model(config, data_dir=None):
    cnn = CNN(config['l1'], config['l2'])

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            cnn = nn.DataParallel(cnn)
    cnn.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(cnn.parameters(), lr=config["lr"])

    checkpoint = get_checkpoint()

    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            data_path = Path(checkpoint_dir) / "data.pkl"
            with open(data_path, "rb") as fp:
                checkpoint_state = pickle.load(fp)
            start_epoch = checkpoint_state["epoch"]
            cnn.load_state_dict(checkpoint_state["net_state_dict"])
            optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
    else:
        start_epoch = 0

    trainset, _ = load_data(data_dir=data_dir)

    test_abs = int(len(trainset) * 0.8)
    train_subset, val_subset = random_split(
        trainset, [test_abs, len(trainset) - test_abs]
    )

    num_workers = 2
    trainloader = DataLoader(
        train_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=num_workers
    )
    valloader = DataLoader(
        val_subset, batch_size=int(config["batch_size"]), shuffle=False, num_workers=num_workers
    )

    for epoch in range(start_epoch, 10):  # loop over the dataset multiple times
        running_loss = 0.0
        epoch_steps = 0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = cnn(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            epoch_steps += 1
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print(
                    "[%d, %5d] loss: %.3f"
                    % (epoch + 1, i + 1, running_loss / epoch_steps)
                )
                running_loss = 0.0

        val_loss = 0.0
        val_steps = 0
        total = 0
        correct = 0
        for i, data in enumerate(valloader, 0):
            with torch.no_grad():
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = cnn(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                loss = criterion(outputs, labels)
                val_loss += loss.cpu().numpy()
                val_steps += 1

        checkpoint_data = {
            "epoch": epoch,
            "net_state_dict": cnn.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        }

        with tempfile.TemporaryDirectory() as checkpoint_dir:
            data_path = Path(checkpoint_dir) / "data.pkl"
            with open(data_path, "wb") as fp:
                pickle.dump(checkpoint_data, fp)

            checkpoint = Checkpoint.from_directory(checkpoint_dir)
            train.report(
                {"loss": val_loss / val_steps, "accuracy": correct / total},
                checkpoint=checkpoint,
            )

    print("Finished Training")

In [None]:
max_num_epochs = 10

scheduler = ASHAScheduler(
    metric="loss",
    mode="min",
    max_t=max_num_epochs,
    grace_period=1,
    reduction_factor=2,
)

In [None]:
config = {
    "l1": tune.choice([2 ** i for i in range(4, 8)]),
    "l2": tune.choice([2 ** i for i in range(4, 8)]),
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([16, 32, 64])
}

In [None]:
gpus_per_trial = 0
num_samples = 10
data_dir = os.path.abspath("./data")

result = tune.run(
    partial(train_model, data_dir=data_dir),
    resources_per_trial={"cpu": 12, "gpu": gpus_per_trial},
    config=config,
    num_samples=num_samples,
    scheduler=scheduler)