In [1]:
#This script is to generate network x object directly from xyz + tresp.
import re
from itertools import combinations
from math import sqrt
import math
import pandas as pd
import networkx as nx
import numpy as np
import torch
from torch_geometric.data import DataLoader
from torch.autograd import Variable
from torch.utils.data.sampler import SubsetRandomSampler
from my_dataset_4cluster_2 import *
from torch_geometric import utils
from torch_geometric.data import Data

In [2]:
My_dataset()[0]

Data(X=[508, 6], edge_attr=[43928], edge_index=[2, 43928], y=470.0)

In [3]:
def split_structure_connect(panda_df):
    """
    Given a pandas dataframe, spit out the data part and CONNECT part as two panda frames
    """
    CONNECT_START, CONNECT_END = 0,0
    for idx, row in panda_df.iterrows():
        if row[0] == "TER":
            CONNECT_START = idx
        if row[0] == "END":
            CONNECT_END = idx
    #print(CONNECT_START, CONNECT_END)    
    return (panda_df.iloc[0:CONNECT_START], panda_df.iloc[CONNECT_START+1:CONNECT_END])

def read_file(file_path):
    #read the file at file_path and return its structure and connect info.
    print('reading data from path:')
    print(str(file_path))
    data = pd.read_csv(file_path, 
                         header=None, 
                         delim_whitespace=True, 
                         keep_default_na=False, 
                         na_values=[''], 
                         names = ['a', 'b', 'c', 'd', 'e', "f"], engine='python')
    data_structure, data_CONNECT = split_structure_connect(data)
    data_structure = data_structure.drop(labels='f', axis=1)
    
    return (data_structure, data_CONNECT)    

def node_gen(file_path):
    #given the structure infomation, initialize the graph and generate only nodes.
    G = nx.Graph()
    for i, (element, x, y, z, tresp) in enumerate(read_file(file_path)[0].values):
        element = element[0]
        if element == 'M':
            element = 'Mg'
        G.add_nodes_from([i+1], 
                         atom_x = x, 
                         atom_y = float(y), 
                         atom_z = z, 
                         atom_tresp = tresp, 
                         atom_type = element, 
                        )
    return G

def list_bond_gen(file_path):
    #add the covalent bond to the edge index from the info provided by connect data.
    list_bond = [] #a list that record the [starting atom, end atom]
    #print('processing mol:')
    #print(file_path[0][7:13])
    for line in read_file(file_path)[1].iterrows():
        i = 0
        for element in line[1][1:]: # skip the CONNECT element, the rest element in this for loop are need to be connected.
            if str(element) != "nan": # skip the nan elements.
                if i == 0: #i.e. the start of the for loop. the first element.
                    atom_start = int(element)
                    #print("the starting atom is:")
                    #print(atom_start)
                    i+=1
                else: #the second, third... fourth of the element.
                    atom_connect = int(element)
                    #print("the connecting atom is:")
                    #print(atom_connect)
                    
                    list_bond.append([atom_start, atom_connect])
                    i+=1
    return (list_bond)

def covalent_bond_gen(G, list_bond):
    #takes in the graph(with only nodes now), 
    #and the list of bond (generated from connect info)
    #and spit out the covalent bonded graph.
    G.add_edges_from(list_bond)
    return G
    

def covalent_graph_gen(file_path):
    #generate a graph with the covalent bond.
    G = node_gen(file_path)
    covalent_bond_gen(G, list_bond_gen(file_path))
    return G

def combine_graph(G_local, H_local):
    #takes in two graph G and H; combine it.
    H_new = nx.Graph() #initialize the new graph
    for i in range(len(H_local.nodes)):
        #adding the nodes from H, preserving its attr
        H_new.add_nodes_from(
        [
        ((i+len(G_local.nodes)+1),
        H_local.nodes[i+1])
        ]
    )
        #adding the edges from H, preserving its attr
        H_new.add_edges_from(
        
            new_edges(H_local.edges, len(G_local.nodes))
        
    )
    #print(H_new.nodes)
    #print(G.nodes)
    U = nx.compose(G_local, H_new)
    return U

