# Dataset tests

In [8]:
from torchvision import datasets, transforms
import torch

# MNIST

In [9]:
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])
dataset = datasets.MNIST('example_data', train=True, download=True, transform=transform)
val_set = datasets.MNIST('example_data', train=False, download=True, transform=transform)

print(len(val_set))
print(len(dataset))

10000
60000


# Labels

In [10]:
labels = torch.tensor([1, 0, 5, 2])
labels = labels.unsqueeze(0)

target = torch.zeros(labels.size(0), 10).scatter_(1, labels, 1.)
print(target)
print(labels)

tensor([[1., 1., 1., 0., 0., 1., 0., 0., 0., 0.]])
tensor([[1, 0, 5, 2]])


# OCT MNIST

In [18]:
model_kwargs = {
    'in_channels' : 1, # not in use yet
    'n_classes': None, # filled in the dataset
    'out_dim' :  [1, 8, 16, 32], # [1, 8, 16, 32], #[1, 16, 24, 32] # entry, decent1, decent2, decent3
    'grid_size' : 18*18,
    'criterion': torch.nn.CrossEntropyLoss(),# torch.nn.BCEWithLogitsLoss(),
    'new_cc_mode' : True, # this is for using the new connection cost loss term
    'reset_optimiser_at_update' : True,
    'optimizer': "sgd", # sgd adamw
    'base_lr': 0.001,
    'min_lr' : 0.00001,
    'momentum' : 0.9,
    'lr_update' : 100,
    # decentnet
    'cc_weight': 5, # high weight as the cc doesn't change a lot
    'cc_metric' : 'l2_torch', # connection cost metric (for loss) - distance metric # no idea how the torch works oops
    'ci_metric' : 'l2', # todo: should be l2 # channel importance metric (for pruning)
    'cm_metric' : 'not implemented yet', # 'count', # crossing minimisation 
    'update_every_nth_epoch' : 3, # 5
    'pretrain_epochs' : 15, # 20
    'prune_keep' : 0.95, # 0.97, # in each epoch
    'prune_keep_total' : 0.4, # this number is not exact, depends on the prune_keep value
}

train_kwargs = {
    'input_data_csv': ["data_prep/data_octa500.csv"],
    'result_path': "examples/example_results", # "example_results/lightning_logs", # not in use??
    'exp_name': "every3_95", # must include dataset name, otherwise mnist is used
    'load_ckpt_file' : "version_16/checkpoints/mu_epoch=8-val_f1_macro=0.64-unpruned=2373.ckpt", # "version_0/checkpoints/epoch=94-unpruned=1600-val_f1=0.67.ckpt", # 'version_94/checkpoints/epoch=26-step=1080.ckpt', # change this for loading a file and using "test", if you want training, keep None
    'load_mode' : False, # True, False
    'dataset' : 'octmnist',
    'epochs': 40, # including the pretrain epochs - no adding up
    'img_size' : 28, #168, # keep mnist at original size, training didn't work when i increased the size ... # MNIST/MedMNIST 28 × 28 Pixel
    'p_augment' : 0.2, # probabiliby of torchvision transforms of training data (doesn't apply to all transforms) # 0.1 low, 0.5 half, 1 always
    'batch_size': 1, # laptop: 2, pc: 128, # the higher the batch_size the faster the training - every iteration adds A LOT OF comp cost
    'log_every_n_steps' : 50, # lightning default: 50 # needs to be bigger than the amount of steps in an epoch (based on trainset size and batchsize)
    'device': "cuda",
    'num_workers' : 0, # 18, # 18 for seri computer, 0 or 8 for my laptop # make sure smaller than activate dataset sizes
    'train_size' : -1, # total, none = 0, all = -1  (batch size * forward passes per epoch) # set 0 to skip training and just do testing
    'val_size' : -1, # total, none = 0, all = -1 (batch size * forward passes per epoch) 
    'test_size' : -1, # total, none = 0, all = -1 (batch size * forward passes per epoch)
    'octa500_id' : 200-1, # not in use - we use preselected data from a csv
    'xai_done' : False, # DO NOT CHANGE, WILL BE CHANGED IN CODE
}

