In [1]:
# import model and everything needed to run it
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Trainer

import os
import numpy as np
from copy import deepcopy
from rbm_torch.models.pool_crbm_base import pool_CRBM
from rbm_torch.utils.utils import load_run_file

findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family ['Arial'] not found. Falling back to DejaVu Sans.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Ari

In [5]:
# load data from preconfigured run file
run_file = "./run_files/toy_pcrbm.json"
run_data, base_config = load_run_file(run_file)

base_config["model_name"] = run_data["model_name"]
base_config["gpus"] = run_data["gpus"]

model_type = run_data["model_type"]
assert model_type == "pool_crbm"

server_model_dir = run_data["server_model_dir"]
base_config['seed'] = 69

base_config['lgap'] = 0.0
base_config['lbs'] = 0.0
base_config['lcorr'] = 0.0
base_config['ld'] = 0.0
base_config['l1_2'] = 1.0
base_config['epochs'] = 1000

In [10]:
# Training Code for individual model
def train_model(config):
    model = pool_CRBM(config, debug=False, precision=config["precision"])
    logger = TensorBoardLogger(server_model_dir, name=config["model_name"])
    if config["gpus"] > 1:
        # distributed data parallel, multi-gpus on single machine or across multiple machines
        plt = Trainer(max_epochs=config['epochs'], logger=logger, gpus=run_data["gpus"], accelerator="cuda", strategy="ddp", enable_progress_bar=False, enable_model_summary=False)  # distributed data-parallel
    else:
        if config['gpus'] == 0:
            plt = Trainer(max_epochs=config['epochs'], logger=logger, accelerator="cpu", enable_progress_bar=False, enable_model_summary=False)
        else:
            plt = Trainer(max_epochs=config['epochs'], logger=logger, devices=run_data["gpus"], accelerator="cuda", enable_progress_bar=False, enable_model_summary=False)  # gpus=1,
    plt.fit(model)
    return

In [38]:
import tbparse
import pandas as pd
from rbm_torch.analysis import analysis_methods as am

def parse_tb_files_crbm(model_str, model_dir="./", version=None):
    """parses the log directory of a given model, and optionally version. Extracts data to a pandas dataframe for
    easy graphing of our model"""
    checkp, version_dir = am.get_checkpoint_path(model_str, rbmdir=model_dir, version=version)

    # Read in all scalar event files and extract info
    scalars = ["weight_reg", "field_reg", "distance_reg", "gap_reg", "free_energy_diff", "loss", "free_energy_pos", "free_energy_neg", "Input Correlation Reg"]
    dfs = []
    for scalar in scalars:
        reader = tbparse.SummaryReader(f"{version_dir}/Train Scalars_{scalar}")
        df = reader.scalars
        rename = {"Train Scalars": scalar}
        df.replace({"tag": rename}, inplace=True)
        dfs.append(df)

    # Read in the main event file and drop almost all the information in it
    reader = tbparse.SummaryReader(version_dir)
    df = reader.scalars
    print(df.keys())

    removal_tags = ["Train Scalars", "Val Scalars", "hp_metric", "train_free_energy_epoch", "val_free_energy_epoch", "ptl/free_energy_diff_step",
                    "val_free_energy_step", "train_free_energy_step", "train_loss_step", "epoch", "train_loss_epoch"]
    rename_tags = {"ptl/train_free_energy_epoch": "train_free_energy_epoch",
                   "ptl/val_free_energy": "val_free_energy"}

    # Everything is Removed except for the Validation Free Energy
    for rtag in removal_tags:
        df.drop(df[df['tag'].str.contains(rtag)].index, inplace = True)

    df.replace({"tag": rename_tags}, inplace=True)

    df = df.iloc[1: , :] # Removes a false 0 step value for the Validation Free Energy
    dfs.append(df)

    df = pd.concat(dfs)  # Join all dfs together
    df.rename(columns={"step": "Epoch", "tag": "Scalar", "value": "Value"}, inplace=True)  # rename columns to epoch
    df.reset_index(inplace=True, drop=True)
    df["Model"] = model_str
    return df

## Define Models

In [None]:
# define several models for comparison across a single variable
vary = "l1_2"
vrange = np.geomspace(0.1, 100.0, num=10).tolist()
vrange.insert(0, 0.)
vrange = [round(x, 3) for x in vrange]

model_configs = [deepcopy(base_config) for _ in vrange]

for vid, vval in enumerate(vrange):
    model_configs[vid][vary] = vval
    model_configs[vid]["model_name"] += f"_{vary}_{vid}"



In [None]:
# Train Models
cdir = os.getcwd()
os.chdir("/home/jonah/PycharmProjects/phage_display_ML/")

for vid in range(len(vrange)):
    train_model(model_configs[vid])

os.chdir(cdir)

## Analysis

In [None]:
# Gather Tensorboard Data for all models
dfs = []
for i in range(len(vrange)):
    df = parse_tb_files_crbm(model_configs[i]["model_name"], model_dir="./trained_crbms/")
    dfs.append(df)
df = pd.concat(dfs)  # Join all dfs together

In [None]:
# Plot anything we're interested
import seaborn as sns
import matplotlib.pyplot as plt

key = 'distance_reg'
sns.lineplot(data=df[df['Scalar'] == key], x="Epoch", y="Value", hue="Model")
plt.ylim(0.0, 0.2)


In [None]:
key = 'free_energy_pos'
sns.lineplot(data=df[df['Scalar'] == key], x="Epoch", y="Value", hue="Model")