<a href="https://colab.research.google.com/github/anjaa7/Graph-Neural-Networks-for-Molecular-Propery-Prediction---JAX/blob/main/MolecularPropertyPredictionJAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
  
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:

class MPLayer(nn.Module):
    def __init__(self, in_feats, out_feats):
        super().__init__()
        
        self.transform = nn.Linear(in_feats, out_feats)#kreira sloj feed forward mreze sa in_feats ulaza, out_feats izlaza
        self.msg_layer = nn.Linear(in_feats, in_feats) 

    def forward(self, g, node_feats):	
        # Reassign Data
        g.ndata['x'] = node_feats
        
        # Message Passing
        g.send(g.edges(), message_func=self.message)
        g.recv(g.nodes(), reduce_func=self.reduce)

        # Compute New Features with Transformation Layer
        new_feats = self.transform(g.ndata.pop('x'))

        return new_feats
        
    def message(self, edges):
        msg = torch.relu(self.msg_layer(edges.src['x']))

        return {'message': msg}

    def reduce(self, nodes):
        msg = nodes.mailbox['message']
        sum_msg = torch.sum(msg, dim=1)

        return {'x': sum_msg} 


In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize hyperparameters
        self.raw_features = 119
        self.num_classes = 2

        self.batch_size = 10
        self.dense_size = 300

        # Initialize trainable layers
        # lifting -> message passing * 3 -> readout
        self.lift = nn.Linear(self.raw_features,self.dense_size)
        self.mpl1 = MPLayer(self.dense_size,self.dense_size)
        self.mpl2 = MPLayer(self.dense_size,self.dense_size)
        self.mpl3 = MPLayer(self.dense_size,self.dense_size)
        self.read = nn.Linear(self.dense_size, self.num_classes)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=5e-4)

    def forward(self, g):		
        out = self.lift(g.ndata.pop('x').to('cuda'))
        out = torch.relu(self.mpl1(g, out))
        out = torch.relu(self.mpl2(g, out))
        out = torch.relu(self.mpl3(g, out))
        out = self.readout(g, out)
        return out

    def readout(self, g, node_feats):
        g.ndata['x'] = self.read(node_feats)
        logits = dgl.sum_nodes(g, 'x')

        return logits

    def accuracy_function(self, logits, labels):
        preds = torch.argmax(logits, dim=1)	
        acc = torch.sum(torch.eq(preds, labels)).item() / preds.nelement()
        
        return acc

In [None]:
def train(model, train_data):
	n = model.batch_size
	num_batches = len(train_data) // n

	loss_func = nn.CrossEntropyLoss()

	for i in tqdm(range(num_batches)):
		start, end = i*n, (i+1)*n
		molecules = train_data[start:end]

		### Batching Graphs and Labels
		in_, out_ = [], []
		for mol in molecules:
			in_.append(build_graph(mol))
			out_.append(mol.label.item())
		inputs, labels = dgl.batch(in_), torch.tensor(out_, dtype=torch.long)

		### Forward Pass and Loss
		logits = model.forward(inputs)
		loss = loss_func(logits.cpu(), labels.cpu())

		### Back Propagation
		model.optimizer.zero_grad()
		loss.backward()
		model.optimizer.step()

In [None]:
def test(model, test_data):
	n = model.batch_size
	num_batches = len(test_data) // n

	acc = 0
	for i in tqdm(range(num_batches)):
		start, end = i*n, (i+1)*n
		molecules = test_data[start:end]

		### Batching Graphs and Labels
		in_, out_ = [], []
		for mol in molecules:
			in_.append(build_graph(mol))
			out_.append(mol.label.item())
		inputs, labels = dgl.batch(in_), torch.tensor(out_, dtype=torch.long)

		### Forward Pass and Accuracy
		logits = model.forward(inputs)
		acc += model.accuracy_function(logits.cpu(), labels) * n

	return acc / (n * num_batches)

In [None]:
def main():
	
	

	model = gnn.Model()
	
	num_epochs = 15
	for epoch in range(num_epochs):
		train(model, train_data)
		acc = test(model, test_data)
		print("Test Accuracy (Epoch: %d / %d) : %.3f" % (epoch+1, num_epochs, acc))


In [None]:
if __name__ == '__main__':
	main()

NameError: ignored