This example notebook shows how we can train an [image/digit classification](https://pytorch.org/tutorials/beginner/nn_tutorial.html?highlight=mnist)
model based on MNIST dataset, and store it as TileDB array. Firstly, let's import what we need.

In [1]:
import glob
import json
import pickle
import os
import shutil
from pprint import pprint

import tiledb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

from tiledb.ml.models.pytorch import PyTorchTileDBModel

First let's define the parameters/hyperparameters we will need.

In [2]:
epochs = 1
batch_size_train = 128
learning_rate = 0.01
momentum = 0.5
log_interval = 10

# Set random seeds for anything using random number generation
random_seed = 1

# Disable nondeterministic algorithms
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

<torch._C.Generator at 0x15dae6650>

We  will also need the DataLoaders API for the dataset. We will also employ TorchVision which let's as load the MNIST
dataset in a handy way. We'll use a batch_size of 64 for training while the values 0.1307 and 0.3081 used for
the Normalize() transformation below are the global mean and standard deviation of the MNIST dataset,
we'll take them as a given here.

In [3]:
data_home = os.path.join(os.path.pardir, 'data')
dataset = torchvision.datasets.MNIST(
    root=data_home, 
    train=True, 
    download=True,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081,))
    ])
)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size_train, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw



Moving on, we build our network. We'll use two 2-D convolutional layers followed by two fully-connected
layers. As activation function we'll choose ReLUs and as a means of regularization we'll use two dropout layers.

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim = 1)


We will now initialise our Neural Network and optimizer.

In [5]:
network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum)

We continue with the training loop and we iterate over all training data once per epoch. Loading the individual batches
is handled by the DataLoader. We need to set the gradients to zero using optimizer.zero_grad() since PyTorch by default
accumulates gradients. We then produce the output of the network (forward pass) and compute a negative log-likelihodd
loss between the output and the ground truth label. The backward() call we now collect a new set of gradients which we
propagate back into each of the network's parameters using optimizer.step().

In [6]:
train_losses = []
train_counter = []

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()


def train(epoch):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    writer.add_scalar("Loss/train", loss, epoch)
    loss.backward()
    optimizer.step()
    if batch_idx % log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))

for epoch in range(1, epochs + 1):
  train(epoch)
writer.flush()
writer.close()



We can now save the trained model as a TileDB array. In case we want to train  the model further in a later time, we can also save
the optimizer in our TileDB array. In case we will use our model only for inference, we don't have to save the optimizer and we
only keep the model. We first declare a PytTorchTileDB object and initialize it with the corresponding TileDB uri, model and optimizer,
and then save the model as a TileDB array. Finally, we can save any kind of metadata (in any structure, i.e., list, tuple or dictionary)
by passing a dictionary to the meta attribute.

In [7]:
uri = os.path.join(data_home, 'pytorch-mnist-1')
tiledb_model_1 = PyTorchTileDBModel(uri=uri, model=network, optimizer=optimizer)

tiledb_model_1.save(update=False,
                    meta={'epochs': epochs,
                          'train_loss': train_losses},
                    summary_writer=writer)

