# Prepare protein structure dataset

- Parse the PDB files into json documents
- Parse the CATH labels from RCSB API, merge into the json documents
- Create index on PDB ID
- Insert json documents into DocumentDB

In [1]:
import os
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('seaborn')
import seaborn as sns
%matplotlib inline

## Download data from [AlphaFold](https://alphafold.ebi.ac.uk/download)

In [3]:
!wget https://ftp.ebi.ac.uk/pub/databases/alphafold/UP000000805_243232_METJA.tar

--2021-08-19 20:49:07--  https://ftp.ebi.ac.uk/pub/databases/alphafold/UP000000805_243232_METJA.tar
Resolving ftp.ebi.ac.uk (ftp.ebi.ac.uk)... 193.62.197.74
Connecting to ftp.ebi.ac.uk (ftp.ebi.ac.uk)|193.62.197.74|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 178278400 (170M) [application/octet-stream]
Saving to: ‘UP000000805_243232_METJA.tar’


2021-08-19 20:49:13 (28.3 MB/s) - ‘UP000000805_243232_METJA.tar’ saved [178278400/178278400]



In [5]:
!cd data && tar -xvf UP000000805_243232_METJA.tar

AF-O06917-F1-model_v1.cif.gz
AF-O06917-F1-model_v1.pdb.gz
AF-O06918-F1-model_v1.cif.gz
AF-O06918-F1-model_v1.pdb.gz
AF-O53113-F1-model_v1.cif.gz
AF-O53113-F1-model_v1.pdb.gz
AF-P0CL56-F1-model_v1.cif.gz
AF-P0CL56-F1-model_v1.pdb.gz
AF-P0CW37-F1-model_v1.cif.gz
AF-P0CW37-F1-model_v1.pdb.gz
AF-P0CW38-F1-model_v1.cif.gz
AF-P0CW38-F1-model_v1.pdb.gz
AF-P0CW39-F1-model_v1.cif.gz
AF-P0CW39-F1-model_v1.pdb.gz
AF-P0CW76-F1-model_v1.cif.gz
AF-P0CW76-F1-model_v1.pdb.gz
AF-P43409-F1-model_v1.cif.gz
AF-P43409-F1-model_v1.pdb.gz
AF-P54009-F1-model_v1.cif.gz
AF-P54009-F1-model_v1.pdb.gz
AF-P54010-F1-model_v1.cif.gz
AF-P54010-F1-model_v1.pdb.gz
AF-P54011-F1-model_v1.cif.gz
AF-P54011-F1-model_v1.pdb.gz
AF-P54012-F1-model_v1.cif.gz
AF-P54012-F1-model_v1.pdb.gz
AF-P54013-F1-model_v1.cif.gz
AF-P54013-F1-model_v1.pdb.gz
AF-P54014-F1-model_v1.cif.gz
AF-P54014-F1-model_v1.pdb.gz
AF-P54015-F1-model_v1.cif.gz
AF-P54015-F1-model_v1.pdb.gz
AF-P54016-F1-model_v1.cif.gz
AF-P54016-F1-model_v1.pdb.gz
AF-P54017-F1-m

