In [3]:
import torch
from transformers import ResNetConfig
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import transformers
from torchvision import transforms
from resnet import ResNetForMultiLabel
from resnet import OrganAMNISTDataset, compute_metrics, train_model
import random
import numpy as np
import os 

In [4]:
#SET SEEDS
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x162b86675d0>

#### Import NPZ by concatenating

In [5]:
import os

def import_data(directory, save_path=None, save=False):

    data = []
    for filename in os.listdir(directory):
        if filename.endswith('.npz'):
            file_path = os.path.join(directory, filename)
            loaded_data = np.load(file_path)
            data.append(loaded_data)

    # concatenate the data from all files
    all_data = {}
    for key in data[0].keys():
        all_data[key] = np.concatenate([d[key] for d in data], axis=0)

    # check the shape of the concatenated data
    for key, value in all_data.items():
        print(f"{key}: {value.shape}")  

    if save:
        if save_path is None:
            save_path = f'datasets/{directory}_concatenated_data.npz'
        np.savez(save_path, **all_data)
        print(f"Data saved to {save_path}")

    return all_data

### Image normalizer 

In [6]:
def normalize_image(image, mean=0.5, std=0.5):
    """
    Normalize an image tensor to have a mean and standard deviation.
    """
    return (image - mean) / std

def normalize_images(images, mean=0.5, std=0.5):
    """
    Normalize a list of images.
    """
    return [normalize_image(image, mean, std) for image in images]

### Dataset Loader

In [7]:
#Modified CustomImageDataset Loader

class ModifiedCustomImageDataset(Dataset):
    def __init__(self, images, labels1,  transform=None):
        self.images = images 
        self.labels1 = labels1

        if transform is None:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x.repeat(3, 1, 1))  # Grayscale to 3-channel
            ])
        else:
            self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx].astype(np.float32)
        label1 = int(self.labels1[idx])


        if self.transform:
            img = self.transform(img)

        return {
            "pixel_values": img,
            "labels": int(label1) if torch.is_tensor(label1) else label1,
        }
        


#### Preprocessing functions for single distortion and multi.

In [8]:
#Modified

def training_preprocess_data(data, float16=False, keylist=['original', 'Uniform_Noise', 'Rotate_90deg', 'Ring_Artifact_v1']):
    """
    Modified to combine the original + distorted images into one set for training C model.
    """
    keys = data.files 
    print(f'Keys found: {keys}')
    
    images = []
    labels = []
    
    for key in keylist: #Iterate through the different distortions
        if key == 'label':
            continue
        images.append(data[key])
        labels.append(data['label'])
    
    normalized_images = []
    for image in images:
        normalized_images.append(normalize_images(image))
        
    labels = np.concatenate(labels, axis=0)
    
    if float16:
        normalized_images = np.concatenate(normalized_images, axis=0).astype(np.float16)
    else:
        normalized_images = np.concatenate(normalized_images, axis=0)
    
    
    print(f"Labels shape: {labels.shape}")
    print(f"Images shape: {normalized_images.shape}")

    dataset = ModifiedCustomImageDataset(images=normalized_images, labels1=labels)

    return dataset


def validation_preprocess_data(data, key='original'):
    """
    Modified from sam's implementation to preprocess 1 set of images. 
    
    key : str (distortion name)
    """
    keys = data.files 
    print(f'\nGenerating {key} validataion set')
    
    images = []
    labels = []

    images.append(data[key])
    labels.append(data['label'])
    
    normalized_images = []
    for image in images:
        normalized_images.append(normalize_images(image))
        
    labels = np.array(labels).squeeze()
    normalized_images = np.array(normalized_images).squeeze()
    
    
    print(f"Labels shape: {labels.shape}")
    print(f"Images shape: {normalized_images.shape}")

    dataset = ModifiedCustomImageDataset(images=normalized_images, labels1=labels)

    return dataset


#### Preprocess validation data into distinct sets

