# ML Training with Alluxio REST APIs

In [None]:
!pip3 install -q torch_tb_profiler
!pip3 install -q humanfriendly

In [None]:
import time
import warnings

import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18

from rest import AlluxioRest
from rest import AlluxioRestDataset

In [None]:
warnings.filterwarnings("ignore", category=UserWarning)

### Checking the device used for training

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device}")

## Section 1: Data Preparation

Here, we configure necessary (hyper-)parameters and create the PyTorch data loader.

The data loader loads data from Alluxio through the REST APIs.

In [None]:
dataset_path = "s3://ref-arch/imagenet-mini/train"
output_model_path = "/mnt/alluxio/fuse/models/demo/ai-demo-rest.pth"
endpoint="10.244.0.119:28080"
dora_root="s3://ref-arch/"
page_size="20MB"

batch_size = 64
num_workers = 1
learning_rate = 0.001
num_epochs = 1

profiler_enabled = False
profiler_log_path = "../log/ai-demo-rest"

In [None]:
transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
        ),
    ]
)

In [None]:
alluxio_rest = AlluxioRest(
    endpoint=endpoint,
    dora_root=dora_root,
    page_size=page_size,
    concurrency=num_workers,
    _logger=None,
)

train_dataset = AlluxioRestDataset(
    alluxio_rest=alluxio_rest,
    dataset_path=dataset_path,
    transform=transform,
    _logger=None,
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
)

## Section 2: Setup the Model

We fine-tune the ResNet18 model against a subset of the ImageNet dataset.

In [None]:
model = resnet18(pretrained=True)
model = model.to(device)

## Section 3: Setup the PyTorch Profiler

We setup the PyTorch profiler with the TensorBoard for the visualization of GPU utilization and other metrics.

In [None]:
profiler = None
if profiler_enabled:
    profiler = torch.profiler.profile(
        schedule=torch.profiler.schedule(
            wait=0, warmup=0, active=1, repeat=1
        ),
        on_trace_ready=torch.profiler.tensorboard_trace_handler(
            profiler_log_path
        ),
    )
    profiler.start()

## Section 4: Model Training

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate
)

start_time = time.perf_counter()
print(f"Started training at the timestamp {time.perf_counter()}")

for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        # Move input and label tensors to the device
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero out the optimization
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

    print(
        f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():0.4f} at the timestamp {time.perf_counter()}"
    )

    if profiler_enabled:
        profiler.step()

print(f"Finished Training, Loss: {loss.item():0.4f}")

end_time = time.perf_counter()
print(f"Training time in {end_time - start_time:0.4f} seconds")

if profiler_enabled:
    profiler.stop()
    print("The profiler is completed. Please open the TensorBoard to browse the metrics.")

## Save the Trained Model

In [None]:
torch.save(model.state_dict(), output_model_path)
print(f"Saved PyTorch AI demo model to {output_model_path}")