The above step will create a TileDB array in your working directory. For information about the structure of a dense
TileDB array in terms of files on disk please take a look [here](https://docs.tiledb.com/main/concepts/data-format).
Let's open our TileDB array model and check metadata. Metadata that are of type list, dict or tuple have been JSON
serialized while saving, i.e., we need json.loads to deserialize them.

In [8]:
# Check array directory
pprint(glob.glob(f'{uri}/*'))

# Open in write mode in order to add metadata
model_array_1 = tiledb.open(uri)
for key, value in model_array_1.meta.items():
    if isinstance(value, bytes):
        value = json.loads(value)
    print("Key: {}, Value: {}".format(key, value))

['../data/pytorch-mnist-1/__meta',
 '../data/pytorch-mnist-1/__fragment_meta',
 '../data/pytorch-mnist-1/__commits',
 '../data/pytorch-mnist-1/__schema',
 '../data/pytorch-mnist-1/__fragments']
Key: TILEDB_ML_MODEL_ML_FRAMEWORK, Value: PYTORCH
Key: TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION, Value: 1.10.2
Key: TILEDB_ML_MODEL_PREVIEW, Value: Net(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5, inplace=False)
  (fc1): Linear(in_features=320, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=10, bias=True)
)
Key: TILEDB_ML_MODEL_PYTHON_VERSION, Value: 3.9.13
Key: TILEDB_ML_MODEL_STAGE, Value: STAGING
Key: epochs, Value: 1
Key: train_loss, Value: (2.358812093734741, 2.285137891769409, 2.3066349029541016, 2.2708795070648193, 2.2367401123046875, 2.24334716796875, 2.1832549571990967, 2.1485116481781006, 2.1049115657806396, 2.0044069290161133, 1.8622523546218872, 1.884370

As we can see, in array's metadata we have by default information about the backend we used for training (pytorch),
pytorch version, python version and the extra metadata about epochs and training loss that we added.
We can load and check any of the aforementioned without having to load the entire model in memory.
Moreover, we can add any kind of extra information in model's metadata also by opening the TileDB array and adding new keys.

In [9]:
# Open the array in write mode
with tiledb.Array(uri, "w") as A:
    # Keep all history
    A.meta['new_meta'] = json.dumps(['Any kind of info'])

# Check that everything is there
model_array_1 = tiledb.open(uri)
for key, value in model_array_1.meta.items():
    if isinstance(value, bytes):
        value = json.loads(value)
    print("Key: {}, Value: {}".format(key, value))

Key: TILEDB_ML_MODEL_ML_FRAMEWORK, Value: PYTORCH
Key: TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION, Value: 1.10.2
Key: TILEDB_ML_MODEL_PREVIEW, Value: Net(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5, inplace=False)
  (fc1): Linear(in_features=320, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=10, bias=True)
)
Key: TILEDB_ML_MODEL_PYTHON_VERSION, Value: 3.9.13
Key: TILEDB_ML_MODEL_STAGE, Value: STAGING
Key: epochs, Value: 1
Key: new_meta, Value: ["Any kind of info"]
Key: train_loss, Value: (2.358812093734741, 2.285137891769409, 2.3066349029541016, 2.2708795070648193, 2.2367401123046875, 2.24334716796875, 2.1832549571990967, 2.1485116481781006, 2.1049115657806396, 2.0044069290161133, 1.8622523546218872, 1.8843708038330078, 1.7973158359527588, 1.6879109144210815, 1.508046269416809, 1.764279842376709, 1.4700727462768555, 1.3514467477798462, 1.2905819416046143, 1

For the case of PyTorch models, internally, we save model's state_dict and optimizer's state_dict,
as [variable sized attributes)](https://docs.tiledb.com/main/how-to/arrays/writing-arrays/var-length-attributes)
(pickled), i.e., we can open the TileDB and get only the state_dict of the model or optimizer,
without bringing the whole model in memory. For example, we can load model's and optimizer's state_dict
for model `pytorch-mnist-1` as follows.

In [10]:
# First open arrays
model_array_1 = tiledb.open(uri)[:]

# Load model state_dict
model_1_state_dict = pickle.loads(model_array_1['model_state_dict'].item(0))

# Load optimizer state_dict
optimizer_1_state_dict = pickle.loads(model_array_1['optimizer_state_dict'].item(0))

print(f'Type: {type(model_1_state_dict)} , Keys: {model_1_state_dict.keys()}')
print(f'Type: {type(optimizer_1_state_dict)}, Keys: {optimizer_1_state_dict.keys()}')

Type: <class 'collections.OrderedDict'> , Keys: odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])
Type: <class 'dict'>, Keys: dict_keys(['state', 'param_groups'])


Moving on, we can load the trained models for prediction, evaluation or retraining, as usual with
PyTorch models.

In [11]:
# Place holder for the loaded model
network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum)

# Load returns possible extra attributes, other than model's and optimizer's state dicts. In case there were
# no extra attributes it will return an empty dict
_ = tiledb_model_1.load(model=network, optimizer=optimizer)

