In [1]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

In [2]:
class RotatedMNIST(Dataset):
    '''
    Creates a custom RotatedMNIST dataset.
    '''

    def __init__(self, domains: list[int], train: bool, val_set_size: int=0, seed: int=None) -> None:
        super().__init__()

        # Set attributes
        self.domains = torch.tensor(domains)
        self.train = train
        self.val_set_size = val_set_size

        #self.root = os.path.dirname(os.path.realpath(__file__))
        ##Test code
        self.root = os.path.abspath(".")  # this works for .ipynb files
        
        if seed != None:
            torch.manual_seed(seed)
        self.classes = torch.tensor(range(10))
    

        # Create dataset
        mnist = datasets.MNIST(root=self.root, train=self.train, download=True, transform=transforms.ToTensor())
        ##Test code
        self.mnist = mnist
        self.data, self.targets = self._process_data(mnist)

    def _process_data(self, dataset):
        # Shuffle and normalize the images. normalize: uint8 (0..255) -> float32 (0..1)
        loader = DataLoader(dataset, batch_size=len(dataset.targets), shuffle=True)
        images, class_labels = next(iter(loader))
        images = images.squeeze(1)   # remove the batch dimension
        #Test code
        self.images = images
        self.class_labels = class_labels

        # Create domain indices
        domain_count = len(self.domains)
        domain_indeces = torch.randint_like(class_labels, domain_count)
        ##Test code
        self.domain_count = domain_count
        self.domain_indeces = domain_indeces

        # Create the new domain & class label for the cINN
        domains_sincos = torch.tensor([[np.cos(angle), np.sin(angle)]  for angle in np.deg2rad(self.domains)], dtype = torch.float32)
        classes_onehot = torch.eye(10)
        sincos_labels = domains_sincos[domain_indeces]
        onehot_labels = classes_onehot[class_labels]
        cinn_labels = torch.cat((sincos_labels, onehot_labels), 1)
        ##Test code
        self.domains_sincos = domains_sincos
        self.cinn_labels = cinn_labels
        self.domain_labels = self.domains[domain_indeces]  # rotations
        
        ##Test code
        return images, class_labels


### Testing _process_data() -> cINN labels

In [3]:
# Test if creating sincos and onehot codes and combining them to cinn labels works correctly
rot = RotatedMNIST([-10, 15, 30, 180, 720], train=False, seed=1)

print(rot.domains)
print(rot.domains_sincos.T.round(decimals=4))
print("")

print(rot.domain_labels[:8])
print(rot.cinn_labels[:8, :2].T.round(decimals=4))
print("")

print(rot.class_labels[:8])
print(rot.cinn_labels[:8, 2:].T.round(decimals=4))
print("")

print(rot.cinn_labels[:8].T.round(decimals=4))

tensor([-10,  15,  30, 180, 720])
tensor([[ 0.9848,  0.9659,  0.8660, -1.0000,  1.0000],
        [-0.1736,  0.2588,  0.5000,  0.0000, -0.0000]])

tensor([180, 180,  15,  15, 720,  30, 180, 720])
tensor([[-1.0000, -1.0000,  0.9659,  0.9659,  1.0000,  0.8660, -1.0000,  1.0000],
        [ 0.0000,  0.0000,  0.2588,  0.2588, -0.0000,  0.5000,  0.0000, -0.0000]])

tensor([3, 9, 4, 9, 9, 0, 8, 3])
tensor([[0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 1., 1., 0., 0., 0.]])

