## Installation

First, we install the Dataflux Dataset for PyTorch.

In [None]:
! pip install gcs-torch-dataflux

Then we need to install the rest of the rest of the expected packages for this demo.

In [None]:
! pip install torch
! pip install lightning

## Preparation

After installing all required packages, we must perform some additional preparation.

First, we set up the [authentication](https://github.com/GoogleCloudPlatform/dataflux-pytorch?tab=readme-ov-file#configuration) needed to run the notebook.

Modify the `PROJECT_ID` field to your own project ID.

In [None]:
from google.colab import auth
PROJECT_ID = "YOUR_PROJECT_ID"
auth.authenticate_user(project_id=PROJECT_ID)

## Lightning Checkpointing

The Dataflux PyTorch offers an optional implementation of PyTorch Lightning's checkpoints through implmentation of the CheckpointIO interface.

The methods that are supported are `save_checkpoint`, `load_checkpoint`, `remove_checkpoint` and `teardown`. 

First construct a DatafluxLightningCheckpoint.

Modify the `BUCKET_NAME` field to your own bucket name within the your project and `CKPT_PATH` field to the path you would like to test with.

In [None]:
from dataflux_pytorch.lightning import DatafluxLightningCheckpoint

BUCKET_NAME="YOUR_BUCKET_NAME"
dataflux_ckpt = DatafluxLightningCheckpoint(project_name=PROJECT_ID, bucket_name=BUCKET_NAME)
CKPT_PATH = "gcs://YOUR_BUCKET_NAME/demo/"

Ensure your datset, dataloader and model have been defined. This example pulls from the PyTorch demos datasets for simplicity.

In [None]:
from lightning.pytorch.demos import WikiText2, LightningTransformer
from torch.utils.data import DataLoader

dataset = WikiText2()
dataloader = DataLoader(dataset, num_workers=1)

model = LightningTransformer(vocab_size=dataset.vocab_size)

Lightning contains automatic checkpoint support through the use of callbacks. Using a `ModelCheckpoint` you can determine the settings around your checkpointing. For example, the `filename` sets the naming convention for your checkpoint files and `every_n_train_steps` sets the checkpointing frequency.

In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    save_top_k=-1,
    every_n_train_steps=1,
    filename="checkpoint-{epoch:02d}-{step:02d}",
    enable_version_counter=True,
)

Using `trainer.fit()` will run the model and using the callbacks our checkpoints will be saved every step. Check the `CKPT_PATH` to see your checkpoints.

In [None]:
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint

trainer = Trainer(
    default_root_dir=CKPT_PATH,
    callbacks=[checkpoint_callback],
    plugins=[dataflux_ckpt],
    min_epochs=4,
    max_epochs=5,
    max_steps=3,
    accelerator="cpu",
)
trainer.fit(model, dataloader)

Using the trainer you can now save, load and remove the checkpoint manually as well.

For example you would save the checkpoint to the `CKPT_PATH` by calling `save_checkpoint`.

In [None]:
trainer.save_checkpoint(CKPT_PATH)

You can verify the method by checking your bucket to see the checkpoint saved to `CKPT_PATH`.