print("train kwargs", train_kwargs)
print("model kwargs", model_kwargs)

kwargs = {'train_kwargs':train_kwargs, 'model_kwargs':model_kwargs}

train kwargs {'input_data_csv': ['data_prep/data_octa500.csv'], 'result_path': 'examples/example_results', 'exp_name': 'every3_95', 'load_ckpt_file': 'version_16/checkpoints/mu_epoch=8-val_f1_macro=0.64-unpruned=2373.ckpt', 'load_mode': False, 'dataset': 'octmnist', 'epochs': 40, 'img_size': 28, 'p_augment': 0.2, 'batch_size': 1, 'log_every_n_steps': 50, 'device': 'cuda', 'num_workers': 0, 'train_size': -1, 'val_size': -1, 'test_size': -1, 'octa500_id': 199, 'xai_done': False}
model kwargs {'in_channels': 1, 'n_classes': None, 'out_dim': [1, 8, 16, 32], 'grid_size': 324, 'criterion': CrossEntropyLoss(), 'new_cc_mode': True, 'reset_optimiser_at_update': True, 'optimizer': 'sgd', 'base_lr': 0.001, 'min_lr': 1e-05, 'momentum': 0.9, 'lr_update': 100, 'cc_weight': 5, 'cc_metric': 'l2_torch', 'ci_metric': 'l2', 'cm_metric': 'not implemented yet', 'update_every_nth_epoch': 3, 'pretrain_epochs': 15, 'prune_keep': 0.95, 'prune_keep_total': 0.4}


In [22]:
import torch
# from torch.utils.data import DataLoader

import sys 
sys.path.insert(0, "../helper")
from data.octmnist import DataLoaderOCTMNIST

dataloader = DataLoaderOCTMNIST(train_kwargs, model_kwargs)   

# Example DataLoader (replace with your actual DataLoader)
# Assumes labels are integers from 0 to num_classes-1

all_labels = []
# Extract all labels from the DataLoader
for inputs, labels in dataloader.train_dataloader:
    all_labels.append(labels.squeeze(0))
# Concatenate all labels into a single tensor
all_labels = torch.cat(all_labels)
sorted_labels, sorted_indices = torch.sort(all_labels)
# Count the occurrences of each class
class_counts = torch.bincount(sorted_labels)
# Calculate weights (inverse of class frequency)
class_weights = 1.0 / class_counts.float()
# Normalize weights (optional, but recommended for stability)
class_weights = class_weights / class_weights.sum()
print("class_counts", class_counts, "class_weights:", class_weights)

# Pass the weights to CrossEntropyLoss
criterion = torch.nn.CrossEntropyLoss(weight=weights)



Using downloaded and verified file: C:\Users\Prinzessin\.medmnist\octmnist.npz
Using downloaded and verified file: C:\Users\Prinzessin\.medmnist\octmnist.npz
Using downloaded and verified file: C:\Users\Prinzessin\.medmnist\octmnist.npz
********** DECENT INFO: DataLoader infos **********
python_class : OCTMNIST
description : The OCTMNIST is based on a prior dataset of 109,309 valid optical coherence tomography (OCT) images for retinal diseases. The dataset is comprised of 4 diagnosis categories, leading to a multi-class classification task. We split the source training set with a ratio of 9:1 into training and validation set, and use its source validation set as the test set. The source images are gray-scale, and their sizes are (384−1,536)×(277−512). We center-crop the images and resize them into 1×28×28.
url : https://zenodo.org/records/10519652/files/octmnist.npz?download=1
MD5 : c68d92d5b585d8d81f7112f81e2d0842
task : multi-class
label : {'0': 'cnv', '1': 'dme', '2': 'drusen', '3':