# Preliminary to Breaking down into Research Questions

Since we have many models, checkpoints and saliency maps, it was important to have an easy access to those things.
We have `RQ.weight_parser.py` module to achieve our needs and here are the basic usage.

In default, all weights are saved in `meta_brain/weights/default` directory.
Other than `default` directoy contains ablation studies/checkpoints of some models and they may not contain target saliency methods or checkpoints in need.

### Naming Protocols
They're naively named after model name, e.g. resnet10.
Since we also deals with seed variability, we have a suffix that represents seed number, i.e. resnet10-seed42

### Attributes

- Meta Information
  - `prediction`: Inference results on test dataset
  - `config`: Hydra config yaml file used to batch the experiment
  - `ckpt_dict`: Dictionary of checkpoints with parsed information on step-performance.
  - `test_performance`: Test inference performance on 3,029 brains. `__repr__` will hold this value in string as well.

- XAI Information
  - `xai_dict`: Projected RoI Saliency key-value dict, averaged across 3,029 brains
  - `xai_dict_indiv`: Projected RoI Saliency key-list, where each list contains 3,029 individuals saliency maps.
  - `img_dict`: All image files PosixPath. Images include visualization of saliency map via nilearn.
  - `attrs`: `np.ndarray` of total averaged saliency maps without top-k value extracted.
  - `top_attr`: `np.ndarray` of averaged across top-k saliency maps.

Note that each XAI information can be set after `load_xai` method. This is to ensure that _which_ XAI method information is being extracted, since there are multiple methods of XAI.

## Fetching Base Information

In [2]:
# Basic Usage
from weight_parser import Weights

resnet10_42 = Weights(model_name="resnet10", seed=42)

INFO:/home/1pha/codespace/brain-age-prediction/RQ/weight_parser.py:Set base_path as /home/1pha/codespace/brain-age-prediction/meta_brain/weights/default/resnet10
INFO:/home/1pha/codespace/brain-age-prediction/RQ/weight_parser.py:Loading Basic Information


In [3]:
getattr(resnet10_42, "attrs")

AttributeError: 'Weights' object has no attribute 'attrs'

In [2]:
resnet10_42.prediction

