In [None]:
import plotly.express as px
import numpy as np
import torch
import functools
from pathlib import Path
from tqdm import tqdm
from torch.linalg import norm

from src.trainers.gan_conditional_inn_trainer_hsi import GanCondinitionalDomainAdaptationINNHSI
from src.utils.config_io import load_config
from src import settings
from src.utils.susi import ExperimentResults

In [None]:
model_file = Path('/home/menjivar/DKFZ/Projects/MICCAI_23/results/gan_cinn_hsi/2023_02_20_21_06_28/version_0/checkpoints/epoch=699-step=641200.ckpt')
config_file = Path('/home/menjivar/DKFZ/Projects/MICCAI_23/results/gan_cinn_hsi/2023_02_20_21_06_28/version_0/hparams.yaml')
data_folder = Path('/home/menjivar/DKFZ/Projects/MICCAI_23/intermediates/semantic/val_synthetic_sampled')
segmentation_folder = Path('/home/menjivar/DKFZ/Projects/MICCAI_23/intermediates/semantic/segmentation')

In [None]:
ckpt = torch.load(model_file)
config = load_config(config_file)

In [None]:
ckpt['state_dict'].keys()

In [None]:
config = load_config(config_file)
model = GanCondinitionalDomainAdaptationINNHSI.load_from_checkpoint(model_file, experiment_config=config, strict=True)
model.cuda().eval()

In [None]:
def rgetattr(obj, attr, *args):
    def _getattr(obj, attr):
        return getattr(obj, attr, *args)
    return functools.reduce(_getattr, [obj] + attr.split('.'))

for key in ckpt['state_dict']:
    torch.testing.assert_close(ckpt['state_dict'][key], rgetattr(model, key), rtol=1e-9, atol=1e-6)

In [None]:
ignore_classes = ['gallbladder']
organs = [o for o in settings.organ_labels if o not in ignore_classes]
mapping_inv = {v: i for i, v in settings.mapping.items()}
order = {int(mapping_inv[o]): i for i, o in enumerate(organs) if o not in ignore_classes}

In [None]:
files = list(data_folder.glob('*.npy'))
files = [f for f in files if '_ind.npy' not in str(f)]
label_files = [segmentation_folder / Path(str(f.name).replace('_KNN_0', '')) for f in files]

In [None]:
results = ExperimentResults()
for i, f in enumerate(tqdm(files)):
    subject_id, image_id = f.name.split('#')
    image_id = image_id.split('.')[0].replace('_KNN_0', '')
    x = np.load(f)
    y = np.load(label_files[i])
    ind = y != 9  # ignore gallbladder
    y = y[ind]
    x_tensor = torch.tensor(x[ind], dtype=torch.float32)
    x_tensor = x_tensor / norm(x_tensor, ord=2)
    if config.normalization == "standardize":
        x_tensor = (x_tensor - config.data.mean_a) / config.data.std_a
    y_tensor = torch.tensor(y, dtype=torch.float32)
    batch = dict(spectra_a=x_tensor, seg_a=y_tensor, order=order, spectra_b=x_tensor)
    spectra_a, spectra_b = model.get_spectra(batch)
    output = model.translate_spectrum(spectra_a, input_domain="a")[0].detach().cpu().numpy()
    output = output * config.data.std_b + config.data.mean_b
    for label in np.unique(y):
        ind = y == label
        agg = np.mean(output[ind], axis=0)
        variation = np.std(output[ind], axis=0)
        results.append(value=variation, name='variation')
        results.append(value=agg, name='agg')
        results.append(value=np.arange(500, 1000, 5), name='wavelength')
        results.append(value=[label for _ in agg], name='label')
        results.append(value=[settings.mapping.get(str(label)) for _ in agg], name='organ')
        results.append(value=[image_id for _ in agg], name='image_id')
        results.append(value=[subject_id for _ in agg], name='subject_id')

In [None]:
results_df = results.get_df()

In [None]:
px.line(data_frame=results_df,
        x='wavelength',
        y='agg',
        color='organ',
        line_group='image_id',
        )

In [None]:
px.line(data_frame=results_df,
        x='wavelength',
        y='variation',
        color='organ',
        line_group='image_id',
        )

In [None]:
px.line(output[:10].T, line_shape='spline')