In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import time
import os
from datasets.amos import AmosDataset
import numpy as np
from PIL import Image

def load_amos(path):
    train = AmosDataset(path, split='train',
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Resize(256),
                            transforms.CenterCrop(256),
                            transforms.Normalize(0.5, 0.5)
                        ]),
                        index_range = range(0, 500)
                        )

    val = AmosDataset(path, split='val',
                      transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Resize(256),
                            transforms.CenterCrop(256),
                            # transforms.Normalize(0.5, 0.5)
                      ]),
                      index_range=range(0, 500)
                      )
    return train, val
    


def data_loaders(train_data, val_data, batch_size):

    train_loader = DataLoader(train_data,
                              batch_size=batch_size,
                              shuffle=True,
                              pin_memory=True)
    val_loader = DataLoader(val_data,
                            batch_size=batch_size,
                            shuffle=True,
                            pin_memory=True)
    return train_loader, val_loader

def calc_x_train_var(train_data):
    print('Calculating x_train_var')
    all_images = []
    for img, _ in train_data:
        img = img / 255.0
        all_images.append(img.flatten())
    all_images = np.concatenate(all_images)
    return np.var(all_images)
        

def load_data_and_data_loaders(path, batch_size):
    training_data, validation_data = load_amos(path)
    training_loader, validation_loader = data_loaders(
        training_data, validation_data, batch_size)
    # x_train_var = calc_x_train_var(training_data)

    return training_data, validation_data, training_loader, validation_loader#, x_train_var

In [2]:
tr, va, trl, val = load_data_and_data_loaders('/vol/aimspace/users/hunecke/diffusion/data/amos_slices/', 16)

Loading Amos train data
Loaded 26069 train images
Loading Amos val data
Loaded 15361 val images


In [5]:
trl = iter(trl)

In [9]:
tr[0]

(tensor([[[-0.9864, -0.9840, -0.9868,  ..., -0.9807, -0.9875, -0.9830],
          [-0.9866, -0.9884, -0.9857,  ..., -0.9836, -0.9876, -0.9815],
          [-0.9836, -0.9884, -0.9817,  ..., -0.9873, -0.9839, -0.9866],
          ...,
          [-0.9368, -0.9381, -0.9355,  ..., -0.9192, -0.9240, -0.9374],
          [-0.9377, -0.9315, -0.9284,  ..., -0.9344, -0.9275, -0.9324],
          [-0.9320, -0.9199, -0.9301,  ..., -0.9380, -0.9394, -0.9433]]]),
 <PIL.PngImagePlugin.PngImageFile image mode=L size=512x512>)

In [3]:
tr[100][0][0].min()

tensor(-1.)

In [4]:
tr[100][0][0]


tensor([[-1., -1., -1.,  ..., -1., -1., -1.],
        [-1., -1., -1.,  ..., -1., -1., -1.],
        [-1., -1., -1.,  ..., -1., -1., -1.],
        ...,
        [-1., -1., -1.,  ..., -1., -1., -1.],
        [-1., -1., -1.,  ..., -1., -1., -1.],
        [-1., -1., -1.,  ..., -1., -1., -1.]])