In [1]:
# File processing
import os

# Data processing
import numpy as np
from sklearn.preprocessing import MinMaxScaler

# Machine Learning
import random
import torch
import torch.nn.functional as F
from torch import linalg as LAtorch
from numpy import linalg as LAnumpy

# Plotting
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Set device type
device = torch.device('cpu')

# Constants

In [2]:
# Models constants
SEED = 0
BATCH_SIZE = 1
EMBEDDING_SIZE = 3 # Euclidean 3D space
MODELS_ROOT = '../../../t_reach_3d_1.0/t_reach_3d_1.0_dev/models/'

# Trussart test set constants
NB_BINS = 202
TRUSSART_HIC_PATH = '../../data/trussart/hic_matrices/150_TADlike_alpha_150_set0.mat'
TRUSSART_STRUCTURES_PATH = '../../data/trussart/structure_matrices/'
NB_TRUSSART_STRUCTURES = 100

# REACH3D
REACH3D_STRUCTURE_PATH = '../../../previous_works/reach3d/reach3d_trussart_output_structure_150.txt'

# GEM 
GEM_STRUCTURE_PATH = '../../../previous_works/gem/data/trussart/trussart_structure_formatted/trussart_structure_formatted_150.txt'

# miniMDS
MINIMDS_STRUCTURE_PATH = '../../../previous_works/minimds/data/trussart/trussart_structure_formatted/trussart_structure_formatted_150.txt'

# Seeds

In [3]:
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

# Structure analysis functions

### Torch

In [4]:
def centralize_torch(z):
    return z - torch.repeat_interleave(torch.reshape(torch.mean(z, axis=1), (-1,1,EMBEDDING_SIZE)), NB_BINS, dim=1)

In [5]:
def normalize_torch(z):
    
    norms = LAtorch.norm(z, 2, dim=2)
    max_norms, _ = torch.max(norms, axis=1)
    max_norms = torch.reshape(max_norms, (BATCH_SIZE,1,1))
    max_norms = torch.repeat_interleave(max_norms, NB_BINS, dim=1)
    max_norms = torch.repeat_interleave(max_norms, EMBEDDING_SIZE, dim=2)
    max_norms[max_norms == 0] = 1
    
    return z / max_norms

In [6]:
def centralize_and_normalize_torch(z):
    
    # Translate
    z = centralize_torch(z)
    
    # Scale
    z = normalize_torch(z)
    
    return z

### Numpy

In [7]:
def centralize_numpy(z):
    return z - np.mean(z, axis=0)

In [8]:
def normalize_numpy(z):
    
    norm = LAnumpy.norm(z, 2, axis=1)
    max_norm = np.max(norm, axis=0)
    if max_norm == 0:
        max_norm = 1
    
    return z / max_norm

In [9]:
def centralize_and_normalize_numpy(z):
    
    # Translate
    z = centralize_numpy(z)
    
    # Scale
    z = normalize_numpy(z)
    
    return z

In [10]:
def kabsch_superimposition_numpy(pred_structure, true_structure):
    
    # Centralize and normalize to unit ball
    pred_structure_unit_ball = centralize_and_normalize_numpy(pred_structure)
    true_structure_unit_ball = centralize_and_normalize_numpy(true_structure)
    
    # Rotation (solution for the constrained orthogonal Procrustes problem, subject to det(R) = 1)
    m = np.matmul(np.transpose(true_structure_unit_ball), pred_structure_unit_ball)
    u, s, vh = np.linalg.svd(m)
    
    d = np.sign(np.linalg.det(np.matmul(u, vh)))
    a = np.eye(EMBEDDING_SIZE)
    a[-1,-1] = d
    
    r = np.matmul(np.matmul(u, a), vh)
    
    pred_structure_unit_ball = np.transpose(np.matmul(r, np.transpose(pred_structure_unit_ball)))
    
    return pred_structure_unit_ball, true_structure_unit_ball

In [11]:
def kabsch_distance_numpy(pred_structure, true_structure):
    
    pred_structure_unit_ball, true_structure_unit_ball = kabsch_superimposition_numpy(pred_structure, true_structure)
    
    # Structure comparison
    d = np.mean(np.sum(np.square(pred_structure_unit_ball - true_structure_unit_ball), axis=1))
    
    return d

# Create dataset

## Hic matrice

In [12]:
trussart_hic = np.loadtxt(TRUSSART_HIC_PATH, dtype='f', delimiter='\t')
scaler = MinMaxScaler()
trussart_hic = scaler.fit_transform(trussart_hic)

## Structure matrices

In [13]:
trussart_structures = []

file_list = os.listdir(TRUSSART_STRUCTURES_PATH)
file_list = filter(lambda f: f.endswith('.xyz'), file_list)

for file_name in file_list:
    current_trussart_structure = np.loadtxt(TRUSSART_STRUCTURES_PATH + file_name, dtype='f', delimiter='\t')
    current_trussart_structure = current_trussart_structure[:,1:]
    current_trussart_structure = centralize_and_normalize_numpy(current_trussart_structure)
    trussart_structures.append(current_trussart_structure)

# Import previous works structures

## REACH-3D

