In [1]:
import numpy as np
import scipy
from scipy.sparse import csr_matrix
import matplotlib.pyplot as plt
import seaborn as sns

import cvxopt
from cvxopt import matrix, solvers, spmatrix, sparse
import quadprog
import qpsolvers

import pickle
import gc
from pathlib import Path

import networkx as nx
from networkx.algorithms.approximation import clique

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import MaxAbsScaler
from sklearn.decomposition import PCA

=====================================================================================

In [2]:
y = np.load('../general_t14_t3_l10.0_g0.4_r0.5_sasa_solvent_all/labels.npy')
print(y.shape)
s = np.load('../general_t14_t3_l10.0_g0.4_r0.5_sasa_solvent_all/affinities.npy')
print(s.shape)
loader = np.load('../general_t14_t3_l10.0_g0.4_r0.5_sasa_solvent_all/data_array_sparse.npz')
X = csr_matrix((loader['data'], loader['indices'], loader['indptr']), shape=loader['shape'])
shape = loader['shape']
print(shape)

D = 19
N = X.shape[1]
P = int(X.shape[0] / D)

scaler = MaxAbsScaler(copy=False)
X = scaler.fit_transform(X)

native_indices = np.arange(P) * D
native_X = X[native_indices].toarray()
corr = np.corrcoef(native_X)

(214624,)
(11296,)
[214624  23379]


In [3]:
!free -h

              total        used        free      shared  buff/cache   available
Mem:            62G        9.9G        4.7G        1.7G         48G         50G
Swap:            0B          0B          0B


In [4]:
# %%time
# corr_graph = nx.from_numpy_matrix(corr < 0.5)
# T = clique.max_clique(corr_graph)
# np.save('max_clique_indices', np.array(T))

In [5]:
# paircorr = np.corrcoef(native_X[corr.mean(0) < 0.2], native_X[corr.mean(0) >= 0.2])
# plt.figure(figsize=(12, 12))
# sns.heatmap(paircorr);

