In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

import torchinfo

import matplotlib.pyplot as plt
import numpy as np
import time

import os
import pathlib
from PIL import Image
import skimage
from tqdm import tqdm

# importing a module with utilities for displaying stats and data
import sys
sys.path.insert(1, '../../util')
import vcpi_util


In [None]:
print(torch.__version__)

In [None]:
def train(model, train_loader, val_loader, epochs, loss_fn, optimizer, scheduler, early_stopper, save_prefix = 'model'):

    history = {}

    history['val_loss'] = []
    history['loss'] = []
    best_val_loss = np.inf

    for epoch in range(epochs):  # loop over the dataset multiple times

        model.train()
        start_time = time.time() 
        correct = 0
        running_loss = 0.0
        for i, (inputs, _) in tqdm(enumerate(train_loader, 0)):
            
            inputs = inputs.to(device)
    
            outputs = model(inputs)
    
            loss = loss_fn(outputs, inputs)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.cpu().detach().numpy()
     
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for i,_ in val_loader:
                i = i.to(device)
                o = model(i)
                
                #with torch.no_grad():
                val_loss += loss_fn(o, i).cpu().detach().numpy()


        old_lr = optimizer.param_groups[0]['lr']
        scheduler.step(val_loss)
        new_lr = optimizer.param_groups[0]['lr']
        
        if old_lr != new_lr:
            print('==> Learning rate updated: ', old_lr, ' -> ', new_lr)

        epoch_loss = running_loss / len(train_loader.dataset)
        val_loss = val_loss / len(val_loader.dataset)
        stop_time = time.time()
        print(f'Epoch: {epoch:03d}; Loss: {epoch_loss:0.6f}; Val Loss: {val_loss:0.6f}; Elapsed time: {(stop_time - start_time):0.4f}')

        history['val_loss'].append(val_loss)
        history['loss'].append(epoch_loss)
 
        ###### Saving ######
        if val_loss < best_val_loss:
           
            torch.save({
                'epoch': epoch,
                'model':model.state_dict(),
                'history': history,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
                },
                f'{save_prefix}_best.pt')
            best_val_loss = val_loss

        if early_stopper(val_loss):
            print('Early stopping!')
            break
        
    print('Finished Training')

    return(history)



class Early_Stopping():

    def __init__(self, patience = 3, min_delta = 0.00001):

        self.patience = patience 
        self.min_delta = min_delta

        self.min_val_loss = float('inf')

    def __call__(self, val_loss):

        # improvement
        if val_loss + self.min_delta < self.min_val_loss:
            self.min_val_loss = val_loss
            self.counter = 0

        # no improvement            
        else:
            self.counter += 1
            if self.counter > self.patience:
                return True
            
        return False
    
from matplotlib import colors

def plot_scatter(x,y,targets):
    cmap = colors.ListedColormap(['black', 'darkred', 'darkblue', 
                                  'darkgreen', 'yellow', 'brown', 
                                  'purple', 'lightgreen', 'red', 'lightblue'])
    bounds=[0, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5,8.5,9.5]
    norm = colors.BoundaryNorm(bounds, cmap.N)

    plt.figure(figsize=(10,10))
    plt.scatter(x, y, c = targets, cmap=cmap, s = 1, norm=norm)
    plt.colorbar()

    plt.show()


## Configuration

In [None]:
HEIGHT = 28
WIDTH = 28
NUM_CHANNELS = 1
BATCH_SIZE = 32
LATENT_SPACE_DIM = 2

MODEL_PATH = 'autoencoder_models'

train_online = True

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)


In [None]:
from importlib import reload
reload(vcpi_util)

## Load and prepare MNIST dataset

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor()]) 

full_dataset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_set, val_set = torch.utils.data.random_split(full_dataset, [train_size, val_size])

train_loader = torch.utils.data.DataLoader(train_set, batch_size = BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size = BATCH_SIZE)

test_set = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE,
                                         shuffle=False)

In [None]:
images, targets = next(iter(train_loader))

vcpi_util.show_images(4,8, images, targets, full_dataset.classes) 

In [None]:
class Autoencoder(torch.nn.Module):
    
    def __init__(self, latent_space_dim):
        
        super().__init__()
        
        self.econv1 = torch.nn.Conv2d(1, 32, 3, 2)
        self.erelu1 = torch.nn.ReLU()
        self.eb1 = torch.nn.BatchNorm2d(32)
        
        self.econv2 = torch.nn.Conv2d(32, 64, 3, 2)
        self.erelu2 = torch.nn.ReLU()
        self.eb2 = torch.nn.BatchNorm2d(64)
        
        self.efc1 = torch.nn.Linear(64*7*7, latent_space_dim)
        
        # Decoder
        
        self.dfc1 = torch.nn.Linear(latent_space_dim, 64*7*7)
        self.dconv1 = torch.nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)
        self.drelu1 = torch.nn.ReLU()
        self.db1 = torch.nn.BatchNorm2d(32)
        
        self.dconv2 = torch.nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1)
        self.dsig = torch.nn.Sigmoid()
        
    def encoder(self, x):
        x = torch.nn.functional.pad(x, (0, 1, 0, 1))
        x = self.econv1(x)
        x = self.erelu1(x)
        x = self.eb1(x)
        
        x = torch.nn.functional.pad(x, (0, 1, 0, 1))
        x = self.econv2(x)
        x = self.erelu2(x)
        x = self.eb2(x)
        
        x = torch.flatten(x, 1)
        
        x = self.efc1(x)
        
        return x

    def decoder(self, x):
        
        x = self.dfc1(x)
        x = x.reshape(-1, 64, 7, 7)
        x = self.dconv1(x)
        x = self.drelu1(x)
        x = self.db1(x)
        
        x = self.dconv2(x)
        x = self.dsig(x)
        
        return x
        
    def forward(self,x):
        encoded = self.encoder(x)
        result = self.decoder(encoded)
        return result 

