In [1]:
import torch
from transformers import ResNetConfig
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import transformers
from torchvision import transforms
from resnet import ResNetForMultiLabel
from resnet import OrganAMNISTDataset, compute_metrics, train_model, evaluate_model
import random
import numpy as np
import os 

In [2]:
#SET SEEDS
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x2d51b863530>

#### Import NPZ by concatenating

In [3]:
import os

def import_data(directory, save_path=None, save=False):

    data = []
    for filename in os.listdir(directory):
        if filename.endswith('.npz'):
            file_path = os.path.join(directory, filename)
            loaded_data = np.load(file_path)
            data.append(loaded_data)

    # concatenate the data from all files
    all_data = {}
    for key in data[0].keys():
        all_data[key] = np.concatenate([d[key] for d in data], axis=0)

    # check the shape of the concatenated data
    for key, value in all_data.items():
        print(f"{key}: {value.shape}")  

    if save:
        if save_path is None:
            save_path = f'datasets/{directory}_concatenated_data.npz'
        np.savez(save_path, **all_data)
        print(f"Data saved to {save_path}")

    return all_data

#### After combining the uniform noise + rotation + original, add in CT 

In [None]:
#quick method to add new CT information into the full concat file. 
def combine_npz(npz_data, CT_directory, save_name, file_prefix='train'):

    labels = ['label', 'Ring_Artifact_v1']
    data = []
    for filename in os.listdir(CT_directory):
        if filename.startswith(file_prefix):
                file_path = os.path.join(CT_directory, filename)
                loaded_data = np.load(file_path)
                data.append(loaded_data)
    all_data = {}
    for key in labels:
        all_data[key] = np.concatenate([d[key] for d in data], axis=0)

    for key, value in all_data.items():
        print(f"{key}: {value.shape}")  
    
    npz_data['Ring_Artifact_v1'] = all_data['Ring_Artifact_v1']
    
    for key, value in npz_data.items():
        print(f"{key}: {value.shape}")  
        
    np.savez(save_name, **npz_data)
    print(f"Data saved to {save_name}")
    
    
directory = 'Distorted_OrganAMNIST/RingArtifactv1_npz'
training_data = import_data('Distorted_OrganAMNIST/UniformNoise_Rotate90_npz')
# val_data = import_data('Distorted_OrganAMNIST/UniformNoise_Rotate90_val_dataset')
# test_data = import_data('Distorted_OrganAMNIST/UniformNoise_Rotate90_test_dataset')

combine_npz(training_data, directory, 'training_concatenated_dataset_full.npz', 'train')
    
    


### Image normalizer 

In [5]:
def normalize_image(image, mean=0.5, std=0.5):
    """
    Normalize an image tensor to have a mean and standard deviation.
    """
    return (image - mean) / std

def normalize_images(images, mean=0.5, std=0.5):
    """
    Normalize a list of images.
    """
    return [normalize_image(image, mean, std) for image in images]

### Dataset Loader

In [6]:
#Modified CustomImageDataset Loader

class ModifiedCustomImageDataset(Dataset):
    def __init__(self, images, labels1,  transform=None):
        self.images = images 
        self.labels1 = labels1

        if transform is None:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x.repeat(3, 1, 1))  # Grayscale to 3-channel
            ])
        else:
            self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx].astype(np.float32)
        label1 = int(self.labels1[idx])


        if self.transform:
            img = self.transform(img)

        return {
            "pixel_values": img,
            "labels": int(label1) if torch.is_tensor(label1) else label1,
        }
        


#### Preprocessing functions for single distortions. Use C_pipeline for loading multiple distortions per set

In [7]:
#Modified

def preprocess_data(data, key='original'):
    """
    Modified from sam's implementation to preprocess 1 set of images. 
    
    key : str (distortion name)
    """
    keys = data.files 
    print(f'\nGenerating {key} set')
    
    images = data[key]
    labels = data['label']
    
    normalized_images = []
    for image in images:
        normalized_images.append(normalize_images(image))
        
    labels = np.array(labels)
    normalized_images = np.array(normalized_images)
    
    
    print(f"Labels shape: {labels.shape}")
    print(f"Images shape: {normalized_images.shape}")

    dataset = ModifiedCustomImageDataset(images=normalized_images, labels1=labels)

    return dataset


#### Preprocess data into distinct sets

