In [None]:
!wget https://nlp.stanford.edu/data/glove.6B.zip
!unzip glove.6B.zip
from gensim.models import KeyedVectors
word_vectors = KeyedVectors.load_word2vec_format('glove.6B.50d.txt', binary=False, no_header=True)

In [None]:
!pip install torch torch-geometric

import pandas as pd
import os
from collections import defaultdict
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch.nn import Linear
from torch_geometric.nn import GCNConv, global_mean_pool
from torch.optim import Adam
from torch.nn.functional import cross_entropy
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score, precision_score, recall_score, roc_auc_score, r2_score, mean_squared_error

torch.manual_seed(42)

In [111]:
def load_graph_info(nodes_degree_txt, nodes_edges_txt):
  nodes_features = defaultdict(dict)
  edges = []
  with open(nodes_degree_txt, 'r') as nodes_degree_file, open(nodes_edges_txt, 'r') as nodes_edges_file:
      reading_nodes, reading_edges = False, False

      for line in nodes_edges_file:
          line = line.strip()
          if (not line) or line.startswith("id*int") or line.startswith("source*int"):
              continue

          if line.startswith("*Nodes"):
              # start reading node info
              reading_nodes = True
              reading_edges = False
          elif line.startswith("*DirectedEdges"):
              # start reading edge info
              reading_nodes = False
              reading_edges = True

          elif reading_nodes:
            # parse node info
              parts = line.split(" ", 2)
              node_id = int(parts[0])
              label = parts[1].strip('"')
              nodes_features[node_id]['label'] = label


          elif reading_edges:
            # parse edge info
              parts = line.split()
              source = int(parts[0])
              target = int(parts[1])
              edges.append((source, target))

      for line in nodes_degree_file:
        line = line.strip()
        if (not line) or (not line.startswith("Node Degree")):
          parts = line.split()
          node_id = int(parts[0])
          degree = parts[1].strip('"')
          nodes_features[node_id]['degree'] = degree

      # print(len(nodes_features), len(edges))
      return nodes_features, edges


def get_word_embedding(word, embedding_dim=50):
    try:
        return torch.tensor(word_vectors[word], dtype=torch.float)
    except KeyError:
        # for words not in the vocabulary, use a zero vector
        print('WARNING: word not in the vocabulary: ', word)
        return torch.zeros(embedding_dim)


def prep_data_for_gnn(nodes_features, edges):
    # mapping from node ID to index
    node_id_to_index = {node_id: i for i, node_id in enumerate(nodes_features.keys())}

    # create node features: node degrees and word embeddings)
    features = []
    for node_id in nodes_features.keys():
        node_label = nodes_features[node_id]['label']
        node_embeddings = get_word_embedding(node_label)
        features.append(node_embeddings)
    x = torch.stack(features)  # Node feature matrix

    # create edge index tensor
    edge_index = torch.tensor(
        [[node_id_to_index[src], node_id_to_index[dst]] for src, dst in edges],
        dtype=torch.long,
    ).t().contiguous()
    return x, edge_index


def get_dataloader(id_label_list,transcript_folder, batch_size=None):
  """
  returns a list of Data objects
  """
  if not batch_size:
    batch_size = len(id_label_list)
  data_list = []
  for id, label in id_label_list:
    filepath = os.path.join(transcript_folder, id)
    nodes_features, edges = load_graph_info(f'{filepath}_dir_nodes_degree.txt', f'{filepath}_dir_nodes_edges.txt')
    x, edge_index = prep_data_for_gnn(nodes_features, edges)
    graph_data = Data(x=x, edge_index=edge_index, y=LABEL_MAPPING[label])
    data_list.append(graph_data)
  return DataLoader(data_list, batch_size=batch_size, shuffle=True)


def calc_metrics(actual_labels, pred_vals):
  results = {}
  results['accuracy'] = accuracy_score(pred, batch.y)
  results['f1'] = f1_score(actual_labels, pred_vals, average='macro')
  results['precision'] = precision_score(actual_labels, pred_vals, average='macro')
  results['recall'] = recall_score(actual_labels, pred_vals, average='macro')
  results['confusion_matrix'] = confusion_matrix(actual_labels, pred_vals)
  return results

In [113]:
LABEL_MAPPING = {'HC': 0,
                 'MCI': 1,
                 'Dementia': 2}

metadata = pd.read_csv("PROCESS_METADATA_ALL.csv")
df_train = metadata[metadata['Tr/Tt/Dv']=='train']
df_dev = metadata[metadata['Tr/Tt/Dv']=='dev']
df_test = metadata[metadata['Tr/Tt/Dv']=='test']
train_id_label_list = [(id, label) for id, label in zip(df_train['anyon_IDs'], df_train['diagnosis'])]
dev_id_label_list = [(id, label) for id, label in zip(df_dev['anyon_IDs'], df_dev['diagnosis'])]
test_id_label_list = [(id, label) for id, label in zip(df_test['anyon_IDs'], df_test['diagnosis'])]

transcript_folder = "SFT_transcripts_2"
train_dataloader = get_dataloader(train_id_label_list, transcript_folder)
dev_dataloader = get_dataloader(dev_id_label_list, transcript_folder)
test_dataloader = get_dataloader(test_id_label_list, transcript_folder)



In [118]:
class GNNClassifier(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNNClassifier, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = Linear(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.fc(x)
        return x

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_node_features = next(iter(train_dataloader)).x.size(1)
num_classes = 3
model = GNNClassifier(input_dim=num_node_features, hidden_dim=64, output_dim=num_classes).to(device)
optimizer = Adam(model.parameters(), lr=0.01)

# Train
for epoch in range(350):  # train for 50 epochs
    model.train()
    total_loss = 0

    for batch in train_dataloader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch)
        loss = cross_entropy(out, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    if epoch % 50 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item():.5f}')


Epoch 0, Loss: 1.09032
Epoch 50, Loss: 0.17506
Epoch 100, Loss: 0.00253
Epoch 150, Loss: 0.00099
Epoch 200, Loss: 0.00059
Epoch 250, Loss: 0.00039
Epoch 300, Loss: 0.00028


In [119]:
# Evaluate
model.eval()
for batch in dev_dataloader:
  pred = model(batch).argmax(dim=1)
  results = calc_metrics(batch.y, pred)
  print(results)

for batch in test_dataloader:
  pred = model(batch).argmax(dim=1)
  results = calc_metrics(batch.y, pred)
  print(results)


{'accuracy': 0.4, 'f1': 0.2749354005167959, 'precision': 0.2707070707070707, 'recall': 0.2793650793650793, 'confusion_matrix': array([[12,  9,  0],
       [ 8,  4,  3],
       [ 2,  2,  0]])}
{'accuracy': 0.575, 'f1': 0.49425287356321834, 'precision': 0.5416666666666666, 'recall': 0.476984126984127, 'confusion_matrix': array([[15,  6,  0],
       [ 7,  7,  1],
       [ 2,  1,  1]])}