{'loss': tensor([ 6.8574,  6.6092,  4.0870,  4.2467, 14.9120,  9.6608,  7.5648,  6.5287,
          7.5628,  9.1117,  7.2554,  7.1737, 10.8974,  9.3474,  9.5231,  7.7420,
          7.3360,  5.3722, 12.4250,  9.9102,  7.6106, 13.4158, 13.7678, 10.4643,
          5.9760, 11.8100,  8.3364, 14.3582, 11.9143, 10.4633, 10.2717,  9.8976,
          9.8213, 10.8430, 12.5011,  6.5727,  8.2385,  6.8624, 14.0359,  8.7301,
         12.5269,  5.3029, 10.8694, 10.3541, 10.7945,  8.0213, 11.8855, 15.6165,
          6.9720,  8.1804, 12.5467, 12.2646, 12.7532, 10.0480,  6.1709, 10.3209,
         13.4379,  9.7357, 10.6232, 10.8740,  9.2203,  7.6944,  6.9129, 11.4115,
          7.1686,  8.2259,  6.7797,  9.7257, 15.3016, 10.7471,  9.3316,  8.4056,
          9.4702, 11.3525,  6.9490, 15.1385,  7.5234, 10.5188,  7.6345,  8.8465,
          9.4362, 14.9369,  9.7391, 13.2086, 11.8094, 11.9764,  9.7353, 11.3397,
         10.2826,  6.4984, 11.6619,  8.9934,  9.5255, 14.0168, 15.5532]),
 'pred': tensor([67.2856, 5

In [3]:
resnet10_42.config

{'dataloader': {'_target_': 'torch.utils.data.DataLoader', 'batch_size': 32, 'num_workers': 4, 'pin_memory': True, 'dataset': '${dataset}'}, 'misc': {'seed': 42, 'debug': False, 'modes': ['train', 'valid', 'test']}, 'module': {'_target_': 'sage.trainer.PLModule', '_recursive_': False, 'augmentation': {'_target_': 'sage.data.augment', 'spatial_size': [160, 192, 160]}, 'load_from_checkpoint': None, 'load_model_ckpt': None, 'separate_lr': None, 'save_dir': '${callbacks.checkpoint.dirpath}'}, 'metrics': {'mae': {'_target_': 'torchmetrics.MeanAbsoluteError'}, 'rmse': {'_target_': 'torchmetrics.MeanSquaredError', 'squared': False}, 'r2': {'_target_': 'torchmetrics.R2Score'}}, 'logger': {'_target_': 'pytorch_lightning.loggers.WandbLogger', 'project': 'brain-age', 'entity': '1pha', 'name': 'R ${model.name} | ${module.augmentation.spatial_size} | ${misc.seed}', 'tags': ['model=${model.name}', 'REG']}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'max_epochs': 500, 'devices': 1, 'acceler

In [4]:
resnet10_42.ckpt_dict

{'steps': [(50000, 6.292),
  (250, 3736.735),
  (1000, 537.31),
  (25000, 9.847),
  (100000, 7.116),
  (10000, 17.014)],
 'last': PosixPath('/home/1pha/codespace/brain-age-prediction/meta_brain/weights/default/resnet10/last.ckpt'),
 'best': PosixPath('/home/1pha/codespace/brain-age-prediction/meta_brain/weights/default/resnet10/156864-valid_mae3.465.ckpt'),
 'best_valid_mae': [(156864, 3.465)]}

In [5]:
resnet10_42.test_performance

{'mse': 9.836, 'mae': 2.472, 'r2': 0.8315}

## Fetching Informations of Explainability methods 



In [6]:
resnet10_42.load_xai(xai_method="gbp")

INFO:/home/1pha/codespace/brain-age-prediction/RQ/weight_parser.py:Loading XAI information from /home/1pha/codespace/brain-age-prediction/meta_brain/weights/default/resnet10/gbpk0.99


In [7]:
resnet10_42.xai_dict

{'Left-Cerebral-White-Matter': 0.15630731406495435,
 'Left-Lateral-Ventricle': 0.6957168006506934,
 'Left-Inf-Lat-Vent': 0.10404497136550818,
 'Left-Cerebellum-White-Matter': 1.5143362531476072,
 'Left-Cerebellum-Cortex': 0.5381404644688782,
 'Left-Thalamus': 1.2142168953930585,
 'Left-Caudate': 0.6862779158485125,
 'Left-Putamen': 1.4636496779375936,
 'Left-Pallidum': 2.0797911295214724,
 '3rd-Ventricle': 2.5183922453495002,
 '4th-Ventricle': 0.5923415455801275,
 'Brain-Stem': 0.9234419224351268,
 'Left-Hippocampus': 0.16901173787404838,
 'Left-Amygdala': 0.30976443748508337,
 'CSF': 0.43540367069232594,
 'Left-Accumbens-area': 0.5724081656030201,
 'Left-VentralDC': 1.2987201759432883,
 'Left-choroid-plexus': 0.4518537586756545,
 'Right-Cerebral-White-Matter': 0.09066763225086949,
 'Right-Lateral-Ventricle': 0.5764432871728291,
 'Right-Inf-Lat-Vent': 0.036970381215333005,
 'Right-Cerebellum-White-Matter': 1.4926739162785665,
 'Right-Cerebellum-Cortex': 0.12309100013790775,
 'Right-Tha

In [8]:
resnet10_42.xai_dict_indiv

{'Left-Cerebral-White-Matter': [0.3014619052410126,
  0.25620996952056885,
  0.2873043119907379,
  0.17003440856933594,
  -0.00214385031722486,
  0.28820088505744934,
  0.35278433561325073,
  0.28773027658462524,
  0.40145331621170044,
  0.31434741616249084,
  0.27457964420318604,
  0.30474793910980225,
  0.2727004885673523,
  0.0024162703193724155,
  0.3509160280227661,
  0.1607201099395752,
  0.32812273502349854,
  0.28465571999549866,
  0.32350727915763855,
  -0.005454536061733961,
  -0.004244647454470396,
  0.280618816614151,
  0.378671795129776,
  0.2904545068740845,
  0.3071535527706146,
  0.26042667031288147,
  0.2933950126171112,
  0.32388007640838623,
  0.32170331478118896,
  0.23141588270664215,
  0.31118181347846985,
  0.27692022919654846,
  0.32033535838127136,
  0.01397360023111105,
  0.2954002916812897,
  0.27326470613479614,
  0.31828489899635315,
  0.3051712214946747,
  0.2836467921733856,
  0.28348013758659363,
  0.2932809591293335,
  0.3780822157859802,
  0.2694373428

In [9]:
resnet10_42.attrs

array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0.

In [13]:
from weight_parser import WeightAvg

avg_dict = dict(xai_method="", seeds=[42, 43, 44])
resnet10_avg = WeightAvg(model_name="resnet10", **avg_dict)
resnet18_avg = WeightAvg(model_name="resnet18", **avg_dict)
resnet34_avg = WeightAvg(model_name="resnet34", **avg_dict)
convnext_tiny_avg = WeightAvg(model_name="convnext-tiny", **avg_dict)

INFO:/home/1pha/codespace/brain-age-prediction/RQ/weight_parser.py:Load all seeds: [42, 43, 44]
INFO:/home/1pha/codespace/brain-age-prediction/RQ/weight_parser.py:Set base_path as /home/1pha/codespace/brain-age-prediction/meta_brain/weights/default/resnet10-42
INFO:/home/1pha/codespace/brain-age-prediction/RQ/weight_parser.py:Loading Basic Information
INFO:/home/1pha/codespace/brain-age-prediction/RQ/weight_parser.py:Set base_path as /home/1pha/codespace/brain-age-prediction/meta_brain/weights/default/resnet10-43
INFO:/home/1pha/codespace/brain-age-prediction/RQ/weight_parser.py:Loading Basic Information
INFO:/home/1pha/codespace/brain-age-prediction/RQ/weight_parser.py:Set base_path as /home/1pha/codespace/brain-age-prediction/meta_brain/weights/default/resnet10-44
INFO:/home/1pha/codespace/brain-age-prediction/RQ/weight_parser.py:Loading Basic Information
INFO:/home/1pha/codespace/brain-age-prediction/RQ/weight_parser.py:Aggregate across 3 seeds
INFO:/home/1pha/codespace/brain-age-pr

In [12]:
resnet10_avg.test_performance, resnet18_avg.test_performance, resnet34_avg.test_performance

({'mse': 10.612666666666668, 'mae': 2.5783333333333336, 'r2': 0.8182},
 {'mse': 11.122666666666666,
  'mae': 2.6229999999999998,
  'r2': 0.8094666666666667},
 {'mse': 13.956666666666665,
  'mae': 2.9876666666666662,
  'r2': 0.7609333333333334})

In [14]:
convnext_tiny_avg.test_performance

{'mse': 11.399666666666667, 'mae': 2.683666666666667, 'r2': 0.8047}

{'mse': 10.612666666666668, 'mae': 2.5783333333333336, 'r2': 0.8182}