In [None]:
#| default_exp experiment

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [None]:
# | export
from relax.import_essentials import *
from relax.data import TabularDataModule, DataLoader
from relax.trainer import train_model_with_states, TrainingConfigs
from relax.evaluate import (
    Explanation,
    accuracy,
    evaluate_cfs,
    benchmark_cfs,
    generate_cf_explanations,
    _AuxPredFn,
    BaseEvalMetrics,
    Validity,
    Proximity,
    PredictiveAccuracy
)
from relax.module import BaseTrainingModule
from relax.methods.base import BaseCFModule, BaseParametricCFModule, BasePredFnCFModule
from relax.utils import validate_configs, proximity

from rocourse_net.module import RoCourseNetTrainingModule
from copy import deepcopy
from functools import partial
from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder
import wandb
from pydantic import validator


### Evaluation Metrics

In [None]:
#| export
def compute_rob_validity(cf_results: Explanation, shifted_pred_fn: Callable):
    pred_fn = cf_results.pred_fn
    y_pred = pred_fn(cf_results.X).reshape(-1, 1).round()
    y_prime = 1. - y_pred
    cf_y = shifted_pred_fn(cf_results.cfs).reshape(-1, 1).round()

    rob_validity = accuracy(y_prime, cf_y)
    return rob_validity.item()

### Util Methods 

In [None]:
#| export
def _aggregate_default_data_encoders(default_data_config: Dict[str, Any], data_dir_list: List[str]):
    # data encoding
    data = pd.concat(
        [pd.read_csv(data_dir) for data_dir in data_dir_list]
    )
    print(f"total data length: {len(data)}")
    if len(default_data_config['continous_cols']) != 0:
        print("preprocessing continuous features...")
        normalizer = MinMaxScaler().fit(
            data[default_data_config['continous_cols']]
        )
        default_data_config.update({"normalizer": normalizer})

    if len(default_data_config['discret_cols']) != 0:
        print("preprocessing discret features...")
        encoder = OneHotEncoder(sparse=False).fit(
            data[default_data_config['discret_cols']]
        )
        default_data_config.update({"encoder": encoder})
    return default_data_config

In [None]:
#| export
def calculate_validity_matrix(
    cf_results_list: Iterable[Explanation]
) -> pd.DataFrame:
    validity_matrix_dict = {}
    for i, cf_results_i in enumerate(cf_results_list):
        data_name = cf_results_i.dataset_name
        rob_validity = {}
        for j, cf_results_j in enumerate(cf_results_list):
            cf_name_j = cf_results_j.cf_name
            shifted_pred_fn = cf_results_j.pred_fn
            val = compute_rob_validity(
                cf_results_i, shifted_pred_fn
            )
            rob_validity[cf_name_j] = val
            # print(f'data_name: {data_name}; cf_name: {cf_name_j}; shifted_pred_fn: {shifted_pred_fn}; val: {val} ')
        validity_matrix_dict[data_name] = rob_validity
    return pd.DataFrame.from_dict(validity_matrix_dict)

### Training Models

In [None]:
#| export
def get_datamodules(
    default_data_configs: Dict[str, Any],
    data_dir_list: List[str]
):
    data_module_list = []
    for data_dir in data_dir_list:
        data_config = deepcopy(default_data_configs)
        data_config['data_dir'] = data_dir
        data_module = TabularDataModule(data_config)
        data_module_list.append(data_module)
    return data_module_list

In [None]:
#| export
class FasterTabularDataModule(TabularDataModule):
    def train_dataloader(self, batch_size):
        return DataLoader(self.train_dataset, self._configs.backend, 
            batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True
        )

    def val_dataloader(self, batch_size):
        return DataLoader(self.val_dataset, self._configs.backend,
            batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True
        )

    def test_dataloader(self, batch_size):
        return DataLoader(self.val_dataset, self._configs.backend,
            batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True
        ) 

In [None]:
#| export
def train_models(
    training_module: BaseTrainingModule,
    default_data_configs: Dict[str, Any],
    data_dir_list: List[str],
    t_configs: Dict[str, Any],
    return_data_module_list: bool = False,
    use_fast: bool = True
):
    model_params_opt_list = []
    data_module_list = []

    for i, data_dir in enumerate(data_dir_list):
        data_config = deepcopy(default_data_configs)
        data_config['data_dir'] = data_dir
        if use_fast:
            dm = FasterTabularDataModule(data_config)
        else:
            dm = TabularDataModule(data_config)
        params, opt_state = training_module.init_net_opt(
            dm, random.PRNGKey(42)
        )
        
        params, opt_state = train_model_with_states(
            training_module, params, opt_state, dm, t_configs
        )
        
        model_params_opt_list.append((params, opt_state))
        data_module_list.append(dm)
    if return_data_module_list:
        return model_params_opt_list, data_module_list
    else:
        return model_params_opt_list

