In [64]:
import torch
import pickle
import numpy as np
import networkx as nx
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool, GINConv, GATConv
from torch_geometric.explain import Explainer, GNNExplainer
from torch_geometric.loader import DataLoader
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report

In [65]:
# Load data splits
# dataset = 'fer2013'
dataset = 'ck'
train_data_path = dataset + '_data/train_data_70_20_10.pkl'
val_data_path = dataset + '_data/val_data_70_20_10.pkl'
test_data_path = dataset + '_data/test_data_70_20_10.pkl'

with open(train_data_path, 'rb') as f:
    train_data = pickle.load(f)
with open(val_data_path, 'rb') as f:
    val_data = pickle.load(f)
with open(test_data_path, 'rb') as f:
    test_data = pickle.load(f)

adjacency_matrix = np.loadtxt('standard_mesh_adj_matrix.csv', delimiter=',')
G = nx.from_numpy_array(adjacency_matrix)

# Add batch attribute to each data object
for data in train_data:
    data.batch = torch.zeros(data.x.size(0), dtype=torch.long)
    data.device = 'cpu'
for data in val_data:
    data.batch = torch.zeros(data.x.size(0), dtype=torch.long)
    data.device = 'cpu'
for data in test_data:
    data.batch = torch.zeros(data.x.size(0), dtype=torch.long)
    data.device = 'cpu'
    
class EarlyStopping:
    def __init__(self, patience=10, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0

    def __call__(self, val_loss, model):
        if self.best_score is None:
            self.best_score = val_loss
            self.save_checkpoint(val_loss, model)
        elif val_loss > self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_loss
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_loss_min = val_loss


In [66]:
from torch_geometric.nn import BatchNorm
from torch_geometric.nn import GATConv, GCNConv, SAGEConv, GraphConv

class GIN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, device='cpu'):
        super(GIN, self).__init__()
        self.device = device
        nn1 = torch.nn.Sequential(torch.nn.Linear(input_dim, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, hidden_dim))
        self.conv1 = GINConv(nn1)
        self.bn1 = BatchNorm(hidden_dim)
        
        nn2 = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, hidden_dim))
        self.conv2 = GINConv(nn2)
        self.bn2 = BatchNorm(hidden_dim)
        
        self.lin = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, data, edge_index, batch):
        data = data.to(self.device)
        x, edge_index, batch = data.to(torch.float), edge_index.to(torch.int64), batch
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = self.lin(x)
        return F.log_softmax(x, dim=1)

class GAT(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, device='cpu'):
        super(GAT, self).__init__()
        self.conv1 = GATConv(input_dim, hidden_dim, heads=4, concat=True)
        self.conv2 = GATConv(hidden_dim * 4, hidden_dim, heads=4, concat=True)
        self.lin = torch.nn.Linear(hidden_dim * 4, output_dim)
        self.device = device

    def forward(self, data, edge_index, batch):
        data = data.to(self.device)
        x, edge_index, batch = data.to(torch.float), edge_index.to(torch.int64), batch
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = self.conv2(x, edge_index)
        x = F.elu(x)
        x = global_mean_pool(x, batch)
        x = self.lin(x)
        return F.log_softmax(x, dim=1)
    
