In [None]:
%%capture
!pip install pytorch-ignite
!pip install zenml
!pip install pyparsing==2.4.2

Before proceeding : 
* Restart runtime in order to use newly installed versions 

* Then run the cells below

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import torch.nn.functional as F
import torch.optim as optim

device = "cuda" if torch.cuda.is_available() else "cpu"

from ignite.contrib.handlers import ProgressBar
from ignite.engine import Engine
from ignite.metrics import Accuracy
from ignite.engine import Events

from zenml.integrations.constants import PYTORCH
from zenml.pipelines import pipeline
from zenml.steps import Output, step, BaseStepConfig

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [None]:
@step
def importer_mnist() -> Output(
    train_dataloader=DataLoader,
    test_dataloader=DataLoader,
):

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, ), (0.5, )),
    ])

    train_set = datasets.MNIST(
        "~/.pytorch/MNIST_data/", train=True, download=True, transform=transform)
    test_set = datasets.MNIST(
        "~/.pytorch/MNIST_data/", train=False, download=True, transform=transform)

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=256, shuffle=True) 

    test_loader = torch.utils.data.DataLoader(
        test_set, batch_size=256, shuffle=True)
    
    return train_loader, test_loader

In [None]:
def ignite_train_step(engine, batch):
    data, targets = batch
    model.train()
    optimizer.zero_grad()
    outputs = model(data)
    loss = F.nll_loss(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss

def ignite_validation_step(engine, batch):
    model.eval()
    with torch.no_grad():
        x, y = batch
        y_pred = model(x)
    return y_pred, y

In [None]:
# https://docs.zenml.io/developer-guide/runtime-configuration
class zenml_trainer_config(BaseStepConfig):
    max_train_epochs: int

@step
def zenml_trainer(config: zenml_trainer_config, train_dataloader: DataLoader) -> nn.Module:
    trainer = Engine(ignite_train_step)
    ProgressBar().attach(trainer)
    trainer.run(train_dataloader, max_epochs=config.max_train_epochs)
    return model

@step
def zenml_evaluator(test_dataloader: DataLoader, model: nn.Module) -> float:
    evaluator = Engine(ignite_validation_step)
    Accuracy().attach(evaluator, "accuracy")
    evaluator.run(test_dataloader)
    ignite_metrics = evaluator.state.metrics
    print("Accuracy: ", ignite_metrics['accuracy'])
    return ignite_metrics['accuracy']

In [None]:
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)

@pipeline(required_integrations=[PYTORCH])
def zenml_pipeline(
    importer,
    trainer,
    evaluator,
):
    """Link all the steps and artifacts together"""
    train_dataloader, test_dataloader = importer()
    model = trainer(train_dataloader)
    evaluator(test_dataloader=test_dataloader, model=model)

p = zenml_pipeline(
    importer=importer_mnist(),
    trainer=zenml_trainer(zenml_trainer_config(max_train_epochs=2, )),
    evaluator=zenml_evaluator(),
)

p.run()