In [None]:
!pip install -U ray
!pip install pytorch-ignite

In [None]:
import os
import argparse
from filelock import FileLock
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from ignite.contrib.handlers import ProgressBar
from ignite.engine import Engine
from ignite.metrics import Accuracy, Loss
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator


import ray
from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler

In [None]:
def get_data_loaders():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, ), (0.5, )),
    ])

    train_set = datasets.MNIST(
        "~/.pytorch/MNIST_data/", train=True, download=True, transform=transform)
    test_set = datasets.MNIST(
        "~/.pytorch/MNIST_data/", train=False, download=True, transform=transform)

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=256, shuffle=True) 

    test_loader = torch.utils.data.DataLoader(
        test_set, batch_size=256, shuffle=True)
    
    return train_loader, test_loader


In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
        
model = Net().to(device)

In [None]:
def train_step(engine, batch=None, optimizer=None):
    data, targets = batch
    data, targets = data.to(device), targets.to(device)
    model.train()
    optimizer.zero_grad()
    outputs = model(data)
    loss = F.nll_loss(outputs, targets)
    loss.backward()
    optimizer.step()
    # return loss.data()

def train_one_epoch(config):
    train_loader, test_loader = get_data_loaders()
    optimizer = optim.SGD(model.parameters(), 
                          lr=config["lr"], 
                          momentum=config["momentum"]
                          )
    
    for batch in train_loader:
        # batch = batch.to(device)
        train_step(Engine, batch, optimizer)
        acc = test(model, test_loader)
        tune.report(mean_accuracy=acc)

def validation_step(engine, batch):
    model.eval()
    with torch.no_grad():
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
    return y_pred, y

def test(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            # if batch_idx * len(data) > TEST_SIZE:
            #     break
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return correct / total

# Find best hyperparams using ray tune


In [None]:
ray.shutdown()
ray.init()

# for early stopping
sched = AsyncHyperBandScheduler()

analysis = tune.run(
    train_one_epoch,
    metric="mean_accuracy",
    mode="max",
    name="exp",
    scheduler=sched,
    stop={
        "mean_accuracy": 0.98,
        "training_iteration": 3,
    },
    resources_per_trial={"cpu": 1, "gpu": 1},  # set this for GPUs
    num_samples=3,
    config={
        "lr": tune.loguniform(1e-3, 1e-2),
        "momentum": tune.uniform(0.1, 0.4),
    },
)

print("Best config is:", analysis.best_config)

#Train using PyTorch Ignite

In [None]:
best_optimizer = optim.SGD(model.parameters(), 
                           lr=analysis.best_config["lr"], #using best lr from ray tune
                           momentum=analysis.best_config["momentum"]) #using best momentum from ray tune
                                                
criterion = nn.CrossEntropyLoss()
trainer = create_supervised_trainer(model.to(device), best_optimizer, criterion, device=device)
val_metrics = {
    "accuracy": Accuracy(),
    "loss": Loss(criterion)
}

train_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)
val_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)

train_loader, test_loader = get_data_loaders()
ProgressBar().attach(trainer)

evaluator = Engine(validation_step)
Accuracy().attach(evaluator, "accuracy")

validate_every = 1
log_every=1

@trainer.on(Events.EPOCH_COMPLETED(every=validate_every))
def run_validation():
    evaluator.run(test_loader)

@trainer.on(Events.EPOCH_COMPLETED(every=log_every))
def log_validation():
    metrics = evaluator.state.metrics
    # print(metrics)
    print(f"Epoch: {trainer.state.epoch},  Accuracy: {metrics['accuracy']}")

trainer.run(train_loader, max_epochs=3)