<a href="https://colab.research.google.com/github/Samitha-Nawarathna/GNN-for-Text-Analysis/blob/main/Model_Trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.4.0-py3-none-any.whl (1.0 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.1/1.0 MB[0m [31m3.5 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m17.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.4.0


In [3]:
data_path = "/content/drive/MyDrive/Datasets/bbc-full-text-document-classification/processed/"

In [4]:
import torch
import torch_geometric
import torch.nn as nn
import torch_geometric.nn as gnn
from torch_geometric.nn import aggr
from torch_geometric.nn import MessagePassing
from torch.utils.data import Dataset, random_split
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
import glob
import numpy as np

In [5]:
INPUT_SIZE = 50
OUTPUT_SIZE = 5
EPOCHS = 2

In [6]:
if torch.cuda.is_available():
  device = 'cuda'
else:
  device = 'cpu'

In [7]:
device

'cuda'

In [8]:
class Graph_Dataset(Dataset):
  def __init__(self,path):
    super().__init__()
    self.path = path
    self.path_list = self.getPaths()

  def getPaths(self):
    return glob.glob(self.path+'/*/*.pt')

  def __len__(self):
    return len(self.path_list)

  def __getitem__(self,idx):
    #print(idx)
    if type(idx) == slice:
      graphs = []
      start ,step, stop = idx.start, idx.step, idx.stop
      if start == None:
        start =0
      if stop == None:
        stop = len(self.path_list)-1
      if step == None:
        step =1
      #print(type(stop))
      for item in range(start, stop, step):
        graphs.append(torch.load(self.path_list[item]))
      return graphs
    return torch.load(self.path_list[idx])

In [9]:
dataset = Graph_Dataset(data_path)

In [10]:
train_size = int(0.8*len(dataset))
test_size = int(0.1*len(dataset))
val_size = len(dataset) - train_size - test_size
train_set, test_set, val_set = random_split(dataset, [train_size, test_size, val_size])

In [42]:
class MLP(nn.Module):
  def __init__(self,input_size, output_size):
    super().__init__()
    self.p1 = nn.Linear(input_size, 100)
    self.relu = nn.ReLU()
    self.p2 = nn.Linear(100,50)
    self.p3 = nn.Linear(50,output_size)

  def forward(self,x):
    x = self.p1(x)
    x = self.relu(x)
    x = self.p2(x)
    x = self.relu(x)
    x = self.p3(x)
    return x

class GraphModel(nn.Module):
  def __init__(self,input_size, output_size):
    super().__init__()
    self.gconv1 = gnn.conv.GatedGraphConv(out_channels=50, num_layers=50)
    self.f1 = MLP(input_size,output_size)
    self.f2 = MLP(input_size,output_size)
    self.sigmoid = nn.Sigmoid()
    self.tanh = nn.Tanh()

  def forward(self,x,edge_index,edge_attr):
    x = self.gconv1(x, edge_index, edge_weight = edge_attr)
    #print(x)
    x1 = self.f1(x)
    #print(x)
    x2 = self.f2(x)
    #print(x)
    x = x1*x2
    #print(x)
    xg = (1/x.shape[0])*torch.sum(x,dim=0) + torch.max(x, dim=0).values
    #print(x)
    return xg
    #x = self.attention(x)


In [43]:
model = GraphModel(INPUT_SIZE, OUTPUT_SIZE).to(device)

In [44]:
def extractGraph(graph):
   x = graph.x.type(torch.float)
   edge_index = graph.edge_index.type(torch.int64)
   edge_attr = graph.edge_attr.type(torch.float)
   return x, edge_index, edge_attr

In [45]:
def getPrediction(graph, model):
  x, edge_index, edge_attr = extractGraph(graph)
  return model(x, edge_index, edge_attr)

In [15]:
graph = dataset[5]

In [46]:
error = nn.CrossEntropyLoss()
lr = 0.001
optimizer = Adam(model.parameters(),lr=lr,betas=[0.99,0.95])
scheduler = StepLR(optimizer, step_size=500, gamma=0.5)

In [47]:
def getValErrors(graph, model=model):
  graph.to(device)
  output = getPrediction(graph,model)
  return error(output, graph.y)



In [48]:
val_split,_ = random_split(val_set, [50, len(val_set) - 50])

In [49]:
best_loss = float("inf")

In [50]:
for epoch in range(EPOCHS):
  for i, graph in enumerate(train_set):
    graph.to(device)
    optimizer.zero_grad()
    output = getPrediction(graph, model)
    y = graph.y.to(device)
    loss = error(output, y)
    loss.backward()
    optimizer.step()
    scheduler.step()
    if i%25==0 and i != 0:
      val_loss = 0
      for item in val_split:
        val_loss += getValErrors(item)
      val_loss /= len(val_split)
      if val_loss < best_loss:
        checkpoint_saver = torch.save(model.state_dict(), '/content/drive/MyDrive/Datasets/bbc-full-text-document-classification/models/cheackpoints.pt')
        best_loss =val_loss
      print(epoch,i,val_loss)



0 25 tensor(1.6096, device='cuda:0', grad_fn=<DivBackward0>)
0 50 tensor(1.6094, device='cuda:0', grad_fn=<DivBackward0>)
0 75 tensor(1.6090, device='cuda:0', grad_fn=<DivBackward0>)
0 100 tensor(1.6090, device='cuda:0', grad_fn=<DivBackward0>)
0 125 tensor(1.6091, device='cuda:0', grad_fn=<DivBackward0>)
0 150 tensor(1.6087, device='cuda:0', grad_fn=<DivBackward0>)
0 175 tensor(1.6083, device='cuda:0', grad_fn=<DivBackward0>)
0 200 tensor(1.6075, device='cuda:0', grad_fn=<DivBackward0>)
0 225 tensor(1.6062, device='cuda:0', grad_fn=<DivBackward0>)
0 250 tensor(1.6052, device='cuda:0', grad_fn=<DivBackward0>)
0 275 tensor(1.6049, device='cuda:0', grad_fn=<DivBackward0>)
0 300 tensor(1.6034, device='cuda:0', grad_fn=<DivBackward0>)
0 325 tensor(1.6018, device='cuda:0', grad_fn=<DivBackward0>)
0 350 tensor(1.6014, device='cuda:0', grad_fn=<DivBackward0>)
0 375 tensor(1.5991, device='cuda:0', grad_fn=<DivBackward0>)
0 400 tensor(1.5981, device='cuda:0', grad_fn=<DivBackward0>)
0 425 tenso

KeyboardInterrupt: ignored

In [38]:
model.load_state_dict(torch.load('/content/drive/MyDrive/Datasets/bbc-full-text-document-classification/models/cheackpoints.pt'))

<All keys matched successfully>

In [41]:
165/222

0.7432432432432432

In [40]:
c = 0
for item in test_set:
  item
  model.to('cpu')
  if np.argmax(getPrediction(item, model).detach().numpy()) == np.argmax(item.y.detach().numpy()):
    c += 1

print(f"{c} from {len(test_set)}")

165 from 222




```
# 165 from 222 for m1
# 40 from 222


```



In [None]:
def sqr(num):
  return num**2
sqr = np.vectorize(sqr)

In [None]:
x = torch.randn(10,)

In [None]:
torch.tensor(sqr(x))

tensor([1.2091, 0.2429, 2.7144, 0.1128, 0.2558, 0.0573, 0.2809, 0.4439, 1.9288,
        0.9987], dtype=torch.float64)