In [8]:
import os
import sys
import glob
import csv
import time
import platform
import socket
import random
from datetime import datetime
from typing import Tuple, List, Callable

In [9]:
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
from PIL import Image
import psutil  # For system resource monitoring
# import GPUtil  # For GPU information

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [11]:
from displacements import VectorFieldComposer, VECTOR_FIELDS

TILES_DIR = "../tiles"
TILE_IMAGE_PATHS = glob.glob(os.path.join(TILES_DIR, "**/*.png"), recursive=True)

TILE_SIZE = 256

In [None]:
def random_image() -> np.ndarray:
    path = TILE_IMAGE_PATHS[random.randint(0, len(TILE_IMAGE_PATHS) - 1)]
    image = Image.open(path, mode="r")
    return np.array(image)

In [12]:
def generate_training_data(num_samples=100) -> Tuple[np.ndarray, np.ndarray]:
    """
    returns (images in groups of 2, label) pairs
    """

    composer = VectorFieldComposer()

    available_fields = list(VECTOR_FIELDS.keys())
    num_fields = random.randint(1, 3)
    for _ in range(num_fields):
        field_type = random.choice(available_fields)
        composer.add_field(field_type, randomize=True)
    
    image_pairs = np.zeros((num_samples, 2, TILE_SIZE, TILE_SIZE, 3), dtype=np.float32)
    motion_vectors = np.zeros((num_samples, 2, TILE_SIZE, TILE_SIZE, 3), dtype=np.float32) # motin vectors

    for i in range(num_samples):
        base_image = random_image()

        grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, TILE_SIZE), np.linspace(-1, 1, TILE_SIZE))
        dx, dy = composer.compute_combined_field(grid_X, grid_Y)

        transformed_image = composer.apply_to_image(base_image)

        image_pairs[i, 0] = base_image
        image_pairs[i, 1] = transformed_image
        motion_vectors[i, 0] = dx
        motion_vectors[i, 1] = dy

    return image_pairs, motion_vectors