## Synthetic Mixing Experiments

In [1]:
# use hydra configs to load the dataset
import hydra
from utils import hydra_custom_resolvers
import torch
from torchvision import transforms
import torch.nn.functional as F
import numpy as np
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
from PIL import Image as Image, ImageEnhance
configs_path = "configs"
# config_name = "train_root.yaml"
config_name = "train_root.yaml"

with hydra.initialize(config_path=configs_path):
    config = hydra.compose(config_name=config_name,
                            overrides=[
                                "datamodule=mixing",
                                "datamodule.dataset.num_domains=10",
                                # "model=mixing"
                            ],
                            return_hydra_config=True,
    )

    # setup the dataset with the hydra config
    datamodule = hydra.utils.instantiate(config.datamodule, _recursive_=False)
    datamodule.prepare_data()
    datamodule.setup()

    # instantiate the model with hydra
    # model = hydra.utils.instantiate(config.model, _recursive_=False)

iterator = iter(datamodule.test_dataloader())

  cpuset_checked))


In [9]:
z_dim_invariant_data = datamodule.train_dataset.dataset.z_dim_invariant
z_dim = datamodule.train_dataset.dataset.z_dim

In [3]:
batch = next(iterator)

In [4]:
batch["z"][(batch["domain"] == 1).squeeze(), :z_dim_invariant_data].min()

tensor(0.0984)

In [11]:
for domain_idx in range(datamodule.train_dataset.dataset.num_domains):
    print(f"domain:{domain_idx} -- min:{batch['z'][(batch['domain'] == domain_idx).squeeze(), :z_dim_invariant_data].min()}, max:{batch['z'][(batch['domain'] == domain_idx).squeeze(), :z_dim_invariant_data].max()}, mean:{batch['z'][(batch['domain'] == domain_idx).squeeze(), :z_dim_invariant_data].mean()}")
print("----------------")
for domain_idx in range(datamodule.train_dataset.dataset.num_domains):
    print(f"domain:{domain_idx} -- min:{batch['z'][(batch['domain'] == domain_idx).squeeze(), z_dim_invariant_data:].min()}, max:{batch['z'][(batch['domain'] == domain_idx).squeeze(), z_dim_invariant_data:].max()}, mean:{batch['z'][(batch['domain'] == domain_idx).squeeze(), z_dim_invariant_data:].mean()}")
print("----------------")
for domain_idx in range(datamodule.train_dataset.dataset.num_domains):
    for z_dim_spurious in range(z_dim_invariant_data, z_dim):
        print(f"domain:{domain_idx}, z_dim:{z_dim_spurious} -- min:{batch['z'][(batch['domain'] == domain_idx).squeeze(), z_dim_spurious].min()}, max:{batch['z'][(batch['domain'] == domain_idx).squeeze(), z_dim_spurious].max()}, mean:{batch['z'][(batch['domain'] == domain_idx).squeeze(), z_dim_spurious].mean()}")


# # print(batch["z"][(batch["domain"] == 0).squeeze(), z_dim_invariant_data:].min())

# print(batch["z"][(batch["domain"] == 1).squeeze(), :z_dim_invariant_data].min())
# print(batch["z"][(batch["domain"] == 1).squeeze(), z_dim_invariant_data:].min())

# print(batch["z"][(batch["domain"] == 0).squeeze(), :z_dim_invariant_data].max())
# print(batch["z"][(batch["domain"] == 0).squeeze(), z_dim_invariant_data:].max())

# print(batch["z"][(batch["domain"] == 1).squeeze(), :z_dim_invariant_data].max())
# print(batch["z"][(batch["domain"] == 1).squeeze(), z_dim_invariant_data:].max())

# print(batch["z"][(batch["domain"] == 0).squeeze(), :z_dim_invariant_data].mean())
# print(batch["z"][(batch["domain"] == 0).squeeze(), z_dim_invariant_data:].mean())

# print(batch["z"][(batch["domain"] == 1).squeeze(), :z_dim_invariant_data].mean())
# print(batch["z"][(batch["domain"] == 1).squeeze(), z_dim_invariant_data:].mean())

----------------
domain:0, z_dim:2 -- min:-3.395689010620117, max:-1.9791346788406372, mean:-2.710878849029541
domain:0, z_dim:3 -- min:-1.2148573398590088, max:-0.8057138919830322, mean:-1.0142569541931152
domain:1, z_dim:2 -- min:-4.767006874084473, max:-2.539933681488037, mean:-3.676292896270752
domain:1, z_dim:3 -- min:-4.0764007568359375, max:-0.9848365783691406, mean:-2.4427056312561035
domain:2, z_dim:2 -- min:-1.7045924663543701, max:-0.3207516670227051, mean:-0.9372190237045288
domain:2, z_dim:3 -- min:1.9994860887527466, max:2.109677791595459, mean:2.0614607334136963
domain:3, z_dim:2 -- min:3.3331382274627686, max:4.46589469909668, mean:3.8830065727233887
domain:3, z_dim:3 -- min:-2.0340495109558105, max:-0.48385369777679443, mean:-1.1251734495162964
domain:4, z_dim:2 -- min:-2.998680353164673, max:-1.8430469036102295, mean:-2.490309715270996
domain:4, z_dim:3 -- min:-4.002238750457764, max:-3.2039177417755127, mean:-3.526283025741577
domain:5, z_dim:2 -- min:-4.054628849029