def generate_4_cluster_graph(list_file_path):
    #fist we generate the 4 seperate graphs.
    G1 = covalent_graph_gen(list_file_path[0])
    G2 = covalent_graph_gen(list_file_path[1])
    G3 = covalent_graph_gen(list_file_path[2])
    G4 = covalent_graph_gen(list_file_path[3])
    
    #here we set up the node molname attribute:
    G1.add_nodes_from((G1.nodes), molname = 1)
    G2.add_nodes_from((G2.nodes), molname = 1)
    G3.add_nodes_from((G3.nodes), molname = 1)
    G4.add_nodes_from((G4.nodes), molname = 2)
    
    print("The node number for each graph are:")
    print(len(G1.nodes), len(G2.nodes), len(G3.nodes), len(G4.nodes))
    
    #here we combine the graphs into a giant graph.(no inter-mol connected yet)
    U = combine_graph(G1,G2)
    U = combine_graph(U, G3)
    U = combine_graph(U, G4)
    
    print('The combined graph node number is: %d' %(len(U.nodes)))
    
    #here we set up the edge covalent_bond_weight attribute
    U.add_edges_from((U.edges), bond_weight = 1)
    
    return U

In [4]:
#file_path = ['./data/CLA610/frame0.csv']
#file_path_2 = ['./data/LUT620/frame1.csv']
#G = covalent_graph_gen(file_path)
#H = covalent_graph_gen(file_path_2)

In [5]:
#U = combine_graph(G,H)

In [6]:
def new_edges(list_edges, len_other_graph):
    new_list_edges = []
    for element in list_edges:
        new_start = element[0] + len_other_graph
        new_end = element[1] + len_other_graph
        new_list_edges.append((new_start, new_end))
    #print(new_list_edges)
    return new_list_edges

In [7]:
list_file_path = ['./data/CLA610/frame0.csv',
                  './data/CLA611/frame0.csv',
                  './data/CLA612/frame0.csv',
                  './data/LUT620/frame0.csv']

list_file_path = ['./data/ThreeCopies/7_1_1_conect_lifetimes/CLA610/frame1.csv',
                  './data/ThreeCopies/7_1_1_conect_lifetimes/CLA611/frame1.csv',
                  './data/ThreeCopies/7_1_1_conect_lifetimes/CLA612/frame1.csv',
                  './data/ThreeCopies/7_1_1_conect_lifetimes/LUT620/frame1.csv']   #data\ThreeCopies\7_1_1_conect_lifetimes\CLA610
covalent_bond_weight = 1

In [8]:
U_test = generate_4_cluster_graph(list_file_path)

reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA610/frame1.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA610/frame1.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA611/frame1.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA611/frame1.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA612/frame1.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA612/frame1.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/LUT620/frame1.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/LUT620/frame1.csv
The node number for each graph are:
136 137 137 98
The combined graph node number is: 508


In [9]:
for node in U_test.nodes.items():
    print(node)
    print(node[1]['atom_y'])
    break

(1, {'atom_x': 83.213, 'atom_y': 65.916, 'atom_z': 29.292, 'atom_tresp': -0.000723, 'atom_type': 'C', 'molname': 1})
65.916


