# Introduction

This notebook walks through the steps to load pre-trained hyper-representation models, instanciate a model, load the checkpoint, load the dataset and do a forward pass.   
Make sure to install the ghrp package by running `pip3 install .` in the main directory and download the data first by running `bash download_data.sh` in `/data`.

In [None]:
import torch
import ray
from ray import tune
from ray.tune.logger import DEFAULT_LOGGERS

import json
import sys
from pathlib import Path

from ghrp.model_definitions.def_simclr_ae_module import SimCLRAEModule
from ghrp.checkpoints_to_datasets.dataset_simclr import SimCLRDataset

In [None]:
# set which hyper-representation to load

PATH_ROOT = Path("./../data/hyper_representations/mnist")
# PATH_ROOT = Path("./../data/hyper_representations/svhn")
# PATH_ROOT = Path("./../data/hyper_representations/cifar10")
# PATH_ROOT = Path("./../data/hyper_representations/stl10")

In [None]:
# load config
config_path = PATH_ROOT.joinpath('config_ae.json')
config = json.load(config_path.open('r'))
config['dataset::dump'] = PATH_ROOT.joinpath('dataset.pt').absolute()

In [None]:
# set resources
gpus = 1 if torch.cuda.is_available() else 0
cpus = 4
resources_per_trial = {"cpu": cpus, "gpu": gpus}

device = torch.device('cuda') if gpus>0 else torch.device('cpu')
config['device'] = device

In [None]:
# Instanciate model
module = SimCLRAEModule(config)

In [None]:
# load checkpoint
checkpoint_path = PATH_ROOT.joinpath('checkpoint_ae.pt')
checkpoint = torch.load(checkpoint_path,map_location=device)

In [None]:
# load checkpoint to model
module.model.load_state_dict(checkpoint)

In [None]:
# load dataset
dataset_path = PATH_ROOT.joinpath('dataset.pt')
dataset = torch.load(dataset_path)

In [None]:
# get test weights
weights_test = dataset['testset'].__get_weights__()

In [None]:
# forward propagate test weights
with torch.no_grad():
    z, y = module.forward(weights_test.to(device))

In [None]:
# z are the latent representations, y the reconstructed weights
print(z.shape)
print(y.shape)