In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchinfo import summary
from torchmetrics import Accuracy
from torchvision import datasets
from torchvision.transforms import ToTensor

import mlflow

In [2]:
training_data = datasets.FashionMNIST(
  root="data",
  train=True,
  download=True,
  transform=ToTensor(),
)

test_data = datasets.FashionMNIST(
  root="data",
  train=False,
  download=True,
  transform=ToTensor(),
)

In [3]:
print(f"Image size: {training_data[0][0].shape}")
print(f"Size of training dataset: {len(training_data)}")
print(f"Size of test dataset: {len(test_data)}")

Image size: torch.Size([1, 28, 28])
Size of training dataset: 60000
Size of test dataset: 10000


In [4]:
train_dataloader = DataLoader(training_data, batch_size=64,num_workers=0)
test_dataloader = DataLoader(test_data, batch_size=64,num_workers=0)


In [5]:
class ImageClassifier(nn.Module):
  def __init__(self):
      super().__init__()
      self.model = nn.Sequential(
          nn.Conv2d(1, 8, kernel_size=3),
          nn.ReLU(),
          nn.Conv2d(8, 16, kernel_size=3),
          nn.ReLU(),
          nn.Flatten(),
          nn.LazyLinear(10),  # 10 classes in total.
      )

  def forward(self, x):
      return self.model(x)

In [6]:
import os

mlflow.set_tracking_uri("http://127.0.0.1:8080")
os.environ['MLFLOW_TRACKING_USERNAME'] = 'name'
os.environ['MLFLOW_TRACKING_PASSWORD'] = 'pass'
mlflow.set_experiment("mlflow-pytorch-quickstart")


<Experiment: artifact_location='mlflow-artifacts:/613627314193825433', creation_time=1756888004599, experiment_id='613627314193825433', last_update_time=1756888883686, lifecycle_stage='active', name='mlflow-pytorch-quickstart', tags={}>

In [7]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

def train(dataloader, model, loss_fn, metrics_fn, optimizer, epoch):
  """Train the model on a single pass of the dataloader.

  Args:
      dataloader: an instance of `torch.utils.data.DataLoader`, containing the training data.
      model: an instance of `torch.nn.Module`, the model to be trained.
      loss_fn: a callable, the loss function.
      metrics_fn: a callable, the metrics function.
      optimizer: an instance of `torch.optim.Optimizer`, the optimizer used for training.
      epoch: an integer, the current epoch number.
  """
  model.train()
  for batch, (X, y) in enumerate(dataloader):
      X = X.to(device)
      y = y.to(device)

      pred = model(X)
      loss = loss_fn(pred, y)
      accuracy = metrics_fn(pred, y)

      # Backpropagation.
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      if batch % 100 == 0:
          loss_value = loss.item()
          current = batch
          step = batch // 100 * (epoch + 1)
          mlflow.log_metric("loss", f"{loss_value:2f}", step=step)
          mlflow.log_metric("accuracy", f"{accuracy:2f}", step=step)
          print(f"loss: {loss_value:2f} accuracy: {accuracy:2f} [{current} / {len(dataloader)}]")

In [8]:
def evaluate(dataloader, model, loss_fn, metrics_fn, epoch):
  """Evaluate the model on a single pass of the dataloader.

  Args:
      dataloader: an instance of `torch.utils.data.DataLoader`, containing the eval data.
      model: an instance of `torch.nn.Module`, the model to be trained.
      loss_fn: a callable, the loss function.
      metrics_fn: a callable, the metrics function.
      epoch: an integer, the current epoch number.
  """
  num_batches = len(dataloader)
  model.eval()
  eval_loss = 0
  eval_accuracy = 0
  with torch.no_grad():
      for X, y in dataloader:
          X = X.to(device)
          y = y.to(device)
          pred = model(X)
          eval_loss += loss_fn(pred, y).item()
          eval_accuracy += metrics_fn(pred, y)

  eval_loss /= num_batches
  eval_accuracy /= num_batches
  mlflow.log_metric("eval_loss", f"{eval_loss:2f}", step=epoch)
  mlflow.log_metric("eval_accuracy", f"{eval_accuracy:2f}", step=epoch)

  print(f"Eval metrics: Accuracy: {eval_accuracy:.2f}, Avg loss: {eval_loss:2f}")

In [9]:
epochs = 3
loss_fn = nn.CrossEntropyLoss()
metric_fn = Accuracy(task="multiclass", num_classes=10).to(device)
model = ImageClassifier().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In [10]:
with mlflow.start_run() as run:
  params = {
      "epochs": epochs,
      "learning_rate": 1e-3,
      "batch_size": 64,
      "loss_function": loss_fn.__class__.__name__,
      "metric_function": metric_fn.__class__.__name__,
      "optimizer": "SGD",
  }
  # Log training parameters.
  mlflow.log_params(params)

  # Log model summary.
  with open("model_summary.txt", "w") as f:
      f.write(str(summary(model)))
  mlflow.log_artifact("model_summary.txt")

  for t in range(epochs):
      print(f"Epoch {t + 1}-------------------------------")
      train(train_dataloader, model, loss_fn, metric_fn, optimizer, epoch=t)
      evaluate(test_dataloader, model, loss_fn, metric_fn, epoch=0)

  # Save the trained model to MLflow.
  model_info = mlflow.pytorch.log_model(model, name="model")

Epoch 1-------------------------------
loss: 2.296156 accuracy: 0.109375 [0 / 938]
loss: 2.025807 accuracy: 0.500000 [100 / 938]
loss: 1.547872 accuracy: 0.703125 [200 / 938]
loss: 1.282949 accuracy: 0.578125 [300 / 938]
loss: 0.949073 accuracy: 0.687500 [400 / 938]
loss: 0.881768 accuracy: 0.703125 [500 / 938]
loss: 0.867905 accuracy: 0.718750 [600 / 938]
loss: 0.710541 accuracy: 0.765625 [700 / 938]
loss: 0.765847 accuracy: 0.718750 [800 / 938]
loss: 0.790817 accuracy: 0.750000 [900 / 938]
Eval metrics: Accuracy: 0.75, Avg loss: 0.715142
Epoch 2-------------------------------
loss: 0.664400 accuracy: 0.812500 [0 / 938]
loss: 0.751565 accuracy: 0.703125 [100 / 938]
loss: 0.486596 accuracy: 0.843750 [200 / 938]
loss: 0.768619 accuracy: 0.734375 [300 / 938]
loss: 0.692129 accuracy: 0.640625 [400 / 938]
loss: 0.671169 accuracy: 0.765625 [500 / 938]
loss: 0.702956 accuracy: 0.687500 [600 / 938]
loss: 0.630668 accuracy: 0.796875 [700 / 938]
loss: 0.675652 accuracy: 0.718750 [800 / 938]
los



🏃 View run powerful-sloth-810 at: http://127.0.0.1:8080/#/experiments/613627314193825433/runs/c17d71980b9143c2a88d652aeb79310b
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/613627314193825433