tensor([[-1.0000, -1.0000,  0.9659,  0.9659,  1.0000,  0.8660, -1.0000,  1.0000],
        [ 0.0000,  0.0000,  0.2588,  0.2588, -0.0000,  0.5000,  0.0000, -0.0000],
        [ 0.0000,  0

### Testing _process_data() -> domain indices

In [4]:
# Test if seed still applies to domain_label sampling
rot_sets = [
    RotatedMNIST([-10, 15, 30, 180, 720], train=False),
    RotatedMNIST([-10, 15, 30, 180, 720], train=False),
    RotatedMNIST([-10, 15, 30, 180, 720], train=False, seed=1),
    RotatedMNIST([-10, 15, 30, 180, 720], train=False, seed=1),
]

for rot in rot_sets:
    print(rot.domain_indeces[:10])
    print("")
    

tensor([3, 4, 1, 3, 1, 3, 3, 4, 0, 1])

tensor([2, 1, 2, 1, 1, 2, 3, 4, 1, 1])

tensor([3, 3, 1, 1, 4, 2, 3, 4, 1, 3])

tensor([3, 3, 1, 1, 4, 2, 3, 4, 1, 3])



In [5]:
# Test if domain_count and domain_labels are produced correctly
rot_sets = [
    RotatedMNIST([0, 15, 30], train=False),
    RotatedMNIST([0], train=False),
    RotatedMNIST([-10, 15, 30, 180, 720], train=False),
]

for rot in rot_sets:
    print(rot.domain_count)
    print(rot.domain_labels.shape)
    print(rot.domain_labels[:10])
    print("")

3
torch.Size([10000])
tensor([30, 15,  0, 30, 30,  0,  0, 15,  0, 30])

1
torch.Size([10000])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

5
torch.Size([10000])
tensor([ 15,  15,  30, -10,  30, -10,  15,  15, 720, 180])



### Testing _process_data() -> dataloader

In [6]:
# Test if the dataloader works for the test set too
rot = RotatedMNIST([0, 15, 30], train=False)

data = rot.images
print(type(data))
print(data.shape)
print(data.dtype)
print(data.min(), data.max())
print("")

targets = rot.class_labels
print(type(targets))
print(targets.shape)
print(targets.dtype)
print(targets.min(), targets.max())
print("")

<class 'torch.Tensor'>
torch.Size([10000, 28, 28])
torch.float32
tensor(0.) tensor(1.)

<class 'torch.Tensor'>
torch.Size([10000])
torch.int64
tensor(0) tensor(9)



In [7]:
# Test if the dataloader does what it should
rot = RotatedMNIST([0, 15, 30], train=True)

data = rot.images
print(type(data))
print(data.shape)
print(data.dtype)
print(data.min(), data.max())
print("")

targets = rot.class_labels
print(type(targets))
print(targets.shape)
print(targets.dtype)
print(targets.min(), targets.max())
print("")

<class 'torch.Tensor'>
torch.Size([60000, 28, 28])
torch.float32
tensor(0.) tensor(1.)

<class 'torch.Tensor'>
torch.Size([60000])
torch.int64
tensor(0) tensor(9)



In [8]:
# Test if manual seeds work
rot_set = [
    RotatedMNIST([0, 15, 30], train=True),
    RotatedMNIST([0, 15, 30], train=True),
    RotatedMNIST([0, 15, 30], train=True, seed=1),
    RotatedMNIST([0, 15, 30], train=True, seed=1),
]

for rot in rot_set:
    print(rot.mnist.targets[:10])
    print(rot.class_labels[:10])
    print("")

tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4])
tensor([3, 1, 5, 2, 0, 8, 5, 2, 5, 4])

tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4])
tensor([3, 6, 9, 3, 9, 2, 5, 2, 2, 1])

tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4])
tensor([4, 8, 8, 6, 7, 1, 0, 7, 1, 8])

tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4])
tensor([4, 8, 8, 6, 7, 1, 0, 7, 1, 8])



### Testing init() -> mnist dateset

In [9]:
# Test if train=False loads the mnist test set
rot = RotatedMNIST([0, 15, 30], train=False)

data = rot.mnist.data
print(type(data))
print(data.shape)
print(data.dtype)
print(data.min(), data.max())
print("")

targets = rot.mnist.targets
print(type(targets))
print(targets.shape)
print(targets.dtype)
print(targets.min(), targets.max())
print("")

<class 'torch.Tensor'>
torch.Size([10000, 28, 28])
torch.uint8
tensor(0, dtype=torch.uint8) tensor(255, dtype=torch.uint8)

<class 'torch.Tensor'>
torch.Size([10000])
torch.int64
tensor(0) tensor(9)



In [10]:
# Test if the mnist dataset is loaded correctly
rot = RotatedMNIST([0, 15, 30], train=True)

print(rot.mnist.classes)
print("")

data = rot.mnist.data
print(type(data))
print(data.shape)
print(data.dtype)
print(data.min(), data.max())
print("")

targets = rot.mnist.targets
print(type(targets))
print(targets.shape)
print(targets.dtype)
print(targets.min(), targets.max())
print("")

['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

<class 'torch.Tensor'>
torch.Size([60000, 28, 28])
torch.uint8
tensor(0, dtype=torch.uint8) tensor(255, dtype=torch.uint8)

<class 'torch.Tensor'>
torch.Size([60000])
torch.int64
tensor(0) tensor(9)



In [11]:
# Test if the basic attributes work
rot = RotatedMNIST([0, 15, 30], train=True)

print(rot.root)
print(rot.domains)
print(rot.train)
print(rot.val_set_size)
print(rot.classes)

/home/birk/Documents/Programmieren/Bachelor Thesis/Own Code/Rotated-cINN/rotated_cinn/code_tests
tensor([ 0, 15, 30])
True
0
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
