In [24]:
import pandas as pd
import os
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, KFold
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary

In [12]:
dfs = []
for filename in os.listdir('features_ring'):
    if filename[-4:] == '.tsv':
        dfs.append(pd.read_csv('features_ring/' + filename, sep='\t'))
df = pd.concat(dfs)
df

Unnamed: 0,pdb_id,s_ch,s_resi,s_ins,s_resn,s_ss8,s_rsa,s_up,s_down,s_phi,...,t_down,t_phi,t_psi,t_ss3,t_a1,t_a2,t_a3,t_a4,t_a5,Interaction
0,6ihr,A,17,,D,E,0.061,24.0,24.0,-1.424,...,24.0,-2.498,2.442,H,-0.032,0.326,2.213,0.908,1.313,VDW
1,6ihr,A,142,,L,H,0.091,15.0,9.0,-1.106,...,6.0,-1.864,0.166,H,-0.032,0.326,2.213,0.908,1.313,
2,6ihr,A,53,,E,H,0.237,16.0,14.0,-1.174,...,12.0,-1.140,-0.727,H,1.538,-0.055,1.502,0.440,2.897,HBOND
3,6ihr,A,75,,N,H,0.561,6.0,17.0,-1.175,...,19.0,-1.120,-0.758,H,1.050,0.302,-3.656,-0.259,-3.242,HBOND
4,6ihr,A,5,,F,E,0.168,23.0,13.0,-2.231,...,23.0,-2.740,2.733,H,-1.006,-0.590,1.891,-0.397,0.412,PIPISTACK
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
134,6va5,A,1499,,G,S,0.429,1.0,18.0,-1.634,...,22.0,-0.903,-0.734,H,-1.019,-0.987,-1.505,1.266,-0.912,VDW
135,6va5,A,1519,,F,E,0.025,19.0,14.0,-1.351,...,13.0,1.895,-0.291,L,-0.384,1.652,1.330,1.045,2.064,HBOND
136,6va5,A,1572,,I,E,0.006,20.0,17.0,-2.065,...,8.0,-2.424,2.237,H,0.260,0.830,3.097,-0.838,1.512,VDW
137,6va5,A,1563,,K,-,0.537,6.0,11.0,-2.276,...,10.0,1.523,-0.111,L,-0.384,1.652,1.330,1.045,2.064,HBOND


In [32]:
df.dropna(inplace=True)

# Define ground truth values
y = df['Interaction'].astype('category')
y_oneHot = pd.get_dummies(y)
y

0            VDW
2          HBOND
3          HBOND
4      PIPISTACK
5          HBOND
         ...    
134          VDW
135        HBOND
136          VDW
137        HBOND
138        HBOND
Name: Interaction, Length: 454193, dtype: category
Categories (6, object): ['HBOND', 'IONIC', 'PICATION', 'PIPISTACK', 'SSBOND', 'VDW']

In [19]:
# Define training features
X = df[['s_rsa', 's_up', 's_down', 's_phi', 's_psi', 's_a1', 's_a2', 's_a3', 's_a4', 's_a5', 
        't_rsa', 't_up', 't_down', 't_phi', 't_psi', 't_a1', 't_a2', 't_a3', 't_a4', 't_a5']]
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
X_scaled = pd.DataFrame(X_scaled, columns=X.columns)
X_scaled

Unnamed: 0,s_rsa,s_up,s_down,s_phi,s_psi,s_a1,s_a2,s_a3,s_a4,s_a5,t_rsa,t_up,t_down,t_phi,t_psi,t_a1,t_a2,t_a3,t_a4,t_a5
0,-0.612926,1.212548,1.272426,-0.080637,1.435561,1.032979,0.438314,-1.675155,-0.535573,-1.967293,-0.926898,0.482516,1.366067,-1.502723,1.362867,-0.045132,0.558611,1.029760,0.741296,0.813673
1,0.235709,0.077481,-0.558473,0.268180,-0.783304,1.337499,-1.430106,0.702567,-0.124882,-0.467409,-0.118231,-0.267989,-0.712127,0.257241,-0.791289,1.476577,0.135358,0.690168,0.232221,1.818554
2,1.797971,-1.341352,-0.009203,0.266785,-0.772920,0.928827,0.998308,0.620114,-0.436212,0.636456,1.398591,-1.318697,0.500153,0.283161,-0.812362,1.003587,0.531949,-1.773429,-0.528128,-2.075994
3,-0.096994,1.070665,-0.741563,-1.206618,1.219451,-1.006413,-0.511333,0.894341,-0.687926,0.311533,-0.926898,0.932819,1.192884,-1.816354,1.560677,-0.989173,-0.458974,0.875965,-0.678240,0.242084
4,-0.096994,1.070665,-0.741563,-1.206618,1.219451,-1.006413,-0.511333,0.894341,-0.687926,0.311533,-0.926898,0.932819,1.192884,-1.816354,1.560677,-0.989173,-0.458974,0.875965,-0.678240,0.242084
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
454188,1.161494,-2.050769,0.173887,-0.373643,-0.294621,-0.389438,1.875560,0.634473,0.904054,1.341807,0.854911,-1.618899,1.019702,0.564392,-0.796048,-1.001773,-0.900002,-0.746055,1.130717,-0.597855
454189,-0.786511,0.503131,-0.558473,0.021218,1.453084,-1.006413,-0.511333,0.894341,-0.687926,0.311533,1.901152,-2.219303,-0.538944,4.190592,-0.494914,-0.386305,2.031665,0.608016,0.890320,1.290104
454190,-0.878125,0.645015,-0.009203,-0.975003,1.163638,-1.237531,-0.465554,1.005515,0.184240,0.563489,-0.433474,1.082920,-1.404858,-1.406819,1.223517,0.237886,1.118505,1.451982,-1.157946,0.939918
454191,1.682248,-1.341352,-1.107742,-1.269405,-0.158335,1.807671,-0.480459,0.265285,-0.555445,1.082368,2.335182,-2.369404,-1.058493,3.708481,-0.372558,-0.386305,2.031665,0.608016,0.890320,1.290104


In [25]:
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.1, random_state=0)
kf = KFold(n_splits=10, shuffle=True)

In [None]:
class ContactClassNet(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_layers_dim=[]):
        super().__init__()
        self.layers = nn.ModuleList()
        if len(hidden_layers_dim) == 0:
            self.layers = self.layers.append(nn.Linear(input_dim, num_classes))
        else:
            for layer_idx in range(len(hidden_layers_dim)):
                if layer_idx == 0:  # first layer, from input to hidden
                    self.layers = self.layers.append(nn.Linear(input_dim, hidden_layers_dim[layer_idx]))
                else:  # hidden layers, depending on the input
                    self.layers = self.layers.append(nn.Linear(hidden_layers_dim[layer_idx-1], hidden_layers_dim[layer_idx]))
            self.layers = self.layers.append(nn.Linear(hidden_layers_dim[-1], num_classes))  # final output layer
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=.1)
            if module.bias is not None:
                module.bias.data.zero_()
  
    def forward(self, x):
        if len(self.layers) == 1:
            return self.layers[0](x)
        else:
            for layer in self.layers[:-1]:
                x = F.relu(layer(x))
        return F.log_softmax(x, dim=1)#self.layers[-1](x)

# Example usage:
input_size = X.shape[1]  # The number of input features
num_classes = y_oneHot.shape[1] # The number of output classes

model = ContactClassNet(input_size, num_classes, [256,256,256])

# Criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

summary(model)

In [35]:
X.shape

(454193, 20)