In this notebook, we show how we can train a model with PyTorch and save it as a TileDB array on TileDB-Cloud.
Firstly, let's import what we need and define some variables needed for training a model.

In [None]:
import tiledb.cloud
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

from tiledb.ml.models.pytorch import PyTorchTileDBModel

epochs = 1
batch_size_train = 128
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

# Set random seeds for anything using random number generation
random_seed = 1

# Disable nondeterministic algorithms
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

We then have to export and load our TileDB-Cloud credentials. For TileDB cloud you can also use a token.
You have to also set up your AWS credentials on your TileDB-Cloud account.

In [None]:
# This is also our namespace on TileDB-Cloud.
TILEDB_USER_NAME = os.environ.get('TILEDB_USER_NAME')
TILEDB_PASSWD = os.environ.get('TILEDB_PASSWD')

We then create a TileDB-Cloud context and set up our communication with TileDB-Cloud.

In [None]:
ctx = tiledb.cloud.Ctx()
tiledb.cloud.login(username=TILEDB_USER_NAME, password=TILEDB_PASSWD)

We  will also need the DataLoaders API for the dataset. We will also employ TorchVision which let's as load the MNIST
dataset in a handy way. We'll use a batch_size of 64 for training while the values 0.1307 and 0.3081 used for
the Normalize() transformation below are the global mean and standard deviation of the MNIST dataset,
we'll take them as a given here.

In [None]:
import logging
logging.getLogger("lightning").setLevel(logging.ERROR)

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

Moving on, we build our network. We'll use two 2-D convolutional layers followed by two fully-connected
layers. As activation function we'll choose ReLUs and as a means of regularization we'll use two dropout layers.

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim = 1)


We will now initialise our Neural Network and optimizer.

In [None]:
model = Net()
optimizer = optim.SGD(model.parameters(), lr=learning_rate,
                      momentum=momentum)

We continue with the training loop and we iterate over all training data once per epoch. Loading the individual batches
is handled by the DataLoader. We need to set the gradients to zero using optimizer.zero_grad() since PyTorch by default
accumulates gradients. We then produce the output of the network (forward pass) and compute a negative log-likelihodd
loss between the output and the ground truth label. The backward() call we now collect a new set of gradients which we
propagate back into each of the network's parameters using optimizer.step().

In [None]:
train_losses = []
train_counter = []

def train(epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))

for epoch in range(1, epochs + 1):
  train(epoch)

We can move on by defining a TileDBPyTorch model and use model save functionality in order to save it directly to
our bucket on S3 (defined with AWS credentials in your TileDB-Cloud account) and register it on TileDB-Cloud.

In [None]:
# Define array model uri.
uri = "tiledb-pytorch-model"

print('Defining PyTorchTileDBModel model...')
# In order to save our model on S3 and register it on TileDB-Cloud we have to pass our Namespace and TileDB Context.
tiledb_model = PyTorchTileDBModel(uri=uri, namespace=TILEDB_USER_NAME, ctx=ctx, model=model)

# We will need the uri that was created from our model class
# (and follows pattern tiledb://my_username/s3://my_bucket/my_array),
# in order to interact with our model on TileDB-Cloud.
tiledb_cloud_model_uri = tiledb_model.uri

print('Saving model on S3 and registering on TileDB-Cloud...')
tiledb_model.save(meta={'epochs': epochs,
                        'train_loss': train_losses})


Finally, we can use TileDB-Cloud API as described in our [cloud documentation](https://docs.tiledb.com/cloud/), in order
to list our models, get information and deregister them.

In [None]:
# List all our models. Here, we filter with file_type = 'ml_model'. All machine learning model TileDB arrays are of type
# 'ml_model'
print(
tiledb.cloud.client.list_arrays(
    file_type=['ml_model'],
    namespace=TILEDB_USER_NAME))

# Get model's info
print(tiledb.cloud.array.info(tiledb_cloud_model_uri))

# Load our model for inference
# Place holder for the loaded model
loaded_model = Net()
loaded_optimizer = optim.SGD(model.parameters(), lr=learning_rate,
                             momentum=momentum)

PyTorchTileDBModel(uri=tiledb_cloud_model_uri, ctx=ctx).load(model=loaded_model, optimizer=loaded_optimizer)


# Check model parameters
for key_item_1, key_item_2 in zip(
    model.state_dict().items(), loaded_model.state_dict().items()
):
    print(torch.equal(key_item_1[1], key_item_2[1]))

# Check optimizer parameters
for key_item_1, key_item_2 in zip(
    optimizer.state_dict().items(), loaded_optimizer.state_dict().items()
):
    print(all([a == b for a, b in zip(key_item_1[1], key_item_2[1])]))

# Deregister model
tiledb.cloud.deregister_array(tiledb_cloud_model_uri)