In [8]:
val_set_loaded = np.load('val_concatenated_dataset_full.npz')
key_list = [key for key in val_set_loaded.files if key != 'label']
print(key_list)

val_rotate_set = preprocess_data(val_set_loaded, 'Rotate_90deg')
val_original_set = preprocess_data(val_set_loaded, 'original')
val_noise_set = preprocess_data(val_set_loaded, 'Uniform_Noise')
val_ct_set = preprocess_data(val_set_loaded, 'Ring_Artifact_v1')


['original', 'Uniform_Noise', 'Rotate_90deg', 'Ring_Artifact_v1']

Generating Rotate_90deg set
Labels shape: (6491, 1)
Images shape: (6491, 224, 224)

Generating original set
Labels shape: (6491, 1)
Images shape: (6491, 224, 224)

Generating Uniform_Noise set
Labels shape: (6491, 1)
Images shape: (6491, 224, 224)

Generating Ring_Artifact_v1 set
Labels shape: (6491, 1)
Images shape: (6491, 224, 224)


In [None]:
train_set_loaded = np.load('training_concatenated_dataset_full.npz')

train_rotate_set = preprocess_data(train_set_loaded, 'Rotate_90deg')
train_original_set = preprocess_data(train_set_loaded, 'original')
train_noise_set = preprocess_data(train_set_loaded, 'Uniform_Noise')
train_ct_set = preprocess_data(train_set_loaded, 'Ring_Artifact_v1')


## Train Model

In [None]:
if torch.cuda.is_available():
    print("CUDA is available!")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    device = torch.device("cuda")
else:
    print("CUDA is not available. Using CPU.")
    device = torch.device("cpu")

#### A Original

In [None]:
output_path = os.path.join('A1_Original', 'results')


config = ResNetConfig()
model = ResNetForMultiLabel(config)

print("Starting training")
trainer = train_model(
    train_dataset=train_original_set,
    eval_dataset=val_original_set,
    model=model,
    output_dir=output_path,  # Checkpoints will go here
    num_epochs=100,
    batch_size=32
)

print("Saving final model")
trainer.save_model(output_path)

#### A Uniform

In [None]:
output_path = os.path.join('A4_UNI', 'results')


config = ResNetConfig()
model = ResNetForMultiLabel(config)

print("Starting training")
trainer = train_model(
    train_dataset=train_noise_set,
    eval_dataset=val_noise_set,
    model=model,
    output_dir=output_path,  # Checkpoints will go here
    num_epochs=100,
    batch_size=32
)

print("Saving final model")
trainer.save_model(output_path)

#### A Ring

In [None]:
output_path = os.path.join('A2_CT', 'results')


config = ResNetConfig()
model = ResNetForMultiLabel(config)

print("Starting training")
trainer = train_model(
    train_dataset=train_ct_set,
    eval_dataset=val_ct_set,
    model=model,
    output_dir=output_path,  # Checkpoints will go here
    num_epochs=100,
    batch_size=32
)

print("Saving final model")
trainer.save_model(output_path)

#### A Rotation90

In [None]:
output_path = os.path.join('A3_ROT', 'results')


config = ResNetConfig()
model = ResNetForMultiLabel(config)

print("Starting training")
trainer = train_model(
    train_dataset=train_rotate_set,
    eval_dataset=val_rotate_set,
    model=model,
    output_dir=output_path,  # Checkpoints will go here
    num_epochs=100,
    batch_size=32
)

print("Saving final model")
trainer.save_model(output_path)

### Evaluate other val sets

#### C10 Model

In [None]:
import pandas as pd


#Load in checkpoints.

checkpoint_directory= 'C10_model/results/'
model_name = 'C10'

#Change variables above ^^^^^

checkpoints = sorted(
    os.listdir(checkpoint_directory),
    key=lambda x: int(x.split('-')[-1]) if x.startswith('checkpoint-') else float('inf')
)
orig, uni, rot, ring = [], [], [], []
val_sets = [
    ("original", val_original_set, orig),
    ("Uniform_Noise", val_noise_set, uni),
    ("Rotate_90deg", val_rotate_set, rot),
    ("Ring_Artifact_v1", val_ct_set, ring)
]


val_sets = [v for v in val_sets if v[0] != "original"]

#Change which val sets to use ^^^^



