In [1]:
import torch
from reconstruction import AE
from datasets import MeshData
from utils import utils, DataLoader, sap
import numpy as np
import os, sys
from math import ceil
from scipy.ndimage import zoom
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, mean_squared_error

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Meshplot left an annoying print statement in their code. Using this context manager to supress it...
class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

In [3]:
device = torch.device('cuda', 1)
# Set the path to the saved model directory
#model_path = "/home/jakaria/torus_bump_500_three_scale_binary_bump_variable_noise_fixed_angle/models_classification_regression_only_correlation_loss/models/65"
#model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/torus/models_contrastive_inhib/146"
#model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/torus/models_guided/30"# Load the saved model
#model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/hippocampus/models_guided/44"
model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/hippocampus/models_contrastive_inhib/172"
#model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/hippocampus/models_attribute/99"

model_state_dict = torch.load(f"{model_path}/model_state_dict.pt")
in_channels = torch.load(f"{model_path}/in_channels.pt")
out_channels = torch.load(f"{model_path}/out_channels.pt")
latent_channels = torch.load(f"{model_path}/latent_channels.pt")
spiral_indices_list = torch.load(f"{model_path}/spiral_indices_list.pt")
up_transform_list = torch.load(f"{model_path}/up_transform_list.pt")
down_transform_list = torch.load(f"{model_path}/down_transform_list.pt")
std = torch.load(f"{model_path}/std.pt")
mean = torch.load(f"{model_path}/mean.pt")
template_face = torch.load(f"{model_path}/faces.pt")

# Create an instance of the model
model = AE(in_channels, out_channels, latent_channels,
           spiral_indices_list, down_transform_list,
           up_transform_list)
model.load_state_dict(model_state_dict)
model.to(device)
# Set the model to evaluation mode
model.eval()

AE(
  (en_layers): ModuleList(
    (0): SpiralEnblock(
      (conv): SpiralConv(3, 24, seq_length=9)
    )
    (1): SpiralEnblock(
      (conv): SpiralConv(24, 24, seq_length=9)
    )
    (2): SpiralEnblock(
      (conv): SpiralConv(24, 24, seq_length=9)
    )
    (3): SpiralEnblock(
      (conv): SpiralConv(24, 48, seq_length=9)
    )
    (4): Linear(in_features=8544, out_features=32, bias=True)
  )
  (de_layers): ModuleList(
    (0): Linear(in_features=16, out_features=8544, bias=True)
    (1): SpiralDeblock(
      (conv): SpiralConv(48, 48, seq_length=9)
    )
    (2): SpiralDeblock(
      (conv): SpiralConv(48, 24, seq_length=9)
    )
    (3): SpiralDeblock(
      (conv): SpiralConv(24, 24, seq_length=9)
    )
    (4): SpiralDeblock(
      (conv): SpiralConv(24, 24, seq_length=9)
    )
    (5): SpiralConv(24, 3, seq_length=9)
  )
  (cls_sq): Sequential(
    (0): Linear(in_features=1, out_features=8, bias=True)
    (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_runn

In [9]:
template_fp = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/template/template.ply"
data_fp = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA"
test_exp = "bareteeth"
split = "interpolation"

meshdata = MeshData(data_fp,
                    template_fp,
                    split=split,
                    test_exp=test_exp)

train_loader = DataLoader(meshdata.train_dataset, batch_size=16)
test_loader = DataLoader(meshdata.test_dataset, batch_size=16)

angles_train = []
thick_train = []
latent_codes_train = []

angles_test = []
thick_test = []
latent_codes_test = []

single_latent = False

with torch.no_grad():
    for i, data in enumerate(train_loader):
        #print("train...")
        x = data.x.to(device)
        y = data.y.to(device)
        pred, mu, log_var, re, re2 = model(x)

        z = model.reparameterize(mu, log_var)
        if single_latent:
            z = z[:,0]
        latent_codes_train.append(z)
        angles_train.append(y[:, :, 1])
        thick_train.append(y[:, :, 0])

with torch.no_grad():
    for i, data in enumerate(test_loader):
        #print("test...")
        x = data.x.to(device)
        y = data.y.to(device)
        pred, mu, log_var, re, re2 = model(x)

        z = model.reparameterize(mu, log_var)
        if single_latent:
            z = z[:,0]
        latent_codes_test.append(z)
        angles_test.append(y[:, :, 1])
        thick_test.append(y[:, :, 0])

latent_codes_train = torch.concat(latent_codes_train)
if single_latent:
    latent_codes_train = latent_codes_train.cpu().numpy().reshape(-1, 1)
else:
    latent_codes_train = latent_codes_train.cpu().numpy()
angles_train = torch.concat(angles_train).view(-1,1)
angles_train = angles_train.view(-1).cpu().numpy()
thick_train = torch.concat(thick_train).view(-1,1)
thick_train = thick_train.view(-1).cpu().numpy()

latent_codes_test = torch.concat(latent_codes_test)
if single_latent:
    latent_codes_test = latent_codes_test.cpu().numpy().reshape(-1, 1)
else:   
    latent_codes_test = latent_codes_test.cpu().numpy()
angles_test = torch.concat(angles_test).view(-1,1)
angles_test = angles_test.view(-1).cpu().numpy()
thick_test = torch.concat(thick_test).view(-1,1)
thick_test = thick_test.view(-1).cpu().numpy()

# Train a classifier on the latent codes
X_train = latent_codes_train
y_train = thick_train
X_test = latent_codes_test
y_test = thick_test

#knn = KNeighborsClassifier(n_neighbors=5)
knn = KNeighborsRegressor(n_neighbors=12)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
#print(y_test[:10], y_pred[:10])
#print(latent_codes_test[:10], y_test[:10])

#print(y_test, y_pred)
#print("Accuracy of the KNN for binary bump: ", accuracy_score(y_test, y_pred[:len(y_test)]))
print("MSE of the KNN for thickness: ", mean_squared_error(y_test, y_pred[:len(y_test)]))




Normalizing...
Done!
MSE of the KNN for thickness:  0.024993027


In [7]:
LR = LinearRegression()
LR.fit(X_train, y_train)
y_pred = LR.predict(X_test)
print("MSE of the Linear Regression for thickness: ", mean_squared_error(y_test, y_pred[:len(y_test)]))

MSE of the Linear Regression for thickness:  0.0017567916