In [10]:
import math
def adding_intermol_edge(Graph, distance_cutoff, inter_mol_K):
    #iterating the nodes in the graph, using node attrbute xyz to calculate the distance
    #if distance smaller than cutoff, define a edge between this 2 nodes
    #this also add bond_weight as parameter
    
    inter_mol_edge_list = []
    
    for node_1 in Graph.nodes.items():
        node_1_num = node_1[0]
        node_1_att = node_1[1]
        for node_2 in Graph.nodes.items():
            node_2_num = node_2[0]
            node_2_att = node_2[1]
            if node_1_num >= node_2_num:
                continue
            else:
                x1, y1, z1 = float(node_1_att['atom_x']), float(node_1_att['atom_y']), float(node_1_att['atom_z'])
                x2, y2, z2 = float(node_2_att['atom_x']), float(node_2_att['atom_y']), float(node_2_att['atom_z'])
                distance = sqrt((x1-x2)**2 + (y1-y2)**2 + (z1-z2)**2)
                if distance < distance_cutoff:
                    inter_mol_edge_list.append([node_1_num, node_2_num])
    print('within the cutoff, the inter mol are:')
    print(len(inter_mol_edge_list))
    #print(inter_mol_edge_list)
    
    #now that we find out all the pair-wise inter-mol interactions
    #time to give them edge and bond_weight
    
    Graph_new = Graph.copy()
    
    Graph_new.add_edges_from(inter_mol_edge_list, bond_weight = inter_mol_K)
    Graph_new.add_edges_from((Graph.edges), bond_weight = 1) #make the original covalent bond to 1.
    
    return Graph_new

In [11]:
U_test2 = adding_intermol_edge(U_test, 7.5, 0.2)

within the cutoff, the inter mol are:
21303


In [12]:
print(len(U_test2.nodes), len(U_test2.edges))

508 21303


In [13]:
list(U_test2.edges(data= True))

