In [None]:
cd '../'

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from src.utils_freq import rgb2gray, dct, dct2, idct, idct2, batch_dct2, getDCTmatrix, batch_idct2, batch_dct2_3channel, batch_idct2_3channel, equal_dist_from_top_left, mask_radial
from src.utils_dataset import load_dataset

_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
dataset = 'cifar100'
threshold_list = [90, 70, 50, 30, 10]

train_loader, test_loader = load_dataset(dataset, 1)
for x,y in test_loader:
    X_original = x
    _dim = X_original.shape[-1]
    dct_matrix = getDCTmatrix(_dim).to(_device)
    break
X_dct = batch_dct2_3channel(X_original, dct_matrix)


In [None]:
fix, axs = plt.subplots(nrows = 1, ncols=3, figsize=(8, 2.4), tight_layout=True)

axs[0].imshow(X_original.squeeze().permute(1,2,0).cpu().numpy())
im = axs[1].imshow(X_dct.squeeze().abs().mean(dim=0).cpu().numpy(), cmap = 'gray')
fix.colorbar(im, ax=axs[1])
im = axs[2].imshow(X_dct.squeeze().abs().mean(dim=0).log().cpu().numpy(), cmap = 'gray')
fix.colorbar(im, ax=axs[2])
axs[0].axis('off')
axs[1].axis('off')
axs[2].set_xticks([])
axs[2].set_yticks([])

axs[0].set_title(r'$x$', fontsize=11)
axs[1].set_title(r'$|\tilde{x}|$', fontsize=11)
axs[2].set_title(r'$\log |\tilde{x}|$', fontsize=11)

plt.savefig('./figures/OB_{}_original.pdf'.format(dataset), bbox_inches='tight')  

In [None]:
plot_nrg_filtered_X = []
plot_nrg_filtered_X_dct = []

for threshold in threshold_list:
    X_dct_abs_mean = X_dct.abs().mean(dim=1)
    threshold_for_each_sample = torch.quantile(X_dct_abs_mean.view(-1, _dim*_dim), 1.-threshold/100., dim=1, keepdim=False)
    X_threshold = X_dct_abs_mean >= threshold_for_each_sample.view(1,1,1)
    
    X_dct_new = X_dct*X_threshold.unsqueeze(1).expand(-1,X_original.shape[1],-1,-1)

    X_new = batch_idct2_3channel(X_dct_new, dct_matrix)

    plot_nrg_filtered_X.append(X_new.squeeze())
    plot_nrg_filtered_X_dct.append(X_dct_new.squeeze())

In [None]:
plot_freq_mask = []

spacing = _dim*np.sqrt(2)/2000

for _threshold in threshold_list:
    print(_threshold)
    candidate_radius = np.arange(_dim*np.sqrt(2), 0, -spacing)

    for _r in candidate_radius:
        area = mask_radial(_dim, _r).sum()/_dim/_dim*100
        if area < _threshold:
            break

    freq_mask = torch.tensor(mask_radial(_dim, _r),device=_device, dtype= torch.float16).unsqueeze(2).expand(-1,-1,X_original.shape[1]).squeeze()
    plot_freq_mask.append(freq_mask.cpu())


In [None]:
plot_freq_filtered_X = []
plot_freq_filtered_X_dct = []
for i, _threshold in enumerate(threshold_list):
    X_dct_new = batch_dct2_3channel(X_original, dct_matrix) * plot_freq_mask[i].permute(2,0,1).unsqueeze(0)
    X_new = batch_idct2_3channel(X_dct_new.to(torch.float32), dct_matrix)
    plot_freq_filtered_X.append(X_new.squeeze())
    plot_freq_filtered_X_dct.append(X_dct_new.squeeze())

In [None]:
fix, axs = plt.subplots(nrows = 6, ncols=5, figsize=(12, 12), tight_layout=True)

for i, _threshold in enumerate([90, 70, 50, 30, 10]):
    
    # row 1
    axs[0,i].imshow(plot_nrg_filtered_X[i].permute(1,2,0).cpu().numpy(), cmap = 'gray')
    axs[0,i].axis('off')

    # row 2
    im = axs[1,i].imshow((((plot_nrg_filtered_X[i].permute(1,2,0) - X_original.squeeze().permute(1,2,0)))*10).abs().squeeze().cpu().numpy(), cmap = 'gray')
    fix.colorbar(im, ax=axs[1,i])
    axs[1,i].set_xticks([])
    axs[1,i].set_yticks([])

    # row 3
    im = axs[2,i].imshow((plot_nrg_filtered_X_dct[i][0] != 0), cmap = 'gray')
    fix.colorbar(im, ax=axs[2,i])
    axs[2,i].set_xticks([])
    axs[2,i].set_yticks([])

    # row 4
    axs[3,i].imshow(plot_freq_filtered_X[i].permute(1,2,0).cpu().numpy(), cmap = 'gray')
    axs[3,i].axis('off')

    # row 5
    im = axs[4,i].imshow((((plot_freq_filtered_X[i].permute(1,2,0) - X_original.squeeze().permute(1,2,0)))*10).abs().squeeze().cpu().numpy(), cmap = 'gray')
    fix.colorbar(im, ax=axs[4,i])
    axs[4,i].set_xticks([])
    axs[4,i].set_yticks([])

    # row 6
    im = axs[5,i].imshow((plot_freq_mask[i][:,:,0] != 0), cmap = 'gray')
    fix.colorbar(im, ax=axs[5,i])
    axs[5,i].set_xticks([])
    axs[5,i].set_yticks([])
    

    axs[0,i].set_title(r'$\Phi_{nrg}(x,$'+str(100-_threshold)+r'$)$', fontsize=11)
    axs[1,i].set_title(r'$|x - \Phi_{nrg}(x,$'+str(100-_threshold)+r'$)| \times 10$', fontsize=11)
    axs[2,i].set_title(r'$M_{nrg}(\tilde{x},$'+str(100-_threshold)+r'$)$', fontsize=11)
    axs[3,i].set_title(r'$\Phi_{freq}(x,$'+str(100-_threshold)+r'$)$', fontsize=11)
    axs[4,i].set_title(r'$|x - \Phi_{freq}(x,$'+str(100-_threshold)+r'$)| \times 10$', fontsize=11)
    axs[5,i].set_title(r'$M_{freq}($'+str(100-_threshold)+r'$)$', fontsize=11)
        
plt.tight_layout()

plt.savefig('./figures/OB_{}_modified.pdf'.format(dataset), bbox_inches='tight')  