In [1]:
import torch
import yaml 
import torch.nn.functional as F
import os
import numpy as np
import json

import mlflow

from model.VAE.pl_VAE import PlVAE
from dataset.datasetH5 import HDF5Dataset


  from .autonotebook import tqdm as notebook_tqdm


# Params

In [2]:
path_checkpoint = "/projects/pnria/julien/autofill/runs/grid_ag_les_beta0.0001_etamin1e-07_ld64_bs64/epoch=73-step=1593590.ckpt"
path_config = "/projects/pnria/julien/autofill/runs/grid_ag_les_beta0.0001_etamin1e-07_ld64_bs64/config_model.yaml"

In [3]:
path_checkpoint = "/projects/pnria/julien/autofill/runs/grid_ag_saxs_beta0.0001_ld64_bs16/epoch=197-step=517968.ckpt"
path_config = "/projects/pnria/julien/autofill/runs/grid_ag_saxs_beta0.0001_ld64_bs16/config_model.yaml"

In [4]:


with open(path_config, 'r') as file:
    config = yaml.safe_load(file)



In [5]:
config

{'dataset': {'conversion_dict_path': '/projects/pnria/DATA/AUTOFILL/v2/all_data_saxs_v2.json',
  'hdf5_file': '/projects/pnria/DATA/AUTOFILL/v2/all_data_saxs_v2.h5',
  'metadata_filters': {'material': ['ag'],
   'technique': ['saxs'],
   'type': ['simulation']},
  'requested_metadata': ['shape', 'material'],
  'sample_frac': 1.0,
  'transform': {'q': {'PaddingTransformer': {'pad_size': 500, 'value': 0}},
   'y': {'MinMaxNormalizer': {},
    'PaddingTransformer': {'pad_size': 500, 'value': 0}}}},
 'devices': '1',
 'experiment_name': 'grid_ag_saxs_beta0.0001_ld64_bs16',
 'model': {'args': {'dilation': 1,
   'down_channels': [16, 32, 64, 128, 256, 512],
   'in_channels': 1,
   'input_dim': 500,
   'latent_dim': 64,
   'output_channels': 1,
   'strat': 'y',
   'up_channels': [512, 256, 128, 64, 32, 16]},
  'output_transform_log': True,
  'vae_class': 'ResVAE'},
 'name': 'saxs_ag_gpu1',
 'training': {'T_max': 200,
  'batch_size': 16,
  'beta': 0.0001,
  'eta_min': 1e-08,
  'max_lr': 0.0001,

In [6]:
config["model"]["vae_class"] = "ResVAE"

In [7]:
config["dataset"]["requested_metadata"].append("concentration")
config["dataset"]["requested_metadata"].append("d")
config["dataset"]["requested_metadata"].append("h")
config["dataset"]["requested_metadata"].append('opticalPathLength')

In [8]:
config["dataset"]["requested_metadata"]

['shape', 'material', 'concentration', 'd', 'h', 'opticalPathLength']

In [9]:
config["dataset"]["hdf5_file"] = "data_vae_saxs.h5"

In [10]:
config["dataset"]["conversion_dict_path"] = "conversion_dict_vae_saxs.json"

In [11]:
model = PlVAE(config)

VAE Architecture:
	Input Dimension: 500
	Latent Dimension: 64
	In Channels: 1
	Down Channels: [16, 32, 64, 128, 256, 512]
	Up Channels: [512, 256, 128, 64, 32, 16]
	Output Channels: 1
	Flattened Size: 2048
	Encoder Architecture: Sequential(
  (0): ResidualBlock(
    (conv1): Conv1d(1, 16, kernel_size=(3,), stride=(2,), padding=(1,))
    (relu): ReLU()
    (conv2): Conv1d(16, 16, kernel_size=(3,), stride=(1,), padding=(1,))
    (skip_connection): Conv1d(1, 16, kernel_size=(1,), stride=(2,))
  )
  (1): ResidualBlock(
    (conv1): Conv1d(16, 32, kernel_size=(3,), stride=(2,), padding=(1,))
    (relu): ReLU()
    (conv2): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))
    (skip_connection): Conv1d(16, 32, kernel_size=(1,), stride=(2,))
  )
  (2): ResidualBlock(
    (conv1): Conv1d(32, 64, kernel_size=(3,), stride=(2,), padding=(1,))
    (relu): ReLU()
    (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
    (skip_connection): Conv1d(32, 64, kernel_size=(1,)

In [12]:
checkpoint = torch.load(path_checkpoint, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [15]:
dataset = HDF5Dataset(
hdf5_file = config["dataset"]["hdf5_file"],
metadata_filters = config["dataset"]["metadata_filters"],
conversion_dict_path = config["dataset"]["conversion_dict_path"],
sample_frac = config["dataset"]["sample_frac"],
transform =  config["dataset"]["transform"],
requested_metadata =  config["dataset"]["requested_metadata"],
)
print("========================================")

['concentration', 'd', 'h', 'material', 'method', 'opticalPathLength', 'pair_index', 'shape', 'technique', 'type']


Applying filters: 100%|██████████| 3/3 [00:00<00:00, 164.59it/s]






In [16]:
len(dataset)

741

In [17]:
dataset.metadata_datasets['d'][45]

33.0

In [18]:
dataset[6]["metadata"]

{'shape': tensor(0., dtype=torch.float64),
 'material': tensor(1., dtype=torch.float64),
 'concentration': tensor(9.4200e+12, dtype=torch.float64),
 'd': tensor(66., dtype=torch.float64),
 'h': tensor(50., dtype=torch.float64),
 'opticalPathLength': tensor(-1., dtype=torch.float64)}

In [19]:
with open(config["dataset"]["conversion_dict_path"], 'r') as file:
    conversion_dict = json.load(file)

In [20]:
conversion_dict

{'material': {'au': 0, 'ag': 1, 'silica': 2, 'sio2': 3},
 'type': {'simulation': 0, 'experimental': 1},
 'method': {'pysaxs': 0, 'chemsaxs': 1},
 'shape': {'cylinder': 0, 'sphere': 1, 'cube': 2},
 'technique': {'saxs': 0}}

In [21]:
import os
import yaml
import random

# Create output directory
output_dir = "/projects/pnria/caroline/export_vae_saxs"
os.makedirs(output_dir, exist_ok=True)


# Sample 1000 random indices from the dataset
random.seed(42)
indices = random.sample(range(len(dataset)), 500)

# Process each selected item
for file_idx, data_idx in enumerate(indices):
    data_item = dataset[data_idx]
    data_metadata = data_item["metadata"]
    org_y = data_item["data_y"]
    org_q = data_item["data_q"]

    data_item = {k:v.unsqueeze(0) if ("data_" in k and not "min" in k and not "max" in k) else v for k,v in data_item.items()}
    output = model(data_item)
    recon_y = output["recon"]

    # Extract scalar values
    org_q_values = org_q.detach().numpy().flatten()
    recon_values = recon_y.detach().numpy().flatten()
    recon_values = recon_values * (data_item["data_y_max"] - data_item["data_y_min"]) + data_item["data_y_min"]

    # Write to individual .txt file
    txt_filename = os.path.join(output_dir, f"sample_{file_idx:04d}.txt")
    with open(txt_filename, "w") as txt_file:
        for org_q_val, recon_val in zip(org_q_values,recon_values): 
            txt_file.write(f"{org_q_val} {recon_val}\n")

    # Convert metadata
    converted_metadata = {}
    for k, v in data_metadata.items():
        v = v.cpu().numpy().item()
        if k in conversion_dict:
            inv_conv = conversion_dict[k]
            inv_conv = {v_: k_ for k_, v_ in inv_conv.items()}
            v = inv_conv.get(v, v)
        converted_metadata[k] = v

    # Write to individual .yaml file
    yaml_filename = os.path.join(output_dir, f"sample_{file_idx:04d}.yaml")
    with open(yaml_filename, "w") as yaml_file:
        yaml.dump(converted_metadata, yaml_file)


yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
yea
