In [1]:
# File processing
import glob
import os

# Data processing
import random
import numpy as np
from tqdm import tqdm

# Data display 
import matplotlib
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from matplotlib.legend_handler import HandlerLine2D

# Machine learning
import torch
import torch.nn.functional as F
from torch import linalg as LAtorch
from numpy import linalg as LAnumpy
from torch_geometric.data import DataLoader
from sklearn.metrics import mean_squared_error
from torch_geometric.data import Data, InMemoryDataset

# Constants

In [2]:
NB_EPOCHS = 200
BATCH_SIZE = 10
NB_NODES = 202
EMBEDDING_SIZE = 3 # Euclidean 3D space
LEARNING_RATE = 0.001
SEED = 0
LAMBDA_STRUCTURE = 10
LAMBDA_DISTANCE = 1
TRAIN_DATASET_SIZE = 800
TEST_DATASET_SIZE = 200
NOISE_VARIANCE = 0.1

# Seeds

In [3]:
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

# Create dataset

In [4]:
# grab last 4 digits of the file txt name:
def last_4digits(x):
    return(x[-8:-4])

### Hic matrices

In [None]:
train_transfer_learning_hics = []

file_list = os.listdir('../../../../data/ae/synthetic_biological/train/hic_matrices/')

for file_name in sorted(filter(lambda x: x.endswith('.txt'), file_list), key = last_4digits):
    current_train_transfer_learning_hic = np.loadtxt('../../../../data/ae/synthetic_biological/train/hic_matrices/'\
                                                     + file_name, dtype='f', delimiter=' ')
    train_transfer_learning_hics.append(current_train_transfer_learning_hic)

In [None]:
test_transfer_learning_hics = []

file_list = os.listdir('../../../../data/synthetic_biological/test/hic_matrices/')

for file_name in sorted(filter(lambda x: x.endswith('.txt'), file_list), key = last_4digits):
    current_test_transfer_learning_hic = np.loadtxt('../../../../data/synthetic_tad/test/hic_matrices/'\
                                                     + file_name, dtype='f', delimiter=' ')
    test_transfer_learning_hics.append(current_test_transfer_learning_hic)

### Structure matrices

In [None]:
train_transfer_learning_structures = []

file_list = os.listdir('../../../data/synthetic_tad/train/structure_matrices/')

for file_name in sorted(filter(lambda x: x.endswith('.txt'), file_list), key = last_4digits):
    current_train_transfer_learning_structure = \
        np.loadtxt('../../../data/synthetic_tad/train/structure_matrices/'\
                                                     + file_name, dtype='f', delimiter=' ')
    train_transfer_learning_structures.append(current_train_transfer_learning_structure)

In [None]:
test_transfer_learning_structures = []

file_list = os.listdir('../../../data/synthetic_tad/test/structure_matrices/')

for file_name in sorted(filter(lambda x: x.endswith('.txt'), file_list), key = last_4digits):
    current_test_transfer_learning_structure = \
        np.loadtxt('../../../data/synthetic_tad/test/structure_matrices/'\
                                                     + file_name, dtype='f', delimiter=' ')
    test_transfer_learning_structures.append(current_test_transfer_learning_structure)

### Distance matrices

In [None]:
train_transfer_learning_distances = []

file_list = os.listdir('../../../data/synthetic_tad/train/distance_matrices/')

for file_name in sorted(filter(lambda x: x.endswith('.txt'), file_list), key = last_4digits):
    current_train_transfer_learning_distance = \
            np.loadtxt('../../../data/synthetic_tad/train/distance_matrices/'\
                                                     + file_name, dtype='f', delimiter=' ')
    train_transfer_learning_distances.append(current_train_transfer_learning_distance)

In [None]:
test_transfer_learning_distances = []

file_list = os.listdir('../../../data/synthetic_tad/test/distance_matrices/')

for file_name in sorted(filter(lambda x: x.endswith('.txt'), file_list), key = last_4digits):
    current_test_transfer_learning_distance = \
            np.loadtxt('../../../data/synthetic_tad/test/distance_matrices/'\
                                                     + file_name, dtype='f', delimiter=' ')
    test_transfer_learning_distances.append(current_test_transfer_learning_distance)

### Final dataset

