env: cpa

# import

In [1]:
import sys
import json
from anndata import AnnData
import cpa
import scanpy as sc
import importlib
import numpy as np
from tqdm import tqdm
import pickle
import os
import shutil
import torch
import pandas as pd

sys.path.append("/data1/lichen/code/single_cell_perturbation/scPerturb/Byte_Pert_Data/")
import v1
from v1.utils import *
from v1.dataloader import *

Global seed set to 0


In [2]:
importlib.reload(v1)
importlib.reload(v1.utils)
importlib.reload(v1.dataloader)

<module 'v1.dataloader' from '/data1/lichen/code/single_cell_perturbation/scPerturb/Byte_Pert_Data/v1/dataloader.py'>

# load K562 gwps

In [3]:
prefix = 'ReplogleWeissman2022_K562_gwps'

print('='*20, prefix)

# - init para
data_dir = '/nfs/public/lichen/data/single_cell/perturb_data/scPerturb/raw/scPerturb_rna/statistic_20240520'
pert_cell_filter = 100 # this is used to filter perts, cell number less than this will be filtered
seed = 2024 # this is the random seed
split_type = 1 # 1 for unseen perts; 0 for unseen celltypes
split_ratio = [0.7, 0.2, 0.1] # train:test:val; val is used to choose data, test is for final validation
var_num = 5000 # selecting hvg number
num_de_genes = 20 # number of de genes
bs_train = 32 # batch size of trainloader
bs_test = 32 # batch size of testloader
lr = 1e-4

tmp_dir = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/scPerturb'
# save_prefix = f'GEARS'
save_prefix = f'GEARS-prefix_{prefix}-pert_cell_filter_{pert_cell_filter}-\
seed_{seed}-split_type_{split_type}-var_num_{var_num}-num_de_genes_{num_de_genes}-bs_train_{bs_train}-\
bs_test_{bs_test}'

save_dir = os.path.join(tmp_dir, prefix, save_prefix)

# - load pert_data
pert_data = pickle.load(open(os.path.join(save_dir,'pert_data.pkl'), 'rb'))



In [None]:
adata_ori = pert_data.adata_ori
sc.pp.normalize_per_cell(adata_ori, key_n_counts='n_counts_all')
sc.pp.log1p(adata_ori)

# 存储可以用来计算的adata_ori

In [12]:
sc.pp.normalize_per_cell(pert_data.adata_ori, key_n_counts='n_counts_all')
sc.pp.log1p(pert_data.adata_ori)

In [13]:
np.max(pert_data.adata_ori.X)

8.268049

In [4]:
# - get cell line name
common_cell_line = \
{   'A549': 'A549',
    'HEPG2': 'HepG2',
    'HT29': 'HT29',
    'MCF7': 'MCF7',
    # 'SKBR3': 'SK-BR-3',
    'SW480': 'SW480',
    'PC3': 'PC3',
    'A375': 'A375',
} # L1000 cell line : single-cell cell line

# - read adata_L1000, this is processed data
adata_L1000 = sc.read('/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/GSE92742/adata_gene_pert.h5ad')
adata_L1000

AnnData object with n_obs × n_vars = 36720 × 978
    obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id'
    var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'

In [5]:
L1000_total_perts = []
for cell_line_bulk in common_cell_line.keys():
    cell_line_single = common_cell_line[cell_line_bulk]

    adata_L1000_sub = adata_L1000[adata_L1000.obs['cell_id']==cell_line_bulk]
    L1000_total_perts += list(np.unique(adata_L1000_sub.obs['pert_iname']))
L1000_total_perts = np.unique(L1000_total_perts)

    
import json
with open('/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/utils_data/direct_change_dict.json', 'r') as f:
    direct_change_dict = json.load(f)
# - single_total_perts: the perts used in K562
gene_list = direct_change_dict['gene_list']
single_total_perts = list(direct_change_dict.keys())
single_total_perts.remove('gene_list')

common_perts = np.intersect1d(single_total_perts, L1000_total_perts)
len(common_perts)

2359

In [10]:
common_pert_groups = [i+' | K562' for i in common_perts] + ['control | K562']
adata_sub = pert_data.adata_split[pert_data.adata_split.obs['perturbation_group'].isin(common_pert_groups)]

In [11]:
adata_sub.obs['perturbation_group'].value_counts()

control | K562    75328
DUSP9 | K562       1351
RAP1GAP | K562     1270
STK38L | K562      1169
PPP2R1A | K562     1061
                  ...  
CDK6 | K562          96
HIC2 | K562          93
DDB1 | K562          70
MED15 | K562         59
AP2M1 | K562         51
Name: perturbation_group, Length: 2360, dtype: int64

In [14]:
adata_sub_ori = pert_data.adata_ori[adata_sub.obs_names].copy()
adata_sub_ori.obs = adata_sub.obs

In [17]:
adata_sub_ori

AnnData object with n_obs × n_vars = 607014 × 8248
    obs: 'batch', 'gene', 'gene_id', 'transcript', 'gene_transcript', 'guide_id', 'percent_mito', 'UMI_count', 'z_gemgroup_UMI', 'core_scale_factor', 'core_adjusted_UMI_count', 'disease', 'cancer', 'cell_line', 'sex', 'age', 'perturbation', 'organism', 'perturbation_type', 'tissue_type', 'ncounts', 'ngenes', 'nperts', 'percent_ribo', 'perturbation_new', 'perturbation_type_new', 'nperts_new', 'celltype', 'celltype_new', 'perturbation_group', 'data_split', 'retain', 'n_genes', 'control_barcode', 'sgRNA_new', 'sgRNA_ID', 'pert_sgRNA'
    var: 'chr', 'start', 'end', 'class', 'strand', 'length', 'in_matrix', 'mean', 'std', 'cv', 'fano', 'ensembl_id', 'ncounts', 'ncells'
    uns: 'log1p'

In [16]:
type(adata_sub_ori.X)

numpy.ndarray

In [18]:
import scipy.sparse

# 检查是否已经是稀疏矩阵
if not scipy.sparse.issparse(adata_sub_ori.X):
    # 如果不是稀疏矩阵，转换为稀疏矩阵
    adata_sub_ori.X = scipy.sparse.csr_matrix(adata_sub_ori.X)
    
    

In [23]:
adata_sub_ori

AnnData object with n_obs × n_vars = 607014 × 8248
    obs: 'batch', 'gene', 'gene_id', 'transcript', 'gene_transcript', 'guide_id', 'percent_mito', 'UMI_count', 'z_gemgroup_UMI', 'core_scale_factor', 'core_adjusted_UMI_count', 'disease', 'cancer', 'cell_line', 'sex', 'age', 'perturbation', 'organism', 'perturbation_type', 'tissue_type', 'ncounts', 'ngenes', 'nperts', 'percent_ribo', 'perturbation_new', 'perturbation_type_new', 'nperts_new', 'celltype', 'celltype_new', 'perturbation_group', 'data_split', 'retain', 'n_genes', 'control_barcode', 'sgRNA_new', 'sgRNA_ID', 'pert_sgRNA'
    var: 'chr', 'start', 'end', 'class', 'strand', 'length', 'in_matrix', 'mean', 'std', 'cv', 'fano', 'ensembl_id', 'ncounts', 'ncells'
    uns: 'log1p'

In [29]:
adata_sub_ori.obs = adata_sub_ori.obs[['perturbation_group', 'control_barcode']]

In [30]:
adata_sub_ori

AnnData object with n_obs × n_vars = 607014 × 8248
    obs: 'perturbation_group', 'control_barcode'
    var: 'chr', 'start', 'end', 'class', 'strand', 'length', 'in_matrix', 'mean', 'std', 'cv', 'fano', 'ensembl_id', 'ncounts', 'ncells'
    uns: 'log1p'

In [31]:
adata_sub_ori.write('/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/utils_data/adata_K562_sub.h5ad')

: 

# get common perts

In [None]:
# - get cell line name
common_cell_line = \
{   'A549': 'A549',
    'HEPG2': 'HepG2',
    'HT29': 'HT29',
    'MCF7': 'MCF7',
    # 'SKBR3': 'SK-BR-3',
    'SW480': 'SW480',
    'PC3': 'PC3',
    'A375': 'A375',
} # L1000 cell line : single-cell cell line

# - read adata_L1000, this is processed data
adata_L1000 = sc.read('/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/GSE92742/adata_gene_pert.h5ad')
adata_L1000

cell_line_bulk = 'A549'
cell_line_single = common_cell_line[cell_line_bulk]

In [None]:
    
import json
with open('/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/utils_data/direct_change_dict.json', 'r') as f:
    direct_change_dict = json.load(f)
# - single_total_perts: the perts used in K562
gene_list = direct_change_dict['gene_list']
single_total_perts = list(direct_change_dict.keys())
single_total_perts.remove('gene_list')

# - get common pert
adata_L1000_sub = adata_L1000[adata_L1000.obs['cell_id']==cell_line_bulk]
L1000_total_perts = np.unique(adata_L1000_sub.obs['pert_iname'])
common_perts = np.intersect1d(single_total_perts, L1000_total_perts)

In [None]:
len(common_perts)

2184

In [None]:
print('='*20, f'cell line is {cell_line_single}')

#===================prepare data
if cell_line_bulk in ['PC3', 'A375']:
    save_dir_adata = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/single_cell_data/SCP542'
    adata_rna = sc.read(os.path.join(save_dir_adata, cell_line_bulk, f'adata_{cell_line_bulk}.h5ad'))
    
    # - read adata_rna_raw
    save_dir = f'/nfs/public/lichen/data/single_cell/cell_line/SCP542/process/{cell_line_bulk}'
    adata_rna_raw = sc.read(os.path.join(save_dir, f'adata.h5ad'))

else:
    save_dir_adata = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/single_cell_data/CNP0003658'
    adata_rna = sc.read(os.path.join(save_dir_adata, cell_line_bulk, f'adata_{cell_line_bulk}.h5ad'))

    # - read adata_rna
    save_dir = f'/nfs/public/lichen/data/single_cell/cell_line/CNP0003658/process/RNA/{cell_line_single}'
    adata_rna_raw = sc.read(os.path.join(save_dir, f'adata_rna_{cell_line_single}.h5ad'))

# - consctrut corr mtx
if not isinstance(adata_rna.X, np.ndarray):
    adata_rna.X = adata_rna.X.toarray()



In [None]:
# - get common var
common_var = np.intersect1d(adata_rna.var_names, direct_change_dict['gene_list'])
common_var_2 = np.intersect1d(common_var, adata_L1000.var_names)

print('common_perts num: ', len(common_perts))
print('common var of direct change and single-cell data is: ', len(common_var))
print('common var to L1000 data is: ', len(common_var_2))

common_perts num:  2184
common var of direct change and single-cell data is:  2764
common var to L1000 data is:  712


In [None]:
pert = common_perts[0]

In [None]:
adata_pert = pert_data.adata_split[pert_data.adata_split[pert_data.adata_split.obs['perturbation_group']==pert+' | K562'].obs_names]
adata_ctrl = pert_data.adata_split[adata_pert.obs['control_barcode']]

adata_pert = adata_ori[adata_pert.obs_names, common_var]
adata_ctrl = adata_ori[adata_ctrl.obs_names, common_var]

In [None]:
adata_ctrl, adata_pert

