In [1]:
import torch
import ray
from ray import tune
from ray.tune.logger import DEFAULT_LOGGERS

import json
import sys
from pathlib import Path

# TODO Needed to install tensorboard separately
from ghrp.model_definitions.def_simclr_ae_module import SimCLRAEModule
from ghrp.checkpoints_to_datasets.dataset_simclr import SimCLRDataset

# set which hyper-representation to load

PATH_ROOT = Path("./../data/hyper_representations/mnist")
# PATH_ROOT = Path("./../data/hyper_representations/svhn")
# PATH_ROOT = Path("./../data/hyper_representations/cifar10")
# PATH_ROOT = Path("./../data/hyper_representations/stl10")

# load config
config_path = PATH_ROOT.joinpath('config_ae.json')
config = json.load(config_path.open('r'))
config['dataset::dump'] = PATH_ROOT.joinpath('dataset.pt').absolute()

# set resources
gpus = 1 if torch.cuda.is_available() else 0
cpus = 4
resources_per_trial = {"cpu": cpus, "gpu": gpus}

device = torch.device('cuda') if gpus>0 else torch.device('cpu')
config['device'] = device

# Instanciate model
module = SimCLRAEModule(config)

from weight_diffusion.data.modelzoo_with_latent_dataset import ModelZooDataset

ds = ModelZooDataset(
    data_dir=Path("./../data/tune_zoo_mnist_uniform/"),
    checkpoint_property_of_interest="validation_loss",
    openai_coefficient=4.185,
    split="train",
    encoder=module,
    device=device
)


init attn encoder
## encoder -- use index_dict
model: use only positive contrast loss
initialze projection head


Loading Models:   0%|          | 0/15 [00:00<?, ?it/s]

compute 2 random permutations for layer 0 - 0
compute 2 random permutations for layer 1 - 3
compute 2 random permutations for layer 2 - 6
compute 2 random permutations for layer 3 - 9
prepared 10 permutations
prepare permutation dicts
0
../data/tune_zoo_mnist_uniform/NN_tune_trainable_c0371_00549_549_seed=550_2021-07-02_19-38-33/checkpoint_000050/checkpoints_latent_rep
1
../data/tune_zoo_mnist_uniform/NN_tune_trainable_c0371_00549_549_seed=550_2021-07-02_19-38-33/checkpoint_000050/checkpoints_latent_rep
10
5
2
../data/tune_zoo_mnist_uniform/NN_tune_trainable_c0371_00549_549_seed=550_2021-07-02_19-38-33/checkpoint_000050/checkpoints_latent_rep
10
9
3
../data/tune_zoo_mnist_uniform/NN_tune_trainable_c0371_00549_549_seed=550_2021-07-02_19-38-33/checkpoint_000050/checkpoints_latent_rep
10
2
4
../data/tune_zoo_mnist_uniform/NN_tune_trainable_c0371_00549_549_seed=550_2021-07-02_19-38-33/checkpoint_000050/checkpoints_latent_rep
10
4
5
../data/tune_zoo_mnist_uniform/NN_tune_trainable_c0371_005




IndexError: list index out of range

In [2]:
ds.data_sample

# TODO Test if flattened checkpoint is same using the two methods

(OrderedDict([('module_list.0.weight',
               tensor([[[[0.6521, 0.0107, 0.1979, 0.7808, 0.2648],
                         [0.3922, 0.5463, 0.7784, 0.2338, 0.7042],
                         [0.8604, 0.5623, 0.0874, 0.1844, 0.8295],
                         [0.3164, 0.0144, 0.0860, 0.6312, 0.0269],
                         [0.9355, 0.2814, 0.9490, 0.4393, 0.9571]]],
               
               
                       [[[0.8417, 0.6177, 0.4498, 0.3087, 0.0874],
                         [0.5457, 0.1713, 0.4794, 0.7438, 0.9316],
                         [0.3874, 0.2327, 0.2339, 0.3350, 0.5142],
                         [0.4129, 0.8936, 0.6974, 0.7713, 0.9553],
                         [0.0425, 0.7793, 0.2272, 0.0851, 0.5085]]],
               
               
                       [[[0.1674, 0.3356, 0.8585, 0.8096, 0.6052],
                         [0.5438, 0.1096, 0.3534, 0.7434, 0.9908],
                         [0.6399, 0.4768, 0.9893, 0.1435, 0.2780],
                      

In [2]:
# load checkpoint
checkpoint_path = PATH_ROOT.joinpath('checkpoint_ae.pt')
checkpoint = torch.load(checkpoint_path,map_location=device)

# load checkpoint to model
module.model.load_state_dict(checkpoint)

# load dataset
dataset_path = PATH_ROOT.joinpath('dataset.pt')
dataset = torch.load(dataset_path)

# get test weights
weights_test = dataset['testset'].__get_weights__()
print(weights_test.shape)
print(weights_test)
print(weights_test[754])

# forward propagate test weights
with torch.no_grad():
    z, y = module.forward(weights_test.to(device))

# z are the latent representations, y the reconstructed weights
print(z.shape)
print(y.shape)

torch.Size([755, 2464])
tensor([[ 0.1057,  0.5063,  0.7677,  ...,  0.1089, -0.2859, -0.0415],
        [ 0.1394,  0.5325,  0.7727,  ...,  0.1310, -0.3042, -0.0491],
        [ 0.1374,  0.5240,  0.7612,  ...,  0.1481, -0.3287, -0.0630],
        ...,
        [-0.6401, -0.2447,  0.1664,  ..., -0.0060, -0.1384, -0.1888],
        [-0.6525, -0.2684,  0.1314,  ...,  0.0011, -0.1421, -0.2065],
        [-0.6510, -0.2968,  0.0874,  ...,  0.0124, -0.1574, -0.2247]])
tensor([-0.6510, -0.2968,  0.0874,  ...,  0.0124, -0.1574, -0.2247])
torch.Size([755, 700])
torch.Size([755, 2464])


In [3]:
print(y[0].shape)
print(y[1].shape)
print([y[0], y[1]])

ls = y[0].tolist()
print(ls)
test = torch.tensor([ls])
print(test)

# TODO Check if this operation loses floating point info

torch.Size([2464])
torch.Size([2464])
[tensor([-0.1885,  0.1287,  0.4737,  ...,  0.0332, -0.0883, -0.0474]), tensor([-0.1688,  0.1368,  0.4810,  ...,  0.0360, -0.0885, -0.0518])]
[-0.18846958875656128, 0.1287471204996109, 0.4737478196620941, -0.04170406609773636, -0.6226829290390015, 0.45010483264923096, 1.2612406015396118, 0.8412647247314453, 0.18783283233642578, -0.5810459852218628, 0.2705056667327881, 0.7415896654129028, 0.08787162601947784, 0.25182032585144043, -0.20636172592639923, -0.18356674909591675, 0.08819474279880524, 0.06885562092065811, 0.15804779529571533, -0.6367818713188171, -0.6543688178062439, -0.31292828917503357, -0.11141198873519897, -0.07415363192558289, -0.6733338832855225, -0.3693573772907257, -0.14116783440113068, -0.15709049999713898, 0.1236235499382019, 0.11098845303058624, 0.12124721705913544, 0.05510734021663666, 0.1836421638727188, 0.49008700251579285, 0.5012267231941223, 0.4149497151374817, 1.1869255304336548, 1.3410305976867676, 0.8311110734939575, 0.577