D21:

Train on:
- 34k undistorted (treat as source distortion - d=0) - from the training set

- 34k CT ring artifact (treat as target distortion - d=1) - from the training set

Validate on:

- 6k undistorted - from the new validation set (measure performance on both branches)

Separately, also track:

- 6k CT ring artifact - from the new validation set (measure performance on both branches)

Test on:

- 7k undistorted - from the new test set

- 7k rotated 90 degrees - from the new test set

- 7k uniform noise - from the new test set

- 7k CT ring artifact - from the new test set

With any additional time, please also test:

- 34k undistorted - from the training set

- 34k rotated 90 degrees - from the training set

- 34k uniform noise - from the training set

- 34k CT ring artifact - from the training set

In [1]:
import pandas as pd
import numpy as np
import os

# data_dir = 'ct-distortion'

# for filename in os.listdir(data_dir):
#     if filename.endswith('.csv'):
#         temp_data = np.loadtxt(os.path.join(data_dir, filename), delimiter=',')

In [2]:
def combine_npzs(data_dir):
    combined_data = {}
    order = ['first', 'second', 'third', 'fourth', 'fifth', 'sixth', 'seventh', 'eighth', 'ninth', 'tenth', 'last']
    # Process files in the order specified by the 'order' list
    for pos in order:
        for filename in os.listdir(data_dir):
            file_split = filename.split('_')
            if pos in file_split and filename.endswith('.npz'):
                print(f"Processing file: {filename}")
                file_path = os.path.join(data_dir, filename)
                data = np.load(file_path)
                for key in data.files:
                    if key in combined_data:
                        combined_data[key] = np.concatenate((combined_data[key], data[key]), axis=0)
                    else:
                        combined_data[key] = data[key]
    return combined_data

In [3]:
# train_data = combine_npzs('ct-distortion/train')
# val_data = combine_npzs('ct-distortion/val')
# test_data = combine_npzs('ct-distortion/test')

In [4]:
# train_data.keys()

In [5]:
import os

def import_data(directory, save_path=None, save=False):

    data = []
    for filename in os.listdir(directory):
        if filename.endswith('.npz'):
            file_path = os.path.join(directory, filename)
            loaded_data = np.load(file_path)
            data.append(loaded_data)

    # concatenate the data from all files
    all_data = {}
    for key in data[0].keys():
        all_data[key] = np.concatenate([d[key] for d in data], axis=0)

    # check the shape of the concatenated data
    for key, value in all_data.items():
        print(f"{key}: {value.shape}")  

    if save:
        if save_path is None:
            save_path = f'datasets/{directory}_concatenated_data.npz'
        np.savez(save_path, **all_data)
        print(f"Data saved to {save_path}")

    return all_data

In [6]:
def combine_data(orig_data, new_data):
    combined_data = dict(orig_data)

    combined_data['Ring_Artifact_v1'] = new_data['Ring_Artifact_v1']
    combined_data['ring_labels'] = new_data['label']

    return combined_data


In [7]:
# orig_train_data = np.load('datasets/training.npz')
# orig_val_data = np.load('datasets/validation.npz')
# orig_test_data = np.load('datasets/test.npz')

# combined_train_data = combine_data(orig_train_data, train_data)
# combined_val_data = combine_data(orig_val_data, val_data)
# combined_test_data = combine_data(orig_test_data, test_data)

In [8]:
# orig_data = np.load('datasets/training.npz')
# print(orig_data.files)
# print(len(orig_data['original']))

In [9]:
# data = dict(orig_data)

# data['Ring_Artifact_v1'] = train_data['Ring_Artifact_v1']
# data['ring_labels'] = train_data['label']

# data.keys()

In [10]:
def normalize_image(image, mean=0.5, std=0.5):
    """
    Normalize an image tensor to have a mean and standard deviation.
    """
    return (image - mean) / std

def normalize_images(images, mean=0.5, std=0.5):
    """
    Normalize a list of images.
    """
    return [normalize_image(image, mean, std) for image in images]

In [11]:
from branched_resnet_v2 import CustomImageDataset, dataset_load