class GCNSAGE(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, device='cpu'):
        super(GCNSAGE, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.lin = torch.nn.Linear(hidden_dim, output_dim)
        self.device = device

    def forward(self, data, edge_index, batch):
        data = data.to(self.device)
        x, edge_index, batch = data.to(torch.float), edge_index.to(torch.int64), batch
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = self.lin(x)
        return F.log_softmax(x, dim=1)
    
class GConv(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, device='cpu'):
        super(GConv, self).__init__()
        self.conv1 = GraphConv(input_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, hidden_dim)
        self.lin = torch.nn.Linear(hidden_dim, output_dim)
        self.device = device

    def forward(self, data, edge_index, batch):
        data = data.to(self.device)
        x, edge_index, batch = data.to(torch.float), edge_index.to(torch.int64), batch
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = self.lin(x)
        return F.log_softmax(x, dim=1)
    
class GINBNorm(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, device='cpu'):
        super(GINBNorm, self).__init__()
        nn1 = torch.nn.Sequential(torch.nn.Linear(input_dim, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, hidden_dim))
        self.conv1 = GINConv(nn1)
        self.bn1 = BatchNorm(hidden_dim)
        
        nn2 = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, hidden_dim))
        self.conv2 = GINConv(nn2)
        self.bn2 = BatchNorm(hidden_dim)
        
        self.lin = torch.nn.Linear(hidden_dim, output_dim)
        self.device = device

    def forward(self, data, edge_index, batch):
        data = data.to(self.device)
        x, edge_index, batch = data.to(torch.float), edge_index.to(torch.int64), batch
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = self.lin(x)
        return F.log_softmax(x, dim=1)

In [67]:
def train(model, device, train_loader, optimizer, criterion):
    model.train()
    model.to(device)
    total_loss = 0
    correct = 0
    total = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pred = out.argmax(dim=1)
        correct += pred.eq(data.y).sum().item()
        total += data.y.size(0)
    return total_loss / len(train_loader), correct / total

def evaluate(model, device, loader, criterion):
    model.eval()
    model.to(device)
    correct = 0
    total = 0
    val_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            pred = out.argmax(dim=1)
            correct += pred.eq(data.y).sum().item()
            total += data.y.size(0)
            val_loss += criterion(out, data.y).item()
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(data.y.cpu().numpy())
    return correct / total, val_loss / len(loader), all_labels, all_preds


In [97]:
output_dim = len(np.unique([data.y.item() for data in train_data]))
device = ('cpu')
### CHANGE THE NAME OF THE MODEL HERE
model = GIN(input_dim=3, hidden_dim=64, output_dim=output_dim)
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
batch_size = 32
n_epochs = 250
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
# Initialize model, optimizer, and criterion

# Calculate class weights
label_counts = np.bincount([data.y.item() for data in train_data])
class_weights = 1.0 / label_counts
class_weights = class_weights / class_weights.sum()
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)


for epoch in range(n_epochs):
    train_loss, train_acc = train(model, device, train_loader, optimizer, criterion)
    train_losses.append(train_loss)
    train_accuracies.append(train_acc) 
    test_acc, test_loss, test_labels, test_preds = evaluate(
        model, 
        device, 
        test_loader, 
        criterion
    )
    test_losses.append(test_loss)
    test_accuracies.append(test_acc)

In [77]:
explainer = Explainer(
    model=model, 
    algorithm=GNNExplainer(
    ),
    explanation_type="model",
    node_mask_type="attributes",
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='graph',
        return_type='log_probs'
    )
)
# explanations for each node
explanations = []
explanation = explainer(test_data[0].x, test_data[0].edge_index, target=None, batch=test_data[0].batch)
explanations.append(explanation)

In [79]:
import plotly.graph_objects as go
selected_node = train_data[0]

# Create an empty figure
# fig = go.Figure()

x_vals = [value[0] for value in selected_node.x]
y_vals = [value[1] for value in selected_node.x]
z_vals = [value[2] for value in selected_node.x]
weights = np.mean(explanations[0].node_mask.numpy(), 1)

fig = go.Figure(go.Scatter3d(
    x=x_vals,
    y=y_vals,
    z=z_vals,
    mode='markers',
    marker=dict(
        size=5,
        color=weights,
        colorscale='Viridis',
        colorbar=dict(title='Weights'),
        opacity=0.8,
    ),
))

# Update layout for better visualization
fig.update_layout(
    title=f'3D Landmarks for Expression',
    scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'),
)

# Show the plot
fig.show(renderer='browser')

