In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
##!pip3 install torch==1.8.1 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu111

In [3]:
import torch
import random
import numpy as np
import os 
from torch_geometric.data import DataLoader

from Chromatin3D.Data_Tools.Data_Access import get_data_from_path, VanillaDataset, set_logits_data
from Chromatin3D.Data_Tools.Data_Plotting import plot_structure_in_sphere, plot_hic, plot_optimal_transport, plot_losses, plot_test_distance_matrix, plot_true_pred_structures, plot_hist_kabsch_distances, plot_grad_flow, plot_pred_conf
from Chromatin3D.Model.model import UniformLinear, train_uniform_linear, evaluate_uniform_linear, ConfLinear, train_conf_linear, evaluate_conf_linear, TransConf, train_trans_conf, evaluate_trans_conf
from Chromatin3D.Model.losses import compute_trussart_test_kabsch_loss, biological_loss_fct, kabsch_loss_fct
from Chromatin3D.Data_Tools.Data_Calculation import save_structure, import_trussart_data, kabsch_superimposition_numpy, kabsch_distance_numpy, make_gif, scale_logits, mse_unscaled_scaled, import_fission_yeast, FISH_values_Tanizawa, dist_Tanizawa_FISH, save_structure_fission_yeast
from Chromatin3D.Model.lddt_tools import lddt, get_confidence_metrics
from Chromatin3D.Model.calibration_nn import ModelWithTemperature, isotonic_calibration, beta_calibration
from scipy.spatial import distance_matrix
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr

NB_HEAD = 2
NB_HIDDEN = 250
NB_LAYERS = 1
DROPOUT = 0.1
SECD_HID = 100
ZERO_INIT = False
EXPONENT = 1
NUM_BINS_LOGITS = 100
NB_EPOCHS = 44
##47
SEED = 2
BATCH_SIZE = 10
NB_BINS = 1258
EMBEDDING_SIZE = 3
ANGLE_PRED = 3
LAMBDA_BIO = 0
LAMBDA_KABSCH = 0.1
LAMBDA_LDDT = 0.1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
DATA_DIR = '/dccstor/cpath_data/datasets/Chromatin3D/data'
DATA_PATH=f'{DATA_DIR}/more_fission_yeast/'
TRAIN_DATASET_SIZE = 800
TEST_DATASET_SIZE = 200
FISH_Tanizawa = f'{DATA_DIR}/fission_yeast/FISH_Tanizawa.csv'




In [4]:
!nvidia-smi

Thu Sep 29 04:27:19 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.39.01    Driver Version: 510.39.01    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  On   | 00000000:84:00.0 Off |                    0 |
| N/A   32C    P0    64W / 400W |      2MiB / 81920MiB |      0%   E. Process |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [5]:
train_transfer_learning_hics, test_transfer_learning_hics, train_transfer_learning_structures, test_transfer_learning_structures, train_transfer_learning_distances, test_transfer_learning_distances = get_data_from_path(DATA_PATH)

In [6]:
train_dataset = VanillaDataset(root = '', is_training = True, dataset_size = TRAIN_DATASET_SIZE, hics = train_transfer_learning_hics, structures = train_transfer_learning_structures, distances = train_transfer_learning_distances)
train_dataset = train_dataset.shuffle()
train_size = len(train_dataset)
print(train_size)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, device=device)

100%|██████████| 800/800 [00:00<00:00, 16068.98it/s]

Processing...





Done!
800


In [7]:
test_dataset = VanillaDataset(root = '', is_training = False, dataset_size = TEST_DATASET_SIZE, hics = test_transfer_learning_hics, structures = test_transfer_learning_structures, distances = test_transfer_learning_distances)
test_dataset = test_dataset.shuffle()

test_train_idx, test_test_idx = train_test_split(list(range(len(test_dataset))), test_size=0.1)
test_train_calib = test_dataset.index_select(test_train_idx)
test_test_calib= test_dataset.index_select(test_test_idx)

test_size = len(test_dataset)

test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, device=device)
test_train_calib_loader = DataLoader(test_train_calib, batch_size=BATCH_SIZE, device=device)
test_test_calib_loader = DataLoader(test_test_calib, batch_size=BATCH_SIZE, device=device)

100%|██████████| 200/200 [00:00<00:00, 73237.37it/s]

Processing...





Done!


In [8]:
fission_yeast_hic = import_fission_yeast(DATA_DIR)

In [9]:
distance_loss_fct = torch.nn.MSELoss()

In [10]:
device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda')
model = TransConf(NB_BINS, ANGLE_PRED, BATCH_SIZE, NUM_BINS_LOGITS, ZERO_INIT, NB_HEAD, NB_HIDDEN, NB_LAYERS, DROPOUT, SECD_HID).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay=1e-5)
#optimizer = torch.optim.Adagrad(model.parameters(), lr=0.001)

In [11]:
fish_table = FISH_values_Tanizawa(FISH_Tanizawa)
dist_fish = list(fish_table['FISH_dist'])

