# Playground

Used for testing on the fly new features

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

In [2]:
import matplotlib.pyplot as plt
import torch.nn as nn

from torchvision.transforms import ToTensor
from torchvision.models import resnet18
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from PytorchModulePCA import PytorchModulePCA
from PytorchModulePCA.utils import torch_img2numpy_img
from fastai.vision import *
from fastai.layers import simple_cnn
from torchvision.utils import make_grid

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TRAIN = False

train_ds = MNIST(root='~/Documents/datasets/', download=True, transform=ToTensor())
train_dl = DataLoader(train_ds, num_workers=14, batch_size=128, shuffle=True)
test_ds = MNIST(root='~/Documents/datasets/', download=True, transform=ToTensor(), train=False)
test_dl = DataLoader(test_ds, num_workers=14, batch_size=128, shuffle=False, drop_last=True)
# data = ImageDataBunch(train_dl, test_dl) # create DataBunch


In [3]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder,self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3),
            nn.LeakyReLU(True),
            nn.Conv2d(32,16,kernel_size=3),
            )
        
        self.decoder = nn.Sequential(             
            nn.ConvTranspose2d(16,32,kernel_size=3),
            nn.LeakyReLU(True),
            nn.ConvTranspose2d(32,1,kernel_size=3),
            nn.LeakyReLU(True),
            nn.Sigmoid())
        
    def forward(self,x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    
model = Autoencoder()

In [None]:
from tqdm import tqdm_notebook as tqdm
model = Autoencoder().to(device)

PATH = './autoencoder-mnist.pth'
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),weight_decay=1e-5, lr=0.001)
num_epochs = 10

bar = tqdm(range(num_epochs))
for epoch in bar:
    for data in train_dl:
        x, _ = data
        x = x.to(device)
        
        output = model(x)
        loss = criterion(output, x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    bar.set_description('epoch [{}/{}], loss:{:.4f}'.format(epoch+1, num_epochs, loss.item()))
    
torch.save(model.state_dict(), PATH)


HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

In [None]:
model.load_state_dict(torch.load(PATH))
model

In [None]:
# learn.load('./mnist-simple')

# model = learn.model
last_conv_layer = model.encoder[2] # get the last conv layer
print(last_conv_layer)

def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    
module_pca = PytorchModulePCA(model.eval(), last_conv_layer.eval(), test_dl)
module_pca(k=4, n_batches=None) # run only on 4 batches
# module_pca.plot() # plot
# module_pca.annotate()
# plt.savefig('./images/example')
# df = module_pca.state.to_df() # get the points as pandas df
# print(df)

In [None]:
imgs, img, info = module_pca.query(1, 8*8, numpy=False)
_, dist = info

fig = plt.figure()
plt.title('Original')
show(make_grid(img))

In [None]:
fig = plt.figure()
plt.title('Similar images')
show(make_grid(imgs))
print(dist)