### Log to Wandb

In [None]:
#| export
def wandb_vis_table(table: pd.DataFrame):
    return wandb.Table(dataframe=table)

def wandb_vis_heatmap(heatmap: pd.DataFrame):
    return wandb.plots.HeatMap(x_labels=heatmap.columns, y_labels=heatmap.index, matrix_values=heatmap.values, show_text=False)

DATA_TYPE_TO_VIS_FN = {
    'table': wandb_vis_table,
    'heatmap': wandb_vis_heatmap
}

In [None]:
#| exporti
class ExperimentResult(BaseParser):
    name: str
    data_type: str
    data: Any

    @validator('data_type')
    def validate_data_type(cls, v):
        if v not in DATA_TYPE_TO_VIS_FN.keys():
            raise ValueError(f"`data_type` should be one of {DATA_TYPE_TO_VIS_FN.keys()}, but got {v}")
        return v

    def wandb_vis(self):
        return DATA_TYPE_TO_VIS_FN[self.data_type](self.data)

In [None]:
#| export
class ExperimentLogger(ABC):
    @abstractmethod
    def store_results(self, results: List[ExperimentResult]):
        raise NotImplementedError

In [None]:
#| export
class ExperimentLoggerWanbConfigs(BaseParser):
    project_name: str                           # `project`
    user_name: str                              # `entity`
    experiment_name: str                        # `name`
    hparams: Optional[Dict[str, Any]] = None    # hypterparamters


class ExperimentLoggerWanb(ExperimentLogger):
    def __init__(self, configs: ExperimentLoggerWanbConfigs):
        super().__init__()
        self.run = wandb.init(
            project=configs.project_name, entity=configs.user_name, name=configs.experiment_name, config=configs.hparams,
            settings=wandb.Settings(start_method="fork"))

    def store_results(self, results: List[ExperimentResult]):
        with self.run as run:
            run.log({
                r.name: r.wandb_vis() for r in results
            })
        return self.run.dir

### Metrics

In [None]:
#| export
class NormalizedProximity(BaseEvalMetrics):
    def __str__(self):
        return "NormalizedProximity"
    
    """Normalized proximity of counterfactuals to the original instance."""
    def __call__(self, cf_explanations: Explanation) -> float:
        X, _ = cf_explanations.data_module.test_dataset[:]
        return (proximity(X, cf_explanations.cfs) / X.shape[1]).item()

### Experiment

In [None]:
#| exporti
def calculate_validity_changes(val_matrix_df: pd.DataFrame):
    assert len(val_matrix_df.columns) == len(val_matrix_df.index), \
        f"val_matrix_df.columns={val_matrix_df.columns}, but val_matrix_df.index={val_matrix_df.index}"
    matrix = val_matrix_df.values
    n_datasets = len(matrix[0])

    validity = [matrix[i][i] for i in range(n_datasets)]
    w1_val_result, w1_dec_result = [], []
    for i in range(n_datasets - 1):
        w1_val_result.append(matrix[i][i+1])
        w1_dec_result.append(matrix[i][i] - matrix[i][i+1])

    wall_val_result, wall_dec_result =  [], []
    for i in range(n_datasets):
        for j in range(n_datasets):
            if j == i: continue
            wall_val_result.append(matrix[i][j])
            wall_dec_result.append(matrix[i][i] - matrix[i][j])
    result = {
        'cf_validity': { 'mean': np.average(validity), 'std': np.std(validity) },
        'cf_validity (w=1)': { 'mean': np.average(w1_val_result), 'std': np.std(w1_val_result) },
        'cf_validity (all)': { 'mean': np.average(wall_val_result), 'std': np.std(wall_val_result) },
        'validity_decrease (w=1)': { 'mean': np.average(w1_dec_result), 'std': np.std(w1_dec_result) },
        'validity_decrease (all)': { 'mean': np.average(wall_dec_result), 'std': np.std(wall_dec_result) }
    }
    return pd.DataFrame.from_dict(result)

