# General Metrics

In [None]:
import os 
import glob
import PIL.Image
from typing import List, Optional, Union
import numpy as np
import torch
from metrics.metrics import Metrics, preprocess_image

device = "cuda" if torch.cuda.is_available() else "cpu"
test_image_dir = "example_data/2c21b97ff3dc4fc3b1ef9e4bb0164318"

def load_images_from_dir(image_dir: str):
    """
    Load all images from a directory and return them as a tensor with shape (num_frames, channels, height, width)
    """
    all_images = glob.glob(os.path.join(image_dir, "*.png"))
    all_images = [preprocess_image(PIL.Image.open(image)) for image in all_images]
    torch_image_tensor = torch.tensor(np.array(all_images), dtype=torch.float32)
    # (num_frames, channels, height, width)
    torch_image_tensor = torch_image_tensor.permute(0, 3, 1, 2)
    target = torch_image_tensor.to(device)
    return target

In [None]:
# create metrics object
metrics = Metrics(device=device)
# load images
target = load_images_from_dir(test_image_dir)

In [None]:
# compute metrics for 2 example runs with small noise
for i in range(2):
    input = target + torch.randn_like(target) * 0.001
    input = torch.clamp(input, 0, 1).to(device)
    metrics.compute_image(input, target)
metrics.get_total_metrics()