Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update disentanglement interpolation #241

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion avae/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,12 @@ def train(
x_train, p_train, vae, data_dim, device
)

if pose and config.VIS_POSE_CLASS:
if (
pose
and config.VIS_POS
and config.VIS_POSE_CLASS
and (epoch + 1) % config.FREQ_POS == 0
):
vis.pose_class_disentanglement_plot(
x_train,
y_train,
Expand Down
3 changes: 2 additions & 1 deletion avae/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,9 @@ def save_imshow_png(

fig, _ = plt.subplots(figsize=(10, 10))
plt.imshow(array, cmap=cmap, vmin=min, vmax=max) # channels last
plt.axis("off")

plt.savefig("plots/" + fname)
plt.savefig("plots/" + fname, bbox_inches="tight", pad_inches=0)

if writer:
writer.add_figure(figname, fig, epoch)
Expand Down
54 changes: 32 additions & 22 deletions avae/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,30 +1077,35 @@ def latent_disentamglement_plot(
"################################################################"
)
logging.info("Visualising latent content disentanglement ...\n")
number_of_samples = 7

# every SAMPLING_STEP interval from -SIGMA to SIGMA
sigma = 2
sampling_rate = 15
sampling_step = (sigma * 2) / (sampling_rate - 1)

padding = 0
lats = np.asarray(lats)
if poses is not None:
poses = np.asarray(poses)

lat_means = np.mean(lats, axis=0)
lat_stds = np.std(lats, axis=0)

lat_dims = lats.shape[-1]
lat_grid = np.zeros((lat_dims * number_of_samples, lat_dims))
lat_grid = np.zeros((lat_dims * sampling_rate, lat_dims))
if poses is not None:
pos_means = np.mean(poses, axis=0)
pos_dims = poses.shape[-1]
pos_grid = (
np.zeros((lat_dims * number_of_samples, pos_dims)) + pos_means
)
pos_grid = np.zeros((lat_dims * sampling_rate, pos_dims)) + pos_means

# Generate vectors representing single transversals along each lat_dim
for l_dim in range(lat_dims):
for grid_spot in range(7):
for grid_spot in range(sampling_rate):
means = copy.deepcopy(lat_means)
# every 0.4 interval from -1.2 to 1.2 sigma
means[l_dim] += lat_stds[l_dim] * (-1.2 + 0.4 * grid_spot)
lat_grid[l_dim * number_of_samples + grid_spot, :] = means
means[l_dim] += lat_stds[l_dim] * (
-sigma + sampling_step * grid_spot
)
lat_grid[l_dim * sampling_rate + grid_spot, :] = means

# Decode interpolated vectors
with torch.no_grad():
Expand All @@ -1121,13 +1126,13 @@ def latent_disentamglement_plot(
return

recon = np.reshape(
np.array(recon.cpu()), (lat_dims, number_of_samples, *dsize)
np.array(recon.cpu()), (lat_dims, sampling_rate, *dsize)
)
grid_for_napari = create_grid_for_plotting(
lat_dims, number_of_samples, dsize, padding
lat_dims, sampling_rate, dsize, padding
)
grid_for_napari = fill_grid_for_plottting(
lat_dims, number_of_samples, grid_for_napari, dsize, recon, padding
lat_dims, sampling_rate, grid_for_napari, dsize, recon, padding
)

if data_dim == 3:
Expand Down Expand Up @@ -1182,27 +1187,32 @@ def pose_disentanglement_plot(
"Visualising pose disentanglement for class {}...\n".format(label)
)

number_of_samples = 7
padding = 0
# every SAMPLING_STEP interval from -SIGMA to SIGMA
sigma = 2
sampling_rate = 15
sampling_step = (sigma * 2) / (sampling_rate - 1)

padding = 0
lats = np.asarray(lats)
poses = np.asarray(poses)

pos_means = np.mean(poses, axis=0)
pos_stds = np.std(poses, axis=0)
pos_dims = poses.shape[-1]
pos_grid = np.zeros((pos_dims * number_of_samples, pos_dims))
pos_grid = np.zeros((pos_dims * sampling_rate, pos_dims))

lat_means = np.mean(lats, axis=0)
lat_dims = lats.shape[-1]
lat_grid = np.zeros((pos_dims * number_of_samples, lat_dims)) + lat_means
lat_grid = np.zeros((pos_dims * sampling_rate, lat_dims)) + lat_means

# Generate vectors representing single transversals along each lat_dim
for p_dim in range(pos_dims):
for grid_spot in range(number_of_samples):
for grid_spot in range(sampling_rate):
means = copy.deepcopy(pos_means)
means[p_dim] += pos_stds[p_dim] * (-1.2 + 0.4 * grid_spot)
pos_grid[p_dim * number_of_samples + grid_spot, :] = means
means[p_dim] += pos_stds[p_dim] * (
-sigma + sampling_step * grid_spot
)
pos_grid[p_dim * sampling_rate + grid_spot, :] = means

# Decode interpolated vectors
with torch.no_grad():
Expand All @@ -1222,15 +1232,15 @@ def pose_disentanglement_plot(

recon = np.reshape(
np.array(recon.cpu()),
(pos_dims, number_of_samples, *dsize),
(pos_dims, sampling_rate, *dsize),
)
grid_for_napari = create_grid_for_plotting(
pos_dims, number_of_samples, dsize, padding
pos_dims, sampling_rate, dsize, padding
)

# Create and save the mrc file with single transversals
grid_for_napari = fill_grid_for_plottting(
pos_dims, number_of_samples, grid_for_napari, dsize, recon, padding
pos_dims, sampling_rate, grid_for_napari, dsize, recon, padding
)

if data_dim == 3:
Expand Down
Loading