In [35]:
import random
from pathlib import Path
from typing import List

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms


In [36]:
def get_mean_and_std(dataloader):
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0
    for data, _ in dataloader:
        # Mean over batch, height and width, but not over the channels
        channels_sum += torch.mean(data, dim=[0, 2, 3])
        channels_squared_sum += torch.mean(data**2, dim=[0, 2, 3])
        num_batches += 1

    mean = channels_sum / num_batches

    # std = sqrt(E[X^2] - (E[X])^2)
    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

    return mean, std


def min_max_scale_normalization(tensor: torch.Tensor) -> torch.Tensor:
    '''
    Only support RGB image (3 channels)
    '''
    dtype = tensor.dtype

    if not isinstance(tensor, torch.Tensor):
        raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor)))

    if tensor.ndim < 3:
        raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = '
                         '{}.'.format(tensor.size()))

    channel_min = [0, 0, 0]
    channel_max = [1, 1, 1]
    new_channel_min = [-1, -1, -1]
    new_channel_max = [1, 1, 1]

    src_min = torch.as_tensor(channel_min, dtype=dtype, device=tensor.device)
    src_max = torch.as_tensor(channel_max, dtype=dtype, device=tensor.device)
    new_min = torch.as_tensor(new_channel_min, dtype=dtype, device=tensor.device)
    new_max = torch.as_tensor(new_channel_max, dtype=dtype, device=tensor.device)

    # for batch images
    if tensor.ndim == 4:
        src_min = src_min.view(1, 3, 1, 1).expand(tensor.shape[0], -1, -1, -1)
        src_max = src_max.view(1, 3, 1, 1).expand(tensor.shape[0], -1, -1, -1)
        new_min = new_min.view(1, 3, 1, 1).expand(tensor.shape[0], -1, -1, -1)
        new_max = new_max.view(1, 3, 1, 1).expand(tensor.shape[0], -1, -1, -1)
    else:
        src_min = src_min.view(3, 1, 1)
        src_max = src_max.view(3, 1, 1)
        new_min = new_min.view(3, 1, 1)
        new_max = new_max.view(3, 1, 1)

    return (tensor - src_min)/(src_max - src_min)*(new_max - new_min) + new_min

class MinMaxScaler(object):
    def __init__(
        self,
        channel_min: List[float],
        channel_max: List[float],
        new_channel_min: List[float],
        new_channel_max: List[float],
    ):
        assert isinstance(channel_min, list)
        assert isinstance(channel_max, list)
        assert isinstance(new_channel_min, list)
        assert isinstance(new_channel_max, list)
        self.channel_min = channel_min
        self.channel_max = channel_max
        self.new_channel_min = new_channel_min
        self.new_channel_max = new_channel_max

    def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
        '''
        Only support RGB image (3 channels)
        '''
        assert isinstance(tensor, torch.Tensor)
        dtype = tensor.dtype

        if not isinstance(tensor, torch.Tensor):
            raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor)))

        if tensor.ndim < 3:
            raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = '
                             '{}.'.format(tensor.size()))

        src_min = torch.as_tensor(self.channel_min, dtype=dtype, device=tensor.device)
        src_max = torch.as_tensor(self.channel_max, dtype=dtype, device=tensor.device)
        new_min = torch.as_tensor(self.new_channel_min, dtype=dtype, device=tensor.device)
        new_max = torch.as_tensor(self.new_channel_max, dtype=dtype, device=tensor.device)

        # for batch images
        if tensor.ndim == 4:
            src_min = src_min.view(1, 3, 1, 1).expand(tensor.shape[0], -1, -1, -1)
            src_max = src_max.view(1, 3, 1, 1).expand(tensor.shape[0], -1, -1, -1)
            new_min = new_min.view(1, 3, 1, 1).expand(tensor.shape[0], -1, -1, -1)
            new_max = new_max.view(1, 3, 1, 1).expand(tensor.shape[0], -1, -1, -1)
        else:
            src_min = src_min.view(3, 1, 1)
            src_max = src_max.view(3, 1, 1)
            new_min = new_min.view(3, 1, 1)
            new_max = new_max.view(3, 1, 1)

        return (tensor - src_min)/(src_max - src_min)*(new_max - new_min) + new_min

    def __repr__(self):
        return self.__class__.__name__ + f'(channel_min={self.channel_min}, channel_max={self.channel_max}, new_channel_min={self.new_channel_min}, new_channel_max={self.new_channel_max})'


In [37]:
transform_test = transforms.Compose([
    transforms.ToTensor(),
    # MinMaxScaler(
        # channel_min=[0, 0, 0],
        # channel_max=[1, 1, 1],
        # new_channel_min=[-1, -1, -1],
        # new_channel_max=[1, 1, 1],
    # ),
    transforms.Grayscale(num_output_channels=3),
    transforms.Lambda(lambda x: x.float()),
])


In [38]:
train_data = datasets.CIFAR10(
    "../data",
    train=True,
    download=False,
    transform=transform_test
)  # ! Change
test_data = datasets.CIFAR10(
    "../data",
    train=False,
    download=False,
    transform=transform_test
)  # ! Change

train_loader = DataLoader(train_data,
                          batch_size=4,
                          shuffle=False
                          )
test_loader = DataLoader(test_data,
                         batch_size=100,
                         shuffle=False)


In [39]:
_x, _y = None, None
batch: 'list[torch.FloatTensor]'
for batch in train_loader:
    # _x, _y = batch[0].to('cuda'), batch[1].to('cuda')
    _x = batch[0]
    print(_x.shape)

torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])
t

KeyboardInterrupt: 

In [None]:
# _data = np.array(
#     [
#         [
#             [1, 200, 30],
#             [1, 200, 30],
#             [1, 200, 30],
#         ],
#         [
#             [5, 500, 50],
#             [5, 500, 50],
#             [5, 500, 50],
#         ],
#         [
#             [1, 200, 10],
#             [1, 200, 10],
#             [1, 200, 10],
#         ],
#     ],
# )
# _data = torch.from_numpy(_data.transpose((2, 0, 1))).contiguous()
# print(_data.shape)
# print(_data)

# # src_min = torch.as_tensor([2, 200, 20]).unsqueeze(-1).unsqueeze(-1)
# src_min = torch.as_tensor([2, 200, 20])
# src_min = src_min.view(1, 3, 1, 1)
# print(src_min.shape)
# print(_data - src_min)

torch.Size([3, 3, 3])
tensor([[[  1,   1,   1],
         [  5,   5,   5],
         [  1,   1,   1]],

        [[200, 200, 200],
         [500, 500, 500],
         [200, 200, 200]],

        [[ 30,  30,  30],
         [ 50,  50,  50],
         [ 10,  10,  10]]])
torch.Size([1, 3, 1, 1])
tensor([[[[ -1,  -1,  -1],
          [  3,   3,   3],
          [ -1,  -1,  -1]],

         [[  0,   0,   0],
          [300, 300, 300],
          [  0,   0,   0]],

         [[ 10,  10,  10],
          [ 30,  30,  30],
          [-10, -10, -10]]]])
