In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from data_generation import *
from torchvision import transforms
from PIL import Image, ImageDraw

In [None]:
def get_distances(positions):
    """
    euclidean distances between consecutive points
    """
    distances = []
    for i in range(len(positions) - 1):
        tuple_subtraction = (
            positions[i + 1][0] - positions[i][0],
            positions[i + 1][1] - positions[i][1],
        )
        distances.append(np.linalg.norm(tuple_subtraction))
    return distances

In [None]:
sequence_length = 10
speed_min = 10
speed_max = 15
direction_min = 0
direction_max = 2 * np.pi
position_x_min = 2
position_x_max = 14
position_y_min = 6
position_y_max = 14
gravity_min = 0
gravity_max = 4
restitution_min = 0.7
restitution_max = 1

In [None]:
distances = []
for _ in range(100):
    sequence, positions = generate_random_sequence(
        sequence_length=sequence_length,
        speed_min=speed_min,
        speed_max=speed_max,
        direction_min=direction_min,
        direction_max=direction_max,
        position_x_min=position_x_min,
        position_x_max=position_x_max,
        position_y_min=position_y_min,
        position_y_max=position_y_max,
        gravity_min=gravity_min,
        gravity_max=gravity_max,
        restitution_min=restitution_min,
        restitution_max=restitution_max,
    )
    distances += get_distances(positions)

In [None]:
# plot histogram of distances
plt.hist(distances, bins=40)
plt.show()

In [None]:
for _ in range(10):
    sequence, positions = generate_random_sequence(
        sequence_length=sequence_length,
        speed_min=speed_min,
        speed_max=speed_max,
        direction_min=direction_min,
        direction_max=direction_max,
        position_x_min=position_x_min,
        position_x_max=position_x_max,
        position_y_min=position_y_min,
        position_y_max=position_y_max,
        gravity_min=gravity_min,
        gravity_max=gravity_max,
        restitution_min=restitution_min,
        restitution_max=restitution_max,
    )
    display_sequence(sequence)

In [None]:
# Hyperparameters
N = 100000  # Number of sequences
L = 10  # Length of each sequence
H, W = 16, 16  # Dimensions of the images


class PixelDataset(Dataset):
    def __init__(self, num_sequences, sequence_length):
        self.data = []
        self.targets = []
        for i in range(num_sequences):
            images, positions = generate_random_sequence(
                sequence_length=sequence_length,
                speed_min=speed_min,
                speed_max=speed_max,
                direction_min=direction_min,
                direction_max=direction_max,
                position_x_min=position_x_min,
                position_x_max=position_x_max,
                position_y_min=position_y_min,
                position_y_max=position_y_max,
                gravity_min=gravity_min,
                gravity_max=gravity_max,
                restitution_min=restitution_min,
                restitution_max=restitution_max,
            )
            self.data.append(images)  # Store the whole sequence
            self.targets.append(positions)  # Store all positions

            if i % 1000 == 0:
                print(f"Generated {i} sequences")

        self.transform = transforms.ToTensor()
        self.sequence_length = sequence_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Transform and stack images to create a sequence
        sequence_of_images = [self.transform(image) for image in self.data[idx]]
        sequence_of_images = torch.stack(sequence_of_images)

        # Stack target positions
        target_positions = torch.FloatTensor(self.targets[idx])

        return sequence_of_images, target_positions


# Creating Dataset and DataLoader
dataset = PixelDataset(N, L)
torch.save(dataset, "dataset.pt")