#### Training

In [None]:
is_training = True

In [None]:
class VanillaDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(VanillaDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return []
    @property
    def processed_file_names(self):
        if is_training:
            return ['processed_transfer_learning_train_synthetic_data.txt']
        else:
            return ['processed_transfer_learning_test_synthetic_data.txt']

    def download(self):
        pass
        
    def process(self):
        
        data_list = []
        if is_training:
            dataset_size = TRAIN_DATASET_SIZE
        else:
            dataset_size = TEST_DATASET_SIZE
        
        for i in tqdm(range(dataset_size)):
            
            if is_training:
                transfer_learning_hic = train_transfer_learning_hics[i]
                transfer_learning_structure = train_transfer_learning_structures[i]
                transfer_learning_distance_matrix = train_transfer_learning_distances[i]
            else:
                transfer_learning_hic = test_transfer_learning_hics[i]
                transfer_learning_structure = test_transfer_learning_structures[i]
                transfer_learning_distance_matrix = test_transfer_learning_distances[i]
               
            hic_matrix = torch.FloatTensor(transfer_learning_hic)
            structure_matrix = torch.FloatTensor(transfer_learning_structure)
            distance_matrix = torch.FloatTensor(transfer_learning_distance_matrix)

            data = Data(hic_matrix=hic_matrix, structure_matrix=structure_matrix, distance_matrix=distance_matrix)
            data_list.append(data)
            
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [None]:
train_dataset = VanillaDataset('../')
train_dataset = train_dataset.shuffle()

In [None]:
train_size = len(train_dataset)
train_size

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)

#### Testing

In [None]:
is_training = False

In [None]:
test_dataset = VanillaDataset('../')
test_dataset = test_dataset.shuffle()

In [None]:
test_size = len(test_dataset)
test_size

In [None]:
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

# Synthetic Network: Linear Network

In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.linear_encoder_layer_1 = torch.nn.Linear(NB_NODES, 100)
        self.linear_encoder_layer_2 = torch.nn.Linear(100, 50)
        self.linear_encoder_layer_3 = torch.nn.Linear(50, EMBEDDING_SIZE)
        
        self.linear_decoder_layer_1 = torch.nn.Linear(NB_NODES, NB_NODES)
        
        self.diagonal_mask = torch.eye(NB_NODES, dtype=torch.float).repeat(BATCH_SIZE, 1, 1)
        self.zero_output = torch.zeros(NB_NODES, NB_NODES, dtype=torch.float).repeat(BATCH_SIZE, 1, 1)
        
    def forward(self, x, is_training=False):
        
        x = torch.reshape(x, (BATCH_SIZE, NB_NODES, NB_NODES))
        if is_training:
            x = x + (NOISE_VARIANCE**0.5)*torch.randn(BATCH_SIZE, NB_NODES, NB_NODES)
        
        z = self.linear_encoder_layer_1(x)
        z = F.relu(z)
        z = self.linear_encoder_layer_2(z)
        z = F.relu(z)
        z = self.linear_encoder_layer_3(z)
        z = F.relu(z)
        z = centralize_and_normalize_torch(z)
        
        w = torch.cdist(z, z, p=2)
        
        y = self.linear_decoder_layer_1(w)
        y = F.relu(y)

        # Set y to be symmetric
        y = (y + torch.transpose(y, 1, 2))/2
        
        # Clamp ouput between 0 and 1 (Note: need to use Min Max scaler for input)
        y = torch.clamp(y, min=0, max=1)
        
        # Set diagonal prediction HiC frequencies to 0
        y = torch.where(self.diagonal_mask > 0.0, self.zero_output, y)

        return y, z, w

In [None]:
device = torch.device('cpu')
synthetic_model = torch.load('../../../models/synthetic/linear/synthetic_linear_model.pt').to(device)

# Freez training of weights
for param in synthetic_model.parameters():
    param.requires_grad = False

# Biological Network: Linear Neural Network

