In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import obonet
import random
import torch
import math
from Bio import SeqIO
import Bio.PDB
import urllib.request
import py3Dmol
import pylab
import pickle as pickle
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GATConv
from torch_geometric.nn import GATv2Conv
from torch_geometric.nn import GENConv
from torch_geometric.nn.models import MLP
from torch_geometric.data import Data
from torch_geometric.nn.pool import SAGPooling
from torch_geometric.nn.aggr import MeanAggregation
import matplotlib.pyplot as plt
import os
from Bio import PDB
from rdkit import Chem
import blosum as bl

In [2]:
class CFG:
    pdbfiles: str = "/home/paul/BioHack/pdbind-refined-set/"
    AA_mol2_files: str = "/home/paul/BioHack/AA_mol2/"

In [3]:
with open('atom2emb.pkl', 'rb') as f:
    atom2emb = pickle.load(f)
    
with open('AA_embeddings_11172023.pkl', 'rb') as f:
    AA_embeddings = pickle.load(f)
    
with open('bond_type_dict.pkl', 'rb') as f:
    bond_type_dict = pickle.load(f)

def get_atom_symbol(atomic_number):
    return Chem.PeriodicTable.GetElementSymbol(Chem.GetPeriodicTable(), atomic_number)

def remove_hetatm(input_pdb_file, output_pdb_file):
    # Open the input PDB file for reading and the output PDB file for writing
    with open(input_pdb_file, 'r') as infile, open(output_pdb_file, 'w') as outfile:
        for line in infile:
            # Check if the line starts with 'HETATM' (non-protein atoms)
            if line.startswith('HETATM'):
                continue  # Skip this line (HETATM record)
            # Write all other lines to the output file
            outfile.write(line)
            
def get_atom_types_from_sdf(sdf_file):
    supplier = Chem.SDMolSupplier(sdf_file)
    atom_types = set()

    for mol in supplier:
        if mol is not None:
            atoms = mol.GetAtoms()
            atom_types.update([atom.GetSymbol() for atom in atoms])

    return sorted(list(atom_types))

def get_atom_types_from_mol2_split(mol2_file):
    atom_types = set()

    with open(mol2_file, 'r') as mol2:
        reading_atoms = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>ATOM':
                reading_atoms = True
                continue
            elif line.strip() == '@<TRIPOS>BOND':
                break

            if reading_atoms:
                parts = line.split()
                if len(parts) >= 5:
                    atom_type = parts[5]
                    atom_types.add(atom_type)
    
    atom_types_split = set()
    for atom in atom_types:
        atom_types_split.add(str(atom).split('.')[0])
        

    return sorted(list(atom_types_split))

def get_atom_types_from_mol2(mol2_file):
    atom_types = set()

    with open(mol2_file, 'r') as mol2:
        reading_atoms = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>ATOM':
                reading_atoms = True
                continue
            elif line.strip() == '@<TRIPOS>BOND':
                break

            if reading_atoms:
                parts = line.split()
                if len(parts) >= 5:
                    atom_type = parts[5]
                    atom_types.add(atom_type)

    return sorted(list(atom_types))

def get_atom_list_from_mol2_split(mol2_file):
    atoms = []
    with open(mol2_file, 'r') as mol2:
        reading_atoms = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>ATOM':
                reading_atoms = True
                continue
            elif line.strip() == '@<TRIPOS>BOND':
                break

            if reading_atoms:
                parts = line.split()
                if len(parts) >= 5:
                    atom_type = parts[5]
                    atoms.append(atom_type)
    
    atom_list = []
    for atom in atoms:
        atom_list.append(str(atom).split('.')[0])
        

    return atom_list

def get_atom_list_from_mol2(mol2_file):
    atoms = []
    with open(mol2_file, 'r') as mol2:
        reading_atoms = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>ATOM':
                reading_atoms = True
                continue
            elif line.strip() == '@<TRIPOS>BOND':
                break

            if reading_atoms:
                parts = line.split()
                if len(parts) >= 5:
                    atom_type = parts[5]
                    atoms.append(atom_type)

    return atoms

