In [None]:
import os

import torch
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST

In [None]:
print(torch.cuda.is_available())
# print(torch.cuda.get_device_name())
# print(torch.cuda.get_device_capability())

In [None]:
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64

In [None]:
class MNISTModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

In [None]:
mnist_model = MNISTModel()
train_ds = MNIST(
    PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor()
)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)
trainer = Trainer(gpus=AVAIL_GPUS, max_epochs=3, progress_bar_refresh_rate=20)

In [None]:
trainer = Trainer(gpus=AVAIL_GPUS, max_epochs=3, progress_bar_refresh_rate=20)

In [None]:
trainer.fit(mnist_model, train_loader)