In [None]:
# | exporti
def _evaluate_adversarial_model(
    cf_results_list: Iterable[Explanation],
    experiment_logger_configs: Optional[ExperimentLoggerWanbConfigs] = None,
):
    cf_results_df = benchmark_cfs(
        cf_results_list,
        # metrics=[
        #     PredictiveAccuracy(),
        #     Validity(),
        #     NormalizedProximity(),
        # ],
    )
    cf_aggre_df = (
        cf_results_df.describe()
        .loc[["mean", "std"]]
        .reset_index()
        .rename(columns={"index": "stat"})
    )

    print("calculating the validity matrix...")
    validity_matrix_df = calculate_validity_matrix(cf_results_list=cf_results_list)
    valditity_changes = calculate_validity_changes(validity_matrix_df)

    experiment_results = [
        ExperimentResult(name="CF Results", data_type="table", data=cf_results_df),
        ExperimentResult(name="CF Metrics", data_type="table", data=cf_aggre_df),
        ExperimentResult(name="Heatmap", data_type="heatmap", data=validity_matrix_df),
        ExperimentResult(name="Validity Matrix", data_type="table", data=validity_matrix_df),
        ExperimentResult(name="Validity Changes", data_type="table", data=valditity_changes),
    ]

    if experiment_logger_configs:
        logger = ExperimentLoggerWanb(experiment_logger_configs)
        dir_path = logger.store_results(experiment_results)
        print(f"Results stored at {dir_path}")
    return experiment_results


In [None]:
#| export
def adversarial_experiment(
    pred_training_module: BaseTrainingModule,
    cf_module: BaseCFModule,
    default_data_config: Dict[str, Any],
    data_dir_list: List[str],
    t_config: Dict[str, Any],
    use_prev_model_params: bool = False,
    return_best_model: bool = False, # return last model by default
    experiment_logger_configs: Optional[ExperimentLoggerWanbConfigs] = None,
    fast_dm: bool = False,
):
    hparams = deepcopy(default_data_config)
    hparams.update(t_config)
    if pred_training_module:
        hparams.update(pred_training_module.hparams)
    if experiment_logger_configs:
        experiment_logger_configs.hparams = hparams

    # data encoding
    print("aggregating data...")
    default_data_config = _aggregate_default_data_encoders(default_data_config, data_dir_list)

    # training models
    # TODO: something is wrong here
    if pred_training_module is not None:
        print("start training...")
        model_params_opt_list, data_module_list = train_models(
            pred_training_module, default_data_config,
            data_dir_list, t_config, return_data_module_list=True,
            use_fast=fast_dm
        )
    if isinstance(cf_module, BaseParametricCFModule):
        print("start training CF Module...")
        model_params_opt_list, data_module_list = train_models(
            cf_module.module, default_data_config,
            data_dir_list, t_config, return_data_module_list=True
        )

    # else:
    #     data_module_list = get_datamodules(default_data_config, data_dir_list)
    #     model_params_opt_list = [(None, None)] * len(data_module_list)

    # evaluate cfs
    print("generating cfs...")
    experiment_results: List[ExperimentResult] = []
    cf_results_list = []

    # generate_cf_results_fn
    # def generate_cf_results_fn(cf_module, dm, params, rng_key):
    #     _params = deepcopy(params)
    #     pred_fn = lambda x: pred_training_module.forward(_params, rng_key, x)
    #     return generate_cf_results_local_exp(cf_module, dm, pred_fn=pred_fn) if is_local_cf_module \
    #         else generate_cf_results_cfnet(cf_module, dm, params=_params, rng_key=rng_key)
    for i, ((params, _), dm) in enumerate(zip(model_params_opt_list, data_module_list)):
        # cf_results = generate_cf_results_fn(
        #     cf_module, dm,
        #     params=params, rng_key=random.PRNGKey(0)
        # )
        if pred_training_module: 
            pred_fn = lambda x, params, rng: pred_training_module.forward(params, rng, x)
        else: pred_fn = None
        _params = deepcopy(params)
        del params
        
        if isinstance(cf_module, BaseParametricCFModule):
            cf_module.params = _params
        
        cf_exp = generate_cf_explanations(
            cf_module, dm,
            pred_fn=pred_fn,
            t_configs=t_config,
            pred_fn_args={'params': _params, 'rng': random.PRNGKey(0)},
        )
        cf_exp.cf_name = f"model_{i}"
        cf_exp.dataset_name = f"data_{i}"
        if isinstance(cf_module, BasePredFnCFModule):
            pred_fn = lambda x, params, rng: cf_module.module.predict(params, rng, x)
            cf_exp.pred_fn = _AuxPredFn(pred_fn, {'params': _params, 'rng': random.PRNGKey(0)})
        cf_results_list.append(cf_exp)

    experiment_results = _evaluate_adversarial_model(
        cf_results_list, experiment_logger_configs=experiment_logger_configs)
    return cf_results_list, experiment_results

