In [1]:
%load_ext autoreload
%autoreload 2

# Loading model from checkpoint

In [2]:
RESULTS_DIR = "/home/rgura001/ML4GWsearch/results/train_20231122/130833"

checkpoint_path = RESULTS_DIR + "/checkpoints/epoch=4-step=190.ckpt"

import json
with open(RESULTS_DIR+"/run_configs/full_config.json") as f:
    config = json.load(f)
config

{'dataset_name': 'g2net-gravitational-wave-detection',
 'ifos': ['LIGO Hanford', 'LIGO Livingston', 'Virgo'],
 'total_datapoints': 560000,
 'input_channels': 3,
 'seq_len': 4096,
 'n_classes': 2,
 'z_norm': False,
 'highpass': False,
 'whiten': False,
 'scale': True,
 'bandpass': True,
 'epochs': 10,
 'batch_size': 128,
 'num_batches': 100,
 'optimizer': 'sgd',
 'learning_rate': 0.1,
 'lr_scheduler': 'step',
 'stop_early': True,
 'accumulate_grad_batches': 2,
 'layers': [128, 256, 128],
 'kernel_sizes': [7, 5, 3],
 'batch_norm': False,
 'nesterov': True,
 'momentum': 0.9,
 'weight_decay': 0.0,
 'lr_scheduler__step_size': 10,
 'lr_scheduler__gamma': 0.1,
 'stop_early__monitor': 'val_loss',
 'stop_early__mode': 'min',
 'stop_early__patience': 10,
 'use_gpu': 1,
 'model_name': 'FCN',
 'sample_size': 12800}

### [WRONG METHOD] Loading weights without specifying the model object

In [11]:
import torch
from models.get_model import get_model
from pprint import pprint
checkpoint = torch.load(checkpoint_path)
pprint(list(checkpoint.keys()))

model, lossfn  = get_model(config=config)
model.load_state_dict(checkpoint['state_dict'])

['epoch',
 'global_step',
 'pytorch-lightning_version',
 'state_dict',
 'loops',
 'callbacks',
 'optimizer_states',
 'lr_schedulers',
 'hparams_name',
 'hyper_parameters']


RuntimeError: Error(s) in loading state_dict for FCNPlus:
	Missing key(s) in state_dict: "backbone.convblock1.0.weight", "backbone.convblock1.1.weight", "backbone.convblock1.1.bias", "backbone.convblock1.1.running_mean", "backbone.convblock1.1.running_var", "backbone.convblock2.0.weight", "backbone.convblock2.1.weight", "backbone.convblock2.1.bias", "backbone.convblock2.1.running_mean", "backbone.convblock2.1.running_var", "backbone.convblock3.0.weight", "backbone.convblock3.1.weight", "backbone.convblock3.1.bias", "backbone.convblock3.1.running_mean", "backbone.convblock3.1.running_var", "head.2.weight", "head.2.bias". 
	Unexpected key(s) in state_dict: "model.backbone.convblock1.0.weight", "model.backbone.convblock1.1.weight", "model.backbone.convblock1.1.bias", "model.backbone.convblock1.1.running_mean", "model.backbone.convblock1.1.running_var", "model.backbone.convblock1.1.num_batches_tracked", "model.backbone.convblock2.0.weight", "model.backbone.convblock2.1.weight", "model.backbone.convblock2.1.bias", "model.backbone.convblock2.1.running_mean", "model.backbone.convblock2.1.running_var", "model.backbone.convblock2.1.num_batches_tracked", "model.backbone.convblock3.0.weight", "model.backbone.convblock3.1.weight", "model.backbone.convblock3.1.bias", "model.backbone.convblock3.1.running_mean", "model.backbone.convblock3.1.running_var", "model.backbone.convblock3.1.num_batches_tracked", "model.head.2.weight", "model.head.2.bias". 

### [CORRECT METHOD] Loading weights using the pytorch lightning module object

In [4]:
from GWDetectionLightningModule import GWDetectionLightningModule
model = GWDetectionLightningModule(config).load_from_checkpoint(checkpoint_path)
model.eval() ## Absolutely necessary to do if you want to use the model for inference otherwise you will get wrong results

