D21:

Train on:
- 34k undistorted (treat as source distortion - d=0) - from the training set

- 34k CT ring artifact (treat as target distortion - d=1) - from the training set

Validate on:

- 6k undistorted - from the new validation set (measure performance on both branches)

Separately, also track:

- 6k CT ring artifact - from the new validation set (measure performance on both branches)

Test on:

- 7k undistorted - from the new test set

- 7k rotated 90 degrees - from the new test set

- 7k uniform noise - from the new test set

- 7k CT ring artifact - from the new test set

With any additional time, please also test:

- 34k undistorted - from the training set

- 34k rotated 90 degrees - from the training set

- 34k uniform noise - from the training set

- 34k CT ring artifact - from the training set

In [1]:
import pandas as pd
import numpy as np
import os

data_dir = 'ct-distortion'

for filename in os.listdir(data_dir):
    if filename.endswith('.csv'):
        temp_data = np.loadtxt(os.path.join(data_dir, filename), delimiter=',')

In [2]:
def combine_npzs(data_dir):
    combined_data = {}
    order = ['first', 'second', 'third', 'fourth', 'fifth', 'sixth', 'seventh', 'eighth', 'ninth', 'tenth', 'last']
    # Process files in the order specified by the 'order' list
    for pos in order:
        for filename in os.listdir(data_dir):
            file_split = filename.split('_')
            if pos in file_split and filename.endswith('.npz'):
                print(f"Processing file: {filename}")
                file_path = os.path.join(data_dir, filename)
                data = np.load(file_path)
                for key in data.files:
                    if key in combined_data:
                        combined_data[key] = np.concatenate((combined_data[key], data[key]), axis=0)
                    else:
                        combined_data[key] = data[key]
    return combined_data

In [3]:
train_data = combine_npzs('ct-distortion/train')
val_data = combine_npzs('ct-distortion/val')
test_data = combine_npzs('ct-distortion/test')

Processing file: train_subset_first_5000_RingArtifactv1_images.npz
Processing file: train_subset_second_5000_RingArtifactv1_images.npz
Processing file: train_subset_third_5000_RingArtifactv1_images.npz
Processing file: train_subset_fourth_5000_RingArtifactv1_images.npz
Processing file: train_subset_fifth_5000_RingArtifactv1_images.npz
Processing file: train_subset_sixth_5000_RingArtifactv1_images.npz
Processing file: train_subset_last_4561_RingArtifactv1_images.npz
Processing file: val_subset_first_5000_RingArtifactv1_images.npz
Processing file: val_subset_last_1491_RingArtifactv1_images.npz
Processing file: test_subset_first_5000_RingArtifactv1_images.npz
Processing file: test_subset_second_5000_RingArtifactv1_images.npz
Processing file: test_subset_third_5000_RingArtifactv1_images.npz
Processing file: test_subset_last_2778_RingArtifactv1_images.npz


In [4]:
train_data.keys()

dict_keys(['original', 'label', 'Ring_Artifact_v1'])

In [5]:
import os

def import_data(directory, save_path=None, save=False):

    data = []
    for filename in os.listdir(directory):
        if filename.endswith('.npz'):
            file_path = os.path.join(directory, filename)
            loaded_data = np.load(file_path)
            data.append(loaded_data)

    # concatenate the data from all files
    all_data = {}
    for key in data[0].keys():
        all_data[key] = np.concatenate([d[key] for d in data], axis=0)

    # check the shape of the concatenated data
    for key, value in all_data.items():
        print(f"{key}: {value.shape}")  

    if save:
        if save_path is None:
            save_path = f'datasets/{directory}_concatenated_data.npz'
        np.savez(save_path, **all_data)
        print(f"Data saved to {save_path}")

    return all_data

In [6]:
def combine_data(orig_data, new_data):
    combined_data = dict(orig_data)

    combined_data['Ring_Artifact_v1'] = new_data['Ring_Artifact_v1']
    combined_data['ring_labels'] = new_data['label']

    return combined_data