In [6]:
T = [1, 10246, 2062, 15, 2064, 2069, 2070, 4120, 8217, 8221, 2077, 30, 10273, 6178, 6177, 38, 4139, 46, 8239, 4145, 2098, 2100, 52, 55, 6203, 2108, 2109, 61, 59, 64, 10306, 6210, 2119, 73, 2124, 10318, 10320, 2130, 6227, 4179, 8279, 8283, 4188, 95, 4192, 97, 6242, 8291, 99, 2149, 2150, 101, 104, 98, 2154, 8301, 8303, 10352, 4208, 2162, 2166, 6266, 2172, 4221, 6271, 2175, 10369, 10370, 6274, 8324, 2181, 2182, 8327, 6280, 2185, 137, 136, 135, 8335, 144, 145, 4243, 2197, 4247, 4249, 153, 2205, 2208, 8356, 166, 2218, 4268, 6317, 173, 10417, 4274, 4275, 2228, 2230, 2232, 2236, 189, 190, 8385, 8386, 194, 10437, 6341, 199, 2248, 202, 2253, 206, 210, 2260, 212, 10455, 8407, 6361, 218, 221, 223, 226, 227, 8420, 2279, 231, 8425, 234, 10480, 4339, 244, 2301, 253, 1272, 6402, 2306, 8452, 259, 2311, 264, 265, 6410, 8459, 2316, 2320, 8467, 8468, 277, 4374, 2329, 2331, 4380, 4383, 6432, 291, 6439, 2343, 4393, 2345, 10541, 2350, 2351, 8496, 305, 2354, 307, 8500, 306, 4406, 313, 2362, 314, 8509, 10560, 2369, 2368, 8515, 10564, 320, 2374, 326, 328, 6474, 2379, 335, 343, 8537, 2403, 8549, 359, 360, 362, 8556, 4460, 366, 369, 6514, 4467, 4466, 6517, 2422, 8567, 10617, 4473, 6524, 380, 382, 385, 386, 6531, 4483, 389, 390, 393, 2442, 2444, 4493, 398, 2449, 405, 406, 408, 4505, 410, 411, 412, 409, 416, 420, 10662, 4519, 2472, 425, 8618, 2475, 428, 431, 6577, 434, 8628, 2485, 437, 10679, 8633, 442, 4541, 10686, 6591, 448, 446, 10695, 455, 4553, 8650, 6603, 4555, 6605, 2509, 461, 458, 6609, 2514, 6611, 6612, 469, 465, 6615, 472, 4571, 476, 475, 479, 4576, 2529, 481, 6627, 6628, 484, 487, 2536, 489, 6634, 10735, 6639, 498, 501, 8695, 504, 503, 6650, 4606, 2559, 514, 6662, 8711, 521, 6666, 8720, 533, 2585, 539, 540, 10781, 6685, 542, 4640, 541, 549, 2598, 4647, 551, 553, 554, 2603, 10796, 4653, 6702, 4655, 563, 566, 568, 2618, 573, 574, 2626, 10819, 10821, 582, 4681, 10828, 589, 2638, 4687, 588, 593, 2642, 8787, 10838, 8793, 8794, 605, 606, 10850, 8802, 2658, 610, 2662, 614, 2664, 8811, 2668, 2667, 8815, 4720, 627, 2677, 4726, 939, 4730, 4732, 637, 941, 6784, 640, 639, 643, 646, 2695, 943, 2702, 655, 656, 657, 131, 2709, 663, 665, 670, 2727, 2731, 683, 686, 8883, 693, 694, 696, 698, 2747, 4796, 699, 8895, 2752, 705, 10950, 710, 10954, 715, 716, 717, 10959, 720, 723, 724, 725, 729, 730, 4828, 4829, 734, 735, 6883, 2787, 741, 2790, 6887, 742, 10985, 6890, 6889, 6892, 4842, 750, 10991, 751, 747, 10994, 8946, 6899, 757, 755, 6904, 11002, 8955, 6907, 762, 2814, 2817, 4866, 770, 769, 2821, 8966, 774, 11016, 779, 11020, 2829, 6926, 783, 782, 2834, 6931, 2836, 4886, 8984, 792, 797, 4894, 11039, 4896, 799, 4902, 806, 812, 2861, 815, 4914, 4915, 2866, 11066, 4923, 2875, 9021, 830, 9023, 831, 829, 11074, 6983, 4937, 842, 841, 9036, 845, 844, 9039, 848, 849, 11092, 852, 2902, 9048, 4956, 2909, 2910, 11103, 9055, 11105, 4962, 11108, 869, 2918, 878, 879, 7024, 4978, 11128, 899, 900, 902, 4999, 2954, 907, 2957, 5006, 11151, 912, 5012, 2967, 919, 5018, 7067, 7068, 2972, 11166, 5023, 928, 923, 922, 2982, 2984, 936, 11178, 5034, 2987, 2986, 942, 11183, 9135, 5040, 2994, 946, 2996, 2997, 5046, 5047, 2998, 951, 945, 9149, 9150, 959, 9153, 7105, 5057, 964, 962, 5062, 9165, 5070, 977, 980, 9177, 3036, 5086, 991, 11234, 5092, 3044, 998, 997, 1001, 3050, 5099, 11244, 1005, 1006, 5103, 5104, 11249, 9204, 3063, 5114, 3068, 1021, 5119, 5120, 1023, 1026, 7172, 1029, 3079, 1034, 3087, 1039, 11285, 5142, 9241, 1050, 1051, 1052, 5150, 3102, 3109, 1061, 1065, 1076, 5178, 1084, 3134, 3136, 1090, 9283, 9284, 7237, 9286, 1094, 3144, 1097, 3147, 1101, 1102, 3152, 3155, 5210, 5211, 9308, 3169, 9320, 3176, 5234, 1139, 7284, 9333, 5238, 5237, 1145, 7291, 1148, 5249, 3201, 1153, 9351, 5256, 1164, 3214, 1170, 1171, 1172, 9367, 3226, 5276, 1181, 9376, 5283, 9381, 9384, 1201, 1202, 5299, 3252, 5301, 7350, 1207, 9400, 1210, 5309, 5310, 3261, 7364, 5318, 1222, 1224, 5324, 5325, 3278, 1231, 1229, 3281, 5332, 5333, 1236, 3287, 9435, 3293, 9444, 1255, 5352, 3305, 1262, 1263, 1265, 7411, 5364, 1268, 7415, 5368, 1273, 1274, 9467, 5372, 5371, 9470, 3327, 1275, 7425, 9478, 9483, 9484, 3347, 1300, 5398, 1306, 1310, 9503, 1311, 1313, 9508, 1320, 5417, 3369, 5425, 5426, 3378, 3380, 7477, 1340, 3395, 1350, 3399, 1353, 7505, 5458, 5457, 5463, 3418, 9564, 3420, 1372, 9567, 3423, 9575, 1383, 5484, 7533, 3441, 9586, 7538, 1401, 3452, 3456, 1408, 1414, 1419, 1422, 7568, 5520, 9618, 1427, 1431, 1435, 1436, 1439, 1440, 1442, 9635, 1443, 5541, 5542, 7591, 5544, 1447, 7594, 1452, 1453, 1454, 5552, 1456, 3508, 3509, 1460, 1464, 1466, 7616, 5569, 7619, 5578, 7629, 1485, 3535, 3538, 9683, 1494, 7639, 1496, 1495, 1499, 1502, 1504, 1509, 1515, 3564, 1516, 301, 5619, 9718, 1526, 5633, 5635, 1540, 3595, 5648, 7699, 3604, 3606, 1562, 7707, 5662, 1567, 9760, 9767, 9768, 3625, 1581, 5678, 7728, 1585, 1589, 1590, 1591, 1594, 7751, 5704, 5703, 5714, 1619, 1622, 5720, 5721, 3675, 9823, 5729, 3682, 5731, 1635, 9829, 7781, 1640, 9833, 3691, 1644, 1646, 9839, 1647, 3697, 1651, 1652, 9848, 1667, 5767, 739, 1678, 5776, 3729, 1681, 1688, 1689, 3745, 3752, 1705, 3768, 1720, 7866, 9915, 3772, 7872, 7873, 3779, 9929, 1739, 9934, 9936, 5843, 1748, 3797, 5846, 3798, 1751, 3801, 1767, 9963, 1774, 7922, 1778, 1782, 1783, 5883, 7934, 1791, 5888, 1793, 9998, 3861, 3862, 3864, 1818, 5916, 1822, 7967, 7969, 3883, 5934, 5936, 7986, 1842, 1845, 7992, 8006, 3913, 1869, 10062, 1874, 5977, 3932, 1894, 5991, 1896, 3952, 10097, 1912, 3965, 1918, 3968, 8067, 1924, 6035, 6039, 8091, 3998, 6049, 4002, 8100, 4004, 1963, 8113, 6065, 6067, 1971, 8119, 4025, 6075, 8125, 4031, 6084, 6087, 1991, 10187, 6091, 4044, 10190, 6100, 2005, 8151, 2007, 2011, 6111, 2017, 4070, 2023, 2024, 2022, 2040, 10236, 2045, 2046]

In [7]:
len(T) / 11296

0.08640226628895184

In [8]:
# train_indices, test_indices = train_test_split(np.arange(P), test_size=0.2)
test_indices = np.array(T)
train_indices = np.array([i for i in range(native_X.shape[0]) if i not in T])
train_indices.sort()
test_indices.sort()
train_indices_D = np.stack([train_indices * D + i for i in range(D)]).T.ravel()
test_indices_D = np.stack([test_indices * D + i for i in range(D)]).T.ravel()

X_test = X[test_indices_D]
X_train = X[train_indices_D]
y_test = y[test_indices_D]
y_train = y[train_indices_D]
s_test = s[test_indices]
s_train = s[train_indices]

In [31]:
np.save('train_test_split2/X_test.npy', X_test.toarray())
np.savez('train_test_split2/X_train.npz', data=X_train.data,
                                         indices=X_train.indices,
                                         indptr=X_train.indptr,
                                         shape=X_train.shape)

np.save('train_test_split2/y_test.npy', y_test)
np.save('train_test_split2/y_train.npy', y_train)
np.save('train_test_split2/s_test.npy', s_test)
np.save('train_test_split2/s_train.npy', s_train)

**=================================================================================**