def get_bond_types_from_mol2(mol2_file):
    bond_types = set()

    with open(mol2_file, 'r') as mol2:
        reading_bonds = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>BOND':
                reading_bonds = True
                continue
            elif reading_bonds and line.strip().startswith('@<TRIPOS>'):
                break

            if reading_bonds:
                parts = line.split()
                if len(parts) >= 4:
                    bond_type = parts[3]
                    bond_types.add(bond_type)

    return sorted(list(bond_types))

def read_mol2_bonds(mol2_file):
    bonds = []
    bond_types = []

    with open(mol2_file, 'r') as mol2:
        reading_bonds = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>BOND':
                reading_bonds = True
                continue
            elif reading_bonds and line.strip().startswith('@<TRIPOS>'):
                break

            if reading_bonds:
                parts = line.split()
                if len(parts) >= 4:
                    atom1_index = int(parts[1])
                    atom2_index = int(parts[2])
                    bond_type = parts[3]
                    bonds.append((atom1_index, atom2_index))
                    bond_types.append(bond_type)

    return bonds, bond_types

def calc_residue_dist(residue_one, residue_two) :
    """Returns the C-alpha distance between two residues"""
    diff_vector  = residue_one["CA"].coord - residue_two["CA"].coord
    return np.sqrt(np.sum(diff_vector * diff_vector))

def calc_dist_matrix(chain_one, chain_two) :
    """Returns a matrix of C-alpha distances between two chains"""
    answer = np.zeros((len(chain_one), len(chain_two)), float)
    for row, residue_one in enumerate(chain_one) :
        for col, residue_two in enumerate(chain_two) :
            answer[row, col] = calc_residue_dist(residue_one, residue_two)
    return answer

def calc_contact_map(uniID,map_distance):
    pdb_code = uniID
    pdb_filename = uniID+"_pocket_clean.pdb"
    structure = Bio.PDB.PDBParser(QUIET = True).get_structure(pdb_code, (CFG.pdbfiles +'/'+pdb_code+'/'+pdb_filename))
    model = structure[0]
    flag1 = 0
    flag2 = 0
    idx = 0
    index = []
    chain_info = []
    
    for chain1 in model:
        for resi in chain1:
            index.append(idx)
            idx += 1
            chain_info.append([chain1.id,resi.id])
        for chain2 in model:
            if flag1 == 0:
                dist_matrix = calc_dist_matrix(model[chain1.id], model[chain2.id])
            else:
                new_matrix = calc_dist_matrix(model[chain1.id], model[chain2.id])
                dist_matrix = np.hstack((dist_matrix,new_matrix))
            flag1 += 1
        flag1 = 0
        if flag2 == 0:
            top_matrix = dist_matrix
        else:
            top_matrix = np.vstack((top_matrix,dist_matrix))
        flag2 += 1
    
    contact_map = top_matrix < map_distance
    return contact_map, index, chain_info

one_letter_to_three_letter_dict = {'G':'gly',
                                   'A':'ala',
                                   'V':'val',
                                   'C':'cys',
                                   'P':'pro',
                                   'L':'leu',
                                   'I':'ile',
                                   'M':'met',
                                   'W':'trp',
                                   'F':'phe',
                                   'K':'lys',
                                   'R':'arg',
                                   'H':'his',
                                   'S':'ser',
                                   'T':'thr',
                                   'Y':'tyr',
                                   'N':'asn',
                                   'Q':'gln',
                                   'D':'asp',
                                   'E':'glu'
    
}

def BLOSUM_encode_single(seq,AA_dict):
    allowed = set("gavcplimwfkrhstynqdeuogavcplimwfkrhstynqde")
    if not set(seq).issubset(allowed):
        invalid = set(seq) - allowed
        raise ValueError(f"Sequence has broken AA: {invalid}")
    vec = AA_dict[seq]
    return vec

matrix = bl.BLOSUM(62)
allowed_AA = "GAVCPLIMWFKRHSTYNQDE"
BLOSUM_dict_three_letter = {}
for i in allowed_AA:
    vec = []
    for j in allowed_AA:
        vec.append(matrix[i][j])
    BLOSUM_dict_three_letter.update({one_letter_to_three_letter_dict[i]:torch.Tensor(vec)})

