# Imports

In [1]:
# General Imports
import gc
import sys
import pickle
from models_utils import *
from amino_acid_features import sanity_check_dimensions
from tqdm import tqdm

# PyTorch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import WeightedRandomSampler

# PyTorch Geometric
import torch_geometric
from torch_geometric.data.dataset import Dataset
from torch_geometric.nn import GCNConv, global_mean_pool

# Sets the seed for generating random numbers in PyTorch, numpy and Python.
torch_geometric.seed_everything(42)
dtype = torch.float

if torch.cuda.is_available():
    gc.collect()
    torch.cuda.empty_cache()

    device = torch.device("cuda")
    current_device = torch.cuda.current_device()

    print(f"Device Name: {torch.cuda.get_device_name(current_device)}")
    print(f"Device Available: {torch.cuda.is_available()}")
    print(f"Device Initialised: {torch.cuda.is_initialized()}")
    print(f"Device Properties: {torch.cuda.get_device_properties(current_device)}")
    print(f"Device Memory Summary: \n {torch.cuda.memory_summary(current_device)}")
    os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = 'max_split_size_mb:512'
else:
    device = torch.device("cpu")

Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


Device Name: NVIDIA GeForce GTX 1050 Ti
Device Available: True
Device Initialised: True
Device Properties: _CudaDeviceProperties(name='NVIDIA GeForce GTX 1050 Ti', major=6, minor=1, total_memory=4095MB, multi_processor_count=6)
Device Memory Summary: 
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Active memory         |       0 B  |       0

In [2]:
print(f"Python Version: {sys.version}")
print(f"Torch Version: {torch.__version__}")
print(f"Cuda Version: {torch.version.cuda}")
print(f"Torch Geometric Version: {torch_geometric.__version__}")

Python Version: 3.9.16 (main, Jan 11 2023, 16:16:36) [MSC v.1916 64 bit (AMD64)]
Torch Version: 1.13.0
Cuda Version: 11.7
Torch Geometric Version: 2.1.0


# Helper Methods

In [3]:
def print_useful_information(dataframe):
    binding_count = dataframe[dataframe['Molecular_Function_DNA.binding'] == 1]
    non_binding_count = dataframe[dataframe['Molecular_Function_DNA.binding'] == 0]

    print(f"DNA Binding Count: {binding_count.shape[0]}")
    print(f"DNA Non-Binding Count: {non_binding_count.shape[0]}")
    print(f"Class Imbalance: {non_binding_count.shape[0] / binding_count.shape[0]:.0f}:1 (DNA Non-Binding:DNA Binding)")

In [4]:
def remove_proteins_with_no_graphs(dataframe, dataframe_name):
    proteins_with_no_graphs = np.load("../Dataset_Files/Protein_Graph_Data/proteins_with_no_graphs.npy")

    dataframe_correct_dimensions = dataframe[dataframe["Protein_Accession"].isin(proteins_with_no_graphs) == False]
    dataframe_correct_dimensions.reset_index(inplace=True, drop=True)

    size_difference = len(dataframe) - len(dataframe_correct_dimensions)

    print(f"{dataframe_name} entries lost due to wrong graphs: {size_difference}")
    return dataframe_correct_dimensions

In [5]:
def get_dataset():
    dataset_feature_selection = load_from_pickle("Dataset_Populated_Feature_Selection")

    dataset_feature_selection_correct_dimensions = remove_proteins_with_no_graphs(dataset_feature_selection, "Training")

    protein_accession_class = dataset_feature_selection_correct_dimensions.loc[:,
                              "Protein_Accession":"Molecular_Function_DNA.binding"]

    X_normalised = normalise(dataset_feature_selection_correct_dimensions, "hydrophobicity.Group1")
    X_normalised = pd.DataFrame(X_normalised,
                                columns=dataset_feature_selection_correct_dimensions.loc[0,
                                        "hydrophobicity.Group1":].index,
                                index=protein_accession_class.index)

    dataset_normalised_feature_selection_correct_dimensions = pd.concat([protein_accession_class, X_normalised], axis=1)
    print_useful_information(dataset_normalised_feature_selection_correct_dimensions)
    return dataset_normalised_feature_selection_correct_dimensions

# Dataframe

In [6]:
dataframe = get_dataset()
dataframe

Training entries lost due to wrong graphs: 168
DNA Binding Count: 1989
DNA Non-Binding Count: 9045
Class Imbalance: 5:1 (DNA Non-Binding:DNA Binding)