In [9]:
val_set_loaded = np.load('val_concatenated_dataset_full.npz')
key_list = [key for key in val_set_loaded.files if key != 'label']
print(key_list)

val_rotate_set = validation_preprocess_data(val_set_loaded, 'Rotate_90deg')
val_original_set = validation_preprocess_data(val_set_loaded, 'original')
val_noise_set = validation_preprocess_data(val_set_loaded, 'Uniform_Noise')
val_ct_set = validation_preprocess_data(val_set_loaded, 'Ring_Artifact_v1')

['original', 'Uniform_Noise', 'Rotate_90deg', 'Ring_Artifact_v1']

Generating Rotate_90deg validataion set
Labels shape: (6491,)
Images shape: (6491, 224, 224)

Generating original validataion set
Labels shape: (6491,)
Images shape: (6491, 224, 224)

Generating Uniform_Noise validataion set
Labels shape: (6491,)
Images shape: (6491, 224, 224)

Generating Ring_Artifact_v1 validataion set
Labels shape: (6491,)
Images shape: (6491, 224, 224)


### Preprocess training data into one set

In [10]:
keylist=['original', 'Uniform_Noise', 'Rotate_90deg', 'Ring_Artifact_v1']
keylist=['original', 'Uniform_Noise','Ring_Artifact_v1']
train_set_loaded = np.load('training_concatenated_dataset_full.npz')
training_set = training_preprocess_data(train_set_loaded, float16=True, keylist=keylist)

Keys found: ['original', 'label', 'Uniform_Noise', 'Rotate_90deg', 'Ring_Artifact_v1']
Labels shape: (103683, 1)
Images shape: (103683, 224, 224)


## Train Model

In [None]:
output_path = os.path.join('C11_model', 'results')


config = ResNetConfig()
model = ResNetForMultiLabel(config)

print("Starting training")
trainer = train_model(
    train_dataset=training_set,
    eval_dataset=val_original_set,
    model=model,
    output_dir=output_path,  # Checkpoints will go here
    num_epochs=100,
    batch_size=32
)

print("Saving final model")
trainer.save_model(output_path)

In [12]:

keylist=['original', 'Rotate_90deg']
train_set_loaded = np.load('training_concatenated_dataset_full.npz')
training_set = training_preprocess_data(train_set_loaded, float16=True, keylist=keylist)
output_path = os.path.join('C10_model_r', 'results')


config = ResNetConfig()
model = ResNetForMultiLabel(config)

print("Starting training")
trainer = train_model(
    train_dataset=training_set,
    eval_dataset=val_original_set,
    model=model,
    output_dir=output_path,  # Checkpoints will go here
    num_epochs=100,
    batch_size=32
)

print("Saving final model")
trainer.save_model(output_path)

Keys found: ['original', 'label', 'Uniform_Noise', 'Rotate_90deg', 'Ring_Artifact_v1']
Labels shape: (69122, 1)
Images shape: (69122, 224, 224)
Starting training


  0%|          | 0/216100 [00:00<?, ?it/s]

Could not estimate the number of tokens of the input, floating-point operations will not be computed


