In [None]:
import torch
import numpy as np
import pandas as pd
import time, os, sys
from tqdm.auto import tqdm, trange

import random
from tensorboardX import SummaryWriter
import argparse
from omegaconf import OmegaConf
from datetime import datetime
from functools import partial
from torchvision import transforms

import plotly.graph_objects as go
import plotly.io as pio
import plotly.express as px
plotly_layout = dict(margin=dict(l=20, r=20, t=20, b=20))

In [None]:
sys.path.append('../..')
from loader import get_dataloader
from torchvision import transforms

In [None]:
data_cfg = {
    'dataset': 'MNIST',
    'root': '../../dataset',
    'batch_size': 100,
    'n_workers': 4,
    'split': 'training',
    'shuffle': True,
    'digits': [3],
}

dl = get_dataloader(data_cfg)

In [None]:
data = dl.dataset.data
targets = dl.dataset.targets

In [None]:
EPISODES_PER_IMAGE = 5
TIME_HORIZON = 36
ANGLE_STEP = 10

data = data.repeat_interleave(EPISODES_PER_IMAGE, 0).unsqueeze(-1).repeat_interleave(TIME_HORIZON, -1)
targets = targets.repeat_interleave(EPISODES_PER_IMAGE, 0)

In [None]:
for i_idx in trange(data.shape[0]):
    theta_init = torch.rand(()) * 360
    for j_idx in range(data.shape[-1]):
        theta = theta_init.item() + ANGLE_STEP * j_idx
        data[i_idx, :, :, :, j_idx] = transforms.functional.affine(img=data[i_idx, :, :, :, j_idx], angle=theta, translate=[0, 0], scale=1., shear=0)

In [None]:
def draw_images(X):
    
    fig=go.Figure(go.Image(z=X[:, :, :, 0].permute(1, 2, 0).repeat(1, 1, 3)*255))

    fig.layout.updatemenus = [
        {
            "buttons": [
                {
                    "args": [None, {"frame": {"duration": 100, "redraw": True},
                                    "fromcurrent": True, 
                                    "transition": {"duration": 1, "easing": "quadratic-in-out"}}],
                    "label": "Play",
                    "method": "animate"
                },
                {
                    "args": [[None], {"frame": {"duration": 0, "redraw": False},
                                    "mode": "immediate",
                                    "transition": {"duration": 0}}],
                    "label": "Pause",
                    "method": "animate"
                }
            ],
            "direction": "down",
            "pad": {"r": 10, "t": 30},
            "showactive": False,
            "type": "buttons",
            "x": 0.1,
            "xanchor": "right",
            "y": 0,
            "yanchor": "top"
        }
    ]

    sliders_dict = {
        "active": 0,
        "yanchor": "top",
        "xanchor": "left",
        "currentvalue": {
            "font": {"size": 20},
            "prefix": "Frame:",
            "visible": True,
            "xanchor": "left"
        },
        "transition": {"duration": 1, "easing": "cubic-in-out"},
        "pad": {"b": 10, "t": 10},
        "len": 0.9,
        "x": 0.1,
        "y": 0,
        "steps": []
    }

    frames = [None]*TIME_HORIZON

    frame_idx = 0
    for i in range(TIME_HORIZON):
        frames[frame_idx] = go.Frame(data=[go.Image(z=X[:, :, :, i].permute(1, 2, 0).repeat(1, 1, 3)*255)], name=str(i))
        frame_idx += 1

        slider_step = {
            "args": [
                [i],
                {"frame": {"duration": 1, "redraw": True},
                "mode": "immediate",
                "transition": {"duration": 1}}
            ],
            "label": i+1,
            "method": "animate"
        }
        sliders_dict["steps"].append(slider_step)
    fig.layout.sliders = [sliders_dict]

    fig.update(frames=frames)
    fig.update_layout(**plotly_layout, width=500, height=500)
    
    return fig

In [None]:
draw_images(data[15])

In [None]:
SAVE_PATH = '/home/jblim/GGAE-public/dataset/RotatingMNIST/'
os.makedirs(SAVE_PATH, exist_ok=True)
for i in trange(10):
    torch.save({
        'data': data[targets == i],
        'targets': targets[targets == i]
    }, os.path.join(SAVE_PATH, f'RotatingMNIST-digit={i}.pkl'))