In [None]:
import keras
import os.path as osp
import tensorflow as tf
import sklearn.metrics
from tqdm import tqdm, trange
import pickle
from torchvision.models import EfficientNet_B0_Weights, efficientnet_b0
import torchvision
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
import glob as glob
import numpy as np
import torch
import scipy.stats as stats
from sklearn.preprocessing import StandardScaler
from sklearn import decomposition
from scipy.interpolate import RegularGridInterpolator
import warnings

from colorviz.conv_color import visualizations, utils, hooks
from colorviz.conv_color.config_objects import ImageDatasetCfg, ExperimentConfig
from colorviz.birds_dataset.data import ImageDataset
from colorviz.birds_dataset import network


%load_ext autoreload
%autoreload 2

In [None]:
data_cfg = ImageDatasetCfg(batch_size=512,
                            num_workers=4,
                            data_dir="/scratch/ssd004/scratch/jackk/birds_data",
                            device="cuda:0")
transform = EfficientNet_B0_Weights.IMAGENET1K_V1.transforms()

dsets = {split: ImageDataset(split, transform, data_cfg, ddp=False) for split in ["train", "valid", "test"]}

In [None]:
class Normalizer():
    def __init__(self, means, scales):
        self.means = np.asarray(means)[None, None, None, :]
        self.scales = np.asarray(scales)[None, None, None, :]

    def fwd(self, imgs):
        if imgs.ndim < self.means.ndim:  # eg. if imgs is (b, 224, 224), apply red channel transformation only
            return (imgs - self.means[..., 0]) / self.scales[...,0]
        return (imgs - self.means) / self.scales

    def rev(self, imgs):
        if imgs.ndim < self.means.ndim:
            return imgs * self.scales[...,0] + self.means[...,0]
        return imgs * self.scales + self.means
        
normer = Normalizer(means=[0.485, 0.456, 0.406], scales=[0.229, 0.224, 0.225])


In [None]:
fwd_transform = transforms.Compose([transforms.Resize(256, antialias=True),
                                    transforms.CenterCrop(224),
                                                         ])
raw_img = dsets['train'].load_raw(1234)[0]

ours = normer.fwd(fwd_transform(raw_img).numpy().transpose(1,2,0)[None,...]).squeeze()
tvision = dsets['train'][1234]['image'].numpy().transpose(1,2,0)
print(ours.mean(axis=0).mean(axis=0))
print(tvision.mean(axis=0).mean(axis=0))

In [None]:
with open("big_sample_pca_dirs_reshaped.pkl", "rb") as p:
    pca_direction_grids = pickle.load(p)

In [None]:
strides = [1] * len(pca_direction_grids)
scales = [pca_dir.shape[3] for pca_dir in pca_direction_grids]
im_size = 224

random_permutes = [np.mgrid[:im_size, :im_size].transpose(1,2,0).reshape(-1,2) for _ in range(30)]
for perm in random_permutes:
    np.random.shuffle(perm) 
random_permutes = [x.reshape(im_size, im_size, 2) for x in random_permutes]

full_sample = []
sample_perm = []
for i in trange(8192):
    full_sample.append(dsets['train'].generate_one()[0])
    red_channel = full_sample[-1][...,0]
    perm = random_permutes[i % len(random_permutes)]
    sample_perm.append(red_channel[perm[..., 0], perm[..., 1]])
    
print(set(x.shape for x in full_sample))
full_sample = np.stack(full_sample, axis=0).squeeze()
sample_perm = np.stack(sample_perm, axis=0).squeeze()


im_size = full_sample.shape[1]
im_channels = full_sample.shape[-1]

In [None]:
plt.imshow(normer.rev(sample[5]))

In [None]:
plt.imshow(sample_perm[5])

In [None]:
print(sample_perm[5].mean(), sample_perm[5].std())
print(sample[5][...,0].mean(), sample[5][...,0].std())