{'loss': 1.1112, 'grad_norm': 91.74635314941406, 'learning_rate': 0.099, 'epoch': 1.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.6746246218681335, 'eval_accuracy': 0.7314743490987521, 'eval_precision': 0.7848419833368225, 'eval_recall': 0.7442682828283833, 'eval_f1': 0.7338933867561092, 'eval_runtime': 15.3194, 'eval_samples_per_second': 423.71, 'eval_steps_per_second': 53.005, 'epoch': 1.0}
{'loss': 0.4656, 'grad_norm': 51.78986358642578, 'learning_rate': 0.098, 'epoch': 2.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 1.0267927646636963, 'eval_accuracy': 0.7710676321059929, 'eval_precision': 0.8610781937115337, 'eval_recall': 0.8041851458333054, 'eval_f1': 0.7780737687800126, 'eval_runtime': 9.8318, 'eval_samples_per_second': 660.204, 'eval_steps_per_second': 82.589, 'epoch': 2.0}
{'loss': 0.2781, 'grad_norm': 89.80549621582031, 'learning_rate': 0.097, 'epoch': 3.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.11012283712625504, 'eval_accuracy': 0.9599445385918964, 'eval_precision': 0.956443224751045, 'eval_recall': 0.9585391379318888, 'eval_f1': 0.957171332609201, 'eval_runtime': 10.092, 'eval_samples_per_second': 643.185, 'eval_steps_per_second': 80.46, 'epoch': 3.0}
{'loss': 0.1934, 'grad_norm': 113.8421401977539, 'learning_rate': 0.096, 'epoch': 4.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.2521253824234009, 'eval_accuracy': 0.9084886766291789, 'eval_precision': 0.9109704104929105, 'eval_recall': 0.8967134908569391, 'eval_f1': 0.8960745253690695, 'eval_runtime': 10.1539, 'eval_samples_per_second': 639.261, 'eval_steps_per_second': 79.969, 'epoch': 4.0}
{'loss': 0.1362, 'grad_norm': 34.817928314208984, 'learning_rate': 0.095, 'epoch': 5.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.13106010854244232, 'eval_accuracy': 0.9554768140502233, 'eval_precision': 0.9622655600545343, 'eval_recall': 0.9626704917343627, 'eval_f1': 0.9615730728239417, 'eval_runtime': 9.8994, 'eval_samples_per_second': 655.693, 'eval_steps_per_second': 82.025, 'epoch': 5.0}
{'loss': 0.104, 'grad_norm': 70.72021484375, 'learning_rate': 0.094, 'epoch': 6.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.1973087638616562, 'eval_accuracy': 0.9536281004467725, 'eval_precision': 0.9585319591742691, 'eval_recall': 0.9546408607505833, 'eval_f1': 0.9530401920855962, 'eval_runtime': 10.1546, 'eval_samples_per_second': 639.216, 'eval_steps_per_second': 79.964, 'epoch': 6.0}
{'loss': 0.0813, 'grad_norm': 123.73053741455078, 'learning_rate': 0.09300000000000001, 'epoch': 7.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.2792467474937439, 'eval_accuracy': 0.9422276998921584, 'eval_precision': 0.9509828829142264, 'eval_recall': 0.9468301350657872, 'eval_f1': 0.944651514803036, 'eval_runtime': 9.9616, 'eval_samples_per_second': 651.599, 'eval_steps_per_second': 81.513, 'epoch': 7.0}
{'loss': 0.0626, 'grad_norm': 99.90606689453125, 'learning_rate': 0.09200000000000001, 'epoch': 8.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.1345997005701065, 'eval_accuracy': 0.9690340471421969, 'eval_precision': 0.966896365595315, 'eval_recall': 0.9647403806926342, 'eval_f1': 0.9654872897214735, 'eval_runtime': 10.1235, 'eval_samples_per_second': 641.183, 'eval_steps_per_second': 80.21, 'epoch': 8.0}
{'loss': 0.0509, 'grad_norm': 198.91490173339844, 'learning_rate': 0.09100000000000001, 'epoch': 9.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.1765868067741394, 'eval_accuracy': 0.9588661223232168, 'eval_precision': 0.9692923835109785, 'eval_recall': 0.9593203767117157, 'eval_f1': 0.963199463797888, 'eval_runtime': 10.183, 'eval_samples_per_second': 637.436, 'eval_steps_per_second': 79.741, 'epoch': 9.0}
{'loss': 0.0409, 'grad_norm': 251.77304077148438, 'learning_rate': 0.09000000000000001, 'epoch': 10.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.225211039185524, 'eval_accuracy': 0.9488522569711909, 'eval_precision': 0.95364185661765, 'eval_recall': 0.9463694220507591, 'eval_f1': 0.9466562913687298, 'eval_runtime': 10.9362, 'eval_samples_per_second': 593.533, 'eval_steps_per_second': 74.249, 'epoch': 10.0}
{'loss': 0.0292, 'grad_norm': 136.08534240722656, 'learning_rate': 0.08900000000000001, 'epoch': 11.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.11855398863554001, 'eval_accuracy': 0.9750423663534125, 'eval_precision': 0.9765433129768973, 'eval_recall': 0.9754954282706817, 'eval_f1': 0.9754694841584239, 'eval_runtime': 11.2946, 'eval_samples_per_second': 574.7, 'eval_steps_per_second': 71.893, 'epoch': 11.0}
{'loss': 0.023, 'grad_norm': 103.93693542480469, 'learning_rate': 0.08800000000000001, 'epoch': 12.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.15797245502471924, 'eval_accuracy': 0.9639500847327068, 'eval_precision': 0.9622593185743366, 'eval_recall': 0.9657963943101393, 'eval_f1': 0.9635249687419383, 'eval_runtime': 11.0752, 'eval_samples_per_second': 586.082, 'eval_steps_per_second': 73.317, 'epoch': 12.0}
{'loss': 0.0169, 'grad_norm': 206.48959350585938, 'learning_rate': 0.08700000000000001, 'epoch': 13.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.16111551225185394, 'eval_accuracy': 0.9619473116623016, 'eval_precision': 0.965153542694564, 'eval_recall': 0.9646992919280997, 'eval_f1': 0.9644123286525759, 'eval_runtime': 11.1228, 'eval_samples_per_second': 583.575, 'eval_steps_per_second': 73.003, 'epoch': 13.0}
{'loss': 0.015, 'grad_norm': 123.8583984375, 'learning_rate': 0.08600000000000001, 'epoch': 14.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.19145958125591278, 'eval_accuracy': 0.9607148359266677, 'eval_precision': 0.9693130513649674, 'eval_recall': 0.9692820036856183, 'eval_f1': 0.9685222678591859, 'eval_runtime': 10.9343, 'eval_samples_per_second': 593.637, 'eval_steps_per_second': 74.262, 'epoch': 14.0}
{'loss': 0.0126, 'grad_norm': 52.01276779174805, 'learning_rate': 0.085, 'epoch': 15.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.0962800607085228, 'eval_accuracy': 0.9813588044985364, 'eval_precision': 0.9789538844354078, 'eval_recall': 0.9803740652462892, 'eval_f1': 0.9794225167791953, 'eval_runtime': 10.9253, 'eval_samples_per_second': 594.125, 'eval_steps_per_second': 74.323, 'epoch': 15.0}
{'loss': 0.0101, 'grad_norm': 196.5611114501953, 'learning_rate': 0.084, 'epoch': 16.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.10940303653478622, 'eval_accuracy': 0.9808966260976737, 'eval_precision': 0.9787273830155805, 'eval_recall': 0.9797999022632978, 'eval_f1': 0.9790548968613166, 'eval_runtime': 10.9826, 'eval_samples_per_second': 591.026, 'eval_steps_per_second': 73.935, 'epoch': 16.0}
{'loss': 0.0072, 'grad_norm': 5.167089462280273, 'learning_rate': 0.083, 'epoch': 17.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.11583942174911499, 'eval_accuracy': 0.9764289015560006, 'eval_precision': 0.9730838604474129, 'eval_recall': 0.9748029335942451, 'eval_f1': 0.973136268919475, 'eval_runtime': 10.9142, 'eval_samples_per_second': 594.727, 'eval_steps_per_second': 74.398, 'epoch': 17.0}
{'loss': 0.0068, 'grad_norm': 130.91151428222656, 'learning_rate': 0.082, 'epoch': 18.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.1582583785057068, 'eval_accuracy': 0.97350177168387, 'eval_precision': 0.9761296679016738, 'eval_recall': 0.9736894699913216, 'eval_f1': 0.9738427356319597, 'eval_runtime': 11.0294, 'eval_samples_per_second': 588.518, 'eval_steps_per_second': 73.621, 'epoch': 18.0}
{'loss': 0.0058, 'grad_norm': 18.852920532226562, 'learning_rate': 0.08100000000000002, 'epoch': 19.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.09180068224668503, 'eval_accuracy': 0.9812047450315822, 'eval_precision': 0.9808044651657446, 'eval_recall': 0.9812886013480759, 'eval_f1': 0.980909945146178, 'eval_runtime': 10.8425, 'eval_samples_per_second': 598.665, 'eval_steps_per_second': 74.891, 'epoch': 19.0}
{'loss': 0.0056, 'grad_norm': 122.7890396118164, 'learning_rate': 0.08000000000000002, 'epoch': 20.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.29844722151756287, 'eval_accuracy': 0.9559389924510862, 'eval_precision': 0.9620092851264609, 'eval_recall': 0.9594227095428064, 'eval_f1': 0.958194698238243, 'eval_runtime': 10.8603, 'eval_samples_per_second': 597.68, 'eval_steps_per_second': 74.768, 'epoch': 20.0}
{'loss': 0.0048, 'grad_norm': 108.02734375, 'learning_rate': 0.07900000000000001, 'epoch': 21.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.10927681624889374, 'eval_accuracy': 0.9816669234324449, 'eval_precision': 0.9797200756877497, 'eval_recall': 0.979774293351571, 'eval_f1': 0.9796015355123772, 'eval_runtime': 11.1527, 'eval_samples_per_second': 582.012, 'eval_steps_per_second': 72.808, 'epoch': 21.0}
{'loss': 0.0036, 'grad_norm': 141.32801818847656, 'learning_rate': 0.07800000000000001, 'epoch': 22.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.08976449817419052, 'eval_accuracy': 0.9859805885071637, 'eval_precision': 0.9840075110228106, 'eval_recall': 0.9843214910013546, 'eval_f1': 0.9840804937042453, 'eval_runtime': 11.0078, 'eval_samples_per_second': 589.675, 'eval_steps_per_second': 73.766, 'epoch': 22.0}
{'loss': 0.0027, 'grad_norm': 4.436644554138184, 'learning_rate': 0.07700000000000001, 'epoch': 23.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.102291040122509, 'eval_accuracy': 0.9781235556924973, 'eval_precision': 0.9753354689212742, 'eval_recall': 0.9769817111722042, 'eval_f1': 0.97591745649537, 'eval_runtime': 11.0411, 'eval_samples_per_second': 587.894, 'eval_steps_per_second': 73.543, 'epoch': 23.0}
{'loss': 0.0026, 'grad_norm': 0.000385869963793084, 'learning_rate': 0.07600000000000001, 'epoch': 24.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.12896493077278137, 'eval_accuracy': 0.9782776151594516, 'eval_precision': 0.9784043403110506, 'eval_recall': 0.9784245366265657, 'eval_f1': 0.978180281843127, 'eval_runtime': 10.8147, 'eval_samples_per_second': 600.204, 'eval_steps_per_second': 75.083, 'epoch': 24.0}
{'loss': 0.0024, 'grad_norm': 0.012109609320759773, 'learning_rate': 0.07500000000000001, 'epoch': 25.0}


  0%|          | 0/812 [00:00<?, ?it/s]

{'eval_loss': 0.18925584852695465, 'eval_accuracy': 0.9701124634108766, 'eval_precision': 0.9712731234959201, 'eval_recall': 0.9717825268491033, 'eval_f1': 0.9701752664648481, 'eval_runtime': 11.2321, 'eval_samples_per_second': 577.896, 'eval_steps_per_second': 72.293, 'epoch': 25.0}


KeyboardInterrupt: 

In [3]:

if torch.cuda.is_available():
    print("CUDA is available!")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    device = torch.device("cuda")
else:
    print("CUDA is not available. Using CPU.")
    device = torch.device("cpu")

CUDA is available!
Using GPU: NVIDIA GeForce RTX 3090
