In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.cuda
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
import os
from PIL import Image
from tqdm.notebook import tqdm
from torch.nn.functional import relu
from torch.utils.data import DataLoader, Dataset
from sklearn.manifold import TSNE
from SHG import SHG
from utils import *
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.base import clone
from sklearn_extra.cluster import KMedoids
from AE import AE
from SHG import SHG
from SHG_AE import SHG_AE

In [2]:
""" PATHS """

# MODEL
MODEL_PATH = "C:/Users/André/OneDrive 2/OneDrive/Skrivebord/bsc_data/SHG_AE/models/Wed_May_26_10-42-05_2021/epoch_15.pth"

# Input images
TRAIN_IMGS_PATH = "C:/Users/André/OneDrive 2/OneDrive/Skrivebord/bsc_data/train/image/"

# Ground truth heatmaps
HEATMAPS_PATH = "C:/Users/André/OneDrive 2/OneDrive/Skrivebord/bsc_data/train/heatmaps/"

# SAVING PATH
ALL_SAVING_PATH = "C:/Users/André/OneDrive 2/OneDrive/Skrivebord/bsc_data/SHG_AE/latent_space/all/"
FULL_SAVING_PATH = "C:/Users/André/OneDrive 2/OneDrive/Skrivebord/bsc_data/SHG_AE/latent_space/full/"

# Path for mean rgb
average_rgb = np.loadtxt("./average_rgb.npy")

In [3]:
""" DATASET AND DATALOADER """

class dataset(Dataset):
    def __init__(self, X_path, y_path, average_rgb):
        self.X_path = X_path
        self.y_path = y_path
        self.X_data = os.listdir(self.X_path)
        self.average_rgb = average_rgb
        self.norm = transforms.Normalize(mean = self.average_rgb, std = [1, 1, 1])

    def __len__(self):
        return len(self.X_data)

    def __getitem__(self, i):
        ID = self.X_data[i]
        x = Image.open(self.X_path + ID)

        y = []

        for i in range(17):
            y.append(torch.from_numpy(np.load(self.y_path + ID[:-4] + "/" + str(i) + ".npy")))

        x = TF.to_tensor(x)

        if (x.shape[0] == 1): # If the image is gray-scale, cast it to rgb
            x = torch.stack((x[0],) * 3)

        x = self.norm(x) # Subtracts mean rgb

        y = torch.stack(y)
        return x, y, ID

train_data = dataset(TRAIN_IMGS_PATH, HEATMAPS_PATH, average_rgb)
train_dataloader = DataLoader(train_data, batch_size = 1, shuffle = True)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SHG_model = SHG(num_hourglasses = 1).to(device)
AE_model = AE().to(device)
model = SHG_AE(SHG_model, AE_model).to(device)
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()
print()




In [5]:
""" Saving bottleneck data - train"""
 
with torch.no_grad():
    for x, y, id_ in tqdm(train_dataloader):
        x = x.to(device, dtype = torch.float)

        x_latent = model.encode(model.SHG_prework(x), add_noise = False).cpu().data.numpy()
        
        np.save(ALL_SAVING_PATH + id_[0][:-4] + ".npy", x_latent)

        y = np.array(turn_featuremaps_to_keypoints(y)).reshape((-1, 3))
        if (0 not in y[:, -1]):
            np.save(FULL_SAVING_PATH + id_[0][:-4] + ".npy", x_latent)

HBox(children=(FloatProgress(value=0.0, max=124040.0), HTML(value='')))


