In [5]:
import torch
from transformers import ResNetConfig
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import transformers
from torchvision import transforms
from branched_resnet import ResNetForMultiLabel
from branched_resnet import OrganAMNISTDataset, compute_metrics, train_model
import random
import numpy as np
import os 

In [6]:
#SET SEEDS
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x1df280ad530>

#### Import NPZ by concatenating

In [9]:
import os

def import_data(directory, save_path=None, save=False):

    data = []
    for filename in os.listdir(directory):
        if filename.endswith('.npz'):
            file_path = os.path.join(directory, filename)
            loaded_data = np.load(file_path)
            data.append(loaded_data)

    # concatenate the data from all files
    all_data = {}
    for key in data[0].keys():
        all_data[key] = np.concatenate([d[key] for d in data], axis=0)

    # check the shape of the concatenated data
    for key, value in all_data.items():
        print(f"{key}: {value.shape}")  

    if save:
        if save_path is None:
            save_path = f'datasets/{directory}_concatenated_data.npz'
        np.savez(save_path, **all_data)
        print(f"Data saved to {save_path}")

    return all_data

### Image normalizer 

In [10]:
def normalize_image(image, mean=0.5, std=0.5):
    """
    Normalize an image tensor to have a mean and standard deviation.
    """
    return (image - mean) / std

def normalize_images(images, mean=0.5, std=0.5):
    """
    Normalize a list of images.
    """
    return [normalize_image(image, mean, std) for image in images]

### Dataset Loader

In [16]:
#Modified CustomImageDataset Loader

class ModifiedCustomImageDataset(Dataset):
    def __init__(self, images, labels1,  transform=None):
        self.images = images 
        self.labels1 = labels1

        if transform is None:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x.repeat(3, 1, 1))  # Grayscale to 3-channel
            ])
        else:
            self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx].astype(np.float32)
        label1 = int(self.labels1[idx])


        if self.transform:
            img = self.transform(img)

        return {
            "pixel_values": img,
            "labels": int(label1) if torch.is_tensor(label1) else label1,
        }
        


#### Preprocessing functions for single distortion and multi.

In [None]:
#Modified

def validation_preprocess_data(data, key='original'):
    """
    Modified from sam's implementation to preprocess 1 set of images. 
    
    key : str (distortion name)
    """
    keys = data.files 
    print(f'\nGenerating {key} validataion set')
    
    images = data[key]
    labels = data['label']
    
    normalized_images = []
    for image in images:
        normalized_images.append(normalize_images(image))
        
    labels = np.array(labels)
    normalized_images = np.array(normalized_images)
    
    
    print(f"Labels shape: {labels.shape}")
    print(f"Images shape: {normalized_images.shape}")

    dataset = ModifiedCustomImageDataset(images=normalized_images, labels1=labels)

    return dataset


#### Preprocess data into distinct sets

In [None]:
val_set_loaded = np.load('val_concatenated_data.npz')
key_list = [key for key in val_set_loaded.files if key != 'label']
print(key_list)

val_rotate_set = validation_preprocess_data(val_set_loaded, 'Rotate_90deg')
val_original_set = validation_preprocess_data(val_set_loaded, 'original')
val_noise_set = validation_preprocess_data(val_set_loaded, 'Uniform_Noise')




train_set_loaded = np.load('training_concatenated_data.npz')
train_rotate_set = validation_preprocess_data(train_set_loaded, 'Rotate_90deg')
train_original_set = validation_preprocess_data(train_set_loaded, 'original')
train_noise_set = validation_preprocess_data(train_set_loaded, 'Uniform_Noise')




['original', 'Uniform_Noise', 'Rotate_90deg']

Generating Rotate_90deg validataion set
Labels shape: (1, 6491, 1)
Images shape: (1, 6491, 224, 224)

Generating original validataion set
Labels shape: (1, 6491, 1)
Images shape: (1, 6491, 224, 224)

Generating Uniform_Noise validataion set
Labels shape: (1, 6491, 1)
Images shape: (1, 6491, 224, 224)


## Train Model

In [None]:
output_path = os.path.join('C_model', 'results')


config = ResNetConfig()
model = ResNetForMultiLabel(config)

print("Starting training")
trainer = train_model(
    train_dataset=train_original_set,
    eval_dataset=val_original_set,
    model=model,
    output_dir=output_path,  # Checkpoints will go here
    num_epochs=100,
    batch_size=32
)

print("Saving final model")
trainer.save_model(output_path)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


Starting training


***** Running training *****
  Num examples = 103,683
  Num Epochs = 100
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 1
  Total optimization steps = 324,100
  Number of trainable parameters = 24,562,763


  0%|          | 0/324100 [00:00<?, ?it/s]

Could not estimate the number of tokens of the input, floating-point operations will not be computed


In [2]:

if torch.cuda.is_available():
    print("✅ CUDA is available!")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    device = torch.device("cuda")
else:
    print("❌ CUDA is NOT available. Using CPU.")
    device = torch.device("cpu")

✅ CUDA is available!
Using GPU: NVIDIA GeForce RTX 3090