Unnamed: 0,Protein_Accession,Molecular_Function_DNA.binding,hydrophobicity.Group1,hydrophobicity.Group3,polarity.Group1,polarity.Group3,charge.Group1,secondarystruct.Group2,solventaccess.Group1,solventaccess.Group2,Xc1.Q,Xc1.K,UniProt_Embedding_24,UniProt_Embedding_33,UniProt_Embedding_36,UniProt_Embedding_41,UniProt_Embedding_48,UniProt_Embedding_62,UniProt_Embedding_103,UniProt_Embedding_112,UniProt_Embedding_123,UniProt_Embedding_125,UniProt_Embedding_127,UniProt_Embedding_133,UniProt_Embedding_159,UniProt_Embedding_174,UniProt_Embedding_182,UniProt_Embedding_183,UniProt_Embedding_188,UniProt_Embedding_191,UniProt_Embedding_194,UniProt_Embedding_206,UniProt_Embedding_207,UniProt_Embedding_216,UniProt_Embedding_218,UniProt_Embedding_226,UniProt_Embedding_246,UniProt_Embedding_248,UniProt_Embedding_251,UniProt_Embedding_253,UniProt_Embedding_254,UniProt_Embedding_258,UniProt_Embedding_260,UniProt_Embedding_271,UniProt_Embedding_275,UniProt_Embedding_284,UniProt_Embedding_295,UniProt_Embedding_296,UniProt_Embedding_303,UniProt_Embedding_314,UniProt_Embedding_316,UniProt_Embedding_318,UniProt_Embedding_319,UniProt_Embedding_320,UniProt_Embedding_329,UniProt_Embedding_330,UniProt_Embedding_339,UniProt_Embedding_346,UniProt_Embedding_354,UniProt_Embedding_359,UniProt_Embedding_366,UniProt_Embedding_375,UniProt_Embedding_380,UniProt_Embedding_393,UniProt_Embedding_401,UniProt_Embedding_404,UniProt_Embedding_426,UniProt_Embedding_429,UniProt_Embedding_434,UniProt_Embedding_436,UniProt_Embedding_462,UniProt_Embedding_466,UniProt_Embedding_468,UniProt_Embedding_480,UniProt_Embedding_487,UniProt_Embedding_488,UniProt_Embedding_490,UniProt_Embedding_495,UniProt_Embedding_499,UniProt_Embedding_501,UniProt_Embedding_513,UniProt_Embedding_526,UniProt_Embedding_529,UniProt_Embedding_530,UniProt_Embedding_531,UniProt_Embedding_535,UniProt_Embedding_541,UniProt_Embedding_542,UniProt_Embedding_547,UniProt_Embedding_561,UniProt_Embedding_563,UniProt_Embedding_573,UniProt_Embedding_575,UniProt_Embedding_581,UniProt_Embedding_582,UniProt_Embedding_590,UniProt_Embedding_613,UniProt_Embedding_621,UniProt_Embedding_623,UniProt_Embedding_647,UniProt_Embedding_650,UniProt_Embedding_658,UniProt_Embedding_662,UniProt_Embedding_675,UniProt_Embedding_681,UniProt_Embedding_688,UniProt_Embedding_698,UniProt_Embedding_704,UniProt_Embedding_717,UniProt_Embedding_725,UniProt_Embedding_728,UniProt_Embedding_742,UniProt_Embedding_747,UniProt_Embedding_750,UniProt_Embedding_752,UniProt_Embedding_767,UniProt_Embedding_770,UniProt_Embedding_791,UniProt_Embedding_792,UniProt_Embedding_815,UniProt_Embedding_823,UniProt_Embedding_826,UniProt_Embedding_836,UniProt_Embedding_838,UniProt_Embedding_843,UniProt_Embedding_851,UniProt_Embedding_855,UniProt_Embedding_859,UniProt_Embedding_892,UniProt_Embedding_896,UniProt_Embedding_897,UniProt_Embedding_908,UniProt_Embedding_925,UniProt_Embedding_926,UniProt_Embedding_942,UniProt_Embedding_952,UniProt_Embedding_956,UniProt_Embedding_979,UniProt_Embedding_983,UniProt_Embedding_987,UniProt_Embedding_996,UniProt_Embedding_997,UniProt_Embedding_1007,UniProt_Embedding_1014,Tripeptide_Composition_PCA_Component_1,Tripeptide_Composition_PCA_Component_5
0,A0A024RBG1,0,0.916405,-0.305236,-0.233445,0.731093,0.642717,-0.000319,-0.214877,0.916405,-0.792135,-0.701142,0.298237,-0.565300,-1.991293,-0.084925,-0.240134,-1.132587,0.453842,-0.942955,-0.030196,-0.316114,-1.167242,-0.049784,0.037247,-0.628116,0.480436,0.239593,0.172265,-1.453291,-0.883213,-0.910657,-0.298370,0.146447,0.051537,-1.085512,1.555405,1.329280,-2.049825,-2.409428,0.625404,-0.450640,-0.499359,-1.627262,-0.554801,-0.877581,0.407440,0.153725,-0.555857,-0.717183,-0.226090,0.621391,0.264532,-0.007082,-0.386296,0.942037,1.428328,1.527616,-0.392314,-0.544164,0.235616,0.717897,-0.121158,-0.875129,0.215670,-0.358604,-1.387326,-1.091579,0.610351,-1.480118,0.764419,0.942095,0.321761,-0.469506,-0.483362,-1.260828,1.301769,-0.210134,-0.770473,-1.153852,-0.427284,0.527598,0.079270,-0.219223,-0.070345,0.576500,1.350331,-1.215433,0.570158,0.160650,0.575236,-0.381409,0.778707,0.111663,-0.371106,-1.366619,0.659971,-0.436829,-0.066719,0.249351,-0.567260,-0.224450,-0.208896,-0.889211,-0.162586,0.424765,-0.381743,0.549524,-0.188183,-0.290272,-0.594326,1.213956,-0.909752,-0.175113,0.448467,0.750420,-0.423157,-0.583951,-1.793584,0.159863,1.519093,-0.214477,-0.151524,0.406613,0.295764,0.025812,0.862332,0.626206,0.105541,1.367962,-0.576330,0.154692,-1.236848,-0.654571,-0.713014,-0.261297,-0.652350,-0.513559,-0.653190,-1.074442,-0.536669,0.747246,0.267036,0.032423,-0.365465,0.040042
1,A0A075B6L2,0,-0.735651,0.981135,0.850915,-0.366682,0.039038,1.244539,0.790031,-0.735651,-0.852348,-0.702965,-0.658022,-0.286108,0.497560,1.922494,0.262432,0.926107,0.116495,-1.461235,0.626955,0.623764,-0.191673,0.264998,-0.035565,0.940450,-0.246515,0.539394,1.817020,-1.231207,0.491006,-0.587006,0.140989,1.705566,2.565071,0.049309,0.930325,0.867694,-0.056478,-2.015450,-0.927281,-1.605354,0.386419,0.427832,0.172001,1.283194,0.561242,-0.464840,0.349529,-0.009291,0.490325,0.513019,0.097516,0.265543,-1.010397,-1.757778,-2.511772,0.001211,0.137643,-1.695409,1.035987,-0.195147,-0.140070,1.009481,1.579800,0.725446,1.616056,-2.781029,1.445664,1.356340,-2.060620,1.117965,1.557815,0.550927,-0.729448,1.584981,1.147511,1.437029,-0.843479,0.176542,0.766080,-0.877841,-0.027763,0.976620,0.745617,-0.498179,0.245121,1.289088,-0.103250,-0.793414,0.770852,-0.806972,0.235493,-0.529532,-0.176840,-1.701030,-0.727079,-0.793200,-0.105832,2.344393,2.152004,-0.736517,-0.359254,-0.707819,-0.944511,-0.377993,1.600952,-1.284614,-0.005690,-0.407856,-2.213219,-0.760688,-1.393940,-0.166450,2.291248,0.583645,0.038782,-0.967307,2.131962,0.542494,0.197464,-0.089665,0.136970,0.522849,-0.134250,-0.368921,2.014311,-0.486658,-0.870377,-0.121325,-1.484830,-1.050011,1.467721,0.201248,1.105460,0.497720,1.919321,0.900146,0.022882,-0.996384,-0.947140,2.124591,2.166675,0.122764,0.114527,-1.224869
2,A0A075B6L6,0,-0.591904,-0.581802,-0.467647,-0.741001,-0.353082,-0.228781,0.773383,-0.591904,-0.825111,-0.949516,-0.433889,-0.153859,1.072179,1.825346,0.119643,1.804798,-0.131418,0.034935,1.186293,0.166227,0.809017,-1.882920,0.395949,0.732227,-0.421711,1.240594,1.666332,-0.415437,-0.971156,-0.311040,0.129290,1.877973,0.627365,-0.827460,1.138383,2.044834,0.720023,-0.046452,0.073792,-1.175307,-0.192727,1.239009,-0.181901,1.238258,0.596735,-1.641411,-0.155853,-0.154857,1.356601,0.202719,-0.477159,-1.248685,-1.765426,-1.916938,-0.876463,1.006462,-0.143416,-2.103634,0.530814,-0.819389,-0.672379,1.144639,1.226803,0.393544,1.993302,-1.534726,2.006164,-0.126583,-1.141096,0.591300,1.406622,-0.462779,-1.422134,-0.375357,0.661883,0.656681,-0.792475,-0.471142,-0.801166,-0.858002,-0.586190,0.236254,-0.915861,0.522471,0.024823,1.113901,0.139658,1.080692,1.150701,0.704335,1.606639,0.320418,0.507590,-1.001455,-1.845215,0.275421,-0.134936,0.878594,1.632987,-1.434197,0.001068,-0.888286,-0.712969,-0.311687,1.747033,-3.259330,0.836163,-0.975159,-1.974753,0.213254,-0.436051,0.872406,3.083067,0.485464,1.176948,-0.673026,2.899198,0.742920,0.988781,-0.350838,0.773468,0.514644,-0.610408,0.503581,1.705049,1.165340,-0.505666,-0.756060,-0.363629,-1.208805,1.469716,-0.641173,1.603598,-1.327167,1.850657,0.138864,1.315642,-1.053627,-2.052313,0.887810,0.847208,0.578766,-0.206280,-2.275214
3,A0A075B6N1,0,0.141412,0.139337,0.592398,-0.153705,-0.843365,0.887530,-0.016122,0.141412,-0.684033,-0.902643,0.894354,-0.158757,0.904891,1.386777,0.608044,0.913837,0.615101,0.039779,1.340134,0.567825,-0.407241,-1.154956,0.163621,1.072830,-0.128150,0.491090,1.004637,-1.320385,-0.206966,-0.438877,0.302174,1.116311,0.950062,-0.208020,0.264085,2.683136,0.944930,-1.899994,0.269745,-0.379057,0.640323,0.460214,-0.351265,1.644328,0.750537,-0.880100,0.151678,0.262624,0.841780,-0.083875,-0.125204,0.127648,-0.874712,-2.107929,-1.049703,0.577555,-0.138755,-1.914457,0.238691,-0.381150,0.263935,0.122813,1.040345,0.550859,1.076239,-1.821513,2.208614,1.183822,-1.881416,0.521141,2.288255,-0.368034,-0.886561,-0.271756,1.068404,0.768926,-0.811476,0.224486,-0.568498,-0.804523,-0.650550,0.033150,-0.393991,0.700765,0.112298,1.104979,0.261054,0.029909,1.212801,-0.995423,1.477878,-0.119917,-0.172448,-1.296861,-1.392087,0.127405,-0.458112,1.647784,1.997340,-0.190702,0.454232,0.061710,-0.197411,-0.141109,1.805020,-1.757792,0.753619,-0.381669,-1.179183,0.184475,0.070425,-0.598645,2.996687,0.355003,0.937334,-0.586725,2.759701,0.597156,0.884564,-0.619508,1.225291,0.369691,-1.294106,0.326038,2.679133,0.661764,-0.708525,-1.066218,-0.633174,-1.137453,0.713514,-0.289209,1.689299,-0.004121,-0.041350,0.114307,0.698953,-0.722316,-2.017540,1.749450,0.593327,0.469128,-0.474337,-2.777483
4,A0A075B6N2,0,-1.295328,0.553585,0.478548,-1.282997,-0.773463,0.246857,0.169075,-1.295328,-0.837127,-0.891074,0.074740,-0.173451,0.386035,1.710449,0.676998,1.419722,0.353379,-0.322714,0.542793,-0.174684,-0.003259,-1.015288,0.296523,1.119715,0.010657,-0.203021,1.016938,-1.024187,-0.256223,-0.548452,0.316472,1.306431,1.305511,-0.854720,0.628188,1.761896,0.547735,-0.002412,0.114030,-0.916063,0.481056,0.796269,-0.402294,1.532535,0.855537,-1.157902,-0.692418,0.013769,1.084107,-0.041753,0.355984,-0.104516,-1.616223,-2.119719,-1.967304,0.089516,-0.269730,-1.397537,0.400786,0.052122,-0.683503,0.733466,1.468036,0.920573,0.886574,-1.678119,1.802042,1.508670,-1.408434,0.622062,0.238310,0.300786,-0.252684,-0.346219,1.588311,0.439708,-1.193510,0.035578,-0.672462,-0.909756,-0.062868,0.733451,-0.134738,-0.571098,-0.147401,1.449675,-0.674815,0.140561,1.086530,-0.639267,1.232009,-0.736431,-0.213435,-1.156014,-1.328103,0.185035,-0.028528,1.486771,1.823342,-1.227031,-0.187330,-0.682831,-0.288004,0.165791,1.057328,-2.469851,0.153929,-1.469710,-1.985032,0.186999,0.374681,0.690961,2.933341,0.290445,1.422406,-0.574088,2.125423,0.724700,-0.174677,-0.150045,1.460744,0.454475,-0.635598,-0.256461,1.426111,0.387873,-0.457764,0.048761,0.103943,-0.636228,1.407863,0.453626,1.347566,-0.729566,1.182259,0.651639,0.905260,-0.316416,-2.200476,1.593369,0.983088,-0.176119,-0.052141,0.029200
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11029,Q9Y6X6,0,0.152197,-0.227904,-0.267759,0.165412,-0.271806,-0.800540,-0.439321,0.152197,2.725410,2.775455,0.453920,-0.682855,0.365559,0.704696,0.456394,-0.270649,0.601755,-0.006993,-0.236267,0.101844,-0.165566,0.479260,0.272085,0.226149,0.281322,-0.903958,-1.127035,-0.726010,0.108097,0.118127,-0.038394,0.344760,1.074088,-0.060056,0.369924,0.707877,-0.621002,0.773939,-0.105308,1.307688,1.728070,-0.198434,0.723169,-0.124077,-0.458808,-0.220201,0.433401,-1.422192,0.023098,1.029480,0.325759,0.474656,-0.466155,-0.671957,-1.704588,-0.887351,1.240441,0.470591,-0.930896,-1.283017,0.484943,-0.511994,-0.645522,-0.542178,-0.093147,0.954830,0.079340,-0.988257,-0.229651,-0.245752,-0.713520,0.163994,0.167828,0.331643,0.255995,0.739925,0.842172,0.485383,-0.091763,0.284247,-0.497525,-0.479742,0.015191,-0.342547,-0.300916,0.252972,-0.921884,0.300098,0.226050,-0.223856,-0.277612,0.195874,0.668608,0.854980,0.263222,0.136887,0.584638,0.132033,-0.250774,-0.113599,-0.553343,-0.725403,-0.132231,-0.650704,0.713590,0.066669,0.638256,-0.468959,0.759377,-0.134114,0.173774,0.173579,-1.153775,-1.017186,0.072873,0.096595,-0.417136,1.672168,-1.380560,-0.869639,-0.392151,0.260293,0.675620,0.101118,-0.924606,-0.907946,0.426613,0.499873,1.363592,-1.618829,-0.323676,0.168911,-0.982437,0.919039,0.499745,0.157924,1.471766,-0.155965,0.673012,-0.345776,0.301900,-0.294153,-0.081171,0.435902
11030,Q9Y6X8,1,0.571363,-0.947613,-1.019462,0.505742,0.198593,-0.251205,-1.035372,0.571363,1.001802,1.263062,-0.183777,-1.015926,-1.447605,0.146736,1.177032,-0.937487,0.159497,0.830822,-0.908138,-0.343556,0.325789,0.240992,1.032674,-1.602350,1.283434,-0.898707,-1.863049,0.872993,-1.080534,-0.598166,0.419163,-1.047044,0.906393,-0.616048,-0.872096,-0.914917,-1.304143,-0.165776,0.034695,-0.372700,0.063558,1.046406,0.998263,-0.215045,-1.166608,-0.654206,0.142000,-0.052528,-0.500545,-0.546361,0.931822,0.130813,0.303645,2.034932,1.054561,-0.587156,-0.920874,1.246799,-1.054114,-0.425305,0.728199,-0.956550,-1.036696,-0.943868,1.212756,0.758808,-0.524034,-0.368843,-0.424279,1.705750,-1.068431,0.824965,0.048196,-0.925202,-0.877181,1.029938,0.010597,-0.202997,-2.087654,1.200937,-0.473672,0.286053,-0.756321,-0.415541,0.781371,0.942769,-0.896749,1.080044,-0.310732,0.450881,-0.748258,-0.438950,0.098406,1.253292,-0.474292,2.021319,0.710010,-0.816211,-1.288530,1.085205,-0.939724,-0.506992,-0.270592,0.241989,0.530151,0.118622,-0.497476,0.282759,-0.236114,-1.654352,-0.542677,-0.690089,-1.279229,0.578266,-0.929902,-0.147514,-0.570529,0.906905,-1.191953,0.544931,-0.390858,-0.912321,0.359519,-0.399592,-1.311647,-1.714364,0.668302,-1.125072,-0.454117,-0.193905,-0.243086,-0.014801,-1.194548,1.003278,0.309344,0.861660,-0.127945,0.297637,1.827888,-0.416514,0.163338,-0.904532,-0.048263,0.623242
11031,Q9Y6X9,0,1.295959,-0.660280,-0.680386,1.057755,1.312932,-0.317940,-1.112525,1.295959,0.774422,1.448333,1.190057,-0.315496,-0.858174,0.516413,-1.274438,-1.425878,0.236235,0.107289,0.017761,0.641707,0.608042,0.203166,0.307905,0.100664,-0.102188,0.519705,-0.876914,-0.055282,0.472418,-0.527146,-0.356864,-0.918919,-0.485699,-0.689386,0.150557,-0.221089,-1.474928,0.046537,-0.346528,0.998143,0.539339,-1.485733,0.243957,-0.910461,-0.164882,0.347958,0.552757,-0.086157,0.907671,-0.306704,0.483475,-0.186262,0.619325,-0.304122,-0.519195,-0.616821,0.486171,0.176038,0.196520,0.091861,0.875784,-1.413589,-0.623390,-1.397608,0.228216,-0.534022,0.080177,-0.725809,-0.454391,0.640141,0.140378,0.017112,-0.802362,-0.348378,-1.611994,0.491803,0.064102,-0.686463,-0.031372,0.343656,1.009766,0.134802,1.526259,-0.330390,0.637644,-0.516305,-0.221587,-0.965142,0.199269,0.268863,-0.149405,-0.548341,1.528484,1.134011,0.926787,0.445602,0.747938,-0.466237,-0.524132,2.060220,-0.673151,0.107521,1.126662,-0.097529,0.309217,0.094173,0.088788,0.306379,1.014288,0.814079,-0.168611,-0.901133,-0.811965,-0.232069,-0.534837,0.530567,-1.048145,-1.853506,-0.750757,0.473610,-1.085572,-0.456267,-0.218737,-0.487616,-0.127376,-1.057836,-0.200017,0.456469,-1.750288,-0.278940,-1.102667,0.380579,-0.393777,-0.384682,-0.878726,0.162688,0.629255,-0.567935,-0.098606,-0.916703,-0.309561,-1.162155,-0.180327,0.735362
11032,Q9Y6Y1,0,-0.183990,-0.665566,-0.660520,-0.116283,-0.433691,-0.321995,-0.798424,-0.183990,3.101483,2.268052,0.651883,-0.702447,-0.443873,1.570332,0.988976,0.928938,-0.539848,1.418852,-1.036339,-0.193682,0.200180,0.470002,1.605124,-0.137896,-0.188661,-1.165431,-0.250584,0.469455,-1.045856,-0.201466,0.432161,-1.140333,0.366364,-0.493012,-0.920492,0.077813,1.020100,0.134691,-0.347772,-0.001661,0.534722,1.273468,0.341365,0.553151,-1.338180,-0.035160,0.645769,-0.917516,0.316503,-1.717077,0.962823,-0.237457,-0.369273,0.851110,0.393964,-0.358757,0.352050,0.924036,-0.447137,-0.601925,1.204328,0.307095,-1.749883,-0.284498,0.276414,1.330094,-0.290318,-0.966234,-0.516819,-0.166958,0.107980,-0.925851,0.982303,0.184739,-1.154055,1.143794,0.280247,-0.225390,-1.911971,0.809980,-0.880987,0.578756,-0.587170,-1.029046,0.627732,-0.098618,-0.307687,0.867490,-0.512299,0.197426,0.759767,-1.132001,0.087637,-0.273125,0.128043,0.899747,0.735822,-0.406171,-0.205370,-0.059720,-0.646194,-0.868851,0.023895,-0.881171,0.326432,0.105379,-0.152381,-0.035078,0.188346,-0.550273,-0.016537,-1.082097,-0.522231,1.478040,-0.928441,-0.569773,-0.493152,1.690388,-1.023521,-0.039011,0.065817,-1.376068,0.309585,-0.164199,-0.142191,-0.675259,0.733623,-1.700659,0.020952,0.380200,-0.025728,-0.266020,-0.633204,0.712406,0.280243,-0.737142,-0.766937,-0.152495,0.588867,-0.398716,0.360900,-0.780637,0.122421,0.292612


