In [None]:
import sys
sys.path.append("../")

import sqlite3
import pandas as pd
from omegaconf import OmegaConf
from copy import deepcopy
import torch
import numpy as np
from src.utils.setup_model import setup
from tqdm import tqdm
# Import sigmoid function
from torch.nn.functional import sigmoid
try:
    OmegaConf.register_new_resolver("eval", eval)
except ValueError:
    pass

In [None]:
def get_predictions(problem_type, table, db_path, device="cuda"):

    res_conn = sqlite3.connect(f'../out/{problem_type}/results.db')
    res_cur = res_conn.cursor()

    runs = pd.read_sql("SELECT * FROM " + table, res_conn)

    config_path = f"../out/{problem_type}/24/rotary_transformer/.hydra/config.yaml"

    omega_conf = OmegaConf.load(config_path)
    omega_conf.local.db_path = "../data/features.db"

    base_config = OmegaConf.to_container(omega_conf, resolve=True)

    base_config["logger"]["args"]["mode"] = "disabled"

    # Change device in config
    base_config["shared"]["device"] = device

    base_config["data"]["dataset"]["train"]["args"]["splits"]

    base_config["train"]["load_checkpoint"] = False
    base_config["cross_validation"]["use_checkpoints"] = False

    # The db should be in the same directory as the main db
    pred_conn = sqlite3.connect(db_path)
    pred_cur = pred_conn.cursor()

    # Now we need to create a table for the predictions
    # Drop if it exists
    pred_cur.execute(f"DROP TABLE IF EXISTS {problem_type}_{table}_predictions")
    pred_cur.execute(f"""CREATE TABLE {problem_type}_{table}_predictions 
                (run_id INTEGER, 
                harpnum INT,
                start TEXT,
                end TEXT,
                activity_in_obs REAL,
                flare_id INT,
                cme_id INT,
                cme_type INT,
                has_cme_flare_above_threshold INT,
                has_cme_flare_below_threshold INT,
                has_cme_no_flare INT,
                has_flare_above_threshold INT,
                has_flare_below_threshold INT,
                pred REAL, 
                target REAL,
                PRIMARY KEY (run_id, harpnum, end)
                )"""
                )


    for idx, run in tqdm(runs.iterrows()):
        config = deepcopy(base_config)

        # Set the splits correctly

        config["data"]["dataset"]["train"]["args"]["splits"] = pd.eval(run["train_splits"])
        config["data"]["dataset"]["val"]["args"]["splits"] = pd.eval(run["val_splits"])

        checkpoint_dir = run["checkpoint_dir"]

        # Break checkpoint_dir path where the subfolder out is 
        checkpoint_dir = "../out" + checkpoint_dir.split("out")[1]

        best_epoch = run["best_epoch"]
        checkpoint_path = checkpoint_dir + f"/checkpoint_{best_epoch}.pth"

        # Read the checkpoint

        checkpoint = torch.load(checkpoint_path, map_location=device)

        # Now load all the bits
        model_parts = setup(config)

        # Now need to load the checkpoint to the model

        model = model_parts["model"]
        model.load_state_dict(checkpoint["model_state_dict"])

        model.eval()

        # Now we can make predictions for the whole thing.
        # We'll store these in a new database.

        # Now we can go through the validation dataloader
        val_dataloader = model_parts["dataloaders"]["val"]

        run_id = run["run_id"]
        results = []
        for batch in tqdm(val_dataloader):
            data, label, metadata = batch
            data = data.to(device)
            label = label.to(device)

            with torch.no_grad():
                pred = sigmoid(model(data))
                pred = pred.cpu().numpy()
                label = label.cpu().numpy()

            cme_ids = metadata.get("cme_id", None)
            flare_ids = metadata.get("flare_id", None)
            cme_types = metadata.get("cme_type", None)
            has_cme_flare_above_thresholds = metadata.get("has_cme_flare_above_threshold", None)
            has_cme_flare_below_thresholds = metadata.get("has_cme_flare_below_threshold", None)
            has_cme_no_flares = metadata.get("has_cme_no_flare", None)
            has_flare_above_thresholds = metadata.get("has_flare_above_threshold", None)
            has_flare_below_thresholds = metadata.get("has_flare_below_threshold", None)


            for i in range(len(pred)):
                harpnum = int(metadata["harpnum"][i].cpu().numpy())
                start = metadata["start_date"][i]
                end = metadata["end_date"][i]
                if cme_ids is not None:
                    cme_id = cme_ids[i].cpu().numpy()
                    cme_id = int(cme_id) if not np.isnan(cme_id) else None
                else:
                    cme_id = None
                
                if flare_ids is not None:
                    flare_id = flare_ids[i].cpu().numpy()
                    flare_id = int(flare_id) if not np.isnan(flare_id) else None
                else:
                    flare_id = None

                if cme_types is not None:
                    cme_type = cme_types[i].cpu().numpy()
                    cme_type = int(cme_type) if not np.isnan(cme_type) else None
                else:
                    cme_type = None

                if has_cme_flare_above_thresholds is not None:
                    has_cme_flare_above_threshold = has_cme_flare_above_thresholds[i].cpu().numpy()
                    has_cme_flare_above_threshold = int(has_cme_flare_above_threshold)
                else:
                    has_cme_flare_above_threshold = None

                if has_cme_flare_below_thresholds is not None:
                    has_cme_flare_below_threshold = has_cme_flare_below_thresholds[i].cpu().numpy()
                    has_cme_flare_below_threshold = int(has_cme_flare_below_threshold)
                else:
                    has_cme_flare_below_threshold = None

                if has_cme_no_flares is not None:
                    has_cme_no_flare = has_cme_no_flares[i].cpu().numpy()
                    has_cme_no_flare = int(has_cme_no_flare)
                else:
                    has_cme_no_flare = None

                if has_flare_above_thresholds is not None:
                    has_flare_above_threshold = has_flare_above_thresholds[i].cpu().numpy()
                    has_flare_above_threshold = int(has_flare_above_threshold)
                else:
                    has_flare_above_threshold = None

                if has_flare_below_thresholds is not None:
                    has_flare_below_threshold = has_flare_below_thresholds[i].cpu().numpy()
                    has_flare_below_threshold = int(has_flare_below_threshold)
                else:
                    has_flare_below_threshold = None

                activity_in_obs = int(metadata["activity_in_obs"][i].cpu().numpy())
                spred = float(pred[i][0])
                slabel = int(label[i])

                results.append((run_id, harpnum, start, end, activity_in_obs, flare_id, cme_id, cme_type, spred, slabel, 
                                has_cme_flare_above_threshold, has_cme_flare_below_threshold, has_cme_no_flare, 
                                has_flare_above_threshold, has_flare_below_threshold))

        # Insert into databasae

        pred_cur.executemany(f"""INSERT INTO {problem_type}_{table}_predictions
                                    (run_id, harpnum, start, end, activity_in_obs, flare_id, cme_id, cme_type, pred, target, 
                                    has_cme_flare_above_threshold, has_cme_flare_below_threshold, has_cme_no_flare,
                                    has_flare_above_threshold, has_flare_below_threshold)
                                    VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", results)

        pred_conn.commit()

In [None]:
# Override device as appropriate
get_predictions("flare_forecast", "rotary_transformer_24", "../out/main_predictions.db", device="cuda")

In [None]:
# Override device as appropriate
get_predictions("cme_forecast", "rotary_transformer_24", "../out/main_predictions.db", device="cuda")

In [None]:
# Override device as appropriate
get_predictions("flare_cme_assoc", "rotary_transformer_24", "../out/main_predictions.db", device="cuda")