# Vanilla model


In [63]:
from plot import training_log

version = "8"
training_log(version=[version, "vae8_16_", 7])

In [42]:
import os

import torch
import yaml
from models.lit_model import LitModel
from numpy import random
from data.dataset import MCSims
from plot.plot import plot_field_xy_from_tensor
import holoviews as hv

hv.extension("bokeh")
device = "cpu"

path = f"lightning_logs/version_{version}/"

with open(os.path.join(path, "hparams.yaml"), "r") as f:
    config = yaml.safe_load(f)

# get the latest checkpoint from the path folder
checkpoint = max(
    [f for f in os.listdir(path) if f.endswith(".ckpt")],
    key=lambda x: int(x.split("-")[1])
)
print(f"Loading checkpoint: {checkpoint}")

litmodel = LitModel.load_from_checkpoint(
    os.path.join(path, checkpoint),
)

dataset = MCSims(augment=False, preprocess=False)
n = 864

original = dataset[n].float()  # Convert to float
model = litmodel.model
model.eval()
reconstruct = model(original.unsqueeze(0).to(device))[0].squeeze(0).detach().cpu()

plot_field_xy_from_tensor(original) + plot_field_xy_from_tensor(reconstruct)


Loading checkpoint: model-679-16.012.ckpt



Can't initialize NVML



BokehModel(combine_events=True, render_bundle={'docs_json': {'626f39fe-7a93-4e4c-8602-f87c9a05ecb7': {'version…

In [38]:
from models.lit_model import LitModel, MCSimsDataModule

dataset = MCSimsDataModule(batch_size=256, num_workers=4)
encoded_data = litmodel.encode_data(model, dataset.test_dataloader())

In [56]:
import torch
import hdbscan

from cluster_acc import adj_rand_index, purity

# from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture, BayesianGaussianMixture
from plot.plot import plot_H_vs_T_with_hover

# labels = hdbscan.HDBSCAN(min_cluster_size=15, cluster_selection_epsilon=0.06).fit_predict(encoded_data)
# labels = GaussianMixture(n_components=6, covariance_type='diag').fit_predict(encoded_data)
labels = BayesianGaussianMixture(n_components=7, covariance_type="spherical").fit_predict(encoded_data)

predicted_labels = [labels[10 * i] for i in range(261)]
# calculate the Adjusted Rand Index
ari = adj_rand_index(predicted_labels)
print(f"Adjusted Rand Index: {ari:.4f}")

purity = purity(predicted_labels)
print(f"Purity Score: {purity:.4f}")

plot_H_vs_T_with_hover(labels=labels)

Adjusted Rand Index: 0.5958
Purity Score: 0.8736


FigureWidget({
    'data': [{'customdata': {'bdata': ('AAABAAIAAwAEAAUABgAHAAgACQAKAA' ... 'sJ7AntCe4J7wnwCfEJ8gnzCfQJ9Qk='),
                             'dtype': 'i2'},
              'hovertemplate': 'T: %{x}<br>H: %{y}<extra></extra>',
              'marker': {'color': '#2E91E5', 'size': 10},
              'mode': 'markers',
              'name': 'Cluster 2',
              'type': 'scatter',
              'uid': 'f3eb4f11-5a1e-4f2f-b1b6-c5b5aea6a291',
              'x': {'bdata': ('AAAAAADQoUAAAAAAANChQAAAAAAA0K' ... 'AAABilQAAAAAAAGKVAAAAAAAAYpUA='),
                    'dtype': 'f8'},
              'y': {'bdata': ('AAAAAAAAAADEVz1FvxW/QMZXPUW/Fc' ... '5zT1AXQTL3AnGmzBdBkuwXbv1IGEE='),
                    'dtype': 'f8'}},
             {'customdata': {'bdata': ('GgAbABwAHQAeAB8AIAAhACIAIwAkAC' ... '8JIAkhCSIJIwkkCSUJJgknCSgJKQk='),
                             'dtype': 'i2'},
              'hovertemplate': 'T: %{x}<br>H: %{y}<extra></extra>',
              'marker': {'color': '#E15F