In [None]:
import wandb
import json
from pathlib import Path
import pickle
import os
import glob
import re
import torch
import warnings
import hashlib

# Script constants
UPLOAD_KEY = 'UPLOAD'
DESCRIPTION_PLACEHOLDER = 'YOUR_DESCRIPTION_HERE'
SELECTION_FILEPATH = "./nnunet_wandb_selection.txt"

# wandb constants
WANDB_TEAM_NAME = "test-team-mdl"
WANDB_PROJECT = "nnUNet"

In [None]:
# Adjust paths if necessary

TRAINED_MODELS_PATH = os.environ['RESULTS_FOLDER']
# TRAINED_MODELS_PATH = "/share/data_rechenknecht01_2/weihsbach/nnunet/nnUNet_trained_models"
RAW_DATA_BASE_PATH = os.environ['nnUNet_raw_data_base']
# RAW_DATA_BASE_PATH = "/share/data_rechenknecht01_2/weihsbach/nnunet/nnUNet_raw_data_base"

INFERENCE_PATH = "/data_rechenknecht01_2/weihsbach/nnunet/nnUNet_inference_output"

UPLOAD_MODEL_FILES = False
COLLECT_PREDICTIONS = True

In [None]:
def sha256sum(filename):
    # see https://stackoverflow.com/questions/22058048/hashing-a-file-in-python
    hash_method  = hashlib.sha256()
    barr  = bytearray(128*1024)
    m_view = memoryview(barr)
    with open(filename, 'rb', buffering=0) as f:
        while cnt := f.readinto(m_view):
            hash_method.update(m_view[:cnt])
    return hash_method.hexdigest()

def retrieve_model_infos(model_pkl_path):
    model_path = Path(model_pkl_path.replace(".model.pkl", ".model"))
    model_pkl_path = Path(model_pkl_path)

    model_basepath = Path(model_path).parent
    metadata_filepath = Path(model_basepath, "debug.json")

    if metadata_filepath.is_file() and model_pkl_path.is_file() and model_path.is_file():
        # Collect data keys

        with open(metadata_filepath, 'r') as metadata_file:
            metadata  = json.load(metadata_file)
        task_name = metadata['dataset_directory'].split("/")[-1]

        torch_model = torch.load(model_path)
        model_type = re.match(r".*model_(.*?).model", str(model_path)).groups()[0]
        trainer_name = metadata['experiment_name']

        fold = metadata['fold']
        epoch = torch_model['epoch']
        configuration = re.match(r".*/nnUNet/(.*?)/Task\d{3}", str(model_path)).groups()[0]
        user = os.environ['USER']

        loss_train_data = torch_model['plot_stuff'][0]
        loss_val_data = torch_model['plot_stuff'][1]
        eval_metric_data = torch_model['plot_stuff'][3]

        plot_data = {'training/loss':loss_train_data, 'validation/loss':loss_val_data, 'eval/dice':eval_metric_data}
        config_dict = dict(
            task_name=task_name,
            last_epoch=epoch,
            fold=fold,
            configuration=configuration,
            uploading_user=user,
            trainer_name=trainer_name,
            model_type=model_type,
            model_path=model_path,
            model_hash=sha256sum(model_path),
            model_pkl_path=model_pkl_path,
            model_pkl_hash=sha256sum(model_pkl_path),
        )
        return model_basepath, config_dict, plot_data, metadata
    else:
        return None, None, None

def get_summary_json_paths(*base_paths):
    all_json_paths = []
    for _path in base_paths:
        # Read trained model files from nnUNet directory
        all_json_paths.extend(glob.glob(_path + "/**/summary.json", recursive=True))

    all_json_paths = [Path(_path) for _path in all_json_paths]
    return sorted(all_json_paths)

def get_model_pkl_paths(*base_paths):
    all_trained_model_paths = []
    for _path in base_paths:
        # Read trained model files from nnUNet directory
        all_trained_model_paths.extend(glob.glob(_path + "/**/*model.pkl", recursive=True))

    all_trained_model_paths = [Path(_path) for _path in all_trained_model_paths]
    return sorted(all_trained_model_paths)

