In [1]:
%load_ext autoreload
%load_ext tensorboard
%autoreload 2

In [2]:
# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning.loggers import TensorBoardLogger
import nersc_tensorboard_helper

sys.path.append("..")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
from LightningModules.Embedding.Models.layerless_embedding import LayerlessEmbedding

ModuleNotFoundError: No module named 'LightningModules.Embedding.Models.layerless_embedding'

In [None]:
run_name = "high_warmup"
with open("example_embedding.yaml") as f:
    hparams = yaml.load(f, Loader=yaml.FullLoader)
model = LayerlessEmbedding(hparams)
model.setup(stage="fit")

In [None]:
from pytorch_lightning import Trainer

logger = TensorBoardLogger("tb_logs", name="embedding_" + run_name)
trainer = Trainer(gpus=1, max_epochs=10, logger=logger)
trainer.fit(model)

In [None]:
test_results = trainer.test(ckpt_path=None)
from LightningModules.Embedding.utils import get_metrics

model.eval();

In [None]:
all_efficiencies, all_purities = [], []
all_radius = np.arange(0.5, 1.2, 0.1)

with torch.no_grad():
    for r in all_radius:

        model.hparams.r_test = r
        test_results = trainer.test(ckpt_path=None)

        mean_efficiency, mean_purity = get_metrics(test_results, model)

        all_efficiencies.append(mean_efficiency)
        all_purities.append(mean_purity)

In [None]:
plt.figure(figsize=(12, 8))
plt.plot(all_radius, all_efficiencies)
plt.title("Embedding efficiency", fontsize=24), plt.xlabel(
    "Radius of neighborhood", fontsize=18
), plt.ylabel("Efficiency", fontsize=18)
plt.savefig(run_name + "_eff.png")
plt.figure(figsize=(12, 8))
plt.plot(all_radius, all_purities)
plt.title("Embedding purity", fontsize=24), plt.xlabel(
    "Radius of neighborhood", fontsize=18
), plt.ylabel("Purity", fontsize=18)
plt.savefig(run_name + "_purity.png")

In [None]:
os.environ[
    "TENSORBOARD_BINARY"
] = "/global/homes/j/jferguso/.conda/envs/exatrkx-tracking/bin/tensorboard"
%tensorboard --logdir tb_logs/ --port 0
nersc_tensorboard_helper.tb_address()