# Contact Maps Loading & Plotting

In [7]:
contact_map_A0A0A0MRZ7 = np.load("../Dataset_Files/Protein_Graph_Data/raw/Contact_Map_Files/A0A0A0MRZ7.npy")
fig = px.imshow(contact_map_A0A0A0MRZ7, color_continuous_scale=["white", "black"], template=template)

fig.show()

# Dataset Class

In [8]:
class MyDataset(Dataset):
    def __init__(self, root, dataframe, unique_proteins_list=None, transform=None, pre_transform=None,
                 pre_filter=None):
        self.dataframe = dataframe
        self.unique_proteins_list = unique_proteins_list
        super(MyDataset, self).__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return os.listdir(f"{self.root}/raw")

    @property
    def processed_file_names(self):
        return os.listdir(f"{self.root}/processed")

    def download(self):
        pass

    def process(self):
        print("Creating Protein Graphs")
        for i in tqdm(range(len(self.unique_proteins_list))):
            try:
                accession = self.unique_proteins_list[i]

                amino_acid_descriptors = np.load(
                    f"{self.root}/raw/Amino_Acid_Descriptors_And_PSSM/{accession}_Descriptors.npy")
                pssm = np.load(f"{self.root}/raw/Amino_Acid_Descriptors_And_PSSM/{accession}_PSSM.npy")
                uniprot_per_residue_embedding = np.load(f"{self.root}/raw/Amino_Acid_Embeddings/{accession}.npy")
                contact_map = np.load(f"{self.root}/raw/Contact_Map_Files/{accession}.npy")

                if (contact_map.shape[0]) == (amino_acid_descriptors.shape[0]) == (pssm.shape[0]) == (
                        uniprot_per_residue_embedding.shape[0]):
                    amino_acid_features = np.hstack((amino_acid_descriptors, pssm, uniprot_per_residue_embedding))

                    contact_map_sparse = []
                    index_row, index_col = np.where(contact_map == 1)
                    for row, column in zip(index_row, index_col):
                        contact_map_sparse.append([row, column])
                    contact_map_sparse = np.array(contact_map_sparse)

                    data = torch_geometric.data.Data(x=torch.Tensor(amino_acid_features),
                                                     edge_index=torch.LongTensor(contact_map_sparse).transpose(1, 0),
                                                     size=contact_map.shape[0],
                                                     accession=accession)

                    torch.save(data, f"{self.root}/processed/protein_graph_{accession}.pt")
                else:
                    continue

            except:
                continue

    def len(self):
        return self.dataframe.shape[0]

    def get(self, idx):
        accession = self.dataframe.loc[idx, "Protein_Accession"]
        protein_graph = torch.load(f"{self.root}/processed/protein_graph_{accession}.pt")

        protein_sequence_descriptors = torch.tensor(self.dataframe.loc[idx, "hydrophobicity.Group1":],
                                                    dtype=torch.float32)
        label = torch.tensor(self.dataframe.loc[idx, "Molecular_Function_DNA.binding"], dtype=torch.long)

        return protein_sequence_descriptors, protein_graph, label