def extract_summary_metrics(summary_json_path):
    with open(summary_json_path, 'r') as json_file:
        summary = json.load(json_file)

    file_results = summary['results']['all']
    target_filenames = [entry['test'].split("/")[-1] for entry in file_results]
    class_numstrings = [entry for entry in file_results[0].keys() if entry.isnumeric()]
    # metrics = list(file_results[0][class_nums[0]].keys())
    metrics = ['Dice', 'Jaccard', 'Precision', 'Recall']
    # summary_id = summary['id']
    table_dict = {}
    # "prediction/filename/metric/" per class
    for f_idx, f_name in enumerate(target_filenames):
        for met in metrics:
            data = []
            for c_numstr in class_numstrings:
                metric_val = file_results[f_idx][c_numstr][met]
                data.append([int(c_numstr), metric_val])
            table = wandb.Table(data=data, columns=["class id", met])
            # table_dict[f'prediction/{summary_id}/{f_name}/{met}'] = table
            table_dict[f'prediction/{f_name}/{met}'] = table
            
    return table_dict

def get_best_model_path(task_name, all_model_paths):
    final_models = list(filter(lambda elem: "model_final_checkpoint.model.pkl" in str(elem) and task_name in str(elem), all_model_paths))
    best_models = list(filter(lambda elem: "model_best_checkpoint.model.pkl" in str(elem) and task_name in str(elem), all_model_paths))
    latest_models = list(filter(lambda elem: "model_latest_checkpoint.model.pkl" in str(elem) and task_name in str(elem), all_model_paths))

    if final_models:
        return final_models[0]
    elif best_models: 
        return best_models[0]
    elif latest_models:
        return latest_models[0]
    return None

def find_best_fitting_model(summary_json_path, all_trained_model_paths):

    from difflib import SequenceMatcher

    def string_similarity(str_a, str_b):
        return SequenceMatcher(None, str_a, str_b).ratio()

    best_model_path = None
    best_similarity = 0.0

    for model_path in all_trained_model_paths:
        similarity = string_similarity(str(model_path), str(summary_json_path))

        if similarity > best_similarity:
            best_model_path = model_path
            best_similarity = similarity
    
    return best_model_path, all_trained_model_paths.index(best_model_path)

# 1. Check available models (and prediction scores) and write selection file for user

In [None]:

# Get existing hashes of online project
api = wandb.Api()
existing_model_hashes = [run.config.get('model_pkl_hash',"") for run in api.runs(f"{WANDB_TEAM_NAME}/{WANDB_PROJECT}")]
existing_notes = [run.notes for run in api.runs(f"{WANDB_TEAM_NAME}/{WANDB_PROJECT}")]

all_trained_model_paths = get_model_pkl_paths(TRAINED_MODELS_PATH)
all_summary_json_paths = get_summary_json_paths(TRAINED_MODELS_PATH, INFERENCE_PATH)

with open(SELECTION_FILEPATH, 'w') as selection_file:
    selection_file.write(f"# Change SKIP/EXISTS entry to {UPLOAD_KEY} for all models that shall be uploaded\n")
    for m_idx, m_path in enumerate(all_trained_model_paths):
        sha = sha256sum(m_path)
        if sha in existing_model_hashes:
            idx = existing_model_hashes.index(sha)
            description = existing_notes[idx] if existing_notes[idx] != None else DESCRIPTION_PLACEHOLDER
            file_key = "EXISTS"
        else:
            description = DESCRIPTION_PLACEHOLDER
            file_key = "SKIP"

        selection_file.write(f"{file_key};\t{description};\t\tMODEL {m_idx:4d};\t\t{m_path}\n")

if COLLECT_PREDICTIONS:
    with open(SELECTION_FILEPATH, 'a') as mapping_file:
        mapping_file.write("\n")   
        mapping_file.write("\n")   
        mapping_file.write(f"# Remove ? before mapping to upload summary data alongside model. Modify mapping numbers if needed.\n")
        
        for summary_idx, _path in enumerate(all_summary_json_paths):
            best_model_path, model_idx = find_best_fitting_model(_path, all_trained_model_paths)
            mapping_file.write(f"?MAP SUMMARY{summary_idx:4d} -> MODEL {model_idx:4d}\n")
        mapping_file.write("\n")  

        for summary_idx, _path in enumerate(all_summary_json_paths):
            mapping_file.write(f"# SUMMARY{summary_idx:4d}; {_path}\n")
        mapping_file.write("\n")

