In [4]:
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing, global_mean_pool, MLP, global_add_pool
import torch
from torch_geometric.nn import GCNConv

class MPNN(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(MPNN, self).__init__(aggr='add')
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        return x_j

    def update(self, aggr_out):
        return self.lin(aggr_out)

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.etl_embedding_dimensions = 16
        self.htl_embedding_dimensions = 16
        # self.absorber_embedding_dimensions = 16
        self.absorber_dimensions = 76

        self.bandgap_dimension = 1
        # arbitrary choice:
        self.hidden_dimension = 32
        number_of_regression_layers = 3

        # in_channels are describing the number of node features, atom-type, weight, polarity...
        # self.etl_mpnn = MPNN(in_channels=1, out_channels=self.etl_embedding_dimensions)
        # self.htl_mpnn = MPNN(in_channels=1, out_channels=self.htl_embedding_dimensions)
        self.etl_mpnn = GCNConv(in_channels=1, out_channels=self.etl_embedding_dimensions)
        self.htl_mpnn = GCNConv(in_channels=1, out_channels=self.htl_embedding_dimensions)

        self.fc1 = torch.nn.Linear(self.etl_embedding_dimensions + self.htl_embedding_dimensions + self.absorber_dimensions +self.bandgap_dimension, self.hidden_dimension)
        
        self.regression_layers = torch.nn.ModuleList([self.fc1])
        self.regression_layers.extend([torch.nn.Linear( self.hidden_dimension,  self.hidden_dimension) for i in range(1, number_of_regression_layers-1)])
        # 1, because we just want to predict pce:
        self.fc_out = torch.nn.Linear(self.hidden_dimension, 1)
        
        # self.fc_absorber = torch.nn.Linear(self.absorber_dimensions, self.absorber_embedding_dimensions)

    def forward(self, etl_features, htl_features, etl_edge_indices, htl_edge_indices, absorbers, bandgap):
        etl_x = self.etl_mpnn(etl_features, etl_edge_indices)
        etl_x = global_mean_pool(etl_x, torch.zeros(etl_x.size(0), dtype=torch.long))
        
        htl_x = self.htl_mpnn(htl_features, htl_edge_indices)
        htl_x = global_mean_pool(htl_x, torch.zeros(htl_x.size(0), dtype=torch.long))
        
        # rich.print(f"{etl_x.shape=}")
        # rich.print(f"{htl_x.shape=}")
        # rich.print(f"{absorbers.shape=}")
        # rich.print(f"{htl_x.shape=}")

        # replace etl and htl information with zeros to isolate effect of absorber
        #etl_x = torch.zeros([1,self.etl_embedding_dimensions])
        #htl_x = torch.zeros([1,self.htl_embedding_dimensions])

        # absorbers_embed = self.fc_absorber(absorbers)
        # x = torch.cat([etl_x, htl_x, absorbers_embed], dim=1) 
        
        x = torch.cat([etl_x, htl_x, absorbers, bandgap], dim=1)
        for layer in self.regression_layers: 
            x = F.leaky_relu(layer(x))
        
        x = self.fc_out(x)
        return x

# mlp = MLP(in_channels=16, hidden_channels=32,
#           out_channels=128, num_layers=3)


In [None]:
model = Net()
model.load_state_dict(torch.load('NORTH/GNN_no_ctl_graph.pth'))
model.eval()

In [None]:
all_predictions = []
all_labels = []

# Evaluate the model
with torch.no_grad():  # Disable gradient calculation for evaluation
    for batch in eval_loader:
        etl_features = batch['etl_features'].squeeze(0)
        htl_features = batch['htl_features'].squeeze(0)
        etl_edge_indices = batch['etl_edge_indices'].squeeze(0)
        htl_edge_indices = batch['htl_edge_indices'].squeeze(0)
        absorber = batch['absorber']
        bandgap = batch['bandgap']
        true_labels = batch['pce']
        
        # Forward pass
        predictions = model(etl_features, htl_features etl_edge_indices, htl_edge_indices, absorber, bandgap)
        
        # Store predictions and true labels
        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(true_labels.cpu().numpy())

# Compute evaluation metrics (e.g., Mean Squared Error)
mse = mean_squared_error(all_labels, all_predictions)
print(f"Mean Squared Error: {mse}")

# You can also compute other metrics like R^2 score, MAE, etc.
from sklearn.metrics import r2_score, mean_absolute_error

r2 = r2_score(all_labels, all_predictions)
mae = mean_absolute_error(all_labels, all_predictions)

print(f"R^2 Score: {r2}")
print(f"Mean Absolute Error: {mae}")