# Model

In [9]:
class Model(nn.Module):
    def __init__(self, n_output=2, protein_sequence_features=144, amino_acid_features=1110, output_dim=128,
                 dropout=0.2):
        super(Model, self).__init__()

        # Protein Structure Layers
        self.pro_conv1 = GCNConv(amino_acid_features, amino_acid_features)
        self.pro_batch_normalisation1 = nn.BatchNorm1d(amino_acid_features)

        self.pro_conv2 = GCNConv(amino_acid_features, amino_acid_features * 2)
        self.pro_batch_normalisation2 = nn.BatchNorm1d(amino_acid_features * 2)

        self.pro_conv3 = GCNConv(amino_acid_features * 2, amino_acid_features * 4)
        self.pro_batch_normalisation3 = nn.BatchNorm1d(amino_acid_features * 4)

        self.pro_fc1 = nn.Linear(amino_acid_features * 4, 1024)
        self.pro_fc2 = nn.Linear(1024, output_dim)

        # Protein Sequence Descriptors Layers
        self.psd_fc1 = nn.Linear(protein_sequence_features, protein_sequence_features)
        self.psd_batch_normalisation1 = nn.BatchNorm1d(protein_sequence_features)

        self.psd_fc2 = nn.Linear(protein_sequence_features, protein_sequence_features * 2)
        self.psd_batch_normalisation2 = nn.BatchNorm1d(protein_sequence_features * 2)

        self.psd_fc3 = nn.Linear(protein_sequence_features * 2, protein_sequence_features * 4)
        self.psd_batch_normalisation3 = nn.BatchNorm1d(protein_sequence_features * 4)

        self.psd_fc4 = nn.Linear(protein_sequence_features * 4, 1024)
        self.psd_fc5 = nn.Linear(1024, output_dim)

        # Combined Layers
        self.combined_fc1 = nn.Linear(2 * output_dim, 1024)
        self.combined_fc2 = nn.Linear(1024, 512)
        self.out = nn.Linear(512, n_output)

        # Other Layers
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, protein_sequence_descriptors, protein_graph):
        # Protein Graphs
        x_protein = self.pro_conv1(protein_graph.x, protein_graph.edge_index)
        x_protein = self.relu(x_protein)
        x_protein = self.pro_batch_normalisation1(x_protein)

        x_protein = self.pro_conv2(x_protein, protein_graph.edge_index)
        x_protein = self.relu(x_protein)
        x_protein = self.pro_batch_normalisation2(x_protein)

        x_protein = self.pro_conv3(x_protein, protein_graph.edge_index)
        x_protein = self.relu(x_protein)
        x_protein = self.pro_batch_normalisation3(x_protein)

        x_protein = global_mean_pool(x_protein, protein_graph.batch)

        # Flatten
        x_protein = self.pro_fc1(x_protein)
        x_protein = self.relu(x_protein)
        x_protein = self.dropout(x_protein)

        x_protein = self.pro_fc2(x_protein)
        x_protein = self.dropout(x_protein)

        # Protein Sequence Descriptors
        x_protein_sequence_descriptors = self.psd_fc1(protein_sequence_descriptors)
        x_protein_sequence_descriptors = self.relu(x_protein_sequence_descriptors)
        x_protein_sequence_descriptors = self.psd_batch_normalisation1(x_protein_sequence_descriptors)

        x_protein_sequence_descriptors = self.psd_fc2(x_protein_sequence_descriptors)
        x_protein_sequence_descriptors = self.relu(x_protein_sequence_descriptors)
        x_protein_sequence_descriptors = self.psd_batch_normalisation2(x_protein_sequence_descriptors)

        x_protein_sequence_descriptors = self.psd_fc3(x_protein_sequence_descriptors)
        x_protein_sequence_descriptors = self.relu(x_protein_sequence_descriptors)
        x_protein_sequence_descriptors = self.psd_batch_normalisation3(x_protein_sequence_descriptors)

        x_protein_sequence_descriptors = self.psd_fc4(x_protein_sequence_descriptors)
        x_protein_sequence_descriptors = self.relu(x_protein_sequence_descriptors)
        x_protein_sequence_descriptors = self.dropout(x_protein_sequence_descriptors)

        x_protein_sequence_descriptors = self.psd_fc5(x_protein_sequence_descriptors)
        x_protein_sequence_descriptors = self.dropout(x_protein_sequence_descriptors)

        # Combine
        x_combined = torch.cat((x_protein, x_protein_sequence_descriptors), 1)
        x_combined = self.combined_fc1(x_combined)
        x_combined = self.relu(x_combined)
        x_combined = self.dropout(x_combined)

        x_combined = self.combined_fc2(x_combined)
        x_combined = self.relu(x_combined)
        x_combined = self.dropout(x_combined)
        out = self.out(x_combined)

        return out

