In [None]:
import torch


2.4.1+cu121


In [None]:
import numpy as np
import os

file_path = os.path.join("dataset", 'validation.npz')
images=np.load(file=file_path)
for key in images.files:
    print(f"Length of {key} is {len(images[key])}")

In [3]:
labels = images['label']
# print(labels.shape)
# print(labels)
organ_count = {}
for arr in labels:
    label = arr[0]
    organ_count[label] = organ_count.get(label,0)+1


print(f"Class Distribution:")
for key in organ_count:
    print(f"{key} : {organ_count[key]}")
    


Class Distribution:
3 : 392
6 : 1033
8 : 1009
2 : 225
9 : 529
5 : 637
10 : 511
7 : 1033
1 : 233
4 : 568
0 : 321


In [None]:
from torch.utils.data import Dataset
from torchvision import transforms
import torch

class CustomImageDataset(Dataset):
    def __init__(self, images, labels1, labels2, transform=None):
        # Need to add shuffling
        self.images = images  # Should be torch.Tensor of shape [N, 3, 224, 224]
        self.labels1 = labels1
        self.labels2 = labels2
        if transform is None:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x.repeat(3, 1, 1))  # Grayscale to 3-channel
            ])
        else:
            self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        label1 = int(self.labels1[idx])
        label2 = int(self.labels2[idx])

        if self.transform:
            img = self.transform(img)

        # return {"pixel_values": img, "labels1": label1, "labels2": label2}
        return {
            "pixel_values": img,
            "labels1": int(label1) if torch.is_tensor(label1) else label1,
            "labels2": int(label2) if torch.is_tensor(label2) else label2,
                }



def normalize_image(image, mean=0.5, std=0.5):
    """
    Normalize an image tensor to have a mean and standard deviation.
    """
    return (image - mean) / std

def normalize_images(images, mean=0.5, std=0.5):
    """
    Normalize a list of images.
    """
    return [normalize_image(image, mean, std) for image in images]

#### KFOLD SPLIT

In [None]:
from collections import defaultdict


def get_label_indices(label_dataset):
    """Returns dict of indices per label

    Args:
        label_dataset (np array): the "label" array from the npz

    Returns:
        defaultdict: indices sorted by key
    """

    organ_indices = defaultdict(list)
    for idx, lab in enumerate(label_dataset):
        organ_indices[int(lab)].append(idx)
    return organ_indices


keylist=['original', 'Uniform_Noise', 'Rotate_90deg', 'Ring_Artifact_v1', 'label']


def kfold_split(data, first_key, second_key, num_folds=5, seed=43):
    """Data is full concatenated npz file.
    
        keys are strings like 'original'
        
        Return:
        normalized images, list of labels, fold indices
    """
    if first_key not in data or second_key not in data or 'label' not in data:
        missing = [k for k in (first_key, second_key, 'label') if k not in data]
        raise KeyError(f"Missing keys in npz: {missing}")


    first_set  = data[first_key]
    second_set = data[second_key]
    
    
    first_normalized_images = np.array(normalize_images(first_set))
    second_normalized_images = np.array(normalize_images(second_set))
        
    labels     = data['label'] 

    #domain labels. Set to 0 for original, 1 otherwise.
    first_domain_labels  = np.zeros_like(labels) if first_key  == 'original' else np.ones_like(labels)
    second_domain_labels = np.zeros_like(labels) if second_key == 'original' else np.ones_like(labels)
    #build class buckets and shuffle in place
    organ_indices = get_label_indices(labels)   # {label: [idx, ...]}
    rng = np.random.default_rng(seed)
    for _, idxs in organ_indices.items():
        rng.shuffle(idxs)

    folds=[list() for _ in range(num_folds)]
    for label, idxs in organ_indices.items():
        for id, idx in enumerate(idxs):
            fold_id = id % num_folds
            folds[fold_id].append(idx)

    
    
    return first_normalized_images, second_normalized_images, folds, labels, first_domain_labels, second_domain_labels