def uniID2graph(uniID,map_distance):
    atom_name = 'CA'
    node_feature = []
    edge_index = []
    edge_attr = []
    coord = []
    contact_map, index, chain_info = calc_contact_map(uniID,map_distance)
    pdb_code = uniID
    pdb_filename = uniID+"_pocket_clean.pdb"
    structure = Bio.PDB.PDBParser(QUIET = True).get_structure(pdb_code, (CFG.pdbfiles +'/'+pdb_code+'/'+pdb_filename))
    model = structure[0]
    
    for i in index:
        node_feature.append(AA_embeddings[model[chain_info[i][0]][chain_info[i][1]].get_resname()])
        coord.append(model[chain_info[i][0]][chain_info[i][1]]['CA'].coord)
        for j in index:
            if contact_map[i,j] == 1:
                edge_index.append([i,j])
                diff_vector = model[chain_info[i][0]][chain_info[i][1]]['CA'].coord - model[chain_info[j][0]][chain_info[j][1]]['CA'].coord
                dist = (np.sqrt(np.sum(diff_vector * diff_vector))/map_distance)
                bond_type = bond_type_dict['nc']
                edge_attr.append(np.hstack((dist,bond_type)))
                            
    edge_index = np.array(edge_index)
    edge_index = edge_index.transpose()
    edge_index = torch.Tensor(edge_index)
    edge_index = edge_index.to(torch.int64)
    edge_attr = torch.Tensor(edge_attr)
    node_feature = torch.stack(node_feature)
    graph = Data(x = node_feature, edge_index = edge_index,edge_attr = edge_attr)
    return graph, coord

def read_mol2_bonds_and_atoms(mol2_file):
    bonds = []
    bond_types = []
    atom_types = {}
    atom_coordinates = {}

    with open(mol2_file, 'r') as mol2:
        reading_bonds = False
        reading_atoms = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>BOND':
                reading_bonds = True
                continue
            elif line.strip() == '@<TRIPOS>ATOM':
                reading_atoms = True
                continue
            elif line.strip().startswith('@<TRIPOS>SUBSTRUCTURE'):
                break
            elif reading_bonds and line.strip().startswith('@<TRIPOS>'):
                reading_bonds = False
            elif reading_atoms and line.strip().startswith('@<TRIPOS>'):
                reading_atoms = False


            if reading_bonds:
                parts = line.split()
                if len(parts) >= 4:
                    atom1_index = int(parts[1])
                    atom2_index = int(parts[2])
                    bond_type = parts[3]
                    bonds.append((atom1_index, atom2_index))
                    bond_types.append(bond_type)

            if reading_atoms:
                parts = line.split()
                if len(parts) >= 6:
                    atom_index = int(parts[0])
                    atom_type = parts[5]
                    x, y, z = float(parts[2]), float(parts[3]), float(parts[4])
                    atom_types[atom_index] = atom_type.split('.')[0]
                    atom_coordinates[atom_index] = (x, y, z)

    return bonds, bond_types, atom_types, atom_coordinates

def molecule2graph(filename,map_distance):
    node_feature = []
    edge_index = []
    edge_attr = []
    mol2_file = CFG.pdbfiles+filename+'/'+filename+'_ligand.mol2'
    bonds, bond_types, atom_types, atom_coordinates = read_mol2_bonds_and_atoms(mol2_file)
    for atom in atom_types:
        #node_feature.append(torch.zeros(20))
        node_feature.append(torch.Tensor(atom2emb[atom_types[atom]]))
    for i in range(len(bonds)):
        bond = bonds[i]
        edge_index.append([bond[0] - 1,bond[1] - 1])
        coord1 = np.array(atom_coordinates[bond[0]])
        coord2 = np.array(atom_coordinates[bond[1]])
        dist = [np.sqrt(np.sum((coord1 - coord2)*(coord1 - coord2)))/map_distance]
        bond_type = bond_type_dict[bond_types[i]]
        edge_attr.append(np.hstack((dist,bond_type)))
    
    edge_index = np.array(edge_index)
    edge_index = edge_index.transpose()
    edge_index = torch.Tensor(edge_index)
    edge_index = edge_index.to(torch.int64)
    edge_attr = torch.Tensor(edge_attr)
    node_feature = torch.stack(node_feature)
    graph = Data(x = node_feature, edge_index = edge_index,edge_attr = edge_attr)
    
    return graph, atom_coordinates

