In [3]:
import torch
import pickle
import numpy as np
import networkx as nx
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_mean_pool
from torch_geometric.loader import DataLoader
from torch_geometric.utils.convert import to_networkx
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
from torch_geometric.nn.models import Node2Vec

In [4]:
# Load data splits
# dataset = 'fer2013'
dataset = 'ck'
train_data_path = dataset + '_data/train_data_70_20_10.pkl'
val_data_path = dataset + '_data/val_data_70_20_10.pkl'
test_data_path = dataset + '_data/test_data_70_20_10.pkl'

with open(train_data_path, 'rb') as f:
    train_data = pickle.load(f)
with open(val_data_path, 'rb') as f:
    val_data = pickle.load(f)
with open(test_data_path, 'rb') as f:
    test_data = pickle.load(f)

adjacency_matrix = np.loadtxt('standard_mesh_adj_matrix.csv', delimiter=',')
G = nx.from_numpy_array(adjacency_matrix)

# Add batch attribute to each data object
for data in train_data:
    data.batch = torch.zeros(data.x.size(0), dtype=torch.long)
for data in val_data:
    data.batch = torch.zeros(data.x.size(0), dtype=torch.long)
for data in test_data:
    data.batch = torch.zeros(data.x.size(0), dtype=torch.long)

class EarlyStopping:
    def __init__(self, patience=10, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0

    def __call__(self, val_loss, model):
        if self.best_score is None:
            self.best_score = val_loss
            self.save_checkpoint(val_loss, model)
        elif val_loss > self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_loss
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_loss_min = val_loss

In [9]:
len(train_data)

291

In [48]:
from karateclub import Graph2Vec

train_graphs = [to_networkx(data) for data in train_data]
val_graphs = [to_networkx(data) for data in val_data]
test_graphs = [to_networkx(data) for data in test_data]
all_graphs = train_graphs + val_graphs + test_graphs

graph2vec = Graph2Vec()
graph2vec.fit(all_graphs)

all_embeddings = torch.tensor(graph2vec.get_embedding(), dtype=torch.float32)

train_data_embeddings = all_embeddings[:len(train_graphs)]
val_data_embeddings = all_embeddings[len(train_graphs):len(train_graphs) + len(val_graphs)]
test_data_embeddings = all_embeddings[len(train_graphs) + len(val_graphs):]

In [49]:
train_data

[Data(x=[468, 3], edge_index=[2, 2644], y=[1], bbox=[6], batch=[468]),
 Data(x=[468, 3], edge_index=[2, 2644], y=[1], bbox=[6], batch=[468]),
 Data(x=[468, 3], edge_index=[2, 2644], y=[1], bbox=[6], batch=[468]),
 Data(x=[468, 3], edge_index=[2, 2644], y=[1], bbox=[6], batch=[468]),
 Data(x=[468, 3], edge_index=[2, 2644], y=[1], bbox=[6], batch=[468]),
 Data(x=[468, 3], edge_index=[2, 2644], y=[1], bbox=[6], batch=[468]),
 Data(x=[468, 3], edge_index=[2, 2644], y=[1], bbox=[6], batch=[468]),
 Data(x=[468, 3], edge_index=[2, 2644], y=[1], bbox=[6], batch=[468]),
 Data(x=[468, 3], edge_index=[2, 2644], y=[1], bbox=[6], batch=[468]),
 Data(x=[468, 3], edge_index=[2, 2644], y=[1], bbox=[6], batch=[468]),
 Data(x=[468, 3], edge_index=[2, 2644], y=[1], bbox=[6], batch=[468]),
 Data(x=[468, 3], edge_index=[2, 2644], y=[1], bbox=[6], batch=[468]),
 Data(x=[468, 3], edge_index=[2, 2644], y=[1], bbox=[6], batch=[468]),
 Data(x=[468, 3], edge_index=[2, 2644], y=[1], bbox=[6], batch=[468]),
 Data(

In [50]:
from torch.utils.data import DataLoader, TensorDataset

train_labels = torch.tensor(([data.y for data in train_data]))
val_labels = torch.tensor(([data.y for data in val_data]))
test_labels = torch.tensor(([data.y for data in test_data]))

# Update TensorDatasets with labels
train_dataset = TensorDataset(train_data_embeddings, train_labels)
val_dataset = TensorDataset(val_data_embeddings, val_labels)
test_dataset = TensorDataset(test_data_embeddings, test_labels)

# Update DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [51]:
for data in train_loader:
    print(data)

[tensor([[0.1404, 0.1256, 0.1549,  ..., 0.1707, 0.0206, 0.0878],
        [0.1371, 0.1180, 0.1566,  ..., 0.1710, 0.0230, 0.0805],
        [0.1422, 0.1277, 0.1553,  ..., 0.1867, 0.0141, 0.0822],
        ...,
        [0.1398, 0.1208, 0.1682,  ..., 0.1810, 0.0216, 0.0781],
        [0.1424, 0.1198, 0.1598,  ..., 0.1710, 0.0126, 0.0863],
        [0.1394, 0.1271, 0.1591,  ..., 0.1809, 0.0227, 0.0818]]), tensor([0, 5, 5, 0, 2, 5, 0, 1, 1, 3, 0, 3, 1, 0, 5, 4, 5, 7, 6, 1, 3, 0, 4, 1,
        4, 1, 5, 3, 1, 0, 0, 0])]
[tensor([[0.1373, 0.1114, 0.1579,  ..., 0.1715, 0.0143, 0.0770],
        [0.1465, 0.1277, 0.1621,  ..., 0.1723, 0.0234, 0.0849],
        [0.1396, 0.1256, 0.1609,  ..., 0.1683, 0.0112, 0.0926],
        ...,
        [0.1429, 0.1271, 0.1516,  ..., 0.1700, 0.0252, 0.0780],
        [0.1447, 0.1195, 0.1602,  ..., 0.1757, 0.0155, 0.0913],
        [0.1451, 0.1230, 0.1558,  ..., 0.1740, 0.0161, 0.0822]]), tensor([0, 3, 3, 0, 1, 0, 6, 0, 3, 4, 6, 0, 0, 0, 0, 5, 3, 6, 0, 2, 1, 1, 1, 6,
      

In [69]:
import torch.nn as nn

input_dim = train_data_embeddings.shape[-1]
output_dim = len(np.unique([data.y.item() for data in train_data]))

model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim)
        )
# Convert data lists to DataLoader
batch_size = 32

# train_loader = DataLoader(train_data_embeddings, batch_size=batch_size, shuffle=True)
# val_loader = DataLoader(val_data_embeddings, batch_size=batch_size, shuffle=False)
# test_loader = DataLoader(test_data_embeddings, batch_size=batch_size, shuffle=False)

# Define the training and evaluation functions
def train():
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for x, y in train_loader:
        x = x.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pred = out.argmax(dim=1)
        correct += pred.eq(y).sum().item()
        total += y.size(0)
    return total_loss / len(train_loader), correct / total

def evaluate(loader):
    model.eval()
    correct = 0
    total = 0
    val_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            out = model(x)
            pred = out.argmax(dim=1)
            correct += pred.eq(y).sum().item()
            total += y.size(0)
            val_loss += criterion(out, y).item()
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
    return correct / total, val_loss / len(loader), all_labels, all_preds


# Get number of classes
output_dim = len(np.unique([data.y.item() for data in train_data]))

# Initialize model, optimizer, and criterion
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = SimpleGCN(input_dim=3, hidden_dim=64, output_dim=output_dim).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Calculate class weights
label_counts = np.bincount([data.y.item() for data in train_data])
class_weights = 1.0 / label_counts
class_weights = class_weights / class_weights.sum()
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

early_stopping = EarlyStopping(patience=20, delta=0.001)

train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []

# Training loop
for epoch in range(1, 501):
    train_loss, train_acc = train()
    val_acc, val_loss, _, _ = evaluate(val_loader)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_acc)
    val_accuracies.append(val_acc)
    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, '
          f'Train Acc: {int(100 * train_acc):02d}%, Val Acc: {int(100 * val_acc):02d}%')

    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break

# Load the last checkpoint with the best model
model.load_state_dict(torch.load('checkpoint.pt'))

Epoch: 001, Train Loss: 2.0831, Val Loss: 2.0774, Train Acc: 22%, Val Acc: 21%
Epoch: 002, Train Loss: 2.0825, Val Loss: 2.0787, Train Acc: 22%, Val Acc: 21%
Epoch: 003, Train Loss: 2.0822, Val Loss: 2.0788, Train Acc: 22%, Val Acc: 21%
Epoch: 004, Train Loss: 2.0786, Val Loss: 2.0792, Train Acc: 22%, Val Acc: 21%
Epoch: 005, Train Loss: 2.0803, Val Loss: 2.0792, Train Acc: 22%, Val Acc: 21%
Epoch: 006, Train Loss: 2.0824, Val Loss: 2.0795, Train Acc: 22%, Val Acc: 21%
Epoch: 007, Train Loss: 2.0800, Val Loss: 2.0787, Train Acc: 22%, Val Acc: 21%
Epoch: 008, Train Loss: 2.0764, Val Loss: 2.0773, Train Acc: 22%, Val Acc: 21%
Epoch: 009, Train Loss: 2.0839, Val Loss: 2.0772, Train Acc: 18%, Val Acc: 10%
Epoch: 010, Train Loss: 2.0795, Val Loss: 2.0776, Train Acc: 10%, Val Acc: 19%
Epoch: 011, Train Loss: 2.0827, Val Loss: 2.0784, Train Acc: 20%, Val Acc: 21%
Epoch: 012, Train Loss: 2.0779, Val Loss: 2.0795, Train Acc: 22%, Val Acc: 21%
Epoch: 013, Train Loss: 2.0803, Val Loss: 2.0791, Tr

<All keys matched successfully>