Importing libraries

In [52]:
import json
from collections import defaultdict, Counter
import ast
import time
import random

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE

import torch
# import torch_geometric as tg
from torch_geometric.data import Data
import torch.nn as nn
from torch_geometric.nn import GCNConv

from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, confusion_matrix
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split, cross_val_score, RepeatedStratifiedKFold
from sklearn.preprocessing import LabelEncoder

from sentence_transformers import SentenceTransformer

import re
import ast


Pandas setup

In [53]:
pd.set_option('display.max_columns', None)  # Show all columns
pd.set_option('display.max_rows', 10)  # Limit number of rows displayed
pd.set_option('display.width', 1000)  # Set max width for table
pd.set_option('display.colheader_justify', 'center')  # Center-align column headers

Method for cleaning the data

In [54]:
def clean_special_chars(value):
    if isinstance(value, str):  
        return value.replace('\n', ' ').replace('\t', ' ').replace('\r', ' ').replace('  ', ' ').strip()
    return value 

Reading gab

In [55]:
content_gab = pd.read_csv('gab_reddit_benchmark/gab.csv')

content_gab["text"] = content_gab["text"].replace(to_replace=[None, np.nan, "", "nan", "n/a"], value="") 
content_gab["response"] = content_gab["response"].replace(to_replace=[None, np.nan, "", "nan", "n/a"], value="")
content_gab["hate_speech_idx"] = content_gab["hate_speech_idx"].replace(to_replace=[None, np.nan, "", "nan", "n/a"], value="") 

# content_gab["text"] = content_gab["text"].apply(clean_special_chars)
# content_gab["response"] = content_gab["response"].apply(clean_special_chars)

for index, row in content_gab.iterrows():
    row['text'] = row['text'].replace("'", '"')
    row['response'] = row['response'].replace("'", '"')

# content_gab = content_gab.applymap(clean_special_chars)
print(content_gab.head(n=10))
print('\n- - - - - -\n')
print(content_gab.columns)
print('\n- - - - - -\n')
print(content_gab.iloc[1]['id'])

                          id                                                text                        hate_speech_idx                      response                     