In [None]:
# #| export
# def adversarial_experiment_cfnet(
#     training_module: CounterNetTrainingModule,
#     default_data_config: Dict[str, Any],
#     data_dir_list: List[str],
#     t_config: Dict[str, Any],
#     use_prev_model_params: bool = False,
#     return_best_model: bool = False, # return last model by default
#     experiment_logger_configs: Optional[ExperimentLoggerWanbConfigs] = None
# ):
#     return adversarial_experiment(
#         pred_training_module=training_module,
#         cf_module=training_module,
#         default_data_config=default_data_config,
#         data_dir_list=data_dir_list,
#         t_config=t_config,
#         is_local_cf_module=False,
#         use_prev_model_params=use_prev_model_params,
#         return_best_model=return_best_model,
#         experiment_logger_configs=experiment_logger_configs
#     )

# def adversarial_experiment_local_exp(
#     pred_training_module: CounterNetTrainingModule,
#     cf_moudle: LocalCFExplanationModule,
#     default_data_config: Dict[str, Any],
#     data_dir_list: List[str],
#     t_config: Dict[str, Any],
#     use_prev_model_params: bool = False,
#     return_best_model: bool = False, # return last model by default
#     experiment_logger_configs: Optional[ExperimentLoggerWanbConfigs] = None
# ):
#     return adversarial_experiment(
#         pred_training_module=pred_training_module,
#         cf_module=cf_moudle,
#         default_data_config=default_data_config,
#         data_dir_list=data_dir_list,
#         t_config=t_config,
#         is_local_cf_module=True,
#         use_prev_model_params=use_prev_model_params,
#         return_best_model=return_best_model,
#         experiment_logger_configs=experiment_logger_configs
#     )

### Test

#### Unit Tests

In [None]:
cf_results_list = [
    Explanation(
        cf_name='m1',
        dataset_name='d1',
        X=jnp.array([
            [1, 0, 1],
            [1, 1, 0],
            [0, 1, 0],
            [0, 1, 1],
        ]), # y_pred = [1, 1, 0, 0]
        y=jnp.ones((4,1)),
        cfs=jnp.array([
            [0, 0, 1],
            [0, 1, 0],
            [1, 1, 0],
            [1, 1, 1]
        ]),
        pred_fn=lambda x: x[:, 0], 
        total_time=0.1,
        data_module=None
    ),
    Explanation(
        cf_name='m2',
        dataset_name='d2',
        X=jnp.array([
            [1, 0, 1],
            [1, 1, 0],
            [0, 1, 1],
            [1, 0, 0]
        ]), # y_pred = [1, 0, 1, 0]
        y=jnp.ones((4,1)),
        cfs=jnp.array([
            [1, 0, 0],
            [1, 1, 1],
            [0, 1, 0],
            [1, 0, 1]
        ]),
        pred_fn=lambda x: x[:, -1],
        total_time=0.1,
        data_module=None
    ),
]



In [None]:
assert compute_rob_validity(cf_results_list[0], cf_results_list[0].pred_fn) == 1.0
assert compute_rob_validity(cf_results_list[0], cf_results_list[1].pred_fn) == 0.5
assert compute_rob_validity(cf_results_list[1], cf_results_list[0].pred_fn) == 0.75
assert compute_rob_validity(cf_results_list[1], cf_results_list[1].pred_fn) == 1.0

In [None]:
calculate_validity_matrix(cf_results_list)

Unnamed: 0,d1,d2
m1,1.0,0.75
m2,0.5,1.0


