In [None]:
import keras
import os.path as osp
import tensorflow as tf
import sklearn.metrics
from tqdm import tqdm
import pickle
from torchvision.models import EfficientNet_B0_Weights, efficientnet_b0
import matplotlib.pyplot as plt
import glob as glob
import numpy as np
import matplotlib
import scipy
import cv2

from colorviz.conv_color import visualizations, utils, hooks
from colorviz.conv_color.config_objects import ImageDatasetCfg, ExperimentConfig
from colorviz.birds_dataset.data import ImageDataset
from colorviz.birds_dataset import network


%load_ext autoreload
%autoreload 2

In [None]:
path = "/scratch/ssd004/scratch/jackk/birds_data"
model = keras.models.load_model(osp.join(path, 'EfficientNetB0-525-(224 X 224)- 98.97.h5'), custom_objects={'F1_score':'F1_score'}) 

In [None]:
te_dset = tf.keras.utils.image_dataset_from_directory(osp.join(path, "test"), image_size=(224, 224))
va_dset = tf.keras.utils.image_dataset_from_directory(osp.join(path, "valid"), image_size=(224, 224))

In [None]:
# rotate, crop and scale back up
def get_f1(dset, scale=20):
    full_labels = tf.zeros([0], dtype=tf.int32)
    full_preds = tf.zeros([0], dtype=tf.int64)
    for b in tqdm(dset):
        noise = tf.random.uniform(b[0].shape, minval=-scale, maxval=scale)
        preds = model(b[0] + noise)
        full_labels = tf.concat([full_labels, b[1]], 0)
        full_preds = tf.concat([full_preds, tf.math.argmax(preds, axis=-1)], 0)
    full_labels = full_labels.numpy()
    full_preds = full_preds.numpy()
    return sklearn.metrics.f1_score(full_preds, full_labels, average="micro")

In [None]:
print(get_f1(te_dset, scale=50))
# print(get_f1(va_dset))

In [None]:
data_cfg = ImageDatasetCfg(batch_size=512,
                            num_workers=4,
                            data_dir="/scratch/ssd004/scratch/jackk/birds_data",
                            device="cuda:0")
transform = EfficientNet_B0_Weights.IMAGENET1K_V1.transforms()

dsets = {split: ImageDataset(split, transform, data_cfg, ddp=False) for split in ["train", "valid", "test"]}

In [None]:
def pca_direction_grids(model, dataset, target_class, img, scales, pca_direction_grids,
                        strides=None, gaussian=False, component=0, batch_size=32):
    # begin by computing d_output_d_alphas
    model.eval()
    im_size = img.shape[0]
    if strides is None:
        strides = scales
    if isinstance(strides, int):
        strides = [strides]*len(pca_direction_grids)

    d_out_d_alpha_grids = []
    interpolators = []
    indices_grid = np.mgrid[0:im_size, 0:im_size]

    stacked_img = np.repeat(np.expand_dims(img, 0), batch_size, axis=0)
    stacked_img = np.transpose(stacked_img, (0, 3, 1, 2)).astype(np.float32) # NCHW format
    img_tensor = dataset.implicit_normalization(torch.tensor(stacked_img).to(dataset.cfg.device))

    for s, (scale, stride) in enumerate(zip(scales, strides)):
        # centers are [scale//2, ..., im_size-scale//2-1], num_windows = im_size-scale+1
        # the -1 on the upper limit center c.f. the "last index" being im_size-1
        # the num_windows is correct because `(im_size-scale//2-1) - (scale//2) = (im_size-2*(scale-1)/2-1) = im_size-scale`
        # and num elements of the array is last-first+1
        index_windows = np.lib.stride_tricks.sliding_window_view(indices_grid, (scale,scale), axis=(1,2))

        xs = np.mgrid[0:im_size-scale:stride, 0:im_size-scale:stride]  # indexes into pca_direction_grids
        num_grid = xs.shape[1]
        #print(xs, num_grid)
        d_out_d_alpha_grid = np.zeros((num_grid, num_grid))

        strided_indices = xs.transpose(1,2,0).reshape(-1, 2)  # ie should always pass strides=1 pca_directions into this
        unstrided_indices = np.mgrid[:num_grid, :num_grid].transpose(1,2,0).reshape(-1, 2)
        for k in tqdm(range(0, num_grid*num_grid, batch_size)):
            actual_batch_size = min(batch_size, num_grid*num_grid-k)
            batch_locs = strided_indices[k: k+actual_batch_size]
            batch_unstrided_locs = unstrided_indices[k: k+actual_batch_size]  # for indexing into a dense grid (num_grid, num_grid)

            pca_directions = pca_direction_grids[s][batch_locs[:,0], batch_locs[:,1], component]
            batch_window_indices = index_windows[:, batch_locs[:,0], batch_locs[:,1], ...]

            # do d_output_d_alpha computation
            alpha = torch.zeros((actual_batch_size,1,1,1), requires_grad=True).to(dataset.cfg.device)
            direction_tensor = dataset.implicit_normalization(torch.tensor(pca_directions).to(dataset.cfg.device).float())
            img_tensor[np.arange(actual_batch_size)[:,None,None], :, batch_window_indices[0], batch_window_indices[1]] += alpha*direction_tensor
            output = model(img_tensor)  # sum since gradient will be back-proped as vector of 1`s

            d_out_d_alpha = torch.autograd.grad(output[:,target_class].sum(), alpha)[0].squeeze()
            model.zero_grad()
            d_out_d_alpha_grid[batch_unstrided_locs[:,0], batch_unstrided_locs[:,1]] = d_out_d_alpha.detach().cpu().numpy()

        d_out_d_alpha_grids.append(d_out_d_alpha_grid.copy())
        # add scale//2 because centers of windows are actually offset by scale//2, and don't directly correspond to indices into
        # pca_direction_grid space
        interpolators.append(RegularGridInterpolator((xs[1,0]+scale//2, xs[1,0]+scale//2), d_out_d_alpha_grid,
                                                     bounds_error=False, fill_value=None))

    # now, per pixel, interpolate what the d_output_d_alpha value would be if the window
    # were centered at that pixel, then take the max over all possible scales
    #print(d_out_d_alpha_grids[-1])
    saliency_map = np.zeros_like(img).astype(np.float32)
    scale_wins = [0] * len(scales)
    for i in tqdm(range(im_size)):
        for j in range(im_size):
            best_d_out_d_alpha = 0
            best_scale = -1
            for s in range(len(scales)):
                interp_value = interpolators[s]([i,j])
                if abs(interp_value) >= abs(best_d_out_d_alpha):
                    best_d_out_d_alpha = interp_value
                    best_scale = s
            saliency_map[i,j] = best_d_out_d_alpha
            scale_wins[best_scale] += 1
    print(scale_wins)
    return saliency_map  # try jacobian with respect to window itself (isnt this just the gradient?)

In [None]:
class_ = 52
class_imgs  = []
for img in dsets['train']:
    if img['label'] == class_:
        class_imgs.append(img['image'].numpy())
class_imgs = np.asarray(class_imgs)
avg_img = class_imgs.mean(axis=0)
plt.imshow(avg_img.transpose(1,2,0))

In [None]:
class_ = 51
class_imgs  = []
for img in dsets['train']:
    if img['label'] == class_:
        class_imgs.append(img['image'].numpy())
class_imgs = np.asarray(class_imgs)
avg_img = class_imgs.mean(axis=0)
plt.imshow(avg_img.transpose(1,2,0))

In [None]:
exp_cfg = ExperimentConfig(epochs=30,
                            lr_max=4e-3,
                            step_size=7,
                            weight_decay=1e-4,
                            lr_decay=0.1,
                            full=False,
                            fc_layers=[2_000])

net = network.OurEfficientNet(efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1),
                        "/scratch/ssd004/scratch/jackk/birds_data/efficientnet_birds.dict", 
                        exp_cfg, dsets['train'].num_classes)
