# Wrap-up codes for Occlusion Training

In [1]:
from collections import defaultdict

import wandb

# Open target RoI indices
with open("assets/dkt_indices.txt", mode="r") as f:
    gt_rois = {int(roi.strip()) for roi in f.readlines()}

api = wandb.Api()
runs = api.runs(path="1pha/brain-age",
                filters={"config.dataloader.dataset._target_": "sage.data.mask.UKB_MaskDataset"})
len(runs)

83

In [2]:
runs_dict = defaultdict(dict)
runs_idx = defaultdict(set)
for run in runs:
    state = run.state
    name: str = run.name.split(" ")[1]
    if not name.isnumeric():
        # Skip outlier run name that does not follow `M XXXX | SEED`
        continue
    idx = int(name)
    if state == "finished":
        runs_dict["success"][idx] = run
        runs_idx["success"].add(idx)
    elif state == "failed":
        runs_dict["fail"][idx] = run
        runs_idx["fail"].add(idx)
    else:
        runs_dict["etc"][idx] = (run, state)
        runs_idx["etc"].add(idx)

In [3]:
len(runs_idx["success"]), len(runs_idx["fail"]), len(runs_idx["etc"])

(66, 11, 3)

In [4]:
assert gt_rois.issuperset(runs_idx["success"])
leftover = gt_rois - runs_idx["success"]

In [5]:
len(leftover)

34

In [6]:
true_fail = {idx for idx in runs_idx["fail"] if idx not in runs_idx["success"]}
print(f"Actual failed runs: {len(true_fail)}")
for fail in true_fail:
    print(fail)

Actual failed runs: 9
1025
2018
2024
1002
77
2030
1009
54
24


In [7]:
# Runs checked manually that has final test performance 
manual_check_success: set = {1025, 2018, 2024, 1002, 77, 54, 24}

# Was not able to get test performance but has final checkpoint
manual_check_ckpt: set = {2030, 1009}

# Sanity check for empty runs
assert len(true_fail - (manual_check_ckpt | manual_check_success)) == 0
assert len((manual_check_ckpt | manual_check_success) - true_fail) == 0

In [8]:
leftover = leftover - true_fail
len(leftover)

25

In [9]:
leftover & runs_idx["etc"]

{2017}

Mask idx 2017 turned out to be crashed. Crahsed runs should be re-batched.

In [10]:
leftover

{2,
 46,
 47,
 49,
 50,
 251,
 252,
 253,
 254,
 255,
 1010,
 1011,
 1012,
 1013,
 1014,
 1015,
 1016,
 2005,
 2006,
 2007,
 2008,
 2017,
 2031,
 2034,
 2035}

Split runs and re-batch

In [12]:
import random

def split_into_chunks(data_set, num_chunks):
    # Convert set to list for shuffling
    data_list = list(data_set)
    
    # Calculate the size of each chunk
    chunk_size = len(data_list) // num_chunks
    remainder = len(data_list) % num_chunks
    
    # Split the list into chunks
    chunks = []
    start = 0
    for i in range(num_chunks):
        chunk_length = chunk_size + (1 if i < remainder else 0)
        chunks.append(data_list[start:start + chunk_length])
        start += chunk_length
        
    flattened_chunks = [item for sublist in chunks for item in sublist]
    assert set(flattened_chunks) == data_set, "Chunks do not contain all elements of the data set"
    return chunks


machines = ["185-0", "185-1", "245-0", "245-1"]
chunks = split_into_chunks(data_set=leftover, num_chunks=4)

In [13]:
chunks

[[2, 46, 47, 49, 50, 2005, 2006],
 [2007, 2008, 2017, 2031, 2034, 2035],
 [1010, 1011, 1012, 1013, 1016, 1014],
 [1015, 251, 252, 253, 254, 255]]

In [14]:
for chunk, machine in zip(chunks, machines):
    with open(f"assets/dkt_leftover_{machine}.txt", mode="w") as f:
        for roi in chunk:
            f.write(f"{roi}\n")