In [1]:
import time

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from mycelya_torch import RemoteMachine  # use remote GPUs on demand in PyTorch

In [2]:
class SimpleCNN(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        return self.classifier(x)

def train_epoch(model: nn.Module, data: list, optimizer: torch.optim.Optimizer, criterion: nn.Module, device: torch.device) -> float:
    model.train()
    running_loss = torch.tensor(0.0, device=device)
    total_count = 0
    
    for images, labels in data:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad(set_to_none=True)
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.detach() * images.size(0)
        total_count += images.size(0)
        
    return running_loss.item() / total_count

def evaluate(model: nn.Module, data: list, device: torch.device) -> tuple[float, float]:
    model.eval()
    correct = torch.tensor(0.0, device=device)
    total_count = 0

    with torch.no_grad():
        for images, labels in data:
            images, labels = images.to(device), labels.to(device)
            logits = model(images)
            predictions = logits.argmax(dim=1)
            correct += predictions.eq(labels).sum()
            total_count += images.size(0)

    accuracy = correct.item() / total_count
    return accuracy

In [3]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

In [4]:
# Spin up remote T4 machine
machine = RemoteMachine("modal", "T4")

data_device = machine.device("cpu")
model_device = machine.device("cuda")

# Upload data to remote CPU
train_data = [(images.to(data_device), labels.to(data_device)) for images, labels in train_loader]
test_data = [(images.to(data_device), labels.to(data_device)) for images, labels in test_loader]

In [5]:
# Upload model to remote GPU
model = SimpleCNN().to(model_device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 5

In [6]:
print(f"{data_device=}")
print(f"{model_device=}")

data_device=device(type='mycelya', index=0)
model_device=device(type='mycelya', index=1)


In [7]:
print("Downscaled image:")
print(train_data[0][0][0,0,::4,::4])
print("Label:", train_data[0][1][0])

Downscaled image:
tensor([[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
        [-0.4242, -0.4242, -0.4242, -0.4242,  1.5868, -0.4242, -0.4242],
        [-0.4242, -0.4242, -0.4242, -0.4242, -0.1187, -0.4242, -0.4242],
        [-0.4242, -0.4242, -0.4242,  1.7141,  1.3959, -0.4242, -0.4242],
        [-0.4242, -0.4242, -0.4242,  2.0960, -0.4242, -0.4242, -0.4242],
        [-0.4242, -0.4242, -0.4242,  1.7905, -0.0678, -0.4242, -0.4242],
        [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242]],
       device='mycelya:0')
Label: tensor(6, device='mycelya:0')


In [8]:
print("Starting training on remote T4...")
start_time = time.time()
for epoch in range(1, num_epochs + 1):
    epoch_start_time = time.time()
    train_loss = train_epoch(model, train_data, optimizer, criterion, model_device)
    test_accuracy = evaluate(model, test_data, model_device)
    duration = time.time() - epoch_start_time
    print(
        f"Epoch {epoch}/{num_epochs} | "
        f"train_loss: {train_loss:.4f} | "
        f"test_acc: {test_accuracy * 100:.2f}% | "
        f"duration: {duration:.2f}s"
    )
duration = time.time() - start_time
print(f"Took {duration:.2f} seconds")

machine.stop()

Starting training on remote T4...
Epoch 1/5 | train_loss: 0.4710 | test_acc: 95.34% | duration: 9.39s
Epoch 2/5 | train_loss: 0.1310 | test_acc: 97.42% | duration: 7.59s
Epoch 3/5 | train_loss: 0.0873 | test_acc: 98.00% | duration: 7.16s
Epoch 4/5 | train_loss: 0.0682 | test_acc: 98.28% | duration: 7.40s
Epoch 5/5 | train_loss: 0.0574 | test_acc: 98.46% | duration: 7.46s
Took 39.00 seconds


In [9]:
# Now compare with training on local CPU

data_device = torch.device("cpu")
model_device = torch.device("cpu")

train_data = [(images.to(data_device), labels.to(data_device)) for images, labels in train_loader]
test_data = [(images.to(data_device), labels.to(data_device)) for images, labels in test_loader]

model = SimpleCNN().to(model_device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 5

start_time = time.time()
for epoch in range(1, num_epochs + 1):
    epoch_start_time = time.time()
    train_loss = train_epoch(model, train_data, optimizer, criterion, model_device)
    test_accuracy = evaluate(model, test_data, model_device)
    duration = time.time() - epoch_start_time
    print(
        f"Epoch {epoch}/{num_epochs} | "
        f"train_loss: {train_loss:.4f} | "
        f"test_acc: {test_accuracy * 100:.2f}% | "
        f"duration: {duration:.2f}s"
    )
duration = time.time() - start_time
print(f"Took {duration:.2f} seconds")

Epoch 1/5 | train_loss: 0.5148 | test_acc: 94.68% | duration: 15.01s
Epoch 2/5 | train_loss: 0.1443 | test_acc: 96.96% | duration: 15.07s
Epoch 3/5 | train_loss: 0.0942 | test_acc: 97.71% | duration: 15.09s
Epoch 4/5 | train_loss: 0.0723 | test_acc: 98.21% | duration: 15.12s
Epoch 5/5 | train_loss: 0.0602 | test_acc: 98.44% | duration: 15.15s
Took 75.44 seconds
