# View `.npz`

> Joseph P. Vantassel, The University of Texas at Austin

In [None]:
import numpy as np
import matplotlib.pyplot as plt

## Import `.npz` saved from `gns/train.py`

In [None]:
loaded = np.load("train.npz")

In [None]:
for idx, key in enumerate(loaded.keys()):
    if idx % 100 == 0:
        print(key)

## Confirm `label` is just the next step

In [None]:
trajectory = 10
value, label = loaded[f"position_{trajectory}"], loaded[f"label_{trajectory}"]

step = 10

fig, ax = plt.subplots(figsize=(3,3), dpi=150)
ax.scatter(value[step, :, 0, 0], value[step, :, 0, 1], s=0.5)
ax.scatter(label[step, :, 0], label[step, :, 1], s=0.5)
ax.scatter(value[step+1, :, 0, 0], value[step+1, :, 0, 1], s=0.5, alpha=0.5)
ax.set_xlim(0.1, 0.9)
ax.set_ylim(0.1, 0.9)
ax.set_xticks([])
ax.set_yticks([])

plt.show() 

## Check `particle_type` does not change across trajectory.

In [None]:
for idx in range(1000):
    ptype = loaded[f"particle_type_{idx}"]

    a = ptype[0, :]
    for b in ptype:
        if not np.allclose(a, b.squeeze()):
            print("a and b are not close")
            print(a, b)
            break

## Check a given trajectory.

In [None]:
# need this line if you're using jupyter notebooks
%matplotlib qt5

x = [] # Some array of images
fig, ax = plt.subplots(figsize=(5,5), dpi=150)
plt.ion() # Turns interactive mode on (probably unnecessary)
fig.show() # Initially shows the figure

for values in loaded["position_100"]:
    ax.clear() # Clears the previous image
    ax.scatter(values[:, 0, 0], values[:, 0, 1], s=0.5) # Loads the new image
    ax.set_xlim(0.1, 0.9)
    ax.set_ylim(0.1, 0.9)
    plt.pause(.000001) # Delay in seconds
    fig.canvas.draw() # Draws the image to the screen

plt.close()  

## Write `position` and `particle_type` to simplified `npz`

In [None]:
dataset = {}
for key in loaded.keys():
    if not key.startswith("position"):
        continue
    
    position_number = int(key[len("position_"):])
    
    dataset[f'simulation_trajectory_{position_number}'] = (loaded[f"position_{position_number}"].squeeze(),
                                                           loaded[f"particle_type_{position_number}"][0])

np.savez_compressed("train_sand.npz", **dataset)