In [None]:
import math
import time

import matplotlib.pyplot as plt
import torch
from IPython.display import clear_output, display

TRAINING_DATA_PATH = "/home/andrej/work/epsclassifiers/epsclassifiers/cr_body_part_classifier/current_training_data.pt"
RELOAD_INTERVAL_SEC = 0.5
NUM_VIEW_GRID_COLS = 4
LABEL_TO_STRING_MAPPING = {
    0: "non-chest",
    1: "chest"
}

def visualize_training_data(data_path):
    try:
        training_data = torch.load(data_path)
    except Exception as e:
        # The data file might be in the process of being written by the training script.
        # Skip this read attempt and try again later to avoid potential errors.
        return

    inputs = training_data["inputs"]
    labels = training_data["labels"]
    labels = [LABEL_TO_STRING_MAPPING[label.item()] for label in labels]

    NUM_IMAGES = inputs.size(0)
    NUM_ROWS = math.ceil(NUM_IMAGES / NUM_VIEW_GRID_COLS)

    clear_output(wait=True)
    plt.figure(figsize=(NUM_VIEW_GRID_COLS * 2, NUM_ROWS * 2))

    for i in range(NUM_IMAGES):
        plt.subplot(NUM_ROWS, NUM_VIEW_GRID_COLS, i + 1)
        plt.imshow(inputs[i, 0, :, :].cpu().numpy(), cmap="gray")
        plt.title(f"Label: {labels[i]}", fontsize=12)
        plt.axis("off")

    plt.tight_layout()
    plt.show()

while True:
    visualize_training_data(TRAINING_DATA_PATH)
    time.sleep(RELOAD_INTERVAL_SEC)