In [7]:
orig_train_data = np.load('datasets/training.npz')
orig_val_data = np.load('datasets/validation.npz')
orig_test_data = np.load('datasets/test.npz')

combined_train_data = combine_data(orig_train_data, train_data)
combined_val_data = combine_data(orig_val_data, val_data)
combined_test_data = combine_data(orig_test_data, test_data)

In [8]:
# orig_data = np.load('datasets/training.npz')
# print(orig_data.files)
# print(len(orig_data['original']))

In [9]:
# data = dict(orig_data)

# data['Ring_Artifact_v1'] = train_data['Ring_Artifact_v1']
# data['ring_labels'] = train_data['label']

# data.keys()

In [10]:
def normalize_image(image, mean=0.5, std=0.5):
    """
    Normalize an image tensor to have a mean and standard deviation.
    """
    return (image - mean) / std

def normalize_images(images, mean=0.5, std=0.5):
    """
    Normalize a list of images.
    """
    return [normalize_image(image, mean, std) for image in images]

In [None]:
from branched_resnet import CustomImageDataset

# def preprocess_data(data, distortions, include_original=True):

#     keys = list(data.keys())

#     if 'Ring_Artifact_v1' in distortions:
#         ring_flag = True
#         distortions.remove('Ring_Artifact_v1')
#     else:
#         ring_flag = False

#     if include_original:
#         images = [data[keys[0]]]
#     else:
#         images = []

#     for distortion in distortions:
#         images.append(data[keys[distortion]])

#     labels = data[keys[1]]
#     ring_labels = data[keys[-1]]

#     normalized_images = []
#     for image in images:
#         normalized_images.append(normalize_images(image))

#     zero_labels = np.zeros_like(labels)
#     one_labels = np.ones_like(labels)

#     if include_original:
#         domain_label_list = [zero_labels]
#         expanded_label_list = [labels]
#     else:
#         domain_label_list = []
#         expanded_label_list = []

#     for _ in distortions:
#         domain_label_list.append(one_labels)
#         expanded_label_list.append(labels)

#     domain_labels = np.concatenate(domain_label_list, axis=0)
#     expanded_labels = np.concatenate(expanded_label_list, axis=0)

#     print(f"Domain labels shape: {domain_labels.shape}")
#     print(f"Expanded labels shape: {expanded_labels.shape}")

#     concatenated_images = np.concatenate(normalized_images, axis=0)

#     if ring_flag:
#         ring_images = data['Ring_Artifact_v1']
#         ring_labels = data['ring_labels']
#         ring_images = normalize_images(ring_images)
#         concatenated_images = np.concatenate((concatenated_images, ring_images), axis=0)
#         domain_labels = np.concatenate((domain_labels, one_labels), axis=0)
#         expanded_labels = np.concatenate((expanded_labels, ring_labels), axis=0)

#     print(len(concatenated_images), len(expanded_labels), len(domain_labels))
#     assert len(concatenated_images) == len(expanded_labels) == len(domain_labels), "Dataset length mismatch!"


#     dataset = CustomImageDataset(images=concatenated_images, labels1=expanded_labels, labels2=domain_labels)

#     return dataset

# from branched_resnet import CustomImageDataset