In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.linear_encoder_layer_x = torch.nn.Linear(NB_NODES, NB_NODES)
        
        self.linear_decoder_layer_y = torch.nn.Linear(NB_NODES, NB_NODES)
        self.linear_decoder_layer_z = torch.nn.Linear(EMBEDDING_SIZE, EMBEDDING_SIZE)
        self.linear_decoder_layer_w = torch.nn.Linear(NB_NODES, NB_NODES)
        
        self.diagonal_mask = torch.eye(NB_NODES, dtype=torch.float).repeat(BATCH_SIZE, 1, 1)
        self.zero_output = torch.zeros(NB_NODES, NB_NODES, dtype=torch.float).repeat(BATCH_SIZE, 1, 1)
        
    def forward(self, x, is_training=False):
        
        x = torch.reshape(x, (BATCH_SIZE, NB_NODES, NB_NODES))
        if is_training:
            x = x + (NOISE_VARIANCE**0.5)*torch.randn(BATCH_SIZE, NB_NODES, NB_NODES)
        
        x = self.linear_encoder_layer_x(x)
        x = F.relu(x)
        
        y, z, w = synthetic_model(x)
        
        # Deal with y
        y = self.linear_decoder_layer_y(y)
        y = F.relu(y)

        # Set y to be symmetric
        y = (y + torch.transpose(y, 1, 2))/2
        
        # Clamp ouput between 0 and 1 (Note: need to use Min Max scaler for input)
        y = torch.clamp(y, min=0, max=1)
        
        # Set diagonal prediction HiC frequencies to 0
        y = torch.where(self.diagonal_mask > 0.0, self.zero_output, y)
        
        # Deal with z
        z = self.linear_decoder_layer_z(z)
        z = F.relu(z)
        
        # Deal with w
        w = self.linear_decoder_layer_w(w)
        w = F.relu(w)
        
        return y, z, w

# Procruste analysis functions

### Torch

In [None]:
def centralize_torch(z):
    return z - torch.repeat_interleave(torch.reshape(torch.mean(z, axis=1), (-1,1,EMBEDDING_SIZE)), NB_NODES, dim=1)

In [None]:
def normalize_torch(z):
    
    norms = LAtorch.norm(z, 2, dim=2)
    max_norms, _ = torch.max(norms, axis=1)
    max_norms = torch.reshape(max_norms, (BATCH_SIZE,1,1))
    max_norms = torch.repeat_interleave(max_norms, NB_NODES, dim=1)
    max_norms = torch.repeat_interleave(max_norms, EMBEDDING_SIZE, dim=2)
    max_norms[max_norms == 0] = 1
    
    return z / max_norms

In [None]:
def centralize_and_normalize_torch(z):
    
    # Translate
    z = centralize_torch(z)
    
    # Scale
    z = normalize_torch(z)
    
    return z

In [None]:
def structure_loss_fct(pred_structure, true_structure):
    
    # Rotation (solution for the constrained orthogonal Procrustes problem, subject to det(R) = 1)
    m = torch.matmul(true_structure, torch.transpose(pred_structure, 1, 2))    
    u, s, vh = torch.svd_lowrank(m, q=EMBEDDING_SIZE)
    r = torch.matmul(u, torch.transpose(vh, 1, 2))

    pred_structure = torch.matmul(r, pred_structure)
    
    return torch.mean(torch.sum(torch.square(pred_structure - true_structure), axis=2))

### Numpy

In [None]:
def centralize_numpy(z):
    return z - np.mean(z, axis=0)

In [None]:
def normalize_numpy(z):
    
    norm = LAnumpy.norm(z, 2, axis=1)
    max_norm = np.max(norm, axis=0)
    if max_norm == 0:
        max_norm = 1
    
    return z / max_norm

In [None]:
def centralize_and_normalize_numpy(z):
    
    # Translate
    z = centralize_numpy(z)
    
    # Scale
    z = normalize_numpy(z)
    
    return z

In [None]:
def procrustes_superimposition_numpy(pred_structure, true_structure):
    
    # Centralize and normalize to unit ball
    pred_structure_unit_ball = centralize_and_normalize_numpy(pred_structure)
    true_structure_unit_ball = centralize_and_normalize_numpy(true_structure)
    
    # Rotation (solution for the constrained orthogonal Procrustes problem, subject to det(R) = 1)
    m = np.matmul(true_structure_unit_ball, np.transpose(pred_structure_unit_ball))
    u, s, vh = np.linalg.svd(m, full_matrices=False)
    r = np.matmul(u, vh)
    pred_structure_unit_ball = np.matmul(r, pred_structure_unit_ball)
    
    return pred_structure_unit_ball, true_structure_unit_ball