count = 1
for checkpoint in checkpoints:
    load_checkpoint = os.path.join(checkpoint_directory, checkpoint)
    # Load config and model
    config = ResNetConfig.from_pretrained(load_checkpoint)
    model = ResNetForMultiLabel.from_pretrained(load_checkpoint, config=config)

    for val in val_sets:
        val_data = val[1]
        val_list = val[2]
        metrics = evaluate_model(eval_dataset=val_data, model=model, output_dir="val_results", num_epochs=0, batch_size=32)
        metrics["checkpoint"] = checkpoint
        # print(metrics)
        val_list.append(metrics)
        df = pd.DataFrame(val_list)
        # print(df)
        df.to_csv(f"{model_name}_val_{val[0]}.csv", index=False)
    count += 1
    if count > 100:
        #Sometimes non chkpt things can be found
        break
        



#### C11 Model (orig, ct, uni)

In [None]:
import pandas as pd


#Load in checkpoints.

checkpoint_directory= 'C11_model/results/'
model_name = 'C11'

#Change variables above ^^^^^

checkpoints = sorted(
    os.listdir(checkpoint_directory),
    key=lambda x: int(x.split('-')[-1]) if x.startswith('checkpoint-') else float('inf')
)
orig, uni, rot, ring = [], [], [], []
val_sets = [
    ("original", val_original_set, orig),
    ("Uniform_Noise", val_noise_set, uni),
    ("Rotate_90deg", val_rotate_set, rot),
    ("Ring_Artifact_v1", val_ct_set, ring)
]


val_sets = [v for v in val_sets if v[0] != "original"]

#Change which val sets to use ^^^^



count = 1
for checkpoint in checkpoints:
    load_checkpoint = os.path.join(checkpoint_directory, checkpoint)
    # Load config and model
    config = ResNetConfig.from_pretrained(load_checkpoint)
    model = ResNetForMultiLabel.from_pretrained(load_checkpoint, config=config)

    for val in val_sets:
        val_data = val[1]
        val_list = val[2]
        metrics = evaluate_model(eval_dataset=val_data, model=model, output_dir="val_results", num_epochs=0, batch_size=32)
        metrics["checkpoint"] = checkpoint
        # print(metrics)
        val_list.append(metrics)
        df = pd.DataFrame(val_list)
        # print(df)
        df.to_csv(f"{model_name}_val_{val[0]}.csv", index=False)
    count += 1
    if count > 100:
        #Sometimes non chkpt things can be found
        break
        



#### A rotate 90 val test

In [None]:
import pandas as pd


#Load in checkpoints.

checkpoint_directory= 'A3_ROT/results/'
model_name = 'a_rotate_90'

#Change variables above ^^^^^

checkpoints = sorted(
    os.listdir(checkpoint_directory),
    key=lambda x: int(x.split('-')[-1]) if x.startswith('checkpoint-') else float('inf')
)
orig, uni, rot, ring = [], [], [], []
val_sets = [
    ("original", val_original_set, orig),
    ("Uniform_Noise", val_noise_set, uni),
    ("Rotate_90deg", val_rotate_set, rot),
    ("Ring_Artifact_v1", val_ct_set, ring)
]


val_sets = [v for v in val_sets if v[0] != "Rotate_90deg"]

#Change which val sets to use ^^^^



count = 1
for checkpoint in checkpoints:
    load_checkpoint = os.path.join(checkpoint_directory, checkpoint)
    # Load config and model
    config = ResNetConfig.from_pretrained(load_checkpoint)
    model = ResNetForMultiLabel.from_pretrained(load_checkpoint, config=config)

    for val in val_sets:
        val_data = val[1]
        val_list = val[2]
        metrics = evaluate_model(eval_dataset=val_data, model=model, output_dir="val_results", num_epochs=0, batch_size=32)
        metrics["checkpoint"] = checkpoint
        # print(metrics)
        val_list.append(metrics)
        df = pd.DataFrame(val_list)
        # print(df)
        df.to_csv(f"{model_name}_val_{val[0]}.csv", index=False)
    count += 1
    if count > 100:
        #Sometimes non chkpt things can be found
        break
        



#### A Uniform val test

In [None]:
import pandas as pd


#Load in checkpoints.

checkpoint_directory= 'A4_UNI/results/'
model_name = 'a_uniform_noise'

