In [None]:
# Get unscaled_features from SMILES: get_unscaled_features(SMILES, DESCRIPTORS) in scale_graph_features

In [6]:
%load_ext autoreload
%autoreload 2
    
import pandas as pd
import numpy as np
import json
import gc
import re
import dgl
import torch


from rdkit.Chem import rdMolDescriptors, Descriptors
from rdkit import Chem

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
MODEL = 'MPNN'

SMILES = '../data_preprocessing/SMILES.txt'
DESCRIPTORS = '../monomer_data/unique_descriptors.json'

df = pd.read_csv('../data_preprocessing/db.csv')

unscaled_feats = get_unscaled_features(SMILES, DESCRIPTORS)

In [8]:
infer_set = df[df['ID'].str.contains('poly', na=False)]
infer_set.head()

Unnamed: 0,sequence,MIC_ecoli,ID,binary_class,3_classes,4_classes,5_classes,6_classes,7_classes,8_classes,9_classes,10_classes
11003,NiMoNiTmaTmaNiTmaNiNiNiMoNiTmaMoTmaMoTmaTmaMoN...,1024.0,polyID1_S1,0,2,3,4,,6,,,9
11004,NiTmaTmaTmaTmaMoTmaNiNiNiNiMoNiMoNiNiNiNiTmaTm...,1024.0,polyID1_S2,0,2,3,4,,6,,,9
11005,TmaNiNiTmaNiNiTmaTmaNiTmaNiNiNiTmaTmaTmaTmaTma...,1024.0,polyID1_S3,0,2,3,4,,6,,,9
11006,TmaMoNiTmaNiTmaMoNiMoTmaTmaNiTmaNiTmaTmaNiNiNi...,1024.0,polyID1_S4,0,2,3,4,,6,,,9
11007,TmaNiNiNiNiNiTmaNiTmaNiTmaTmaMoTmaTmaTmaTmaTma...,1024.0,polyID1_S5,0,2,3,4,,6,,,9


In [9]:
MODEL_PATH = 'model/'
NUM_WORKERS = 0

hyperparameters = json.load(open(MODEL_PATH + 'configure.json'))

data = list(zip(infer_set["ID"], infer_set["sequence"].apply(lambda x: seq_to_dgl(x))))

data_loader = DataLoader(
        dataset=data,
        batch_size=hyperparameters['batch_size'],
        shuffle=True,
        collate_fn=collate_molgraphs,
        num_workers=NUM_WORKERS,
    )

In [15]:
from utils.infer import infer

# GIVE DF with 'sequence' and 'ID' columns ==> predictions 

a = infer(GPU = 0, HYPERPARAMETERS = hyperparameters, MODEL_PATH = MODEL_PATH)
a.predict(data_loader)

ModuleNotFoundError: No module named 'utils.infer'

In [5]:
from torch.utils.data import DataLoader


def collate_molgraphs(data: list[tuple[str, dgl.DGLGraph]]):
    """
    Collate function for a list of tuples (ID, graph).
    """
    # seperate IDs and graphs
    IDs, graphs = map(list, zip(*data))
    bg = dgl.batch(graphs)

    bg.set_n_initializer(dgl.init.zero_initializer)
    bg.set_e_initializer(dgl.init.zero_initializer)

    return IDs, bg


def get_dataloader(DATA, BATCH_SIZE, NUM_WORKERS) -> DataLoader:
    return DataLoader(
        dataset=DATA,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_molgraphs,
        num_workers=NUM_WORKERS,
    )

def get_unscaled_features(SMILES, DESCRIPTORS):

    df_smiles = pd.read_csv(SMILES)
    descriptors_to_keep = pd.read_json(DESCRIPTORS).to_dict(orient='records')[0]

    unscaled_feats = {}
    
    for _type in df_smiles['type'].unique():

        df_type = df_smiles[df_smiles['type'] == _type]
        full_features = df_type['SMILES'].apply(
            lambda x: Descriptors.CalcMolDescriptors(Chem.MolFromSmiles(x), missingVal=-9999, silent=True)
        )
        features = full_features.map(lambda x: np.array([x[key] for key in descriptors_to_keep[_type]]))
        feats_dict = dict(zip(df_type['molecule'], features))
        unscaled_feats[_type] = feats_dict
    
    return unscaled_feats

def seq_to_dgl(sequence):
    monomers = re.findall(r"[A-Z][a-z]+", sequence)

    # Initialize DGL graph
    g = dgl.graph(([], []), num_nodes=len(monomers))

    # Featurize nodes
    node_features = [
        torch.tensor(unscaled_feats["monomer"]
                     [monomer], dtype=torch.float32)
        for monomer in monomers
    ]
    g.ndata["h"] = torch.stack(node_features)

    # Edges are between sequential monomers, i.e., (0->1, 1->2, etc.)
    src_nodes = list(range(len(monomers) - 1))  # Start nodes of edges
    dst_nodes = list(range(1, len(monomers)))  # End nodes of edges
    g.add_edges(src_nodes, dst_nodes)

    # Featurize edges
    edge_features = [
        torch.tensor(unscaled_feats["bond"]["Cc"], dtype=torch.float32)
    ] * g.number_of_edges()
    g.edata["e"] = torch.stack(edge_features)

    if MODEL == "GCN" or MODEL == "GAT":
        g = dgl.add_self_loop(g)

    return g