def preprocess_data(data, distortions, include_original=True, save_data = False, save_path=None):

    keys = list(data.keys())

    if 'Ring_Artifact_v1' in distortions:
        ring_flag = True
        distortions.remove('Ring_Artifact_v1')
    else:
        ring_flag = False

    if include_original:
        images = [data[keys[0]]]
    else:
        images = []

    for distortion in distortions:
        images.append(data[distortion])

    labels = data[keys[1]]
    ring_labels = data[keys[-1]]

    normalized_images = []
    for image in images:
        normalized_images.append(normalize_images(image))

    zero_labels = np.zeros_like(labels)
    one_labels = np.ones_like(labels)

    if include_original:
        domain_label_list = [zero_labels]
        expanded_label_list = [labels]
    else:
        domain_label_list = []
        expanded_label_list = []

    for _ in distortions:
        domain_label_list.append(one_labels)
        expanded_label_list.append(labels)

    if domain_label_list != []:
        domain_labels = np.concatenate(domain_label_list, axis=0)
        
    if expanded_label_list != []:
        expanded_labels = np.concatenate(expanded_label_list, axis=0)

    concatenated_images = np.concatenate(normalized_images, axis=0)

    if ring_flag:
        ring_images = data['Ring_Artifact_v1']
        ring_labels = data['ring_labels']
        ring_images = normalize_images(ring_images)
        concatenated_images = np.concatenate((concatenated_images, ring_images), axis=0)
        domain_labels = np.concatenate((domain_labels, one_labels), axis=0)
        expanded_labels = np.concatenate((expanded_labels, ring_labels), axis=0)


    print(len(concatenated_images), len(expanded_labels), len(domain_labels))
    assert len(concatenated_images) == len(expanded_labels) == len(domain_labels), "Dataset length mismatch!"

    # Shuffle the concatenated images and labels
    seed = 42
    np.random.seed(seed)
    shuffled_indices = np.random.permutation(len(concatenated_images))
    concatenated_images = concatenated_images[shuffled_indices]
    expanded_labels = expanded_labels[shuffled_indices]
    domain_labels = domain_labels[shuffled_indices]

    if save_data:
        if save_path is None:
            raise ValueError("save_path must be specified if save_data is True")
        np.savez_compressed(save_path, images=concatenated_images, labels1=expanded_labels, labels2=domain_labels)

    dataset = CustomImageDataset(images=concatenated_images, labels1=expanded_labels, labels2=domain_labels)

    return dataset




In [12]:
# train_ds = preprocess_data(combined_train_data, distortions=['Ring_Artifact_v1'], include_original=True, save_data=False, save_path='datasets/D21_processed_train_data.npz')
# val_ds = preprocess_data(combined_val_data, distortions=['Ring_Artifact_v1'], include_original=True, save_data=False, save_path='datasets/D21_processed_val_data.npz')

train_ds = dataset_load('datasets/D21_preprocessed/D21_processed_train_data.npz')
val_ds = dataset_load('datasets/D21_preprocessed/D21_processed_val_data.npz')

In [13]:
import branched_resnet_v2 as br 
from branched_resnet_v2 import CustomImageDataset
from transformers import Trainer, TrainingArguments, PreTrainedModel, ResNetConfig

config = ResNetConfig()
model = br.ResNetForMultiLabel(config=config, num_d1_classes=11, num_d2_classes=2, lamb = 0)

for name, param in model.named_parameters():
    print(name, param.data.mean(), param.requires_grad)
    break

resnet.embedder.embedder.convolution.weight tensor(-6.0013e-05) True