#Change variables above ^^^^^

checkpoints = sorted(
    os.listdir(checkpoint_directory),
    key=lambda x: int(x.split('-')[-1]) if x.startswith('checkpoint-') else float('inf')
)
orig, uni, rot, ring = [], [], [], []
val_sets = [
    ("original", val_original_set, orig),
    ("Uniform_Noise", val_noise_set, uni),
    ("Rotate_90deg", val_rotate_set, rot),
    ("Ring_Artifact_v1", val_ct_set, ring)
]


val_sets = [v for v in val_sets if v[0] != "Uniform_Noise"]

#Change which val sets to use ^^^^



count = 1
for checkpoint in checkpoints:
    load_checkpoint = os.path.join(checkpoint_directory, checkpoint)
    # Load config and model
    config = ResNetConfig.from_pretrained(load_checkpoint)
    model = ResNetForMultiLabel.from_pretrained(load_checkpoint, config=config)

    for val in val_sets:
        val_data = val[1]
        val_list = val[2]
        metrics = evaluate_model(eval_dataset=val_data, model=model, output_dir="val_results", num_epochs=0, batch_size=32)
        metrics["checkpoint"] = checkpoint
        # print(metrics)
        val_list.append(metrics)
        df = pd.DataFrame(val_list)
        # print(df)
        df.to_csv(f"{model_name}_val_{val[0]}.csv")
    count += 1
    if count > 100:
        #Sometimes non chkpt things can be found
        break
        



#### EVALUATE TEST SETS

In [40]:
test_set_loaded = np.load('test_concatenated_dataset_full.npz')
key_list = [key for key in val_set_loaded.files if key != 'label']
print(key_list)

test_rotate_set = preprocess_data(test_set_loaded, 'Rotate_90deg')
test_original_set = preprocess_data(test_set_loaded, 'original')
test_noise_set = preprocess_data(test_set_loaded, 'Uniform_Noise')
test_ct_set = preprocess_data(test_set_loaded, 'Ring_Artifact_v1')


['original', 'Uniform_Noise', 'Rotate_90deg', 'Ring_Artifact_v1']

Generating Rotate_90deg set
Labels shape: (17778, 1)
Images shape: (17778, 224, 224)

Generating original set
Labels shape: (17778, 1)
Images shape: (17778, 224, 224)

Generating Uniform_Noise set
Labels shape: (17778, 1)
Images shape: (17778, 224, 224)

Generating Ring_Artifact_v1 set
Labels shape: (17778, 1)
Images shape: (17778, 224, 224)


#### Best model checkpoints based on validataion F1


##### A_original checkpoint-10810
##### A_ring checkpoint-24863
##### A_rotate checkpoint-5405
##### A_uniform checkpoint-43240
##### C10 checkpoint-155568
##### C11 checkpoint-6482

In [45]:
import pandas as pd


testing_models = [
    ('A_original', os.path.join('A1_Original','results', 'checkpoint-10810')),
    ('A_ring', os.path.join('A2_CT','results', 'checkpoint-24863')),
    ('A_rotate', os.path.join('A3_ROT','results', 'checkpoint-5405')),
    ('A_uniform', os.path.join('A4_UNI','results', 'checkpoint-43240')),
    ('C10', os.path.join('C10_model','results', 'checkpoint-155568')),
    ('C11', os.path.join('C11_model','results', 'checkpoint-6482')),
]


#Load in checkpoints.
test_sets = [
    (test_rotate_set, 'Rotate_90deg'),
    (test_original_set, 'original'),
    (test_noise_set, 'Uniform_Noise'),
    (test_ct_set, 'Ring_Artifact_v1')
]




#Change which test sets to use ^^^^

for model_name, checkpoint_path in testing_models:
    config = ResNetConfig.from_pretrained(checkpoint_path)
    model = ResNetForMultiLabel.from_pretrained(checkpoint_path, config=config)
    result = []
    
    for testing_set, test_name in test_sets:
        metrics = evaluate_model(eval_dataset=testing_set, model=model, output_dir="val_results", num_epochs=0, batch_size=32)
        metrics['test_set'] = test_name
        result.append(metrics)
        
    df = pd.DataFrame(result)
    df.to_csv(f"{model_name}_test_results.csv")
        
        





  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]



  0%|          | 0/2223 [00:00<?, ?it/s]