def id2fullgraph(filename, map_distance):
    prot_graph, prot_coord = uniID2graph(filename,map_distance)
    prot_graph = prot_graph.to('cpu')
    mol_graph, mol_coord = molecule2graph(filename,map_distance)
    mol_graph = mol_graph.to('cpu')
    mol_coord = [mol_coord[i] for i in mol_coord]
    node_features = torch.cat((prot_graph.x,mol_graph.x),dim = 0)
    update_edge_index = mol_graph.edge_index + prot_graph.x.size()[0]
    edge_index = torch.cat((prot_graph.edge_index,update_edge_index), dim = 1)
    edge_attr = torch.cat((prot_graph.edge_attr,mol_graph.edge_attr), dim = 0)
    
    new_edge_index = []
    new_edge_attr = []
    for i in range(len(mol_coord)):
        for j in range(len(prot_coord)):
            dist_vec = mol_coord[i] - prot_coord[j]
            dist = np.sqrt(np.sum(dist_vec*dist_vec))/map_distance
            if dist < 1.0:
                new_edge_index.append([j,i + len(prot_coord)])
                new_edge_attr.append((np.hstack(([dist],bond_type_dict['nc']))))
                
    new_edge_index = np.array(new_edge_index)
    new_edge_index = new_edge_index.transpose()
    new_edge_index = torch.Tensor(new_edge_index)
    new_edge_index = new_edge_index.to(torch.int64)
    new_edge_attr = torch.Tensor(new_edge_attr)
    
    edge_index = torch.cat((edge_index,new_edge_index), dim = 1)
    edge_attr = torch.cat((edge_attr,new_edge_attr), dim = 0)
    
    graph = Data(x = node_features, edge_index = edge_index,edge_attr = edge_attr)
    
    return graph

def molecule2graph_AA(filename,map_distance):
    node_feature = []
    edge_index = []
    edge_attr = []
    mol2_file = filename
    bonds, bond_types, atom_types, atom_coordinates = read_mol2_bonds_and_atoms(mol2_file)
    for atom in atom_types:
        #node_feature.append(torch.zeros(20))
        node_feature.append(torch.Tensor(atom2emb[atom_types[atom]]))
    for i in range(len(bonds)):
        bond = bonds[i]
        edge_index.append([bond[0] - 1,bond[1] - 1])
        coord1 = np.array(atom_coordinates[bond[0]])
        coord2 = np.array(atom_coordinates[bond[1]])
        dist = [np.sqrt(np.sum((coord1 - coord2)*(coord1 - coord2)))/map_distance]
        bond_type = bond_type_dict[bond_types[i]]
        edge_attr.append(np.hstack((dist,bond_type)))
    
    #Master_node
    node_feature.append(torch.zeros(len(atom2emb['N'])))
    
    for i in range(len(node_feature) - 1):
        edge_index.append([i,int(len(node_feature)-1)])
        bond_type = bond_type_dict['1']
        edge_attr.append(np.hstack((1.0,bond_type)))
    
    edge_index = np.array(edge_index)
    edge_index = edge_index.transpose()
    edge_index = torch.Tensor(edge_index)
    edge_index = edge_index.to(torch.int64)
    edge_attr = torch.Tensor(edge_attr)
    node_feature = torch.stack(node_feature)
    graph = Data(x = node_feature, edge_index = edge_index,edge_attr = edge_attr)
    
    return graph, atom_coordinates

upper2lower = {
    "ala": "ALA",
    "arg": "ARG",
    "asn": "ASN",
    "asp": "ASP",
    "cys": "CYS",
    "gln": "GLN",
    "glu": "GLU",
    "gly": "GLY",
    "his": "HIS",
    "ile": "ILE",
    "leu": "LEU",
    "lys": "LYS",
    "met": "MET",
    "phe": "PHE",
    "pro": "PRO",
    "ser": "SER",
    "thr": "THR",
    "trp": "TRP",
    "tyr": "TYR",
    "val": "VAL",
}

def createMask(graph,indicies,num_masked):
    size = graph.x.size()[0]
    protein_mask = [False]*size
    true_mask = [True] * num_masked
    indicies_mask = [False]*(len(indicies) - num_masked)
    design_mask = np.hstack((true_mask,indicies_mask))
    random.shuffle(design_mask)
    
    count = 0
    for i in range(len(protein_mask)):
        if i in indicies:
            protein_mask[i] = design_mask[count]
            count += 1
            
    for i, j in enumerate(protein_mask):
        if j == 1.0:
            protein_mask[i] = True
    
    return protein_mask