def preprocess_data(data, distortions, include_original=True):

    keys = list(data.keys())

    if 'Ring_Artifact_v1' in distortions:
        ring_flag = True
        distortions.remove('Ring_Artifact_v1')
    else:
        ring_flag = False

    if include_original:
        images = [data[keys[0]]]
    else:
        images = []

    for distortion in distortions:
        images.append(data[distortion])

    labels = data[keys[1]]
    ring_labels = data[keys[-1]]

    normalized_images = []
    for image in images:
        normalized_images.append(normalize_images(image))

    zero_labels = np.zeros_like(labels)
    one_labels = np.ones_like(labels)

    if include_original:
        domain_label_list = [zero_labels]
        expanded_label_list = [labels]
    else:
        domain_label_list = []
        expanded_label_list = []

    for _ in distortions:
        domain_label_list.append(one_labels)
        expanded_label_list.append(labels)

    if domain_label_list != []:
        domain_labels = np.concatenate(domain_label_list, axis=0)
        
    if expanded_label_list != []:
        expanded_labels = np.concatenate(expanded_label_list, axis=0)

    # print(f"Domain labels shape: {domain_labels.shape}")
    # print(f"Expanded labels shape: {expanded_labels.shape}")

    # if ring_flag and distortions == []:
    #     if include_original == False:
    #         # Empty npy array for concatenation
    #         concatenated_images = np.empty((0,) + (224, 224), dtype=float)
    #         domain_labels = np.empty((0,), dtype=int)
    #         expanded_labels = np.empty((0,), dtype=int)
    #     else:
    #         concatenated_images = np.concatenate(normalized_images, axis=0)
    # else:
    #     concatenated_images = np.concatenate(normalized_images, axis=0)

    if ring_flag:
        ring_images = data['Ring_Artifact_v1']
        ring_labels = data['ring_labels']
        ring_images = normalize_images(ring_images)
        concatenated_images = np.concatenate((concatenated_images, ring_images), axis=0)
        domain_labels = np.concatenate((domain_labels, one_labels), axis=0)
        expanded_labels = np.concatenate((expanded_labels, ring_labels), axis=0)


    print(len(concatenated_images), len(expanded_labels), len(domain_labels))
    assert len(concatenated_images) == len(expanded_labels) == len(domain_labels), "Dataset length mismatch!"

    # Shuffle the concatenated images and labels
    seed = 42
    np.random.seed(seed)
    shuffled_indices = np.random.permutation(len(concatenated_images))
    concatenated_images = concatenated_images[shuffled_indices]
    expanded_labels = expanded_labels[shuffled_indices]
    domain_labels = domain_labels[shuffled_indices]

    dataset = CustomImageDataset(images=concatenated_images, labels1=expanded_labels, labels2=domain_labels)

    return dataset

In [28]:
train_ds = preprocess_data(combined_train_data, distortions=['Ring_Artifact_v1'], include_original=True)
val_ds = preprocess_data(combined_val_data, distortions=['Ring_Artifact_v1'], include_original=True)

69122 69122 69122
12982 12982 12982


In [29]:
import branched_resnet as br 
from branched_resnet import CustomImageDataset
from transformers import Trainer, TrainingArguments, PreTrainedModel, ResNetConfig

config = ResNetConfig()
model = br.ResNetForMultiLabel(config=config, num_d1_classes=11, num_d2_classes=2)

for name, param in model.named_parameters():
    print(name, param.data.mean(), param.requires_grad)
    break

resnet.embedder.embedder.convolution.weight tensor(0.0005) True


In [None]:
# Uncomment to train the model
trainer = br.train_model(train_dataset=train_ds, eval_dataset= val_ds, model=model, output_dir= "./D21_Experiment", num_epochs=50, batch_size=32, train=True)

In [None]:
# from branched_resnet import CustomImageDataset

# def preprocess_data(data, distortions, include_original=True):

#     keys = list(data.keys())

#     if 'Ring_Artifact_v1' in distortions:
#         ring_flag = True
#         distortions.remove('Ring_Artifact_v1')
#     else:
#         ring_flag = False

#     if include_original:
#         images = [data[keys[0]]]
#     else:
#         images = []

#     for distortion in distortions:
#         images.append(data[distortion])

#     labels = data[keys[1]]
#     ring_labels = data[keys[-1]]

#     normalized_images = []
#     for image in images:
#         normalized_images.append(normalize_images(image))

#     zero_labels = np.zeros_like(labels)
#     one_labels = np.ones_like(labels)

#     if include_original:
#         domain_label_list = [zero_labels]
#         expanded_label_list = [labels]
#     else:
#         domain_label_list = []
#         expanded_label_list = []

#     for _ in distortions:
#         domain_label_list.append(one_labels)
#         expanded_label_list.append(labels)

#     if domain_label_list != []:
#         domain_labels = np.concatenate(domain_label_list, axis=0)
        
#     expanded_labels = np.concatenate(expanded_label_list, axis=0)