In [6]:
# count number of pdb files
!ls data/*pdb.gz | wc -l

1773


In [2]:
## get uniprot ids from pdb filenames
import os

files = os.listdir('data/')

uniprot_ids = []
for file in files:
    if file.startswith('AF-') and file.endswith('pdb.gz'):
        uniprot_id = file.split('-')[1]
        uniprot_ids.append(uniprot_id)
print(len(uniprot_ids))

1773


In [3]:
# https://www.uniprot.org/help/api_idmapping
# mapping uniprot id to pdb id
import urllib.parse
import urllib.request

url = 'https://www.uniprot.org/uploadlists/'

params = {
'from': 'ACC+ID',
'to': 'PDB_ID',
'format': 'tab',
'query': ' '.join(uniprot_ids)
}

data = urllib.parse.urlencode(params)
data = data.encode('utf-8')
req = urllib.request.Request(url, data)
with urllib.request.urlopen(req) as f:
    response = f.read()

res = response.decode('utf-8')

In [4]:
import io
import pandas as pd
id_mappings = pd.read_csv(io.StringIO(res), sep='\t')
id_mappings.columns = ['uniprot_id', 'pdb_id']
print(id_mappings.shape)
id_mappings.head()

(510, 2)


Unnamed: 0,uniprot_id,pdb_id
0,Q57991,2HMF
1,Q57991,3C1M
2,Q57991,3C1N
3,Q57991,3C20
4,Q60175,1RH5


In [5]:
id_mappings.nunique()

uniprot_id    211
pdb_id        462
dtype: int64

In [6]:
id_mappings.drop_duplicates().shape

(510, 2)

In [7]:
pdb_ids = id_mappings['pdb_id'].unique()

In [16]:
with open('pdb_ids.txt', 'w') as out:
    out.write(','.join(list(pdb_ids)))

In [17]:
%%bash
# download from RCSB PDB

cd data
bash ../batch_download.sh -f ../pdb_ids.txt -p

Downloading https://files.rcsb.org/download/2HMF.pdb.gz to ./2HMF.pdb.gz
Downloading https://files.rcsb.org/download/3C1M.pdb.gz to ./3C1M.pdb.gz
Downloading https://files.rcsb.org/download/3C1N.pdb.gz to ./3C1N.pdb.gz
Downloading https://files.rcsb.org/download/3C20.pdb.gz to ./3C20.pdb.gz
Downloading https://files.rcsb.org/download/1RH5.pdb.gz to ./1RH5.pdb.gz
Downloading https://files.rcsb.org/download/1RHZ.pdb.gz to ./1RHZ.pdb.gz
Downloading https://files.rcsb.org/download/2YXQ.pdb.gz to ./2YXQ.pdb.gz
Downloading https://files.rcsb.org/download/2YXR.pdb.gz to ./2YXR.pdb.gz
Downloading https://files.rcsb.org/download/3BO0.pdb.gz to ./3BO0.pdb.gz
Downloading https://files.rcsb.org/download/3BO1.pdb.gz to ./3BO1.pdb.gz
Downloading https://files.rcsb.org/download/3DKN.pdb.gz to ./3DKN.pdb.gz
Downloading https://files.rcsb.org/download/4V4N.pdb.gz to ./4V4N.pdb.gz
Failed to download https://files.rcsb.org/download/4V4N.pdb.gz
Downloading https://files.rcsb.org/download/4V7I.pdb.gz to ./

In [18]:
ls data/*.pdb.gz | wc -l

2233


## Parse PDB files

In [8]:
# !pip install biopython==1.79

In [49]:
from tqdm import tqdm
from joblib import Parallel, delayed
from Bio.PDB import PDBParser, MMCIFParser
from Bio.PDB.Polypeptide import is_aa

import xpdb
from contact_map_utils import parse_structure, three_to_one_standard

In [50]:
def get_atom_coords(residue, target_atoms=["N", "CA", "C", "O"]):
    """Extract the coordinates of the target_atoms from an AA residue.
    Handles exception where residue doesn't contain certain atoms
    """
    atom_coords = []
    for atom in target_atoms:
        try:
            coord = residue[atom].coord
        except KeyError:
            coord = [np.nan] * 3
        atom_coords.append(coord)
    return np.asarray(atom_coords)


def chain_to_coords(chain, target_atoms=["N", "CA", "C", "O"], name=""):
    """Convert a PDB chain in to coordinates of target atoms from all
    AAs"""
    output = {}
    # get AA sequence in the pdb structure
    pdb_seq = "".join(
        [
            three_to_one_standard(res.get_resname())
            for res in chain.get_residues()
            if is_aa(res)
        ]
    )
    if len(pdb_seq) <= 1:
        # has no or only 1 AA in the chain
        return None
    output["seq"] = pdb_seq
    # get the atom coords
    coords = np.asarray(
        [
            get_atom_coords(res, target_atoms=target_atoms)
            for res in chain.get_residues()
            if is_aa(res)
        ]
    )
    output["coords"] = coords.tolist()
    output["name"] = "{}-{}".format(name, chain.id)
    return output


def parse_structure_file_to_json_record(
    pdb_parser, cif_parser, pdb_file_path, name=""
):
    """Parse a protein structure file (.pdb or .cif) to extract all the chains
    to json records for LM-GVP model."""

    try:
        struct = parse_structure(
            pdb_parser, cif_parser, name, pdb_file_path
        )
    except Exception as e:
        print(pdb_file_path, "raised an error:")
        print(e)
        return []
    else:
        records = []
        chain_ids = set()
        for chain in struct.get_chains():
            if chain.id in chain_ids:  # skip duplicated chains
                continue
            chain_ids.add(chain.id)
            record = chain_to_coords(chain, name=name)
            if record is not None:
                records.append(record)
        return records

In [51]:
# PDB parser
pdb_parser = PDBParser(
    QUIET=True,
    PERMISSIVE=True,
    structure_builder=xpdb.SloppyStructureBuilder(),
)

In [13]:
# try parsing a AF pdb file
rec = parse_structure_file_to_json_record(
    pdb_parser, None,
    'data/AF-Q58321-F1-model_v1.pdb.gz',
    name='AF-Q58321'
)

In [15]:
type(rec), len(rec)

(list, 1)

In [16]:
rec[0].keys()

dict_keys(['seq', 'coords', 'name'])

In [17]:
rec[0]['seq']

'MMIMQYIYPFTAIVGQEKMKKALILNAINPKIGGVLIRGEKGTAKSTAVRALADLLPEIEIVEGCPFNCDPNGNLCDICKEKKKRGELKTTKKKMKVVNLPIGATEDRVIGTLDIEKAIKEGIKALEPGILAEANRNILYIDEVNLLDDHIIDVLLDAAAMGWNIIEREGVKIKHPSRFILVGTMNPEEGELRPQILDRFGLMVDVEGLNDVKDRVEVIKRVEEFNENPEAFYKKFEEEQNKLRERIIKARELLNKVEISDDLLEFISKVCIELGIQTNRADITVVRTAKALAAYNGRTYVTIDDVKEAMELALPHRMRRKPFEPPQLNKEKLEQMINEFKQQNNKDNEEKEEHKDDDVKKNMMK'

In [18]:
rec[0]['name']

'AF-Q58321-A'

In [19]:
np.asarray(rec[0]['coords']).shape, len(rec[0]['seq'])

((365, 4, 3), 365)

In [20]:
# try parsing a PDB file from RCSB
rec = parse_structure_file_to_json_record(
    pdb_parser, None,
    'data/2HMF.pdb.gz',
    '2HMF'
)

In [21]:
type(rec), len(rec)

(list, 4)

In [22]:
for r in rec:
    print(r['name'], len(r['seq']))

2HMF-A 465
2HMF-B 465
2HMF-C 465
2HMF-D 465


In [24]:
r['name']

'2HMF-D'

In [36]:
pdb_id = '2HMF'
id_mappings.loc[id_mappings['pdb_id'] == pdb_id, 'uniprot_id'].tolist()

['Q57991']

In [25]:
np.asarray(rec[0]['coords']).shape, len(rec[0]['seq'])

((465, 4, 3), 465)

In [23]:
len(set([r['seq'] for r in rec]))

1

## Save parsed structure to DocumentDB

In [26]:
!pip install pymongo

Collecting pymongo
  Downloading pymongo-3.12.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (523 kB)
[K     |████████████████████████████████| 523 kB 23.4 MB/s eta 0:00:01
[?25hInstalling collected packages: pymongo
Successfully installed pymongo-3.12.0
You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_latest_p36/bin/python -m pip install --upgrade pip' command.[0m


In [3]:
# !wget https://s3.amazonaws.com/rds-downloads/rds-combined-ca-bundle.pem

In [27]:
!ls -lht rds-combined-ca-bundle.pem

-r-------- 1 ec2-user ec2-user 43K Aug 19 20:29 rds-combined-ca-bundle.pem


In [28]:
from pymongo import MongoClient

In [29]:
secrets = json.load(open('DocumentDB_secrets.json', 'r')) 

In [284]:
from pymongo import MongoClient

# TLS enabled
uri = 'mongodb://{}:{}@{}:27017/?tls=true&tlsCAFile=rds-combined-ca-bundle.pem&replicaSet=rs0&readPreference=secondaryPreferred&retryWrites=false'\
    .format(secrets['db_username'], secrets['db_password'], secrets['host'])

client = MongoClient(uri)

In [285]:
db = client['proteins']

In [286]:
collection = db['proteins']

In [129]:
collection.delete_many({})

<pymongo.results.DeleteResult at 0x7fd79ed5afc8>

In [130]:
collection.create_index('id', unique=True)

'id_1'

## Parse and save data to DB

- for rcsb, use PDB-chain as id, add identifiers {uniprot_ids: [], pdb_ids: []}
- for AF, use `AF-Q58321` as id, add identifiers {uniprot_ids: [], pdb_ids: []}


In [131]:
import glob

In [132]:
# parse AF 
for pdb_file in glob.glob('data/AF-*.pdb.gz'):
    # AlphaFold structure
    id_ = '-'.join(os.path.basename(pdb_file).split('-')[:2])
    rec = parse_structure_file_to_json_record(
        pdb_parser, None,
        pdb_file,
        id_
    )[0]
    
    uniprot_id = os.path.basename(pdb_file).split('-')[1]
    # look up pdb id
    pdb_ids = list(id_mappings.loc[id_mappings['uniprot_id']==uniprot_id, 'pdb_id'])
    identifiers = {
        'uniprot_ids': [uniprot_id],
        'pdb_ids': pdb_ids
    }
    rec['id'] = id_
    rec['identifiers'] = identifiers
    rec['is_AF'] = 1

    collection.insert_one(rec)

In [133]:
# parse RCSB
for pdb_file in glob.glob('data/*.pdb.gz'):
    if not os.path.basename(pdb_file).startswith('AF-'):
        # RCSB structure
        pdb_id = os.path.basename(pdb_file).split('.')[0]
        recs = parse_structure_file_to_json_record(
            pdb_parser, None,
            pdb_file,
            pdb_id
        )
        
        # look up uniprot_ids
        uniprot_ids = list(id_mappings.loc[id_mappings['pdb_id']==pdb_id, 'uniprot_id'])
        identifiers = {
            'uniprot_ids': uniprot_ids,
            'pdb_ids': [pdb_id]
        }
        for rec in recs:
            rec['id'] = rec['name']
            rec['identifiers'] = identifiers
            rec['is_AF'] = 0
        
        collection.insert_many(recs)

In [134]:
collection.count_documents({})

3151

## Parse CATH labels

In [27]:
cath_df = pd.read_csv(
    '/home/ec2-user/SageMaker/efs/pdb-download/rcsb_cath_labels.txt.gz'
)
print(cath_df.shape)
cath_df.head()

(206938, 5)


Unnamed: 0,pdb,chain,c,a,t
0,5d8v,A,4,10,490
1,3nir,A,3,30,1350
2,1ejg,A,3,30,1350
3,5nw3,A,2,20,28
4,1ucs,A,3,90,1210


In [29]:
cath_df.count()

pdb      206938
chain    206938
c        206938
a        206938
t        206938
dtype: int64

In [30]:
cath_df['id'] = cath_df.apply(lambda row: '{}-{}'.format(
    row['pdb'].upper(), row['chain']), axis=1)

In [31]:
cath_df.nunique()

pdb       95787
chain        36
c             5
a            24
t           432
id       206938
dtype: int64

In [32]:
cath_df = cath_df.set_index('id')

In [33]:
cath_df['c'].value_counts()

3    106896
2     51414
1     45194
4      2032
6      1402
Name: c, dtype: int64

In [34]:
# join with rec_meta_df
rec_meta_df = rec_meta_df.merge(
    cath_df,
    left_index=True,
    right_index=True,
    how='left'
)
print(rec_meta_df.shape)
rec_meta_df.head()

(258917, 6)


Unnamed: 0_level_0,seq,pdb,chain,c,a,t
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1BXR-A,MPKRTDIKSILILGAGPIVIGQACEFDYSGAQACKALREEGYRVIL...,,,,,
1BXR-B,IKSALLVLEDGTQFHGRAIGATGSAVGEVVFNTSMTGYQEILTDPS...,1bxr,B,3.0,40.0,50.0
1BXR-C,MPKRTDIKSILILGAGPIVIGQACEFDYSGAQACKALREEGYRVIL...,,,,,
1BXR-D,IKSALLVLEDGTQFHGRAIGATGSAVGEVVFNTSMTGYQEILTDPS...,1bxr,D,3.0,40.0,50.0
1BXR-E,MPKRTDIKSILILGAGPIVIGQACEFDYSGAQACKALREEGYRVIL...,,,,,


In [35]:
rec_meta_df.count()

seq      258917
pdb      206904
chain    206904
c        206904
a        206904
t        206904
dtype: int64

## Add CATH labels to json documents

In [36]:
pdb_ids_with_CATH = set(cath_df.index)
for rec in all_records:
    if 'id' not in rec:
        pdb_id = rec['name'][:4] + '-' + rec['name'][-1]
        rec['id'] = pdb_id
    rec.pop('name', None)
        
    if rec['id'] in pdb_ids_with_CATH:
        rec['CATH'] = {key: int(cath_df.loc[rec['id'], key]) 
                       for key in ['c','a','t']}

In [37]:
all_records[0].keys()

dict_keys(['seq', 'coords', 'id'])

In [38]:
all_records[0]['CATH']

KeyError: 'CATH'

In [39]:
all_records[1]['CATH']

{'c': 3, 'a': 40, 't': 50}

# Load meta data for training

## Train a GNN classifer to distinguish AlphaFold predicted vs experimental structures

In [87]:
import torch
import torch.utils.data as data
import torch.nn.functional as F

In [88]:
torch.__version__

'1.7.1'

In [91]:
torch.version.cuda

'10.1'

In [92]:
!pip install dgl-cu101 -f https://data.dgl.ai/wheels/repo.html

Looking in links: https://data.dgl.ai/wheels/repo.html
Collecting dgl-cu101
  Downloading https://data.dgl.ai/wheels/dgl_cu101-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (109.4 MB)
[K     |████████████████████████████████| 109.4 MB 9.0 MB/s eta 0:00:01    |███████████████████████         | 78.6 MB 6.1 MB/s eta 0:00:06
Installing collected packages: dgl-cu101
Successfully installed dgl-cu101-0.7.0
You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_latest_p36/bin/python -m pip install --upgrade pip' command.[0m


In [93]:
import dgl

DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


Using backend: pytorch


## Dataset class

In [201]:
len(d1_to_index)

20

In [202]:
from Bio.PDB.Polypeptide import d1_to_index, three_to_one

d1_to_index['X'] = len(d1_to_index) # encode uncommon residue as 20

class ProteinDataset(data.Dataset):
    def __init__(self, collection, pipeline):
        """
        Args:
            - collection: pymongo.collection.Collection object
            - pipeline: a DocumentDB aggregation pipeline
            - tokenizer:
        """
        self.collection = collection
        # pre-fetch the metadata and labels from DocumentDB
        self.docs = [doc for doc in self.collection.aggregate(pipeline)]
        self.labels = [doc["y"] for doc in self.docs]
        
    def _convert_to_graph(self, protein):
        coords = torch.tensor(protein['coords'])
        X_ca = coords[:, 1]
        # construct knn graph from C-alpha coordinates
        g = dgl.knn_graph(X_ca, k=5)        
        seq = protein['seq']
        node_features = torch.tensor([d1_to_index[residue] for residue in seq])
        node_features = F.one_hot(node_features, num_classes=len(d1_to_index)).to(dtype=torch.float)
        
        # add node features
        g.ndata["h"] = node_features
        return g
        
        
    def __getitem__(self, idx):
        id_ = self.docs[idx]['id']
        protein = self.collection.find_one(
            {'id': id_}, 
            projection={"_id": False, "coords": True, "seq": True}
        )
        return self._convert_to_graph(protein), self.labels[idx]

    def __len__(self):
        return len(self.docs)



In [287]:
collection = db['proteins']

match = {"is_AF": {"$exists": True}}
project = {"y": "$is_AF", "_id": False, 'id': True}

pipeline = [
    {"$match": match},
    {"$project": project},
]

# docs = [doc for doc in collection.aggregate(pipeline)]
# docs[0]
dataset = ProteinDataset(collection, pipeline)

In [288]:
g, label = dataset[0]
g, label

(Graph(num_nodes=380, num_edges=1900,
       ndata_schemes={'h': Scheme(shape=(21,), dtype=torch.float32)}
       edata_schemes={}),
 1)

In [289]:
type(g)

dgl.heterograph.DGLHeteroGraph

In [290]:
# Compute contact map from C-alpha coordinates
# coords = torch.tensor(dataset[0]['coords'])
# coords.shape

In [291]:
# X_ca = coords[:, 1]
# X_ca.shape

In [292]:
# # construct knn graph from C-alpha coordinates
# g = dgl.knn_graph(X_ca, k=5)
# g

In [293]:
# type(g)

In [294]:
# seq = dataset[0]['seq']
# node_features = torch.tensor([d1_to_index[residue] for residue in seq])
# node_features = F.one_hot(node_features, num_classes=len(d1_to_index)).to(dtype=torch.float)
# node_features.shape

In [295]:
d1_to_index

{'A': 0,
 'C': 1,
 'D': 2,
 'E': 3,
 'F': 4,
 'G': 5,
 'H': 6,
 'I': 7,
 'K': 8,
 'L': 9,
 'M': 10,
 'N': 11,
 'P': 12,
 'Q': 13,
 'R': 14,
 'S': 15,
 'T': 16,
 'V': 17,
 'W': 18,
 'Y': 19,
 'X': 20}

In [296]:
# # add node features
# g.ndata["h"] = node_features

In [297]:
def collate(samples):
    graphs = list(zip(*samples))[0]
    targets = list(zip(*samples))[1]
    bg = dgl.batch(graphs)
    return bg, torch.tensor(targets).unsqueeze(1).to(torch.float32)

In [298]:
# train_loader = data.DataLoader(
#     dataset, batch_size=16, shuffle=True, collate_fn=collate)

In [299]:
# bg, labels = next(iter(train_loader))

In [300]:
# bg

In [301]:
# labels.shape

## GNN model

In [302]:
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_examples = len(dataset)
num_train = int(num_examples * 0.8)

train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))

# train_dataloader = GraphDataLoader(
#     dataset, sampler=train_sampler, batch_size=32, drop_last=False,
#     num_workers=16
# )
# test_dataloader = GraphDataLoader(
#     dataset, sampler=test_sampler, batch_size=32, drop_last=False)

train_dataloader = data.DataLoader(
    dataset, sampler=train_sampler, batch_size=32, 
    collate_fn=collate,
    num_workers=1
)

test_dataloader = data.DataLoader(
    dataset, sampler=test_sampler, batch_size=32, 
    collate_fn=collate,
    num_workers=1
)

In [303]:
len(train_dataloader), len(test_dataloader)

(79, 20)

In [304]:
it = iter(train_dataloader)
batch = next(it)
print(batch)

  "MongoClient opened before fork. Create MongoClient only "


RuntimeError: DataLoader worker (pid(s) 27672) exited unexpectedly

In [256]:
import torch.nn as nn
from dgl.nn import GraphConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')


In [257]:
# Create the model with given dimensions
dim_nfeats = len(d1_to_index)
n_classes = 2

model = GCN(dim_nfeats, 16, n_classes)

In [258]:
device = torch.device('cuda:0')
model = model.to(device)

In [259]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [260]:
model.train()
for epoch in range(10):
    print('epoch:', epoch)
    for batched_graph, labels in train_dataloader:
        
        batched_graph = batched_graph.to(device)
        labels = labels.to(device)
        
        pred = model(batched_graph, batched_graph.ndata['h'].float())
        loss = F.cross_entropy(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

epoch: 0


  "MongoClient opened before fork. Create MongoClient only "
  "MongoClient opened before fork. Create MongoClient only "
  "MongoClient opened before fork. Create MongoClient only "
  "MongoClient opened before fork. Create MongoClient only "
  "MongoClient opened before fork. Create MongoClient only "
  "MongoClient opened before fork. Create MongoClient only "
  "MongoClient opened before fork. Create MongoClient only "
  "MongoClient opened before fork. Create MongoClient only "
  "MongoClient opened before fork. Create MongoClient only "
  "MongoClient opened before fork. Create MongoClient only "
  "MongoClient opened before fork. Create MongoClient only "
  "MongoClient opened before fork. Create MongoClient only "
  "MongoClient opened before fork. Create MongoClient only "
  "MongoClient opened before fork. Create MongoClient only "
  "MongoClient opened before fork. Create MongoClient only "
  "MongoClient opened before fork. Create MongoClient only "


ConnectionResetError: [Errno 104] Connection reset by peer

RuntimeError: DataLoader worker (pid(s) 26104) exited unexpectedly

In [253]:
num_correct = 0
num_tests = 0
model.eval()
with torch.no_grad():
    for batched_graph, labels in test_dataloader:
        batched_graph = batched_graph.to(device)
        labels = labels.to(device)

        pred = model(batched_graph, batched_graph.ndata['h'].float())
        num_correct += (pred.argmax(1) == labels).sum().item()
        num_tests += len(labels)

print('Test accuracy:', num_correct / num_tests)

Test accuracy: 0.312202852614897


In [None]:
# Test accuracy: 0.312202852614897