In [None]:
m_configs = {
    "enc_sizes": [50,10],
    "dec_sizes": [10],
    "exp_sizes": [50, 50],
    "dropout_rate": 0.3,
    'lr': 0.003,
    "lambda_1": 1.0,
    "lambda_3": 0.1,
    "lambda_2": 0.2,
    'adv_lr': 0.03
}
t_configs = {
    'n_epochs': 2,
    'monitor_metrics': 'val/val_loss',
    'batch_size': 128,
}
data_configs = {
    "data_dir": "../assets/data/loan/year=2008.csv",
    "data_name": "loan",
    'sample_frac': 0.1,
    'batch_size': 128,
    "continous_cols": [
        "NoEmp", "NewExist", "CreateJob", "RetainedJob", "DisbursementGross", "GrAppv", "SBA_Appv"
    ],
    "discret_cols": [
        "State", "Term", "UrbanRural", "LowDoc", "Sector_Points"
    ],
}


#### RocourseNet

In [None]:
from relax.module import PredictiveTrainingModule
from rocourse_net.module import RoCourseNet

In [None]:
experiment_logger_configs = ExperimentLoggerWanbConfigs(
    project_name='debug',
    user_name='birkhoffg',
    experiment_name='rocoursenet',
    
)
cf_results_list, experiment_results = adversarial_experiment(
    pred_training_module=None, #PredictiveTrainingModule({'lr': 0.003, 'sizes': [200, 10]}),
    cf_module=RoCourseNet(m_configs),
    default_data_config=data_configs,
    data_dir_list=[ 
        f"assets/data/loan/year={year}.csv" for year in range(2007, 2010) 
    ],
    # data_dir_list=[ 
    #     f"assets/data/loan/year={year}.csv" for year in range(1994, 2010) 
    # ],
    t_config=t_configs,
    # experiment_logger_configs=experiment_logger_configs
)

aggregating data...
total data length: 89366
preprocessing continuous features...
preprocessing discret features...
start training CF Module...


Epoch 1: 100%|██████████| 30/30 [00:00<00:00, 49.78batch/s, train/adv_loss=0.06253081, train/train_loss_1=0.0621, train/train_loss_2=0.0625, train/train_loss_3=0.0391] 
Epoch 1: 100%|██████████| 17/17 [00:00<00:00, 51.66batch/s, train/adv_loss=nan, train/train_loss_1=0.143, train/train_loss_2=nan, train/train_loss_3=nan]            
Epoch 1: 100%|██████████| 7/7 [00:00<00:00, 51.30batch/s, train/adv_loss=0.14924568, train/train_loss_1=0.108, train/train_loss_2=0.149, train/train_loss_3=0.0417]


generating cfs...
calculating the validity matrix...


#### VanillaCF

In [None]:
from relax.methods import VanillaCF
from relax.module import PredictiveTrainingModule


In [None]:
cf_results_list, experiment_results = adversarial_experiment(
    pred_training_module=PredictiveTrainingModule({'lr': 0.003, 'sizes': [200, 10]}),
    cf_module=VanillaCF(),
    default_data_config=data_configs,
    data_dir_list=[ 
        f"assets/data/loan/year={year}.csv" for year in range(2007, 2010) 
    ],
    # data_dir_list=[ 
    #     f"assets/data/loan/year={year}.csv" for year in range(1994, 2010) 
    # ],
    t_config=t_configs,
    # experiment_logger_configs=experiment_logger_configs
)

aggregating data...
total data length: 89366
preprocessing continuous features...
preprocessing discret features...
start training...


Epoch 1: 100%|██████████| 30/30 [00:00<00:00, 91.07batch/s, train/train_loss_1=0.0644]
Epoch 1: 100%|██████████| 17/17 [00:00<00:00, 90.14batch/s, train/train_loss_1=0.15]  
Epoch 1: 100%|██████████| 7/7 [00:00<00:00, 94.06batch/s, train/train_loss_1=0.0994]


generating cfs...


100%|██████████| 1000/1000 [00:22<00:00, 44.12it/s]
100%|██████████| 1000/1000 [00:12<00:00, 83.25it/s]
100%|██████████| 1000/1000 [00:05<00:00, 182.37it/s]


calculating the validity matrix...


In [None]:
cf_results_list[0].cfs

DeviceArray([[ 0.3246443 , -0.66034925, -0.09777143, ...,  0.        ,
               0.        ,  1.        ],
             [-0.5055199 , -0.4706281 ,  0.07288305, ...,  0.        ,
               0.        ,  0.        ],
             [ 0.26521772, -0.57555026, -0.05948754, ...,  0.        ,
               0.        ,  1.        ],
             ...,
             [ 0.3639497 , -0.70889324, -0.12165295, ...,  0.        ,
               0.        ,  1.        ],
             [ 0.4209727 , -0.80577576, -0.15645298, ...,  0.        ,
               0.        ,  1.        ],
             [ 0.25610036, -0.54156506, -0.06003434, ...,  0.        ,
               0.        ,  1.        ]], dtype=float32)

