Importing libraries

In [48]:
import json
from collections import defaultdict, Counter
import ast
import time
import random

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE

import torch
# import torch_geometric as tg
from torch_geometric.data import Data
import torch.nn as nn
from torch_geometric.nn import GCNConv

from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, confusion_matrix
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split, cross_val_score, RepeatedStratifiedKFold
from sklearn.preprocessing import LabelEncoder

from sentence_transformers import SentenceTransformer

import re
import ast


Pandas setup

In [49]:
pd.set_option('display.max_columns', None)  # Show all columns
pd.set_option('display.max_rows', 10)  # Limit number of rows displayed
pd.set_option('display.width', 1000)  # Set max width for table
pd.set_option('display.colheader_justify', 'center')  # Center-align column headers

Method for cleaning the data

In [50]:
def clean_special_chars(value):
    if isinstance(value, str):  
        return value.replace('\n', ' ').replace('\t', ' ').replace('\r', ' ').replace('  ', ' ').strip()
    return value 

Reading gab

In [51]:
content_gab = pd.read_csv('gab_reddit_benchmark/gab.csv')

content_gab["text"] = content_gab["text"].replace(to_replace=[None, np.nan, "", "nan", "n/a"], value="") 
content_gab["response"] = content_gab["response"].replace(to_replace=[None, np.nan, "", "nan", "n/a"], value="")
content_gab["hate_speech_idx"] = content_gab["hate_speech_idx"].replace(to_replace=[None, np.nan, "", "nan", "n/a"], value="") 

# content_gab["text"] = content_gab["text"].apply(clean_special_chars)
# content_gab["response"] = content_gab["response"].apply(clean_special_chars)

for index, row in content_gab.iterrows():
    row['text'] = row['text'].replace("'", '"')
    row['response'] = row['response'].replace("'", '"')

# content_gab = content_gab.applymap(clean_special_chars)
print(content_gab.head(n=10))
print('\n- - - - - -\n')
print(content_gab.columns)
print('\n- - - - - -\n')
print(content_gab.iloc[1]['id'])

                          id                                                text                        hate_speech_idx                      response                     
