In [2]:
import torch
import numpy as np
import pandas as pd
import time, os, sys
from tqdm.auto import tqdm, trange

import random
from tensorboardX import SummaryWriter
import argparse
from omegaconf import OmegaConf
from datetime import datetime
from functools import partial
from torchvision import transforms

import plotly.graph_objects as go
import plotly.io as pio
import plotly.express as px
plotly_layout = dict(margin=dict(l=20, r=20, t=20, b=20))

In [3]:
sys.path.append('../..')
from loader import get_dataloader
from torchvision import transforms

In [4]:
data_cfg = {
    'dataset': 'MNIST',
    'root': '../../dataset',
    'batch_size': 100,
    'n_workers': 4,
    'split': 'training',
    'shuffle': True,
    'digits': [3],
}

dl = get_dataloader(data_cfg)

fraction_data: None
MNIST split training | 5100
no global Laplacian, normalized K, or distance matrix is precomputed


In [5]:
data = dl.dataset.data
targets = dl.dataset.targets

print(data.shape, targets.shape)

torch.Size([5100, 1, 28, 28]) torch.Size([5100])


In [6]:
# Data configuration
EPISODES_PER_IMAGE = 3      # Number of different episodes (i.e., different initial position) per same original image
TIME_HORIZON = 10
IMAGE_SIZE = 42
RESIZE_DIM = 14             # Resize the image to this dimension

# Set the velocity range for translation (in pixels per frame)
VELOCITY_RANGE = (-2, 2)  # Velocities in x and y direction can range between -2 and 2

# Repeat data for multiple episodes and create additional time dimension
data = data.repeat_interleave(EPISODES_PER_IMAGE, 0)
targets = targets.repeat_interleave(EPISODES_PER_IMAGE, 0)

# Initialize new data tensor that will contain the transformed images of (1,56,56) size
new_data = torch.zeros((data.shape[0], data.shape[1], IMAGE_SIZE, IMAGE_SIZE, TIME_HORIZON))

print(data.shape, targets.shape, new_data.shape)

torch.Size([15300, 1, 28, 28]) torch.Size([15300]) torch.Size([15300, 1, 42, 42, 10])


In [7]:

resize_transform = transforms.Resize((RESIZE_DIM, RESIZE_DIM))

for i_idx in trange(data.shape[0]):
    # Random initial position for each episode
    x_init = torch.randint(0, IMAGE_SIZE - RESIZE_DIM + 1, (1,)).item()
    y_init = torch.randint(0, IMAGE_SIZE - RESIZE_DIM + 1, (1,)).item()
    
    # Calculate the maximum allowed velocities based on the initial position and the time horizon
    max_x_velocity = min((IMAGE_SIZE - RESIZE_DIM - x_init) // (TIME_HORIZON - 1), VELOCITY_RANGE[1])
    min_x_velocity = max(-(x_init // (TIME_HORIZON - 1)), VELOCITY_RANGE[0])                            # // before -. This is because -3//2 = -2, but we want -1.

    max_y_velocity = min((IMAGE_SIZE - RESIZE_DIM - y_init) // (TIME_HORIZON - 1), VELOCITY_RANGE[1])
    min_y_velocity = max(-(y_init // (TIME_HORIZON - 1)), VELOCITY_RANGE[0])                            # // before -. This is because -3//2 = -2, but we want -1.

    # Ensure that there is a feasible velocity other than (0, 0)
    assert min_x_velocity < max_x_velocity or min_y_velocity < max_y_velocity, f"Invalid velocity range: {min_x_velocity}, {max_x_velocity}, {min_y_velocity}, {max_y_velocity}"
    
    while True:
        # Random initial velocities within the allowed range
        x_velocity = torch.randint(min_x_velocity, max_x_velocity + 1, (1,)).item()
        y_velocity = torch.randint(min_y_velocity, max_y_velocity + 1, (1,)).item()
        if not (x_velocity == 0 and y_velocity == 0):
            break


    for j_idx in range(TIME_HORIZON):
        # Random translation at each timestep within the allowed range
        x_translate = min(max(x_init + x_velocity * j_idx, 0), IMAGE_SIZE - RESIZE_DIM)
        y_translate = min(max(y_init + y_velocity * j_idx, 0), IMAGE_SIZE - RESIZE_DIM)
        
        # Resize and create a padded canvas
        resized_img = resize_transform(data[i_idx, :, :, :]) # (1, 14, 14)
        padded_img = torch.zeros([1,IMAGE_SIZE,IMAGE_SIZE])  # (1, 56, 56)

        # Place the resized image on the canvas at the translated position
        padded_img[:, x_translate:x_translate + RESIZE_DIM, y_translate:y_translate + RESIZE_DIM] = resized_img

        # Update the dataset with the translated image
        new_data[i_idx, :, :, :, j_idx] = padded_img

100%|██████████| 15300/15300 [00:08<00:00, 1876.52it/s]


In [8]:
# def draw_images(X):
    
#     fig=go.Figure(go.Image(z=X[:, :, :, 0].permute(1, 2, 0).repeat(1, 1, 3)*255))

#     fig.layout.updatemenus = [
#         {
#             "buttons": [
#                 {
#                     "args": [None, {"frame": {"duration": 100, "redraw": True},
#                                     "fromcurrent": True, 
#                                     "transition": {"duration": 1, "easing": "quadratic-in-out"}}],
#                     "label": "Play",
#                     "method": "animate"
#                 },
#                 {
#                     "args": [[None], {"frame": {"duration": 0, "redraw": False},
#                                     "mode": "immediate",
#                                     "transition": {"duration": 0}}],
#                     "label": "Pause",
#                     "method": "animate"
#                 }
#             ],
#             "direction": "down",
#             "pad": {"r": 10, "t": 30},
#             "showactive": False,
#             "type": "buttons",
#             "x": 0.1,
#             "xanchor": "right",
#             "y": 0,
#             "yanchor": "top"
#         }
#     ]

#     sliders_dict = {
#         "active": 0,
#         "yanchor": "top",
#         "xanchor": "left",
#         "currentvalue": {
#             "font": {"size": 20},
#             "prefix": "Frame:",
#             "visible": True,
#             "xanchor": "left"
#         },
#         "transition": {"duration": 1, "easing": "cubic-in-out"},
#         "pad": {"b": 10, "t": 10},
#         "len": 0.9,
#         "x": 0.1,
#         "y": 0,
#         "steps": []
#     }

#     frames = [None]*TIME_HORIZON

#     frame_idx = 0
#     for i in range(TIME_HORIZON):
#         frames[frame_idx] = go.Frame(data=[go.Image(z=X[:, :, :, i].permute(1, 2, 0).repeat(1, 1, 3)*255)], name=str(i))
#         frame_idx += 1

#         slider_step = {
#             "args": [
#                 [i],
#                 {"frame": {"duration": 1, "redraw": True},
#                 "mode": "immediate",
#                 "transition": {"duration": 1}}
#             ],
#             "label": i+1,
#             "method": "animate"
#         }
#         sliders_dict["steps"].append(slider_step)
#     fig.layout.sliders = [sliders_dict]

#     fig.update(frames=frames)
#     fig.update_layout(**plotly_layout, width=500, height=500)
    
#     return fig

In [9]:
# draw_images(data[15])

In [10]:
# Save the dataset
SAVE_PATH = '../../dataset_42/TranslatingMNIST/'
os.makedirs(SAVE_PATH, exist_ok=True)

torch.save({
    'data': new_data[targets == 3],
    'targets': targets[targets == 3]
}, os.path.join(SAVE_PATH, 'TranslatingMNIST-digit=3.pkl'))