In [14]:
# Uncomment to train the model
trainer = br.train_model(train_dataset=train_ds, eval_dataset= val_ds, model=model, output_dir= "./D21_Experiment_0808", num_epochs=50, batch_size=32, train=True)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msamuelsavine[0m ([33msamuelsavine-johns-hopkins-university[0m). Use [1m`wandb login --relogin`[0m to force relogin


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Accuracy Branch1,Precision Branch1,Recall Branch1,F1 Branch1,Accuracy Branch2,Precision Branch2,Recall Branch2,F1 Branch2,Lambda
1,1.0456,0.946216,0.842705,0.741815,0.773954,0.74966,0.50647,0.50651,0.50647,0.505717,0.0
2,1.2679,0.78806,0.946772,0.941833,0.93836,0.938601,0.508935,0.508952,0.508935,0.508715,0.099668
3,2.6548,2.265007,0.958558,0.957685,0.9511,0.951867,0.5,0.25,0.5,0.333333,0.197375
4,3.4774,0.883536,0.9611,0.9557,0.961567,0.957984,0.50208,0.502749,0.50208,0.469826,0.291313
5,5.3014,0.9393,0.964335,0.959497,0.958933,0.958581,0.498845,0.498584,0.498845,0.474686,0.379949
6,7.0502,1.772734,0.810815,0.916579,0.862322,0.835682,0.5,0.25,0.5,0.333333,0.462117
7,7.1363,0.852029,0.935372,0.948993,0.918821,0.926842,0.49091,0.490209,0.49091,0.48163,0.53705
8,9.0351,2.986841,0.92428,0.914782,0.87347,0.854905,0.5,0.25,0.5,0.333333,0.604368
9,10.0755,3.801847,0.691804,0.790656,0.790473,0.751124,0.5,0.25,0.5,0.333333,0.664037
10,9.8623,4.804426,0.699661,0.850411,0.702916,0.693628,0.5,0.25,0.5,0.333333,0.716298


In [15]:
# Load the model from the saved checkpoint
from safetensors.torch import load_file

# model_path = './D21_Experiment_0801/checkpoint-108050/model.safetensors'

# state_dict = load_file(model_path)
# model.load_state_dict(state_dict)

# trainer = br.train_model(train_dataset=train_ds, eval_dataset= val_ds, model=model, output_dir= "./D20_Experiment", num_epochs=50, batch_size=32, train=False)

In [16]:
# Clear up memory
combined_train_data = None
combined_val_data = None
train_ds = None
val_ds = None

test_data = combine_npzs('ct-distortion/test')
orig_test_data = np.load('datasets/test.npz')
combined_test_data = combine_data(orig_test_data, test_data)

print(combined_test_data.keys())


Processing file: test_subset_first_5000_RingArtifactv1_images.npz
Processing file: test_subset_second_5000_RingArtifactv1_images.npz
Processing file: test_subset_third_5000_RingArtifactv1_images.npz
Processing file: test_subset_last_2778_RingArtifactv1_images.npz
dict_keys(['original', 'label', 'Uniform_Noise', 'Rotate_90deg', 'Ring_Artifact_v1', 'ring_labels'])


In [17]:
# Undistorted
# baseline = preprocess_data(combined_test_data, distortions=[], include_original=True, save_data=True, save_path='datasets/D21_processed_undistorted_test_data.npz')
baseline = dataset_load('datasets/D21_preprocessed/D21_processed_undistorted_test_data.npz')
trainer.evaluate(baseline)
#baseline = None

{'eval_loss': 1.0989922285079956,
 'eval_accuracy_branch1': 0.878445269434132,
 'eval_precision_branch1': 0.869077424493261,
 'eval_recall_branch1': 0.8625485199907562,
 'eval_f1_branch1': 0.8609263276413132,
 'eval_accuracy_branch2': 0.4185510181122736,
 'eval_precision_branch2': 0.5,
 'eval_recall_branch2': 0.2092755090561368,
 'eval_f1_branch2': 0.2950553154367739,
 'eval_lambda': 0.9998891029505543,
 'eval_runtime': 142.8588,
 'eval_samples_per_second': 124.445,
 'eval_steps_per_second': 15.561,
 'epoch': 50.0}

In [18]:
# All available images and distortions
# test_ds1 = preprocess_data(combined_test_data, distortions=['Uniform_Noise', 'Rotate_90deg', 'Ring_Artifact_v1'], include_original=True, save_data=True, save_path='datasets/D21_processed_all_distortions_test_data.npz')
test_ds1 = dataset_load('datasets/D21_preprocessed/D21_processed_all_distortions_test_data.npz')
trainer.evaluate(test_ds1)
#test_ds1 = None

{'eval_loss': 1.061236023902893,
 'eval_accuracy_branch1': 0.6571183485206435,
 'eval_precision_branch1': 0.6613203942713818,
 'eval_recall_branch1': 0.6394690071088301,
 'eval_f1_branch1': 0.635435500617461,
 'eval_accuracy_branch2': 0.5964534818314771,
 'eval_precision_branch2': 0.5301326129212903,
 'eval_recall_branch2': 0.5371526605917426,
 'eval_f1_branch2': 0.5252914505024343,
 'eval_lambda': 0.9998891029505543,
 'eval_runtime': 357.4699,
 'eval_samples_per_second': 198.931,
 'eval_steps_per_second': 24.866,
 'epoch': 50.0}

In [19]:
# Uniform Noise and Rotate 90 degrees, no undistorted 
# test_ds2 = preprocess_data(combined_test_data, distortions=['Uniform_Noise', 'Rotate_90deg'], include_original=False, save_data=True, save_path='datasets/D21_processed_uniform_noise_rotate_test_data.npz')
test_ds2 = dataset_load('datasets/D21_preprocessed/D21_processed_uniform_noise_rotate_test_data.npz')
trainer.evaluate(test_ds2)
#test_ds2 = None

{'eval_loss': 0.6401367783546448,
 'eval_accuracy_branch1': 0.43972887838902014,
 'eval_precision_branch1': 0.45005733959837074,
 'eval_recall_branch1': 0.4219098346324429,
 'eval_f1_branch1': 0.4082836851814925,
 'eval_accuracy_branch2': 0.6851726853414333,
 'eval_precision_branch2': 0.5,
 'eval_recall_branch2': 0.34258634267071664,
 'eval_f1_branch2': 0.40658900497346373,
 'eval_lambda': 0.9998891029505543,
 'eval_runtime': 112.1134,
 'eval_samples_per_second': 317.143,
 'eval_steps_per_second': 39.647,
 'epoch': 50.0}

In [20]:
# Ring Artifact only
# d3images = []
# for image in combined_test_data['Ring_Artifact_v1']:
#     d3images.append(normalize_image(image))

# d3labels = combined_test_data['ring_labels']
# d3domain_labels = np.ones_like(d3labels)

# # Save the dataset as an npz file

# np.savez_compressed('datasets/D21_processed_ring_artifact_test_data.npz', images=d3images, labels1=d3labels, labels2=d3domain_labels)

# test_ds3 = br.CustomImageDataset(images=d3images, labels1=d3labels, labels2=d3domain_labels)
test_ds3 = dataset_load('datasets/D21_preprocessed/D21_processed_ring_artifact_test_data.npz')

trainer.evaluate(test_ds3)


{'eval_loss': 0.6650639176368713,
 'eval_accuracy_branch1': 0.8705703678704017,
 'eval_precision_branch1': 0.8635510939073875,
 'eval_recall_branch1': 0.8515078391796789,
 'eval_f1_branch1': 0.8510729350588605,
 'eval_accuracy_branch2': 0.5969175385307683,
 'eval_precision_branch2': 0.5,
 'eval_recall_branch2': 0.29845876926538417,
 'eval_f1_branch2': 0.37379358929200424,
 'eval_lambda': 0.9998891029505543,
 'eval_runtime': 210.5483,
 'eval_samples_per_second': 84.437,
 'eval_steps_per_second': 10.558,
 'epoch': 50.0}

In [21]:
# Uniform Noise only
# test_ds4 = preprocess_data(combined_test_data, distortions=['Uniform_Noise'], include_original=False, save_data=True, save_path='datasets/D21_processed_uniform_noise_test_data.npz')
test_ds4 = dataset_load('datasets/D21_preprocessed/D21_processed_uniform_noise_test_data.npz')
trainer.evaluate(test_ds4)
#test_ds4 = None

{'eval_loss': 0.6523623466491699,
 'eval_accuracy_branch1': 0.45995050061874226,
 'eval_precision_branch1': 0.500974633482635,
 'eval_recall_branch1': 0.4372007485889999,
 'eval_f1_branch1': 0.3847833983854388,
 'eval_accuracy_branch2': 0.6820789740128248,
 'eval_precision_branch2': 0.5,
 'eval_recall_branch2': 0.3410394870064124,
 'eval_f1_branch2': 0.4054975922953451,
 'eval_lambda': 0.9998891029505543,
 'eval_runtime': 58.1566,
 'eval_samples_per_second': 305.692,
 'eval_steps_per_second': 38.224,
 'epoch': 50.0}

In [22]:
# Rotate 90 degrees only
# test_ds5 = preprocess_data(combined_test_data, distortions=['Rotate_90deg'], include_original=False, save_data=True, save_path='datasets/D21_processed_rotate_90_test_data.npz')
test_ds5 = dataset_load('datasets/D21_preprocessed/D21_processed_rotate_90_test_data.npz')
trainer.evaluate(test_ds5)
#test_ds5 = None

{'eval_loss': 0.6279111504554749,
 'eval_accuracy_branch1': 0.419507256159298,
 'eval_precision_branch1': 0.3922679614309215,
 'eval_recall_branch1': 0.40661892067588584,
 'eval_f1_branch1': 0.3807842638796789,
 'eval_accuracy_branch2': 0.6882663966700416,
 'eval_precision_branch2': 0.5,
 'eval_recall_branch2': 0.3441331983350208,
 'eval_f1_branch2': 0.40767641767175317,
 'eval_lambda': 0.9998891029505543,
 'eval_runtime': 53.1018,
 'eval_samples_per_second': 334.791,
 'eval_steps_per_second': 41.863,
 'epoch': 50.0}