(View of AnnData object with n_obs × n_vars = 182 × 2764
     obs: 'batch', 'gene', 'gene_id', 'transcript', 'gene_transcript', 'guide_id', 'percent_mito', 'UMI_count', 'z_gemgroup_UMI', 'core_scale_factor', 'core_adjusted_UMI_count', 'disease', 'cancer', 'cell_line', 'sex', 'age', 'perturbation', 'organism', 'perturbation_type', 'tissue_type', 'ncounts', 'ngenes', 'nperts', 'percent_ribo', 'perturbation_new', 'perturbation_type_new', 'nperts_new', 'celltype', 'celltype_new', 'perturbation_group', 'data_split', 'retain', 'n_counts_all'
     var: 'chr', 'start', 'end', 'class', 'strand', 'length', 'in_matrix', 'mean', 'std', 'cv', 'fano', 'ensembl_id', 'ncounts', 'ncells'
     uns: 'log1p',
 View of AnnData object with n_obs × n_vars = 182 × 2764
     obs: 'batch', 'gene', 'gene_id', 'transcript', 'gene_transcript', 'guide_id', 'percent_mito', 'UMI_count', 'z_gemgroup_UMI', 'core_scale_factor', 'core_adjusted_UMI_count', 'disease', 'cancer', 'cell_line', 'sex', 'age', 'perturbation', 'o

In [None]:
adata_rna_common = adata_rna[:, common_var]
adata_rna_common

View of AnnData object with n_obs × n_vars = 500 × 2764
    obs: 'n_genes', 'n_counts_all'
    var: 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg', 'log1p'

In [None]:
# - model initial
model_params = {
    "n_latent": 64,
    "recon_loss": "nb",
    "doser_type": "linear",
    "n_hidden_encoder": 128,
    "n_layers_encoder": 2,
    "n_hidden_decoder": 512,
    "n_layers_decoder": 2,
    "use_batch_norm_encoder": True,
    "use_layer_norm_encoder": False,
    "use_batch_norm_decoder": False,
    "use_layer_norm_decoder": True,
    "dropout_rate_encoder": 0.0,
    "dropout_rate_decoder": 0.1,
    "variational": False,
    "seed": 6977,
}

trainer_params = {
    "n_epochs_kl_warmup": None,
    "n_epochs_pretrain_ae": 30,
    "n_epochs_adv_warmup": 50,
    "n_epochs_mixup_warmup": 0,
    "mixup_alpha": 0.0,
    "adv_steps": None,
    "n_hidden_adv": 64,
    "n_layers_adv": 3,
    "use_batch_norm_adv": True,
    "use_layer_norm_adv": False,
    "dropout_rate_adv": 0.3,
    "reg_adv": 20.0,
    "pen_adv": 5.0,
    "lr": 0.0003,
    "wd": 4e-07,
    "adv_lr": 0.0003,
    "adv_wd": 4e-07,
    "adv_loss": "cce",
    "doser_lr": 0.0003,
    "doser_wd": 4e-07,
    "do_clip_grad": True,
    "gradient_clip_value": 1.0,
    "step_size_lr": 10,
}

In [None]:
# - generate adata_train to input to scGen model
np_list, obs_list, pert_list, celltype_list = [], [], [], []
pert_list_2 = []
adata_list = [adata_pert, adata_rna_common, adata_ctrl, adata_rna_common]
for j, adata_ in enumerate(adata_list):
    if j in [0, 1]:
        pert_list.extend(['stimulated']*len(adata_))
    else:
        pert_list.extend(['ctrl']*len(adata_))

    if j in [0, 2]:
        celltype_list.extend(['K562']*len(adata_))
    else:
        celltype_list.extend([cell_line_bulk]*len(adata_))
    obs_list.extend([obs+f'_{j}' for obs in adata_.obs_names])
    np_list.append(adata_.X.toarray())

adata_train = AnnData(X = np.vstack(np_list))
adata_train.obs_names = obs_list
adata_train.var_names = adata_pert.var_names

adata_train.obs['condition'] = pert_list
# adata_train.obs['condition_2'] = pert_list_2
adata_train.obs['cell_type'] = celltype_list

# - transform the adata_train.X to count
adata_train.obs['cov_cond'] = adata_train.obs['cell_type'] + '_' + adata_train.obs['condition']
adata_train.X = np.exp(adata_train.X)-1

# - initial model
cpa.CPA.setup_anndata(adata_train, 
                    perturbation_key='condition',
                    control_group='ctrl',
                    #   dosage_key='dose',
                    categorical_covariate_keys=['cell_type'],
                    is_count_data=True,
                    #   deg_uns_key='rank_genes_groups_cov',
                    deg_uns_cat_key='cov_cond',
                    max_comb_len=1,
                    )



100%|██████████| 1364/1364 [00:00<00:00, 72502.54it/s]
100%|██████████| 1364/1364 [00:00<00:00, 1131085.54it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        





In [None]:
# - set the train and validation for cpa
# -- get total obs_names of the pert
adata_train_new = adata_train[~((adata_train.obs["cell_type"] == cell_line_bulk) &
                    (adata_train.obs["condition"] == "stimulated"))].copy()
# obs_df_split = adata_train_new.obs
obs_df_sub_idx = np.array(adata_train_new.obs.index)

np.random.seed(2024)
np.random.shuffle(obs_df_sub_idx)

# -- data split
split_point_1 = int(len(obs_df_sub_idx) * 0.9)
split_point_2 = int(len(obs_df_sub_idx) * (0.9+0.1))
train = obs_df_sub_idx[:split_point_1]
valid = obs_df_sub_idx[split_point_1:split_point_2]


adata_train.obs['split_key'] = 'ood'

# -- set the test row
adata_train.obs.loc[train,'split_key'] = 'train'
adata_train.obs.loc[valid,'split_key'] = 'valid'

# - initial the model and training   
model = cpa.CPA(adata=adata_train, 
                split_key='split_key',
                train_split='train',
                valid_split='valid',
                test_split='ood',
                **model_params,
            )

model.train(max_epochs=2000,
            use_gpu=True, 
            batch_size=500,
            plan_kwargs=trainer_params,
            early_stopping_patience=5,
            check_val_every_n_epoch=5,
            # save_path='../../datasets/',
            progress_bar_refresh_rate = 0
        )

# - predict result
model.predict(adata_train, batch_size=2048)

# - get the pred data
cat = cell_line_bulk + '_' + 'stimulated'
cat_adata = adata_train[adata_train.obs['cov_cond'] == cat].copy()
x_pred = cat_adata.obsm['CPA_pred']
x_pred = np.log1p(x_pred)

Global seed set to 6977


100%|██████████| 2/2 [00:00<00:00, 708.92it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.692965367965368
disnt_after = 1.7727272727272727
val_r2_mean = 0.3614012598991394
val_r2_var = -2.6723037560780845



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6858585858585857
disnt_after = 1.7727272727272727
val_r2_mean = 0.7982175350189209
val_r2_var = -2.674081643422445



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6983953823953826
disnt_after = 1.7535285285285285
val_r2_mean = 0.9206221302350363
val_r2_var = -2.6612748305002847



Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6958730158730158
disnt_after = 1.7252252252252251
val_r2_mean = 0.9358939925829569
val_r2_var = -2.6237972577412925



Epoch 00044: cpa_metric reached. Module best state updated.



disnt_basal = 1.6902958152958152
disnt_after = 1.7252252252252251
val_r2_mean = 0.9437996745109558
val_r2_var = -2.581672509511312



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6884559884559884
disnt_after = 1.7252252252252251
val_r2_mean = 0.94074547290802
val_r2_var = -2.503347476323446



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.690079365079365
disnt_after = 1.7252252252252251
val_r2_mean = 0.943392833073934
val_r2_var = -2.4379239877065024



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.692965367965368
disnt_after = 1.7252252252252251
val_r2_mean = 0.9436059792836508
val_r2_var = -2.3871254126230874



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.7011202761202762
disnt_after = 1.7254231504231505
val_r2_mean = 0.9396191239356995
val_r2_var = -2.3239022890726724



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.5094290784290785
disnt_after = 1.7252252252252251
val_r2_mean = 0.9363454778989156
val_r2_var = -2.2833077112833657

disnt_basal = 1.6963924963924963
disnt_after = 1.7361111111111112
val_r2_mean = 0.9404150446256002
val_r2_var = -2.2283284664154053

disnt_basal = 1.697150072150072
disnt_after = 1.7291837291837293
val_r2_mean = 0.933383027712504
val_r2_var = -2.1212982336680093


100%|██████████| 1/1 [00:00<00:00, 20.75it/s]


In [None]:
x_pred

array([[0.59670675, 0.57625264, 0.918021  , ..., 1.582566  , 1.4443177 ,
        0.34309846],
       [0.7520548 , 0.5868137 , 1.0911624 , ..., 1.6446931 , 1.3728073 ,
        0.3184219 ],
       [0.66804856, 0.58108974, 0.913697  , ..., 1.6078756 , 1.4655265 ,
        0.32640746],
       ...,
       [0.6869034 , 0.546764  , 1.0182772 , ..., 1.6756613 , 1.4602138 ,
        0.35031497],
       [0.6638175 , 0.57233644, 1.0785822 , ..., 1.6767553 , 1.427946  ,
        0.30222225],
       [0.6788921 , 0.57769287, 0.99439704, ..., 1.9601779 , 1.5734761 ,
        0.4031645 ]], dtype=float32)

In [None]:
adata_rna_common.X

ArrayView([[0.        , 0.7986983 , 0.7986983 , ..., 0.7986983 ,
            0.7986983 , 0.7986983 ],
           [0.5719357 , 0.9334965 , 2.072573  , ..., 1.9704001 ,
            0.5719357 , 0.        ],
           [0.        , 0.8364108 , 2.1802375 , ..., 1.2854061 ,
            1.0858992 , 0.50321716],
           ...,
           [0.        , 1.5602019 , 1.7939309 , ..., 1.2546245 ,
            0.        , 0.        ],
           [0.        , 1.1712079 , 1.1712079 , ..., 2.038315  ,
            0.90983987, 0.        ],
           [0.9632273 , 0.59336513, 1.7682205 , ..., 1.4446287 ,
            1.7682205 , 0.59336513]], dtype=float32)

In [None]:
torch.cuda.empty_cache()

In [None]:
# ?model.train

In [None]:
common_idx = [direct_change_dict['gene_list'].index(gene) for i, gene in enumerate(common_var)]

# 并行加速搞不通，CPA需要torch，有点复杂

In [None]:
def cpa_single_pert(pert):
    
    # - get adata_pert and adata_ctrl
    adata_pert = pert_data.adata_split[pert_data.adata_split[pert_data.adata_split.obs['perturbation_group']==pert+' | K562'].obs_names]
    adata_ctrl = pert_data.adata_split[adata_pert.obs['control_barcode']]

    adata_pert = adata_ori[adata_pert.obs_names, common_var]
    adata_ctrl = adata_ori[adata_ctrl.obs_names, common_var]


    # - get adata_rna_common
    adata_rna_common = adata_rna[:, common_var]

    # - generate adata_train to input to scGen model
    np_list, obs_list, pert_list, celltype_list = [], [], [], []
    pert_list_2 = []
    adata_list = [adata_pert, adata_rna_common, adata_ctrl, adata_rna_common]
    for j, adata_ in enumerate(adata_list):
        if j in [0, 1]:
            pert_list.extend(['stimulated']*len(adata_))
        else:
            pert_list.extend(['ctrl']*len(adata_))

        if j in [0, 2]:
            celltype_list.extend(['K562']*len(adata_))
        else:
            celltype_list.extend([cell_line_bulk]*len(adata_))
        obs_list.extend([obs+f'_{j}' for obs in adata_.obs_names])
        np_list.append(adata_.X.toarray())

    adata_train = AnnData(X = np.vstack(np_list))
    adata_train.obs_names = obs_list
    adata_train.var_names = adata_pert.var_names

    adata_train.obs['condition'] = pert_list
    # adata_train.obs['condition_2'] = pert_list_2
    adata_train.obs['cell_type'] = celltype_list

    # - transform the adata_train.X to count
    adata_train.obs['cov_cond'] = adata_train.obs['cell_type'] + '_' + adata_train.obs['condition']
    adata_train.X = np.exp(adata_train.X)-1

    # - initial model
    cpa.CPA.setup_anndata(adata_train, 
                        perturbation_key='condition',
                        control_group='ctrl',
                        #   dosage_key='dose',
                        categorical_covariate_keys=['cell_type'],
                        is_count_data=True,
                        #   deg_uns_key='rank_genes_groups_cov',
                        deg_uns_cat_key='cov_cond',
                        max_comb_len=1,
                        )

    # - set the train and validation for cpa
    # -- get total obs_names of the pert
    adata_train_new = adata_train[~((adata_train.obs["cell_type"] == cell_line_bulk) &
                        (adata_train.obs["condition"] == "stimulated"))].copy()
    # obs_df_split = adata_train_new.obs
    obs_df_sub_idx = np.array(adata_train_new.obs.index)

    np.random.seed(2024)
    np.random.shuffle(obs_df_sub_idx)

    # -- data split
    split_point_1 = int(len(obs_df_sub_idx) * 0.9)
    split_point_2 = int(len(obs_df_sub_idx) * (0.9+0.1))
    train = obs_df_sub_idx[:split_point_1]
    valid = obs_df_sub_idx[split_point_1:split_point_2]


    adata_train.obs['split_key'] = 'ood'

    # -- set the test row
    adata_train.obs.loc[train,'split_key'] = 'train'
    adata_train.obs.loc[valid,'split_key'] = 'valid'

    # - initial the model and training   
    model = cpa.CPA(adata=adata_train, 
                    split_key='split_key',
                    train_split='train',
                    valid_split='valid',
                    test_split='ood',
                    **model_params,
                )

    model.train(max_epochs=2000,
                use_gpu=True, 
                batch_size=500,
                plan_kwargs=trainer_params,
                early_stopping_patience=5,
                check_val_every_n_epoch=5,
                # save_path='../../datasets/',
                progress_bar_refresh_rate = 0
            )

    # - predict result
    model.predict(adata_train, batch_size=2048)

    # - get the pred data
    cat = cell_line_bulk + '_' + 'stimulated'
    cat_adata = adata_train[adata_train.obs['cov_cond'] == cat].copy()
    x_pred = cat_adata.obsm['CPA_pred']
    x_pred = np.log1p(x_pred)


    # - get pert_gene_rank_dict
    adata_ctrl = adata_rna_common.copy()
    adata_pert = adata_ctrl.copy()
    adata_pert.X = x_pred
    adata_pert.obs_names = [i+f'_{pert}' for i in adata_pert.obs_names]

    adata_ctrl.obs['batch'] = 'ctrl'
    adata_pert.obs['batch'] = 'pert'

    import anndata as ad
    adata_concat = ad.concat([adata_ctrl, adata_pert])

    # - cal de genes
    rankby_abs = False

    sc.tl.rank_genes_groups(
        adata_concat,
        groupby='batch',
        reference='ctrl',
        rankby_abs=rankby_abs,
        n_genes=len(adata_concat.var),
        use_raw=False,
        method = 'wilcoxon'
    )
    de_genes = pd.DataFrame(adata_concat.uns['rank_genes_groups']['names'])
    pvals = pd.DataFrame(adata_concat.uns['rank_genes_groups']['pvals'])
    pvals_adj = pd.DataFrame(adata_concat.uns['rank_genes_groups']['pvals_adj'])
    scores = pd.DataFrame(adata_concat.uns['rank_genes_groups']['scores'])
    logfoldchanges = pd.DataFrame(adata_concat.uns['rank_genes_groups']['logfoldchanges'])

    # - get gene_score
    gene_score = pd.DataFrame({'gene':list(de_genes['pert']),
                                'z-score':list(scores['pert'])})

    return (list(de_genes['pert']), list(scores['pert']))
    
pert_gene_rank_dict = {}
for pert in tqdm(common_perts, desc='cpa_single_pert'):
    pert_gene_rank_dict[pert] = cpa_single_pert(pert)


cpa_single_pert:   0%|          | 0/2184 [00:00<?, ?it/s]



100%|██████████| 1364/1364 [00:00<00:00, 86824.36it/s]
100%|██████████| 1364/1364 [00:00<00:00, 1085378.61it/s]
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        


Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 950.98it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.692965367965368
disnt_after = 1.7727272727272727
val_r2_mean = 0.3614012598991394
val_r2_var = -2.6723037560780845



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6858585858585857
disnt_after = 1.7727272727272727
val_r2_mean = 0.7982175350189209
val_r2_var = -2.674081643422445



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6983953823953826
disnt_after = 1.7535285285285285
val_r2_mean = 0.9206221302350363
val_r2_var = -2.6612748305002847



Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6958730158730158
disnt_after = 1.7252252252252251
val_r2_mean = 0.9358939925829569
val_r2_var = -2.6237972577412925



Epoch 00044: cpa_metric reached. Module best state updated.



disnt_basal = 1.6902958152958152
disnt_after = 1.7252252252252251
val_r2_mean = 0.9437996745109558
val_r2_var = -2.581672509511312



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6884559884559884
disnt_after = 1.7252252252252251
val_r2_mean = 0.94074547290802
val_r2_var = -2.503347476323446



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.690079365079365
disnt_after = 1.7252252252252251
val_r2_mean = 0.943392833073934
val_r2_var = -2.4379239877065024



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.692965367965368
disnt_after = 1.7252252252252251
val_r2_mean = 0.9436059792836508
val_r2_var = -2.3871254126230874



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.7011202761202762
disnt_after = 1.7254231504231505
val_r2_mean = 0.9396191239356995
val_r2_var = -2.3239022890726724



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.5094290784290785
disnt_after = 1.7252252252252251
val_r2_mean = 0.9363454778989156
val_r2_var = -2.2833077112833657

disnt_basal = 1.6963924963924963
disnt_after = 1.7361111111111112
val_r2_mean = 0.9404150446256002
val_r2_var = -2.2283284664154053

disnt_basal = 1.697150072150072
disnt_after = 1.7291837291837293
val_r2_mean = 0.933383027712504
val_r2_var = -2.1212982336680093


100%|██████████| 1/1 [00:00<00:00, 28.24it/s]
cpa_single_pert:   0%|          | 1/2184 [00:20<12:31:30, 20.66s/it]



100%|██████████| 1470/1470 [00:00<00:00, 67029.34it/s]
100%|██████████| 1470/1470 [00:00<00:00, 853610.26it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 1022.63it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6813796296296297
disnt_after = 1.8574074074074074
val_r2_mean = 0.3187953233718872
val_r2_var = -2.5854386488596597



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6801574074074075
disnt_after = 1.8574074074074074
val_r2_mean = 0.7836709022521973
val_r2_var = -2.5872008005777993



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6909215797430084
disnt_after = 1.8556712962962965
val_r2_mean = 0.9320425391197205
val_r2_var = -2.5695308844248452



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.68944992441421
disnt_after = 1.8416666666666668
val_r2_mean = 0.9443625807762146
val_r2_var = -2.5329365730285645



Epoch 00044: cpa_metric reached. Module best state updated.

Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.6856018518518519
disnt_after = 1.8365740740740741
val_r2_mean = 0.9429282546043396
val_r2_var = -2.503542105356852



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6885
disnt_after = 1.836342592592593
val_r2_mean = 0.9447444677352905
val_r2_var = -2.41489847501119



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6822962962962964
disnt_after = 1.836342592592593
val_r2_mean = 0.9443283279736837
val_r2_var = -2.335684299468994



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6827592592592593
disnt_after = 1.8365740740740741
val_r2_mean = 0.949158251285553
val_r2_var = -2.257342576980591



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6853981481481481
disnt_after = 1.8365740740740741
val_r2_mean = 0.9466696381568909
val_r2_var = -2.192651907602946



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.6900234315948601
disnt_after = 1.8368055555555558
val_r2_mean = 0.9545074303944906
val_r2_var = -2.139136791229248



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.5248860544217688
disnt_after = 1.837037037037037
val_r2_mean = 0.9515902002652487
val_r2_var = -2.0765206813812256

disnt_basal = 1.6828156651549508
disnt_after = 1.8483796296296298
val_r2_mean = 0.9438965320587158
val_r2_var = -1.9478801091512044

disnt_basal = 1.6793570483749054
disnt_after = 1.8473379629629632
val_r2_mean = 0.9525395234425863
val_r2_var = -1.8502254486083984


100%|██████████| 1/1 [00:00<00:00, 23.48it/s]
cpa_single_pert:   0%|          | 2/2184 [00:35<10:22:02, 17.10s/it]



100%|██████████| 1488/1488 [00:00<00:00, 89263.49it/s]
100%|██████████| 1488/1488 [00:00<00:00, 1046290.75it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 683.50it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6477777777777778
disnt_after = 1.8626666666666667
val_r2_mean = 0.3384040395418803
val_r2_var = -2.5896519819895425



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6471111111111112
disnt_after = 1.8626666666666667
val_r2_mean = 0.7986541589101156
val_r2_var = -2.5912588437398276



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6586666666666665
disnt_after = 1.8620770975056689
val_r2_mean = 0.9225162665049235
val_r2_var = -2.5755370457967124



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6506928104575163
disnt_after = 1.8458049886621315
val_r2_mean = 0.9421955148379008
val_r2_var = -2.538591464360555



Epoch 00044: cpa_metric reached. Module best state updated.

Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.6513333333333333
disnt_after = 1.8434467120181408
val_r2_mean = 0.9326207041740417
val_r2_var = -2.473722060521444



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6557777777777778
disnt_after = 1.8430929705215422
val_r2_mean = 0.9393472274144491
val_r2_var = -2.3826725482940674



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6562222222222223
disnt_after = 1.8435646258503402
val_r2_mean = 0.9471829930941263
val_r2_var = -2.313684860865275



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6604444444444444
disnt_after = 1.843918367346939
val_r2_mean = 0.9383940100669861
val_r2_var = -2.2210141817728677



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6622222222222223
disnt_after = 1.8432108843537416
val_r2_mean = 0.9576320648193359
val_r2_var = -2.1928062438964844



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.65956462585034
disnt_after = 1.843328798185941
val_r2_mean = 0.9520739714304606
val_r2_var = -2.057819048563639



Epoch 00104: cpa_metric reached. Module best state updated.



disnt_basal = 1.6408197945845004
disnt_after = 1.8518185941043086
val_r2_mean = 0.9439069231351217
val_r2_var = -2.0358313719431558

disnt_basal = 1.6558020541549954
disnt_after = 1.8493424036281179
val_r2_mean = 0.9403394858042399
val_r2_var = -1.8776757717132568

disnt_basal = 1.6539717220221422
disnt_after = 1.8486349206349206
val_r2_mean = 0.9522263209025065
val_r2_var = -1.806034803390503


100%|██████████| 1/1 [00:00<00:00, 18.30it/s]
cpa_single_pert:   0%|          | 3/2184 [00:46<8:49:34, 14.57s/it] 



100%|██████████| 1402/1402 [00:00<00:00, 65019.34it/s]
100%|██████████| 1402/1402 [00:00<00:00, 769586.99it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 907.07it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.670026350461133
disnt_after = 1.7968599033816424
val_r2_mean = 0.31569822629292804
val_r2_var = -2.5409911473592124



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.678425559947299
disnt_after = 1.7968599033816424
val_r2_mean = 0.7781637112299601
val_r2_var = -2.543226639429728



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6830368906455864
disnt_after = 1.7932884748102138
val_r2_mean = 0.9114195307095846
val_r2_var = -2.5287721951802573



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6776021080368908
disnt_after = 1.772015182884748
val_r2_mean = 0.921968142191569
val_r2_var = -2.4915614128112793



Epoch 00044: cpa_metric reached. Module best state updated.



disnt_basal = 1.6776679841897233
disnt_after = 1.7626984126984127
val_r2_mean = 0.9238333503405253
val_r2_var = -2.4576502641042075



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6822793148880106
disnt_after = 1.7626984126984127
val_r2_mean = 0.9359955986340841
val_r2_var = -2.371142864227295



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6816205533596837
disnt_after = 1.7626984126984127
val_r2_mean = 0.9375585913658142
val_r2_var = -2.3009591897328696



Epoch 00074: cpa_metric reached. Module best state updated.



disnt_basal = 1.6814119455423802
disnt_after = 1.7626984126984127
val_r2_mean = 0.932242751121521
val_r2_var = -2.2698351542154946



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.684287193138125
disnt_after = 1.7626984126984127
val_r2_mean = 0.944982131322225
val_r2_var = -2.222517728805542



Epoch 00094: cpa_metric reached. Module best state updated.



disnt_basal = 1.662238399971319
disnt_after = 1.7699965493443752
val_r2_mean = 0.9409712354342142
val_r2_var = -2.2184998989105225



Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.6856253305011069
disnt_after = 1.7790027605244996
val_r2_mean = 0.946609099706014
val_r2_var = -2.0826830863952637



Epoch 00114: cpa_metric reached. Module best state updated.

Epoch 00119: cpa_metric reached. Module best state updated.



disnt_basal = 1.6794645657999696
disnt_after = 1.7692201518288475
val_r2_mean = 0.9421578248341879
val_r2_var = -1.9916698932647705



Epoch 00124: cpa_metric reached. Module best state updated.

Epoch 00129: cpa_metric reached. Module best state updated.



disnt_basal = 1.640609063124591
disnt_after = 1.76735679779158
val_r2_mean = 0.9432130853335062
val_r2_var = -1.91215976079305



Epoch 00134: cpa_metric reached. Module best state updated.

Epoch 00139: cpa_metric reached. Module best state updated.



disnt_basal = 1.4313635467362176
disnt_after = 1.764095928226363
val_r2_mean = 0.9477550188700358
val_r2_var = -1.8773539066314697



Epoch 00144: cpa_metric reached. Module best state updated.



disnt_basal = 1.6065638640172804
disnt_after = 1.7628536922015183
val_r2_mean = 0.9475875298182169
val_r2_var = -1.8754913806915283

disnt_basal = 1.6605460102354512
disnt_after = 1.7626984126984127
val_r2_mean = 0.94826207558314
val_r2_var = -1.713260014851888

disnt_basal = 1.5794316277235532
disnt_after = 1.762853692201518
val_r2_mean = 0.95298304160436
val_r2_var = -1.6754644711812336


100%|██████████| 1/1 [00:00<00:00, 31.55it/s]
cpa_single_pert:   0%|          | 4/2184 [01:04<9:26:55, 15.60s/it]



100%|██████████| 1262/1262 [00:00<00:00, 85225.92it/s]
100%|██████████| 1262/1262 [00:00<00:00, 1033830.40it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 914.89it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6317460317460317
disnt_after = 1.6317460317460317
val_r2_mean = 0.29072050253550213
val_r2_var = -2.3359766006469727



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6317460317460317
disnt_after = 1.6317460317460317
val_r2_mean = 0.7589184840520223
val_r2_var = -2.3367292881011963



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6245608465608465
disnt_after = 1.6307539682539682
val_r2_mean = 0.8955092827479044
val_r2_var = -2.3194520473480225



Epoch 00034: cpa_metric reached. Module best state updated.



disnt_basal = 1.6285502645502645
disnt_after = 1.6049603174603175
val_r2_mean = 0.9067819118499756
val_r2_var = -2.285635232925415

disnt_basal = 1.6317460317460317
disnt_after = 1.5708994708994708
val_r2_mean = 0.9052253762880961
val_r2_var = -2.228798786799113



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6317460317460317
disnt_after = 1.5755291005291006
val_r2_mean = 0.9135512709617615
val_r2_var = -2.168065071105957



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6305502645502645
disnt_after = 1.5672619047619047
val_r2_mean = 0.9236167470614115
val_r2_var = -2.1182024478912354



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.5748015873015873
disnt_after = 1.5672619047619047
val_r2_mean = 0.9235551953315735
val_r2_var = -2.0444088776906333



Epoch 00084: cpa_metric reached. Module best state updated.



disnt_basal = 1.622116402116402
disnt_after = 1.5722222222222222
val_r2_mean = 0.926965336004893
val_r2_var = -1.9594817161560059

disnt_basal = 1.5960952380952382
disnt_after = 1.5781746031746031
val_r2_mean = 0.9303377866744995
val_r2_var = -1.8846064408620198



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.3418068783068782
disnt_after = 1.5692460317460317
val_r2_mean = 0.9265432953834534
val_r2_var = -1.8160611788431804



Epoch 00114: cpa_metric reached. Module best state updated.



disnt_basal = 1.499473544973545
disnt_after = 1.5715608465608466
val_r2_mean = 0.9194307327270508
val_r2_var = -1.7889052232106526

disnt_basal = 1.465666666666667
disnt_after = 1.5689153439153438
val_r2_mean = 0.9342906475067139
val_r2_var = -1.7011624177296956



Epoch 00139: cpa_metric reached. Module best state updated.



disnt_basal = 1.259994708994709
disnt_after = 1.570568783068783
val_r2_mean = 0.9310408234596252
val_r2_var = -1.6343406836191814



Epoch 00144: cpa_metric reached. Module best state updated.

Epoch 00149: cpa_metric reached. Module best state updated.



disnt_basal = 1.1498703703703703
disnt_after = 1.570568783068783
val_r2_mean = 0.9310978055000305
val_r2_var = -1.5516708691914876

disnt_basal = 1.2221349206349206
disnt_after = 1.5718915343915343
val_r2_mean = 0.9407820502916971
val_r2_var = -1.4956664244333904



Epoch 00169: cpa_metric reached. Module best state updated.



disnt_basal = 1.1193835978835978
disnt_after = 1.5712301587301587
val_r2_mean = 0.9275315205256144
val_r2_var = -1.408171017964681



Epoch 00174: cpa_metric reached. Module best state updated.

Epoch 00179: cpa_metric reached. Module best state updated.



disnt_basal = 1.0629550264550267
disnt_after = 1.5689153439153438
val_r2_mean = 0.9392244021097819
val_r2_var = -1.3721817334493

disnt_basal = 1.1339497354497352
disnt_after = 1.5708994708994708
val_r2_mean = 0.9430321455001831
val_r2_var = -1.3701399167378743



Epoch 00199: cpa_metric reached. Module best state updated.



disnt_basal = 1.07484126984127
disnt_after = 1.5702380952380952
val_r2_mean = 0.9453375736872355
val_r2_var = -1.327462116877238



Epoch 00204: cpa_metric reached. Module best state updated.



disnt_basal = 1.0827989417989419
disnt_after = 1.5682539682539682
val_r2_mean = 0.9398440917332967
val_r2_var = -1.2666839758555095

disnt_basal = 1.138399470899471
disnt_after = 1.5699074074074073
val_r2_mean = 0.9429340163866679
val_r2_var = -1.2378525733947754

disnt_basal = 1.1235079365079366
disnt_after = 1.5695767195767196
val_r2_mean = 0.9431518117586771
val_r2_var = -1.1939315001169841


100%|██████████| 1/1 [00:00<00:00, 33.78it/s]
cpa_single_pert:   0%|          | 5/2184 [01:22<10:06:22, 16.70s/it]



100%|██████████| 1340/1340 [00:00<00:00, 76293.20it/s]
100%|██████████| 1340/1340 [00:00<00:00, 1032017.51it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 858.78it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.7374358974358974
disnt_after = 1.74
val_r2_mean = 0.27734172344207764
val_r2_var = -2.363736867904663



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.736302294197031
disnt_after = 1.74
val_r2_mean = 0.75567626953125
val_r2_var = -2.3636085987091064



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.7366666666666668
disnt_after = 1.7373557692307693
val_r2_mean = 0.9032248258590698
val_r2_var = -2.34852663675944



Epoch 00034: cpa_metric reached. Module best state updated.



disnt_basal = 1.7366666666666666
disnt_after = 1.6971634615384614
val_r2_mean = 0.9157938361167908
val_r2_var = -2.319752852121989

disnt_basal = 1.7366666666666666
disnt_after = 1.678125
val_r2_mean = 0.8990533351898193
val_r2_var = -2.2628565629323325



Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.7366666666666666
disnt_after = 1.678389423076923
val_r2_mean = 0.9106890161832174
val_r2_var = -2.2021677494049072



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.7366666666666666
disnt_after = 1.678125
val_r2_mean = 0.9221992691357931
val_r2_var = -2.14946715037028



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.7328006916329286
disnt_after = 1.678125
val_r2_mean = 0.9263072808583578
val_r2_var = -2.0821568965911865



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6201581477732794
disnt_after = 1.678125
val_r2_mean = 0.9360714753468832
val_r2_var = -2.011984427769979



Epoch 00094: cpa_metric reached. Module best state updated.



disnt_basal = 1.709926197705803
disnt_after = 1.704831730769231
val_r2_mean = 0.9276338815689087
val_r2_var = -1.977919340133667

disnt_basal = 1.720193572874494
disnt_after = 1.7080048076923078
val_r2_mean = 0.9380089243253072
val_r2_var = -1.8860702514648438

disnt_basal = 1.6952749662618083
disnt_after = 1.6961057692307693
val_r2_mean = 0.9393298029899597
val_r2_var = -1.8110002676645915


100%|██████████| 1/1 [00:00<00:00, 22.36it/s]
cpa_single_pert:   0%|          | 6/2184 [01:33<8:55:06, 14.74s/it] 



100%|██████████| 1612/1612 [00:00<00:00, 81036.72it/s]
100%|██████████| 1612/1612 [00:00<00:00, 1031302.33it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 912.60it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6406432748538011
disnt_after = 2.0
val_r2_mean = 0.3359534740447998
val_r2_var = -2.7536628246307373



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.634551656920078
disnt_after = 2.0
val_r2_mean = 0.7977495590845743
val_r2_var = -2.7555975119272866



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6326842991316675
disnt_after = 2.0
val_r2_mean = 0.9374423623085022
val_r2_var = -2.7415647506713867



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6175948077263866
disnt_after = 2.0
val_r2_mean = 0.9517888824144999
val_r2_var = -2.706390857696533



Epoch 00044: cpa_metric reached. Module best state updated.



disnt_basal = 1.635840864788233
disnt_after = 2.0
val_r2_mean = 0.954476793607076
val_r2_var = -2.6405778725941977



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6354775828460038
disnt_after = 2.0
val_r2_mean = 0.9521381060282389
val_r2_var = -2.5698916912078857



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6397417153996101
disnt_after = 2.0
val_r2_mean = 0.9549921154975891
val_r2_var = -2.5041064421335855



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.64208089668616
disnt_after = 2.0
val_r2_mean = 0.9605269829432169
val_r2_var = -2.4390904108683267



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6419812156654263
disnt_after = 2.0
val_r2_mean = 0.9617913762728373
val_r2_var = -2.3690088589986167



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.6337674995569733
disnt_after = 2.0
val_r2_mean = 0.9555482864379883
val_r2_var = -2.291666030883789



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.6137452960982375
disnt_after = 2.0
val_r2_mean = 0.9597216447194418
val_r2_var = -2.227461576461792



Epoch 00114: cpa_metric reached. Module best state updated.



disnt_basal = 1.5978904108161074
disnt_after = 2.0
val_r2_mean = 0.9572152098019918
val_r2_var = -2.1818639437357583

disnt_basal = 1.6296517809675706
disnt_after = 2.0
val_r2_mean = 0.9507816831270853
val_r2_var = -2.0543415546417236

disnt_basal = 1.6305334042176147
disnt_after = 2.0
val_r2_mean = 0.9551775058110555
val_r2_var = -1.9355225563049316


100%|██████████| 1/1 [00:00<00:00, 23.55it/s]
cpa_single_pert:   0%|          | 7/2184 [01:46<8:33:44, 14.16s/it]



100%|██████████| 1414/1414 [00:00<00:00, 85499.32it/s]
100%|██████████| 1414/1414 [00:00<00:00, 1081660.74it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 416.60it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6862111801242237
disnt_after = 1.8166666666666667
val_r2_mean = 0.33938004573186237
val_r2_var = -2.6993443965911865



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6692132505175983
disnt_after = 1.8166666666666664
val_r2_mean = 0.7934506336847941
val_r2_var = -2.7002137502034507



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6732091097308488
disnt_after = 1.8163852813852812
val_r2_mean = 0.9307314356168112
val_r2_var = -2.6866437594095864



Epoch 00034: cpa_metric reached. Module best state updated.



disnt_basal = 1.684337474120083
disnt_after = 1.8014718614718617
val_r2_mean = 0.9368169903755188
val_r2_var = -2.6594194571177163

disnt_basal = 1.6845341614906832
disnt_after = 1.78754329004329
val_r2_mean = 0.9418168266614279
val_r2_var = -2.6139701207478843



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6903623188405796
disnt_after = 1.7876839826839825
val_r2_mean = 0.9462791681289673
val_r2_var = -2.5552965005238852



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6923084886128363
disnt_after = 1.7881060606060604
val_r2_mean = 0.9446269075075785
val_r2_var = -2.482820908228556



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6949482401656315
disnt_after = 1.7871212121212119
val_r2_mean = 0.9531773527463278
val_r2_var = -2.4027557373046875



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6951449275362318
disnt_after = 1.787121212121212
val_r2_mean = 0.941952665646871
val_r2_var = -2.308884938557943



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.691181603076013
disnt_after = 1.787121212121212
val_r2_mean = 0.9482053319613138
val_r2_var = -2.228065808614095



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.5243392890753138
disnt_after = 1.791904761904762
val_r2_mean = 0.9448910554250082
val_r2_var = -2.211301247278849

disnt_basal = 1.6924534161490683
disnt_after = 1.7957034632034632
val_r2_mean = 0.9455838600794474
val_r2_var = -2.082740068435669

disnt_basal = 1.6893921916592725
disnt_after = 1.7955627705627706
val_r2_mean = 0.954897940158844
val_r2_var = -1.9851163228352864


100%|██████████| 1/1 [00:00<00:00, 33.04it/s]
cpa_single_pert:   0%|          | 8/2184 [01:59<8:21:57, 13.84s/it]



100%|██████████| 1392/1392 [00:00<00:00, 66089.42it/s]
100%|██████████| 1392/1392 [00:00<00:00, 991588.17it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 826.71it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6549242424242425
disnt_after = 1.7960784313725489
val_r2_mean = 0.35773982604344684
val_r2_var = -2.6311829884847007



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6647504456327986
disnt_after = 1.7960784313725489
val_r2_mean = 0.8033348123232523
val_r2_var = -2.6335842609405518



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.670254010695187
disnt_after = 1.781360877684407
val_r2_mean = 0.9303862651189169
val_r2_var = -2.6208993593851724

disnt_basal = 1.680557412358883
disnt_after = 1.7626984126984127
val_r2_mean = 0.9419266780217489
val_r2_var = -2.588535944620768



Epoch 00044: cpa_metric reached. Module best state updated.

Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.6761586452762924
disnt_after = 1.7626984126984127
val_r2_mean = 0.9447603623072306
val_r2_var = -2.547029892603556



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6729055258467023
disnt_after = 1.7631535947712416
val_r2_mean = 0.9464993476867676
val_r2_var = -2.4679462909698486



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.672860962566845
disnt_after = 1.7652777777777777
val_r2_mean = 0.9414335489273071
val_r2_var = -2.3818023999532065



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6700311942959
disnt_after = 1.7657329598506069
val_r2_mean = 0.9423782030741373
val_r2_var = -2.302607456843058



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6290446481622953
disnt_after = 1.7646708683473389
val_r2_mean = 0.943507969379425
val_r2_var = -2.2230656941731772



Epoch 00094: cpa_metric reached. Module best state updated.



disnt_basal = 1.5626997389865038
disnt_after = 1.7645191409897292
val_r2_mean = 0.9428917169570923
val_r2_var = -2.237464984258016

disnt_basal = 1.627375116713352
disnt_after = 1.765581232492997
val_r2_mean = 0.9405388832092285
val_r2_var = -2.062814394632975

disnt_basal = 1.608384793311264
disnt_after = 1.7660364145658263
val_r2_mean = 0.9371350208918253
val_r2_var = -1.965319315592448


100%|██████████| 1/1 [00:00<00:00, 32.03it/s]
cpa_single_pert:   0%|          | 9/2184 [02:13<8:16:40, 13.70s/it]



100%|██████████| 1434/1434 [00:00<00:00, 77416.36it/s]
100%|██████████| 1434/1434 [00:00<00:00, 1101782.73it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 913.69it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6714431516636048
disnt_after = 1.82018779342723
val_r2_mean = 0.34310275316238403
val_r2_var = -2.6182135740915933



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6568993672177994
disnt_after = 1.82018779342723
val_r2_mean = 0.7905231316884359
val_r2_var = -2.6187607447306314



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.624171113638352
disnt_after = 1.7975378195096505
val_r2_mean = 0.9321639537811279
val_r2_var = -2.5971460342407227



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6318180303352203
disnt_after = 1.7933333333333332
val_r2_mean = 0.9472194910049438
val_r2_var = -2.556496540705363



Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.6443118337353313
disnt_after = 1.7933333333333334
val_r2_mean = 0.9454571604728699
val_r2_var = -2.501921812693278



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6417125944070219
disnt_after = 1.7933333333333334
val_r2_mean = 0.9489310185114542
val_r2_var = -2.4176193873087564



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6404539035454966
disnt_after = 1.7933333333333334
val_r2_mean = 0.9514341553052267
val_r2_var = -2.349461555480957



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6461458190134595
disnt_after = 1.7933333333333334
val_r2_mean = 0.9540879527727762
val_r2_var = -2.3033103148142495



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.641886724071135
disnt_after = 1.7933333333333334
val_r2_mean = 0.9547244906425476
val_r2_var = -2.2316083113352456



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.6454453428646414
disnt_after = 1.7933333333333334
val_r2_mean = 0.9562251170476278
val_r2_var = -2.1801846027374268



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.5639976875284372
disnt_after = 1.7933333333333334
val_r2_mean = 0.9603047172228495
val_r2_var = -2.0978084405263266

disnt_basal = 1.6475913451724842
disnt_after = 1.8163901930099113
val_r2_mean = 0.9537652532259623
val_r2_var = -1.9823774496714275



Epoch 00124: cpa_metric reached. Module best state updated.

Epoch 00129: cpa_metric reached. Module best state updated.



disnt_basal = 1.6444172280057154
disnt_after = 1.8146270213875848
val_r2_mean = 0.9512648781140646
val_r2_var = -1.885379155476888



Epoch 00134: cpa_metric reached. Module best state updated.

Epoch 00139: cpa_metric reached. Module best state updated.



disnt_basal = 1.6469348180178378
disnt_after = 1.8121857068335943
val_r2_mean = 0.9513763189315796
val_r2_var = -1.802688757578532



Epoch 00144: cpa_metric reached. Module best state updated.



disnt_basal = 1.6377330233991945
disnt_after = 1.8067605633802817
val_r2_mean = 0.9618248144785563
val_r2_var = -1.7748138109842937



Epoch 00154: cpa_metric reached. Module best state updated.

Epoch 00159: cpa_metric reached. Module best state updated.



disnt_basal = 1.5982316026362917
disnt_after = 1.7994366197183098
val_r2_mean = 0.9607738057772318
val_r2_var = -1.6992289225260417



Epoch 00164: cpa_metric reached. Module best state updated.

Epoch 00169: cpa_metric reached. Module best state updated.



disnt_basal = 1.319843621193459
disnt_after = 1.793740219092332
val_r2_mean = 0.9567431410153707
val_r2_var = -1.6562615235646565

disnt_basal = 1.5935783645975232
disnt_after = 1.7933333333333334
val_r2_mean = 0.9548410971959432
val_r2_var = -1.5239485104878743

disnt_basal = 1.6094270754501705
disnt_after = 1.7933333333333334
val_r2_mean = 0.9580127000808716
val_r2_var = -1.5137157440185547


100%|██████████| 1/1 [00:00<00:00, 22.08it/s]
cpa_single_pert:   0%|          | 10/2184 [02:32<9:23:30, 15.55s/it]



100%|██████████| 1520/1520 [00:00<00:00, 80132.50it/s]
100%|██████████| 1520/1520 [00:00<00:00, 1129378.58it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 894.31it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6715250965250965
disnt_after = 1.912162162162162
val_r2_mean = 0.33875419696172077
val_r2_var = -2.709535837173462



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6763030888030888
disnt_after = 1.912162162162162
val_r2_mean = 0.7972262899080912
val_r2_var = -2.712682088216146



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6616518661518662
disnt_after = 1.9099705474705475
val_r2_mean = 0.9323592583338419
val_r2_var = -2.699217955271403



Epoch 00034: cpa_metric reached. Module best state updated.



disnt_basal = 1.6764227799227798
disnt_after = 1.9041580041580042
val_r2_mean = 0.93873530626297
val_r2_var = -2.6613501707712808



Epoch 00044: cpa_metric reached. Module best state updated.

Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.669550193050193
disnt_after = 1.896153846153846
val_r2_mean = 0.947451651096344
val_r2_var = -2.6237974166870117



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.663377734877735
disnt_after = 1.896153846153846
val_r2_mean = 0.9511892795562744
val_r2_var = -2.5485036373138428



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6622265122265123
disnt_after = 1.8963444213444212
val_r2_mean = 0.9548285206158956
val_r2_var = -2.49703311920166



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6647245817245817
disnt_after = 1.8962491337491336
val_r2_mean = 0.9503642320632935
val_r2_var = -2.4225684801737466



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.669362290862291
disnt_after = 1.896153846153846
val_r2_mean = 0.9598462184270223
val_r2_var = -2.3571643034617105



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.6681737451737453
disnt_after = 1.8962491337491336
val_r2_mean = 0.961110790570577
val_r2_var = -2.282841761906942



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.6497419067419068
disnt_after = 1.896153846153846
val_r2_mean = 0.9566459059715271
val_r2_var = -2.1881515979766846



Epoch 00114: cpa_metric reached. Module best state updated.



disnt_basal = 1.664461191961192
disnt_after = 1.9016805266805266
val_r2_mean = 0.9592636227607727
val_r2_var = -2.1870513757069907

disnt_basal = 1.6825990495990495
disnt_after = 1.9061590436590437
val_r2_mean = 0.9551216761271158
val_r2_var = -2.0008467038472495

disnt_basal = 1.67994653994654
disnt_after = 1.9054920304920304
val_r2_mean = 0.9567233721415201
val_r2_var = -1.8855384190877278


100%|██████████| 1/1 [00:00<00:00, 18.48it/s]
cpa_single_pert:   1%|          | 11/2184 [02:46<9:06:30, 15.09s/it]



100%|██████████| 1458/1458 [00:00<00:00, 83725.29it/s]
100%|██████████| 1458/1458 [00:00<00:00, 853733.80it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 881.99it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6645833333333333
disnt_after = 1.8375000000000001
val_r2_mean = 0.2879616419474284
val_r2_var = -2.350093682607015



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6550925925925926
disnt_after = 1.8375000000000001
val_r2_mean = 0.7695302168528239
val_r2_var = -2.351834853490194



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6450648148148148
disnt_after = 1.831219806763285
val_r2_mean = 0.9181086818377177
val_r2_var = -2.3357690970102944



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6414074074074074
disnt_after = 1.8131642512077295
val_r2_mean = 0.938613216082255
val_r2_var = -2.294973373413086

disnt_basal = 1.650425925925926
disnt_after = 1.8115942028985508
val_r2_mean = 0.9287051558494568
val_r2_var = -2.2685186862945557



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.65
disnt_after = 1.811725040257649
val_r2_mean = 0.9312011003494263
val_r2_var = -2.196088949839274



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6569444444444446
disnt_after = 1.8115942028985508
val_r2_mean = 0.9396840731302897
val_r2_var = -2.1393309434254966



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6585648148148149
disnt_after = 1.8115942028985508
val_r2_mean = 0.943692147731781
val_r2_var = -2.0860716501871743



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6551932367149758
disnt_after = 1.8115942028985508
val_r2_mean = 0.9499649604161581
val_r2_var = -2.0334231853485107



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.6418341384863124
disnt_after = 1.811725040257649
val_r2_mean = 0.9535052378972372
val_r2_var = -1.9587825934092205



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.3687946859903382
disnt_after = 1.8143417874396137
val_r2_mean = 0.9509349664052328
val_r2_var = -1.917322079340617

disnt_basal = 1.6541453301127214
disnt_after = 1.8280797101449275
val_r2_mean = 0.9467145005861918
val_r2_var = -1.7950173219045003

disnt_basal = 1.6505712560386474
disnt_after = 1.825462962962963
val_r2_mean = 0.9574272036552429
val_r2_var = -1.7075150807698567


100%|██████████| 1/1 [00:00<00:00, 26.54it/s]
cpa_single_pert:   1%|          | 12/2184 [02:59<8:43:10, 14.45s/it]



100%|██████████| 1590/1590 [00:00<00:00, 82295.04it/s]
100%|██████████| 1590/1590 [00:00<00:00, 1102851.56it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 889.09it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6614920634920636
disnt_after = 1.9813333333333334
val_r2_mean = 0.34738367795944214
val_r2_var = -2.618999640146891



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6592698412698412
disnt_after = 1.9813333333333334
val_r2_mean = 0.7981757322947184
val_r2_var = -2.621148109436035



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6595184059439378
disnt_after = 1.9813333333333334
val_r2_mean = 0.9346696337064108
val_r2_var = -2.605388561884562



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.657047619047619
disnt_after = 1.9787513227513227
val_r2_mean = 0.9481182893117269
val_r2_var = -2.571783701578776



Epoch 00044: cpa_metric reached. Module best state updated.

Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.654984126984127
disnt_after = 1.977862433862434
val_r2_mean = 0.9530457655588785
val_r2_var = -2.525882085164388



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6625396825396825
disnt_after = 1.977904761904762
val_r2_mean = 0.9560842712720236
val_r2_var = -2.4607207775115967



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6643492063492062
disnt_after = 1.977904761904762
val_r2_mean = 0.9556785225868225
val_r2_var = -2.372267166773478



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.663015873015873
disnt_after = 1.977904761904762
val_r2_mean = 0.9623288909594218
val_r2_var = -2.2979396184285483



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6631111111111112
disnt_after = 1.977904761904762
val_r2_mean = 0.9615311423937479
val_r2_var = -2.2501818339029946



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.663873015873016
disnt_after = 1.977904761904762
val_r2_mean = 0.9655895829200745
val_r2_var = -2.171518246332804



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.6589730946752224
disnt_after = 1.977820105820106
val_r2_mean = 0.957540770371755
val_r2_var = -2.0861969788869223



Epoch 00114: cpa_metric reached. Module best state updated.

Epoch 00119: cpa_metric reached. Module best state updated.



disnt_basal = 1.4498131261961047
disnt_after = 1.9784550264550265
val_r2_mean = 0.9602910081545512
val_r2_var = -2.0192414919535318

disnt_basal = 1.6549253630530227
disnt_after = 1.980952380952381
val_r2_mean = 0.9654593467712402
val_r2_var = -1.9430538813273113

disnt_basal = 1.6557028031070584
disnt_after = 1.9806560846560848
val_r2_mean = 0.9653578003247579
val_r2_var = -1.8459004561106365


100%|██████████| 1/1 [00:00<00:00, 30.50it/s]
cpa_single_pert:   1%|          | 13/2184 [03:12<8:21:31, 13.86s/it]



100%|██████████| 1452/1452 [00:00<00:00, 84329.80it/s]
100%|██████████| 1452/1452 [00:00<00:00, 1111540.32it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 426.90it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6618225134008338
disnt_after = 1.8246575342465752
val_r2_mean = 0.33392155170440674
val_r2_var = -2.5672881603240967



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6570875521143538
disnt_after = 1.8246575342465752
val_r2_mean = 0.7898481289545695
val_r2_var = -2.5704508622487388



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6381562437959103
disnt_after = 1.8190291840381179
val_r2_mean = 0.927831212679545
val_r2_var = -2.5530714193979898



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6373845542981935
disnt_after = 1.8006700416914831
val_r2_mean = 0.9420348207155863
val_r2_var = -2.5124780337015786



Epoch 00044: cpa_metric reached. Module best state updated.



disnt_basal = 1.6502755608497122
disnt_after = 1.7999999999999998
val_r2_mean = 0.9394235809644064
val_r2_var = -2.4586164156595864



Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6480841770895376
disnt_after = 1.7999999999999998
val_r2_mean = 0.94125896692276
val_r2_var = -2.389039675394694



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6476970418900139
disnt_after = 1.7999999999999998
val_r2_mean = 0.9432112177213033
val_r2_var = -2.334916591644287



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6503672821123685
disnt_after = 1.7999999999999998
val_r2_mean = 0.94239874680837
val_r2_var = -2.295102914174398



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6570875521143538
disnt_after = 1.7999999999999998
val_r2_mean = 0.9416929483413696
val_r2_var = -2.2311163743336997



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.6502467738733375
disnt_after = 1.7999999999999998
val_r2_mean = 0.9483636617660522
val_r2_var = -2.1465214093526206



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.6354816358943816
disnt_after = 1.7999999999999998
val_r2_mean = 0.9543981949488322
val_r2_var = -2.1001929442087808



Epoch 00114: cpa_metric reached. Module best state updated.

Epoch 00119: cpa_metric reached. Module best state updated.



disnt_basal = 1.3762898550724638
disnt_after = 1.8002680166765932
val_r2_mean = 0.9541044433911642
val_r2_var = -2.06018058458964

disnt_basal = 1.6544867976970419
disnt_after = 1.8184931506849313
val_r2_mean = 0.9544829924901327
val_r2_var = -1.999021848042806

disnt_basal = 1.6547548143736353
disnt_after = 1.8170190589636688
val_r2_mean = 0.9533861875534058
val_r2_var = -1.874489386876424


100%|██████████| 1/1 [00:00<00:00, 27.37it/s]
cpa_single_pert:   1%|          | 14/2184 [03:27<8:36:40, 14.29s/it]



100%|██████████| 1920/1920 [00:00<00:00, 78699.30it/s]
100%|██████████| 1920/1920 [00:00<00:00, 1063672.39it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 1102.89it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6292446941323346
disnt_after = 2.0
val_r2_mean = 0.5729731122652689
val_r2_var = -2.726543744405111



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6276273581634069
disnt_after = 2.0
val_r2_mean = 0.9450042247772217
val_r2_var = -2.7167762915293374



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6032598141212375
disnt_after = 2.0
val_r2_mean = 0.964261015256246
val_r2_var = -2.6684793631235757



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.5893813288944376
disnt_after = 2.0
val_r2_mean = 0.9642144640286764
val_r2_var = -2.6173930962880454



Epoch 00044: cpa_metric reached. Module best state updated.



disnt_basal = 1.591850213731374
disnt_after = 2.0
val_r2_mean = 0.9673049251238505
val_r2_var = -2.5885046323140464

disnt_basal = 1.6304801290054098
disnt_after = 2.0
val_r2_mean = 0.9552059570948283
val_r2_var = -2.4285077253977456



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.5993285713764227
disnt_after = 2.0
val_r2_mean = 0.9613840381304423
val_r2_var = -2.3289989630381265



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6142408794562353
disnt_after = 2.0
val_r2_mean = 0.9618838429450989
val_r2_var = -2.1682868798573813



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.5958482269969556
disnt_after = 2.0
val_r2_mean = 0.9558196266492208
val_r2_var = -2.0754186312357583



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.5358963484971273
disnt_after = 2.0
val_r2_mean = 0.9640516638755798
val_r2_var = -1.979591687520345



Epoch 00104: cpa_metric reached. Module best state updated.



disnt_basal = 1.5386207527505822
disnt_after = 2.0
val_r2_mean = 0.9645289977391561
val_r2_var = -1.9092304706573486

disnt_basal = 1.4618477360171132
disnt_after = 2.0
val_r2_mean = 0.9734293421109518
val_r2_var = -1.8591898282368977



Epoch 00124: cpa_metric reached. Module best state updated.



disnt_basal = 1.522600080491491
disnt_after = 2.0
val_r2_mean = 0.9673162500063578
val_r2_var = -1.7868017355600994



Epoch 00139: cpa_metric reached. Module best state updated.



disnt_basal = 1.2073724775682444
disnt_after = 2.0
val_r2_mean = 0.9729268153508505
val_r2_var = -1.7348755995432537

disnt_basal = 1.4396647328631609
disnt_after = 2.0
val_r2_mean = 0.972183624903361
val_r2_var = -1.6552414894104004



Epoch 00159: cpa_metric reached. Module best state updated.



disnt_basal = 1.1130238015711353
disnt_after = 2.0
val_r2_mean = 0.9770786960919698
val_r2_var = -1.654726505279541

disnt_basal = 1.2529832911711238
disnt_after = 2.0
val_r2_mean = 0.9752334753672282
val_r2_var = -1.620368480682373



Epoch 00179: cpa_metric reached. Module best state updated.



disnt_basal = 1.1068259020157551
disnt_after = 2.0
val_r2_mean = 0.9767802556355795
val_r2_var = -1.559661070505778

disnt_basal = 1.2305442484175482
disnt_after = 2.0
val_r2_mean = 0.9765205780665079
val_r2_var = -1.526133934656779



Epoch 00199: cpa_metric reached. Module best state updated.



disnt_basal = 1.0752894362675312
disnt_after = 2.0
val_r2_mean = 0.9704486926396688
val_r2_var = -1.479473352432251

disnt_basal = 1.1625953577764636
disnt_after = 2.0
val_r2_mean = 0.978067954381307
val_r2_var = -1.4969589710235596

disnt_basal = 1.1318030778412949
disnt_after = 2.0
val_r2_mean = 0.978255033493042
val_r2_var = -1.4442949295043945


100%|██████████| 1/1 [00:00<00:00, 16.90it/s]
cpa_single_pert:   1%|          | 15/2184 [04:00<12:03:00, 20.00s/it]



100%|██████████| 1498/1498 [00:00<00:00, 80177.98it/s]
100%|██████████| 1498/1498 [00:00<00:00, 1086096.35it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 941.69it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.651931623931624
disnt_after = 1.8793333333333333
val_r2_mean = 0.3227066198984782
val_r2_var = -2.5638760725657144



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6528376068376067
disnt_after = 1.8793333333333333
val_r2_mean = 0.7893996834754944
val_r2_var = -2.5662038326263428



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6572976370035195
disnt_after = 1.8685555555555555
val_r2_mean = 0.9319630861282349
val_r2_var = -2.552162488301595



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6348179989944698
disnt_after = 1.8607777777777779
val_r2_mean = 0.9374167323112488
val_r2_var = -2.5185322761535645



Epoch 00044: cpa_metric reached. Module best state updated.



disnt_basal = 1.6593760683760683
disnt_after = 1.8606666666666665
val_r2_mean = 0.9361593127250671
val_r2_var = -2.479434013366699



Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.666179487179487
disnt_after = 1.8606666666666665
val_r2_mean = 0.944792648156484
val_r2_var = -2.405449310938517



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.660854700854701
disnt_after = 1.8606666666666665
val_r2_mean = 0.940283477306366
val_r2_var = -2.3304686546325684



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6585128205128206
disnt_after = 1.8606666666666665
val_r2_mean = 0.9354501366615295
val_r2_var = -2.2582857608795166



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6561452991452992
disnt_after = 1.8606666666666665
val_r2_mean = 0.9435293674468994
val_r2_var = -2.208865165710449



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.6577094017094018
disnt_after = 1.8606666666666665
val_r2_mean = 0.9433780908584595
val_r2_var = -2.1429710388183594



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.642034188034188
disnt_after = 1.8606666666666665
val_r2_mean = 0.9512623945871989
val_r2_var = -2.0743475755055747



Epoch 00114: cpa_metric reached. Module best state updated.

Epoch 00119: cpa_metric reached. Module best state updated.



disnt_basal = 1.3903348416289592
disnt_after = 1.8613333333333333
val_r2_mean = 0.9593100945154825
val_r2_var = -2.0369327068328857

disnt_basal = 1.6600437405731525
disnt_after = 1.8707777777777777
val_r2_mean = 0.9534353613853455
val_r2_var = -1.8746368885040283

disnt_basal = 1.6504449472096532
disnt_after = 1.8695555555555554
val_r2_mean = 0.9566797216733297
val_r2_var = -1.789516846338908


100%|██████████| 1/1 [00:00<00:00, 24.40it/s]
cpa_single_pert:   1%|          | 16/2184 [04:16<11:15:46, 18.70s/it]



100%|██████████| 1332/1332 [00:00<00:00, 82967.95it/s]
100%|██████████| 1332/1332 [00:00<00:00, 962579.76it/s]


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        


Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 995.09it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.7089608636977058
disnt_after = 1.738974358974359
val_r2_mean = 0.3431936502456665
val_r2_var = -2.6055612564086914



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.7152091767881241
disnt_after = 1.738974358974359
val_r2_mean = 0.7910579442977905
val_r2_var = -2.6070214907328286



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.7131708607869598
disnt_after = 1.7267909867909867
val_r2_mean = 0.9231439828872681
val_r2_var = -2.586860259373983



Epoch 00034: cpa_metric reached. Module best state updated.



disnt_basal = 1.7109441401391865
disnt_after = 1.6835275835275836
val_r2_mean = 0.9356913963953654
val_r2_var = -2.5535847345987954

disnt_basal = 1.7186898997115714
disnt_after = 1.6800466200466202
val_r2_mean = 0.9396381775538126
val_r2_var = -2.515809694925944



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.7189338731443993
disnt_after = 1.6825330225330224
val_r2_mean = 0.9369866649309794
val_r2_var = -2.443882942199707



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.722847503373819
disnt_after = 1.6825330225330224
val_r2_mean = 0.9469824632008871
val_r2_var = -2.4020979404449463



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.7195512880652202
disnt_after = 1.682284382284382
val_r2_mean = 0.9450995127360026
val_r2_var = -2.3253190517425537



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.5449893553299126
disnt_after = 1.6825330225330224
val_r2_mean = 0.949434240659078
val_r2_var = -2.2526838779449463

disnt_basal = 1.7187680568485524
disnt_after = 1.6865112665112665
val_r2_mean = 0.9432789087295532
val_r2_var = -2.125180800755819

disnt_basal = 1.7173810503222269
disnt_after = 1.6865112665112665
val_r2_mean = 0.9464794794718424
val_r2_var = -2.0211211840311685


100%|██████████| 1/1 [00:00<00:00, 37.96it/s]
cpa_single_pert:   1%|          | 17/2184 [04:27<9:55:55, 16.50s/it] 



100%|██████████| 1416/1416 [00:00<00:00, 80386.76it/s]
100%|██████████| 1416/1416 [00:00<00:00, 1090349.64it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 821.37it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6482815734989646
disnt_after = 1.8166666666666664
val_r2_mean = 0.319965660572052
val_r2_var = -2.5382744471232095



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.65972049689441
disnt_after = 1.8166666666666664
val_r2_mean = 0.7833959062894186
val_r2_var = -2.5411651134490967



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6675362318840579
disnt_after = 1.8166666666666667
val_r2_mean = 0.9245308836301168
val_r2_var = -2.530617634455363



Epoch 00034: cpa_metric reached. Module best state updated.



disnt_basal = 1.6716977225672878
disnt_after = 1.8068181818181819
val_r2_mean = 0.9300942023595175
val_r2_var = -2.5075294971466064

disnt_basal = 1.671935817805383
disnt_after = 1.7923268398268397
val_r2_mean = 0.9227483669916788
val_r2_var = -2.461564540863037



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6692546583850931
disnt_after = 1.7909199134199132
val_r2_mean = 0.9284729957580566
val_r2_var = -2.376403490702311



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.663902691511387
disnt_after = 1.7907792207792206
val_r2_mean = 0.9405759572982788
val_r2_var = -2.3182316621144614



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6682298136645963
disnt_after = 1.7910606060606058
val_r2_mean = 0.9395920038223267
val_r2_var = -2.262680689493815



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6614906832298137
disnt_after = 1.789090909090909
val_r2_mean = 0.9456813534100851
val_r2_var = -2.185255289077759



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.6682401656314698
disnt_after = 1.788246753246753
val_r2_mean = 0.9545066754023234
val_r2_var = -2.133148988087972



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.6647361592858487
disnt_after = 1.7881060606060606
val_r2_mean = 0.9569498896598816
val_r2_var = -2.067066192626953



Epoch 00114: cpa_metric reached. Module best state updated.

Epoch 00119: cpa_metric reached. Module best state updated.



disnt_basal = 1.2710747895996342
disnt_after = 1.7888095238095234
val_r2_mean = 0.9546862045923868
val_r2_var = -2.0341062545776367

disnt_basal = 1.662870049205453
disnt_after = 1.801190476190476
val_r2_mean = 0.9500458240509033
val_r2_var = -1.8803922335306804

disnt_basal = 1.6635239036325993
disnt_after = 1.8038636363636362
val_r2_mean = 0.949347178141276
val_r2_var = -1.7839840253194172


100%|██████████| 1/1 [00:00<00:00, 33.03it/s]
cpa_single_pert:   1%|          | 18/2184 [04:42<9:38:25, 16.02s/it]



100%|██████████| 1254/1254 [00:00<00:00, 81227.72it/s]
100%|██████████| 1254/1254 [00:00<00:00, 876317.43it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 938.64it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.5881720430107529
disnt_after = 1.5881720430107529
val_r2_mean = 0.2892739772796631
val_r2_var = -2.371535380681356



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.5881720430107527
disnt_after = 1.5881720430107527
val_r2_mean = 0.751863936583201
val_r2_var = -2.374580144882202



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.5881720430107529
disnt_after = 1.5610008271298594
val_r2_mean = 0.9026029706001282
val_r2_var = -2.36378280321757

disnt_basal = 1.5881720430107527
disnt_after = 1.5267576509511995
val_r2_mean = 0.9110826055208842
val_r2_var = -2.329702297846476



Epoch 00044: cpa_metric reached. Module best state updated.

Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.5881720430107529
disnt_after = 1.5256410256410258
val_r2_mean = 0.9170685807863871
val_r2_var = -2.2748100757598877



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.5881720430107527
disnt_after = 1.5256410256410258
val_r2_mean = 0.91954638560613
val_r2_var = -2.213196277618408



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.5881720430107527
disnt_after = 1.5256410256410258
val_r2_mean = 0.9353243708610535
val_r2_var = -2.1421643098195395



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.3941321044546853
disnt_after = 1.5256410256410258
val_r2_mean = 0.9258395234743754
val_r2_var = -2.09499724706014

disnt_basal = 1.5881720430107527
disnt_after = 1.5371794871794873
val_r2_mean = 0.930893063545227
val_r2_var = -1.9975697994232178

disnt_basal = 1.5701139075977786
disnt_after = 1.533829611248966
val_r2_mean = 0.9375860095024109
val_r2_var = -1.9007246494293213


100%|██████████| 1/1 [00:00<00:00, 26.85it/s]
cpa_single_pert:   1%|          | 19/2184 [04:53<8:43:29, 14.51s/it]



100%|██████████| 1432/1432 [00:00<00:00, 84950.33it/s]
100%|██████████| 1432/1432 [00:00<00:00, 1071204.45it/s]


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        


Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 875.09it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.667309655031639
disnt_after = 1.82018779342723
val_r2_mean = 0.338659663995107
val_r2_var = -2.593356450398763



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6605531741171666
disnt_after = 1.82018779342723
val_r2_mean = 0.7958119114240011
val_r2_var = -2.5958775679270425



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6721888265410267
disnt_after = 1.8111006781429315
val_r2_mean = 0.936845044294993
val_r2_var = -2.5798420906066895



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6591770082190869
disnt_after = 1.7945539906103285
val_r2_mean = 0.9475898345311483
val_r2_var = -2.5426501433054605



Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.6732904674423352
disnt_after = 1.794960876369327
val_r2_mean = 0.9471261103947958
val_r2_var = -2.4866249561309814



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6712696468667074
disnt_after = 1.797130933750652
val_r2_mean = 0.9382565220197042
val_r2_var = -2.4031283855438232



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.673229230455195
disnt_after = 1.79848721961398
val_r2_mean = 0.9463387727737427
val_r2_var = -2.354393800099691



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6708001632986325
disnt_after = 1.7967240479916537
val_r2_mean = 0.9553853273391724
val_r2_var = -2.292421817779541



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.675913451724842
disnt_after = 1.7952321335419927
val_r2_mean = 0.9546962181727091
val_r2_var = -2.2541622320810952



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.6721002527703603
disnt_after = 1.79509650495566
val_r2_mean = 0.9544816811879476
val_r2_var = -2.1703023115793862



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.4044821447664595
disnt_after = 1.7950965049556595
val_r2_mean = 0.9613906145095825
val_r2_var = -2.1382813453674316

disnt_basal = 1.666136362689595
disnt_after = 1.8102869066249347
val_r2_mean = 0.9581615130106608
val_r2_var = -2.0396703084309897

disnt_basal = 1.6689461822696012
disnt_after = 1.809337506520605
val_r2_mean = 0.9624332785606384
val_r2_var = -1.932869831720988


100%|██████████| 1/1 [00:00<00:00, 30.25it/s]
cpa_single_pert:   1%|          | 20/2184 [05:06<8:21:15, 13.90s/it]



100%|██████████| 1426/1426 [00:00<00:00, 80482.78it/s]
100%|██████████| 1426/1426 [00:00<00:00, 1034430.56it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 875.55it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6648136645962732
disnt_after = 1.8166666666666664
val_r2_mean = 0.35956428448359173
val_r2_var = -2.7094593048095703



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6503312629399587
disnt_after = 1.8166666666666667
val_r2_mean = 0.804447074731191
val_r2_var = -2.710476875305176



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6689855072463768
disnt_after = 1.7941558441558443
val_r2_mean = 0.9378485282262167
val_r2_var = -2.696113665898641



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6760351966873706
disnt_after = 1.787121212121212
val_r2_mean = 0.9525734782218933
val_r2_var = -2.6709434191385903



Epoch 00044: cpa_metric reached. Module best state updated.

Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.6697619047619048
disnt_after = 1.7872619047619045
val_r2_mean = 0.9459466735521952
val_r2_var = -2.6339422861735025



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.670527950310559
disnt_after = 1.787121212121212
val_r2_mean = 0.9496904611587524
val_r2_var = -2.579047203063965



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6700207039337474
disnt_after = 1.7878246753246754
val_r2_mean = 0.9503888885180155
val_r2_var = -2.5136331717173257



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6697619047619048
disnt_after = 1.787121212121212
val_r2_mean = 0.9586944381395975
val_r2_var = -2.4814866383870444



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.67
disnt_after = 1.787121212121212
val_r2_mean = 0.9566695292790731
val_r2_var = -2.4106899897257485



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.631339436960555
disnt_after = 1.787121212121212
val_r2_mean = 0.9535489082336426
val_r2_var = -2.3436867396036782



Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.673290446613428
disnt_after = 1.7916233766233767
val_r2_mean = 0.9572346607844034
val_r2_var = -2.2365134557088218



Epoch 00114: cpa_metric reached. Module best state updated.

Epoch 00119: cpa_metric reached. Module best state updated.



disnt_basal = 1.6744676131322094
disnt_after = 1.7941558441558443
val_r2_mean = 0.9600429137547811
val_r2_var = -2.1268988450368247



Epoch 00124: cpa_metric reached. Module best state updated.

Epoch 00129: cpa_metric reached. Module best state updated.



disnt_basal = 1.6725598935226265
disnt_after = 1.7951406926406923
val_r2_mean = 0.9551815390586853
val_r2_var = -2.032339096069336



Epoch 00134: cpa_metric reached. Module best state updated.

Epoch 00139: cpa_metric reached. Module best state updated.



disnt_basal = 1.674390712806862
disnt_after = 1.7934523809523808
val_r2_mean = 0.9592291514078776
val_r2_var = -1.983866850535075



Epoch 00144: cpa_metric reached. Module best state updated.

Epoch 00149: cpa_metric reached. Module best state updated.



disnt_basal = 1.6595314726681185
disnt_after = 1.791060606060606
val_r2_mean = 0.956028421719869
val_r2_var = -1.9257745742797852



Epoch 00154: cpa_metric reached. Module best state updated.

Epoch 00159: cpa_metric reached. Module best state updated.



disnt_basal = 1.5771649727084511
disnt_after = 1.7900757575757575
val_r2_mean = 0.9619842966397604
val_r2_var = -1.9213543732961018



Epoch 00164: cpa_metric reached. Module best state updated.



disnt_basal = 1.3658544405904653
disnt_after = 1.787121212121212
val_r2_mean = 0.9563153386116028
val_r2_var = -1.9177482922871907

disnt_basal = 1.6595908284262322
disnt_after = 1.7878246753246754
val_r2_mean = 0.9615426858266195
val_r2_var = -1.7775553862253826

disnt_basal = 1.6460021241698262
disnt_after = 1.7883874458874458
val_r2_mean = 0.9656227032343546
val_r2_var = -1.705542802810669


100%|██████████| 1/1 [00:00<00:00, 32.80it/s]
cpa_single_pert:   1%|          | 21/2184 [05:25<9:14:32, 15.38s/it]



100%|██████████| 1304/1304 [00:00<00:00, 86273.15it/s]
100%|██████████| 1304/1304 [00:00<00:00, 965126.60it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 795.20it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6738095238095239
disnt_after = 1.6738095238095239
val_r2_mean = 0.28327323993047077
val_r2_var = -2.2644272645314536



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6738095238095239
disnt_after = 1.6738095238095239
val_r2_mean = 0.7563994924227396
val_r2_var = -2.265320618947347



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6677248677248677
disnt_after = 1.6671075837742504
val_r2_mean = 0.8998583356539408
val_r2_var = -2.2511433760325112



Epoch 00034: cpa_metric reached. Module best state updated.



disnt_basal = 1.6720458553791886
disnt_after = 1.6325396825396825
val_r2_mean = 0.9104765057563782
val_r2_var = -2.22329052289327



Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.6738095238095239
disnt_after = 1.6166666666666667
val_r2_mean = 0.9187938769658407
val_r2_var = -2.1693476835886636



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6738095238095239
disnt_after = 1.6166666666666667
val_r2_mean = 0.9153626362482706
val_r2_var = -2.1031951904296875



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6559082892416226
disnt_after = 1.6166666666666667
val_r2_mean = 0.9232852856318156
val_r2_var = -2.034376939137777



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6002204585537918
disnt_after = 1.6166666666666667
val_r2_mean = 0.935504674911499
val_r2_var = -1.9707162380218506



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.5455908289241622
disnt_after = 1.6166666666666667
val_r2_mean = 0.9318453073501587
val_r2_var = -1.9230742454528809

disnt_basal = 1.6271604938271604
disnt_after = 1.6198412698412699
val_r2_mean = 0.9265917936960856
val_r2_var = -1.7892110347747803



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.5653439153439155
disnt_after = 1.6180776014109348
val_r2_mean = 0.9243407050768534
val_r2_var = -1.721236228942871



Epoch 00114: cpa_metric reached. Module best state updated.

Epoch 00119: cpa_metric reached. Module best state updated.



disnt_basal = 1.2929453262786597
disnt_after = 1.6166666666666667
val_r2_mean = 0.9368294676144918
val_r2_var = -1.7012054920196533

disnt_basal = 1.4966049382716051
disnt_after = 1.6166666666666667
val_r2_mean = 0.9294304847717285
val_r2_var = -1.600153923034668

disnt_basal = 1.4807760141093476
disnt_after = 1.6166666666666667
val_r2_mean = 0.9411884943644205
val_r2_var = -1.5308455626169841


100%|██████████| 1/1 [00:00<00:00, 36.89it/s]
cpa_single_pert:   1%|          | 22/2184 [05:38<8:56:12, 14.88s/it]



100%|██████████| 1474/1474 [00:00<00:00, 88320.06it/s]
100%|██████████| 1474/1474 [00:00<00:00, 1061356.93it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 391.61it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6981891891891892
disnt_after = 1.8585585585585587
val_r2_mean = 0.33952565987904865
val_r2_var = -2.660771131515503



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6983423423423423
disnt_after = 1.8585585585585587
val_r2_mean = 0.7930688063303629
val_r2_var = -2.6621956825256348



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6842068539127362
disnt_after = 1.8540446696696697
val_r2_mean = 0.9258845249811808
val_r2_var = -2.6465397675832114



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.67148595654478
disnt_after = 1.837331081081081
val_r2_mean = 0.9397350152333578
val_r2_var = -2.6153148810068765



Epoch 00044: cpa_metric reached. Module best state updated.



disnt_basal = 1.687216216216216
disnt_after = 1.8361111111111112
val_r2_mean = 0.9412307143211365
val_r2_var = -2.558656613032023



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6877117117117117
disnt_after = 1.8361111111111112
val_r2_mean = 0.9427894552548727
val_r2_var = -2.477170467376709



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.691054054054054
disnt_after = 1.8361111111111112
val_r2_mean = 0.9460509022076925
val_r2_var = -2.4182024002075195



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6890450450450452
disnt_after = 1.8361111111111112
val_r2_mean = 0.9538228313128153
val_r2_var = -2.3521553675333657



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.68614820703056
disnt_after = 1.8361111111111112
val_r2_mean = 0.9531999627749125
val_r2_var = -2.305747111638387



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.6888009185656245
disnt_after = 1.8361111111111112
val_r2_mean = 0.9598461190859476
val_r2_var = -2.236586411794027



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.6826364599894013
disnt_after = 1.8361111111111112
val_r2_mean = 0.9556497732798258
val_r2_var = -2.164353370666504



Epoch 00114: cpa_metric reached. Module best state updated.



disnt_basal = 1.6125738606253313
disnt_after = 1.8401370120120122
val_r2_mean = 0.9518794020016988
val_r2_var = -2.130057175954183

disnt_basal = 1.6839066198551493
disnt_after = 1.8425769519519521
val_r2_mean = 0.9516083796819051
val_r2_var = -2.018190622329712

disnt_basal = 1.6844880763116057
disnt_after = 1.842698948948949
val_r2_mean = 0.9521786570549011
val_r2_var = -1.9436016082763672


100%|██████████| 1/1 [00:00<00:00, 26.06it/s]
cpa_single_pert:   1%|          | 23/2184 [05:52<8:40:47, 14.46s/it]



100%|██████████| 1306/1306 [00:00<00:00, 87107.59it/s]
100%|██████████| 1306/1306 [00:00<00:00, 1079149.14it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 903.85it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6984375
disnt_after = 1.6984375
val_r2_mean = 0.3133279085159302
val_r2_var = -2.323213736216227



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6984375
disnt_after = 1.6984375
val_r2_mean = 0.7703620394070944
val_r2_var = -2.32456111907959



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6984375
disnt_after = 1.6852370689655172
val_r2_mean = 0.9140551686286926
val_r2_var = -2.3101905981699624

disnt_basal = 1.6984375
disnt_after = 1.636835488505747
val_r2_mean = 0.9278884331385294
val_r2_var = -2.2752816677093506

disnt_basal = 1.6984375
disnt_after = 1.6365211925287357
val_r2_mean = 0.9239024519920349
val_r2_var = -2.237175703048706



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6984375
disnt_after = 1.636206896551724
val_r2_mean = 0.9324952363967896
val_r2_var = -2.15458353360494



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6984375
disnt_after = 1.636206896551724
val_r2_mean = 0.9325738151868185
val_r2_var = -2.1016484101613364



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6984375
disnt_after = 1.636206896551724
val_r2_mean = 0.9421114126841227
val_r2_var = -2.0377093156178794



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.529302720396877
disnt_after = 1.636206896551724
val_r2_mean = 0.9375332991282145
val_r2_var = -2.0073235034942627

disnt_basal = 1.6984374999999998
disnt_after = 1.6572647270114942
val_r2_mean = 0.9250019192695618
val_r2_var = -1.9286341667175293

disnt_basal = 1.6984374999999998
disnt_after = 1.651293103448276
val_r2_mean = 0.939553459485372
val_r2_var = -1.8176924387613933


100%|██████████| 1/1 [00:00<00:00, 25.50it/s]
cpa_single_pert:   1%|          | 24/2184 [06:03<8:05:35, 13.49s/it]



100%|██████████| 1310/1310 [00:00<00:00, 84504.06it/s]
100%|██████████| 1310/1310 [00:00<00:00, 1060107.71it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 670.23it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6857142857142855
disnt_after = 1.6857142857142855
val_r2_mean = 0.3272705078125
val_r2_var = -2.4037667910257974



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6857142857142857
disnt_after = 1.6857142857142857
val_r2_mean = 0.7770458261171976
val_r2_var = -2.4066304365793862



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6845063392233204
disnt_after = 1.6833994708994708
val_r2_mean = 0.9151291251182556
val_r2_var = -2.3903070290883384



Epoch 00034: cpa_metric reached. Module best state updated.



disnt_basal = 1.6857142857142857
disnt_after = 1.6546296296296297
val_r2_mean = 0.9284403324127197
val_r2_var = -2.354764779408773

disnt_basal = 1.6857142857142857
disnt_after = 1.6390873015873015
val_r2_mean = 0.9255573352177938
val_r2_var = -2.3094331423441568



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6857142857142857
disnt_after = 1.6304894179894178
val_r2_mean = 0.9297114014625549
val_r2_var = -2.2373456160227456



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6857142857142857
disnt_after = 1.6261904761904762
val_r2_mean = 0.9326857328414917
val_r2_var = -2.1710216999053955



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6242063492063492
disnt_after = 1.626521164021164
val_r2_mean = 0.9339876969655355
val_r2_var = -2.1278932094573975



Epoch 00084: cpa_metric reached. Module best state updated.



disnt_basal = 1.682091694119996
disnt_after = 1.6321428571428571
val_r2_mean = 0.9409315586090088
val_r2_var = -2.0955821673075357



Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.6754429969052609
disnt_after = 1.6367724867724869
val_r2_mean = 0.936696986357371
val_r2_var = -1.9993003209431965



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.618810522112409
disnt_after = 1.636441798941799
val_r2_mean = 0.9341301918029785
val_r2_var = -1.9168124198913574



Epoch 00114: cpa_metric reached. Module best state updated.

Epoch 00119: cpa_metric reached. Module best state updated.



disnt_basal = 1.3626846860337427
disnt_after = 1.6318121693121692
val_r2_mean = 0.9283532698949178
val_r2_var = -1.8602386315663655

disnt_basal = 1.67049765398822
disnt_after = 1.630820105820106
val_r2_mean = 0.9298275113105774
val_r2_var = -1.770856459935506

disnt_basal = 1.625545322950983
disnt_after = 1.6304894179894178
val_r2_mean = 0.9323822259902954
val_r2_var = -1.6841583251953125


100%|██████████| 1/1 [00:00<00:00, 29.02it/s]
cpa_single_pert:   1%|          | 25/2184 [06:16<7:59:34, 13.33s/it]



100%|██████████| 1494/1494 [00:00<00:00, 88947.89it/s]
100%|██████████| 1494/1494 [00:00<00:00, 1066965.81it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 815.77it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6662222222222223
disnt_after = 1.8626666666666667
val_r2_mean = 0.3108977675437927
val_r2_var = -2.4727625846862793



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6622222222222223
disnt_after = 1.8626666666666667
val_r2_mean = 0.7808197538057963
val_r2_var = -2.4735941092173257



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.674
disnt_after = 1.8575963718820863
val_r2_mean = 0.9254482984542847
val_r2_var = -2.4611124992370605



Epoch 00034: cpa_metric reached. Module best state updated.



disnt_basal = 1.6817777777777778
disnt_after = 1.8445079365079367
val_r2_mean = 0.9315300583839417
val_r2_var = -2.4350082874298096



Epoch 00044: cpa_metric reached. Module best state updated.

Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.6846666666666668
disnt_after = 1.842857142857143
val_r2_mean = 0.9359542727470398
val_r2_var = -2.3663503328959146



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6857777777777778
disnt_after = 1.842857142857143
val_r2_mean = 0.9395005504290262
val_r2_var = -2.3079883257548013



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6800000000000002
disnt_after = 1.842857142857143
val_r2_mean = 0.9442593653996786
val_r2_var = -2.2362000942230225



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6735555555555557
disnt_after = 1.842857142857143
val_r2_mean = 0.9521293242772421
val_r2_var = -2.1712427934010825



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6671111111111112
disnt_after = 1.842857142857143
val_r2_mean = 0.9517532189687093
val_r2_var = -2.0952123006184897



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.6651111111111112
disnt_after = 1.842857142857143
val_r2_mean = 0.9557366172472636
val_r2_var = -2.008717695871989



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.6590292116846739
disnt_after = 1.842857142857143
val_r2_mean = 0.9557372530301412
val_r2_var = -1.9297052224477131



Epoch 00114: cpa_metric reached. Module best state updated.



disnt_basal = 1.6210140056022408
disnt_after = 1.8472199546485262
val_r2_mean = 0.9575667778650919
val_r2_var = -1.9465992450714111

disnt_basal = 1.6433842870481525
disnt_after = 1.8487528344671205
val_r2_mean = 0.9480304718017578
val_r2_var = -1.7428747812906902



Epoch 00139: cpa_metric reached. Module best state updated.



disnt_basal = 1.6545092703748168
disnt_after = 1.852172335600907
val_r2_mean = 0.948683520158132
val_r2_var = -1.6722920735677083



Epoch 00144: cpa_metric reached. Module best state updated.

Epoch 00149: cpa_metric reached. Module best state updated.



disnt_basal = 1.6395913031879419
disnt_after = 1.8508752834467121
val_r2_mean = 0.9565751949946085
val_r2_var = -1.594262997309367



Epoch 00154: cpa_metric reached. Module best state updated.

Epoch 00159: cpa_metric reached. Module best state updated.



disnt_basal = 1.5996529278378016
disnt_after = 1.848281179138322
val_r2_mean = 0.9539555311203003
val_r2_var = -1.5549581050872803



Epoch 00164: cpa_metric reached. Module best state updated.

Epoch 00169: cpa_metric reached. Module best state updated.



disnt_basal = 1.3411159130318797
disnt_after = 1.8435646258503402
val_r2_mean = 0.9557815591494242
val_r2_var = -1.564677635828654

disnt_basal = 1.5801766039749234
disnt_after = 1.842857142857143
val_r2_mean = 0.9600072503089905
val_r2_var = -1.4943254788716633

disnt_basal = 1.6109795918367347
disnt_after = 1.8429750566893426
val_r2_mean = 0.9635943571726481
val_r2_var = -1.4319743315378826


100%|██████████| 1/1 [00:00<00:00, 23.85it/s]
cpa_single_pert:   1%|          | 26/2184 [06:35<9:02:52, 15.09s/it]



100%|██████████| 1812/1812 [00:00<00:00, 84854.51it/s]
100%|██████████| 1812/1812 [00:00<00:00, 1055273.38it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 380.40it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.654510921177588
disnt_after = 2.0
val_r2_mean = 0.5896567900975546
val_r2_var = -2.871702194213867



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6391670058336727
disnt_after = 2.0
val_r2_mean = 0.9511598149935404
val_r2_var = -2.8668487866719565



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.64010989010989
disnt_after = 2.0
val_r2_mean = 0.9614291985829672
val_r2_var = -2.825930118560791



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6245183828517162
disnt_after = 2.0
val_r2_mean = 0.9639302690823873
val_r2_var = -2.76961350440979



Epoch 00044: cpa_metric reached. Module best state updated.

Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.5205799755799756
disnt_after = 2.0
val_r2_mean = 0.9572197596232096
val_r2_var = -2.713587919871012

disnt_basal = 1.643050468050468
disnt_after = 2.0
val_r2_mean = 0.9629099369049072
val_r2_var = -2.5617249806722007

disnt_basal = 1.6420451770451772
disnt_after = 2.0
val_r2_mean = 0.9635969599088033
val_r2_var = -2.4664461612701416



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6046275946275945
disnt_after = 2.0
val_r2_mean = 0.9645687937736511
val_r2_var = -2.378301461537679



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6318376068376068
disnt_after = 2.0
val_r2_mean = 0.9654983878135681
val_r2_var = -2.1921419302622476



Epoch 00094: cpa_metric reached. Module best state updated.



disnt_basal = 1.6247815764482432
disnt_after = 2.0
val_r2_mean = 0.9651546676953634
val_r2_var = -2.151333491007487

disnt_basal = 1.5643121693121693
disnt_after = 2.0
val_r2_mean = 0.9655494491259257
val_r2_var = -2.03354819615682



Epoch 00114: cpa_metric reached. Module best state updated.

Epoch 00119: cpa_metric reached. Module best state updated.



disnt_basal = 1.3228150861484194
disnt_after = 2.0
val_r2_mean = 0.9587020675341288
val_r2_var = -1.9087332089742024

disnt_basal = 1.4326604259937594
disnt_after = 2.0
val_r2_mean = 0.9618972539901733
val_r2_var = -1.8370265165964763



Epoch 00134: cpa_metric reached. Module best state updated.



disnt_basal = 1.3606735856735854
disnt_after = 2.0
val_r2_mean = 0.9709634780883789
val_r2_var = -1.8284138043721516

disnt_basal = 1.2588834622167955
disnt_after = 2.0
val_r2_mean = 0.964774509270986
val_r2_var = -1.7384210427602131



Epoch 00154: cpa_metric reached. Module best state updated.
100%|██████████| 1/1 [00:00<00:00, 21.92it/s]
cpa_single_pert:   1%|          | 27/2184 [06:56<10:00:19, 16.70s/it]



100%|██████████| 1260/1260 [00:00<00:00, 51798.79it/s]
100%|██████████| 1260/1260 [00:00<00:00, 661926.73it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 714.41it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6049180327868853
disnt_after = 1.6049180327868853
val_r2_mean = 0.29253751039505005
val_r2_var = -2.3411072889963784



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6049180327868853
disnt_after = 1.6049180327868853
val_r2_mean = 0.7563416163126627
val_r2_var = -2.341919501622518



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.595672131147541
disnt_after = 1.603446826397646
val_r2_mean = 0.9027062853177389
val_r2_var = -2.3248708248138428

disnt_basal = 1.6001584699453553
disnt_after = 1.5475409836065575
val_r2_mean = 0.91940704981486
val_r2_var = -2.2833391030629477



Epoch 00044: cpa_metric reached. Module best state updated.



disnt_basal = 1.6049180327868853
disnt_after = 1.5442307692307693
val_r2_mean = 0.9157058000564575
val_r2_var = -2.2255709966023765



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6049180327868853
disnt_after = 1.5442307692307693
val_r2_mean = 0.9247613747914633
val_r2_var = -2.1619760990142822



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6049180327868853
disnt_after = 1.5442307692307693
val_r2_mean = 0.9318892359733582
val_r2_var = -2.092792828877767



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.5336757741347906
disnt_after = 1.5442307692307693
val_r2_mean = 0.936679482460022
val_r2_var = -2.0496952533721924



Epoch 00084: cpa_metric reached. Module best state updated.



disnt_basal = 1.5915180047639064
disnt_after = 1.553793610760824
val_r2_mean = 0.9243518908818563
val_r2_var = -1.9872064590454102



Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.5470403530895336
disnt_after = 1.5508511979823454
val_r2_mean = 0.9328565796216329
val_r2_var = -1.901870886484782



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.29949418523189
disnt_after = 1.5457019756200086
val_r2_mean = 0.9242237408955892
val_r2_var = -1.8692196210225422

disnt_basal = 1.4529234972677596
disnt_after = 1.5442307692307693
val_r2_mean = 0.9355891744295756
val_r2_var = -1.8121793270111084



Epoch 00129: cpa_metric reached. Module best state updated.



disnt_basal = 1.3616387838027182
disnt_after = 1.5442307692307693
val_r2_mean = 0.9340313076972961
val_r2_var = -1.713729699452718



Epoch 00134: cpa_metric reached. Module best state updated.

Epoch 00139: cpa_metric reached. Module best state updated.



disnt_basal = 1.1614681238615665
disnt_after = 1.5442307692307693
val_r2_mean = 0.920773983001709
val_r2_var = -1.5852205753326416



Epoch 00144: cpa_metric reached. Module best state updated.



disnt_basal = 1.2103913408995375
disnt_after = 1.5445985708280792
val_r2_mean = 0.934138298034668
val_r2_var = -1.5573915640513103



Epoch 00159: cpa_metric reached. Module best state updated.



disnt_basal = 1.1340718789407314
disnt_after = 1.5445985708280792
val_r2_mean = 0.9312589168548584
val_r2_var = -1.4714175860087078



Epoch 00164: cpa_metric reached. Module best state updated.



disnt_basal = 1.1261658960347485
disnt_after = 1.5445985708280792
val_r2_mean = 0.928454061349233
val_r2_var = -1.4497816562652588

disnt_basal = 1.2082072299285413
disnt_after = 1.5442307692307693
val_r2_mean = 0.9287031888961792
val_r2_var = -1.3975172837575276



Epoch 00184: cpa_metric reached. Module best state updated.

Epoch 00189: cpa_metric reached. Module best state updated.



disnt_basal = 1.0957067395264117
disnt_after = 1.5442307692307693
val_r2_mean = 0.9356321096420288
val_r2_var = -1.3430599371592205



Epoch 00194: cpa_metric reached. Module best state updated.



disnt_basal = 1.1000218579234975
disnt_after = 1.5449663724253888
val_r2_mean = 0.9380439718564352
val_r2_var = -1.3418233394622803

disnt_basal = 1.124304189435337
disnt_after = 1.5449663724253888
val_r2_mean = 0.9398115674654642
val_r2_var = -1.2977380752563477



Epoch 00219: cpa_metric reached. Module best state updated.



disnt_basal = 1.0924445845593387
disnt_after = 1.5445985708280792
val_r2_mean = 0.9355281790097555
val_r2_var = -1.270508050918579



Epoch 00224: cpa_metric reached. Module best state updated.

Epoch 00229: cpa_metric reached. Module best state updated.



disnt_basal = 1.0816902059688944
disnt_after = 1.5445985708280792
val_r2_mean = 0.9416446487108866
val_r2_var = -1.2501373291015625

disnt_basal = 1.1293714445845593
disnt_after = 1.5442307692307693
val_r2_mean = 0.9422548214594523
val_r2_var = -1.2187918821970622



Epoch 00249: cpa_metric reached. Module best state updated.



disnt_basal = 1.096590724394003
disnt_after = 1.5442307692307693
val_r2_mean = 0.9439952969551086
val_r2_var = -1.1956807772318523



Epoch 00254: cpa_metric reached. Module best state updated.

Epoch 00259: cpa_metric reached. Module best state updated.



disnt_basal = 1.0582108729157909
disnt_after = 1.5445985708280792
val_r2_mean = 0.9433347781499227
val_r2_var = -1.1928545633951824

disnt_basal = 1.0663208631077485
disnt_after = 1.5442307692307693
val_r2_mean = 0.944778839747111
val_r2_var = -1.1876786549886067



Epoch 00279: cpa_metric reached. Module best state updated.



disnt_basal = 1.0697382653776095
disnt_after = 1.5442307692307693
val_r2_mean = 0.9387099742889404
val_r2_var = -1.1366674900054932



Epoch 00284: cpa_metric reached. Module best state updated.

Epoch 00289: cpa_metric reached. Module best state updated.



disnt_basal = 1.054915931063472
disnt_after = 1.5442307692307693
val_r2_mean = 0.9412481586138407
val_r2_var = -1.1452419757843018

disnt_basal = 1.069421185372005
disnt_after = 1.5442307692307693
val_r2_mean = 0.9379874467849731
val_r2_var = -1.1234926382700603



Epoch 00309: cpa_metric reached. Module best state updated.



disnt_basal = 1.0594071738825837
disnt_after = 1.5442307692307693
val_r2_mean = 0.9465424418449402
val_r2_var = -1.1323231061299641



Epoch 00314: cpa_metric reached. Module best state updated.

Epoch 00319: cpa_metric reached. Module best state updated.



disnt_basal = 1.0532344122180186
disnt_after = 1.5442307692307693
val_r2_mean = 0.9410990476608276
val_r2_var = -1.068858027458191

disnt_basal = 1.0507770772033067
disnt_after = 1.5442307692307693
val_r2_mean = 0.9427252213160197
val_r2_var = -1.0923302173614502

disnt_basal = 1.052321843912008
disnt_after = 1.5442307692307693
val_r2_mean = 0.9440907041231791
val_r2_var = -1.094449758529663


100%|██████████| 1/1 [00:00<00:00, 22.94it/s]
cpa_single_pert:   1%|▏         | 28/2184 [07:28<12:47:20, 21.35s/it]



100%|██████████| 1726/1726 [00:00<00:00, 83022.19it/s]
100%|██████████| 1726/1726 [00:00<00:00, 1058077.86it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 814.11it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6877375730994153
disnt_after = 1.993421052631579
val_r2_mean = 0.5917532642682394
val_r2_var = -2.6800769170125327



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.692364289346555
disnt_after = 1.993421052631579
val_r2_mean = 0.9423079689343771
val_r2_var = -2.6731647650400796



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6835001906941267
disnt_after = 1.9935728744939272
val_r2_mean = 0.9614123106002808
val_r2_var = -2.631368557612101



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6423491927281972
disnt_after = 1.9935897435897436
val_r2_mean = 0.9653044740358988
val_r2_var = -2.579827388127645



Epoch 00044: cpa_metric reached. Module best state updated.



disnt_basal = 1.6745431286549706
disnt_after = 1.993550382366172
val_r2_mean = 0.9608315229415894
val_r2_var = -2.4601271947224936



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6637426900584795
disnt_after = 1.993550382366172
val_r2_mean = 0.96456378698349
val_r2_var = -2.3517483870188394



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.664855072463768
disnt_after = 1.9935278902384166
val_r2_mean = 0.9646859367688497
val_r2_var = -2.265652656555176



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.5931874523264686
disnt_after = 1.993544759334233
val_r2_mean = 0.9531797369321188
val_r2_var = -2.2538839181264243

disnt_basal = 1.6884717454869058
disnt_after = 1.9935728744939272
val_r2_mean = 0.9617907206217448
val_r2_var = -2.1317054430643716



Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.6710902814449724
disnt_after = 1.9935841205578049
val_r2_mean = 0.9690947731335958
val_r2_var = -2.030385732650757



Epoch 00104: cpa_metric reached. Module best state updated.



disnt_basal = 1.6654835685227563
disnt_after = 1.9934997750787224
val_r2_mean = 0.9580827355384827
val_r2_var = -1.903363545735677

disnt_basal = 1.6686250708990982
disnt_after = 1.9934941520467837
val_r2_mean = 0.9605032801628113
val_r2_var = -1.8293663660685222



Epoch 00129: cpa_metric reached. Module best state updated.



disnt_basal = 1.4764888199456276
disnt_after = 1.9935166441745389
val_r2_mean = 0.968933622042338
val_r2_var = -1.8162755966186523

disnt_basal = 1.6443869917268088
disnt_after = 1.9935672514619884
val_r2_mean = 0.9543606241544088
val_r2_var = -1.7310787041982014



Epoch 00149: cpa_metric reached. Module best state updated.



disnt_basal = 1.300550629290618
disnt_after = 1.9935897435897436
val_r2_mean = 0.9679493308067322
val_r2_var = -1.7111637592315674

disnt_basal = 1.607475637602926
disnt_after = 1.9935166441745389
val_r2_mean = 0.9657252232233683
val_r2_var = -1.6366580327351887



Epoch 00169: cpa_metric reached. Module best state updated.



disnt_basal = 1.268466586966301
disnt_after = 1.9935616284300495
val_r2_mean = 0.9648369550704956
val_r2_var = -1.5841223398844402

disnt_basal = 1.4544108407166187
disnt_after = 1.9935728744939272
val_r2_mean = 0.9750189582506815
val_r2_var = -1.6089510917663574



Epoch 00189: cpa_metric reached. Module best state updated.



disnt_basal = 1.2519793683623774
disnt_after = 1.993578497525866
val_r2_mean = 0.9714750448862711
val_r2_var = -1.562255859375



Epoch 00194: cpa_metric reached. Module best state updated.



disnt_basal = 1.3639032520682977
disnt_after = 1.9935616284300497
val_r2_mean = 0.9702209234237671
val_r2_var = -1.538427432378133

disnt_basal = 1.271234707797923
disnt_after = 1.993578497525866
val_r2_mean = 0.9742078383763632
val_r2_var = -1.554869016011556



Epoch 00214: cpa_metric reached. Module best state updated.



disnt_basal = 1.2410917727708346
disnt_after = 1.993578497525866
val_r2_mean = 0.9714258511861166
val_r2_var = -1.4788129329681396

disnt_basal = 1.2322052431105635
disnt_after = 1.993578497525866
val_r2_mean = 0.9743609627087911
val_r2_var = -1.5345291296641033

disnt_basal = 1.1977017079348313
disnt_after = 1.9935728744939272
val_r2_mean = 0.9753010869026184
val_r2_var = -1.4767871697743733


100%|██████████| 1/1 [00:00<00:00, 19.92it/s]
cpa_single_pert:   1%|▏         | 29/2184 [08:00<14:47:31, 24.71s/it]



100%|██████████| 1546/1546 [00:00<00:00, 89037.10it/s]
100%|██████████| 1546/1546 [00:00<00:00, 1069326.18it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 940.85it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.640095536647261
disnt_after = 1.9484848484848487
val_r2_mean = 0.3480852246284485
val_r2_var = -2.7934927940368652



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6455067920585162
disnt_after = 1.9484848484848487
val_r2_mean = 0.8021127382914225
val_r2_var = -2.7962719599405923



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6479762437890515
disnt_after = 1.9483329535961116
val_r2_mean = 0.9334076642990112
val_r2_var = -2.778345743815104



Epoch 00034: cpa_metric reached. Module best state updated.



disnt_basal = 1.6534823960932334
disnt_after = 1.9468140047087419
val_r2_mean = 0.9392191966374716
val_r2_var = -2.736867904663086



Epoch 00044: cpa_metric reached. Module best state updated.



disnt_basal = 1.6804597701149424
disnt_after = 1.9429406850459483
val_r2_mean = 0.9348637660344442
val_r2_var = -2.6399147510528564



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6685251530079115
disnt_after = 1.9421052631578948
val_r2_mean = 0.9603113929430643
val_r2_var = -2.5793863932291665



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6676742797432453
disnt_after = 1.9421052631578948
val_r2_mean = 0.9615769783655802
val_r2_var = -2.518244902292887



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.662718316166592
disnt_after = 1.942257158046632
val_r2_mean = 0.9653994838396708
val_r2_var = -2.4394630591074624



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6612845369132638
disnt_after = 1.9421052631578948
val_r2_mean = 0.9674908518791199
val_r2_var = -2.3870636622111



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.6503366753950108
disnt_after = 1.9421052631578948
val_r2_mean = 0.9618961016337076
val_r2_var = -2.2656455039978027



Epoch 00104: cpa_metric reached. Module best state updated.



disnt_basal = 1.6093860378667224
disnt_after = 1.9430925799346852
val_r2_mean = 0.9597633679707845
val_r2_var = -2.2603464126586914

disnt_basal = 1.6589526323646107
disnt_after = 1.9441558441558442
val_r2_mean = 0.9624156951904297
val_r2_var = -2.082161029179891



Epoch 00124: cpa_metric reached. Module best state updated.

Epoch 00129: cpa_metric reached. Module best state updated.



disnt_basal = 1.6603334354695516
disnt_after = 1.9450672134882663
val_r2_mean = 0.9582409063975016
val_r2_var = -1.9218391577402751



Epoch 00134: cpa_metric reached. Module best state updated.

Epoch 00139: cpa_metric reached. Module best state updated.



disnt_basal = 1.6558863197701674
disnt_after = 1.9454469507101089
val_r2_mean = 0.969336748123169
val_r2_var = -1.884665886561076



Epoch 00144: cpa_metric reached. Module best state updated.

Epoch 00149: cpa_metric reached. Module best state updated.



disnt_basal = 1.653925566266764
disnt_after = 1.9449912660438977
val_r2_mean = 0.959964652856191
val_r2_var = -1.796756903330485



Epoch 00159: cpa_metric reached. Module best state updated.



disnt_basal = 1.648480864049959
disnt_after = 1.9445355813776866
val_r2_mean = 0.9636337955792745
val_r2_var = -1.7943561871846516



Epoch 00164: cpa_metric reached. Module best state updated.

Epoch 00169: cpa_metric reached. Module best state updated.



disnt_basal = 1.5593962889010853
disnt_after = 1.9430166324903166
val_r2_mean = 0.9665084083875021
val_r2_var = -1.724589506785075



Epoch 00174: cpa_metric reached. Module best state updated.



disnt_basal = 1.6047993360023436
disnt_after = 1.9425609478241057
val_r2_mean = 0.9646068016688029
val_r2_var = -1.6739269892374675

disnt_basal = 1.6438840795076213
disnt_after = 1.9424090529353688
val_r2_mean = 0.9680962761243185
val_r2_var = -1.6517505645751953

disnt_basal = 1.6398004078714474
disnt_after = 1.9424850003797371
val_r2_mean = 0.9658639430999756
val_r2_var = -1.5803298155466716


100%|██████████| 1/1 [00:00<00:00, 23.74it/s]
cpa_single_pert:   1%|▏         | 30/2184 [08:19<13:36:34, 22.75s/it]



100%|██████████| 1996/1996 [00:00<00:00, 86914.14it/s]
100%|██████████| 1996/1996 [00:00<00:00, 1096076.30it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 830.14it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6395708020050126
disnt_after = 2.0
val_r2_mean = 0.585692286491394
val_r2_var = -2.804046869277954



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6385432330827068
disnt_after = 2.0
val_r2_mean = 0.9525553782780966
val_r2_var = -2.800686518351237



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6290717866093807
disnt_after = 2.0
val_r2_mean = 0.9691138863563538
val_r2_var = -2.7605907122294107



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6192866093805942
disnt_after = 2.0
val_r2_mean = 0.9648238817850748
val_r2_var = -2.713593085606893



Epoch 00044: cpa_metric reached. Module best state updated.



disnt_basal = 1.6244587396362054
disnt_after = 2.0
val_r2_mean = 0.9666974743207296
val_r2_var = -2.658498764038086



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6450358037952024
disnt_after = 2.0
val_r2_mean = 0.9661541779836019
val_r2_var = -2.4738582770029702



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6254949611423517
disnt_after = 2.0
val_r2_mean = 0.9669811526934305
val_r2_var = -2.339651584625244



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.598015486896513
disnt_after = 2.0
val_r2_mean = 0.9728502631187439
val_r2_var = -2.233444611231486



Epoch 00084: cpa_metric reached. Module best state updated.



disnt_basal = 1.6188865809481687
disnt_after = 2.0
val_r2_mean = 0.9656796058019003
val_r2_var = -2.0714929898579917



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.2687328176884791
disnt_after = 2.0
val_r2_mean = 0.9690450231234232
val_r2_var = -1.9917240937550862

disnt_basal = 1.6057528186713284
disnt_after = 2.0
val_r2_mean = 0.9729076623916626
val_r2_var = -1.879092852274577

disnt_basal = 1.3110905853569497
disnt_after = 2.0
val_r2_mean = 0.9744779070218405
val_r2_var = -1.844575564066569


100%|██████████| 1/1 [00:00<00:00, 14.33it/s]
cpa_single_pert:   1%|▏         | 31/2184 [08:35<12:28:15, 20.85s/it]



100%|██████████| 1522/1522 [00:00<00:00, 82477.14it/s]
100%|██████████| 1522/1522 [00:00<00:00, 758613.27it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 893.07it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6636904761904763
disnt_after = 1.9171052631578946
val_r2_mean = 0.3400599956512451
val_r2_var = -2.6270230611165366



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6773496240601504
disnt_after = 1.9171052631578944
val_r2_mean = 0.7961304783821106
val_r2_var = -2.628861745198568



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6537268416138384
disnt_after = 1.9124420721615358
val_r2_mean = 0.9317131837209066
val_r2_var = -2.6135186354319253



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6437570642292005
disnt_after = 1.903401191658391
val_r2_mean = 0.94233771165212
val_r2_var = -2.5813022454579673



Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.6611842105263157
disnt_after = 1.9037818603111551
val_r2_mean = 0.9461734096209208
val_r2_var = -2.538571357727051



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6607142857142856
disnt_after = 1.9029253558424362
val_r2_mean = 0.9484102129936218
val_r2_var = -2.460632642110189



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6645989974937343
disnt_after = 1.9029253558424362
val_r2_mean = 0.9495706160863241
val_r2_var = -2.390897274017334



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6670426065162907
disnt_after = 1.9028301886792451
val_r2_mean = 0.9496411085128784
val_r2_var = -2.321360429128011



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6651942355889724
disnt_after = 1.9028301886792451
val_r2_mean = 0.9537545045216879
val_r2_var = -2.2711332639058432



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.6639724310776942
disnt_after = 1.9028301886792451
val_r2_mean = 0.9529598752657572
val_r2_var = -2.207559665044149



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.6420864406669655
disnt_after = 1.9028301886792451
val_r2_mean = 0.9533632596333822
val_r2_var = -2.142082691192627



Epoch 00114: cpa_metric reached. Module best state updated.



disnt_basal = 1.651244764704955
disnt_after = 1.90616103939093
val_r2_mean = 0.9554365873336792
val_r2_var = -2.0811290740966797

disnt_basal = 1.6580767494021773
disnt_after = 1.905114200595829
val_r2_mean = 0.946493407090505
val_r2_var = -1.979207436243693

disnt_basal = 1.657872147534949
disnt_after = 1.904352863290301
val_r2_mean = 0.9491136272748312
val_r2_var = -1.8928772608439128


100%|██████████| 1/1 [00:00<00:00, 24.38it/s]
cpa_single_pert:   1%|▏         | 32/2184 [08:47<10:53:48, 18.23s/it]



100%|██████████| 1346/1346 [00:00<00:00, 84799.60it/s]
100%|██████████| 1346/1346 [00:00<00:00, 1066603.66it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 889.94it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.710699096225412
disnt_after = 1.7398989898989898
val_r2_mean = 0.30897001425425213
val_r2_var = -2.334444046020508



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.7119484316852738
disnt_after = 1.7398989898989898
val_r2_mean = 0.7719915707906088
val_r2_var = -2.3338491916656494



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.7075348627980207
disnt_after = 1.712878787878788
val_r2_mean = 0.9169334173202515
val_r2_var = -2.3150511582692466



Epoch 00034: cpa_metric reached. Module best state updated.



disnt_basal = 1.7061500633869056
disnt_after = 1.680808080808081
val_r2_mean = 0.928422192732493
val_r2_var = -2.277638832728068



Epoch 00044: cpa_metric reached. Module best state updated.

Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.7084396597554492
disnt_after = 1.6797979797979798
val_r2_mean = 0.9219240744908651
val_r2_var = -2.2195531527201333



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.7084396597554492
disnt_after = 1.6797979797979798
val_r2_mean = 0.9305001894632975
val_r2_var = -2.131161610285441



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.7101940457203615
disnt_after = 1.6797979797979798
val_r2_mean = 0.9348088900248209
val_r2_var = -2.0330758094787598



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.7099251625567415
disnt_after = 1.6797979797979798
val_r2_mean = 0.938102145989736
val_r2_var = -2.001776854197184



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.5812845049687154
disnt_after = 1.6797979797979798
val_r2_mean = 0.944955031077067
val_r2_var = -1.95359206199646



Epoch 00094: cpa_metric reached. Module best state updated.



disnt_basal = 1.6857972436919804
disnt_after = 1.6934343434343435
val_r2_mean = 0.9375527103741964
val_r2_var = -1.8792461554209392

disnt_basal = 1.6994622336727598
disnt_after = 1.7085858585858587
val_r2_mean = 0.9324017961819967
val_r2_var = -1.773634433746338

disnt_basal = 1.6611065104486158
disnt_after = 1.6967171717171716
val_r2_mean = 0.9458846052487692
val_r2_var = -1.7406307856241863


100%|██████████| 1/1 [00:00<00:00, 23.85it/s]
cpa_single_pert:   2%|▏         | 33/2184 [08:58<9:35:38, 16.06s/it] 



100%|██████████| 1234/1234 [00:00<00:00, 62903.60it/s]
100%|██████████| 1234/1234 [00:00<00:00, 1073811.44it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 839.87it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.5311475409836066
disnt_after = 1.5311475409836066
val_r2_mean = 0.28604594866434735
val_r2_var = -2.1814793745676675



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.5311475409836066
disnt_after = 1.5311475409836066
val_r2_mean = 0.7524718244870504
val_r2_var = -2.1832507451375327



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.5178988535304831
disnt_after = 1.5185079591351864
val_r2_mean = 0.8967121839523315
val_r2_var = -2.16507355372111



Epoch 00034: cpa_metric reached. Module best state updated.



disnt_basal = 1.5216597021322191
disnt_after = 1.485103349964362
val_r2_mean = 0.9008798996607462
val_r2_var = -2.121001402537028



Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.5308207435979857
disnt_after = 1.472463768115942
val_r2_mean = 0.909799059232076
val_r2_var = -2.065092404683431



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.5290206793099754
disnt_after = 1.4729151817533856
val_r2_mean = 0.913582980632782
val_r2_var = -2.0029873053232827



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.516205935926283
disnt_after = 1.472463768115942
val_r2_mean = 0.9211909770965576
val_r2_var = -1.9440340995788574



Epoch 00074: cpa_metric reached. Module best state updated.



disnt_basal = 1.505372438647773
disnt_after = 1.4738180090282729
val_r2_mean = 0.9091137647628784
val_r2_var = -1.896720012029012



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.5093378334940533
disnt_after = 1.4832976954145878
val_r2_mean = 0.920198400815328
val_r2_var = -1.7879133224487305



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.3089080104509225
disnt_after = 1.476526490852934
val_r2_mean = 0.9329630931218466
val_r2_var = -1.7800265947977703

disnt_basal = 1.4080912947082078
disnt_after = 1.4738180090282729
val_r2_mean = 0.9311217466990153
val_r2_var = -1.710644801457723



Epoch 00119: cpa_metric reached. Module best state updated.



disnt_basal = 1.332954324336237
disnt_after = 1.4738180090282729
val_r2_mean = 0.9297945300738016
val_r2_var = -1.601075569788615



Epoch 00124: cpa_metric reached. Module best state updated.

Epoch 00129: cpa_metric reached. Module best state updated.



disnt_basal = 1.1362016393084273
disnt_after = 1.4733665953908293
val_r2_mean = 0.917973796526591
val_r2_var = -1.5384290218353271



Epoch 00134: cpa_metric reached. Module best state updated.



disnt_basal = 1.1517621687635522
disnt_after = 1.4747208363031596
val_r2_mean = 0.9248737096786499
val_r2_var = -1.4518404801686604



Epoch 00144: cpa_metric reached. Module best state updated.

Epoch 00149: cpa_metric reached. Module best state updated.



disnt_basal = 1.0944953659955128
disnt_after = 1.4751722499406035
val_r2_mean = 0.9208947022755941
val_r2_var = -1.3799715042114258



Epoch 00154: cpa_metric reached. Module best state updated.

Epoch 00159: cpa_metric reached. Module best state updated.



disnt_basal = 1.0140862054199022
disnt_after = 1.4747208363031596
val_r2_mean = 0.9333638747533163
val_r2_var = -1.347568194071452

disnt_basal = 1.1087038046389854
disnt_after = 1.4738180090282729
val_r2_mean = 0.9262394110361735
val_r2_var = -1.2535663445790608

disnt_basal = 1.0746870618050024
disnt_after = 1.4747208363031599
val_r2_mean = 0.9346061944961548
val_r2_var = -1.2436909675598145



Epoch 00184: cpa_metric reached. Module best state updated.

Epoch 00189: cpa_metric reached. Module best state updated.



disnt_basal = 1.0513674819490089
disnt_after = 1.4792349726775957
val_r2_mean = 0.9333545565605164
val_r2_var = -1.1815520127614338

disnt_basal = 1.158271688155131
disnt_after = 1.4805892135899263
val_r2_mean = 0.9368560115496317
val_r2_var = -1.1743805408477783

disnt_basal = 1.085684110075121
disnt_after = 1.476977904490378
val_r2_mean = 0.9388951063156128
val_r2_var = -1.1073819796244304



Epoch 00214: cpa_metric reached. Module best state updated.

Epoch 00219: cpa_metric reached. Module best state updated.



disnt_basal = 1.02093370272063
disnt_after = 1.4760750772154907
val_r2_mean = 0.9408150116602579
val_r2_var = -1.0983458360036213

disnt_basal = 1.0855624860467423
disnt_after = 1.4747208363031599
val_r2_mean = 0.942037840684255
val_r2_var = -1.0534868637720745

disnt_basal = 1.079323986845684
disnt_after = 1.47472083630316
val_r2_mean = 0.9433406790097555
val_r2_var = -1.0466297467549641


100%|██████████| 1/1 [00:00<00:00, 25.34it/s]
cpa_single_pert:   2%|▏         | 34/2184 [09:24<11:22:45, 19.05s/it]



100%|██████████| 1516/1516 [00:00<00:00, 84083.53it/s]
100%|██████████| 1516/1516 [00:00<00:00, 1083415.38it/s]


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        


Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 706.05it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6427619047619046
disnt_after = 1.9166666666666665
val_r2_mean = 0.32826022307078045
val_r2_var = -2.63876740137736



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.647531746031746
disnt_after = 1.9166666666666665
val_r2_mean = 0.7898779908816019
val_r2_var = -2.64150063196818



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6702222222222223
disnt_after = 1.9120545073375261
val_r2_mean = 0.9328638116518656
val_r2_var = -2.631096919377645



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6584761904761904
disnt_after = 1.9029224318658278
val_r2_mean = 0.9476055900255839
val_r2_var = -2.607769330342611

disnt_basal = 1.6863095238095238
disnt_after = 1.9028301886792451
val_r2_mean = 0.944433351357778
val_r2_var = -2.5645602544148765



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6852698412698413
disnt_after = 1.9028301886792451
val_r2_mean = 0.9487901528676351
val_r2_var = -2.515448729197184



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.678936507936508
disnt_after = 1.9030146750524108
val_r2_mean = 0.9471489389737447
val_r2_var = -2.4555321534474692



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.676031746031746
disnt_after = 1.9028301886792451
val_r2_mean = 0.953313966592153
val_r2_var = -2.383591334025065



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6750555555555555
disnt_after = 1.9028301886792454
val_r2_mean = 0.9531379739443461
val_r2_var = -2.316852887471517



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.5338252470799643
disnt_after = 1.9028301886792451
val_r2_mean = 0.9542541901270548
val_r2_var = -2.275167385737101

disnt_basal = 1.6771419586702605
disnt_after = 1.90956394129979
val_r2_mean = 0.9551197091738383
val_r2_var = -2.183579921722412

disnt_basal = 1.6882976939203354
disnt_after = 1.9078113207547167
val_r2_mean = 0.9585692286491394
val_r2_var = -2.068634827931722


100%|██████████| 1/1 [00:00<00:00, 27.91it/s]
cpa_single_pert:   2%|▏         | 35/2184 [09:36<10:05:29, 16.91s/it]



100%|██████████| 1788/1788 [00:00<00:00, 82171.87it/s]
100%|██████████| 1788/1788 [00:00<00:00, 1096245.51it/s]


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        


Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 890.60it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6630632911392405
disnt_after = 2.0
val_r2_mean = 0.5820907155672709
val_r2_var = -2.7379111448923745



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6601833098921706
disnt_after = 2.0
val_r2_mean = 0.9467112421989441
val_r2_var = -2.729804515838623



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.677368495077356
disnt_after = 2.0
val_r2_mean = 0.9628198544184366
val_r2_var = -2.6880087852478027



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6675630567276136
disnt_after = 2.0
val_r2_mean = 0.9640020728111267
val_r2_var = -2.641817331314087



Epoch 00044: cpa_metric reached. Module best state updated.



disnt_basal = 1.6619266626481817
disnt_after = 2.0
val_r2_mean = 0.9656092723210653
val_r2_var = -2.5781522591908774



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6596574576384704
disnt_after = 2.0
val_r2_mean = 0.9608611265818278
val_r2_var = -2.437530517578125



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.648819503047351
disnt_after = 2.0
val_r2_mean = 0.9663953979810079
val_r2_var = -2.307945807774862



Epoch 00074: cpa_metric reached. Module best state updated.



disnt_basal = 1.6517759024847631
disnt_after = 2.0
val_r2_mean = 0.9639490246772766
val_r2_var = -2.1970104376475015



Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.621313140446052
disnt_after = 2.0
val_r2_mean = 0.9736519257227579
val_r2_var = -2.090534766515096



Epoch 00094: cpa_metric reached. Module best state updated.



disnt_basal = 1.6425059272654208
disnt_after = 2.0
val_r2_mean = 0.9755843281745911
val_r2_var = -2.063404401143392

disnt_basal = 1.5584574040586698
disnt_after = 2.0
val_r2_mean = 0.9693816701571146
val_r2_var = -1.9088459809621174



Epoch 00114: cpa_metric reached. Module best state updated.



disnt_basal = 1.497324258254638
disnt_after = 2.0
val_r2_mean = 0.9726394017537435
val_r2_var = -1.8585090637207031

disnt_basal = 1.5007174335275602
disnt_after = 2.0
val_r2_mean = 0.9722664753595988
val_r2_var = -1.7407111326853435



Epoch 00134: cpa_metric reached. Module best state updated.



disnt_basal = 1.3992795860960419
disnt_after = 2.0
val_r2_mean = 0.9679755171140035
val_r2_var = -1.701666037241618

disnt_basal = 1.3442530306074612
disnt_after = 2.0
val_r2_mean = 0.975699782371521
val_r2_var = -1.6993837356567383



Epoch 00154: cpa_metric reached. Module best state updated.



disnt_basal = 1.4485691179425357
disnt_after = 2.0
val_r2_mean = 0.9751897056897482
val_r2_var = -1.6377362410227458

disnt_basal = 1.3118263009845288
disnt_after = 2.0
val_r2_mean = 0.9757348299026489
val_r2_var = -1.5907539526621501



Epoch 00174: cpa_metric reached. Module best state updated.



disnt_basal = 1.2770551871944278
disnt_after = 2.0
val_r2_mean = 0.9739026029904684
val_r2_var = -1.547671635945638

disnt_basal = 1.2082240305404863
disnt_after = 2.0
val_r2_mean = 0.9775111476580302
val_r2_var = -1.5451339880625408



Epoch 00194: cpa_metric reached. Module best state updated.



disnt_basal = 1.2918249949768936
disnt_after = 2.0
val_r2_mean = 0.9717827240626017
val_r2_var = -1.4668157895406086


100%|██████████| 1/1 [00:00<00:00, 16.76it/s]
cpa_single_pert:   2%|▏         | 36/2184 [10:05<12:11:08, 20.42s/it]



100%|██████████| 1384/1384 [00:00<00:00, 83180.49it/s]
100%|██████████| 1384/1384 [00:00<00:00, 1075582.13it/s]


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        


Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 942.65it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.7082428765264586
disnt_after = 1.7907960199004973
val_r2_mean = 0.32188816865285236
val_r2_var = -2.432789166768392



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.7127883310719132
disnt_after = 1.7907960199004973
val_r2_mean = 0.7815403938293457
val_r2_var = -2.4346089363098145



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6939850746268654
disnt_after = 1.7863311646893736
val_r2_mean = 0.9188138047854105
val_r2_var = -2.41943359375



Epoch 00034: cpa_metric reached. Module best state updated.



disnt_basal = 1.697903211216644
disnt_after = 1.7638282944253092
val_r2_mean = 0.9295396010080973
val_r2_var = -2.3889804681142173



Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.708815015829941
disnt_after = 1.7734723816813367
val_r2_mean = 0.9324401418368021
val_r2_var = -2.3495686848958335



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.7142582541836273
disnt_after = 1.7757941063911211
val_r2_mean = 0.9353885253270467
val_r2_var = -2.290321667989095



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.7145296246042514
disnt_after = 1.772043628013777
val_r2_mean = 0.9417128364245096
val_r2_var = -2.234018564224243



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.7132075983717774
disnt_after = 1.7586490623804056
val_r2_mean = 0.9401480555534363
val_r2_var = -2.170130968093872



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.7093022417052268
disnt_after = 1.7509695114172725
val_r2_mean = 0.9456554253896078
val_r2_var = -2.117931842803955



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.518632246691948
disnt_after = 1.750255134583493
val_r2_mean = 0.9531817038853964
val_r2_var = -2.0678362051645913

disnt_basal = 1.6955884447227731
disnt_after = 1.764721265467534
val_r2_mean = 0.9370224873224894
val_r2_var = -1.983551025390625

disnt_basal = 1.6988643959688734
disnt_after = 1.767935961219543
val_r2_mean = 0.9377460877100626
val_r2_var = -1.879076639811198


100%|██████████| 1/1 [00:00<00:00, 25.95it/s]
cpa_single_pert:   2%|▏         | 37/2184 [10:15<10:27:11, 17.53s/it]



100%|██████████| 1524/1524 [00:00<00:00, 80458.67it/s]
100%|██████████| 1524/1524 [00:00<00:00, 1042590.00it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 937.90it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6476817042606517
disnt_after = 1.9214912280701753
val_r2_mean = 0.3356865843137105
val_r2_var = -2.6169772942860923



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6480889724310779
disnt_after = 1.9214912280701753
val_r2_mean = 0.794481118520101
val_r2_var = -2.618765195210775



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6193809523809524
disnt_after = 1.9137183235867448
val_r2_mean = 0.933531920115153
val_r2_var = -2.5978306929270425



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6217894736842107
disnt_after = 1.9100552306692657
val_r2_mean = 0.9462761084238688
val_r2_var = -2.5628546873728433

disnt_basal = 1.6560150375939848
disnt_after = 1.9098765432098765
val_r2_mean = 0.9436261455217997
val_r2_var = -2.495734532674154



Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6552318295739348
disnt_after = 1.9098765432098765
val_r2_mean = 0.9451069037119547
val_r2_var = -2.424853563308716



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6492167919799499
disnt_after = 1.9098765432098765
val_r2_mean = 0.9455073674519857
val_r2_var = -2.3550473054250083



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6441979949874685
disnt_after = 1.9098765432098765
val_r2_mean = 0.948945144812266
val_r2_var = -2.300179402033488



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6406290726817043
disnt_after = 1.9098765432098765
val_r2_mean = 0.9596871336301168
val_r2_var = -2.272189219792684



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.643777824190105
disnt_after = 1.9098765432098765
val_r2_mean = 0.958449920018514
val_r2_var = -2.1867775917053223



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.6110033416875522
disnt_after = 1.9098765432098765
val_r2_mean = 0.9405744671821594
val_r2_var = -2.124051094055176



Epoch 00114: cpa_metric reached. Module best state updated.



disnt_basal = 1.635260280330456
disnt_after = 1.9158625730994152
val_r2_mean = 0.9547902941703796
val_r2_var = -2.1309216022491455

disnt_basal = 1.65784335839599
disnt_after = 1.919525666016894
val_r2_mean = 0.9515783389409384
val_r2_var = -1.9321650664011638

disnt_basal = 1.6472994987468672
disnt_after = 1.9191682910981156
val_r2_mean = 0.9548614025115967
val_r2_var = -1.8410189151763916


100%|██████████| 1/1 [00:00<00:00, 21.52it/s]
cpa_single_pert:   2%|▏         | 38/2184 [10:29<9:49:12, 16.47s/it] 



100%|██████████| 1342/1342 [00:00<00:00, 79834.85it/s]
100%|██████████| 1342/1342 [00:00<00:00, 1115488.70it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 674.43it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.7065656565656566
disnt_after = 1.7398989898989898
val_r2_mean = 0.30735133091608685
val_r2_var = -2.3979581197102866



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.7084396597554492
disnt_after = 1.7398989898989898
val_r2_mean = 0.7711598873138428
val_r2_var = -2.399221420288086



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6985589702694965
disnt_after = 1.71489898989899
val_r2_mean = 0.9087949395179749
val_r2_var = -2.3820042610168457



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6960628552733816
disnt_after = 1.6845959595959596
val_r2_mean = 0.9263973037401835
val_r2_var = -2.3436620235443115



Epoch 00044: cpa_metric reached. Module best state updated.

Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.7036161207213838
disnt_after = 1.6797979797979798
val_r2_mean = 0.923003077507019
val_r2_var = -2.2997759183247886



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.7119484316852738
disnt_after = 1.6797979797979798
val_r2_mean = 0.9234435359636942
val_r2_var = -2.2307563622792563



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.7087274567537725
disnt_after = 1.6797979797979798
val_r2_mean = 0.9340178569157919
val_r2_var = -2.150183916091919



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.7078339058602217
disnt_after = 1.6797979797979798
val_r2_mean = 0.9352787931760153
val_r2_var = -2.065797726313273



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6842406862143704
disnt_after = 1.6797979797979798
val_r2_mean = 0.9440385897954305
val_r2_var = -1.9830148220062256



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.442345008792377
disnt_after = 1.6797979797979798
val_r2_mean = 0.9342739582061768
val_r2_var = -1.9567835330963135

disnt_basal = 1.6948886639676113
disnt_after = 1.6866161616161617
val_r2_mean = 0.9324008027712504
val_r2_var = -1.820184866587321

disnt_basal = 1.6512886966834337
disnt_after = 1.68510101010101
val_r2_mean = 0.9425960779190063
val_r2_var = -1.726441224416097


100%|██████████| 1/1 [00:00<00:00, 32.77it/s]
cpa_single_pert:   2%|▏         | 39/2184 [10:42<9:02:50, 15.18s/it]



100%|██████████| 1200/1200 [00:00<00:00, 47674.74it/s]
100%|██████████| 1200/1200 [00:00<00:00, 665586.46it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 749.45it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.4724137931034482
disnt_after = 1.4724137931034482
val_r2_mean = 0.2854882876078288
val_r2_var = -2.1569059689839682



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.4724137931034482
disnt_after = 1.471867816091954
val_r2_mean = 0.7545482118924459
val_r2_var = -2.158172369003296



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.4724137931034482
disnt_after = 1.4724137931034482
val_r2_mean = 0.8900537689526876
val_r2_var = -2.147404670715332



Epoch 00034: cpa_metric reached. Module best state updated.



disnt_basal = 1.4724137931034482
disnt_after = 1.4396551724137931
val_r2_mean = 0.9033295114835104
val_r2_var = -2.1236544450124106

disnt_basal = 1.4724137931034482
disnt_after = 1.42
val_r2_mean = 0.9081450899442037
val_r2_var = -2.087268352508545



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.4724137931034482
disnt_after = 1.42
val_r2_mean = 0.9199134111404419
val_r2_var = -2.034167448679606



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.4565268199233716
disnt_after = 1.42
val_r2_mean = 0.9203456044197083
val_r2_var = -1.9845216274261475



Epoch 00074: cpa_metric reached. Module best state updated.



disnt_basal = 1.4537164750957854
disnt_after = 1.4260057471264367
val_r2_mean = 0.9134201407432556
val_r2_var = -1.9303817749023438



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.4106302681992338
disnt_after = 1.4494827586206895
val_r2_mean = 0.9193650881449381
val_r2_var = -1.8470563888549805



Epoch 00094: cpa_metric reached. Module best state updated.



disnt_basal = 1.4161743295019158
disnt_after = 1.4216379310344824
val_r2_mean = 0.9054044087727865
val_r2_var = -1.8250211874643962

disnt_basal = 1.387800766283525
disnt_after = 1.4210919540229885
val_r2_mean = 0.9167782266934713
val_r2_var = -1.7473111152648926



Epoch 00114: cpa_metric reached. Module best state updated.

Epoch 00119: cpa_metric reached. Module best state updated.



disnt_basal = 1.230250957854406
disnt_after = 1.422183908045977
val_r2_mean = 0.9127552310625712
val_r2_var = -1.6622037887573242



Epoch 00124: cpa_metric reached. Module best state updated.



disnt_basal = 1.1744099616858237
disnt_after = 1.4260057471264367
val_r2_mean = 0.9208822449048361
val_r2_var = -1.5953550338745117

disnt_basal = 1.1960881226053641
disnt_after = 1.4281896551724138
val_r2_mean = 0.922776977221171
val_r2_var = -1.5346468289693196



Epoch 00144: cpa_metric reached. Module best state updated.

Epoch 00149: cpa_metric reached. Module best state updated.



disnt_basal = 1.0984444444444446
disnt_after = 1.424367816091954
val_r2_mean = 0.9277490178743998
val_r2_var = -1.4827464421590169

disnt_basal = 1.131256704980843
disnt_after = 1.4238218390804596
val_r2_mean = 0.9183834195137024
val_r2_var = -1.4340113798777263



Epoch 00164: cpa_metric reached. Module best state updated.



disnt_basal = 1.0991896551724136
disnt_after = 1.4227298850574712
val_r2_mean = 0.9318670829137167
val_r2_var = -1.382994810740153



Epoch 00174: cpa_metric reached. Module best state updated.

Epoch 00179: cpa_metric reached. Module best state updated.



disnt_basal = 1.0425613026819924
disnt_after = 1.424367816091954
val_r2_mean = 0.9320949713389078
val_r2_var = -1.3292584419250488



Epoch 00189: cpa_metric reached. Module best state updated.



disnt_basal = 1.037704980842912
disnt_after = 1.424367816091954
val_r2_mean = 0.9288763205210367
val_r2_var = -1.2703559398651123

disnt_basal = 1.0455555555555556
disnt_after = 1.4249137931034483
val_r2_mean = 0.9277613162994385
val_r2_var = -1.266111175219218



Epoch 00209: cpa_metric reached. Module best state updated.



disnt_basal = 1.0389233716475097
disnt_after = 1.4249137931034483
val_r2_mean = 0.9314190745353699
val_r2_var = -1.20913827419281



Epoch 00214: cpa_metric reached. Module best state updated.



disnt_basal = 1.0517509578544062
disnt_after = 1.4238218390804596
val_r2_mean = 0.9243671695391337
val_r2_var = -1.1926218271255493

disnt_basal = 1.0598467432950192
disnt_after = 1.4238218390804596
val_r2_mean = 0.9306682745615641
val_r2_var = -1.1889361540476482



Epoch 00239: cpa_metric reached. Module best state updated.



disnt_basal = 1.030300766283525
disnt_after = 1.4238218390804596
val_r2_mean = 0.9327283898989359
val_r2_var = -1.150606910387675

disnt_basal = 1.0660249042145593
disnt_after = 1.424367816091954
val_r2_mean = 0.9266292055447897
val_r2_var = -1.1432095368703206



Epoch 00254: cpa_metric reached. Module best state updated.

Epoch 00259: cpa_metric reached. Module best state updated.



disnt_basal = 1.0375862068965516
disnt_after = 1.4254597701149425
val_r2_mean = 0.9265593687693278
val_r2_var = -1.085330327351888



Epoch 00264: cpa_metric reached. Module best state updated.

Epoch 00269: cpa_metric reached. Module best state updated.



disnt_basal = 1.0220402298850575
disnt_after = 1.424367816091954
val_r2_mean = 0.9317479133605957
val_r2_var = -1.0857720772425334

disnt_basal = 1.0341494252873562
disnt_after = 1.4260057471264367
val_r2_mean = 0.9308292667071024
val_r2_var = -1.0931244293848674

disnt_basal = 1.0419252873563218
disnt_after = 1.4238218390804596
val_r2_mean = 0.9291637539863586
val_r2_var = -1.1020429531733196


100%|██████████| 1/1 [00:00<00:00, 24.05it/s]
cpa_single_pert:   2%|▏         | 40/2184 [11:09<11:13:50, 18.86s/it]



100%|██████████| 1438/1438 [00:00<00:00, 80437.02it/s]
100%|██████████| 1438/1438 [00:00<00:00, 1083421.80it/s]


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        


Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 942.12it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6435701163502756
disnt_after = 1.8234741784037558
val_r2_mean = 0.3149514396985372
val_r2_var = -2.451725800832113



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6519902020820576
disnt_after = 1.8234741784037558
val_r2_mean = 0.7842726906140646
val_r2_var = -2.4531131585439048



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6554194733619105
disnt_after = 1.818753827311696
val_r2_mean = 0.9337297280629476
val_r2_var = -2.4407875537872314



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6555011226780976
disnt_after = 1.7999999999999998
val_r2_mean = 0.9500730037689209
val_r2_var = -2.408998648325602



Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.6612573994692794
disnt_after = 1.7999999999999998
val_r2_mean = 0.9358778198560079
val_r2_var = -2.3580544789632163



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6642784241681976
disnt_after = 1.7999999999999998
val_r2_mean = 0.9426230192184448
val_r2_var = -2.297736724217733



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.669585629720351
disnt_after = 1.8005103082261686
val_r2_mean = 0.9372543692588806
val_r2_var = -2.239373763402303



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6642784241681976
disnt_after = 1.7999999999999998
val_r2_mean = 0.9351919094721476
val_r2_var = -2.179274797439575



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6693713002653603
disnt_after = 1.8002551541130842
val_r2_mean = 0.9456000725428263
val_r2_var = -2.128601551055908



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.6570639416207391
disnt_after = 1.8001275770565421
val_r2_mean = 0.9515374104181925
val_r2_var = -2.044280211130778



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.3075176906851738
disnt_after = 1.8005103082261686
val_r2_mean = 0.9444850484530131
val_r2_var = -1.9788976510365803

disnt_basal = 1.6500918554807105
disnt_after = 1.8178607879159012
val_r2_mean = 0.9383877515792847
val_r2_var = -1.8967369397481282

disnt_basal = 1.659399664897598
disnt_after = 1.8163298632373954
val_r2_mean = 0.9459583759307861
val_r2_var = -1.762426773707072


100%|██████████| 1/1 [00:00<00:00, 24.00it/s]
cpa_single_pert:   2%|▏         | 41/2184 [11:23<10:22:03, 17.42s/it]



100%|██████████| 1300/1300 [00:00<00:00, 80065.13it/s]
100%|██████████| 1300/1300 [00:00<00:00, 799969.95it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 880.14it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.655026455026455
disnt_after = 1.655026455026455
val_r2_mean = 0.28868240118026733
val_r2_var = -2.1998306115468345



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.655026455026455
disnt_after = 1.655026455026455
val_r2_mean = 0.7593177556991577
val_r2_var = -2.199819008509318



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.655026455026455
disnt_after = 1.6282186948853616
val_r2_mean = 0.9074740608533224
val_r2_var = -2.1835307280222573



Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.655026455026455
disnt_after = 1.5950617283950617
val_r2_mean = 0.9282337228457133
val_r2_var = -2.1438839435577393



Epoch 00044: cpa_metric reached. Module best state updated.

Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.655026455026455
disnt_after = 1.5954144620811288
val_r2_mean = 0.9117046991984049
val_r2_var = -2.1005664666493735



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.655026455026455
disnt_after = 1.5950617283950617
val_r2_mean = 0.9202151695887247
val_r2_var = -2.04038135210673



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.655026455026455
disnt_after = 1.5950617283950617
val_r2_mean = 0.9200352430343628
val_r2_var = -1.9668527444203694



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6524587612822907
disnt_after = 1.5950617283950617
val_r2_mean = 0.9308096766471863
val_r2_var = -1.9143868287404378



Epoch 00084: cpa_metric reached. Module best state updated.



disnt_basal = 1.5477861165430533
disnt_after = 1.5950617283950617
val_r2_mean = 0.9137888352076212
val_r2_var = -1.8660860061645508

disnt_basal = 1.620453268658966
disnt_after = 1.600352733686067
val_r2_mean = 0.9076457420984904
val_r2_var = -1.7657086849212646



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.5125272331154682
disnt_after = 1.6
val_r2_mean = 0.9236433903376261
val_r2_var = -1.6944154103597004



Epoch 00114: cpa_metric reached. Module best state updated.

Epoch 00119: cpa_metric reached. Module best state updated.



disnt_basal = 1.2425957245232124
disnt_after = 1.595767195767196
val_r2_mean = 0.928682784239451
val_r2_var = -1.6561675866444905

disnt_basal = 1.5201799098395474
disnt_after = 1.5950617283950617
val_r2_mean = 0.9267716209093729
val_r2_var = -1.5987431208292644

disnt_basal = 1.461889395870158
disnt_after = 1.5950617283950617
val_r2_mean = 0.9258437951405843
val_r2_var = -1.4929076830546062


100%|██████████| 1/1 [00:00<00:00, 20.74it/s]
cpa_single_pert:   2%|▏         | 42/2184 [11:37<9:41:09, 16.28s/it] 



100%|██████████| 1348/1348 [00:00<00:00, 77574.25it/s]
100%|██████████| 1348/1348 [00:00<00:00, 1099770.82it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 908.15it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.7165736310473152
disnt_after = 1.7398989898989898
val_r2_mean = 0.31150086720784503
val_r2_var = -2.406089464823405



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.7224481658692183
disnt_after = 1.7398989898989898
val_r2_mean = 0.7790359059969584
val_r2_var = -2.4070212046305337



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.7309675704412546
disnt_after = 1.7378787878787878
val_r2_mean = 0.9092292388280233
val_r2_var = -2.3940370082855225



Epoch 00034: cpa_metric reached. Module best state updated.



disnt_basal = 1.7324694311536417
disnt_after = 1.6946969696969696
val_r2_mean = 0.9187950094540914
val_r2_var = -2.3607449531555176

disnt_basal = 1.7309675704412546
disnt_after = 1.6803030303030302
val_r2_mean = 0.9241755406061808
val_r2_var = -2.3200505574544272



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.724069643806486
disnt_after = 1.6992424242424242
val_r2_mean = 0.9303204218546549
val_r2_var = -2.244027773539225



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.7294657097288675
disnt_after = 1.6921717171717172
val_r2_mean = 0.9351280331611633
val_r2_var = -2.169191281000773



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.726581605528974
disnt_after = 1.6810606060606061
val_r2_mean = 0.9405287106831869
val_r2_var = -2.1126961708068848



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.7241759702286017
disnt_after = 1.6797979797979798
val_r2_mean = 0.9426804979642233
val_r2_var = -2.0616015593210855



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.5216450946714104
disnt_after = 1.683838383838384
val_r2_mean = 0.9368222951889038
val_r2_var = -2.0170138676961265

disnt_basal = 1.718795751032593
disnt_after = 1.6904040404040404
val_r2_mean = 0.9380606214205424
val_r2_var = -1.8888194561004639

disnt_basal = 1.7007417290312028
disnt_after = 1.6876262626262628
val_r2_mean = 0.9387651085853577
val_r2_var = -1.8169821898142497


100%|██████████| 1/1 [00:00<00:00, 28.47it/s]
cpa_single_pert:   2%|▏         | 43/2184 [11:48<8:44:08, 14.69s/it]



100%|██████████| 1284/1284 [00:00<00:00, 78785.86it/s]
100%|██████████| 1284/1284 [00:00<00:00, 1092169.20it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 1020.14it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6543010752688172
disnt_after = 1.6543010752688172
val_r2_mean = 0.34304624795913696
val_r2_var = -2.521113157272339



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6543010752688172
disnt_after = 1.6543010752688172
val_r2_mean = 0.7849728465080261
val_r2_var = -2.521933078765869



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6525330851943756
disnt_after = 1.6483771405814418
val_r2_mean = 0.9215022126833597
val_r2_var = -2.5026097297668457



Epoch 00034: cpa_metric reached. Module best state updated.



disnt_basal = 1.6543010752688172
disnt_after = 1.6110911987256076
val_r2_mean = 0.9355577230453491
val_r2_var = -2.458153009414673

disnt_basal = 1.6543010752688172
disnt_after = 1.5968040621266428
val_r2_mean = 0.9273869196573893
val_r2_var = -2.4199863274892173



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6543010752688172
disnt_after = 1.5978494623655912
val_r2_mean = 0.9365293184916178
val_r2_var = -2.343870004018148



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6543010752688172
disnt_after = 1.597152528872959
val_r2_mean = 0.9341022372245789
val_r2_var = -2.2763936519622803



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6543010752688172
disnt_after = 1.5957586618876942
val_r2_mean = 0.9367160797119141
val_r2_var = -2.2019657293955484



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6539805624483044
disnt_after = 1.5961071286340103
val_r2_mean = 0.9415059487024943
val_r2_var = -2.1526381174723306



Epoch 00094: cpa_metric reached. Module best state updated.



disnt_basal = 1.6363665482735692
disnt_after = 1.6051672640382317
val_r2_mean = 0.9329812924067179
val_r2_var = -2.1142261028289795



Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.6528019023986766
disnt_after = 1.6299084030266826
val_r2_mean = 0.9430886109670004
val_r2_var = -1.9600738684336345



Epoch 00114: cpa_metric reached. Module best state updated.

Epoch 00119: cpa_metric reached. Module best state updated.



disnt_basal = 1.5968266775509026
disnt_after = 1.6110911987256074
val_r2_mean = 0.9400871793429056
val_r2_var = -1.9162033398946126



Epoch 00124: cpa_metric reached. Module best state updated.

Epoch 00129: cpa_metric reached. Module best state updated.



disnt_basal = 1.281126167036983
disnt_after = 1.6006371963361212
val_r2_mean = 0.9398277600606283
val_r2_var = -1.8777410984039307

disnt_basal = 1.5207268561316634
disnt_after = 1.595410195141378
val_r2_mean = 0.9372723897298177
val_r2_var = -1.8487006823221843

disnt_basal = 1.467082912109794
disnt_after = 1.595410195141378
val_r2_mean = 0.9456833600997925
val_r2_var = -1.7393542130788167



Epoch 00154: cpa_metric reached. Module best state updated.

Epoch 00159: cpa_metric reached. Module best state updated.



disnt_basal = 1.2325633051676697
disnt_after = 1.5968040621266426
val_r2_mean = 0.9478431741396586
val_r2_var = -1.681105136871338



Epoch 00164: cpa_metric reached. Module best state updated.



disnt_basal = 1.2518727328262431
disnt_after = 1.6013341298287536
val_r2_mean = 0.9355857570966085
val_r2_var = -1.5624390443166096


100%|██████████| 1/1 [00:00<00:00, 28.06it/s]
cpa_single_pert:   2%|▏         | 44/2184 [12:03<8:50:56, 14.89s/it]



100%|██████████| 1308/1308 [00:00<00:00, 79961.37it/s]
100%|██████████| 1308/1308 [00:00<00:00, 1057469.09it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 749.38it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6857142857142857
disnt_after = 1.6857142857142857
val_r2_mean = 0.3304462432861328
val_r2_var = -2.3160115083058677



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6857142857142857
disnt_after = 1.6857142857142857
val_r2_mean = 0.7809137304623922
val_r2_var = -2.316676219304403



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.677193770589997
disnt_after = 1.6784391534391534
val_r2_mean = 0.9167486429214478
val_r2_var = -2.2963670094807944

disnt_basal = 1.6769292203254467
disnt_after = 1.626851851851852
val_r2_mean = 0.9292843739191691
val_r2_var = -2.2570574283599854



Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.6845562543675752
disnt_after = 1.6261904761904762
val_r2_mean = 0.9260458747545878
val_r2_var = -2.1957576274871826



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6851352700409303
disnt_after = 1.6261904761904762
val_r2_mean = 0.9324094454447428
val_r2_var = -2.145029067993164



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6845562543675752
disnt_after = 1.6261904761904762
val_r2_mean = 0.9328614672025045
val_r2_var = -2.0778594811757407



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6667939502845162
disnt_after = 1.6261904761904762
val_r2_mean = 0.9352286060651144
val_r2_var = -2.009737491607666



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.5613394728960768
disnt_after = 1.6261904761904762
val_r2_mean = 0.9437793095906576
val_r2_var = -1.9943079153696697

disnt_basal = 1.6845063392233204
disnt_after = 1.642063492063492
val_r2_mean = 0.9489092230796814
val_r2_var = -1.8672505219777424



Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.6673193071777979
disnt_after = 1.6390873015873015
val_r2_mean = 0.9475409785906473
val_r2_var = -1.781889835993449



Epoch 00114: cpa_metric reached. Module best state updated.

Epoch 00119: cpa_metric reached. Module best state updated.



disnt_basal = 1.4863906359189378
disnt_after = 1.6321428571428571
val_r2_mean = 0.9426412185033163
val_r2_var = -1.72282870610555



Epoch 00124: cpa_metric reached. Module best state updated.



disnt_basal = 1.4093915343915344
disnt_after = 1.6261904761904762
val_r2_mean = 0.9324842691421509
val_r2_var = -1.6842435201009114

disnt_basal = 1.5204964061096138
disnt_after = 1.6261904761904762
val_r2_mean = 0.9471523761749268
val_r2_var = -1.5992660522460938



Epoch 00149: cpa_metric reached. Module best state updated.



disnt_basal = 1.328150893481082
disnt_after = 1.6261904761904762
val_r2_mean = 0.9436269601186117
val_r2_var = -1.5186044375101726



Epoch 00154: cpa_metric reached. Module best state updated.

Epoch 00159: cpa_metric reached. Module best state updated.



disnt_basal = 1.1655959868224017
disnt_after = 1.6271825396825397
val_r2_mean = 0.9494761824607849
val_r2_var = -1.4734242757161458

disnt_basal = 1.3370931915743238
disnt_after = 1.6298280423280422
val_r2_mean = 0.9464943607648214
val_r2_var = -1.3912618160247803



Epoch 00179: cpa_metric reached. Module best state updated.



disnt_basal = 1.2019479385045422
disnt_after = 1.6291666666666667
val_r2_mean = 0.949959913889567
val_r2_var = -1.3529733816782634



Epoch 00184: cpa_metric reached. Module best state updated.

Epoch 00189: cpa_metric reached. Module best state updated.



disnt_basal = 1.127098931815913
disnt_after = 1.6271825396825397
val_r2_mean = 0.9487897555033366
val_r2_var = -1.3175447781880696

disnt_basal = 1.1907894080063892
disnt_after = 1.626851851851852
val_r2_mean = 0.946760912736257
val_r2_var = -1.2451714674631755



Epoch 00209: cpa_metric reached. Module best state updated.



disnt_basal = 1.1383785065388838
disnt_after = 1.6265211640211639
val_r2_mean = 0.9521733919779459
val_r2_var = -1.261273940404256



Epoch 00214: cpa_metric reached. Module best state updated.

Epoch 00219: cpa_metric reached. Module best state updated.



disnt_basal = 1.0802061495457722
disnt_after = 1.6275132275132274
val_r2_mean = 0.9496567845344543
val_r2_var = -1.1957556406656902

disnt_basal = 1.1374538284915643
disnt_after = 1.628505291005291
val_r2_mean = 0.9496109088261923
val_r2_var = -1.1587957541147869

disnt_basal = 1.1631364180892483
disnt_after = 1.6285052910052908
val_r2_mean = 0.9497319261233012
val_r2_var = -1.1427871386210124


100%|██████████| 1/1 [00:00<00:00, 28.10it/s]
cpa_single_pert:   2%|▏         | 45/2184 [12:25<10:10:36, 17.13s/it]



100%|██████████| 1332/1332 [00:00<00:00, 84194.54it/s]
100%|██████████| 1332/1332 [00:00<00:00, 1002838.44it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 964.76it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.7147368421052631
disnt_after = 1.738974358974359
val_r2_mean = 0.3317577640215556
val_r2_var = -2.6166324615478516



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.7187179487179487
disnt_after = 1.738974358974359
val_r2_mean = 0.7868611216545105
val_r2_var = -2.618832270304362



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.7072522558281074
disnt_after = 1.7322610722610723
val_r2_mean = 0.9200757145881653
val_r2_var = -2.60154390335083



Epoch 00034: cpa_metric reached. Module best state updated.



disnt_basal = 1.7031227011722367
disnt_after = 1.6991919191919194
val_r2_mean = 0.9321349263191223
val_r2_var = -2.5616026719411216

disnt_basal = 1.710593527559472
disnt_after = 1.6927272727272726
val_r2_mean = 0.923226793607076
val_r2_var = -2.5010201136271157



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.7145227699716865
disnt_after = 1.692975912975913
val_r2_mean = 0.925504744052887
val_r2_var = -2.447021802266439



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.7126699478711862
disnt_after = 1.6820357420357421
val_r2_mean = 0.9324873487154642
val_r2_var = -2.3843088944753013



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.7094974994046201
disnt_after = 1.6817871017871018
val_r2_mean = 0.9405357042948405
val_r2_var = -2.3210325241088867



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.5685215575618052
disnt_after = 1.6817871017871018
val_r2_mean = 0.9361584782600403
val_r2_var = -2.2471540768941245

disnt_basal = 1.7104666094139778
disnt_after = 1.7014296814296814
val_r2_mean = 0.9374831914901733
val_r2_var = -2.190711339314779

disnt_basal = 1.7036508451988328
disnt_after = 1.7051592851592852
val_r2_mean = 0.9462971091270447
val_r2_var = -2.0834268728892007


100%|██████████| 1/1 [00:00<00:00, 26.71it/s]
cpa_single_pert:   2%|▏         | 46/2184 [12:35<8:54:43, 15.01s/it] 



100%|██████████| 1522/1522 [00:00<00:00, 85291.54it/s]
100%|██████████| 1522/1522 [00:00<00:00, 971500.64it/s]


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        


Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 665.02it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6635964912280703
disnt_after = 1.9171052631578946
val_r2_mean = 0.34173675378163654
val_r2_var = -2.6320009231567383



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.649874686716792
disnt_after = 1.9171052631578944
val_r2_mean = 0.7973394791285197
val_r2_var = -2.6339941024780273



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6520596835225319
disnt_after = 1.906922376696458
val_r2_mean = 0.9280040860176086
val_r2_var = -2.619366486867269



Epoch 00034: cpa_metric reached. Module best state updated.

Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6504926532016313
disnt_after = 1.9031156901688182
val_r2_mean = 0.9454671541849772
val_r2_var = -2.5834477742513022



Epoch 00044: cpa_metric reached. Module best state updated.



disnt_basal = 1.674127106983144
disnt_after = 1.9031156901688182
val_r2_mean = 0.942569633324941
val_r2_var = -2.528730551401774



Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6642857142857144
disnt_after = 1.903210857332009
val_r2_mean = 0.9536665876706442
val_r2_var = -2.4574131965637207



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6661340852130326
disnt_after = 1.903401191658391
val_r2_mean = 0.9585198958714803
val_r2_var = -2.3795316219329834



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6650062656641604
disnt_after = 1.903496358821582
val_r2_mean = 0.9615811904271444
val_r2_var = -2.2940184275309243



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.668577694235589
disnt_after = 1.903210857332009
val_r2_mean = 0.9568602641423544
val_r2_var = -2.248454968134562



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.662218045112782
disnt_after = 1.9031156901688182
val_r2_mean = 0.9585126241048177
val_r2_var = -2.173448165257772



Epoch 00104: cpa_metric reached. Module best state updated.

Epoch 00109: cpa_metric reached. Module best state updated.



disnt_basal = 1.6595027733039591
disnt_after = 1.9028301886792451
val_r2_mean = 0.9646463592847189
val_r2_var = -2.10721762975057



Epoch 00114: cpa_metric reached. Module best state updated.



disnt_basal = 1.6259966763931657
disnt_after = 1.906636875206885
val_r2_mean = 0.9525313973426819
val_r2_var = -2.070439577102661

disnt_basal = 1.651780185758514
disnt_after = 1.908635385633896
val_r2_mean = 0.961102028687795
val_r2_var = -1.9334542751312256

disnt_basal = 1.6479612483854846
disnt_after = 1.907683714001986
val_r2_mean = 0.9646147092183431
val_r2_var = -1.8667391141255696


100%|██████████| 1/1 [00:00<00:00, 30.31it/s]
cpa_single_pert:   2%|▏         | 47/2184 [12:49<8:39:58, 14.60s/it]



100%|██████████| 1230/1230 [00:00<00:00, 61760.69it/s]
100%|██████████| 1230/1230 [00:00<00:00, 1067672.58it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 1178.84it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.5174999999999998
disnt_after = 1.5174999999999998
val_r2_mean = 0.34603158632914227
val_r2_var = -2.2757205168406167



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.5175
disnt_after = 1.5175
val_r2_mean = 0.7810220917065939
val_r2_var = -2.276139259338379



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.510751633986928
disnt_after = 1.5146212121212121
val_r2_mean = 0.9153726696968079
val_r2_var = -2.257659991582235

disnt_basal = 1.515032679738562
disnt_after = 1.46760101010101
val_r2_mean = 0.9196497797966003
val_r2_var = -2.22585121790568



Epoch 00049: cpa_metric reached. Module best state updated.



disnt_basal = 1.5175
disnt_after = 1.4613636363636364
val_r2_mean = 0.9327340722084045
val_r2_var = -2.1723763942718506



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.5174999999999998
disnt_after = 1.4623232323232322
val_r2_mean = 0.9324835737546285
val_r2_var = -2.0910534063975015



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.5156862745098039
disnt_after = 1.4613636363636364
val_r2_mean = 0.9401386777559916
val_r2_var = -1.9982295036315918



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.3949265277206453
disnt_after = 1.4613636363636364
val_r2_mean = 0.9421463211377462
val_r2_var = -1.9601251284281414

disnt_basal = 1.512516339869281
disnt_after = 1.4685606060606062
val_r2_mean = 0.9225635329882304
val_r2_var = -1.8943760395050049



Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.3008946935417525
disnt_after = 1.4666414141414141
val_r2_mean = 0.9304317434628805
val_r2_var = -1.8342969417572021



Epoch 00104: cpa_metric reached. Module best state updated.



disnt_basal = 1.2923235751176927
disnt_after = 1.4628030303030304
val_r2_mean = 0.9345953464508057
val_r2_var = -1.791979153951009

disnt_basal = 1.3944245623657388
disnt_after = 1.4618434343434346
val_r2_mean = 0.9342469175656637
val_r2_var = -1.683752139409383



Epoch 00129: cpa_metric reached. Module best state updated.



disnt_basal = 1.245692559074912
disnt_after = 1.4618434343434346
val_r2_mean = 0.94505375623703
val_r2_var = -1.5877575874328613



Epoch 00134: cpa_metric reached. Module best state updated.

Epoch 00139: cpa_metric reached. Module best state updated.



disnt_basal = 1.102790461172814
disnt_after = 1.4623232323232322
val_r2_mean = 0.9426237742106119
val_r2_var = -1.5119786262512207

disnt_basal = 1.1591510123863065
disnt_after = 1.465681818181818
val_r2_mean = 0.9450513124465942
val_r2_var = -1.4625600179036458



Epoch 00159: cpa_metric reached. Module best state updated.



disnt_basal = 1.095882810000457
disnt_after = 1.4632828282828283
val_r2_mean = 0.9414555629094442
val_r2_var = -1.4008112748463948



Epoch 00164: cpa_metric reached. Module best state updated.



disnt_basal = 1.128559691942045
disnt_after = 1.463282828282828
val_r2_mean = 0.9491080244382223
val_r2_var = -1.4021526177724202

disnt_basal = 1.1410313542666484
disnt_after = 1.4637626262626262
val_r2_mean = 0.9449966748555502
val_r2_var = -1.3326012293497722



Epoch 00189: cpa_metric reached. Module best state updated.



disnt_basal = 1.0921658896658895
disnt_after = 1.4652020202020202
val_r2_mean = 0.9459897677103678
val_r2_var = -1.3031527996063232



Epoch 00194: cpa_metric reached. Module best state updated.



disnt_basal = 1.1224008181361123
disnt_after = 1.467121212121212
val_r2_mean = 0.9438211917877197
val_r2_var = -1.2213788827260335



Epoch 00209: cpa_metric reached. Module best state updated.



disnt_basal = 1.0963413547237075
disnt_after = 1.4666414141414141
val_r2_mean = 0.9482378760973612
val_r2_var = -1.200268030166626



Epoch 00214: cpa_metric reached. Module best state updated.

Epoch 00219: cpa_metric reached. Module best state updated.



disnt_basal = 1.0785399241281595
disnt_after = 1.4685606060606062
val_r2_mean = 0.948823094367981
val_r2_var = -1.1863210201263428



Epoch 00224: cpa_metric reached. Module best state updated.



disnt_basal = 1.081477672654143
disnt_after = 1.465681818181818
val_r2_mean = 0.9484245975812277
val_r2_var = -1.1695003509521484

disnt_basal = 1.1010545500251383
disnt_after = 1.465681818181818
val_r2_mean = 0.9475379387537638
val_r2_var = -1.144802172978719



Epoch 00244: cpa_metric reached. Module best state updated.

Epoch 00249: cpa_metric reached. Module best state updated.



disnt_basal = 1.0627346999405822
disnt_after = 1.4685606060606062
val_r2_mean = 0.9454999963442484
val_r2_var = -1.1045081615447998



Epoch 00254: cpa_metric reached. Module best state updated.

Epoch 00259: cpa_metric reached. Module best state updated.



disnt_basal = 1.0477133324192147
disnt_after = 1.465681818181818
val_r2_mean = 0.9468554059664408
val_r2_var = -1.0802559852600098



Epoch 00264: cpa_metric reached. Module best state updated.

Epoch 00269: cpa_metric reached. Module best state updated.



disnt_basal = 1.0491585538644361
disnt_after = 1.469040404040404
val_r2_mean = 0.9483922918637594
val_r2_var = -1.066567103068034



Epoch 00274: cpa_metric reached. Module best state updated.



disnt_basal = 1.0517971570912747
disnt_after = 1.47
val_r2_mean = 0.9500649174054464
val_r2_var = -1.0870505174001057

disnt_basal = 1.061274167009461
disnt_after = 1.4685606060606062
val_r2_mean = 0.9498991370201111
val_r2_var = -1.0591905911763508



Epoch 00299: cpa_metric reached. Module best state updated.



disnt_basal = 1.0480259609671374
disnt_after = 1.469040404040404
val_r2_mean = 0.9498939911524454
val_r2_var = -1.0485432147979736



Epoch 00304: cpa_metric reached. Module best state updated.



disnt_basal = 1.063820787056081
disnt_after = 1.469520202020202
val_r2_mean = 0.9504003524780273
val_r2_var = -1.059181849161784

disnt_basal = 1.0551750537044653
disnt_after = 1.4666414141414141
val_r2_mean = 0.9501621127128601
val_r2_var = -1.0259068806966145



Epoch 00329: cpa_metric reached. Module best state updated.



disnt_basal = 1.0457216966040497
disnt_after = 1.468080808080808
val_r2_mean = 0.9484994411468506
val_r2_var = -1.0295904874801636



Epoch 00334: cpa_metric reached. Module best state updated.

Epoch 00339: cpa_metric reached. Module best state updated.



disnt_basal = 1.046625074272133
disnt_after = 1.4704797979797979
val_r2_mean = 0.9481314420700073
val_r2_var = -1.01981520652771



Epoch 00344: cpa_metric reached. Module best state updated.



disnt_basal = 1.0492171717171717
disnt_after = 1.4671212121212123
val_r2_mean = 0.946446935335795
val_r2_var = -0.9972227811813354



Epoch 00354: cpa_metric reached. Module best state updated.



disnt_basal = 1.041471959413136
disnt_after = 1.468560606060606
val_r2_mean = 0.9481274286905924
val_r2_var = -1.0020426114400227

disnt_basal = 1.050138260432378
disnt_after = 1.4671212121212123
val_r2_mean = 0.948376476764679
val_r2_var = -1.005109707514445

disnt_basal = 1.0395763060468943
disnt_after = 1.466161616161616
val_r2_mean = 0.9475937883059183
val_r2_var = -1.010766625404358


100%|██████████| 1/1 [00:00<00:00, 23.91it/s]
cpa_single_pert:   2%|▏         | 48/2184 [13:26<12:35:09, 21.21s/it]



100%|██████████| 1396/1396 [00:00<00:00, 79980.99it/s]
100%|██████████| 1396/1396 [00:00<00:00, 1070429.32it/s]


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        


Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 932.27it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.662121212121212
disnt_after = 1.7941176470588234
val_r2_mean = 0.304043710231781
val_r2_var = -2.401758829752604



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6634135472370766
disnt_after = 1.7941176470588234
val_r2_mean = 0.7739771803220113
val_r2_var = -2.4041152795155845



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6466659391029137
disnt_after = 1.7812051649928262
val_r2_mean = 0.9172187248865763
val_r2_var = -2.387101729710897

disnt_basal = 1.6505547673614902
disnt_after = 1.7573170731707315
val_r2_mean = 0.927110513051351
val_r2_var = -2.356032053629557



Epoch 00044: cpa_metric reached. Module best state updated.



disnt_basal = 1.6678921568627452
disnt_after = 1.7594153515064561
val_r2_mean = 0.9328897595405579
val_r2_var = -2.313055912653605



Epoch 00054: cpa_metric reached. Module best state updated.

Epoch 00059: cpa_metric reached. Module best state updated.



disnt_basal = 1.6695855614973263
disnt_after = 1.7579626972740314
val_r2_mean = 0.9460864861806234
val_r2_var = -2.2197279135386148



Epoch 00064: cpa_metric reached. Module best state updated.

Epoch 00069: cpa_metric reached. Module best state updated.



disnt_basal = 1.6816844919786096
disnt_after = 1.760060975609756
val_r2_mean = 0.951788067817688
val_r2_var = -2.177745262781779



Epoch 00074: cpa_metric reached. Module best state updated.

Epoch 00079: cpa_metric reached. Module best state updated.



disnt_basal = 1.6849094183127797
disnt_after = 1.7578012912482064
val_r2_mean = 0.9558479984601339
val_r2_var = -2.1233274936676025



Epoch 00084: cpa_metric reached. Module best state updated.

Epoch 00089: cpa_metric reached. Module best state updated.



disnt_basal = 1.6888637273091054
disnt_after = 1.7579626972740314
val_r2_mean = 0.9589632749557495
val_r2_var = -2.056976079940796



Epoch 00094: cpa_metric reached. Module best state updated.

Epoch 00099: cpa_metric reached. Module best state updated.



disnt_basal = 1.5968105534896888
disnt_after = 1.7574784791965565
val_r2_mean = 0.9458276033401489
val_r2_var = -1.9597012996673584



Epoch 00104: cpa_metric reached. Module best state updated.



disnt_basal = 1.6654700239297493
disnt_after = 1.760060975609756
val_r2_mean = 0.9283552964528402
val_r2_var = -1.8560153643290203

disnt_basal = 1.6759350081496014
disnt_after = 1.7624820659971305
val_r2_mean = 0.9441032409667969
val_r2_var = -1.7351877689361572

disnt_basal = 1.6729043613010615
disnt_after = 1.7619978479196554
val_r2_mean = 0.9528518915176392
val_r2_var = -1.670147180557251


100%|██████████| 1/1 [00:00<00:00, 22.17it/s]
cpa_single_pert:   2%|▏         | 49/2184 [13:37<10:52:15, 18.33s/it]



100%|██████████| 1256/1256 [00:00<00:00, 82096.43it/s]
100%|██████████| 1256/1256 [00:00<00:00, 1062748.80it/s]

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        



Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 770.59it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

Epoch 00004: cpa_metric reached. Module best state updated.

Epoch 00009: cpa_metric reached. Module best state updated.



disnt_basal = 1.6174863387978142
disnt_after = 1.6174863387978142
val_r2_mean = 0.3224262197812398
val_r2_var = -2.4001025358835855



Epoch 00014: cpa_metric reached. Module best state updated.

Epoch 00019: cpa_metric reached. Module best state updated.



disnt_basal = 1.6174863387978142
disnt_after = 1.6174863387978142
val_r2_mean = 0.7749389012654623
val_r2_var = -2.400305986404419



Epoch 00024: cpa_metric reached. Module best state updated.

Epoch 00029: cpa_metric reached. Module best state updated.



disnt_basal = 1.6174863387978142
disnt_after = 1.5806719287593607
val_r2_mean = 0.9061456521352133
val_r2_var = -2.384357213973999



Epoch 00039: cpa_metric reached. Module best state updated.



disnt_basal = 1.6174863387978142
disnt_after = 1.5555555555555556
val_r2_mean = 0.924484392007192
val_r2_var = -2.3337789376576743


100%|██████████| 1/1 [00:00<00:00, 35.02it/s]
cpa_single_pert:   2%|▏         | 49/2184 [13:41<9:56:39, 16.77s/it] 

KeyboardInterrupt



# 尝试看看cell_line_bulk的并行能否成功

In [52]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
from tqdm import tqdm
from sklearn.metrics import precision_recall_curve, auc
from scipy.spatial.distance import cdist
import concurrent.futures
import json
import anndata as ad


Process SpawnProcess-436:
Traceback (most recent call last):
  File "/data1/lichen/anaconda3/envs/cpa/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/data1/lichen/anaconda3/envs/cpa/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/data1/lichen/anaconda3/envs/cpa/lib/python3.9/concurrent/futures/process.py", line 240, in _process_worker
    call_item = call_queue.get(block=True)
  File "/data1/lichen/anaconda3/envs/cpa/lib/python3.9/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
AttributeError: Can't get attribute 'process_cell_line' on <module '__main__' (built-in)>
Process SpawnProcess-437:
Traceback (most recent call last):
  File "/data1/lichen/anaconda3/envs/cpa/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/data1/lichen/anaconda3/envs/cpa/lib/python3.9/multiprocessing/process.py", line 108, in r

Error: A process in the process pool was terminated abruptly while the future was running or pending.
Error: A process in the process pool was terminated abruptly while the future was running or pending.
Error: A process in the process pool was terminated abruptly while the future was running or pending.
Error: A process in the process pool was terminated abruptly while the future was running or pending.
Error: A process in the process pool was terminated abruptly while the future was running or pending.
Error: A process in the process pool was terminated abruptly while the future was running or pending.
Error: A process in the process pool was terminated abruptly while the future was running or pending.


In [57]:


    
    


# 定义处理每个 cell_line_single 的函数
def process_cell_line(cell_line_bulk, cell_line_single, common_cell_line, adata_L1000, gpu_id):
    print('=' * 20, f'cell line is {cell_line_single}')
    
    torch.cuda.set_device(gpu_id)  # 设置每个进程使用不同的 GPU

    #===================prepare data
    if cell_line_bulk in ['PC3', 'A375']:
        save_dir_adata = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/single_cell_data/SCP542'
        adata_rna = sc.read(os.path.join(save_dir_adata, cell_line_bulk, f'adata_{cell_line_bulk}.h5ad'))
        
        # - read adata_rna_raw
        save_dir = f'/nfs/public/lichen/data/single_cell/cell_line/SCP542/process/{cell_line_bulk}'
        adata_rna_raw = sc.read(os.path.join(save_dir, f'adata.h5ad'))
    
    else:
        save_dir_adata = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/single_cell_data/CNP0003658'
        adata_rna = sc.read(os.path.join(save_dir_adata, cell_line_bulk, f'adata_{cell_line_bulk}.h5ad'))

        # - read adata_rna
        save_dir = f'/nfs/public/lichen/data/single_cell/cell_line/CNP0003658/process/RNA/{cell_line_single}'
        adata_rna_raw = sc.read(os.path.join(save_dir, f'adata_rna_{cell_line_single}.h5ad'))

    # - consctrut corr mtx
    if not isinstance(adata_rna.X, np.ndarray):
        adata_rna.X = adata_rna.X.toarray()
    corr_mtx = np.corrcoef(adata_rna.X.T)
    
    # - get var_names
    var_names = list(adata_rna.var_names)
    
    # - get common pert
    adata_L1000_sub = adata_L1000[adata_L1000.obs['cell_id']==cell_line_bulk]
    L1000_total_perts = np.unique(adata_L1000_sub.obs['pert_iname'])
    common_perts = np.intersect1d(adata_rna.var_names, L1000_total_perts)

    
    
    print('common_perts num: ', len(common_perts))
    print('common var to L1000 data is: ', len(np.intersect1d(var_names, adata_L1000.var_names)))

    
    adata_pert_list = []
    pert_gene_rank_dict = {}
    for pert in tqdm(common_perts, desc='cpa_single_pert'):
        pert_gene_rank_dict[pert] = cpa_single_pert(pert)
        
        
    save_dir = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark_202410/zero_shot/result'
    save_prefix = f'CPA/{cell_line_bulk}' # use result of K562 to do the direct transfer
    os.makedirs(os.path.join(save_dir, save_prefix), exist_ok=True)

    import json
    with open(os.path.join(save_dir, save_prefix, 'pert_gene_rank_dict.json'), 'w') as f:
        json.dump(pert_gene_rank_dict, f)


In [70]:
import cpa

In [68]:
sys.path.append("/data1/lichen/code/single_cell_perturbation/scPerturb_202410/zero_shot/method/CPA/")



In [83]:
import importlib
importlib.reload(cpa_func_lc)

<module 'cpa_func_lc' from '/data1/lichen/code/single_cell_perturbation/scPerturb_202410/zero_shot/method/CPA/cpa_func_lc.py'>

In [84]:
import cpa_func_lc
from cpa_func_lc import process_cell_line

In [90]:

import multiprocessing
from multiprocessing import Manager

# 主函数
if __name__ == "__main__":
    # - get cell line name
    common_cell_line = \
    {   'A549': 'A549',
        'HEPG2': 'HepG2',
        'HT29': 'HT29',
        'MCF7': 'MCF7',
        # 'SKBR3': 'SK-BR-3',
        'SW480': 'SW480',
        'PC3': 'PC3',
        'A375': 'A375',
    } # L1000 cell line : single-cell cell line
    
    multiprocessing.set_start_method('spawn', force=True)

    # - read adata_L1000, this is processed data
    adata_L1000 = sc.read('/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/GSE92742/adata_gene_pert.h5ad')

    

    # 使用并行执行
    with concurrent.futures.ProcessPoolExecutor() as executor:
        futures = [
            executor.submit(process_cell_line, adata_ori, cell_line_bulk, common_cell_line[cell_line_bulk], adata_L1000)
            for cell_line_bulk in common_cell_line.keys()
        ]
        
        # 等待所有任务完成
        for future in concurrent.futures.as_completed(futures):
            try:
                future.result()  # 获取每个任务的结果，如果有异常，将在此处抛出
            except Exception as e:
                print(f"Error: {e}")


Global seed set to 0


Error: name 'pert_data' is not defined
common_perts num:  3096
common var to L1000 data is:  933
AARS


cpa_single_pert:   0%|          | 0/3096 [00:00<?, ?it/s]


In [50]:
import multiprocessing
multiprocessing.set_start_method('spawn', force=True)

In [None]:
multiprocessing.get_start_method()

'spawn'

# cell_line_bulk并行-torch multiprocessing

In [56]:
import torch.multiprocessing as mp

# 主函数
if __name__ == "__main__":
    # - get cell line name
    common_cell_line = {
        'A549': 'A549',
        'HEPG2': 'HepG2',
        'HT29': 'HT29',
        'MCF7': 'MCF7',
        'SW480': 'SW480',
        'PC3': 'PC3',
        'A375': 'A375',
    } # L1000 cell line : single-cell cell line

    # 设置多进程启动方式为 'spawn'
    mp.set_start_method('spawn', force=True)

    # - read adata_L1000, this is processed data
    adata_L1000 = sc.read('/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/GSE92742/adata_gene_pert.h5ad')

    # 获取 GPU 数量，并为每个进程分配一个 GPU
    num_gpus = torch.cuda.device_count()

    # 使用 torch.multiprocessing 进行并行处理
    processes = []
    for i, cell_line_bulk in enumerate(common_cell_line.keys()):
        gpu_id = i % num_gpus  # 为每个进程分配不同的 GPU
        cell_line_single = common_cell_line[cell_line_bulk]
        p = mp.Process(target=process_cell_line, args=(cell_line_bulk, cell_line_single, common_cell_line, adata_L1000, gpu_id))
        p.start()
        processes.append(p)

    # 等待所有进程完成
    for p in processes:
        p.join()

    print("All processes completed.")

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/data1/lichen/anaconda3/envs/cpa/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/data1/lichen/anaconda3/envs/cpa/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'process_cell_line' on <module '__main__' (built-in)>


KeyboardInterrupt: 

In [61]:
import torch.multiprocessing as mp
import scanpy as sc
import dill as pickle  # 使用 dill 代替 pickle


# 必须使用 dill 来序列化多进程函数
if __name__ == "__main__":
    # mp.set_start_method('spawn', force=True)
    # mp.set_executable(mp.get_executable())

    # - get cell line name
    common_cell_line = {
        'A549': 'A549',
        'HEPG2': 'HepG2',
        'HT29': 'HT29',
        'MCF7': 'MCF7',
        'SW480': 'SW480',
        'PC3': 'PC3',
        'A375': 'A375',
    }

    # - read adata_L1000, this is processed data
    adata_L1000 = sc.read('/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/GSE92742/adata_gene_pert.h5ad')

    # 获取 GPU 数量，并为每个进程分配一个 GPU
    num_gpus = torch.cuda.device_count()

    # 使用 torch.multiprocessing 进行并行处理
    processes = []
    for i, cell_line_bulk in enumerate(common_cell_line.keys()):
        gpu_id = i % num_gpus  # 为每个进程分配不同的 GPU
        cell_line_single = common_cell_line[cell_line_bulk]
        p = mp.Process(target=process_cell_line, args=(cell_line_bulk, cell_line_single, common_cell_line, adata_L1000, gpu_id))
        p.start()
        processes.append(p)

    # 等待所有进程完成
    for p in processes:
        p.join()

    print("All processes completed.")


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/data1/lichen/anaconda3/envs/cpa/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/data1/lichen/anaconda3/envs/cpa/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'process_cell_line' on <module '__main__' (built-in)>


KeyboardInterrupt: 

# gpt改写并行版本

In [None]:
import concurrent.futures
import pandas as pd
import numpy as np
from anndata import AnnData
import scanpy as sc

# 改进为返回值形式以便收集结果
def cpa_single_pert(pert, gpu_id = 0):
    
    # 设置 GPU
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    print(f"Starting training for {pert} on GPU {gpu_id}")
    
    print('='*20, pert)
    global pert_data, adata_ori, adata_rna, common_var, cell_line_bulk, model_params, trainer_params

    # - get adata_pert and adata_ctrl
    adata_pert = pert_data.adata_split[pert_data.adata_split[pert_data.adata_split.obs['perturbation_group']==pert+' | K562'].obs_names]
    adata_ctrl = pert_data.adata_split[adata_pert.obs['control_barcode']]

    adata_pert = adata_ori[adata_pert.obs_names, common_var]
    adata_ctrl = adata_ori[adata_ctrl.obs_names, common_var]

    # - get adata_rna_common
    adata_rna_common = adata_rna[:, common_var]

    # - generate adata_train to input to scGen model
    np_list, obs_list, pert_list, celltype_list = [], [], [], []
    adata_list = [adata_pert, adata_rna_common, adata_ctrl, adata_rna_common]
    for j, adata_ in enumerate(adata_list):
        if j in [0, 1]:
            pert_list.extend(['stimulated']*len(adata_))
        else:
            pert_list.extend(['ctrl']*len(adata_))

        if j in [0, 2]:
            celltype_list.extend(['K562']*len(adata_))
        else:
            celltype_list.extend([cell_line_bulk]*len(adata_))
        obs_list.extend([obs+f'_{j}' for obs in adata_.obs_names])
        np_list.append(adata_.X.toarray())

    adata_train = AnnData(X=np.vstack(np_list))
    adata_train.obs_names = obs_list
    adata_train.var_names = adata_pert.var_names

    adata_train.obs['condition'] = pert_list
    adata_train.obs['cell_type'] = celltype_list

    # - transform the adata_train.X to count
    adata_train.obs['cov_cond'] = adata_train.obs['cell_type'] + '_' + adata_train.obs['condition']
    adata_train.X = np.exp(adata_train.X) - 1

    # - initialize model
    cpa.CPA.setup_anndata(adata_train, 
                          perturbation_key='condition',
                          control_group='ctrl',
                          categorical_covariate_keys=['cell_type'],
                          is_count_data=True,
                          deg_uns_cat_key='cov_cond',
                          max_comb_len=1)
    
    print('adata_train.shape: ', adata_train.shape)

    adata_train_new = adata_train[~((adata_train.obs["cell_type"] == cell_line_bulk) &
                                    (adata_train.obs["condition"] == "stimulated"))].copy()

    obs_df_sub_idx = np.array(adata_train_new.obs.index)
    np.random.seed(2024)
    np.random.shuffle(obs_df_sub_idx)

    split_point_1 = int(len(obs_df_sub_idx) * 0.9)
    split_point_2 = int(len(obs_df_sub_idx) * (0.9 + 0.1))
    train = obs_df_sub_idx[:split_point_1]
    valid = obs_df_sub_idx[split_point_1:split_point_2]

    adata_train.obs['split_key'] = 'ood'
    adata_train.obs.loc[train, 'split_key'] = 'train'
    adata_train.obs.loc[valid, 'split_key'] = 'valid'
    
    print('start training......')

    # - train the model
    model = cpa.CPA(adata=adata_train, 
                    split_key='split_key',
                    train_split='train',
                    valid_split='valid',
                    test_split='ood',
                    **model_params)

    model.train(max_epochs=2000,
                use_gpu=True, 
                batch_size=500,
                plan_kwargs=trainer_params,
                early_stopping_patience=5,
                check_val_every_n_epoch=5,
                progress_bar_refresh_rate=0)
    
    print('finish training')

    model.predict(adata_train, batch_size=2048)

    # - get the predicted data
    cat = cell_line_bulk + '_' + 'stimulated'
    cat_adata = adata_train[adata_train.obs['cov_cond'] == cat].copy()
    x_pred = cat_adata.obsm['CPA_pred']
    x_pred = np.log1p(x_pred)

    adata_ctrl = adata_rna.copy()
    adata_pert = adata_ctrl.copy()
    adata_pert.X = x_pred
    adata_pert.obs_names = [i + f'_{pert}' for i in adata_pert.obs_names]

    adata_ctrl.obs['batch'] = 'ctrl'
    adata_pert.obs['batch'] = 'pert'

    adata_concat = sc.concat([adata_ctrl, adata_pert])

    # - calculate DE genes
    sc.tl.rank_genes_groups(
        adata_concat,
        groupby='batch',
        reference='ctrl',
        rankby_abs=False,
        n_genes=len(adata_concat.var),
        use_raw=False,
        method='wilcoxon'
    )
    
    de_genes = pd.DataFrame(adata_concat.uns['rank_genes_groups']['names'])
    scores = pd.DataFrame(adata_concat.uns['rank_genes_groups']['scores'])

    # - return gene score for the current perturbation
    return pert, (list(de_genes['pert']), list(scores['pert']))



In [None]:
# - concurrent使用gpu

if __name__ == "__main__":
    gpu_ids = [0, 1, 2, 3]  # 假设有 4 张 GPU
    
    with concurrent.futures.ProcessPoolExecutor(max_workers=len(gpu_ids)) as executor:
        futures = {executor.submit(cpa_single_pert, pert, gpu_ids[i % len(gpu_ids)]): pert for i, pert in enumerate(common_perts[0:4])}
        
        for future in concurrent.futures.as_completed(futures):
            pert, result = future.result()
            print(f"Task for {pert} completed with result: {result}")

Starting training for ABCB7 on GPU 2Starting training for ABCB6 on GPU 1Starting training for AATF on GPU 0
Starting training for ABCC1 on GPU 3

AATFABCC1





  0%|          | 0/1488 [00:00<?, ?it/s]



100%|██████████| 1364/1364 [00:00<00:00, 83104.99it/s]
100%|██████████| 1488/1488 [00:00<00:00, 82249.93it/s]
100%|██████████| 1364/1364 [00:00<00:00, 647908.34it/s]
100%|██████████| 1470/1470 [00:00<00:00, 75589.72it/s]
100%|██████████| 1488/1488 [00:00<00:00, 770604.32it/s]
100%|██████████| 1402/1402 [00:00<00:00, 83408.95it/s]
100%|██████████| 1470/1470 [00:00<00:00, 591938.06it/s]
100%|██████████| 1402/1402 [00:00<00:00, 1088763.97it/s]
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
adata_train.shape:  (1470, 2764)adata_train.shape: 
 (1488, 2764)adata_train.shape: 
 (1402, 2764)
start training......
start training......start training......



Global seed set to 6977


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        


Global seed set to 6977
Global seed set to 6977


adata_train.shape:  (1364, 2764)
start training......


Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 770.80it/s]
100%|██████████| 2/2 [00:00<00:00, 628.74it/s]
100%|██████████| 2/2 [00:00<00:00, 967.21it/s]
100%|██████████| 2/2 [00:00<00:00, 934.77it/s]


RuntimeError: CUDA error: initialization error
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
import torch.multiprocessing as mp
import os

if __name__ == "__main__":
    gpu_ids = [0, 1, 2, 3]  # 假设有 4 张 GPU
    
    # 使用 torch.multiprocessing
    with mp.Pool(processes=len(gpu_ids)) as pool:
        results = [pool.apply_async(cpa_single_pert, args=(pert, gpu_ids[i % len(gpu_ids)])) for i, pert in enumerate(common_perts[0:4])]
        pool.close()
        pool.join()

        for result in results:
            print(result.get())

Starting training for ABCB6 on GPU 1

  Starting training for ABCC1 on GPU 3AATFABCB7






  0%|          | 0/1402 [00:00<?, ?it/s]



100%|██████████| 1402/1402 [00:00<00:00, 81915.89it/s]
100%|██████████| 1488/1488 [00:00<00:00, 78793.12it/s]
100%|██████████| 1402/1402 [00:00<00:00, 678483.24it/s]
100%|██████████| 1470/1470 [00:00<00:00, 68405.88it/s]
100%|██████████| 1364/1364 [00:00<00:00, 74246.07it/s]
100%|██████████| 1470/1470 [00:00<00:00, 405579.98it/s]
100%|██████████| 1488/1488 [00:00<00:00, 205964.11it/s]
100%|██████████| 1364/1364 [00:00<00:00, 706823.65it/s]
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
adata_train.shape: adata_train.shape: adata_train.shape:    (1364, 2764)(1470, 2764)(1488, 2764)


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generati

Global seed set to 6977
Global seed set to 6977
Global seed set to 6977


start training......


Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 835.27it/s]
100%|██████████| 2/2 [00:00<00:00, 598.03it/s]
100%|██████████| 2/2 [00:00<00:00, 495.17it/s]
100%|██████████| 2/2 [00:00<00:00, 1058.63it/s]


RuntimeError: CUDA error: initialization error
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:

# 使用并行处理
pert_gene_rank_dict = {}

# 并行执行每个 perturbation
with concurrent.futures.ProcessPoolExecutor() as executor:
    futures = {executor.submit(cpa_single_pert, pert): pert for pert in common_perts[0:5]}

    # 使用 tqdm 显示进度条
    for future in tqdm(concurrent.futures.as_completed(futures), total=len(common_perts[0:5])):
        pert, gene_scores = future.result()
        pert_gene_rank_dict[pert] = gene_scores


In [None]:
import multiprocessing

def collect_result(result):
    pert, data = result
    pert_gene_rank_dict[pert] = data
    
if __name__ == "__main__":
    manager = multiprocessing.Manager()
    pert_gene_rank_dict = manager.dict()  # 用 Manager 使其在进程之间共享
    
    pool = multiprocessing.Pool(processes=5)
    # pool = multiprocessing.cpu_count()
    # common_perts = [...]  # 你的 perturbations 列表
    
    # 使用 pool 并行执行
    for pert in common_perts[0:5]:
        pool.apply_async(cpa_single_pert, args=(pert,), callback=collect_result)

    pool.close()
    pool.join()  # 等待所有任务完成


ABCC10



100%|██████████| 1262/1262 [00:00<00:00, 79446.64it/s]
100%|██████████| 1364/1364 [00:00<00:00, 79390.40it/s]
100%|██████████| 1262/1262 [00:00<00:00, 506527.43it/s]
100%|██████████| 1402/1402 [00:00<00:00, 71205.25it/s]
100%|██████████| 1470/1470 [00:00<00:00, 76336.55it/s]
100%|██████████| 1364/1364 [00:00<00:00, 365071.19it/s]
100%|██████████| 1488/1488 [00:00<00:00, 69090.19it/s]
100%|██████████| 1402/1402 [00:00<00:00, 373834.34it/s]
100%|██████████| 1470/1470 [00:00<00:00, 467943.75it/s]
100%|██████████| 1488/1488 [00:00<00:00, 323257.08it/s]
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
A

[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
adata_train.shape:  [34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                                                        
(1262, 2764)
[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Generating sequential column names                                        

Global seed set to 6977


start training......start training......
start training......start training......




Global seed set to 6977
Global seed set to 6977
Global seed set to 6977
Global seed set to 6977
100%|██████████| 2/2 [00:00<00:00, 842.91it/s]
100%|██████████| 2/2 [00:00<00:00, 1026.76it/s]
100%|██████████| 2/2 [00:00<00:00, 797.40it/s]
100%|██████████| 2/2 [00:00<00:00, 783.10it/s]
100%|██████████| 2/2 [00:00<00:00, 570.19it/s]


In [None]:
print(pert_gene_rank_dict)

{}