#     print(f"Domain labels shape: {domain_labels.shape}")
#     print(f"Expanded labels shape: {expanded_labels.shape}")

#     concatenated_images = np.concatenate(normalized_images, axis=0)

#     if ring_flag:
#         ring_images = data['Ring_Artifact_v1']
#         ring_labels = data['ring_labels']
#         ring_images = normalize_images(ring_images)
#         concatenated_images = np.concatenate((concatenated_images, ring_images), axis=0)
#         domain_labels = np.concatenate((domain_labels, one_labels), axis=0)
#         expanded_labels = np.concatenate((expanded_labels, ring_labels), axis=0)

#     print(len(concatenated_images), len(expanded_labels), len(domain_labels))
#     assert len(concatenated_images) == len(expanded_labels) == len(domain_labels), "Dataset length mismatch!"


#     dataset = CustomImageDataset(images=concatenated_images, labels1=expanded_labels, labels2=domain_labels)

#     return dataset

In [None]:
# Load the model from the saved checkpoint
# from safetensors.torch import load_file

# model_path = './D21_Experiment_0801/checkpoint-108050/model.safetensors'

# state_dict = load_file(model_path)
# model.load_state_dict(state_dict)

# trainer = br.train_model(train_dataset=train_ds, eval_dataset= val_ds, model=model, output_dir= "./D20_Experiment", num_epochs=50, batch_size=32, train=False)



In [31]:
# Clear up memory
combined_train_data = None
combined_val_data = None
train_ds = None
val_ds = None

test_data = combine_npzs('ct-distortion/test')
orig_test_data = np.load('datasets/test.npz')
combined_test_data = combine_data(orig_test_data, test_data)

print(combined_test_data.keys())


Processing file: test_subset_first_5000_RingArtifactv1_images.npz
Processing file: test_subset_second_5000_RingArtifactv1_images.npz
Processing file: test_subset_third_5000_RingArtifactv1_images.npz
Processing file: test_subset_last_2778_RingArtifactv1_images.npz
dict_keys(['original', 'label', 'Uniform_Noise', 'Rotate_90deg', 'Ring_Artifact_v1', 'ring_labels'])


In [None]:
# Undistorted
baseline = preprocess_data(combined_test_data, distortions=[], include_original=True)
trainer.evaluate(baseline)
#baseline = None

Domain labels shape: (17778, 1)
Expanded labels shape: (17778, 1)
17778 17778 17778


{'eval_loss': 0.13260479271411896,
 'eval_model_preparation_time': 0.005,
 'eval_accuracy_branch1': 0.9655191810102374,
 'eval_precision_branch1': 0.9635336002766532,
 'eval_recall_branch1': 0.9588807434978147,
 'eval_f1_branch1': 0.9607240512411571,
 'eval_accuracy_branch2': 0.47013162335470804,
 'eval_precision_branch2': 0.5,
 'eval_recall_branch2': 0.23506581167735402,
 'eval_f1_branch2': 0.31978879706152435,
 'eval_runtime': 31.6768,
 'eval_samples_per_second': 561.231,
 'eval_steps_per_second': 70.178}

In [None]:
# All available images and distortions
test_ds1 = preprocess_data(combined_test_data, distortions=['Uniform_Noise', 'Rotate_90deg', 'Ring_Artifact_v1'], include_original=True)
trainer.evaluate(test_ds1)
#test_ds1 = None

Domain labels shape: (53334, 1)
Expanded labels shape: (53334, 1)
71112 71112 71112


{'eval_loss': 0.0331798680126667,
 'eval_model_preparation_time': 0.005,
 'eval_accuracy_branch1': 0.7816121048486894,
 'eval_precision_branch1': 0.7857890939395983,
 'eval_recall_branch1': 0.7728506095613782,
 'eval_f1_branch1': 0.776401121742958,
 'eval_accuracy_branch2': 0.4874985937675779,
 'eval_precision_branch2': 0.4862818779635574,
 'eval_recall_branch2': 0.48170960362995463,
 'eval_f1_branch2': 0.4526181586742759,
 'eval_runtime': 126.8752,
 'eval_samples_per_second': 560.488,
 'eval_steps_per_second': 70.061}

