In [1]:
import os, sys
import toml
import argparse
from munch import Munch, munchify

PROJ_DIR = os.path.expanduser("~/GitWS/Transmisstion-Phase")
DATA_DIR = os.path.join(PROJ_DIR, "data")
SRC_DIR = os.path.join(PROJ_DIR, "src")
LOGS_DIR = os.path.join(PROJ_DIR, "logs", "exp1")
SCRIPTS_DIR = os.path.join(PROJ_DIR, "scripts")
CHECKPOINTS_DIR = os.path.join(PROJ_DIR, "data")
RESULTS_DIR = os.path.join(PROJ_DIR, "results")

sys.path.append(PROJ_DIR)

In [2]:
import numpy as np
np.set_printoptions(formatter={'float': lambda x: "{0:0.2f}".format(x)})
from torchvision import datasets, transforms
DATASET = 'MNIST'
INDEX = 100

In [3]:
npz_fname = os.path.join(DATA_DIR, DATASET, "GradMatrix", f"{INDEX}.npz")
grad_matrix = np.load(npz_fname)['data']

In [4]:
dataset = eval('datasets.'+DATASET)(DATA_DIR, train=True, download=True, 
                transform=transforms.ToTensor()
               )

In [5]:
base_matrix = grad_matrix[0]
src_cls = dataset[INDEX][1]
print('base class', src_cls)
print('base max val', base_matrix.max())
print('base max pas', (base_matrix.argmax()//base_matrix.shape[0], 
                      base_matrix.argmax()%base_matrix.shape[0])
     )
print('base all row max vals', base_matrix.max(axis=0))
print('base all row max pos', np.argmax(base_matrix, axis=0))

base class 5
base max val 10360.0
base max pas (6, 6)
base all row max vals [7912.00 9312.00 7584.00 9168.00 9912.00 9816.00 10360.00 8256.00 7888.00
 6880.00]
base all row max pos [0 1 2 3 4 5 6 7 8 9]


In [6]:
idx_shift = 7888
tgt_cls = dataset[INDEX+idx_shift][1]
sim_matrix = grad_matrix[idx_shift]
print('src class', src_cls)
print('tgt class:', tgt_cls)
print('sim max val:', sim_matrix.max())
print('sim max pos:', sim_matrix.argmax(),
                      (sim_matrix.argmax()//sim_matrix.shape[0],
                      sim_matrix.argmax()%sim_matrix.shape[0])
     )
print('sim all row max vals:', sim_matrix.max(axis=0))
print('sim all row max pos:', np.argmax(sim_matrix, axis=0))

src class 5
tgt class: 5
sim max val: 6824.0
sim max pos: 66 (6, 6)
sim all row max vals: [5264.00 6628.00 5128.00 6688.00 6800.00 6744.00 6824.00 5316.00 5932.00
 4288.00]
sim all row max pos: [0 1 2 3 4 5 6 7 8 9]


In [7]:
NUM_IJ_IS_MAX = 0 # element specified by labels is max
NUM_I_IS_MAX = 0 # row specified by labels is max
NUM_J_IS_MAX = 0 # column specified by labels is max
IJ_MAX_OVER_MEAN = []
NIJ_MAX_OVER_MEAN = []
NUM_MAX_ON_DIAGONAL = 0

num_tgt = grad_matrix.shape[0]
for idx_shift in range(num_tgt):
    tgt_cls = dataset[INDEX+idx_shift][1]
    sim_matrix = grad_matrix[idx_shift]
    i = sim_matrix.argmax() // sim_matrix.shape[0]
    j = sim_matrix.argmax() %  sim_matrix.shape[0]
    if i == src_cls or i == tgt_cls:
        NUM_I_IS_MAX += 1
    if j == tgt_cls or j == src_cls:
        NUM_J_IS_MAX += 1
    if i == src_cls and j == tgt_cls:
        NUM_IJ_IS_MAX += 1
        IJ_MAX_OVER_MEAN.append(sim_matrix.max() / sim_matrix.mean())
    if not i == src_cls or not j == tgt_cls:
        NIJ_MAX_OVER_MEAN.append(sim_matrix.max() / sim_matrix.mean())
    if i == j:
        NUM_MAX_ON_DIAGONAL += 1

In [8]:
print(NUM_IJ_IS_MAX / num_tgt)
print(NUM_I_IS_MAX / num_tgt)
print(NUM_J_IS_MAX / num_tgt)
print(np.mean(IJ_MAX_OVER_MEAN))
print(np.mean(NIJ_MAX_OVER_MEAN))
print(NUM_MAX_ON_DIAGONAL / num_tgt)

0.006583333333333333
0.13711666666666666
0.13711666666666666
23.02
31.4
1.0
