In [6]:
import numpy as np
#import torchvision
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.nn as nn
from tqdm import tqdm
from models import VAEGenerator,AEGenerator,DCGenerator
from models import MNIST224VAEGenerator,PGenerator,PDiscriminator
import utils
import torchvision
import argparse
import matplotlib.pyplot as plt
import model_settings
import os
torch.backends.cudnn.benchmark = True
import numpy as np
from scipy import fftpack
import torch

def get_2d_dct(x):
    return fftpack.dct(fftpack.dct(x.T, norm='ortho').T, norm='ortho')

def get_2d_idct(x):
    return fftpack.idct(fftpack.idct(x.T, norm='ortho').T, norm='ortho')

def RGB_img_dct(img):
    assert len(img.shape) == 3 #and img.shape[0] == 3
    c_img = img.shape[0] # support 1-channel mnist data
    signal = np.zeros_like(img)
    for c in range(c_img):
        signal[c] = get_2d_dct(img[c])
    return signal

def RGB_signal_idct(signal):
    assert len(signal.shape) == 3 #and signal.shape[0] == 3
    c_signal = signal.shape[0] # support 1-channel mnist data
    img = np.zeros_like(signal)
    for c in range(c_signal):
        img[c] = get_2d_idct(signal[c])
    return img
font = {'size': 18}
import matplotlib
matplotlib.use('Agg')
matplotlib.rc('font', **font)

In [2]:
# print in "Z"
# i.e. input np.array([[1,2,3],[4,5,6],[7,8,9]])
#      output [1, 2, 4, 7, 5, 3, 6, 8, 9]
def printMatrixZigZag(matrix):
    internal = []
    aR,aC=0,0
    bR,bC=0,0
    endR=matrix.shape[0]-1
    endC=matrix.shape[1]-1
    fromUp=False
    while aR!=endR+1:
        printLevel(matrix,aR,aC,bR,bC,fromUp,internal)
        aR=aR+1 if aC==endC else aR
        aC=aC if aC==endC else aC+1
        bC = bC + 1 if bR == endR else bC
        bR=bR if bR==endR else bR+1
        fromUp=not fromUp
    return internal
def printLevel(m,aR,aC,bR,bC,fromUp,internal):
    if fromUp:
        while aR!=bR+1: 
            internal.append(m[aR,aC])
            aR+=1
            aC-=1
        # print("")
    else:
        while bR!=aR-1: 
            internal.append(m[bR,bC])
            bR-=1
            bC+=1


In [3]:
# 1000 images
TASKs = ['imagenet', 'celeba', 'cifar10_224', 'mnist_224']
TASK_names = ['(a) ImageNet', '(b) CelebA', '(c) CIFAR10', '(d) MNIST']
model_names = ['res18','dense121','res50','vgg16','googlenet','wideresnet']
orl = ['ResNet18','DenseNet121','ResNet50','VGG16','GoogleNet','WideResNet']
rgb_colors = ['#FF0000', '#FF8000', '#00994C', '#0066CC', '#8c564b', '#e377c2', '#bcbd22','#934BD1','#B01260','#F26749','#497E8C']
fig_size = (6.5*len(TASKs), 5)
fig = plt.figure(figsize=fig_size)
n = len(TASKs)
for i in range(n):
    TASK = TASKs[i]
    plt.subplot(1, n, i+1)
    for idx,model in enumerate(model_names):
        X_dec = []
        num = 0
        j = 1
        while(num<1000):
            temp = '../raw_data/%s_%s/test_batch_%d.npy' % (TASK, model, j)
            data = np.load(temp)
            data = wdata / np.sqrt((data**2).sum(1, keepdims=True))
            if TASK ==  'mnist_224':
                data = data.reshape(-1,1,224,224)
            else:
                data = data.reshape(-1,3,224,224)
            num += data.shape[0]
            j += 1
            for _ in range(wdata.shape[0]):
                ori_signal = RGB_img_dct(data[_])
                X_dec.append(np.abs(ori_signal))
        X_dec = np.stack(X_dec, axis=0)
        X_dec = np.mean(np.mean(X_dec,axis=0),axis=0)
        X_dec = printMatrixZigZag(X_dec)
        import scipy
        X_dec = scipy.signal.savgol_filter(X_dec,801,3)
        plt.plot(np.arange(len(X_dec)), X_dec,rgb_colors[idx], label=orl[idx])
        print("done:",TASK, model)
    plt.xlabel('%s' %TASK_names[i])
    if i == 0:
        plt.ylabel('DCT Coefficient Distribution')
anchor_loc = (1.7, 1.0)
plt.legend(bbox_to_anchor=anchor_loc)
plt.savefig('./plots/dct_coefficient_distribution.png', bbox_inches='tight')
plt.close(fig)

done: imagenet res18
done: imagenet dense121
done: imagenet res50
done: imagenet vgg16
done: imagenet googlenet
done: imagenet wideresnet
done: celeba res18
done: celeba dense121
done: celeba res50
done: celeba vgg16
done: celeba googlenet
done: celeba wideresnet
done: cifar10_224 res18
done: cifar10_224 dense121
done: cifar10_224 res50
done: cifar10_224 vgg16
done: cifar10_224 googlenet
done: cifar10_224 wideresnet
done: mnist_224 res18
done: mnist_224 dense121
done: mnist_224 res50
done: mnist_224 vgg16
done: mnist_224 googlenet
done: mnist_224 wideresnet


In [None]:
# randomly sample 10 images
R = []
G = []
B = []
TASKs = ['cifar10_224'] #, 'celeba', 'cifar10_224', 'mnist_224']
TASK_names = ['(a) Red Channel', '(b) Green Channel', '(c) Blue Channel']
model_names = ['res18'] #,'dense121','res50','vgg16','googlenet','wideresnet']
rgb_colors = ['#FF0000', '#00FF00', '#0000FF', '#0066CC', '#8c564b', '#e377c2', '#bcbd22','#934BD1','#B01260','#F26749','#497E8C']
fig_size = (6.5*len(TASK_names), 5)
fig = plt.figure(figsize=fig_size)
n = 1
TASK = TASKs[0]
for idx,model in enumerate(model_names):
    X_dec = []
    for j in range(10):
        temp = '../raw_data/%s_%s/test_batch_%d.npy' % (TASK, model, j)
        data = np.load(temp)
        data = data / np.sqrt((data**2).sum(1, keepdims=True))
        if TASK ==  'mnist_224':
            data = data.reshape(-1,1,224,224)
        else:
            data = data.reshape(-1,3,224,224)
        for _ in range(data.shape[0]):
            ori_signal = np.abs(RGB_img_dct(data[_]))
            import scipy
            R.append(scipy.signal.savgol_filter(printMatrixZigZag(ori_signal[0]),801,3))
            G.append(scipy.signal.savgol_filter(printMatrixZigZag(ori_signal[1]),801,3))
            B.append(scipy.signal.savgol_filter(printMatrixZigZag(ori_signal[2]),801,3))
            break
for i,C in enumerate([R,G,B]):
    plt.subplot(1, 3, i+1)
    for k in range(10):
        plt.plot(np.arange(len(C[k])), C[k],rgb_colors[i],alpha=0.2)
    if i ==0:
        plt.ylabel('Coefficient Distribution')
    plt.xlabel('%s' %TASK_names[i],fontsize=20)
anchor_loc = (1.5, 1.0)
plt.legend(bbox_to_anchor=anchor_loc)
plt.savefig('./plots/%s.png' % ('long_tail_distribution'), bbox_inches='tight')
plt.close(fig)