In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2 as cv

import os

In [None]:
def imshow(img, title=None):
    plt.imshow(img/(np.max(img)-np.min(img)))
    if title is not None:
        plt.title(title)
    plt.show()

In [None]:
data = pd.read_csv('mnist/mnist_test.csv')

In [None]:
def read_csv(train_file, test_file, data_dir='data/'):
    train_data = pd.read_csv(train_file)
    test_data = pd.read_csv(test_file)

    x_train = train_data.values[:,1:].reshape((-1,28,28))
    y_train = np.asarray(train_data['label'])

    x_test = test_data.values[:,1:].reshape((-1,28,28))
    y_test = np.asarray(test_data['label'])

    return x_train.astype(np.float), y_train, x_test.astype(np.float), y_test

In [None]:
def transform(x, size=(256, 256)):
    x_new = np.zeros(size)
    h = np.random.randint(x.shape[0]/2,x.shape[0]*3,1)
    w = np.random.randint(x.shape[1]/2,x.shape[1]*3,1)
    
    x = cv.resize(x, (h,w))
    l = np.random.randint(0, size[1]-w,1)
    u = np.random.randint(0, size[0]-h,1)
    
    x_new[l[0]:l[0]+w[0], u[0]:u[0]+h[0]] = x
    return x_new

In [None]:
x_train, y_train, x_test, y_test = read_csv('mnist/mnist_train.csv', 'mnist/mnist_test.csv')

In [None]:
size = (128,128)
x_train_new = np.empty((x_train.shape[0], *size))
x_test_new = np.empty((x_test.shape[0], *size))

In [None]:
train_small = 8000
val_small = 2000
test_small = 2000

In [None]:
for i in range(x_train.shape[0]):
    x_train_new[i, :, :] = transform(x_train[i, :, :], size=size)

In [None]:
train_val_idx = np.random.choice(range(x_train.shape[0]), size=train_small+val_small)

In [None]:
np.savez('data/train_small',x=x_train_new[train_val_idx[:train_small]],y=y_train[train_val_idx[:train_small]])
np.savez('data/val_small',x=x_train_new[train_val_idx[train_small:train_small+val_small]],y=y_train[train_val_idx[train_small:train_small+val_small]])

In [None]:
train_val_idx[train_small+val_small]

In [None]:
for i in range(test_small):
    x_test_new[i, :, :] = transform(x_test[i, :, :], size=size)
x_test = None
    
# np.savez('data/test',x=x_test_new,y=y_test)
np.savez('data/test_small',x=x_test_new[:test_small],y=y_test[:test_small])
x_test_new = None
y_test = None

In [None]:
from models import AGCNN

ag_model = AGCNN(input_shape=(1, *image_size), dropout=dropout, num_classes=num_classes)
ag_model = ag_model.to(device)

ag_model.global_branch.load_state_dict(torch.load(os.path.join(models_dir, 'cnn_best')))

In [None]:
with torch.no_grad():
    im_train = None
    labels_train = None
    for data in tqdm_notebook(train_loader,
                              total=len(train_loader),
                              desc='Processing Train Data: '):
        x, y = data.values()
        x = x.to(device)
        if im_train is None:
            im_train = ag_model._get_local_img(x)
            labels_train = y
        else:
            im_train = torch.cat([im_train, ag_model._get_local_img(x)], dim=0)
            labels_train = torch.cat([labels_train, y], dim=0)
    im_train = im_train.squeeze().cpu().numpy()
    labels_train = np.where(labels_train==1)[1]
    
    im_val = None
    labels_val = None
    for data in tqdm_notebook(val_loader,
                              total=len(val_loader),
                              desc='Processing Val Data: '):
        x, y = data.values()
        x = x.to(device)
        if im_val is None:
            im_val = ag_model._get_local_img(x)
            labels_val = y
        else:
            im_val = torch.cat([im_val, ag_model._get_local_img(x)], dim=0)
            labels_val = torch.cat([labels_val, y], dim=0)
    im_val = im_val.squeeze().cpu().numpy()
    labels_val = np.where(labels_val==1)[1]

In [None]:
np.savez(os.path.join(data_dir, 'local_train_small'), x=im_train, y=labels_train)
np.savez(os.path.join(data_dir, 'local_val_small'), x=im_val, y=labels_val)