# Training Loop

In [10]:
def training_loop(n_epochs, optimizer, model, device, loss_fn, train_loader):
    model = model.to(device)

    train_losses = []
    avg_train_losses = []

    for epoch in range(1, n_epochs + 1):

        # Training
        model.train()
        for protein_sequence_descriptors, protein_graphs, labels in train_loader:
            protein_sequence_descriptors = protein_sequence_descriptors.to(device)
            protein_graphs = protein_graphs.to(device)
            outputs = model(protein_sequence_descriptors, protein_graphs)
            loss = loss_fn(outputs, labels.to(device))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())

        # Printing Stats
        train_loss = np.average(train_losses)
        avg_train_losses.append(train_loss)

        print(f"[{epoch}:/{n_epochs}] Training Loss: {train_loss:.2f}")

        # Clear lists to track next epoch
        train_losses = []

    torch.save(model.state_dict(), "Dataset_Files/Neural_Networks/mf_model.pt")
    return model, avg_train_losses

# Model Hyperparameters

In [11]:
model = Model()

BATCH_SIZE = 32
N_EPOCHS = 200
LR_RATE = 0.001
WEIGHT_DECAY = 0.0001

optimizer = optim.Adam(model.parameters(), lr=LR_RATE, weight_decay=WEIGHT_DECAY)
loss_fn = nn.CrossEntropyLoss()