[(1, 4, {'bond_weight': 0.2}),
 (1, 5, {'bond_weight': 0.2}),
 (1, 6, {'bond_weight': 0.2}),
 (1, 9, {'bond_weight': 0.2}),
 (1, 10, {'bond_weight': 0.2}),
 (1, 11, {'bond_weight': 0.2}),
 (1, 12, {'bond_weight': 0.2}),
 (1, 13, {'bond_weight': 0.2}),
 (1, 14, {'bond_weight': 0.2}),
 (1, 15, {'bond_weight': 0.2}),
 (1, 16, {'bond_weight': 0.2}),
 (1, 17, {'bond_weight': 0.2}),
 (1, 20, {'bond_weight': 0.2}),
 (1, 21, {'bond_weight': 0.2}),
 (1, 24, {'bond_weight': 0.2}),
 (1, 25, {'bond_weight': 0.2}),
 (1, 26, {'bond_weight': 0.2}),
 (1, 28, {'bond_weight': 0.2}),
 (1, 29, {'bond_weight': 0.2}),
 (1, 30, {'bond_weight': 0.2}),
 (1, 31, {'bond_weight': 0.2}),
 (1, 32, {'bond_weight': 0.2}),
 (1, 33, {'bond_weight': 0.2}),
 (1, 35, {'bond_weight': 0.2}),
 (1, 36, {'bond_weight': 0.2}),
 (1, 37, {'bond_weight': 0.2}),
 (1, 38, {'bond_weight': 0.2}),
 (1, 39, {'bond_weight': 0.2}),
 (1, 42, {'bond_weight': 0.2}),
 (1, 43, {'bond_weight': 0.2}),
 (1, 44, {'bond_weight': 0.2}),
 (1, 45, {'b

In [None]:
#Here we generate graph for each snapshots and save it into files
#This can reduce the dataloader time.

MDset_list = ["./data/ThreeCopies/7_1_1_conect_lifetimes/",
              "./data/ThreeCopies/7_2_1_conect_lifetimes/",
              "./data/ThreeCopies/7_3_1_conect_lifetimes/"]

pigment_dir_list = ['CLA610/',
            'CLA611/',
            'CLA612/',
            'LUT620/']

label_path = "./data/ThreeCopies/ThreeCopies_all_lifetimes.csv"




distance_cutoff = 7.5
weak_bond_weight = 0.2 #the rigid covalent bond weight is 1


for num, MDset_path in enumerate(MDset_list):
    print("processing MD data set:")
    print(MDset_path)
    data_list = [] # this stores the pre-processed torch data points as a list
    for index in range(1000):
        work_path = []
        for pigment_dir in pigment_dir_list:
            work_path.append(MDset_path + pigment_dir + 'frame' + str(index+1) + '.csv')
        #here we define the workpath for 4 pigments.
        G = generate_4_cluster_graph(work_path)
        G = adding_intermol_edge(G, distance_cutoff, weak_bond_weight)

        data_list.append(utils.from_networkx(G)) #append the datapoint to the list.

        print('Graph generation for frame %d has been done.' %index)
    
    save_path = './data/ThreeCopies/data/'
    for idx, data in enumerate(data_list):
        if num == 0:
            torch.save(data, save_path + 'data_as_graphs_25062021_frame' + str(idx) + '.pt')    
        elif num == 1:
            idx += 1000
            torch.save(data, save_path + 'data_as_graphs_25062021_frame' + str(idx) + '.pt')  
        elif num == 2:
            idx += 2000
            torch.save(data, save_path + 'data_as_graphs_25062021_frame' + str(idx) + '.pt')  
        print("data set No.%d saved" %num)
"""
    #here we seperate them in different files, so we don't access the whole, in order to save time.
    #for the ThreeCopies dataset. everything is stored together, numbering from copy 1 to copy 3, from 1 to 3000.
    save_path = './data/ThreeCopies/data/'
    for idx, data in enumerate(data_list):
        torch.save(data, save_path + 'data_as_graphs_25062021_frame' + str(idx) + '.pt')    
"""   


processing MD data set:
./data/ThreeCopies/7_1_1_conect_lifetimes/
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA610/frame1.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA610/frame1.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA611/frame1.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA611/frame1.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA612/frame1.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA612/frame1.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/LUT620/frame1.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/LUT620/frame1.csv
The node number for each graph are:
136 137 137 98
The combined graph node number is: 508
within the cutoff, the inter mol are:
21303
Graph generation for frame 0 has been done.
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA610/frame2.csv

within the cutoff, the inter mol are:
21382
Graph generation for frame 9 has been done.
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA610/frame11.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA610/frame11.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA611/frame11.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA611/frame11.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA612/frame11.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/CLA612/frame11.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/LUT620/frame11.csv
reading data from path:
./data/ThreeCopies/7_1_1_conect_lifetimes/LUT620/frame11.csv
The node number for each graph are:
136 137 137 98
The combined graph node number is: 508
within the cutoff, the inter mol are:
21200
Graph generation for frame 10 has been done.
reading data from path:
./data/ThreeCopies/7_1_1_cone

In [None]:
#torch.save(data_list, './data_as_graphs_20042021.pt')

In [None]:
import torch


In [None]:
#Test the graph generation.

from torch_geometric import utils
from torch_geometric.data import Data

In [None]:
data = utils.from_networkx(U_test2)

In [None]:
data

In [None]:
data.edge_index

In [None]:
from my_dataset_4cluster_2 import *

In [None]:
My_dataset()

In [None]:
for element in My_dataset():
    print(element)
    a = element
    break

In [None]:
np.random.seed(1)
torch.random.manual_seed(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
############################################################################
#data loader initalization
#DATA_PATH = './data/'
dataset_size = 1000
perm = torch.randperm(dataset_size).numpy()
partition = {}
partition["train"] = perm[:int(dataset_size*8/10)]
partition["validation"] = perm[int(dataset_size*8/10):int(dataset_size*9/10)]
partition["test"] = perm[int(dataset_size*9/10):]

train_loader = DataLoader(torch.utils.data.Subset(My_dataset(), partition["train"]), 
                          batch_size=16, 
                          #sampler=train_sampler,
                          shuffle=True, 
                          num_workers=8)


In [None]:
for element in train_loader:
    print(element)
    break