In [93]:
# File processing
import os

# Data processing
import numpy as np
from sklearn.preprocessing import MinMaxScaler

# Machine Learning
import random
import torch
import torch.nn.functional as F
from torch import linalg as LAtorch
from numpy import linalg as LAnumpy
from scipy.spatial import distance_matrix

# Data display 
import matplotlib
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from matplotlib.legend_handler import HandlerLine2D

# Set device type
device = torch.device('cpu')

In [94]:
# Models constants
SEED = 0
EMBEDDING_SIZE = 3 # Euclidean 3D space

# Trussart test set constants
NB_BINS = 202
TRUSSART_HIC_PATH = '../../data/trussart/hic_matrices/150_TADlike_alpha_150_set0.mat'
TRUSSART_STRUCTURES_PATH = '../../data/trussart/structure_matrices/'
NB_TRUSSART_STRUCTURES = 100

# TREACH3D 
TREACH3D_LINEAR_STRUCTURE_PATH = 'experiment_results/linear/non_ae_synthetic_random_linear_trussart_test_structure_150_0.1.txt'

# REACH3D
REACH3D_STRUCTURE_PATH = '../previous_work/reach3d/reach3d_trussart_output_structure_150.txt'

# GEM 
GEM_STRUCTURE_PATH = '../previous_work/gem/data/trussart/trussart_structure_formatted/trussart_structure_formatted_150.txt'

# miniMDS
MINIMDS_STRUCTURE_PATH = '../previous_work/minimds/data/trussart/trussart_structure_formatted/trussart_structure_formatted_150.txt'

In [95]:
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

In [96]:
trussart_hic = np.loadtxt(TRUSSART_HIC_PATH, dtype='f', delimiter='\t')
scaler = MinMaxScaler()
trussart_hic = scaler.fit_transform(trussart_hic)

In [97]:
trussart_structures = []

file_list = os.listdir(TRUSSART_STRUCTURES_PATH)
file_list = filter(lambda f: f.endswith('.xyz'), file_list)

for file_name in file_list:
    current_trussart_structure = np.loadtxt(TRUSSART_STRUCTURES_PATH + file_name, dtype='f', delimiter='\t')
    current_trussart_structure = current_trussart_structure[:,1:]
    trussart_structures.append(current_trussart_structure)

In [98]:
def centralize_torch(z):
    return z - torch.repeat_interleave(torch.reshape(torch.mean(z, axis=1), (-1,1,EMBEDDING_SIZE)), NB_BINS, dim=1)

In [99]:
def normalize_torch(z):
    
    norms = LAtorch.norm(z, 2, dim=2)
    max_norms, _ = torch.max(norms, axis=1)
    max_norms = torch.reshape(max_norms, (BATCH_SIZE,1,1))
    max_norms = torch.repeat_interleave(max_norms, NB_BINS, dim=1)
    max_norms = torch.repeat_interleave(max_norms, EMBEDDING_SIZE, dim=2)
    max_norms[max_norms == 0] = 1
    
    return z / max_norms

In [100]:
def centralize_and_normalize_torch(z):
    
    # Translate
    z = centralize_torch(z)
    
    # Scale
    z = normalize_torch(z)
    
    return z

In [101]:
def centralize_numpy(z):
    return z - np.mean(z, axis=0)

In [102]:
def normalize_numpy(z):
    
    norm = LAnumpy.norm(z, 2, axis=1)
    max_norm = np.max(norm, axis=0)
    if max_norm == 0:
        max_norm = 1
    
    return z / max_norm

In [103]:
def centralize_and_normalize_numpy(z):
    
    # Translate
    z = centralize_numpy(z)
    
    # Scale
    z = normalize_numpy(z)
    
    return z

In [104]:
def kabsch_superimposition_numpy(pred_structure, true_structure):
    
    # Centralize and normalize to unit ball
    pred_structure_unit_ball = centralize_and_normalize_numpy(pred_structure)
    true_structure_unit_ball = centralize_and_normalize_numpy(true_structure)
    
    # Rotation (solution for the constrained orthogonal Procrustes problem, subject to det(R) = 1)
    m = np.matmul(np.transpose(true_structure_unit_ball), pred_structure_unit_ball)
    u, s, vh = np.linalg.svd(m)
    
    d = np.sign(np.linalg.det(np.matmul(u, vh)))
    a = np.eye(EMBEDDING_SIZE)
    a[-1,-1] = d
    
    r = np.matmul(np.matmul(u, a), vh)
    
    pred_structure_unit_ball = np.transpose(np.matmul(r, np.transpose(pred_structure_unit_ball)))
    
    return pred_structure_unit_ball, true_structure_unit_ball