"""Use outputs from prev functiona and selct fold 0-4"""
def retrieve_fold_data(fold_index, folds, labels, first_norm, second_norm, first_domain, second_domain):
    val_idx = np.array(folds[fold_index])
    all_idx = np.arange(len(labels))
    train_idx = np.setdiff1d(all_idx, val_idx)

    def build(indices):
        images = np.concatenate([first_norm[indices], second_norm[indices]], axis=0)
        organ_labels = np.concatenate([labels[indices], labels[indices]], axis=0)
        domain_labels = np.concatenate([first_domain[indices], second_domain[indices]], axis=0)
        return CustomImageDataset(images, organ_labels, domain_labels)

    val_dataset = build(val_idx)
    train_dataset = build(train_idx)
    return val_dataset, train_dataset
    

In [16]:
file_path = os.path.join("dataset", 'validation.npz')
images=np.load(file=file_path)

# unpack from kfold_split
first_norm, second_norm, folds, labels, first_domain, second_domain = kfold_split(images, "original", "Uniform_Noise")

# now pass the exact same names
val_dataset, train_dataset = retrieve_fold_data(
    fold_index=0,               # pick which fold 0–4
    folds=folds,
    labels=labels,
    first_norm=first_norm,
    second_norm=second_norm,
    first_domain=first_domain,
    second_domain=second_domain
    )



#example usage:

# for fold in range(0,5):
    
#     val_dataset, train_dataset = retrieve_fold_data(
#     fold_index=fold,               # pick which fold 0–4
#     folds=folds,
#     labels=labels,
#     first_norm=first_norm,
#     second_norm=second_norm,
#     first_domain=first_domain,
#     second_domain=second_domain
#     )
    
    #train model

In [17]:
val_dataset[5]

  label1 = int(self.labels1[idx])
  label2 = int(self.labels2[idx])


{'pixel_values': tensor([[[-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          ...,
          [-0.2157, -0.2157, -0.1686,  ..., -1.0000, -1.0000, -1.0000],
          [-0.2863, -0.2863, -0.2471,  ..., -1.0000, -1.0000, -1.0000],
          [-0.2863, -0.2863, -0.2471,  ..., -1.0000, -1.0000, -1.0000]],
 
         [[-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          ...,
          [-0.2157, -0.2157, -0.1686,  ..., -1.0000, -1.0000, -1.0000],
          [-0.2863, -0.2863, -0.2471,  ..., -1.0000, -1.0000, -1.0000],
          [-0.2863, -0.2863, -0.2471,  ..., -1.0000, -1.0000, -1.0000]],
 
         [[-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          [-

#### Creating kfold splitter for A and C models

In [None]:
import numpy as np
import os

train_data_path = os.path.join("..", "dataset", "merged_train_data.npz")
val_data_path = os.path.join("..", "dataset", "merged_val_data.npz")
train_data = np.load(train_data_path)
val_data =np.load(val_data_path)


merged = {name: [train_data[name], val_data[name]] for name in train_data.files}

merged = {name: np.concatenate(merged[name],axis=0) for name in train_data.files}


np.savez('full_train_val_dataset.npz', **merged)



In [None]:
#merging npz file
import glob
files = glob.glob("../datafull/val*")
names = ['original', 'label', 'Uniform_Noise', 'Rotate_90deg']
merged = {name:[] for name in names}
for data in sorted(files):
    # print(data)
    info = np.load(data)
    # print(info.files)
    for name in names:
        merged[name].append(info[name])

merged = {name: np.concatenate(merged[name], axis=0) for name in names}
np.savez('full_val.npz', **merged)

In [3]:
train_data_path = os.path.join("..", "dataset", "full_train_val_dataset.npz")

data = np.load(train_data_path)

for name in data.files:
    print(data[name].shape)

(41052, 224, 224)
(41052, 1)
(41052, 224, 224)
(41052, 224, 224)
(41052, 224, 224)