In [None]:
y_pred_0 = cf_results_list[0].pred_fn(cf_results_list[0].X).reshape(-1, 1).round()
y_pred_1 = cf_results_list[1].pred_fn(cf_results_list[0].X).reshape(-1, 1).round()
(y_pred_0 == y_pred_1).sum()

DeviceArray(12466, dtype=int32)

In [None]:
len(y_pred_0)

12757

In [None]:
cf_results_list[0].pred_fn.fn_args['params']['counter_net_model/Explainer_1']

{'b': DeviceArray([-0.05756189, -0.23633571,  0.01710565, -0.05331689,
              -0.3896599 , -0.27063757, -0.32480994, -0.14294973,
              -0.04665818, -0.03992986,  0.03972355,  0.08615527,
              -0.0744281 , -0.13791753, -0.13287549, -0.15250969,
               0.564096  ,  0.08776585, -0.16287263, -0.13107836,
              -0.09720206, -0.07217125, -0.06066768, -0.11944605,
              -0.07659575, -0.08929113, -0.07949819, -0.09658174,
              -0.09656797, -0.03666689, -0.05324696, -0.1618365 ,
              -0.12976535,  0.62580806, -0.08158124, -0.03045881,
              -0.09196167, -0.12518756, -0.08453243, -0.0775774 ,
              -0.12637912, -0.09183856, -0.0062271 , -0.13629813,
              -0.1257055 ,  0.11852361, -0.02613323, -0.11727765,
              -0.15367144,  0.00479156, -0.06008328, -0.06138969,
              -0.09986464,  0.25155354, -0.06922829, -0.13039733,
              -0.11837475, -0.11905003, -0.00613969, -0.13261838,
     

In [None]:
cf_results_list[0].pred_fn.fn_args['params']['counter_net_model/Explainer_1']

{'b': DeviceArray([-0.05756189, -0.23633571,  0.01710565, -0.05331689,
              -0.3896599 , -0.27063757, -0.32480994, -0.14294973,
              -0.04665818, -0.03992986,  0.03972355,  0.08615527,
              -0.0744281 , -0.13791753, -0.13287549, -0.15250969,
               0.564096  ,  0.08776585, -0.16287263, -0.13107836,
              -0.09720206, -0.07217125, -0.06066768, -0.11944605,
              -0.07659575, -0.08929113, -0.07949819, -0.09658174,
              -0.09656797, -0.03666689, -0.05324696, -0.1618365 ,
              -0.12976535,  0.62580806, -0.08158124, -0.03045881,
              -0.09196167, -0.12518756, -0.08453243, -0.0775774 ,
              -0.12637912, -0.09183856, -0.0062271 , -0.13629813,
              -0.1257055 ,  0.11852361, -0.02613323, -0.11727765,
              -0.15367144,  0.00479156, -0.06008328, -0.06138969,
              -0.09986464,  0.25155354, -0.06922829, -0.13039733,
              -0.11837475, -0.11905003, -0.00613969, -0.13261838,
     

In [None]:
cf_results_list[1].pred_fn.fn_args['params']['counter_net_model/Explainer_1']

{'b': DeviceArray([-0.02574765,  0.09567051, -0.01645298,  0.02813452,
               0.05758457,  0.06874906,  0.04412473,  0.1518215 ,
              -0.03008602, -0.04655835, -0.03157707, -0.02536911,
              -0.02293781, -0.06237506, -0.04882557, -0.0657744 ,
               0.04763514,  0.08136597, -0.04336939, -0.03505868,
              -0.05165615, -0.0383948 ,  0.00507054, -0.03148966,
              -0.01416376,  0.0120643 ,  0.24699654, -0.06389464,
              -0.00368025, -0.01580356,  0.01647298, -0.03387755,
              -0.06293271, -0.076146  ,  0.01791069,  0.2717298 ,
               0.04025881, -0.04208563, -0.05477343, -0.00336144,
              -0.06703535, -0.02773977,  0.1618863 , -0.02318909,
              -0.04799687,  0.0113631 , -0.00982577, -0.0662851 ,
              -0.04215523, -0.03200383,  0.00206313, -0.02934637,
              -0.09422823, -0.03384609, -0.03540881, -0.02038316,
              -0.03813815,  0.01359431, -0.10401475, -0.10064937,
     