0                                    1. 39869714\r\n  1. i joined gab to remind myself how retarded ...         [1]      ["Using words that insult one group while defe...
1  1. 39845588\r\n2. \t39848775\r\n3. \t\t3991101...  1. This is what the left is really scared of. ...         [3]      ['You can disagree with someones opinion witho...
2                   1. 37485560\r\n2. \t37528625\r\n  1. It makes you an asshole.\r\n2. \tGive it to...         [2]      ['Your argument is more rational if you leave ...
3                   1. 39787626\r\n2. \t39794481\r\n  1. So they manage to provide a whole lot of da...         [2]      ["You shouldn't generalize a specific group or...
4  1. 37957930\r\n2. \t39953348\r\n3. \t\t3996521...  1. Hi there, i,m Keith, i hope you are doing w...         [3]      ['If someone is rude it 

Merge posts

In [56]:

def get_first_number(input_string):
    match = re.search(r'\d{2,}', input_string)
    if match:
        return int(match.group())
    return None

content_gab['extracted_id'] = content_gab['id'].apply(get_first_number)

# Find duplicate rows based on 'extracted_id'
duplicates = content_gab[content_gab.duplicated(subset=['extracted_id'], keep=False)]
filtered_groups = []
grouped = content_gab.groupby('extracted_id')
for key, group in grouped:
    if len(group) > 1:
        filtered_groups.append(group)

merged_df = pd.concat(filtered_groups, ignore_index=True)

merged_df.to_csv('gab_reddit_benchmark/gab_groups.csv', index=False)

merged_df = grouped.agg({
    'id': ' '.join,
    'text': ' '.join,
    'hate_speech_idx': ' '.join,
    'response': ' '.join
}).reset_index()
merged_df.to_csv('gab_reddit_benchmark/gab_merged.csv')

df = pd.read_csv('gab_reddit_benchmark/gab_merged.csv')
df = df.applymap(lambda x: x.replace('] [', ', ') if isinstance(x, str) else x)
df.to_csv('gab_reddit_benchmark/gab_merged.csv', index=False)

In [57]:
content_gab = pd.read_csv('gab_reddit_benchmark/gab_merged.csv')

In [58]:
def mark_text_labels(text_utterances_length, labels):
    if not labels:
        # return ['other'] * text_utterances_length
        return [0] * text_utterances_length
    new_labels = []
    int_list = ast.literal_eval(labels)
    for i in range(text_utterances_length):
        if i+1 in int_list:
            # new_labels.append('hate_speech')
            new_labels.append(1)
        else:
            # new_labels.append('other')
            new_labels.append(0)
    return new_labels

Splitting 'text' and 'response' into individual rows, so that I can construct a graph from it

In [59]:
text_column = []
text_labels_column = []
response_column = []
response_labels_column = []

for index, row in content_gab.iterrows():
    text_utterances = row['text'].split('\n')
    text_utterances = list(filter(None, text_utterances))

    for i, t in enumerate(text_utterances):
        text_utterances[i] = clean_special_chars(t)

    text_labels = mark_text_labels(len(text_utterances), row['hate_speech_idx'])

    response_utterances = ast.literal_eval(row['response']) if row['response'] else []
    for i, r in enumerate(response_utterances):
        response_utterances[i] = clean_special_chars(r)
    # response_labels = ['other'] * len(response_utterances)  
    response_labels = [0] * len(response_utterances)  

    
    text_column.append(text_utterances)
    text_labels_column.append(text_labels)
    response_column.append(response_utterances)
    response_labels_column.append(response_labels)

content_gab['text'] = text_column
content_gab['hate_speech_idx'] = text_labels_column
content_gab['response'] = response_column
content_gab['response_labels'] = response_labels_column

content_gab = content_gab.rename(columns={'hate_speech_idx': 'text_labels'})
print(content_gab.head())
print('- - - - ')
print(content_gab.columns)

for index, row in content_gab.iterrows():
    if index == 1:
        continue
    print(row['id'])
    print(row['text'])
    print(row['text_labels'])
    print(row['response'])
    print(row['response_labels'])
    break

ValueError: malformed node or string: nan

Encoding the labels

In [7]:
# label_encoder = LabelEncoder()
# content_gab['all_labels'] = content_gab['text_labels'] + content_gab['response_labels']
# content_gab['all_labels_encoded'] = content_gab['all_labels'].apply(label_encoder.fit_transform)
# print(content_gab.iloc[0])
# content_gab['text_labels_encoded'] = content_gab['text_labels'].apply(label_encoder.fit_transform)
# content_gab['response_labels_encoded'] = content_gab['response_labels'].apply(label_encoder.fit_transform)

Creating BERT encoding method

In [8]:
bert = SentenceTransformer('all-MiniLM-L6-v2')

def generate_embeddings(sentences):
    if isinstance(sentences, list):
        return bert.encode(sentences, show_progress_bar=True).tolist()
    elif isinstance(sentences, str):
        return bert.encode([sentences], show_progress_bar=True).tolist()
    return []

Generating BERT embeddings

In [9]:
content_gab = content_gab[:200]
before = time.time()
content_gab['text_embeddings'] = content_gab['text'].apply(generate_embeddings)
after_text = time.time()
print(content_gab.iloc[1]['text_embeddings'])
print('\nTIME FOR TEXT EMBEDDINGS: ', after_text - before)
print('\n- - - - - -\n')
content_gab['response_embeddings'] = content_gab['response'].apply(generate_embeddings)
after_response = time.time()
print(content_gab.iloc[2]['response_embeddings'])
print('\nTIME FOR RESPONSE EMBEDDINGS: ', after_response - after_text)
print('\n- - - - - -\n')

Batches: 100%|██████████| 1/1 [00:00<00:00,  4.55it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 29.89it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 29.29it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 24.89it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 42.71it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 100.03it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 49.66it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 38.30it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 44.74it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 22.32it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 77.25it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 90.92it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 32.28it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 52.35it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 70.60it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 68.76it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 109.13it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 64.70it/s]
Batches:

[[-0.023572057485580444, 0.01794440858066082, 0.0405656173825264, 0.06729649752378464, 0.09804009646177292, 0.035114046186208725, 0.06723955273628235, -0.07981331646442413, 0.012592255137860775, -0.06395190209150314, 0.014613417908549309, -0.028686098754405975, 0.06557455658912659, -0.05138685926795006, -0.1029239371418953, 0.015551798976957798, -0.06676264107227325, -0.0029045091941952705, -0.027871334925293922, 0.0603627972304821, -0.027235597372055054, 0.02632732316851616, 0.03128805756568909, 0.017424048855900764, 0.013893608935177326, -0.06276204437017441, -0.013789285905659199, -0.01572624407708645, -0.03531145676970482, -0.05476396158337593, 0.013463081791996956, -0.028276406228542328, -0.03120226040482521, -0.054336175322532654, -0.011610085144639015, -0.04129831865429878, 0.10698013752698898, -0.050249602645635605, -0.02981388196349144, 0.06209121271967888, -0.017547253519296646, -0.01501332875341177, 0.08790478855371475, 0.07822359353303909, -0.09159855544567108, 0.0481598377

Batches: 100%|██████████| 1/1 [00:00<00:00, 66.65it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 65.13it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 70.18it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 65.26it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 73.66it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 62.49it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 66.67it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 83.32it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 76.91it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 66.45it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 86.84it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 66.65it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 76.90it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 90.88it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 63.48it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 89.02it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 66.69it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 64.66it/s]
Batches: 1

[[0.09269341826438904, 0.0053440057672560215, -0.021850084885954857, -0.0007701154099777341, 0.04873264953494072, 0.028574705123901367, 0.06419847160577774, -0.009560780599713326, 0.053361181169748306, -0.0879056453704834, -0.014494847506284714, -0.0013367488281801343, 0.0982813909649849, -0.013169684447348118, 0.03653031960129738, 0.08743903785943985, 0.05540039762854576, 0.001535855932161212, -0.03637176379561424, 0.00894644483923912, -0.08317050337791443, 0.09316875785589218, -0.046798184514045715, -0.020233850926160812, -0.0499594546854496, -0.0417790450155735, -0.015589868649840355, 0.033974144607782364, -0.014390970580279827, 0.1086299866437912, 0.020122447982430458, -0.013854089193046093, 0.04655636101961136, 0.014802754856646061, 0.0009073888068087399, -0.0037359357811510563, 0.00721629848703742, 0.047422803938388824, 0.015439050272107124, 0.013636752963066101, 0.012506725266575813, 0.008980403654277325, -0.002120441058650613, -0.07505074888467789, -0.0013401504838839173, -0.00




Method for constructing graphs

In [10]:
def construct_graph(row):
    text_utterances = row['text_embeddings']
    response_utterances = row['response_embeddings']
    # text_utterances = row['text']
    # response_utterances = row['response']

    root = text_utterances[0]
    children = text_utterances[1:] + response_utterances
    num_nodes = len(children) +1

    # for t in text_utterances:
    #     print(t)
    # print()
    # for r in response_utterances:
    #     print(r)
    # print()
    # ids = [[0, i] for i in range(1, num_nodes)]
    # print(ids)
    # edge_index = torch.tensor(
    #     [[0]*num_nodes, list(range(1, num_nodes)
    # )], dtype=torch.long)
    edge_index = torch.tensor(
        [[0, i] for i in range(1, num_nodes)], dtype=torch.long
    ).t().contiguous()
    # edge_index = torch.tensor(
    #     [[0] * len(children), list(range(1, num_nodes))], dtype=torch.long
    # )
    

    # print(row['text_labels_encoded'])
    # print()
    # print(row['response_labels_encoded'])
    # print(type(row['text_labels_encoded']), row['text_labels_encoded'].shape, row['text_labels_encoded'])
    # print(type(row['response_labels_encoded']), row['response_labels_encoded'].shape, row['response_labels_encoded'])

    # ls = np.concatenate((row['text_labels_encoded'], row['response_labels_encoded']))
    ls = np.concatenate((row['text_labels'], row['response_labels'])).astype(int)
    
    print(ls)

    print(ls.shape)
    print(type(ls))
    print(type(ls[0]))
    
    labels = torch.tensor(ls, dtype=torch.int32)

    print(labels)

    node_features = torch.tensor([root] + children, dtype=torch.float)
    
    # print(node_features.shape)
    # print(edge_index.shape)
    # print(labels.shape)
    # print('sss')
    data = Data(x=node_features, edge_index=edge_index, y=labels)
    return data

Constructing graphs for all rows

In [11]:
graphs = []
for index, row in content_gab.iterrows():
    graphs.append(construct_graph(row))

print(graphs[0])
print('\n- - - - - -\n')
print(f"Number of nodes: {graphs[0].num_nodes}")
print(f"Number of edges: {graphs[0].num_edges}")

print()
print(len(graphs))
print('Graphs: ')
for i in range(0, 100):
    print(graphs[i])

[1 0 0 0]
(4,)
<class 'numpy.ndarray'>
<class 'numpy.int32'>
tensor([1, 0, 0, 0], dtype=torch.int32)
[0 0 1 0 0 0]
(6,)
<class 'numpy.ndarray'>
<class 'numpy.int32'>
tensor([0, 0, 1, 0, 0, 0], dtype=torch.int32)
[0 1 0 0 0]
(5,)
<class 'numpy.ndarray'>
<class 'numpy.int32'>
tensor([0, 1, 0, 0, 0], dtype=torch.int32)
[0 1 0 0 0]
(5,)
<class 'numpy.ndarray'>
<class 'numpy.int32'>
tensor([0, 1, 0, 0, 0], dtype=torch.int32)
[0 0 1 0 0 0]
(6,)
<class 'numpy.ndarray'>
<class 'numpy.int32'>
tensor([0, 0, 1, 0, 0, 0], dtype=torch.int32)
[1 0 0 0]
(4,)
<class 'numpy.ndarray'>
<class 'numpy.int32'>
tensor([1, 0, 0, 0], dtype=torch.int32)
[0 0 1 0 0 0]
(6,)
<class 'numpy.ndarray'>
<class 'numpy.int32'>
tensor([0, 0, 1, 0, 0, 0], dtype=torch.int32)
[0 1 0 0]
(4,)
<class 'numpy.ndarray'>
<class 'numpy.int32'>
tensor([0, 1, 0, 0], dtype=torch.int32)
[1 0 1 0 0 0]
(6,)
<class 'numpy.ndarray'>
<class 'numpy.int32'>
tensor([1, 0, 1, 0, 0, 0], dtype=torch.int32)
[0 0 1 0 0 0]
(6,)
<class 'numpy.ndarray'

Node level to graph level labels

In [12]:
# Example: Aggregating node-level labels into graph-level labels
#def convert_to_graph_level(dataset):
 #   new_dataset = []
 #   for data in dataset:
 #       # Example: Majority vote for classification
  #      #graph_label = data.y.mode()[0]  # Use the most frequent label
   #     graph_label = 1 if (data.y == 1).sum().item() > 0 else 0
    #    #data.y = graph_label.unsqueeze(0)  # Ensure shape [1]
     #   data.y = graph_label
      #  new_dataset.append(data)
    #return new_dataset

# Convert dataset
#new_dataset = convert_to_graph_level(graphs)
#for i in new_dataset:
 #   print(i.y)

Given percentage of shuffled dataset is test fold

In [13]:
random.shuffle(graphs)

size_train = len(graphs) - len(graphs) // 10 # 10% test dataset

train_dataset = graphs[:size_train]
test_dataset = graphs[size_train:]
print(len(train_dataset))
print(len(test_dataset))

180
20


Cross-validation 

In [14]:
#y = []
#for index, row in content_gab.iterrows():
#    y.append(np.concatenate((row['text_labels'], row['response_labels'])).astype(int))

#for i, q in enumerate(y):
#    print(q)
#    if i >= 5:
#       print('\n- - - -')
#       break
#for i, q in enumerate(graphs):
 #   print(q)
 #   if i >= 5:
 #      break

#rskf = RepeatedStratifiedKFold(n_splits=5, n_repeats=2, random_state=36851234)
#folds = list(rskf.split(graphs, y))

Mini-batching of graphs

In [15]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
Number of graphs in the current batch: 64
DataBatch(x=[360, 384], edge_index=[2, 296], y=[360], batch=[360], ptr=[65])

Step 2:
Number of graphs in the current batch: 64
DataBatch(x=[364, 384], edge_index=[2, 300], y=[364], batch=[364], ptr=[65])

Step 3:
Number of graphs in the current batch: 52
DataBatch(x=[272, 384], edge_index=[2, 220], y=[272], batch=[272], ptr=[53])



In [16]:
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
input_dim = graphs[0].x.shape[1]    # embedding dimensionality
data = graphs[0]

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        torch.manual_seed(1234567)
        self.conv1 = GCNConv(input_dim, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, 2)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

model = GCN(hidden_channels=16)
print(model)

GCN(
  (conv1): GCNConv(384, 16)
  (conv2): GCNConv(16, 2)
)


In [17]:

def visualize(h, color):
    z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())

    plt.figure(figsize=(10,10))
    plt.xticks([])
    plt.yticks([])

    plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
    plt.show()

model = GCN(hidden_channels=16)
model.eval()

out = model(data.x, data.edge_index)
#visualize(out, color=data.y)

In [None]:
from torch_geometric.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
print(train_dataset)

model = GCN(hidden_channels=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

data = graphs[0]

def train():
    model.train()
    optimizer.zero_grad()  # Clear gradients.
    out = model(data.x, data.edge_index)  # Perform a single forward pass.
    loss = criterion(out, data.y.long())  # Compute the loss for the training dataset.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    return loss

def test():
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)  # Use the class with highest probability.
    test_correct = pred == data.y.long()  # Check against ground-truth labels.
    test_acc = int(test_correct.sum()) / len(data.y)  # Derive ratio of correct predictions.
    return test_acc

for epoch in range(1, 101):
    loss = train()  # Pass the training dataset to train function.
    test_acc = test()  # Pass the test dataset to test function.
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')


[Data(x=[5, 384], edge_index=[2, 4], y=[5]), Data(x=[4, 384], edge_index=[2, 3], y=[4]), Data(x=[6, 384], edge_index=[2, 5], y=[6]), Data(x=[12, 384], edge_index=[2, 11], y=[12]), Data(x=[5, 384], edge_index=[2, 4], y=[5]), Data(x=[5, 384], edge_index=[2, 4], y=[5]), Data(x=[6, 384], edge_index=[2, 5], y=[6]), Data(x=[4, 384], edge_index=[2, 3], y=[4]), Data(x=[5, 384], edge_index=[2, 4], y=[5]), Data(x=[9, 384], edge_index=[2, 8], y=[9]), Data(x=[6, 384], edge_index=[2, 5], y=[6]), Data(x=[15, 384], edge_index=[2, 14], y=[15]), Data(x=[3, 384], edge_index=[2, 2], y=[3]), Data(x=[4, 384], edge_index=[2, 3], y=[4]), Data(x=[4, 384], edge_index=[2, 3], y=[4]), Data(x=[3, 384], edge_index=[2, 2], y=[3]), Data(x=[4, 384], edge_index=[2, 3], y=[4]), Data(x=[5, 384], edge_index=[2, 4], y=[5]), Data(x=[6, 384], edge_index=[2, 5], y=[6]), Data(x=[4, 384], edge_index=[2, 3], y=[4]), Data(x=[9, 384], edge_index=[2, 8], y=[9]), Data(x=[6, 384], edge_index=[2, 5], y=[6]), Data(x=[2, 384], edge_ind



Epoch: 031, Loss: 0.3571, Test Accuracy: 0.8000
Epoch: 032, Loss: 0.3825, Test Accuracy: 0.8000
Epoch: 033, Loss: 0.4025, Test Accuracy: 0.8000
Epoch: 034, Loss: 0.4158, Test Accuracy: 0.8000
Epoch: 035, Loss: 0.3057, Test Accuracy: 0.8000
Epoch: 036, Loss: 0.2970, Test Accuracy: 0.8000
Epoch: 037, Loss: 0.2943, Test Accuracy: 0.8000
Epoch: 038, Loss: 0.2919, Test Accuracy: 0.8000
Epoch: 039, Loss: 0.3938, Test Accuracy: 0.8000
Epoch: 040, Loss: 0.4471, Test Accuracy: 0.8000
Epoch: 041, Loss: 0.4488, Test Accuracy: 0.8000
Epoch: 042, Loss: 0.3368, Test Accuracy: 0.8000
Epoch: 043, Loss: 0.3192, Test Accuracy: 0.8000
Epoch: 044, Loss: 0.2822, Test Accuracy: 0.8000
Epoch: 045, Loss: 0.1937, Test Accuracy: 1.0000
Epoch: 046, Loss: 0.3685, Test Accuracy: 1.0000
Epoch: 047, Loss: 0.2306, Test Accuracy: 1.0000
Epoch: 048, Loss: 0.4012, Test Accuracy: 1.0000
Epoch: 049, Loss: 0.4167, Test Accuracy: 0.8000
Epoch: 050, Loss: 0.3275, Test Accuracy: 0.8000
Epoch: 051, Loss: 0.4047, Test Accuracy:

In [165]:
test_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')

Test Accuracy: 1.0000


GraphNN model class

In [161]:
class GraphNN(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GraphNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.relu = nn.ReLU()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.conv2(x, edge_index)
        return x

Method for training the NN

In [166]:
def train(model, optimizer, criterion, data):
    for train_idx, test_idx in folds:
        train_graphs = [graphs[i] for i in train_idx]
        test_graphs = [graphs[i] for i in test_idx]

        model.train()
        for data in train_graphs:
            optimizer.zero_grad()
            out = model(data)
            loss = criterion(out, data.y)
            loss.backward()
            optimizer.step()

        model.eval()
        correct = 0
        total = 0
        for data in test_graphs:
            out = model(data)
            pred = out.argmax(dim=1)
            correct += int((pred == data.y).sum())
            total += len(pred)
    print(f"Accuracy: {correct/total}")

Train the NN

In [167]:
input_dim = graphs[0].x.shape[1]    # embedding dimensionality
hidden_dim = 64
output_dim = len(label_encoder.classes_)

model = GraphNN(input_dim, hidden_dim, output_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

train(model, optimizer, criterion, graphs)

NameError: name 'label_encoder' is not defined