In [34]:
print(batch["x"][(batch["domain"] == 0).squeeze(), :5].min())
print(batch["x"][(batch["domain"] == 0).squeeze(), 5:].min())

print(batch["x"][(batch["domain"] == 1).squeeze(), :5].min())
print(batch["x"][(batch["domain"] == 1).squeeze(), 5:].min())

tensor(2.0189)
tensor(1.6085)
tensor(3.2615)
tensor(2.8083)


In [35]:
print(batch["x"][(batch["domain"] == 0).squeeze(), :5].max())
print(batch["x"][(batch["domain"] == 0).squeeze(), 5:].max())

print(batch["x"][(batch["domain"] == 1).squeeze(), :5].max())
print(batch["x"][(batch["domain"] == 1).squeeze(), 5:].max())

tensor(7.8686)
tensor(8.8918)
tensor(10.6636)
tensor(11.8312)


In [36]:
print(batch["x"][(batch["domain"] == 0).squeeze(), :5].mean())
print(batch["x"][(batch["domain"] == 0).squeeze(), 5:].mean())

print(batch["x"][(batch["domain"] == 1).squeeze(), :5].mean())
print(batch["x"][(batch["domain"] == 1).squeeze(), 5:].mean())

tensor(4.5176)
tensor(5.1726)
tensor(6.8904)
tensor(7.8361)


## Coloured MNIST Experiments

In [None]:
# golden standard
import numpy as np
import os
import math
import torch
import torchvision
from torchvision import transforms



# laod MNIST
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
root = "/network/datasets/torchvision"
data = torchvision.datasets.MNIST(root, True, transform=transform)

# color it

In [None]:
import matplotlib.pyplot as plt
plt.imshow(data[0][0].permute(1,2,0).repeat(1, 1, 3), cmap='gray')

In [None]:
# use hydra configs to load the dataset
import hydra
from utils import hydra_custom_resolvers
import torch
from torchvision import transforms
import torch.nn.functional as F
import numpy as np
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
from PIL import Image as Image, ImageEnhance
configs_path = "configs"
# config_name = "train_root.yaml"
config_name = "evaluate_root.yaml"

with hydra.initialize(config_path=configs_path):
    config = hydra.compose(config_name=config_name,
                            overrides=[
                                "datamodule/dataset=mnist",
                                # "datamodule/dataset=multi_domain_mnist",
                                "+ckpt_path='/home/mila/s/sayed.mansouri-tehrani/MD-CRL/autoencoder_multi_domain_mnist_32-epoch=52-val_loss=0.01.ckpt'",
                                "model.autoencoder.num_channels=3",
                            ],
                            return_hydra_config=True,
    )

    transform = transforms.Compose([hydra.utils.instantiate(t) for _, t in config.datamodule.transforms.items()])
    def renormalize():
        for _, t in config.datamodule.transforms.items():
            if "Standardize" in t["_target_"]:
                """Renormalize from [-1, 1] to [0, 1]."""
                return lambda x: x / 2.0 + 0.5
            
    # setup the dataset with the hydra config
    datamodule = hydra.utils.instantiate(config.datamodule, _recursive_=False)
    datamodule.prepare_data()
    datamodule.setup()

    # instantiate the model with hydra
    # model = hydra.utils.instantiate(config.model, _recursive_=False)

renormalize = datamodule.renormalize()
iterator = iter(datamodule.test_dataloader())

  cpuset_checked))


In [None]:
m = torch.load("/home/mila/s/sayed.mansouri-tehrani/MD-CRL/autoencoder_multi_domain_mnist_32-epoch=52-val_loss=0.01.ckpt")

In [None]:
for key in model.state_dict().keys():
    print((model.state_dict()[key] == m["state_dict"][key].to('cpu')).all())

In [None]:
model.model.hparams