# Weighted Random Sampler
To balance our batches, given the clear imbalance between the classes

In [12]:
# Reference
# https://www.maskaravivek.com/post/pytorch-weighted-random-sampler/

training_class_counts = dataframe["Molecular_Function_DNA.binding"].value_counts(ascending=True).to_numpy()

training_weights = 1. / training_class_counts

training_weights_all = np.array([training_weights[t] for t in dataframe["Molecular_Function_DNA.binding"].astype(int)])
training_weights_all = torch.from_numpy(training_weights_all)

print(training_weights)
#print(training_weights_all)
print(len(training_weights_all))

[0.00050277 0.00011056]
11034


# Dataloaders

In [13]:
training_sampler = WeightedRandomSampler(training_weights_all, len(training_weights_all))

trainloader = torch_geometric.loader.DataLoader(
    MyDataset(root="../Dataset_Files/Protein_Graph_Data",
              dataframe=dataframe,
              unique_proteins_list=dataframe.loc[:, "Protein_Accession"]),
    batch_size=BATCH_SIZE,
    sampler=training_sampler)

In [14]:
for protein_sequence_descriptors, protein_graphs, labels in trainloader:
    print(protein_sequence_descriptors.shape)
    print(protein_sequence_descriptors.dtype)
    protein_sequence_descriptors = protein_sequence_descriptors.to(device)
    print(protein_graphs)
    protein_graphs = protein_graphs.to(device)
    print(labels.shape)
    labels = labels.to(device)
    print("Check OK")
    break