In [105]:
def kabsch_distance_numpy(pred_structure, true_structure):
    
    pred_structure_unit_ball, true_structure_unit_ball = kabsch_superimposition_numpy(pred_structure, true_structure)
    
    # Structure comparison
    d = np.mean(np.sum(np.square(pred_structure_unit_ball - true_structure_unit_ball), axis=1))
    
    return d

In [106]:
#reach3d_structure = np.loadtxt(REACH3D_STRUCTURE_PATH, dtype='f', delimiter='\t')

In [107]:
minimds_structure = np.loadtxt(MINIMDS_STRUCTURE_PATH, dtype='f', delimiter=' ')

FileNotFoundError: ../previous_work/minimds/trussart/trussart_structure_formatted/trussart_structure_formatted_150.txt not found.

In [88]:
gem_structure = np.loadtxt(GEM_STRUCTURE_PATH, dtype='f', delimiter='\t')

In [None]:
#treach3d_linear_structure = np.loadtxt(TREACH3D_LINEAR_STRUCTURE_PATH, dtype='f', delimiter=' ')

In [89]:
def compute_kabsch_distance_distribution(pred_structure):
    
    kabsch_distance_distribution = []
    for true_structure in trussart_structures:
        kabsch_distance_distribution.append(kabsch_distance_numpy(pred_structure, true_structure))
        
    return kabsch_distance_distribution

In [90]:
#reach3d_distribution = compute_kabsch_distance_distribution(reach3d_structure)
gem_distribution = compute_kabsch_distance_distribution(gem_structure)
#minimds_distribution = compute_kabsch_distance_distribution(minimds_structure)

0.4283119056535361

In [None]:
#treach3d_linear_distribution = compute_kabsch_distance_distribution(treach3d_linear_structure)
# treach3d_bi_lstm_distribution = compute_kabsch_distance_distribution(treach3d_bi_lstm_structure)

In [None]:
def random_prediction(nb_bins):
    
    # Random r's
    r_s = np.random.uniform(low=0, high=1, size=nb_bins)
    
    # Random theta's
    theta_s = np.random.uniform(low=0, high=np.pi, size=nb_bins)
    
    # Random phi's
    phi_s = np.random.uniform(low=0, high=2*np.pi, size=nb_bins)
    
    final_structure = []
    for i in range(nb_bins):
        x = r_s[i] * np.cos(phi_s[i]) * np.sin(theta_s[i])
        y = r_s[i] * np.sin(phi_s[i]) * np.sin(theta_s[i])
        z = r_s[i] * np.cos(theta_s[i])
        final_structure.append([x, y, z])
        
    return np.asarray(final_structure).reshape((nb_bins, EMBEDDING_SIZE))

In [None]:
random_structure = random_prediction(NB_BINS)
random_distribution = compute_kabsch_distance_distribution(random_structure)

In [21]:
perfect_structure = np.mean(trussart_structures, axis=0)
perfect_distribution  = compute_kabsch_distance_distribution(perfect_structure)

In [22]:
perfect_structure = centralize_and_normalize_numpy(perfect_structure)

#treach3d_linear_structure, treach3d_linear_structure_perfect = \
#    kabsch_superimposition_numpy(treach3d_linear_structure, perfect_structure)

#reach3d_structure, reach3d_structure_perfect = \
#    kabsch_superimposition_numpy(reach3d_structure, perfect_structure)

#minimds_structure, minimds_structure_perfect = \
#    kabsch_superimposition_numpy(minimds_structure, perfect_structure)

gem_structure, gem_structure_perfect = \
    kabsch_superimposition_numpy(gem_structure, perfect_structure)

#random_structure, random_structure_perfect = \
#    kabsch_superimposition_numpy(random_structure, perfect_structure)

In [24]:
d_perfect_structure = \
    np.diagonal(distance_matrix(perfect_structure, perfect_structure), offset=1)
#d_treach3d_linear_structure = \
#    np.diagonal(distance_matrix(treach3d_linear_structure, treach3d_linear_structure), offset=1)
#d_reach3d_linear_structure = \
#    np.diagonal(distance_matrix(reach3d_structure, reach3d_structure), offset=1)
#d_minimds_structure = \
#    np.diagonal(distance_matrix(minimds_structure, minimds_structure), offset=1)
d_gem_structure = \
    np.diagonal(distance_matrix(gem_structure, gem_structure), offset=1)
d_random_structure = \
    np.diagonal(distance_matrix(random_structure, random_structure), offset=1)

0.431049598430217

In [None]:
print(np.var(d_perfect_structure))
#print(np.var(d_minimds_structure))
#print(np.var(d_treach3d_linear_structure))
print(np.var(d_gem_structure))
#print(np.var(d_reach3d_linear_structure))
print(np.var(d_random_structure))

In [None]:
#reach3d_mean = np.mean(reach3d_distribution)
gem_mean = np.mean(gem_distribution)
#minimds_mean = np.mean(minimds_distribution)