In [None]:
def do_pca(sample, scales, num_components=4): # dont use anymore
    print("Got sample, beginning directions")
    pca_direction_grids = []
    strides = [2] * len(scales)
    sample_size = sample.shape[0]
    im_channels = sample.shape[-1]
    for scale, stride in zip(scales, strides):
        windows = np.lib.stride_tricks.sliding_window_view(sample, (scale,scale), axis=(1,2))
        strided_windows = windows[:, ::stride, ::stride, :]  # [N, H, W, C]
    
        xs = np.mgrid[scale:im_size:stride]  # technically wrong (but its shape is correct)
        num_grid = xs.shape[0]
        pca_direction_grid = np.zeros((num_grid, num_grid, num_components, scale, scale, im_channels))
        pca_fitter = decomposition.PCA(n_components=num_components, copy=False)
        scale_fitter = StandardScaler()
        for i in tqdm(range(num_grid)):
            for j in range(num_grid):
                pca_selection = strided_windows[:, i, j, :]
                flattened = pca_selection.reshape(sample_size, -1)
                normalized = scale_fitter.fit_transform(flattened)
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")  # gives pointless zero-division warnings
                    pca_fitter.fit(normalized)
                for comp in range(num_components):
                    pca_direction_grid[i, j, comp] = pca_fitter.components_[comp].reshape(scale, scale, im_channels)
    
        pca_direction_grids.append(pca_direction_grid.copy())
    return pca_direction_grids


In [None]:
num_components = 4
sample = full_sample[:1024]
scales = [15]
print("Got sample, beginning directions")
pca_direction_grids = []
strides = [2] * len(scales)
sample_size = sample.shape[0]
im_channels = sample.shape[-1]
for scale, stride in zip(scales, strides):
    windows = np.lib.stride_tricks.sliding_window_view(sample, (scale,scale), axis=(1,2))
    strided_windows = windows[:, ::stride, ::stride, :]  # [N, abs_posx, abs_posy, C, within_windowx, within_windowy]

    pca_direction_grid = np.zeros((strided_windows.shape[1], strided_windows.shape[2], num_components, scale, scale, im_channels))
    pca_fitter = decomposition.PCA(n_components=num_components, copy=False)
    scale_fitter = StandardScaler()
    for i in tqdm(range(strided_windows.shape[1])):
        for j in range(strided_windows.shape[2]):
            pca_selection = strided_windows[:, i, j].transpose(0,2,3,1)  #  [N, within_windowx, within_windowy, C]
            flattened = pca_selection.reshape(sample_size, -1)
            normalized = scale_fitter.fit_transform(flattened)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")  # gives pointless zero-division warnings
                pca_fitter.fit(normalized)
            for comp in range(num_components):
                pca_direction_grid[i, j, comp] = pca_fitter.components_[comp].reshape(scale, scale, im_channels)

    pca_direction_grids.append(pca_direction_grid.copy())


In [None]:
for i in range(5):
    for j in range(7):
        plt.subplot(5,7, i*7+j+1)
        plt.imshow(normer.rev(pca_selection[:35].reshape(5,7,3,3,3)[i,j]))

In [None]:
plt.imshow(split_channels_medium[0][0,0,0,:,:,0])

In [None]:
# subselect to channel 0 since we are only doing this for comparison so only need to compute channel 0
split_channels_medium2 = visualizations.find_pca_directions(None, [7], 2, 
                                                           sample=full_sample[:5,...,0], num_components=5, split_channels=True)

In [None]:
# subselect to channel 0 since we are only doing this for comparison so only need to compute channel 0
split_channels_medium = visualizations.find_pca_directions(None, [7], 2, 
                                                           sample=full_sample[:4096,...,0], num_components=5, split_channels=True)

In [None]:
medium_patch_fixed = visualizations.find_pca_directions(None, [7], 2, sample=full_sample[:4096], num_components=5)

In [None]:
permuted_medium_patch_fixed = visualizations.find_pca_directions(None, [7], 2, sample=sample_perm[:4096], num_components=5)

In [None]:
visualizations.visualize_pca_directions(permuted_medium_patch_fixed, "Permuted medium patches, fixed (4096)", [7], lines=False)

In [None]:
visualizations.visualize_pca_directions(split_channels_medium, "Permuted medium patches, fixed (4096)", [7], lines=False)

In [None]:
plt.imshow(medium_patch_fixed[0][0, 0, 0, ..., 0])

In [None]:
pca_direction_grids[2].shape

In [None]:
fixed_pca_direction_grids = []
for pca_grid in pca_direction_grids:
    num_x, num_y, num_comp, scalex, scaley, num_channels = pca_grid.shape
    fixed_pca_direction_grids.append(pca_grid.reshape(num_x, num_y, num_comp, scalex*scaley*num_channels)
                                             .reshape(num_x, num_y, num_comp, num_channels, scalex, scaley)
                                             .transpose(0,1,2,4,5,3))
with open("big_sample_pca_dirs_reshaped.pkl", "wb") as p:
    pickle.dump(fixed_pca_direction_grids, p)

In [None]:
with open("big_sample_pca_dirs_reshaped.pkl", "rb") as p:
    fixed_pca_direction_grids = pickle.load(p)

