In [1]:
import torch
import mlflow.pytorch
import torchvision.transforms as transforms
from medmnist import PathMNIST
from torch.utils.data import DataLoader
import mlflow


# Load the trained model from MLFlow
def load_model_from_mlflow(model_uri):
    # Load the trained model
    print(f"Loading model from {model_uri}")
    model = mlflow.pytorch.load_model(model_uri)
    return model

In [2]:
# Load the test data (PathMNIST dataset in this case)
def load_test_data(batch_size):
    # Define the data transformations (normalization is the same as used in training)
    data_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    # Load PathMNIST test dataset
    test_dataset = PathMNIST(split='test', transform=data_transform, download=True, root = "../data/raw")
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
    
    return test_loader

In [3]:
# Function to run inference on the test data
def run_inference(model, test_loader):
    model.eval()
    all_preds = []
    with torch.no_grad():
        for images, _ in test_loader:  # We ignore the labels during inference
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            all_preds.extend(predicted.cpu().numpy())
    return all_preds

In [4]:
# Log the inference results to MLFlow
def log_inference_results(predictions):
    # Log predictions to MLFlow
    for i, pred in enumerate(predictions):
        mlflow.log_metric(f"prediction_{i}", pred)

In [7]:
run_id = "adca5b681fb54667a3b5e82114034130"
mlflow.set_tracking_uri(uri="http://127.0.0.1:5000/")

In [8]:
# Load the model from MLFlow (use the model URI stored during training)
model_uri = f"runs:/{run_id}/PathMNIST_cnn_model"  # Replace <RUN_ID> with your actual run ID
model = load_model_from_mlflow(model_uri)

# Set the batch size for testing
batch_size = 64

# Load the test data
test_loader = load_test_data(batch_size)

# Run inference
predictions = run_inference(model, test_loader)

  from .autonotebook import tqdm as notebook_tqdm


Loading model from runs:/adca5b681fb54667a3b5e82114034130/PathMNIST_cnn_model


Downloading artifacts: 100%|██████████| 6/6 [00:00<00:00, 16.49it/s]


Using downloaded and verified file: ../data/raw/pathmnist.npz


In [9]:
predictions[0:10]

[8, 4, 0, 8, 4, 0, 8, 0, 4, 8]