What is really nice with saving models as TileDB array, is native versioning based on fragments as described [here](https://docs.tiledb.com/main/concepts/data-format#immutable-fragments). We can load a model, retrain it with new data and update the already existing TileDB model array with new model parameters and metadata. All information, old and new will be there and accessible. This is extremely useful when you retrain with new data or trying different architectures for the same problem, and you want to keep track of all your experiments without having to store different model instances. In our case, let's continue training model_1 with the rest of our dataset and for 2 more epochs. After training is done, you will notice the extra directories and files (fragments) added to `pytorch-mnist-1` TileDB array directory, which keep all versions of the model.

In [12]:
train_losses = []
train_counter = []

# We train for some extra 2 epochs
for epoch in range(1, 2 + 1):
  train(epoch)

# and update
tiledb_model_1 = PyTorchTileDBModel(uri=uri, model=network, optimizer=optimizer)
tiledb_model_1.save(update=True, 
                    meta={'epochs': epochs,
                          'train_loss': train_losses})

# Check array directory
print()
pprint(glob.glob(f'{uri}/*'))

# tiledb.array_fragments() requires TileDB-Py version > 0.8.5
fragments_info = tiledb.array_fragments(uri)

print()
print("====== FRAGMENTS  INFO ======")
print("array uri: {}".format(fragments_info.array_uri))
print("number of fragments: {}".format(len(fragments_info)))

for fragment_num, fragment in enumerate(fragments_info, start=1):
    print()
    print("===== FRAGMENT NUMBER {} =====".format(fragment.num))
    print("fragment uri: {}".format(fragment.uri))
    print("timestamp range: {}".format(fragment.timestamp_range))
    print(
        "number of unconsolidated metadata: {}".format(
            fragment.unconsolidated_metadata_num
        )
    )
    print("version: {}".format(fragment.version))


['../data/pytorch-mnist-1/__meta',
 '../data/pytorch-mnist-1/__fragment_meta',
 '../data/pytorch-mnist-1/__commits',
 '../data/pytorch-mnist-1/__schema',
 '../data/pytorch-mnist-1/__fragments']

array uri: ../data/pytorch-mnist-1
number of fragments: 2

===== FRAGMENT NUMBER 0 =====
fragment uri: file:///Users/konstantinostsitsimpikos/tileroot/TileDB-ML/examples/data/pytorch-mnist-1/__fragments/__1660811273615_1660811273615_23699d36dbc744809486d88176c2920f_13
timestamp range: (1660811273615, 1660811273615)
number of unconsolidated metadata: 2
version: 13

===== FRAGMENT NUMBER 1 =====
fragment uri: file:///Users/konstantinostsitsimpikos/tileroot/TileDB-ML/examples/data/pytorch-mnist-1/__fragments/__1660811314379_1660811314379_0309938da153404e88a7a64ff044fc20_13
timestamp range: (1660811314379, 1660811314379)
number of unconsolidated metadata: 2
version: 13


Finally, a very interesting and useful, for machine learning models, TileDB feature that is described
[here](https://docs.tiledb.com/main/concepts/data-format#groups) and [here](https://docs.tiledb.com/main/how-to/object-management#creating-tiledb-groups)
are groups. Assuming we want to solve the MNIST problem, and we want to try several architectures. We can save each architecture
as a separate TileDB array with native versioning each time it is re-trained, and then organise all models that solve the same problem (MNIST)
as a TileDB array group with any kind of hierarchy. Let's firstly define a new model architecture.

In [13]:
class OtherNet(nn.Module):
    # For the sake of simplicity we just tweak the initial architecture by replacing a relu with relu6.
    def __init__(self):
        super(OtherNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu6(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim = 1)

Then train it and save it as a new TileDB array.

In [14]:
network = OtherNet()
optimizer = optim.Adam(network.parameters(), lr=learning_rate)

train_losses = []
train_counter = []

for epoch in range(1, epochs + 1):
    train(epoch)

uri2 = os.path.join(data_home, 'pytorch-mnist-2')
tiledb_model_2 = PyTorchTileDBModel(uri=uri2, model=network, optimizer=optimizer)

tiledb_model_2.save(update=False, 
                    meta={'epochs': epochs,
                          'train_loss': train_losses})



Now we can create a TileDB group and organise (in hierarchies, e.g., sophisticated vs less sophisticated) all our
MNIST models as follows.

In [15]:
group = os.path.join(data_home, 'tiledb-pytorch-mnist')
tiledb.group_create(group)
shutil.move(uri, group)
shutil.move(uri2, group)

'../data/tiledb-pytorch-mnist/pytorch-mnist-2'

Any time we can check and query all the available models, including their metadata, for a specific problem like MNIST.

In [None]:
tiledb.ls(group, lambda obj_path, obj_type: print(obj_path, obj_type))