GWDetectionLightningModule(
  (model): FCNPlus(
    (backbone): _FCNBlockPlus(
      (convblock1): ConvBlock(
        (0): Conv1d(3, 128, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
        (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (convblock2): ConvBlock(
        (0): Conv1d(128, 256, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (convblock3): ConvBlock(
        (0): Conv1d(256, 128, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
        (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (add): Sequential()
    )
    (head): Sequential(
      (0): AdaptiveAvgPool1d(output_size=1)
      (1): Squeeze(dim=-1)
      (2): Linear(in_features=128, out_features=2, bias=True)
    )
  )
  (m

In [9]:
from pprint import pprint
pprint(dict(model.hparams))

{'config': {'accumulate_grad_batches': 2,
            'bandpass': True,
            'batch_norm': False,
            'batch_size': 128,
            'dataset_name': 'g2net-gravitational-wave-detection',
            'epochs': 10,
            'highpass': False,
            'ifos': ['LIGO Hanford', 'LIGO Livingston', 'Virgo'],
            'input_channels': 3,
            'kernel_sizes': [7, 5, 3],
            'layers': [128, 256, 128],
            'learning_rate': 0.1,
            'lr_scheduler': 'step',
            'lr_scheduler__gamma': 0.1,
            'lr_scheduler__step_size': 10,
            'model_name': 'FCN',
            'momentum': 0.9,
            'n_classes': 2,
            'nesterov': True,
            'num_batches': 100,
            'optimizer': 'sgd',
            'sample_size': 12800,
            'scale': True,
            'seq_len': 4096,
            'stop_early': True,
            'stop_early__mode': 'min',
            'stop_early__monitor': 'val_loss',
            'stop_e

# Listing FP samples metadata

In [2]:
import pandas as pd

In [None]:
run_clf = pd.read_csv("/data/bchen158/ML4GW/ML4GWsearch/src/results/train_20231202/224821/plots/test_metrics/testset_preds.csv")
count_pred_1 = run_clf[run_clf['prediction'] == 1].shape[0]
count_pred_0 = run_clf[run_clf['prediction'] == 0].shape[0]

print(f"Number of samples with label 1: {count_pred_1}")
print(f"Number of samples with label 0: {count_pred_0}")

In [3]:
run_clf = pd.read_csv("/data/bchen158/ML4GW/ML4GWsearch/src/results/train_20231202/224821/plots/test_metrics/testset_preds.csv")
count_label_1 = run_clf['label'].sum()
count_label_0 = len(run_clf) - count_label_1

print(f"Number of samples with label 1: {count_label_1}")
print(f"Number of samples with label 0: {count_label_0}")

Number of samples with label 1: 0
Number of samples with label 0: 56000


In [3]:
run1_testpreds = pd.read_csv("/home/rgura001/ML4GWsearch/results/train_20230904/202204/plots/test_metrics/testset_preds.csv")
print(run1_testpreds.shape)
run1_testpreds.head()

(56000, 5)


Unnamed: 0,id,label,prediction,prediction_proba_0,prediction_proba_1
0,d3f2689122,1,1,-0.69907,0.600289
1,91d46dc05f,0,0,1.207059,-1.286181
2,c5037fc763,0,0,1.040363,-1.134061
3,c60d13040e,1,0,0.297779,-0.376915
4,12701bd0c3,1,1,-9.217836,8.962153


In [4]:
run2_testpreds = pd.read_csv("/home/rgura001/ML4GWsearch/results/train_20230904/202148/plots/test_metrics/testset_preds.csv")
print(run2_testpreds.shape)
run2_testpreds.head()

(56000, 5)


Unnamed: 0,id,label,prediction,prediction_proba_0,prediction_proba_1
0,d3f2689122,1,1,-0.074259,0.002105
1,91d46dc05f,0,0,1.389192,-1.462251
2,c5037fc763,0,0,0.57062,-0.654549
3,c60d13040e,1,1,-0.072561,0.000794
4,12701bd0c3,1,1,-8.310558,8.058495


In [5]:
fn_run1_testpreds = run1_testpreds[(run1_testpreds['label']==1) & (run1_testpreds['prediction']==0)]
fn_run2_testpreds = run2_testpreds[(run2_testpreds['label']==1) & (run2_testpreds['prediction']==0)]

print(fn_run1_testpreds.shape) 
print(fn_run2_testpreds.shape)

common_fn = pd.merge(fn_run1_testpreds, fn_run2_testpreds, how='inner', on=['id'])
common_fn.shape

(12917, 5)
(12495, 5)


(10696, 9)

In [8]:
fn_exists_in_run1_not_in_run2 = fn_run1_testpreds[~fn_run1_testpreds['id'].isin(common_fn['id'])]
print(fn_exists_in_run1_not_in_run2.shape)

fn_exists_in_run2_not_in_run1 = fn_run2_testpreds[~fn_run2_testpreds['id'].isin(common_fn['id'])]
print(fn_exists_in_run2_not_in_run1.shape)

fn_exists_in_run1_not_in_run2.head()

(2221, 5)
(1799, 5)


Unnamed: 0,id,label,prediction,prediction_proba_0,prediction_proba_1
3,c60d13040e,1,0,0.297779,-0.376915
50,5126fac4f3,1,0,0.134256,-0.217168
66,8d70f2bd0e,1,0,0.469865,-0.55217
68,e43ebd6d60,1,0,-0.029368,-0.066118
72,46ec4d54a1,1,0,0.08223,-0.165095
