<a href="https://colab.research.google.com/github/Samitha-Nawarathna/GNN-for-Text-Analysis/blob/main/Model_Trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.3.1.tar.gz (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: torch_geometric
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone
  Created wheel for torch_geometric: filename=torch_geometric-2.3.1-py3-none-any.whl size=910454 sha256=2461d4ccb37d261527c6b4944fe5dc91322e5c833adb135d537edd48157d99ea
  Stored in directory: /root/.cache/pip/wheels/ac/dc/30/e2874821ff308ee67dcd7a66dbde912411e19e35a1addda028
Successfully built torch_geometric
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.3.1


In [3]:
data_path = "/content/drive/MyDrive/Datasets/bbc-full-text-document-classification/processed/"

In [52]:
import torch
import torch_geometric
import torch.nn as nn
import torch_geometric.nn as gnn
from torch_geometric.nn import aggr
from torch_geometric.nn import MessagePassing
from torch.utils.data import Dataset, random_split
from torch.optim import Adam
import glob
import numpy as np

In [7]:
INPUT_SIZE = 50
OUTPUT_SIZE = 5

In [34]:
if torch.cuda.is_available():
  device = 'cuda'
else:
  device = 'cpu'

'cpu'

In [70]:
class Graph_Dataset(Dataset):
  def __init__(self,path):
    super().__init__()
    self.path = path
    self.path_list = self.getPaths()

  def getPaths(self):
    return glob.glob(self.path+'*.pt')

  def __len__(self):
    return len(self.path_list)

  def __getitem__(self,idx):
    #print(idx)
    if type(idx) == slice:
      graphs = []
      start ,step, stop = idx.start, idx.step, idx.stop
      if start == None:
        start =0
      if stop == None:
        stop = len(self.path_list)-1
      if step == None:
        step =1
      #print(type(stop))
      for item in range(start, stop, step):
        graphs.append(torch.load(self.path_list[item]))
      return graphs
    return torch.load(self.path_list[idx])

In [71]:
dataset = Graph_Dataset(data_path)

In [72]:
train_size = int(0.8*len(dataset))
test_size = int(0.1*len(dataset))
val_size = len(dataset) - train_size - test_size
train_set, test_set, val_set = random_split(dataset, [train_size, test_size, val_size])

In [62]:
class MLP(nn.Module):
  def __init__(self,input_size, output_size):
    super().__init__()
    self.p1 = nn.Linear(input_size, 100)
    self.relu = nn.ReLU()
    self.p2 = nn.Linear(100,50)
    self.p3 = nn.Linear(50,output_size)

  def forward(self,x):
    x = self.p1(x)
    x = self.relu(x)
    x = self.p2(x)
    x = self.relu(x)
    x = self.p3(x)
    return x

class GraphModel(nn.Module):
  def __init__(self,input_size, output_size):
    super().__init__()
    self.gconv1 = gnn.conv.GatedGraphConv(out_channels=50, num_layers=5)
    self.f1 = MLP(input_size,output_size)
    self.f2 = MLP(input_size,output_size)
    self.sigmoid = nn.Sigmoid()
    self.tanh = nn.Tanh()

  def forward(self,x,edge_index,edge_attr):
    x = self.gconv1(x, edge_index, edge_weight = edge_attr)
    #print(x)
    x1 = self.f1(x)
    #print(x)
    x2 = self.f2(x)
    #print(x)
    x = x1*x2
    #print(x)
    xg = (1/x.shape[0])*torch.sum(x,dim=0) + torch.max(x, dim=0).values
    #print(x)
    return xg
    #x = self.attention(x)


In [63]:
model = GraphModel(INPUT_SIZE, OUTPUT_SIZE).to(device)

In [11]:
def extractGraph(graph):
   x = graph.x.type(torch.float)
   edge_index = graph.edge_index.type(torch.int64)
   edge_attr = graph.edge_attr.type(torch.float)
   return x, edge_index, edge_attr

In [16]:
def getPrediction(graph, model):
  x, edge_index, edge_attr = extractGraph(graph)
  return model(x, edge_index, edge_attr)

In [69]:
graph = dataset[5]

5


In [31]:
error = nn.CrossEntropyLoss()
lr = 0.001
optimizer = Adam(model.parameters(),lr=lr,betas=[0.99,0.95])

In [59]:
def getValErrors(graph):
  output = getPrediction(model)
  return error(output, graph.y)

getValErrors = np.vectorize(getValErrors)

In [73]:
for i, graph in enumerate(train_set):
  graph.to(device)
  optimizer.zero_grad()
  output = getPrediction(graph, model)
  y = graph.y.to(device)
  loss = error(output, y)
  loss.backward()
  optimizer.step()
  if i%25==0 &i != 0:
    val_errors = getValErrors(val_set)
    val_loss = torch.sum(val_errors,axis=0)/len(val_set)
    print(val_loss)
  break

In [53]:
def sqr(num):
  return num**2
sqr = np.vectorize(sqr)

In [54]:
x = torch.randn(10,)

In [56]:
torch.tensor(sqr(x))

tensor([1.2091, 0.2429, 2.7144, 0.1128, 0.2558, 0.0573, 0.2809, 0.4439, 1.9288,
        0.9987], dtype=torch.float64)