In [None]:
# Uniform Noise and Rotate 90 degrees, no undistorted 
test_ds2 = preprocess_data(combined_test_data, distortions=['Uniform_Noise', 'Rotate_90deg'], include_original=False)
trainer.evaluate(test_ds2)
#test_ds2 = None

Domain labels shape: (35556, 1)
Expanded labels shape: (35556, 1)
35556 35556 35556


{'eval_loss': 0.0,
 'eval_model_preparation_time': 0.005,
 'eval_accuracy_branch1': 0.5984643941950726,
 'eval_precision_branch1': 0.6192614654873654,
 'eval_recall_branch1': 0.5880634768560274,
 'eval_f1_branch1': 0.5940642440759605,
 'eval_accuracy_branch2': 0.4751096861289234,
 'eval_precision_branch2': 0.5,
 'eval_recall_branch2': 0.2375548430644617,
 'eval_f1_branch2': 0.32208431047303093,
 'eval_runtime': 59.3141,
 'eval_samples_per_second': 599.453,
 'eval_steps_per_second': 74.94}

In [None]:
# Ring Artifact only
d3images = []
for image in combined_test_data['Ring_Artifact_v1']:
    d3images.append(normalize_image(image))

d3labels = combined_test_data['ring_labels']
d3domain_labels = np.ones_like(d3labels)
test_ds3 = br.CustomImageDataset(images=d3images, labels1=d3labels, labels2=d3domain_labels)

trainer.evaluate(test_ds3)


{'eval_loss': 0.0,
 'eval_model_preparation_time': 0.0056,
 'eval_accuracy_branch1': 0.9640004499943751,
 'eval_precision_branch1': 0.9619952781884255,
 'eval_recall_branch1': 0.9563947410356438,
 'eval_f1_branch1': 0.958774862716702,
 'eval_accuracy_branch2': 0.5296433794577567,
 'eval_precision_branch2': 0.5,
 'eval_recall_branch2': 0.2648216897288784,
 'eval_f1_branch2': 0.3462528498933588,
 'eval_runtime': 27.9197,
 'eval_samples_per_second': 636.754,
 'eval_steps_per_second': 79.621}

In [None]:
# Uniform Noise only
test_ds4 = preprocess_data(combined_test_data, distortions=['Uniform_Noise'], include_original=False)
trainer.evaluate(test_ds4)
#test_ds4 = None

17778 17778 17778


{'eval_loss': 0.0,
 'eval_model_preparation_time': 0.0056,
 'eval_accuracy_branch1': 0.7901901226234672,
 'eval_precision_branch1': 0.8445835236369393,
 'eval_recall_branch1': 0.7784693650722848,
 'eval_f1_branch1': 0.7892261792566025,
 'eval_accuracy_branch2': 0.4440881988975138,
 'eval_precision_branch2': 0.5,
 'eval_recall_branch2': 0.2220440994487569,
 'eval_f1_branch2': 0.30752152066373234,
 'eval_runtime': 28.8307,
 'eval_samples_per_second': 616.635,
 'eval_steps_per_second': 77.105}

In [None]:
# Rotate 90 degrees only
test_ds5 = preprocess_data(combined_test_data, distortions=['Rotate_90deg'], include_original=False)
trainer.evaluate(test_ds5)
#test_ds5 = None

17778 17778 17778


{'eval_loss': 0.0,
 'eval_model_preparation_time': 0.0056,
 'eval_accuracy_branch1': 0.40673866576667794,
 'eval_precision_branch1': 0.40807267046597073,
 'eval_recall_branch1': 0.39765758863976985,
 'eval_f1_branch1': 0.3888983603711902,
 'eval_accuracy_branch2': 0.506131173360333,
 'eval_precision_branch2': 0.5,
 'eval_recall_branch2': 0.2530655866801665,
 'eval_f1_branch2': 0.33604720645354047,
 'eval_runtime': 29.2435,
 'eval_samples_per_second': 607.93,
 'eval_steps_per_second': 76.017}