torch.Size([32, 144])
torch.float32
DataBatch(x=[18669, 1110], edge_index=[2, 289741], size=[32], accession=[32], batch=[18669], ptr=[33])
torch.Size([32])
Check OK


# Training

In [15]:
# model, train_loss = training_loop(n_epochs=N_EPOCHS,
#                                   optimizer=optimizer,
#                                   model=model,
#                                   device=device,
#                                   loss_fn=loss_fn,
#                                   train_loader=trainloader)
#
# load_to_numpy(train_loss, "Neural_Networks/mf_train_loss")

In [16]:
load_from_numpy("Neural_Networks/mf_train_loss")

array([0.11339996, 0.07785756, 0.0835734 , 0.07071257, 0.0736931 ,
       0.07098129, 0.06006887, 0.06130498, 0.06483565, 0.056689  ,
       0.06247982, 0.05400618, 0.05634175, 0.05724875, 0.06004448,
       0.05764959, 0.05130492, 0.04501449, 0.05616299, 0.0453704 ,
       0.04122233, 0.04674479, 0.04582046, 0.04735578, 0.04336906,
       0.03680433, 0.04174923, 0.03975666, 0.05054074, 0.04216495,
       0.0409558 , 0.03771227, 0.03926447, 0.03598504, 0.04684586,
       0.03512827, 0.03394992, 0.03380969, 0.02796467, 0.03230669,
       0.03106207, 0.03247793, 0.03245615, 0.03524276, 0.02988787,
       0.03032487, 0.02951382, 0.02612126, 0.03039642, 0.03714302,
       0.032536  , 0.03256359, 0.03042663, 0.0277351 , 0.02868623,
       0.03104124, 0.03226861, 0.02589645, 0.02394617, 0.03002077,
       0.022396  , 0.02685013, 0.02325123, 0.03411857, 0.02292241,
       0.025338  , 0.02395948, 0.0255581 , 0.0311159 , 0.02811889,
       0.01993447, 0.02135175, 0.02442882, 0.01988658, 0.02168

In [17]:
train_loss = pd.Series(load_from_numpy("Neural_Networks/mf_train_loss"))

fig = px.line(train_loss,
              labels={
                  "index": "Epoch",
                  "value": "Cross-Entropy Loss",
              },
              title="Cross-Entropy Loss Over Epochs",
              template=template)
fig.update_layout(showlegend=False)
fig.show()

# Get Protein Embeddings

In [20]:
# batch_accession_list = []
# batch_embeddings = []
# protein_embeddings = {}
# model.load_state_dict(torch.load("Dataset_Files/Neural_Networks/mf_model.pt",map_location='cuda:0'))
# model.to(device)

Model(
  (pro_conv1): GCNConv(1110, 1110)
  (pro_batch_normalisation1): BatchNorm1d(1110, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pro_conv2): GCNConv(1110, 2220)
  (pro_batch_normalisation2): BatchNorm1d(2220, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pro_conv3): GCNConv(2220, 4440)
  (pro_batch_normalisation3): BatchNorm1d(4440, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pro_fc1): Linear(in_features=4440, out_features=1024, bias=True)
  (pro_fc2): Linear(in_features=1024, out_features=128, bias=True)
  (psd_fc1): Linear(in_features=144, out_features=144, bias=True)
  (psd_batch_normalisation1): BatchNorm1d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (psd_fc2): Linear(in_features=144, out_features=288, bias=True)
  (psd_batch_normalisation2): BatchNorm1d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (psd_fc3): Linear(in_features=288, out_features=576, bias=Tru

In [None]:
# # Forward hook
# def get_protein_embeddings():
#     def hook(model, input, output):
#         batch_embeddings.append(output.detach())
#     return hook

In [None]:
# # Register hook
# hook = model.pro_fc2.register_forward_hook((get_protein_embeddings()))
#
# for protein_sequence_descriptors, protein_graphs, labels in trainloader:
#     gc.collect()
#     torch.cuda.empty_cache()
#
#     output = model(protein_sequence_descriptors.to(device), protein_graphs.to(device))
#
#     for i in range(protein_sequence_descriptors.shape[0]):
#         batch_accession_list.append(protein_graphs[i].accession)
#
#     for i in range(len(batch_accession_list)):
#         protein_embeddings[batch_accession_list[i]] = batch_embeddings[0][i].cpu().numpy()
#
#     batch_accession_list = []
#     batch_embeddings = []
#
# hook.remove()

In [21]:
if os.path.exists("Dataset_Files/protein_embeddings.pkl"):
    with open('Dataset_Files/protein_embeddings.pkl', 'rb') as file:
        protein_embeddings = pickle.load(file)
else:
    with open('Dataset_Files/protein_embeddings.pkl', 'wb') as file:
        pickle.dump(protein_embeddings, file, protocol=pickle.HIGHEST_PROTOCOL)
protein_embeddings

{'Q6JQN1': array([ 1.30066089e-02,  3.88356030e-01, -2.56263018e-02,  2.89109349e-02,
         2.31596529e-01, -2.84079313e-01, -2.70526391e-03,  1.41009502e-02,
        -8.58326256e-02,  1.62036810e-03, -8.80797356e-02, -1.25764245e-02,
         2.64684185e-02, -2.44270731e-02,  7.58701423e-03, -1.09423190e-01,
        -7.93765634e-02, -1.95755720e-01, -1.90979242e-02, -2.09489778e-01,
         2.47414112e-02,  4.19138297e-02, -5.28238639e-02, -2.32324705e-01,
         8.30572750e-03, -1.16548873e-02, -1.54799148e-02,  4.90930304e-02,
         4.07092012e-02, -2.24016514e-02, -5.84624615e-03, -7.26676732e-02,
        -3.05888001e-02,  3.56669836e-02,  5.36373351e-03,  3.70354094e-02,
        -3.00989989e-02, -1.54267862e-01,  1.77055262e-02,  2.74378899e-03,
        -8.77451152e-04, -3.39035690e-02,  2.34820712e-02, -7.58502446e-03,
         9.07757040e-03, -2.38958467e-03, -3.04873195e-03, -3.43019329e-03,
         1.90302264e-02, -1.12659484e-02, -6.55687414e-03,  2.42341477e-02,
  