In [None]:
AE = Autoencoder(LATENT_SPACE_DIM)
AE.to(device)
torchinfo.summary(AE, input_size=(BATCH_SIZE, NUM_CHANNELS, HEIGHT, WIDTH))

In [None]:
loss_fn = torch.nn.MSELoss()

opt = torch.optim.Adam(AE.parameters(), lr = 1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.1, patience=3)
early_stop = Early_Stopping(9)

In [None]:
history = train(AE, train_loader, val_loader, 100, loss_fn, opt, scheduler, early_stop, f'auto_{LATENT_SPACE_DIM}')

In [None]:
reload = torch.load(f'auto_{LATENT_SPACE_DIM}_best.pt')
AE.load_state_dict(reload['model'])

In [None]:
def show_preds(set1, set2, count):
    
    columns = 4
    rows = int(count * 2 / columns) + 1
    plt.figure(figsize=(count, 2*rows))
    
    for n in range(count):
        ax = plt.subplot(rows, columns, n*2 + 1)
        plt.title('original')
        plt.imshow(np.transpose(set1[n].cpu().detach().numpy(), (1, 2, 0)), cmap= plt.cm.gray)
        plt.axis('off')
        
        ax = plt.subplot(rows, columns, n*2 + 2)
        plt.title('recon')
        plt.imshow(np.transpose(set2[n].cpu().detach().numpy(), (1, 2, 0)), cmap= plt.cm.gray)
        plt.axis('off')

i, _ = next(iter(test_loader))
recon = AE(i.to(device))
show_preds(i, recon, 10)
    

In [None]:
encoded = []
targets = []

for i,t in test_loader:
    encoded.extend(AE.encoder(i.to(device)).cpu().detach().numpy())
    targets.extend(t)
    
x = np.array(encoded)[:,0]
y = np.array(encoded)[:,1]

plot_scatter(x, y, targets)


In [None]:
pred = AE.decoder(torch.Tensor([[-4,0]]).to(device))
plt.imshow(np.transpose(pred[0].cpu().detach().numpy(), (1, 2, 0)), cmap= plt.cm.gray)

In [None]:
limit = 5
steps = 20

step = (2*limit)/steps
vector_generation = [[-limit + j * step, limit - i * step] for i in range(steps) for j in range(steps)]

predictions = AE.decoder(torch.Tensor(vector_generation).to(device))

vcpi_util.show_predicted_images(steps, steps, predictions.cpu().detach(),10)

In [None]:
import glob

def show_anomaly_sample(image, recon, err):
    
    plt.figure(figsize=(5,2))
    alx = plt.subplot(1, 2, 0)
    plt.imshow(image, cmap='gray')
    plt.title(str(err))
    ax = plt.subplot(1, 2, 1)
    plt.imshow(recon, cmap='gray')
    plt.axis('off')
    
def anomalyDetection(image_path):
    
    files = glob.glob(f'{image_path}/*.jpg')
    for f in files:

        img = Image.open(f).convert('L').resize((28,28))
        data = transform(img).to(device).view(1,1,28,28)
        
        testing = AE(data)
        err = np.mean(img - testing.cpu().detach().numpy())
        show_anomaly_sample(img, np.transpose(testing.cpu().detach().numpy(), (1, 2, 0)), err)

image_path = 'anomalyDetectionImages'
anomalyDetection(image_path)

In [None]:
features = []

for i in range(train_set.__len__()):
    features.extend(AE.encoder(train_set[i][0].view(1,1,28,28).to(device)))

In [None]:
indexes = list(range(0, train_set.__len__()))
data = {'indexes':indexes, 'features':features}

def show_content_retrieval(train_X, top50):
    fig = plt.figure(figsize=(10, 50))
    for i in range(len(top50)):
        ax = plt.subplot(25, 5, i+1)
        plt.imshow(np.transpose(train_X[top50[i][1]][0], (1, 2, 0)))
        plt.title(f'{top50[i][0]:.3f}')
        plt.axis('off')

In [None]:
image_path = 'contentRetrievalImages'

img = Image.open(image_path).convert('L').resize((28,28))
img_tensor = transform(img).to(device).view(1,1,28,28)

dataLatent = AE.encoder(img_tensor).cpu().detach().numpy().squeeze()

In [None]:
results = []
for i in range(len(data['indexes'])):
    err = np.sum((dataLatent - data['features'][i].cpu().detach().numpy()) ** 2)
    err /= float(dataLatent.shape[0])
    
    results.append([err, i])
    

In [None]:
top50 = sorted(results)[:50]

show_content_retrieval(train_set, top50)