# 2. Now select all models to be uploaded in file "./nnunet_wandb_selection.txt" (SKIP -> SAVE) and modify mapping to prediction scores

# 3. Run upload

In [None]:
def shrink_double_whitespace(whitespace_string):
    return ' '.join(whitespace_string.split())
summary_mapping_dict = {}
summary_path_dict = {}

with open(SELECTION_FILEPATH, 'r') as selection_file:
    # Check mappings
    for line in iter(selection_file):
        if line.startswith("MAP"):
            mapping_string, *_ = line.split(';')
            summary_id, model_id = re.match(r"MAP (SUMMARY\s*\d{1,4}) -> (MODEL\s*\d{1,4})", mapping_string).groups()
            summary_mapping_dict[shrink_double_whitespace(model_id)] = shrink_double_whitespace(summary_id)

with open(SELECTION_FILEPATH, 'r') as selection_file:
    # Check summaries
    for line in iter(selection_file):
        if line.startswith("# SUMMARY"):
            summary_id, json_path = line.split(';')
            summary_id = re.match(r".*?(SUMMARY\s*\d{1,4})", summary_id).groups()[0]
            json_path = json_path.replace('#', '').strip()
            summary_path_dict[shrink_double_whitespace(summary_id)] = json_path


with open(SELECTION_FILEPATH, 'r') as selection_file:
    # Check upload selections
    for line in iter(selection_file):
        # Check all file lines with model paths to upload to wandb
        if not line.startswith(UPLOAD_KEY):
            continue

        split_line = line.split(";")
        if len(split_line) != 4:
            warnings.warn(f"Error in entry {line}")
            continue

        line_command, user_description, model_id, model_pkl_path = split_line
        line_command, user_description, model_id, model_pkl_path = \
            line_command.strip(), user_description.strip(), shrink_double_whitespace(model_id), model_pkl_path.strip()

        if model_id in summary_mapping_dict:
            json_summary_path = summary_path_dict[summary_mapping_dict[model_id]]
            summary_tables = extract_summary_metrics(json_summary_path)
        else:
            summary_tables = {}

        model_basepath, config_dict, plot_data, metadata = retrieve_model_infos(model_pkl_path)

        long_name = \
            f"{config_dict['task_name']}" \
            f"__{config_dict['configuration']}" \
            f"__{config_dict['trainer_name']}" \
            f"__fold.{config_dict['fold']}" \
            f"__type.{config_dict['model_type']}" \
            f"__epoch.{config_dict['last_epoch']}" \
            f"__user.{config_dict['uploading_user']}" 
        long_name = long_name[:128] # Limit to wandd max artifact name size of 128

        description = "" if user_description == DESCRIPTION_PLACEHOLDER else user_description
        dataset_json_filepath = Path(RAW_DATA_BASE_PATH, f"nnUNet_raw_data/{config_dict['task_name']}/dataset.json")

        with wandb.init(name=long_name, project= "nnUNet", job_type="train", config=config_dict, notes=description) as run:
            
            for epoch_idx, (loss_train, loss_val, eval_dice) in \
                enumerate(zip(plot_data['training/loss'], plot_data['validation/loss'], plot_data['eval/dice'])):

                run.log({'training/loss': loss_train, 'validation/loss':loss_val, 'eval/dice':eval_dice}, step=epoch_idx)
            
            for wandb_path, table in summary_tables.items():
                # wandb_base = "".join(wandb_path.split("/")[:-1])
                title = " ".join(wandb_path.split("/")[-2:])
                run.log({wandb_path[:111]: wandb.plot.bar(table, *table.columns, title=title)})
                # artifact name cannot be longer than 128. run-bbbbbbb- is prefixed, _table is postfixed == 128-18

            model_artifact = wandb.Artifact(
                name=long_name, type="model",
                description=description,
                metadata=metadata
            )

            if UPLOAD_MODEL_FILES:
                model_artifact.add_file(model_pkl_path)
                model_artifact.add_file(model_path)

            for log_filepath in [_path for _path in model_basepath.iterdir() if "training_log" in str(_path)]:
                model_artifact.add_file(log_filepath)

            if dataset_json_filepath.is_file():
                model_artifact.add_file(dataset_json_filepath)
                
            run.log_artifact(model_artifact)

            wandb.finish()