### Imports

In [1]:
import os
import sys
import glob
import csv
import time
import random
import re
from datetime import datetime
from typing import Tuple, List, Callable

In [2]:
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
from PIL import Image

In [3]:
import torch
from torch import nn
import torch.optim as optim
from torchvision.io import decode_image
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor

### Constants

In [4]:
from displacements import VectorFieldComposer, VECTOR_FIELDS

TILES_DIR = "../tiles"
# TILE_IMAGE_PATHS = glob.glob(os.path.join(TILES_DIR, "**/*.png"), recursive=True)
TILE_IMAGE_PATHS = glob.glob(os.path.join(TILES_DIR, "g*/**/*.png"), recursive=True) # Just the graphite images
MAX_TILES = 50000
NUM_TILES = min(MAX_TILES, len(TILE_IMAGE_PATHS))

TILE_SIZE = 256

In [5]:
sequence_arrays = {}

# Iterate through sequence folders (e.g., "78", "g60")
for sequence_name in os.listdir(TILES_DIR):
    sequence_path = os.path.join(TILES_DIR, sequence_name)

    if os.path.isdir(sequence_path):  # Ignore hidden files/folders
        tile_arrays = {}  # Dictionary for tile arrays within this sequence

        # Iterate through image folders within the sequence (e.g., "0001.tif", "0023.tif")
        for image_folder_name in os.listdir(sequence_path):
            image_folder_path = os.path.join(sequence_path, image_folder_name)

            if os.path.isdir(image_folder_path):
                # Iterate through the tile images within the image folder
                for tile_image_name in os.listdir(image_folder_path):
                    if tile_image_name.startswith("tile_") and tile_image_name.endswith(".png"):
                        try:
                            tile_number_match = re.search(r"tile_(\d+)\.png", tile_image_name)
                            if tile_number_match:
                                tile_number = int(tile_number_match.group(1))
                                tile_image_path = os.path.join(image_folder_path, tile_image_name)
                                if tile_number not in tile_arrays:
                                    tile_arrays[tile_number] = []
                                tile_arrays[tile_number].append(tile_image_path)

                        except ValueError:
                            print(f"Warning: Could not parse tile number from {tile_image_name} in {image_folder_path}")



        sequence_arrays[sequence_name] = tile_arrays  # Add the tile arrays for this sequence

# Dataset

In [6]:
class CustomDataset(Dataset):
    def __init__(self, variations_per_image: int = 10):
        self.variations_per_image = variations_per_image
    
    def __len__(self):
        return NUM_TILES * self.variations_per_image

    def __getitem__(self, index):
        # Indexes work like this:
        # [1_0, ..., n_0, 1_1, ..., n_1, 1_v, ..., n_v, ...]
        # [1  , ..., n  , n+1, ..., n+n, vn+1,..., vn+n,...]
        # Where n is the number of images
        # And v is the variation number
        
        # Get the image index
        path_index = index % NUM_TILES
        variation = index // self.variations_per_image

        random.seed(variation)

        composer = VectorFieldComposer()
            
        available_fields = list(VECTOR_FIELDS.keys())
        num_fields = random.randint(1, 3)
        for _ in range(num_fields):
            field_type = random.choice(available_fields)
            composer.add_field(field_type, randomize=True)
        
        image = np.array(Image.open(TILE_IMAGE_PATHS[path_index], mode="r"))
        image2 = composer.apply_to_image(image)

        grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, TILE_SIZE), np.linspace(-1, 1, TILE_SIZE))
        dx, dy = composer.compute_combined_field(grid_X, grid_Y)

        # return image.astype(np.float32), dx.astype(np.float32)
        return np.array([image, image2]).astype(np.float32), np.array([dx, dy]).astype(np.float32)

# Model

In [7]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [8]:
class ConvolutionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.1, inplace=True),
        )

        self.residual = nn.Sequential()
        if in_channels != out_channels:
            self.residual = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        # i learned about this in class today! the timing in on point!
        return self.conv(x) + self.residual(x)

class MotionVectorRegressionNetwork(nn.Module):
    def __init__(self, input_images = 2):
        super().__init__()
        # Outputs an xy motion vector per pixel
        self.input_images = input_images
        self.vector_channels = 2

        self.convolution = nn.Sequential(
            ConvolutionBlock(input_images, 32, kernel_size=3), # input_images (2) -> 32 channels
            nn.MaxPool2d(kernel_size=2), # scales down by half

            ConvolutionBlock(32, 64, kernel_size=3), # 32 -> 64 channels
            nn.MaxPool2d(kernel_size=2), # scales down by half

            ConvolutionBlock(64, 128, kernel_size=3), # 64 -> 128 channels 
            ConvolutionBlock(128, 128, kernel_size=3), # 128 -> 128 channels
        ) 

        self.output = nn.Sequential(
            # scale back up
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 128 -> 64 channels
            nn.LeakyReLU(0.1),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # 64 -> 32 channels
            nn.LeakyReLU(0.1),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, self.vector_channels, kernel_size=3, stride=1, padding=1), # 32 -> 2 channels
        )

    def forward(self, x):
        # print(x.shape)
        x = self.convolution(x)
        # print(x.shape)
        x = self.output(x)
        # print(x.shape)
        return x

# Testing

In [9]:
MODEL_FILE = "model2/tx7.pth"
model = MotionVectorRegressionNetwork().to(device)
model.load_state_dict(torch.load(MODEL_FILE, weights_only=True))

<All keys matched successfully>

In [10]:
model.eval()

num_examples = 1
sequence_name = 'g60'
if sequence_name not in sequence_arrays:
    print(f"Sequence '{sequence_name}' not found in sequence_arrays. Please check your data.")
else:
    image_paths = sequence_arrays[sequence_name][5]
    for example_index in range(num_examples):
        base_image_path = image_paths[example_index * 2] #0, 2, 4
        next_time_path = image_paths[example_index * 2 + 1] #1, 3, 5

        base_image = np.array(Image.open(base_image_path))
        next_time_pil = Image.open(next_time_path)
        # next_time_pil = next_time_pil.filter(ImageFilter.GaussianBlur(radius=0.5)) #radius controls the amount of blur
        next_time = np.array(next_time_pil)

        with torch.no_grad():
            X = torch.from_numpy(np.array([base_image, next_time])).float()
            X = X.unsqueeze(0)
            X = X.to(device)
            pred = model(X)

            with open("image_displacements.txt", "w") as f:
                for y in range(TILE_SIZE):
                    for x in range(TILE_SIZE):
                        u = pred[0, 0, y, x]
                        v = pred[0, 1, y, x]
                        
                        f.write(f"{x} {y} {u:.6f} {v:.6f}\n")