0                                    1. 39869714\r\n  1. i joined gab to remind myself how retarded ...         [1]      ["Using words that insult one group while defe...
1  1. 39845588\r\n2. \t39848775\r\n3. \t\t3991101...  1. This is what the left is really scared of. ...         [3]      ['You can disagree with someones opinion witho...
2                   1. 37485560\r\n2. \t37528625\r\n  1. It makes you an asshole.\r\n2. \tGive it to...         [2]      ['Your argument is more rational if you leave ...
3                   1. 39787626\r\n2. \t39794481\r\n  1. So they manage to provide a whole lot of da...         [2]      ["You shouldn't generalize a specific group or...
4  1. 37957930\r\n2. \t39953348\r\n3. \t\t3996521...  1. Hi there, i,m Keith, i hope you are doing w...         [3]      ['If someone is rude it 

Merge posts

In [52]:

def get_first_number(input_string):
    match = re.search(r'\d{2,}', input_string)
    if match:
        return int(match.group())
    return None

content_gab['extracted_id'] = content_gab['id'].apply(get_first_number)

# Find duplicate rows based on 'extracted_id'
duplicates = content_gab[content_gab.duplicated(subset=['extracted_id'], keep=False)]
filtered_groups = []
grouped = content_gab.groupby('extracted_id')
for key, group in grouped:
    if len(group) > 1:
        filtered_groups.append(group)

merged_df = pd.concat(filtered_groups, ignore_index=True)

merged_df.to_csv('gab_reddit_benchmark/gab_groups.csv', index=False)

merged_df = grouped.agg({
    'id': ' '.join,
    'text': ' '.join,
    'hate_speech_idx': ' '.join,
    'response': ' '.join
}).reset_index()
merged_df.to_csv('gab_reddit_benchmark/gab_merged.csv')

df = pd.read_csv('gab_reddit_benchmark/gab_merged.csv')
df = df.applymap(lambda x: x.replace('] [', ', ') if isinstance(x, str) else x)
df = df.applymap(lambda x: x.replace(']  [', ', ') if isinstance(x, str) else x)
df = df.applymap(lambda x: 'n/a' if isinstance(x, str) and x.strip() == '' else x)
#df = df.applymap(lambda x: x.replace(' ', ',n/a') if isinstance(x, str) else x)
df.to_csv('gab_reddit_benchmark/gab_merged.csv', index=False)

In [53]:
content_gab = pd.read_csv('gab_reddit_benchmark/gab_merged.csv')
content_gab = content_gab.drop('Unnamed: 0', axis=1)
content_gab = content_gab.drop('extracted_id', axis=1)


content_gab["text"] = content_gab["text"].replace(to_replace=[None, np.nan, "", "nan", "n/a"], value="") 
content_gab["response"] = content_gab["response"].replace(to_replace=[None, np.nan, "", "nan", "n/a"], value="")
content_gab["hate_speech_idx"] = content_gab["hate_speech_idx"].replace(to_replace=[None, np.nan, "", "nan", "n/a"], value="") 

for index, row in content_gab.iterrows():
    row['text'] = row['text'].replace("'", '"')
    row['response'] = row['response'].replace("'", '"')

print(content_gab.head(n=10))
print('\n- - - - - -\n')
print(content_gab.columns)
print('\n- - - - - -\n')
print(content_gab.iloc[0]['id'])

                          id                                                text                        hate_speech_idx                      response                     
0  1. 37011807\r\n2. \t37012801\r\n3. \t\t3701333...  1. 64% of Hispanics have IQ scores too low to ...     [2, 2, 3]    ['Racist and hate speech will not be tolerated...
1                    1. 37012081\r\n 1. 37012081\r\n  1. LYING CUNT BITCH !!\r\n 1. LYING CUNT BITCH...        [1, 1]    ['Please refrain from using derogatory terms f...
2                                    1. 37012105\r\n  1. Nigger Kills 2 cops https://www.google.com/...           [1]    ['The N word is highly offensive.   All ethnic...
3                   1. 37012146\r\n2. \t37013020\r\n  1. STABBED IN NECK   arrested but no descr   #...           [2]    ['Use of the n-word is unacceptable in our dis...
4  1. 37012285\r\n2. \t37013809\r\n3. \t37013868\r\n  1. Retard 1st Class  w/ Cluster & Ribbon\r\n2....           [1]    ['We should try and not 

In [54]:
def mark_text_labels(text_utterances_length, labels):
    if not labels:
        # return ['other'] * text_utterances_length
        return [0] * text_utterances_length
    new_labels = []
    int_list = ast.literal_eval(labels)
    for i in range(text_utterances_length):
        if i+1 in int_list:
            # new_labels.append('hate_speech')
            new_labels.append(1)
        else:
            # new_labels.append('other')
            new_labels.append(0)
    return new_labels

Splitting 'text' and 'response' into individual rows, so that I can construct a graph from it

In [55]:
text_column = []
text_labels_column = []
response_column = []
response_labels_column = []

for index, row in content_gab.iterrows():
    text_utterances = row['text'].split('\n')
    text_utterances = list(filter(None, text_utterances))

    for i, t in enumerate(text_utterances):
        text_utterances[i] = clean_special_chars(t)
    text_labels = mark_text_labels(len(text_utterances), row['hate_speech_idx'])

    response_utterances = ast.literal_eval(row['response']) if row['response'] else []
    for i, r in enumerate(response_utterances):
        response_utterances[i] = clean_special_chars(r)
    # response_labels = ['other'] * len(response_utterances)  
    response_labels = [0] * len(response_utterances)  

    
    text_column.append(text_utterances)
    text_labels_column.append(text_labels)
    response_column.append(response_utterances)
    response_labels_column.append(response_labels)

content_gab['text'] = text_column
content_gab['hate_speech_idx'] = text_labels_column
content_gab['response'] = response_column
content_gab['response_labels'] = response_labels_column

content_gab = content_gab.rename(columns={'hate_speech_idx': 'text_labels'})
print(content_gab.head())
print('- - - - ')
print(content_gab.columns)

for index, row in content_gab.iterrows():
    if index == 1:
        continue
    print(row['id'])
    print(row['text'])
    print(row['text_labels'])
    print(row['response'])
    print(row['response_labels'])
    break

                          id                                                text                               text_labels                             response                        response_labels  
0  1. 37011807\r\n2. \t37012801\r\n3. \t\t3701333...  [1. 64% of Hispanics have IQ scores too low to...  [0, 1, 1, 0, 0, 0, 0, 0]  [Racist and hate speech will not be tolerated ...     [0, 0, 0, 0, 0]
1                    1. 37012081\r\n 1. 37012081\r\n   [1. LYING CUNT BITCH !!, 1. LYING CUNT BITCH !!]                    [1, 0]  [Please refrain from using derogatory terms fo...  [0, 0, 0, 0, 0, 0]
2                                    1. 37012105\r\n  [1. Nigger Kills 2 cops https://www.google.com...                       [1]  [The N word is highly offensive.  All ethnicit...           [0, 0, 0]
3                   1. 37012146\r\n2. \t37013020\r\n  [1. STABBED IN NECK  arrested but no descr  #D...                    [0, 1]  [Use of the n-word is unacceptable in our disc...              [0

Encoding the labels

In [56]:
# label_encoder = LabelEncoder()
# content_gab['all_labels'] = content_gab['text_labels'] + content_gab['response_labels']
# content_gab['all_labels_encoded'] = content_gab['all_labels'].apply(label_encoder.fit_transform)
# print(content_gab.iloc[0])
# content_gab['text_labels_encoded'] = content_gab['text_labels'].apply(label_encoder.fit_transform)
# content_gab['response_labels_encoded'] = content_gab['response_labels'].apply(label_encoder.fit_transform)

Creating BERT encoding method

In [57]:
bert = SentenceTransformer('all-MiniLM-L6-v2')

def generate_embeddings(sentences):
    if isinstance(sentences, list):
        return bert.encode(sentences, show_progress_bar=True).tolist()
    elif isinstance(sentences, str):
        return bert.encode([sentences], show_progress_bar=True).tolist()
    return []

Generating BERT embeddings

In [58]:
content_gab = content_gab[:200]
before = time.time()
content_gab['text_embeddings'] = content_gab['text'].apply(generate_embeddings)
after_text = time.time()
print(content_gab.iloc[1]['text_embeddings'])
print('\nTIME FOR TEXT EMBEDDINGS: ', after_text - before)
print('\n- - - - - -\n')
content_gab['response_embeddings'] = content_gab['response'].apply(generate_embeddings)
after_response = time.time()
print(content_gab.iloc[2]['response_embeddings'])
print('\nTIME FOR RESPONSE EMBEDDINGS: ', after_response - after_text)
print('\n- - - - - -\n')

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches: 100%|██████████| 1/1 [00:00<00:00,  6.03it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 63.96it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 72.73it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 21.55it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 52.36it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 55.20it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 14.82it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 28.35it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 105.34it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 27.19it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 326.15it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.50it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 50.75it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 34.74it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 42.58it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 94.53it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 22.29it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.37it/s]
Batches:

[[-0.040111515671014786, -0.019980860874056816, 0.013976861722767353, 0.021806733682751656, -0.013460909947752953, 0.029162606224417686, 0.12978620827198029, -0.03199726715683937, 0.013665660284459591, 0.009298085235059261, 0.027889572083950043, -0.12927503883838654, 0.057423628866672516, -0.03645731136202812, -0.08924495428800583, 0.003100527450442314, -0.08778698742389679, -0.01677405834197998, 0.004921083338558674, 0.01669718511402607, 0.058399613946676254, 0.0006712392787449062, 0.05522868037223816, 0.06297323107719421, -0.0073430766351521015, -0.045072540640830994, 0.01580304279923439, -0.00940280593931675, -0.035581137984991074, -0.043321527540683746, -0.0003930269740521908, 0.08144824951887131, -0.0247819721698761, 0.026637855917215347, 0.01791100762784481, -0.07403568178415298, 0.043707575649023056, -0.005354813765734434, -0.03749185428023338, 0.019640570506453514, 0.0009533570264466107, -0.08591044694185257, 0.0563356839120388, 0.015451686456799507, 0.012978585436940193, 0.013

Batches: 100%|██████████| 1/1 [00:00<00:00, 27.43it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 28.49it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 48.46it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 56.63it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 57.03it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 50.48it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 35.30it/s]
Batches: 0it [00:00, ?it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 41.05it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 35.94it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 33.41it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 86.33it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 49.62it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 48.57it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 42.34it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 43.18it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 36.86it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 47.15it/s]
Batches: 100%|██████████| 1/1 [00:00<

[[0.07439517229795456, 0.04188426211476326, -0.09229229390621185, -0.031470734626054764, -0.052862267941236496, -0.020004451274871826, 0.03124356083571911, -0.08033803105354309, -0.010242228396236897, -0.038911741226911545, 0.048023175448179245, -0.05908317118883133, 0.03346498683094978, 0.0034259618259966373, -0.054501939564943314, 0.05408971756696701, 0.013149838894605637, -0.029767796397209167, -0.045929502695798874, -0.0669841393828392, -0.07561828196048737, 0.09951933473348618, 0.14216117560863495, 0.03618994355201721, -0.07632524520158768, -0.05548761039972305, 0.042400769889354706, 0.006625581532716751, 0.03837255761027336, 0.021637508645653725, -0.01584904082119465, 0.017154689878225327, 0.06257234513759613, 0.06082208827137947, -0.07147544622421265, -0.05742684006690979, 0.06866130977869034, 0.0014725130749866366, 0.006498364731669426, -0.04267704859375954, -0.047141071408987045, 0.010848326608538628, -0.016492504626512527, -0.03367657959461212, 0.08141005039215088, 0.05338096

Method for constructing graphs

In [59]:
def construct_graph(row):
    text_utterances = row['text_embeddings']
    response_utterances = row['response_embeddings']
    # text_utterances = row['text']
    # response_utterances = row['response']

    root = text_utterances[0]
    children = text_utterances[1:] + response_utterances
    num_nodes = len(children) +1

    # for t in text_utterances:
    #     print(t)
    # print()
    # for r in response_utterances:
    #     print(r)
    # print()
    # ids = [[0, i] for i in range(1, num_nodes)]
    # print(ids)
    # edge_index = torch.tensor(
    #     [[0]*num_nodes, list(range(1, num_nodes)
    # )], dtype=torch.long)
    edge_index = torch.tensor(
        [[0, i] for i in range(1, num_nodes)], dtype=torch.long
    ).t().contiguous()
    # edge_index = torch.tensor(
    #     [[0] * len(children), list(range(1, num_nodes))], dtype=torch.long
    # )
    

    # print(row['text_labels_encoded'])
    # print()
    # print(row['response_labels_encoded'])
    # print(type(row['text_labels_encoded']), row['text_labels_encoded'].shape, row['text_labels_encoded'])
    # print(type(row['response_labels_encoded']), row['response_labels_encoded'].shape, row['response_labels_encoded'])

    # ls = np.concatenate((row['text_labels_encoded'], row['response_labels_encoded']))
    ls = np.concatenate((row['text_labels'], row['response_labels'])).astype(int)
    
    print(ls)

    print(ls.shape)
    print(type(ls))
    print(type(ls[0]))
    
    labels = torch.tensor(ls, dtype=torch.int32)

    print(labels)

    node_features = torch.tensor([root] + children, dtype=torch.float)
    
    # print(node_features.shape)
    # print(edge_index.shape)
    # print(labels.shape)
    # print('sss')
    data = Data(x=node_features, edge_index=edge_index, y=labels)
    return data

Constructing graphs for all rows

In [60]:
graphs = []
for index, row in content_gab.iterrows():
    graphs.append(construct_graph(row))

print(graphs[0])
print('\n- - - - - -\n')
print(f"Number of nodes: {graphs[0].num_nodes}")
print(f"Number of edges: {graphs[0].num_edges}")

print()
print(len(graphs))
print('Graphs: ')
for i in range(0, 100):
    print(graphs[i])

[0 1 1 0 0 0 0 0 0 0 0 0 0]
(13,)
<class 'numpy.ndarray'>
<class 'numpy.int32'>
tensor([0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32)
[1 0 0 0 0 0 0 0]
(8,)
<class 'numpy.ndarray'>
<class 'numpy.int32'>
tensor([1, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32)
[1 0 0 0]
(4,)
<class 'numpy.ndarray'>
<class 'numpy.int32'>
tensor([1, 0, 0, 0], dtype=torch.int32)
[0 1 0 0]
(4,)
<class 'numpy.ndarray'>
<class 'numpy.int32'>
tensor([0, 1, 0, 0], dtype=torch.int32)
[1 0 0 0 0 0]
(6,)
<class 'numpy.ndarray'>
<class 'numpy.int32'>
tensor([1, 0, 0, 0, 0, 0], dtype=torch.int32)
[1 0 0]
(3,)
<class 'numpy.ndarray'>
<class 'numpy.int32'>
tensor([1, 0, 0], dtype=torch.int32)
[0 0 1 1 0 0 0]
(7,)
<class 'numpy.ndarray'>
<class 'numpy.int32'>
tensor([0, 0, 1, 1, 0, 0, 0], dtype=torch.int32)
[0]
(1,)
<class 'numpy.ndarray'>
<class 'numpy.int32'>
tensor([0], dtype=torch.int32)
[1 0 0 0]
(4,)
<class 'numpy.ndarray'>
<class 'numpy.int32'>
tensor([1, 0, 0, 0], dtype=torch.int32)
[1 0 0 0]
(4,)
<clas

Merge to one graph

In [61]:

# Initialize empty lists to store the merged node features, edge indices, and labels (y)
merged_x = []
merged_edge_index = []
merged_y = []

# Keep track of the offset for node indices in subsequent graphs
node_offset = 0

# Iterate over each graph in the list
for graph in graphs:
    # Concatenate node features
    merged_x.append(graph.x)
    
    # Adjust edge indices: add the current node_offset to the second row of edge_index
    merged_edge_index.append(graph.edge_index + node_offset)
    
    # Concatenate labels (y), the target labels from each graph
    merged_y.append(graph.y)
    
    # Update node_offset for the next graph
    node_offset += graph.x.size(0)

# Concatenate all node features, edge indices, and labels
merged_x = torch.cat(merged_x, dim=0)
merged_edge_index = torch.cat(merged_edge_index, dim=1)
merged_y = torch.cat(merged_y, dim=0)

# Create a new graph with merged node features, edge indices, and labels (y)
merged_graph = Data(x=merged_x, edge_index=merged_edge_index, y=merged_y)

print(merged_graph)
# Print the merged graph details
print("Merged Node Features:")
print(merged_graph.x)
print("Merged Edge Index:")
print(merged_graph.edge_index)


Data(x=[1198, 384], edge_index=[2, 998], y=[1198])
Merged Node Features:
tensor([[ 0.0813, -0.0214, -0.0582,  ...,  0.0492,  0.0206,  0.0012],
        [-0.0027,  0.0063, -0.0170,  ..., -0.0939, -0.0490,  0.0221],
        [ 0.0185, -0.0804,  0.0782,  ..., -0.0679, -0.0118,  0.0497],
        ...,
        [ 0.0485,  0.0435, -0.0563,  ..., -0.0356,  0.0521, -0.0419],
        [ 0.0144,  0.0081, -0.0431,  ..., -0.0029,  0.0659, -0.0579],
        [ 0.1278,  0.0149, -0.0109,  ...,  0.0264,  0.0013, -0.0679]])
Merged Edge Index:
tensor([[   0,    0,    0,  ..., 1193, 1193, 1193],
        [   1,    2,    3,  ..., 1195, 1196, 1197]])


Node level to graph level labels

In [62]:
# Example: Aggregating node-level labels into graph-level labels
#def convert_to_graph_level(dataset):
 #   new_dataset = []
 #   for data in dataset:
 #       # Example: Majority vote for classification
  #      #graph_label = data.y.mode()[0]  # Use the most frequent label
   #     graph_label = 1 if (data.y == 1).sum().item() > 0 else 0
    #    #data.y = graph_label.unsqueeze(0)  # Ensure shape [1]
     #   data.y = graph_label
      #  new_dataset.append(data)
    #return new_dataset

# Convert dataset
#new_dataset = convert_to_graph_level(graphs)
#for i in new_dataset:
 #   print(i.y)

Given percentage of shuffled dataset is test fold

In [63]:
random.shuffle(graphs)

size_train = len(graphs) - len(graphs) // 10 # 10% test dataset

train_dataset = graphs[:size_train]
test_dataset = graphs[size_train:]
print(len(train_dataset))
print(len(test_dataset))

180
20


Cross-validation 

In [64]:
#y = []
#for index, row in content_gab.iterrows():
#    y.append(np.concatenate((row['text_labels'], row['response_labels'])).astype(int))

#for i, q in enumerate(y):
#    print(q)
#    if i >= 5:
#       print('\n- - - -')
#       break
#for i, q in enumerate(graphs):
 #   print(q)
 #   if i >= 5:
 #      break

#rskf = RepeatedStratifiedKFold(n_splits=5, n_repeats=2, random_state=36851234)
#folds = list(rskf.split(graphs, y))

Mini-batching of graphs

In [65]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
Number of graphs in the current batch: 64
DataBatch(x=[411, 384], edge_index=[2, 347], y=[411], batch=[411], ptr=[65])

Step 2:
Number of graphs in the current batch: 64
DataBatch(x=[368, 384], edge_index=[2, 304], y=[368], batch=[368], ptr=[65])

Step 3:
Number of graphs in the current batch: 52
DataBatch(x=[317, 384], edge_index=[2, 265], y=[317], batch=[317], ptr=[53])



Node classification

In [66]:
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
input_dim = merged_graph.x.shape[1]    # embedding dimensionality
data = merged_graph

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        torch.manual_seed(1234567)
        self.conv1 = GCNConv(input_dim, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, 2)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

model = GCN(hidden_channels=16)
print(model)

GCN(
  (conv1): GCNConv(384, 16)
  (conv2): GCNConv(16, 2)
)


In [67]:

def visualize(h, color):
    z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())

    plt.figure(figsize=(10,10))
    plt.xticks([])
    plt.yticks([])

    plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
    plt.show()

model = GCN(hidden_channels=16)
model.eval()

out = model(data.x, data.edge_index)
#visualize(out, color=data.y)

In [68]:
from torch_geometric.data import DataLoader

model = GCN(hidden_channels=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

data = merged_graph

def train():
    model.train()
    optimizer.zero_grad()  # Clear gradients.
    out = model(data.x, data.edge_index)  # Perform a single forward pass.
    loss = criterion(out, data.y.long())  # Compute the loss for the training dataset.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    return loss

def test():
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)  # Use the class with highest probability.
    test_correct = pred == data.y.long()  # Check against ground-truth labels.
    test_acc = int(test_correct.sum()) / len(data.y)  # Derive ratio of correct predictions.
    return test_acc

for epoch in range(1, 101):
    loss = train()  # Pass the training dataset to train function.
    test_acc = test()  # Pass the test dataset to test function.
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')


Epoch: 001, Loss: 0.6894, Test Accuracy: 0.7922
Epoch: 002, Loss: 0.6319, Test Accuracy: 0.7922
Epoch: 003, Loss: 0.5803, Test Accuracy: 0.7922
Epoch: 004, Loss: 0.5405, Test Accuracy: 0.7922
Epoch: 005, Loss: 0.5318, Test Accuracy: 0.7922
Epoch: 006, Loss: 0.5238, Test Accuracy: 0.7922
Epoch: 007, Loss: 0.5417, Test Accuracy: 0.7922
Epoch: 008, Loss: 0.5459, Test Accuracy: 0.7922
Epoch: 009, Loss: 0.5421, Test Accuracy: 0.7922
Epoch: 010, Loss: 0.5371, Test Accuracy: 0.7922
Epoch: 011, Loss: 0.5251, Test Accuracy: 0.7922
Epoch: 012, Loss: 0.5132, Test Accuracy: 0.7922
Epoch: 013, Loss: 0.5104, Test Accuracy: 0.7922
Epoch: 014, Loss: 0.5037, Test Accuracy: 0.7922
Epoch: 015, Loss: 0.4979, Test Accuracy: 0.7922
Epoch: 016, Loss: 0.5026, Test Accuracy: 0.7922
Epoch: 017, Loss: 0.4963, Test Accuracy: 0.7922
Epoch: 018, Loss: 0.5015, Test Accuracy: 0.7922
Epoch: 019, Loss: 0.4988, Test Accuracy: 0.7922
Epoch: 020, Loss: 0.4973, Test Accuracy: 0.7922
Epoch: 021, Loss: 0.5009, Test Accuracy:

In [69]:
test_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')

Test Accuracy: 0.7922


GraphNN model class

In [70]:
class GraphNN(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GraphNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.relu = nn.ReLU()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.conv2(x, edge_index)
        return x

Method for training the NN

In [71]:
def train(model, optimizer, criterion, data):
    for train_idx, test_idx in folds:
        train_graphs = [graphs[i] for i in train_idx]
        test_graphs = [graphs[i] for i in test_idx]

        model.train()
        for data in train_graphs:
            optimizer.zero_grad()
            out = model(data)
            loss = criterion(out, data.y)
            loss.backward()
            optimizer.step()

        model.eval()
        correct = 0
        total = 0
        for data in test_graphs:
            out = model(data)
            pred = out.argmax(dim=1)
            correct += int((pred == data.y).sum())
            total += len(pred)
    print(f"Accuracy: {correct/total}")

Train the NN

In [72]:
input_dim = graphs[0].x.shape[1]    # embedding dimensionality
hidden_dim = 64
output_dim = len(label_encoder.classes_)

model = GraphNN(input_dim, hidden_dim, output_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

train(model, optimizer, criterion, graphs)

NameError: name 'label_encoder' is not defined