In [None]:
def procrustes_distance_numpy(pred_structure, true_structure):
    
    # Centralize and normalize to unit ball
    pred_structure_unit_ball = centralize_and_normalize_numpy(pred_structure)
    true_structure_unit_ball = centralize_and_normalize_numpy(true_structure)
    
    # Rotation (solution for the constrained orthogonal Procrustes problem, subject to det(R) = 1)
    m = np.matmul(true_structure_unit_ball, np.transpose(pred_structure_unit_ball))
    u, s, vh = np.linalg.svd(m, full_matrices=False)
    r = np.matmul(u, vh)
    pred_structure_unit_ball = np.matmul(r, pred_structure_unit_ball)
    
    # Structure comparison
    d = np.mean(np.sum(np.square(pred_structure_unit_ball - true_structure_unit_ball), axis=1))
    
    return d

# Train and test

In [None]:
device = torch.device('cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
hic_loss_fct = torch.nn.MSELoss()
distance_loss_fct = torch.nn.MSELoss()

In [None]:
def plot_grad_flow(named_parameters):
    '''Plots the gradients flowing through different layers in the net during training.
    Can be used for checking for possible gradient vanishing / exploding problems.
    
    Usage: Plug this function in Trainer class after loss.backwards() as 
    "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
    ave_grads = []
    max_grads= []
    layers = []
    for n, p in named_parameters:
        if(p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean())
            max_grads.append(p.grad.abs().max())
    plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
    plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
    plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" )
    plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(left=0, right=len(ave_grads))
    plt.ylim(bottom = -0.001, top=0.02) # zoom in on the lower gradient regions
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.title("Gradient flow")
    plt.grid(True)
    plt.legend([matplotlib.lines.Line2D([0], [0], color="c", lw=4),
                matplotlib.lines.Line2D([0], [0], color="b", lw=4),
                matplotlib.lines.Line2D([0], [0], color="k", lw=4)], 
               ['max-gradient', 'mean-gradient', 'zero-gradient'])

In [None]:
def train():
    model.train()

    loss_all = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        pred_hic, pred_structure, pred_distance = model(data.hic_matrix, is_training=True)
        
        pred_hic = torch.reshape(pred_hic, (BATCH_SIZE*NB_NODES, NB_NODES))
        true_hic = data.hic_matrix.to(device)
        
        true_structure = data.structure_matrix.to(device)
        true_structure = torch.reshape(true_structure, (BATCH_SIZE, NB_NODES, EMBEDDING_SIZE))
        
        pred_distance = torch.reshape(pred_distance, (BATCH_SIZE*NB_NODES, NB_NODES))
        true_distance = data.distance_matrix.to(device)
        
        # HiC loss
        hic_loss = hic_loss_fct(pred_hic, true_hic)
        
        # Structure loss
        structure_loss = structure_loss_fct(pred_structure, true_structure)
        
        # Distance loss 
        distance_loss = distance_loss_fct(pred_distance, true_distance)
        
        # Combine losses
        loss = hic_loss + LAMBDA_STRUCTURE * structure_loss + LAMBDA_DISTANCE * distance_loss
        
#         with torch.autograd.detect_anomaly():
        loss.backward()
        
        loss_all += data.num_graphs * loss.item()
        
        # Plot grad flow
#         plot_grad_flow(model.named_parameters())
        
        optimizer.step()
    return loss_all / len(train_dataset)

In [None]:
def evaluate(loader):
    model.eval()

    pred_hics = []
    true_hics = []
    
    pred_structures = []
    true_structures = []
    
    pred_distances = []
    true_distances = []
    
    hic_losses = []
    structure_losses = []
    distance_losses = []

    with torch.no_grad():
        for data in loader:

            data = data.to(device)
            
            pred_hic, pred_structure, pred_distance = model(data.hic_matrix)
            
            pred_hic = pred_hic.detach().cpu()
            pred_structure = pred_structure.detach().cpu()
            pred_distance = pred_distance.detach().cpu()
            
            pred_hic = torch.reshape(pred_hic, (BATCH_SIZE*NB_NODES, NB_NODES))
            pred_distance = torch.reshape(pred_distance, (BATCH_SIZE*NB_NODES, NB_NODES))
            
            true_hic = data.hic_matrix.detach().cpu()
            true_structure = data.structure_matrix.detach().cpu()
            true_distance = data.distance_matrix.detach().cpu()
            
            true_structure = torch.reshape(true_structure, (BATCH_SIZE, NB_NODES, EMBEDDING_SIZE))
            
            # Hic loss
            hic_loss = hic_loss_fct(pred_hic, true_hic).numpy()
            hic_losses.append(hic_loss)
        
            # Structure loss
            structure_loss = structure_loss_fct(pred_structure, true_structure).numpy()
            structure_losses.append(structure_loss)
            
            # Distance 
            distance_loss = distance_loss_fct(pred_distance, true_distance).numpy()
            distance_losses.append(distance_loss)
            
            # To numpy
            pred_hic = pred_hic.numpy()
            true_hic = true_hic.numpy()
            
            pred_structure = pred_structure.numpy()
            true_structure = true_structure.numpy()
            
            pred_distance = pred_distance.numpy()
            true_distance = true_distance.numpy()
            
            # Store results
            pred_hics.append(pred_hic)
            true_hics.append(true_hic)
            
            pred_structures.append(pred_structure)
            true_structures.append(true_structure)
            
            pred_distances.append(pred_distance)
            true_distances.append(true_distance)
    
    # Format reesults
    pred_hics = np.vstack(pred_hics)
    true_hics = np.vstack(true_hics)
    
    pred_structures = np.vstack(pred_structures)
    true_structures = np.vstack(true_structures)
    
    pred_distances = np.vstack(pred_distances)
    true_distances = np.vstack(true_distances)
    
    # Compute mean losses
    mean_hic_loss = np.mean(np.asarray(hic_losses).flatten())
    mean_structure_loss = np.mean(np.asarray(structure_losses).flatten())
    mean_distance_loss = np.mean(np.asarray(distance_losses).flatten())
    
    
    return mean_hic_loss, mean_structure_loss, mean_distance_loss, pred_hics, true_hics, \
            pred_structures, true_structures, pred_distances, true_distances

In [None]:
train_hic_losses_all_epochs = []
train_structure_losses_all_epochs = []
train_distance_losses_all_epochs = []

test_hic_losses_all_epochs = []
test_structure_losses_all_epochs = []
test_distance_losses_all_epochs = []

losses = []

for epoch in range(1, NB_EPOCHS+1):
    loss = train()
    losses.append(loss)
    
    ### Training
    train_mean_hic_loss, train_mean_structure_loss, train_mean_distance_loss, train_pred_hics, train_true_hics, \
        train_pred_structures, train_true_structures, train_pred_distances, \
            train_true_distances = evaluate(train_loader) 
    
    # Store results
    train_hic_losses_all_epochs.append(train_mean_hic_loss)
    train_structure_losses_all_epochs.append(train_mean_structure_loss)    
    train_distance_losses_all_epochs.append(train_mean_distance_loss)
    
    ### Testing
    test_mean_hic_loss, test_mean_structure_loss, test_mean_distance_loss, test_pred_hics, test_true_hics, \
        test_pred_structures, test_true_structures, test_pred_distances, \
            test_true_distances = evaluate(test_loader) 
    
    # Store results
    test_hic_losses_all_epochs.append(test_mean_hic_loss)
    test_structure_losses_all_epochs.append(test_mean_structure_loss)    
    test_distance_losses_all_epochs.append(test_mean_distance_loss)
    
    print('E: {:03d}, Tr H: {:.4f}, Tr S: {:.4f}, Tr D: {:.4f}, Te H: {:.4f}, Te S: {:.4f}, Te D: {:.4f}'.format(\
        epoch, train_mean_hic_loss, train_mean_structure_loss, train_mean_distance_loss, \
               test_mean_hic_loss, test_mean_structure_loss, test_mean_distance_loss))

In [None]:
plt.plot(losses, label='Losses')
plt.legend()

In [None]:
plt.plot(train_hic_losses_all_epochs, label='Train Hic')
plt.plot(test_hic_losses_all_epochs, label='Test Hic')

plt.legend()

In [None]:
plt.plot(train_structure_losses_all_epochs, label='Train Struct')
plt.plot(test_structure_losses_all_epochs, label='Test Struct')

plt.legend()

In [None]:
plt.plot(train_distance_losses_all_epochs, label='Train Dist')
plt.plot(test_distance_losses_all_epochs, label='Test Dist')

plt.legend()

# Model evaluation

In [None]:
GRAPH_TESTED = 0

### Test reconstruction

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15,15))

ground_truth_matrix = test_true_hics[GRAPH_TESTED*NB_NODES:GRAPH_TESTED*NB_NODES+NB_NODES, :]
axes[0].imshow(ground_truth_matrix, cmap='hot', interpolation='nearest')

reconstruction_matrix = test_pred_hics[GRAPH_TESTED*NB_NODES:GRAPH_TESTED*NB_NODES+NB_NODES, :]
axes[1].imshow(reconstruction_matrix, cmap='hot', interpolation='nearest')

plt.show()

### Test distance matrix

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15,15))

ground_truth_matrix = test_true_distances[GRAPH_TESTED*NB_NODES:GRAPH_TESTED*NB_NODES+NB_NODES, :]
axes[0].imshow(ground_truth_matrix, cmap='hot', interpolation='nearest')

reconstruction_matrix = test_pred_distances[GRAPH_TESTED*NB_NODES:GRAPH_TESTED*NB_NODES+NB_NODES, :]
axes[1].imshow(reconstruction_matrix, cmap='hot', interpolation='nearest')

plt.show()

### Test latent space

In [None]:
fig = plt.figure(figsize=(50, 50))

test_true_structure = test_true_structures[GRAPH_TESTED]
test_pred_structure = test_pred_structures[GRAPH_TESTED]

test_pred_structure_pro, test_true_structure_pro = \
        procrustes_superimposition_numpy(test_pred_structure, test_true_structure)

x_pred = test_pred_structure_pro[:, 0]  
y_pred = test_pred_structure_pro[:, 1]
z_pred = test_pred_structure_pro[:, 2]

x_true = test_true_structure_pro[:, 0]  
y_true = test_true_structure_pro[:, 1]
z_true = test_true_structure_pro[:, 2]

ax = fig.add_subplot(1, 2, 1, projection='3d')
ax.plot3D(x_true, y_true, z_true, c='blue', alpha=0.5, lw=4.5)

ax = fig.add_subplot(1, 2, 2, projection='3d')
ax.plot3D(x_pred, y_pred, z_pred, c='blue', alpha=0.5, lw=4.5)

# Shape comparison
d = np.mean(np.sum(np.square(test_pred_structure_pro - test_true_structure_pro), axis=1))
print('Procrustes distance is ' + str(d))

plt.show()

In [None]:
procruste_distances = []

for graph_index in range(test_size):

    test_true_structure = test_true_structures[graph_index,:,:]
    test_pred_structure = test_pred_structures[graph_index,:,:]
    
    d = procrustes_distance_numpy(test_pred_structure, test_true_structure)
    procruste_distances.append(d)

In [None]:
n, bins, patches = plt.hist(procruste_distances, 100, facecolor='blue', alpha=0.5)
plt.show()

print('mean: ' + str(np.mean(procruste_distances)))
print('median: ' + str(np.median(procruste_distances)))
print('variance: ' + str(np.var(procruste_distances)))

# Save trained model

In [None]:
torch.save(model, '../../../models/synthetic_tad_extended/linear/synthetic_tad_extended_linear_model.pt')

# Save results

In [None]:
np.savetxt('../../../results/synthetic_tad_extended/linear/synthetic_tad_extended_linear_structure_train_losses.txt', 
           train_structure_losses_all_epochs, delimiter=',')
np.savetxt('../../../results/synthetic_tad_extended/linear/synthetic_tad_extended_linear_structure_test_losses.txt', 
           test_structure_losses_all_epochs, delimiter=',')
np.savetxt('../../../results/synthetic_tad_extended/linear/synthetic_tad_extended_linear_structure_procruste_distances.txt', 
           procruste_distances, delimiter=',')