In [4]:
graph_list = torch.load('graphs_w_designable_indicies_mn_11172023')

In [5]:
smallest = 12
count = 0
graph_list_clean = []
for entry in graph_list:
    if len(entry.designable_indicies) >= smallest:
        graph_list_clean.append(entry)
        count += 1

In [15]:
for i, graph in enumerate(graph_list_clean):
    graph_list_clean[i].mask = createMask(graph,graph.designable_indicies,int(len(graph.designable_indicies)))
    graph_list_clean[i].inv_mask = [not i for i in graph_list_clean[i].mask]

In [16]:
graph = random.choice(graph_list_clean)

In [18]:
print(graph.mask)
print(graph.inv_mask)

[True, False, True, False, True, False, True, True, True, True, False, True, False, False, True, True, True, True, False, False, False, False, False, True, True, True, True, False, True, False, False, False, False, True, False, True, False, True, False, True, True, True, True, True, True, False, False, False, True, True, True, True, False, False, False, False, False, False, True, False, True, True, True, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]
[False, True, False

In [40]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.node_feature_size = 133
        self.node_feature_hidden_size = 250
        self.node_feature_size_out = 133
        self.dropout = 0.1
        self.Droput = nn.Dropout(p = self.dropout)
        self.conv1 = GENConv(self.node_feature_size,self.node_feature_hidden_size,aggr = 'mean',edge_dim = 7,num_layer = 3,norm = 'layer',expansion = 4)
        self.conv2 = GATv2Conv(self.node_feature_hidden_size,self.node_feature_hidden_size, edge_dim = 7, heads = 7,concat = False, dropout = self.dropout)
        self.conv3 = GENConv(self.node_feature_hidden_size,self.node_feature_hidden_size,aggr = 'mean',edge_dim = 7,num_layer = 3, norm = 'layer',expansion = 4)
        self.conv4 = GATv2Conv(self.node_feature_hidden_size,self.node_feature_size_out, edge_dim = 7, heads = 7,concat = False, dropout = self.dropout)
        self.ReLu = nn.ReLU()
        self.tanh = nn.Tanh()
        
    def forward(self,graph):
        x, edge_index, edge_attr,inv_mask = graph.x,graph.edge_index,graph.edge_attr,graph.inv_mask
        x1 = self.conv1(x, edge_index,edge_attr)
        #x1[inv_mask] = x[inv_mask]
        x1 = self.ReLu(x1)
        x1 = self.conv2(x1, edge_index,edge_attr)
        #x1[inv_mask] = x[inv_mask]
        x1 = self.ReLu(x1)
        x1 = self.conv3(x1, edge_index,edge_attr)
        #x1[inv_mask] = x[inv_mask]
        x1 = self.ReLu(x1)
        x1 = self.conv4(x1, edge_index,edge_attr)
        x1 = self.tanh(x1)

        return x1

In [20]:
#for i,graph in enumerate(graphs):
#    mask1 =  createMask(graph.label,graph,1)
#    mask2 = createMask(graph.label,graph,1)
#    if mask1.all() == mask2.all():
#        mask2 = createMask(graph.label,graph,1)
#        if mask1.all() == mask2.all():
#            mask2 = createMask(graph.label,graph,1)
#    graphs[i].train_mask = mask1
#    graphs[i].val_mask = mask2
            

In [8]:
torch.save(graph_list_clean,'full_graphs_mn_rm_12.pt')

In [21]:
#def quick_split(prot_list, split_frac=0.8):
#    '''
#    Given a df of samples, randomly split indices between
#    train and test at the desired fraction
#    '''
#
#    # shuffle indices
#    idxs = list(range(len(prot_list)))
#    random.shuffle(idxs)
#
#    # split shuffled index list by split_frac
#    train_idxs = idxs[:split]
#    test_idxs = idxs[split:]
#    
#    # split dfs and return
#    train_data = [prot_list[i] for i in train_idxs]
#    test_data = [prot_list[i] for i in test_idxs]
#        
#    return train_data, test_data, train_idxs, test_idxs
    
    
#full_train_data, test_data,train_idxs, test_idxs = quick_split(graph_list_clean)
#train_data, val_data,train_idxs, test_idxs = quick_split(graph_list_clean)

#print("Train:", len(train_data))
#print("Val:", len(val_data))
#print("Test:", len(test_data))

In [22]:
#idxs = [torch.Tensor(train_idxs),torch.Tensor(test_idxs)]
#torch.save(idxs, 'initial_training_idx_11092023.pt')

In [23]:
idxs = torch.load('initial_training_idx_11092023.pt')
train_idxs = idxs[0].detach().numpy()
val_idxs = idxs[1].detach().numpy()
train_data = [graph_list_clean[int(i)] for i in train_idxs]
val_data = [graph_list_clean[int(i)] for i in val_idxs]

In [24]:
from torch_geometric.loader import DataLoader
full_dl = DataLoader(graph_list,batch_size = 1, shuffle = True)
train_dl = DataLoader(train_data,batch_size = 1, shuffle = True)
val_dl = DataLoader(val_data,batch_size = 1, shuffle = True)

In [25]:
graph = random.choice(graph_list)
mask = createMask(graph,graph.designable_indicies,int(len(graph.designable_indicies)))

In [26]:
truth = graph.x[mask]
with torch.no_grad():
    for i, j in enumerate(mask):
        if j == True:
            graph.x[i] = torch.zeros(133)

In [27]:
print(mask)
print(graph.x[mask])

[False, False, False, True, False, False, False, False, True, True, False, False, True, False, False, False, False, True, True, True, False, True, True, False, False, False, False, True, True, True, True, False, False, False, False, True, False, True, False, False, False, True, True, False, True, False, False, False, False, True, False, False, True, True, True, True, True, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], grad_fn=<IndexBackward0>)