"decoder_fc":   {'_target_': 'models.modules.fc_ae.Decoder', 'latent_size': 32, 'width': 28, 'height': 28, 'num_channels': 3, 'decoder_layers': {'Linear1': {'_target_': 'torch.nn.Linear', 'in_features': 32, 'out_features': 64}, 'LeakyReLU1': {'_target_': 'torch.nn.LeakyReLU'}, 'Dropout1': {'_target_': 'torch.nn.Dropout', 'p': 0.5}, 'Linear2': {'_target_': 'torch.nn.Linear', 'in_features': 64, 'out_features': 128}, 'LeakyReLU2': {'_target_': 'torch.nn.LeakyReLU'}, 'Dropout2': {'_target_': 'torch.nn.Dropout', 'p': 0.5}, 'Linear3': {'_target_': 'torch.nn.Linear', 'in_features': 128, 'out_features': 256}, 'LeakyReLU3': {'_target_': 'torch.nn.LeakyReLU'}, 'Dropout3': {'_target_': 'torch.nn.Dropout', 'p': 0.5}, 'Linear4': {'_target_': 'torch.nn.Linear', 'in_features': 256, 'out_features': 512}, 'LeakyReLU4': {'_target_': 'torch.nn.LeakyReLU'}, 'Dropout4': {'_target_': 'torch.nn.Dropout', 'p': 0.5}, 'Linear5': {'_target_': 'torch.nn.Linear', 'in_features': 512, 'out_features': 2352}, 'LeakyRe

In [None]:
batch = next(iterator)
img, labels = batch["image"], batch["label"]
# img, labels, domains, colors = batch["image"], batch["label"], batch["domain"], batch["color"]

In [None]:
img.shape

torch.Size([256, 1, 28, 28])

In [None]:
z, recons = model.model(img)

In [None]:
domain_z_mins = torch.zeros((2, 16, 5))
domain_z_maxs = torch.zeros((2, 16, 5))

# z is [batch_size, latent_dim], so is domains. For the first d dimensions
# of z, find the top_k smallest values of that dimension in each domain
# find the mask of z's for each domain
# for each domain, and for each of the first d dimensions, 
# find the top_k smallest values of that z dimension in that domain
for domain_idx in range(2):
    domain_mask = (domains == domain_idx).squeeze()
    domain_z = z[domain_mask]
    # for each dimension i among the first d dimensions of z, find the top_k
    # smallest values of dimension i in domain_z
    for i in range(16):
        domain_z_sorted, _ = torch.sort(domain_z[:, i], dim=0)
        domain_z_sorted = domain_z_sorted.squeeze()
        domain_z_sorted = domain_z_sorted[:5]
        domain_z_mins[domain_idx, i, :] = domain_z_sorted
        # find the top_k largest values of dimension i in domain_z
        domain_z_sorted, _ = torch.sort(domain_z[:, i], dim=0, descending=True)
        domain_z_sorted = domain_z_sorted.squeeze()
        domain_z_sorted = domain_z_sorted[:5]
        domain_z_maxs[domain_idx, i, :] = domain_z_sorted

mse_mins = F.mse_loss(domain_z_mins[0], domain_z_mins[1], reduction="mean")
mse_maxs = F.mse_loss(domain_z_maxs[0], domain_z_maxs[1], reduction="mean")


In [None]:
# print(domain_z_mins[0])
# print(domain_z_mins[1])
# print(domain_z_maxs[0])
# print(domain_z_maxs[1])
mse_mins = F.mse_loss(domain_z_mins[0], domain_z_mins[1], reduction="sum")
print(mse_mins)

tensor(0.4542, grad_fn=<MseLossBackward0>)


In [None]:
import matplotlib.pyplot as plt

# a function to clamps the values of a numpy array between 0,1
def clamp(x):
    return np.minimum(np.maximum(x, 0), 1)

for i in range(10):
    plt.figure()
    plt.imshow(clamp(recons[i].detach().cpu().numpy()))

In [None]:
# iterator = iter(datamodule.train_dataloader())
iterator = iter(datamodule.test_dataloader())
# iterator = iter(datamodule.val_dataloader())
print(next(iterator)["image"].shape)
import matplotlib.pyplot as plt
for i in range(20,40):
    sample = next(iterator)
    sample_img, sample_label, sample_domain, sample_color = sample["image"][i], sample["label"][i], sample["domain"][i], sample["color"][i]
    print(f"sample_label: {sample_label}, sample_domain: {sample_domain}, sample_color: {sample_color}")
    plt.figure()
    plt.imshow(sample_img)

In [None]:
model(next(iterator)[0]).shape

In [None]:
datamodule.valid_dataset[0][0]

In [None]:
plt.matshow(datamodule.valid_dataset[0][0].squeeze())

In [None]:
plt.imshow(data[0][0].squeeze())

In [None]:
test_sample = data[0][0].permute(1,2,0).repeat(1, 1, 3)
test_sample
non_bg_pixels = test_sample[:, :, 0] > 0.0
# test_sample[non_bg_pixels, :] = test_sample[non_bg_pixels, :] * torch.tensor([1.0, 0.0, 0.0])
# test_sample[non_bg_pixels, :] = test_sample[non_bg_pixels, :] * torch.tensor([0.0, 1.0, 0.0])
test_sample[non_bg_pixels, :] = test_sample[non_bg_pixels, :] * torch.tensor([0.0, 0.0, 1.0])
test_sample[~non_bg_pixels, :] = torch.tensor([0.0, 0.0, 0.0])
plt.imshow(test_sample)
# should normalize again to be between -1,1 or -0.5 and 0.5
# plt.matshow(data[0][0].squeeze())

In [None]:
test_sample

In [None]:
data[0][0]

In [None]:
import matplotlib.pyplot as plt

plt.imshow(data[0][0])

In [None]:
type(data[0][0])

In [None]:
# dataset for all properties

import hydra
from slot_based_disentanglement.utils import hydra_custom_resolvers
import torch
import torch.nn.functional as F
import numpy as np
from omegaconf import OmegaConf
import pickle
import matplotlib.pyplot as plt
from PIL import Image as Image, ImageEnhance
from itertools import product
import scipy
configs_path = "configs"
# config_name = "config.yaml"
config_name = "evaluate_root.yaml" # "dummy.yaml" # "evaluate_root.yaml"
# config_name = "train_root.yaml" # "dummy.yaml" # "evaluate_root.yaml"
from slot_based_disentanglement.utils.lp_solver import lp_solver_cvxpy, lp_solver_pulp


n_balls = 4
from slot_based_disentanglement.utils import hydra_custom_resolvers
with hydra.initialize(config_path=configs_path):
    config = hydra.compose(config_name=config_name,
                            overrides=[
                                # "model=inertia_balls_saae_contrastive_recons",
                                # "model.encoder.slot_size=64",
                                # f"model.encoder.num_slots={n_balls+1}",
                                # "model.z_dim=2",
                                # "model.pl_model_ckpt_path=/home/mila/s/sayed.mansouri-tehrani/mechanism-based-disentanglement/disentanglement_by_mechanisms/slot_attention_inertia_balls_autoencoder_2_cyclic_fixed-epoch\=1796-train_loss\=0.00-Linear_Disentanglement\=0.00-Permutation_Disentanglement\=0.00.ckpt",
                                # "ckpt_path=null",
                                # "datamodule=inertia_balls",
                                # "datamodule/dataset=all_properties_sparse_offset",
                                # "datamodule.dataset.signed=False",
                                # "ckpt_path=/home/mila/s/sayed.mansouri-tehrani/mechanism-based-disentanglement/disentanglement_by_mechanisms/SA_inertia_balls_contrastive_recons_zdim_3_n_balls_2_cyclic_fixed_argmin_known_mech_False_sparsity_1-epoch\=394-train_loss\=0.18-Linear_Disentanglement\=0.37-Permutation_Disentanglement\=0.47.ckpt",
                                # "ckpt_path=/home/mila/s/sayed.mansouri-tehrani/mechanism-based-disentanglement/disentanglement_by_mechanisms/SA_inertia_balls_contrastive_recons_zdim_4_n_balls_2_cyclic_fixed_argmin_known_mech_False_sparsity_1-epoch\=174-train_loss\=0.33-Linear_Disentanglement\=0.01-Permutation_Disentanglement\=0.04.ckpt",
                                # z_dim=4, z_dis=4
                                # "model.signed_change=False",
                                # "model.z_dim=4",
                                # "ckpt_path=/home/mila/s/sayed.mansouri-tehrani/mechanism-based-disentanglement/disentanglement_by_mechanisms/SA_inertia_balls_contrastive_recons_zdim_4_n_balls_2_cyclic_fixed_argmin_known_mech_False_sparsity_1-epoch\=399-train_loss\=0.15-Linear_Disentanglement\=0.21-Permutation_Disentanglement\=0.27.ckpt"
                                
                                # z_dim=2, z_dis=2, load from ckpt
                                "model=inertia_balls_saae_contrastive_recons_ckpt",
                                f"model.num_slots={n_balls+1}",
                                "model.z_dim=2",
                                "+model.encoder_ckpt_path=null",
                                # frozen enc-dec
                                # "ckpt_path=/home/mila/s/sayed.mansouri-tehrani/mechanism-based-disentanglement/disentanglement_by_mechanisms/SA_inertia_balls_contrastive_recons_zdim_2_n_balls_2_cyclic_fixed_argmin_known_mech_False_sparsity_1_z_dis_2-epoch\=399-train_loss\=0.00-Linear_Disentanglement\=0.99-Permutation_Disentanglement\=0.99.ckpt",
                                # finetuned enc-dec
                                "ckpt_path=/home/mila/s/sayed.mansouri-tehrani/mechanism-based-disentanglement/disentanglement_by_mechanisms/SA_inertia_balls_contrastive_recons_zdim_2_n_balls_2_cyclic_fixed_argmin_known_mech_False_sparsity_1_z_dis_2-epoch\=399-train_loss\=0.01-Linear_Disentanglement\=1.00-Permutation_Disentanglement\=0.98.ckpt"

                                # z_dim=3, z_dis=3, load from ckpt
                                # "model=inertia_balls_saae_contrastive_recons_ckpt",
                                # f"model.num_slots={n_balls+1}",
                                # "model.z_dim=3",
                                # "+model.encoder_ckpt_path=null",
                                # frozen enc-dec
                                # "ckpt_path=/home/mila/s/sayed.mansouri-tehrani/mechanism-based-disentanglement/disentanglement_by_mechanisms/SA_inertia_balls_contrastive_recons_zdim_3_n_balls_2_cyclic_fixed_argmin_known_mech_False_sparsity_1_z_dis_3-epoch\=399-train_loss\=0.97-Linear_Disentanglement\=0.66-Permutation_Disentanglement\=0.67.ckpt",
                                # finetuned enc-dec
                                # "ckpt_path=/home/mila/s/sayed.mansouri-tehrani/mechanism-based-disentanglement/disentanglement_by_mechanisms/SA_inertia_balls_contrastive_recons_zdim_3_n_balls_2_cyclic_fixed_argmin_known_mech_False_sparsity_1_z_dis_3-epoch\=399-train_loss\=0.84-Linear_Disentanglement\=0.66-Permutation_Disentanglement\=0.67.ckpt",
                                
                                # z_dim=4, z_dis=4, load from ckpt
                                # "model=inertia_balls_saae_contrastive_recons_ckpt",
                                # f"model.num_slots={n_balls+1}",
                                # "model.z_dim=4",
                                # "+model.encoder_ckpt_path=null",
                                # frozen enc-dec
                                # "ckpt_path=/home/mila/s/sayed.mansouri-tehrani/mechanism-based-disentanglement/disentanglement_by_mechanisms/SA_inertia_balls_contrastive_recons_zdim_4_n_balls_2_cyclic_fixed_argmin_known_mech_False_sparsity_1_z_dis_4-epoch\=299-train_loss\=1.31-Linear_Disentanglement\=0.50-Permutation_Disentanglement\=0.52.ckpt",
                                # finetuned enc-dec
                                # "ckpt_path=/home/mila/s/sayed.mansouri-tehrani/mechanism-based-disentanglement/disentanglement_by_mechanisms/SA_inertia_balls_contrastive_recons_zdim_4_n_balls_2_cyclic_fixed_argmin_known_mech_False_sparsity_1_z_dis_4-epoch\=299-train_loss\=0.34-Linear_Disentanglement\=0.50-Permutation_Disentanglement\=0.46.ckpt",

                                # z_dim=3, z_dis=3, load from ckpt, position change + shape change (colour fixed)
                                # "model=inertia_balls_saae_contrastive_recons_ckpt",
                                # f"model.num_slots={n_balls+1}",
                                # "model.z_dim=3",
                                # "+model.encoder_ckpt_path=null",
                                # "ckpt_path=/home/mila/s/sayed.mansouri-tehrani/mechanism-based-disentanglement/disentanglement_by_mechanisms/SA_inertia_balls_contrastive_recons_zdim_3_n_balls_2_cyclic_fixed_argmin_known_mech_False_sparsity_1_z_dis_3-epoch\=16-train_loss\=1.26-Linear_Disentanglement\=0.66-Permutation_Disentanglement\=0.66.ckpt",

                            ],
                            return_hydra_config=True,
    )
    
    # model = hydra.utils.instantiate(config.model, _recursive_=False)

    from torchvision import transforms
    transform = transforms.Compose([hydra.utils.instantiate(t) for _, t in config.datamodule.transforms.items()])
    def renormalize():
        for _, t in config.datamodule.transforms.items():
            if "Standardize" in t["_target_"]:
                """Renormalize from [-1, 1] to [0, 1]."""
                return lambda x: x / 2.0 + 0.5
            
            # TODO: add more options if required

device = 'cuda'
# model.to(device)
renormalize = renormalize()

In [None]:
n_balls = 3
z = np.zeros((n_balls, 6))
z[0, :2] = [0.5, 0.25]
z[1, :2] = [0., 0.4]
z[2, :2] = [0.25, 0.25]
# colour
z[0, 2] = 0
z[1, 2] = 1
z[2, 2] = 2
# shape
z[0, 3] = 0
z[1, 3] = 1
z[2, 3] = 2
# size
z[0, 4] = 0.1
z[1, 4] = 0.07
z[2, 4] = 0.1
# rotation angle
z[0, 5] = 0.0 # 0.2
z[1, 5] = -0.0 # 0.4
z[2, 5] = 0.0 # 0.4

hsv_colours = [COLOURS_[z[i,2].astype(int)] for i in range(z.shape[0])]
rgb_colours = [[255.*channel for channel in colorsys.hls_to_rgb(*c)] for c in hsv_colours]
x, seg_mask = draw_scene(z, rgb_colours)
x = transform(x)

In [None]:
plt.figure()
plt.imshow(renormalize(x.permute(1,2,0)))
plt.figure()
plt.imshow(seg_mask[3])

In [None]:
# dataset for all properties

import hydra
from slot_based_disentanglement.utils import hydra_custom_resolvers
import torch
import torch.nn.functional as F
import numpy as np
from omegaconf import OmegaConf
import pickle
import matplotlib.pyplot as plt
from PIL import Image as Image, ImageEnhance
from itertools import product
import scipy
configs_path = "configs"
# config_name = "config.yaml"
config_name = "evaluate_root.yaml" # "dummy.yaml" # "evaluate_root.yaml"
# config_name = "train_root.yaml" # "dummy.yaml" # "evaluate_root.yaml"
from slot_based_disentanglement.utils.lp_solver import lp_solver_cvxpy, lp_solver_pulp


n_balls = 4
from slot_based_disentanglement.utils import hydra_custom_resolvers
with hydra.initialize(config_path=configs_path):
    config = hydra.compose(config_name=config_name,
                            overrides=[
                                "datamodule=inertia_balls",
                                f"datamodule.n_balls={n_balls}",
                                f"datamodule.num_samples.train={10}",
                                f"datamodule.num_samples.valid={10}",
                                f"datamodule.num_samples.test={10}",
                                "datamodule/dataset=all_p_sparse_offset",
                                "datamodule.dataset.signed=True",
                                "datamodule.dataset.properties_list=['x','y','c','l','p']",
                                "datamodule.dataset.z_dim=5",
                            ],
                            return_hydra_config=True,
    )
    
    datamodule = hydra.utils.instantiate(config.datamodule, _recursive_=False)
    datamodule.prepare_data()
    datamodule.setup()

    from torchvision import transforms
    transform = transforms.Compose([hydra.utils.instantiate(t) for _, t in config.datamodule.transforms.items()])
    def renormalize():
        for _, t in config.datamodule.transforms.items():
            if "Standardize" in t["_target_"]:
                """Renormalize from [-1, 1] to [0, 1]."""
                return lambda x: x / 2.0 + 0.5
            
            # TODO: add more options if required

renormalize = datamodule.renormalize()
iterator = iter(datamodule.test_dataloader())

In [None]:
batch = next(iterator)

In [None]:
z1,z2 = batch["latents"]
x1,x2 = batch["images"]
s1,s2 = batch["segmentation_masks"]

In [None]:
print(x2.min())
print(x2.max())

In [None]:
plt.matshow(x1[0].permute(1,2,0))

In [None]:
n = 4
for i in range(n):
    plt.figure()
    plt.imshow(x1[i].permute(1,2,0))
    plt.figure()
    plt.imshow(x2[i].permute(1,2,0))
# for i in range(0,1):
#     for j in range(4):
#         plt.figure()
#         plt.imshow(s1[i][j])
#         plt.figure()
#         plt.imshow(s2[i][j])

In [None]:
for i in range(0,5):
    plt.figure()
    plt.imshow(renormalize(x1[i].permute(1,2,0)))
    plt.figure()
    plt.imshow(renormalize(x2[i].permute(1,2,0)))

In [None]:
num_samples = 5
print(z1[:num_samples, :2])
print(z1[:num_samples, 6:8])
print(f"--------------------======--------------------")
print(z2[:num_samples, :2])
print(z2[:num_samples, 6:8])
print(f"--------------------======--------------------")
print((z2[:num_samples]-z1[:num_samples]).reshape(num_samples, n_balls, -1))
print((np.linalg.norm(z2[:num_samples, :2]-z1[:num_samples, :2])))

In [None]:
dists = np.linalg.norm(z1.reshape(z1.shape[0], -1, 6)[:num_samples, 0, :2]-z1.reshape(z1.shape[0], -1, 6)[:num_samples, 1, :2], ord=2, axis=1)
print(dists)
print(z1[:num_samples, 4] * 4)

In [None]:
import numpy as np
import pygame
import os
from pygame import gfxdraw
import colorsys
import math


if "SDL_VIDEODRIVER" not in os.environ:
    os.environ["SDL_VIDEODRIVER"] = "dummy"
    os.environ["SDL_AUDIODRIVER"] = "dsp"

# HSV colours
COLOURS_ = [
    [0.05, 0.6, 0.6],
    # [0.15, 0.6, 0.6],
    [0.25, 0.6, 0.6],
    # [0.35, 0.6, 0.6],
    [0.45, 0.6, 0.6],
    # [0.55, 0.6, 0.6],
    [0.65, 0.6, 0.6],
    # [0.75, 0.6, 0.6],
    [0.85, 0.6, 0.6],
    # [0.95, 0.6, 0.6],
]

SHAPES_ = [
    "circle",
    "square",
    "triangle",
    "heart"
]

SCREEN_DIM = 128
Y_SHIFT = 0.0

def draw_shape(
    x_,
    y_,
    surf,
    color=(204, 204, 0),
    radius=0.1,
    screen_width=SCREEN_DIM,
    y_shift=Y_SHIFT,
    offset=None,
    shape="circle",
    rotation_angle=0.
):
    if offset is None:
        offset = screen_width / 2
    scale = screen_width
    x = scale * x_ + offset
    y = scale * y_ + offset

    temp_surf = pygame.Surface((screen_width, screen_width), pygame.SRCALPHA)
    # temp_surf_rotation = pygame.Surface((screen_width, screen_width), pygame.SRCALPHA) # for rotations
    temp_surf_rotation = pygame.Surface((20, 20), pygame.SRCALPHA) # for rotations

    if shape == "circle":
        # pygame.draw.circle(surface=surf, color=color,
        #                center=(int(x), int(y - offset * y_shift)), radius=int(radius * scale))
        gfxdraw.aacircle(
            temp_surf_rotation, 0, 0, int(radius * scale), color
            )
        gfxdraw.filled_circle(
            temp_surf_rotation, 0, 0, int(radius * scale), color
        )

        # for segmentation mask
        gfxdraw.aacircle(
        temp_surf, 0, 0, int(radius * scale), color
            )
        gfxdraw.filled_circle(
            temp_surf, 0, 0, int(radius * scale), color
            )

    elif shape == "square":
        radius = int(radius * scale)*2
        pygame.draw.polygon(surface=temp_surf_rotation, color=color,
                        points=[(int(i), int(j)) for i, j in [(0,0), (radius,0), (radius,radius), (0,radius)]])
        # for segmentation mask
        pygame.draw.polygon(surface=temp_surf, color=color,
                        points=[(int(i), int(j)) for i, j in [(0,0), (radius,0), (radius,radius), (0,radius)]])

    elif shape == "triangle":
        radius = (radius * scale)*2
        # x, y = ((x) - radius/2, (y - offset * y_shift) - radius/2)
        pygame.draw.polygon(surface=temp_surf_rotation, color=color,
                        points=[(int(i), int(j)) for i, j in [(radius//2,radius), (radius,0), (0,0)]])
        # for segmentation mask
        pygame.draw.polygon(surface=temp_surf, color=color,
                        points=[(int(i), int(j)) for i, j in [(radius//2,radius), (radius,0), (0,0)]])

    elif shape == "heart":
        radius = (radius * scale)*2
        s = 3.4 # 3.5
        j = 1.33
        pygame.draw.circle(surface=temp_surf_rotation, color=color,
                    center=(int(3 * radius /(s * j)), int(radius/(s * j) + radius/2)), radius=int(radius/s))
        pygame.draw.circle(surface=temp_surf_rotation, color=color,
                    center=(int(radius/(s*j)), int(radius /(s*j) + radius/2)), radius=int(radius/s))
        pygame.draw.polygon(surface=temp_surf_rotation, color=color,
                        points=[(int(np.floor(i)), int(np.floor(j))) for i, j in [(2*radius/(s*j),0), (2 * radius/(s*j) - radius/2.0,radius/30 + radius/2), (2*radius/(s*j) + radius/2.0,radius/30 + radius/2)]])
        # for segmentation mask
        pygame.draw.circle(surface=temp_surf, color=color,
                    center=(int(3 * radius /(s * j)), int(radius/(s * j) + radius/2)), radius=int(radius/s))
        pygame.draw.circle(surface=temp_surf, color=color,
                    center=(int(radius/(s*j)), int(radius /(s*j) + radius/2)), radius=int(radius/s))
        pygame.draw.polygon(surface=temp_surf, color=color,
                        points=[(int(np.floor(i)), int(np.floor(j))) for i, j in [(2*radius/(s*j),0), (2 * radius/(s*j) - radius/2.0,radius/30 + radius/2), (2*radius/(s*j) + radius/2.0,radius/30 + radius/2)]])

    rotated_surf = pygame.Surface((screen_width, screen_width), pygame.SRCALPHA)
    # Rotate the temporary surface with the rectangle and blit it onto the new surface
    # rotated_surf.blit(pygame.transform.rotate(temp_surf_rotation, math.degrees(rotation_angle)), (0, 0))

    rotated_surf_rect = temp_surf_rotation.get_rect(center=(int(x), int(y)))
    
    rotated_fuck = pygame.transform.rotate(temp_surf_rotation, math.degrees(rotation_angle))
    rotated_surf_rect = rotated_fuck.get_rect(center=(int(x), int(y)))
    
    rotated_surf.blit(pygame.transform.rotate(temp_surf_rotation, math.degrees(rotation_angle)), rotated_surf_rect)
    surf.blit(rotated_surf, (0, 0))
    # rotated_surf.blit(pygame.transform.rotate(temp_surf_rotation, math.degrees(rotation_angle)), (int(x), int(y)))
    # surf.blit(rotated_surf, (int(x), int(y)))

    temp_surf_pos = (0,0)
    ball_mask = pygame.mask.from_surface(temp_surf)

    # mask -› surface
    new_temp_surf = ball_mask.to_surface()
    # do the same flip as the one occurring for the screen
    new_temp_surf = pygame.transform.flip(new_temp_surf, False, True)
    new_temp_surf.set_colorkey((0,0,0))

    return np.transpose(np.array(pygame.surfarray.pixels3d(new_temp_surf)), axes=(1, 0, 2))[:, :, :1] # [screen_width, screen_width, 1]
    
    # temp_surf_pos = (0,0)
    # ball_mask = pygame.mask.from_surface(temp_surf)

    # # mask -› surface
    # new_temp_surf = ball_mask.to_surface()
    # # do the same flip as the one occurring for the screen
    # new_temp_surf = pygame.transform.flip(new_temp_surf, False, True)
    # new_temp_surf.set_colorkey((0,0,0))

    # return np.transpose(np.array(pygame.surfarray.pixels3d(new_temp_surf)), axes=(1, 0, 2))[:, :, :1] # [screen_width, screen_width, 1]


In [None]:
pygame.init()
screen_dim = 128
ball_rad = 0.08
screen = pygame.display.set_mode((screen_dim, screen_dim))
surf = pygame.Surface((screen_dim, screen_dim))

def draw_scene(z, colours=None, human_mode=True):
    global surf
    global screen
    surf.fill((255, 255, 255))
    bg_surf = pygame.Surface((screen_dim, screen_dim), pygame.SRCALPHA)

    obj_masks = []
    if z.ndim == 1:
        z = z.reshape((1, 2))
    if colours is None:
        colours = [COLOURS_[3]] * z.shape[0]
    for i in range(z.shape[0]):
        obj_masks.append(
            draw_shape(
                z[i, 0],
                z[i, 1],
                surf,
                color=colours[i],
                radius=z[i,4],
                screen_width=screen_dim,
                y_shift=0.0,
                offset=0.0,
                shape=SHAPES_[int(z[i,3])],
                rotation_angle=z[i,5]
            )
        )

        _ = draw_shape(
                z[i, 0],
                z[i, 1],
                bg_surf,
                color=colours[i],
                radius=z[i,4],
                screen_width=screen_dim,
                y_shift=0.0,
                offset=0.0,
                shape=SHAPES_[int(z[i,3])],
                rotation_angle=z[i,5]
            )

        bg_surf_pos = (0,0)
        bg_mask = pygame.mask.from_surface(bg_surf)
        bg_mask.invert() # so that mask bits for balls are cleared and the bg gets set.

        # mask -› surface
        new_bg_surf = bg_mask.to_surface()
        new_bg_surf.set_colorkey((0,0,0))
        # do the same flip as the one occurring for the screen
        new_bg_surf = pygame.transform.flip(new_bg_surf, False, True)

        # print(np.array(pygame.surfarray.pixels3d(new_bg_surf)).shape)
        # bg_mask = np.array(pygame.surfarray.pixels3d(new_bg_surf))[:, :, :1] # [screen_width, screen_width, 1]
        bg_mask = np.transpose(np.array(pygame.surfarray.pixels3d(new_bg_surf)), axes=(1, 0, 2))[:, :, :1] # [screen_width, screen_width, 1]
    # ------------------------------------------ #
    surf = pygame.transform.flip(surf, False, True)
    screen.blit(surf, (0, 0))
    if human_mode:
        pygame.display.flip()
    return (
        np.transpose(
            np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2)
            )
        , np.array([bg_mask] + obj_masks)
    )


In [None]:
n_balls = 2
z = np.zeros((n_balls, 6))
for i in range(n_balls):
    z[i, :2] = [0.5, 0.5]
    # colour
    z[i, 2] = i % 5
    # shape
    z[i, 3] = 3
    # size
    z[i, 4] = 0.16
    # rotation angle
    z[i, 5] = 0.4*i # 0.2


hsv_colours = [COLOURS_[z[i,2].astype(int)] for i in range(z.shape[0])]
rgb_colours = [[255.*channel for channel in colorsys.hls_to_rgb(*c)] for c in hsv_colours]
x, seg_mask = draw_scene(z, rgb_colours)
x = transform(x)
img = renormalize(x.permute(1,2,0))
img[int(128*0.5)-2:int(128*0.5)+2, int(128*0.5)-2:int(128*0.5)+2, :] = 0.8
plt.imshow(img)

In [None]:
import torch

# data = torch.load("/home/mila/s/sayed.mansouri-tehrani/train_dataset_xy_2000.pt")
data = torch.load("/home/mila/s/sayed.mansouri-tehrani/test_dataset_xy_100.pt")


In [None]:
x1,x2 = data[0]["images"]
import matplotlib.pyplot as plt
from PIL import Image as Image, ImageEnhance

plt.figure()
plt.imshow(x2.permute(1,2,0))

In [None]:
import torch.nn as nn
import torchvision.models as models

# Autoencoder with ResNet18 Encoder
class ResNet18Autoencoder(nn.Module):
    def __init__(self):
        super(ResNet18Autoencoder, self).__init__()
        
        # Load pretrained ResNet18
        resnet18 = models.resnet18(pretrained=True)
        # Modify the last fully connected layer to output 64 features
        resnet18.fc = nn.Linear(512, 64)
        self.encoder = resnet18 # nn.Sequential(*list(resnet18.children())[:-2])  # Exclude the last two layers

        # Decoder layers
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=1, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=2),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 8, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(8, 4, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(4, 4, kernel_size=2, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(4, 3, kernel_size=4, stride=2, padding=1),
            # nn.ReLU(),
            # # nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            # # nn.ReLU(),
            # nn.ConvTranspose2d(32, 16, kernel_size=2, stride=1, padding=1),
            # nn.ReLU(),
            # nn.ConvTranspose2d(16, 8, kernel_size=4, stride=2, padding=1),
            # nn.ReLU(),
            # nn.ConvTranspose2d(8, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()  # Output range [0, 1] for images
        )


    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded.view(encoded.size(0), 64, 1, 1))
        return decoded

In [None]:
o = ResNet18Autoencoder()

In [None]:
o(img.permute(0,3,1,2)).shape

torch.Size([256, 3, 28, 28])