In [None]:
plt.imshow(medium_patch_fixed[0][0,0,0, :, :, 0])
plt.colorbar()
# => zero-centering the colobars will cause all contrast to be lost and that's why the component 0s looked uniform

In [None]:
utils.plt_grid_figure(medium_patch_fixed[0][::20, ::20, 0, :, :, 0], first_cmap="bwr", cmap="bwr", colorbar=True, zero_centered_cmap=False)

In [None]:
utils.plt_grid_figure(fixed_pca_direction_grids[2][::16, ::16, 0, :, :, 0], first_cmap="bwr", cmap="bwr", colorbar=False, zero_centered_cmap=False)
# After reshaping and transposing the old direction_grids, and removing the zero-centering, it matches the correct computation
# so we don't have to redo all those computations

In [None]:
# medium_patch_fixed.shape (109, 109, 5, 7, 7, 3)
utils.plt_grid_figure(medium_patch_fixed[0][::8, ::8, 0, :, :, 0], first_cmap="bwr", cmap="bwr", colorbar=False, zero_centered_cmap=False)
# an example of what the 7x7 patches are supposed to look like for component 0

In [None]:
utils.plt_grid_figure(medium_patch_fixed[0][::8, ::8, 1, :, :, 0], first_cmap="bwr", cmap="bwr", colorbar=False, zero_centered_cmap=False)
# what the 7x7 patches are supposed to look like for component 1

In [None]:
utils.plt_grid_figure(medium_patch_fixed[0][::8, ::8, 2, :, :, 0], first_cmap="bwr", cmap="bwr", colorbar=False, zero_centered_cmap=False)
# what the 7x7 patches are supposed to look like for component 2

In [None]:
utils.plt_grid_figure(medium_patch_fixed[0][::8, ::8, 3, :, :, 0], first_cmap="bwr", cmap="bwr", colorbar=False, zero_centered_cmap=False)
# what the 7x7 patches are supposed to look like for component 3

In [None]:
utils.plt_grid_figure(split_channels_medium[0][::8, ::8, 4, :, :, 0], first_cmap="bwr", cmap="bwr", colorbar=False, zero_centered_cmap=False)
# for comparison
# qualitatively, they are quite different => worthwhile to check if doing chanels individually is better

In [None]:
utils.plt_grid_figure(medium_patch_fixed[0][::8, ::8, 4, :, :, 0], first_cmap="bwr", cmap="bwr", colorbar=False, zero_centered_cmap=False)
# what the 7x7 patches are supposed to look like for component 4

In [None]:
visualizations.visualize_pca_directions([medium_patch_fixed[0][::10, ::10]], "Medium patches fixed (4096)", [7], lines=False)

In [None]:
visualizations.visualize_pca_directions(medium_patch_fixed, "Medium patches fixed (4096)", [7], lines=False)

In [None]:
visualizations.visualize_pca_directions(pca_direction_grids, "Randomly permuted PCA (8192)", [3], lines=False)

In [None]:
strided_windows.shape

In [None]:
windows.shape

In [None]:
random_pca_dirs = do_pca(sample_perm[...,None], [3])

In [None]:
random_big_pca_dirs = do_pca(sample_perm[:100, ..., None], [25])

In [None]:
random_full_pca_dirs = do_pca(sample_perm[..., None], [222])

In [None]:
big_patches = do_pca(sample, [222])

In [None]:
visualizations.visualize_pca_directions(random_pca_dirs, "Randomly permuted PCA (8192)", [3], lines=False)
# an example of what the PCA dirs for permuted images (aka. each pixel has the same marginal distribution,
# but any conditional structure is removed). Clearly its uniform noise, which is good

In [None]:
big_patches[0].shape

In [None]:
np.concatenate(np.concatenate(big_patches[0][:, :, 0, :, :, 0], 1), 1).shape

In [None]:
x = np.arange(600).reshape(2, 3, 2, 5, 5, 2)
x[0,1, 0, :, :, 1]

In [None]:
np.concatenate(np.concatenate(x[:, :, 0, :, :, 1], 1), 1)

In [None]:
visualizations.visualize_pca_directions(big_patches, "Big patches, regular (8192)", [222], lines=True)

In [None]:
visualizations.visualize_pca_directions(random_big_pca_dirs, "Bigger patches, permuted (100)", [25], lines=False)

In [None]:
visualizations.visualize_pca_directions(random_full_pca_dirs, "Largest possible patches, permuted (8192)", [222], lines=True)

In [None]:
visualizations.visualize_pca_directions(pca_direction_grids, "", lines=False)