In [84]:
def run_explanation_vis(graph_id):
    explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(
    ),
    explanation_type="model",
    node_mask_type="attributes",
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='graph',
        return_type='log_probs'
    )
    )
    # explanations for each node
    explanation = explainer(test_data[graph_id].x, test_data[graph_id].edge_index, target=None, batch=test_data[graph_id].batch)

    # Create an empty figure
    # fig = go.Figure()

    x_vals = [value[0] for value in selected_node.x]
    y_vals = [value[1] for value in selected_node.x]
    z_vals = [value[2] for value in selected_node.x]
    weights = np.mean(explanation.node_mask.numpy(), 1)

    fig = go.Figure(go.Scatter3d(
        x=x_vals,
        y=y_vals,
        z=z_vals,
        mode='markers',
        marker=dict(
            size=5,
            color=weights,
            colorscale='Viridis',
            colorbar=dict(title='Weights'),
            opacity=0.8,
        ),
    ))

    # Update layout for better visualization
    fig.update_layout(
        title=f'3D Landmarks for Expression',
        scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'),
    )

    # Show the plot
    fig.show(renderer='browser')
    
run_explanation_vis(3)

In [98]:
explainer = Explainer(
model=model,
algorithm=GNNExplainer(
),
explanation_type="model",
node_mask_type="attributes",
edge_mask_type='object',
model_config=dict(
    mode='multiclass_classification',
    task_level='graph',
    return_type='log_probs'
)
)
# explanations for each node
explanations = np.zeros_like(test_data[0].x)
for graph_id in range(len(test_data)):
    explanations += explainer(test_data[graph_id].x, test_data[graph_id].edge_index, target=None, batch=test_data[graph_id].batch).node_mask.numpy()

# Create an empty figure
# fig = go.Figure()

x_vals = [value[0] for value in selected_node.x]
y_vals = [value[1] for value in selected_node.x]
z_vals = [value[2] for value in selected_node.x]
weights = np.mean(explanations, 1)
weights = (weights - weights.min()) / (weights.max() - weights.min())
fig = go.Figure(go.Scatter3d(
    x=x_vals,
    y=y_vals,
    z=z_vals,
    mode='markers',
    marker=dict(
        size=5,
        color=weights,
        colorscale='Viridis',
        colorbar=dict(title='Weights'),
        opacity=0.8,
    ),
))

# Update layout for better visualization
fig.update_layout(
    title=f'3D Landmarks for Expression',
    scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'),
)

# Show the plot
fig.show(renderer='browser')

# GIN looks at half of the face, nose is symmetrically considered, cheeks and eyes are the most important parts
# GAT looks at the whole face, lower face and edges are the most important parts
# SAGE only considers half of the face, top right, nose is not important, cheek, eye and the edge of the lips
# similar to SAGE, but also looks at the other half of the face for the space between eye and nose, much more focus on one side of the face


In [None]:
explainer = Explainer(
model=model,
algorithm=GNNExplainer(
),
explanation_type="model",
node_mask_type="attributes",
edge_mask_type='object',
model_config=dict(
    mode='multiclass_classification',
    task_level='graph',
    return_type='log_probs'
)
)
# explanations for each node
explanations = np.zeros_like(test_data[0].x)
for graph_id in range(len(test_data)):
    explanations += explainer(test_data[graph_id].x, test_data[graph_id].edge_index, target=None, batch=test_data[graph_id].batch).node_mask.numpy()

# Create an empty figure
# fig = go.Figure()

x_vals = [value[0] for value in selected_node.x]
y_vals = [value[1] for value in selected_node.x]
z_vals = [value[2] for value in selected_node.x]
weights = np.mean(explanations, 1)
weights = (weights - weights.min()) / (weights.max() - weights.min())
fig = go.Figure(go.Scatter3d(
    x=x_vals,
    y=y_vals,
    z=z_vals,
    mode='markers',
    marker=dict(
        size=5,
        color=weights,
        colorscale='Viridis',
        colorbar=dict(title='Weights'),
        opacity=0.8,
    ),
))

# Update layout for better visualization
fig.update_layout(
    title=f'3D Landmarks for Expression',
    scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'),
)

# Show the plot
fig.show(renderer='browser')