In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

from package.dataset import DicomDataset

# Mean and Std

In [2]:
def get_mean_and_std(dataloader):
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0
    for data, _ in dataloader:
        # Mean over batch, height and width, but not over the channels
        channels_sum += torch.mean(data, dim=[0,2,3])
        channels_squared_sum += torch.mean(data**2, dim=[0,2,3])
        num_batches += 1
    
    mean = channels_sum / num_batches

    # std = sqrt(E[X^2] - (E[X])^2)
    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

    return mean, std

In [3]:
batch_size = 16

In [4]:
preprocess = transforms.Compose([
    transforms.CenterCrop(50), transforms.Resize(224),
])

trainset = DicomDataset(root='./data', transform=preprocess)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

In [5]:
mean, std = get_mean_and_std(trainloader)
print(mean, std)

tensor([62.2852, 62.2852, 62.2852]) tensor([76.8448, 76.8448, 76.8448])