net.load_model_state_dict()
net.to("cuda:0")
net = hooks.GuidedBackprop(net)

In [None]:
class Normalizer():
    def __init__(self, means, scales):
        self.means = np.asarray(means)[None, None, None, :]
        self.scales = np.asarray(scales)[None, None, None, :]

    def fwd(self, imgs):
        if imgs.ndim < self.means.ndim:  # eg. if imgs is (b, 224, 224), apply red channel transformation only
            return (imgs - self.means[..., 0]) / self.scales[...,0]
        return (imgs - self.means) / self.scales

    def rev(self, imgs):
        if imgs.ndim < self.means.ndim:
            return np.clip(imgs * self.scales[...,0] + self.means[...,0], 0, 1)
        return np.clip(imgs * self.scales + self.means, 0, 1)
        
normer = Normalizer(means=[0.485, 0.456, 0.406], scales=[0.229, 0.224, 0.225])

In [None]:
img_grid = []
titles = []
map_names = glob.glob("bird_pca_split_channels/guided*")
for name in tqdm(map_names):
    fname = osp.basename(name)
    data_idx = int(fname.split("-")[-1])
    comp = fname.split("_")[2]
    image, label_idx = dsets['valid'][data_idx].values()
    
    image = image.numpy().transpose(1,2,0)
    label = dsets['valid'].idx_to_class_name[label_idx]
    titles.append(f"{comp=} {label=}")
    with open(name, "rb") as p:
        saliency_map = pickle.load(p)
    img_grid.append([normer.rev(image), saliency_map])
# what the saliency maps look like with the right PCA directions (SiLUs are suppressed), no cropping

# try doing channels individually for PCA maps, see if it affects the results (DONE, though probably has a bug)
# do visualization that combines the saliency map directly with the image (DONE, though blending could be better)
# try doing PCA directions in HSV space (or in only H space)
# investigate why the border artifacts occur?
# investigate why "squares" appear? (only 1 particualr pixel that has a very high d_out_d_alpha for some reason
# figure out why the saliency maps are identical in all channels, when pca maps are different along channels? (probably a bug)

In [None]:
salienced_imgs = [visualizations.combine_saliency_and_img(img, saliency, method="bone", alpha=0.7) for img,saliency in img_grid]
utils.image_grid(salienced_imgs, titles=titles, force_linear=True)

In [None]:
hsv_version = matplotlib.colors.rgb_to_hsv(img_grid[0][0])
alpha = 0.4
hsv_version[...,2] = hsv_version[...,2]*alpha + abs(img_grid[0][1][...,0])*(1-alpha)
plt.subplot(1,2,1)
plt.imshow(img_grid[0][0])
plt.subplot(1,2,2)
plt.imshow(matplotlib.colors.hsv_to_rgb(hsv_version))

In [None]:
plt.subplot(1,2,1)
plt.imshow(img_grid[0][0])
plt.subplot(1,2,2)
plt.imshow(visualizations.combine_saliency_and_img(img_grid[0][0], img_grid[0][1], method="bone", alpha=0.7))

In [None]:
[abs(x[1][...,2] - x[1][...,0]).max() for x in img_grid]  # uh oh