In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from tqdm.auto import tqdm
import os
import numpy as np
import ipdb
import time
from utils import ZCA

In [None]:
crop_side_length = 11
stride = 1
multiscale = False # multilayer pipeline to preprocess the images
uniform_noise = True # adds uniform noise in the range [0,1] on each pixel
full_image = False # full images instead of patches are used
save_as_double = True # choose to save as float or double
batch_size = 1000
shuffle = True
whitening = False
plot_covariance = True

In [None]:
transform = transforms.Compose([transforms.ToTensor(),])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_folder = 'cifar_train/'
test_folder = 'cifar_test/'
# classes = ('plane', 'car', 'bird', 'ct', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
classes = ('plane', 'car')

In [None]:
class PatchWriter(torch.nn.Module):
    def __init__(self, crop_side_length, stride, multiscale):
        super(PatchWriter, self).__init__()
        self.pool = torch.nn.AvgPool2d(crop_side_length, stride=crop_side_length//2)
        self.crop_side_length = crop_side_length
        self.stride = stride
        self.multiscale = multiscale

    # define pipeline here
    def forward(self, images):
        all_patches = []
        all_cells = []
        for image in images:
            patch_list = []
            cell_list = []
            patches, cells = self.extract_patches(image) # original scale
            patch_list.append(patches) # original scale
            cell_list.append(cells) # original scale
            if self.multiscale:
                image = self.pool(image)
                patches, cells = self.extract_patches(image) # pool 1
                patch_list.append(patches) # pool 1
                cell_list.append(cells) # pool 1
                image = self.pool(image)
                patches, cells = self.extract_patches(image) # pool 2
                patch_list.append(patches) # pool 2
                cell_list.append(cells) # pool 2
            patches = np.vstack(patch_list)
            cells = np.vstack(cell_list)
            all_patches.append(patches)
            all_cells.append(cells)
        return np.concatenate(all_patches).reshape(images.shape[0], *patches.shape), np.concatenate(all_cells).reshape(images.shape[0], *cells.shape)

    # given an image, crop so many times and return array of patches
    def extract_patches(self, image):
        n_patches = image.shape[-1]//stride+1
        image = np.squeeze(image)
        patches = np.zeros((n_patches**2, image.shape[0], self.crop_side_length**2), dtype=np.float64)  
        cells = np.zeros((n_patches**2))
        image = transforms.functional.to_pil_image(image)
        image = transforms.functional.pad(image, self.crop_side_length//2)
        for i in range(n_patches):
            for j in range(n_patches):
                image_cropped = transforms.functional.crop(image, i*stride, j*stride, self.crop_side_length, self.crop_side_length)
                image_tensors = transforms.functional.to_tensor(image_cropped)
                patches[i*n_patches+j][:] = image_tensors.reshape(-1, self.crop_side_length**2)
                if i<n_patches/2 and j<n_patches/2:
                    cells[i*n_patches+j]=0
                elif i<n_patches/2 and j>n_patches/2:
                    cells[i*n_patches+j]=1
                elif i>n_patches/2 and j<n_patches/2:
                    cells[i*n_patches+j]=2
                elif i>n_patches/2 and j>n_patches/2:
                    cells[i*n_patches+j]=3
        return patches,cells

In [None]:
def write_patches_to_disk(dataset, folder, crop_side_length, stride, multiscale, uniform_noise, full_image, save_as_double):
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4)
    if full_image:
        for f in [folder, folder+'data/', folder+'labels/']:
            if not os.path.exists(f): os.makedirs(f)
            for fname in os.listdir(f):
                if not os.path.isdir(f+fname): os.remove(f+fname)
    else:
        for f in [folder, folder+'data/', folder+'labels/', folder+'cells/']:
            if not os.path.exists(f): os.makedirs(f)
            for fname in os.listdir(f):
                if not os.path.isdir(f+fname): os.remove(f+fname)
    
    data_type = np.float64 if save_as_double else np.float32
    model = PatchWriter(crop_side_length, stride, multiscale)

    for index, images_and_labels in enumerate(tqdm(iter(dataloader))):
        images, labels = images_and_labels
        data_file_name = './' + folder + 'data/' + str(index) + '.npy'
        label_file_name = './' + folder + 'labels/' + str(index) + '.npy'
        np.save(label_file_name, labels.numpy().astype(np.int32))
        if uniform_noise: images = images + torch.Tensor(np.random.uniform(0,1, images.shape))
        if full_image:
            images_np = images.numpy()
            images_np = images_np.reshape(images.shape[0], 1, 1, -1)
            np.save(data_file_name, images_np.astype(data_type))
        else:
            cell_file_name = './' + folder + 'cells/' + str(index) + '.npy'
            patches,cells = model(images)
            np.save(data_file_name, patches.astype(data_type))  
            np.save(cell_file_name, cells.astype(data_type))            
    
    C = 1
    D = 3*(images.shape[-1]**2) if full_image else crop_side_length**2
    Total = len(dataloader)*batch_size
    if not full_image: Total *= patches.shape[1]
    np.savetxt('./' + folder +'header.txt', np.array([Total, batch_size, C, D]), delimiter=',', fmt='%d')

### save the patches

In [None]:
# write_patches_to_disk(trainset, train_folder, crop_side_length, stride, multiscale, uniform_noise, full_image, save_as_double)

In [None]:
write_patches_to_disk(testset, test_folder, crop_side_length, stride, multiscale, uniform_noise, full_image, save_as_double)

In [None]:
if whitening:
    train_data = [] 
    for i,f in enumerate(os.listdir(train_folder+"data/")):
        if i%3==0: train_data.append(np.load(train_folder+"data/"+f))
    print(len(train_data))
    train_data = np.concatenate(train_data)
    
    orig_shape = train_data.shape
    train_data = train_data.reshape((train_data.shape[0],-1))

    zca = ZCA(bias=1e-8)
    zca.fit(train_data)
    train_data = zca.transform(train_data)
    train_data = train_data.reshape(orig_shape)
    
    for i, step in enumerate(range(0,train_data.shape[0],batch_size)):
        print(i)
        np.save(train_folder+"data/{}.npy".format(i),train_data[step:step+batch_size])
        
        
    
    test_data = [] 
    for f in os.listdir(test_folder+"data/"):
        test_data.append(np.load(test_folder+"data/"+f))
    test_data = np.concatenate(test_data)
    print(len(test_data))
    
    orig_shape = test_data.shape
    test_data = test_data.reshape((test_data.shape[0],-1))

    test_data = zca.transform(test_data)
    test_data = test_data.reshape(orig_shape)
    
    for i, step in enumerate(range(0,test_data.shape[0],batch_size)):
        print(i)
        np.save(test_folder+"data/{}.npy".format(i),test_data[step:step+batch_size])
        


### covariances

In [None]:
if plot_covariance:
    import seaborn as sn
    import matplotlib.pyplot as plt
    dataset_patches = np.concatenate([np.load(test_folder+"data/"+f)[0] for f in os.listdir(test_folder+"data/") if f.endswith(".npy")])
    dataset_patches = np.array([0.2989 * p[0,:] + 0.5870 *  p[1,:] + 0.1140 * p[2,:] for p in dataset_patches])
    covariance = np.corrcoef(dataset_patches.T)
    sn.heatmap(covariance)
    plt.show()