In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Net()
#model.load_state_dict(torch.load('large_model_ckpt_layer_norm1.pt'))
model.to(DEVICE) # put on GPU

# Define a loss function (e.g., Mean Squared Error) and an optimizer (e.g., Adam)
criterion = loss_func = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-07)
#optimizer.load_state_dict(torch.load('large_model_ckpt_opt_layer_norm1.pt'))

# Training loop
num_epochs = 10000  # Adjust the number of epochs as needed
losses = []

for epoch in range(num_epochs):
    total_loss = 0.0
    val_loss = 0.0
        
    for batch in train_dl:
        model.train()
        inputs = batch[0].to(DEVICE)
        
        inputs.mask = createMask(inputs,inputs.designable_indicies,12)
        #mask = createMask(inputs,inputs.designable_indicies,int(len(inputs.designable_indicies)))
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        truth = inputs.x[inputs.mask]
        with torch.no_grad():
            for i, j in enumerate(inputs.mask):
                if j == True:
                    inputs.x[i] = torch.zeros(133)
        
        outputs = model(inputs)
        
        # Compute the loss
        loss = criterion(outputs[inputs.mask], truth)
        
        # Backpropagation and optimization
        loss.backward()
        optimizer.step()
        
        inputs= inputs.to('cpu')
        
        total_loss += loss.item()
        
    for batch in val_dl:
        with torch.no_grad():
            model.eval()
            inputs = batch[0].to(DEVICE)
            
            inputs.mask = createMask(inputs,inputs.designable_indicies,12)
            #mask = createMask(inputs,inputs.designable_indicies,int(len(inputs.designable_indicies)))
        
            truth = inputs.x[inputs.mask]
            for i, j in enumerate(inputs.mask):
                if j == True:
                    inputs.x[i] = torch.zeros(133)
                    
            outputs = model(inputs)
            
            loss = criterion(outputs[inputs.mask], truth)
        
            inputs= inputs.to('cpu')
        
        val_loss += loss.item()
    
    # Print the average loss for this epoch
    avg_loss = total_loss / len(train_dl)
    avg_val_loss = val_loss / len(val_dl)
    print(f'Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f} Val Loss: {avg_val_loss:.4f}')
    losses.append([avg_loss,avg_val_loss])
    
    #if epoch > 100:
    #    torch.save(autoencoder,('autoencoder_98var_10062023_'+str(epoch)))

print('Training complete')


  truth = inputs.x[inputs.mask]
  loss = criterion(outputs[inputs.mask], truth)
  truth = inputs.x[inputs.mask]
  loss = criterion(outputs[inputs.mask], truth)