In [12]:
train_biological_losses_all_epochs = []
train_kabsch_losses_all_epochs = []
train_distance_losses_all_epochs = []
train_lddt_losses_all_epochs = []

test_biological_losses_all_epochs = []
test_kabsch_losses_all_epochs = []
test_distance_losses_all_epochs = []
test_lddt_losses_all_epochs = []

losses = []

fission_yeast_pearson_loss_all_epochs = []

for epoch in range(1, NB_EPOCHS+1):
    loss = train_trans_conf(model, train_loader, train_dataset, optimizer, device, BATCH_SIZE,  NB_BINS, EMBEDDING_SIZE, LAMBDA_BIO, LAMBDA_KABSCH, distance_loss_fct, LAMBDA_LDDT, NUM_BINS_LOGITS)
    losses.append(loss)
    
    ### Training
    train_mean_biological_loss, train_mean_kabsch_loss, train_mean_distance_loss, train_true_hics, \
        train_pred_structures, train_true_structures, train_pred_distances, \
            train_true_distances, train_mean_lddt_loss = evaluate_trans_conf(train_loader, model, device, BATCH_SIZE, NB_BINS, EMBEDDING_SIZE, distance_loss_fct, NUM_BINS_LOGITS) 

    
    # Store results
    train_biological_losses_all_epochs.append(train_mean_biological_loss)
    train_kabsch_losses_all_epochs.append(train_mean_kabsch_loss)    
    train_distance_losses_all_epochs.append(train_mean_distance_loss)
    train_lddt_losses_all_epochs.append(train_mean_lddt_loss)

    ### Testing
    test_mean_biological_loss, test_mean_kabsch_loss, test_mean_distance_loss, test_true_hics, \
        test_pred_structures, test_true_structures, test_pred_distances, \
            test_true_distances, test_mean_lddt_loss = evaluate_trans_conf(test_loader, model, device, BATCH_SIZE, NB_BINS, EMBEDDING_SIZE, distance_loss_fct, NUM_BINS_LOGITS) 

    
    ### Trussart test
    #trussart_test_kabsch_loss = compute_trussart_test_kabsch_loss(trussart_hic, trussart_structures, model, NB_BINS, BATCH_SIZE, EMBEDDING_SIZE, True)
    #save_structure(model, epoch, trussart_structures, trussart_hic, NB_BINS, BATCH_SIZE, EMBEDDING_SIZE, True)
    # Store results
    test_biological_losses_all_epochs.append(test_mean_biological_loss)
    test_kabsch_losses_all_epochs.append(test_mean_kabsch_loss)    
    test_distance_losses_all_epochs.append(test_mean_distance_loss)
    test_lddt_losses_all_epochs.append(test_mean_lddt_loss)
    
    torch_fission_yeast_hic = torch.FloatTensor(fission_yeast_hic)
    torch_fission_yeast_hic = torch.reshape(torch_fission_yeast_hic, (1, NB_BINS, NB_BINS))
    torch_fission_yeast_hic = torch.repeat_interleave(torch_fission_yeast_hic, BATCH_SIZE, 0)

    fission_yeast_pred_structure, _ , _= model(torch_fission_yeast_hic)
    fission_yeast_pred_structure = fission_yeast_pred_structure.detach().numpy()[0]
    dist_model = dist_Tanizawa_FISH(fission_yeast_pred_structure, fish_table)
    save_structure_fission_yeast(model, epoch, fission_yeast_hic, NB_BINS, BATCH_SIZE, EMBEDDING_SIZE, True)

    fission_yeast_pearson_loss = pearsonr(dist_fish, dist_model)[0]
    
    fission_yeast_pearson_loss_all_epochs.append(fission_yeast_pearson_loss)

    print('E: {:03d}, Tr B: {:.4f}, Tr K: {:.4f}, Tr D: {:.4f}, Te B: {:.4f}, Te K: {:.4f}, Te D: {:.4f}, Tr LD: {:.4f}, Te LD: {:.4f}, Trus: {:.4f}'.format(\
        epoch, train_mean_biological_loss, train_mean_kabsch_loss, train_mean_distance_loss, \
            test_mean_biological_loss, test_mean_kabsch_loss, test_mean_distance_loss,train_mean_lddt_loss, test_mean_lddt_loss, fission_yeast_pearson_loss))

RuntimeError: Expected condition, x and y to be on the same device, but condition is on cuda:0 and x and y are on cpu and cpu respectively

In [None]:


# Plot structures
x_pred = fission_yeast_pred_structure[:, 0]  
y_pred = fission_yeast_pred_structure[:, 1]
z_pred = fission_yeast_pred_structure[:, 2]

colorscale = np.full(len(fission_yeast_pred_structure[:,0]), 'some non specified color')
colorscale[:558] = 'red'
colorscale[558:1012] = 'green'
colorscale[1012:] = 'blue'
color = 'Viridis'
plot_pred_conf(fission_yeast_pred_structure, colorscale, color)