In [None]:
# Import necessary packages
import os
os.environ['XLA_FLAGS'] = '--xla_gpu_deterministic_ops'
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

import json

from flax import serialization
from jax import random

from cellori.applications.cyto import data
from cellori.applications.cyto import training

In [None]:
# Load training and testing datasets
train_ds = data.load_dataset('train')
test_ds = data.load_dataset('test')

In [None]:
# Create train state and metrics log
rng = random.PRNGKey(0)
state = training.create_train_state(rng, 0.0005)
metrics_log = []

In [None]:
# Load in previous train state and metrics log
with open('cellori_state', 'rb') as f_state:
    bytes_output = f_state.read()
    state = serialization.from_bytes(state, bytes_output)
with open('cellori_metrics_log', 'r') as f_log:
    metrics_log = json.load(f_log)

In [None]:
# Define training parameters
num_epochs = 100
batch_size = 8

In [None]:
# Training loop
for epoch in range(len(metrics_log) + 1, len(metrics_log) + num_epochs + 1):

    # Run an optimization step over a training batch
    state, metrics = training.train_epoch(state, train_ds, test_ds, batch_size, epoch)
    metrics_log.append(metrics)

    if epoch % 10 == 0:
        print('Saving checkpoint...')
        bytes_output = serialization.to_bytes(state)
        with open('cellori_state', 'wb') as f_state:
            f_state.write(bytes_output)
        with open('cellori_metrics_log', 'w') as f_log:
            json.dump(metrics_log, f_log, indent=4)
        print('Saved checkpoint!')

In [None]:
# Save model
variables = {'params': state.params, 'batch_stats': state.batch_stats}
bytes_output = serialization.to_bytes(variables)

with open('cellori_model', 'wb') as f_model:
    f_model.write(bytes_output)