In [14]:
reach3d_structure = np.loadtxt(REACH3D_STRUCTURE_PATH, dtype='f', delimiter='\t')
reach3d_structure = centralize_and_normalize_numpy(reach3d_structure)

OSError: ../../../previous_works/reach3d/reach3d_trussart_output_structure_150.txt not found.

## MiniMDS

In [None]:
minimds_structure = np.loadtxt(MINIMDS_STRUCTURE_PATH, dtype='f', delimiter=' ')
minimds_structure = centralize_and_normalize_numpy(minimds_structure)

## GEM

In [None]:
gem_structure = np.loadtxt(GEM_STRUCTURE_PATH, dtype='f', delimiter='\t')
gem_structure = centralize_and_normalize_numpy(gem_structure)

# Threshold structures

## Random prediction threshold

In [None]:
def random_prediction(nb_bins):
    
    # Random r's
    r_s = np.random.uniform(low=0, high=1, size=nb_bins)
    
    # Random theta's
    theta_s = np.random.uniform(low=0, high=np.pi, size=nb_bins)
    
    # Random phi's
    phi_s = np.random.uniform(low=0, high=2*np.pi, size=nb_bins)
    
    final_structure = []
    for i in range(nb_bins):
        x = r_s[i] * np.cos(phi_s[i]) * np.sin(theta_s[i])
        y = r_s[i] * np.sin(phi_s[i]) * np.sin(theta_s[i])
        z = r_s[i] * np.cos(theta_s[i])
        final_structure.append([x, y, z])
        
    return np.asarray(final_structure).reshape((nb_bins, EMBEDDING_SIZE))

In [None]:
random_structure = random_prediction(NB_BINS)
random_structure = centralize_and_normalize_numpy(random_structure)

## Prefect prediction threshold

In [None]:
perfect_structure = np.mean(trussart_structures, axis=0)
perfect_structure = centralize_and_normalize_numpy(perfect_structure)

# Final plot

In [None]:
def compute_between_bin_distances(pred_structure):

    between_bin_distances = []

    for i in range(NB_BINS-1):
          between_bin_distances.append(np.sqrt(np.sum((pred_structure[i,:]-pred_structure[i+1,:])**2)))
            
    return np.asarray(between_bin_distances)

In [None]:
def compute_sizes(pred_structure):
    
    sizes = compute_between_bin_distances(pred_structure)
    sizes = np.append(sizes, sizes[-1])
    sizes = np.log(1 / sizes) * 5
    
    return sizes

In [None]:
# structure_1, structure_2 = procrustes_superimposition_numpy(reach3d_structure, perfect_structure)
structure_1 = perfect_structure
structure_2, _ = kabsch_superimposition_numpy(minimds_structure, perfect_structure)
structure_3, _ = kabsch_superimposition_numpy(reach3d_structure, perfect_structure)

In [None]:
display_structure = structure_3

In [None]:
# Initialize figure with 3D subplots
fig = make_subplots(
    rows=1, cols=1,
    specs=[[{'type': 'scatter3d'}]])

# adding surfaces to subplots.
fig.add_trace(
    go.Scatter3d(
    x=display_structure[:,0], y=display_structure[:,1], z=display_structure[:,2], opacity=0.7,
    marker=dict(
        size=6, #compute_sizes(structure_2),
        color=np.asarray(range(len(structure_2[:,0]))),
        colorscale='Viridis',
        line=dict(width=3)
    ),
    line=dict(
        color='darkblue',
        width=2
    )
), row=1, col=1)

fig.update_layout(
    height=1000,
    width=1000
)

fig.show()

In [None]:
# Initialize figure with 3D subplots
fig = make_subplots(
    rows=1, cols=3,
    specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}, {'type': 'scatter3d'}]])

# adding surfaces to subplots.
fig.add_trace(
    go.Scatter3d(
    x=structure_1[:,0], y=structure_1[:,1], z=structure_1[:,2], opacity=0.4,
    marker=dict(
        size=5, #compute_sizes(structure_1),
        color=np.asarray(range(len(structure_1[:,0]))),
        colorscale='Viridis',
        line=dict(width=3)
    ),
    line=dict(
        color='darkblue',
        width=2
    )
), row=1, col=1)

fig.add_trace(
    go.Scatter3d(
    x=structure_2[:,0], y=structure_2[:,1], z=structure_2[:,2], opacity=0.4,
    marker=dict(
        size=5, #compute_sizes(structure_2),
        color=np.asarray(range(len(structure_2[:,0]))),
        colorscale='Viridis',
        line=dict(width=3)
    ),
    line=dict(
        color='darkblue',
        width=2
    )
), row=1, col=2)

fig.add_trace(
    go.Scatter3d(
    x=structure_3[:,0], y=structure_3[:,1], z=structure_3[:,2], opacity=0.4,
    marker=dict(
        size=5, #compute_sizes(structure_3),
        color=np.asarray(range(len(structure_3[:,0]))),
        colorscale='Viridis',
        line=dict(width=3)
    ),
    line=dict(
        color='darkblue',
        width=2
    )
), row=1, col=3)

fig.update_layout(
    height=1000,
    width=1000
)

fig.show()