Epoch [1/10000] Loss: 0.4327 Val Loss: 0.4286
Epoch [2/10000] Loss: 0.4281 Val Loss: 0.4240
Epoch [3/10000] Loss: 0.4221 Val Loss: 0.4183
Epoch [4/10000] Loss: 0.4164 Val Loss: 0.4199
Epoch [5/10000] Loss: 0.4092 Val Loss: 0.4039
Epoch [6/10000] Loss: 0.4033 Val Loss: 0.4019
Epoch [7/10000] Loss: 0.3982 Val Loss: 0.3966
Epoch [8/10000] Loss: 0.3943 Val Loss: 0.3939
Epoch [9/10000] Loss: 0.3918 Val Loss: 0.3939
Epoch [10/10000] Loss: 0.3887 Val Loss: 0.3883
Epoch [11/10000] Loss: 0.3859 Val Loss: 0.3846
Epoch [12/10000] Loss: 0.3826 Val Loss: 0.3835
Epoch [13/10000] Loss: 0.3805 Val Loss: 0.3793
Epoch [14/10000] Loss: 0.3770 Val Loss: 0.3779
Epoch [15/10000] Loss: 0.3746 Val Loss: 0.3771
Epoch [16/10000] Loss: 0.3725 Val Loss: 0.3767
Epoch [17/10000] Loss: 0.3706 Val Loss: 0.3728
Epoch [18/10000] Loss: 0.3689 Val Loss: 0.3688
Epoch [19/10000] Loss: 0.3661 Val Loss: 0.3664
Epoch [20/10000] Loss: 0.3637 Val Loss: 0.3673
Epoch [21/10000] Loss: 0.3616 Val Loss: 0.3641
Epoch [22/10000] Loss:

Epoch [174/10000] Loss: 0.2028 Val Loss: 0.2448
Epoch [175/10000] Loss: 0.2028 Val Loss: 0.2436
Epoch [176/10000] Loss: 0.2025 Val Loss: 0.2445
Epoch [177/10000] Loss: 0.2010 Val Loss: 0.2455
Epoch [178/10000] Loss: 0.2015 Val Loss: 0.2432
Epoch [179/10000] Loss: 0.2013 Val Loss: 0.2413
Epoch [180/10000] Loss: 0.1993 Val Loss: 0.2428
Epoch [181/10000] Loss: 0.1983 Val Loss: 0.2423
Epoch [182/10000] Loss: 0.1982 Val Loss: 0.2441
Epoch [183/10000] Loss: 0.1974 Val Loss: 0.2411
Epoch [184/10000] Loss: 0.1974 Val Loss: 0.2436
Epoch [185/10000] Loss: 0.1975 Val Loss: 0.2391
Epoch [186/10000] Loss: 0.1966 Val Loss: 0.2447
Epoch [187/10000] Loss: 0.1971 Val Loss: 0.2417
Epoch [188/10000] Loss: 0.1950 Val Loss: 0.2416
Epoch [189/10000] Loss: 0.1956 Val Loss: 0.2430
Epoch [190/10000] Loss: 0.1947 Val Loss: 0.2385
Epoch [191/10000] Loss: 0.1959 Val Loss: 0.2405
Epoch [192/10000] Loss: 0.1943 Val Loss: 0.2436
Epoch [193/10000] Loss: 0.1939 Val Loss: 0.2412
Epoch [194/10000] Loss: 0.1930 Val Loss:

Epoch [345/10000] Loss: 0.1462 Val Loss: 0.2168
Epoch [346/10000] Loss: 0.1452 Val Loss: 0.2127
Epoch [347/10000] Loss: 0.1438 Val Loss: 0.2152
Epoch [348/10000] Loss: 0.1456 Val Loss: 0.2099
Epoch [349/10000] Loss: 0.1443 Val Loss: 0.2093
Epoch [350/10000] Loss: 0.1436 Val Loss: 0.2134
Epoch [351/10000] Loss: 0.1451 Val Loss: 0.2140


In [72]:
#torch.save(model.state_dict(), 'flexible_number_of_binding_sites_11172023.pt')
#torch.save(optimizer.state_dict(),'flexible_number_of_binding_sites_opt_11172023.pt')

In [73]:
#torch.save(torch.Tensor(losses),'flexible_number_of_binding_sites_loss_11172023.pt')