In [None]:
DEVICE = 'cuda:0' 
config = {
    "embedding": 32, # Node embeddings size 
    "batch": 512, # Batch size 
    "epochs": 1000, # Training 
    "target": "seconds", # sat OR seconds (classification vs regression)
    "log": False, # Whether model should be trained in log-space 
    "dataset": "cnf.csv", # File with dataset
    "layers": 3, # Number of GNN layers 
    "transformation": "relu", # Node transformation before applying GNN (relu, linear, mlp)
    "gnn": "MLPGraphNorm", # Actual graph network architecture (SAGEConv, SAGE_MLP, GINConv, MLP_GIN, GATConv, MLPGraphNorm)
    "projection": "mlp", # How to combine neighbor embeddings (linear, mlp, mlp-3)
    "readout": "multi", # How to aggregate final node embeddings (lstm, mean, sum, multi)
    "dropout": False,
    "jumping": True, # Whether apply jumping knowledge technique (sum over past values) 
    "encoding": "lcg",
    "small": False, # Use subset of training data
}

In [None]:
from pysat.formula import CNF
from itertools import permutations 
import torch 
import pandas as pd 
from copy import deepcopy 
from tqdm import tqdm 
import matplotlib.pyplot as plt 
import numpy 

from torch_geometric.loader import DataLoader
from torch.nn import Sequential, ReLU, Linear 
import torch.nn.functional as F 

from torch_geometric.nn import MessagePassing
from torch_geometric.nn import HeteroConv, SAGEConv, GAT, GATConv, GINConv 
from torch_geometric.data import HeteroData
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn.aggr import MLPAggregation, LSTMAggregation, SumAggregation, MultiAggregation 
from torch_geometric.nn.models import MLP 
from torch_geometric.utils import degree
from torch import Tensor

import torchmetrics
from torch_geometric.nn import summary 
import time 
import os 
import random 
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Dataset

In [None]:
# Read CNF file and transform to graph

In [None]:
initial_literal_embedding = torch.ones(config['embedding'])
initial_clause_embedding = torch.ones(config['embedding'])

In [None]:
def index_literal(l, n_vars):
    "Return index of the literal."
    # First go positives, then negatives. Ex. [1, 2, 3, -1, -2, -3] 
    return l-1 if l > 0 else (-l-1)+n_vars 

def lig_graph(filepath):
    # LIG encoding 
    cnf = CNF(from_file=filepath)
    n_vars = cnf.nv 
    edges = [] 
    for clause in cnf.clauses:
        for (x_i, x_j) in permutations(clause, r=2):
            x_i = index_literal(x_i, n_vars)
            x_j = index_literal(x_j, n_vars)
            edges.append((x_i, x_j))
    edges = list(set(edges)) # Remove duplicates 
    # Add links between literals of same variable
    edges.extend([(i, i+n_vars) for i in range(n_vars)])
    edges.extend([(i+n_vars, i) for i in range(n_vars)]) # Both directions 
    edges = torch.Tensor(edges).to(torch.int64).T 

    # 2 literals for each variable 
    x = initial_literal_embedding.repeat(n_vars*2, 1) # Initialize node embeddings to ones 
    
    return x, edges 

def lcg_graph(filepath):
    # LCG encoding 
    cnf = CNF(from_file=filepath)
    n_vars = cnf.nv 
    lc_edges = [] 
    cl_edges = [] 
    for i, clause in enumerate(cnf.clauses):
        for x_i in clause:
            x_i = index_literal(x_i, n_vars) 
            # Note: literals and clause has distinct indices 
            lc_edges.append((x_i, i)) 
            cl_edges.append((i, x_i)) 

    # Add links between literals of same variable
    literal_edges = [(i, i+n_vars) for i in range(n_vars)] 
    literal_edges.extend([(i+n_vars, i) for i in range(n_vars)]) # Both directions 

    lc_edges = torch.Tensor(lc_edges).to(torch.int64).T 
    cl_edges = torch.Tensor(cl_edges).to(torch.int64).T 
    literal_edges = torch.Tensor(literal_edges).to(torch.int64).T 

    # 2 literals for each variable 
    x_vars = initial_literal_embedding.repeat(n_vars*2, 1) # copy embedding for each literal 

    x_clauses = initial_clause_embedding.repeat(len(cnf.clauses), 1) # copy embedding for each literal 
    
    return (x_vars, x_clauses), (literal_edges, lc_edges, cl_edges) 

## Pre-process cnf files into graphs

In [None]:
formulas = pd.read_csv(config['dataset'])[['filename', 'seconds', 'sat']] 

formulas = formulas.dropna() # Take only valid data 
formulas = formulas.sample(frac=1, random_state=12345).reset_index(drop=True) # Random permutation 

if config['small']:
 formulas = formulas[:2000] # Choose a subset 
print(f"SAT are {sum(formulas['sat'] == True)}, UNSAT are {sum(formulas['sat'] == False)}")


if config['log']:
    formulas['seconds'] = numpy.log(formulas['seconds'])  # log10
formulas['seconds'].hist() # bins=[0, 10, 100, 500]


mapping = {} 
dataset = [] 
batches = [] 
current_batch = [] 
for i, (file, seconds, sat) in enumerate(tqdm(formulas.values)):  
    x, edges = lcg_graph(file) 
    data = HeteroData() 
    # Add node features 
    data['literals'].x = x[0] 
    data['clauses'].x = x[1] 
    # Add edges 
    data['literals', 'negates', 'literals'].edge_index = edges[0] 
    data['literals', 'inside', 'clauses'].edge_index = edges[1] 
    data['clauses', 'contains', 'literals'].edge_index = edges[2] 
    # Add graph label 
    if config['target'] == "sat":
        data.y = float(sat) 
    else:
        data.y = seconds 
    data.validate() # Throw error if graph is not valid 
    dataset.append(data) 
    mapping[i] = file 
print(f"Total dataset has {len(dataset)} graphs") 

# GNN model class

In [None]:
class MLPGraphConv(MessagePassing):
    def __init__(self, mlp, aggr='mean'):
        super().__init__(aggr=aggr)  # Aggregation can be 'mean', 'add', or 'max'
        # Define the MLP
        self.mlp = mlp 

    def forward(self, x, edge_index):
        # Perform message passing
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        # Apply the MLP to the source node features
        return self.mlp(x_j)


class SAGE_MLP(SAGEConv):
    # Only apply MLP on node embedddings before message propagation 
    def __init__(self, mlp, *args, **kwargs):
        super().__init__(*args, **kwargs) 
        # Define the MLP
        self.mlp = deepcopy(mlp) 

    def message(self, x_j):
        # Apply the MLP to the source node features
        return self.mlp(x_j)


# With normalization 
class MLPGraphNorm(MessagePassing):
    def __init__(self, mlp, aggr='mean'):
        super(MLPGraphNorm, self).__init__(aggr=aggr)  # Aggregation can be 'mean', 'add', or 'max'
        # Define the MLP
        self.mlp = mlp

    def forward(self, x, edge_index):
        """
        Args:
            x: Node feature matrix or a tuple (x_src, x_dst) for bipartite graphs.
            edge_index: Edge index tensor.
        """
        if isinstance(x, tuple):
            x_src, x_dst = x
        elif isinstance(x, torch.Tensor):
            x_src = x_dst = x
        else:
            print("HERE", "STRANDE")
            print(type(x)) 
            print(x, x.shape)
            print(stange)

        # Number of source and target nodes
        num_src, num_dst = x_src.size(0), x_dst.size(0)
        
        # Compute degree for source nodes (row normalization)
        row, col = edge_index
        deg_src = degree(row, num_src, dtype=x_src.dtype)
        deg_inv_sqrt_src = deg_src.pow(-0.5)
        deg_inv_sqrt_src[deg_inv_sqrt_src == float('inf')] = 0

        # Compute degree for target nodes (col normalization)
        deg_dst = degree(col, num_dst, dtype=x_dst.dtype)
        deg_inv_sqrt_dst = deg_dst.pow(-0.5)
        deg_inv_sqrt_dst[deg_inv_sqrt_dst == float('inf')] = 0

        # Perform message passing
        return self.propagate(edge_index, x=(x_src, x_dst), 
                              deg_inv_sqrt_src=deg_inv_sqrt_src, 
                              deg_inv_sqrt_dst=deg_inv_sqrt_dst)

    def message(self, x_j, edge_index, deg_inv_sqrt_src, deg_inv_sqrt_dst):
        """
        Args:
            x_j: Features of source nodes.
            edge_index: Edge index tensor.
            deg_inv_sqrt_src: Normalization factor for source nodes.
            deg_inv_sqrt_dst: Normalization factor for target nodes.
        """
        row, col = edge_index
        norm = deg_inv_sqrt_src[row] * deg_inv_sqrt_dst[col]
        return norm.view(-1, 1) * self.mlp(x_j)

class MLP_GIN(GINConv):
    # Just GIN with MLP applied before message aggregation 
    def __init__(self, mlp, **kwargs):
        super().__init__(**kwargs)
        self.mlp = mlp 
    def message(self, x_j: Tensor) -> Tensor:
        return self.mlp(x_j) 

In [None]:
class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        
        # Define convolution layers for each edge type
        self.convs = torch.nn.ModuleList() 
        for i in range(config['layers']):
            assert config['embedding'] == hidden_dim 

            match config['transformation']:
                case "relu":
                    node_mlp = Sequential(
                        Linear(hidden_dim, hidden_dim),
                        ReLU(),
                        Linear(hidden_dim, hidden_dim)
                    )
                case "linear":
                    node_mlp = Linear(hidden_dim, hidden_dim)
                case "mlp":
                    node_mlp = MLP(in_channels=hidden_dim, hidden_channels=hidden_dim, out_channels=hidden_dim, num_layers=3)  
                case _:
                    print(f"ERROR: '{config['transformation']}' is not defined as transformation operation.")
                    exit(1)
            
            # Used for GIN-based convolution 
            gin_nn = Sequential(
                Linear(hidden_dim, hidden_dim),
                ReLU(),
                Linear(hidden_dim, hidden_dim)
            )

            match config['gnn']:
                case "SAGEConv":
                    relation_dict = {
                        ('literals', 'inside', 'clauses'): SAGEConv(config['embedding'], hidden_dim, aggr='sum', project=False),
                        ('clauses', 'contains', 'literals'): SAGEConv(config['embedding'], hidden_dim, aggr='sum', project=False),
                        ('literals', 'negates', 'literals'): SAGEConv(config['embedding'], hidden_dim, aggr='sum', project=False),
                    }
                case "SAGE_MLP":
                    relation_dict = {
                        ('literals', 'inside', 'clauses'): SAGE_MLP(node_mlp, config['embedding'], hidden_dim, aggr='mean', project=False),
                        ('clauses', 'contains', 'literals'): SAGE_MLP(node_mlp, config['embedding'], hidden_dim, aggr='mean', project=False),
                        ('literals', 'negates', 'literals'): SAGE_MLP(node_mlp, config['embedding'], hidden_dim, aggr='mean', project=False),
                    }
                case "MLPGraphNorm":
                    relation_dict = {
                        ('literals', 'inside', 'clauses'): MLPGraphNorm(node_mlp, aggr='sum'),
                        ('clauses', 'contains', 'literals'): MLPGraphNorm(node_mlp, aggr='sum'),
                        ('literals', 'negates', 'literals'): MLPGraphNorm(node_mlp, aggr='sum'),
                    }
                case "MLP_GIN":
                    relation_dict = {
                        ('literals', 'inside', 'clauses'): MLP_GIN(node_mlp, nn=gin_nn, aggr='sum'),
                        ('clauses', 'contains', 'literals'): MLP_GIN(node_mlp, nn=gin_nn, aggr='sum'),
                        ('literals', 'negates', 'literals'): MLP_GIN(node_mlp, nn=gin_nn, aggr='sum'), 
                    }
                case "GINConv":
                    relation_dict = {
                        ('literals', 'inside', 'clauses'): GINConv(nn=gin_nn, aggr='sum'),
                        ('clauses', 'contains', 'literals'): GINConv(nn=gin_nn, aggr='sum'),
                        ('literals', 'negates', 'literals'): GINConv(nn=gin_nn, aggr='sum'), 
                    }
                case "GATConv":
                    relation_dict = {
                        ('literals', 'inside', 'clauses'): GATConv(config['embedding'], hidden_dim, add_self_loops=False), 
                        ('clauses', 'contains', 'literals'): GATConv(config['embedding'], hidden_dim, add_self_loops=False),
                        ('literals', 'negates', 'literals'): GATConv(config['embedding'], hidden_dim, add_self_loops=False),
                    }
                case _:
                    print(f"ERROR: '{config['gnn']}' is not defined as GNN architecture.")
                    exit(1)

            convs = HeteroConv(relation_dict, aggr='cat') # Aggregation across edge types 
            self.convs.append(convs) 

        self.literal_linears = torch.nn.ModuleList() 
        self.clause_linears = torch.nn.ModuleList() 
        for i in range(config['layers']):
            match config["projection"]:
                case "linear":
                    # Projection layers for each node type
                    literals_linear = torch.nn.Linear(2*hidden_dim, hidden_dim) # Concat of clause and opposite literal embeds 
                    clauses_linear = torch.nn.Linear(hidden_dim, hidden_dim) 
                case "mlp":
                    literals_linear = MLP([2*hidden_dim, hidden_dim, hidden_dim]) 
                    clauses_linear = MLP([hidden_dim, hidden_dim, hidden_dim]) 
                case "mlp-3":
                    literals_linear = MLP(in_channels=2*hidden_dim, hidden_channels=hidden_dim, out_channels=hidden_dim, num_layers=3) # Concat of clause and opposite literal embeds 
                    clauses_linear = MLP(in_channels=hidden_dim, hidden_channels=hidden_dim, out_channels=hidden_dim, num_layers=3)
                case _:
                    print(f"ERROR: '{config['projection']}' is not defined as projection operation.")
                    exit(1)

            self.literal_linears.append(literals_linear) 
            self.clause_linears.append(clauses_linear) 

        match config["readout"]:
            case "lstm":
                self.readout = LSTMAggregation(hidden_dim, hidden_dim, num_layers=3) 
            case "mean":
                self.readout = global_mean_pool 
            case "sum":
                self.readout = SumAggregation() 
            case "multi":
                self.readout = MultiAggregation(['sum', 'mean'], mode='cat') # cat is default mode
            case _:
                print(f"ERROR: '{config['readout']}' is not defined as readout operation.")
                exit(1)

        if config['target'] == "sat":
            # SAT classification 
            self.classifier = MLP(in_channels=-1, hidden_channels=hidden_dim, out_channels=1, num_layers=2) 
        else:
            self.regressor = MLP(in_channels=-1, hidden_channels=hidden_dim, out_channels=1, num_layers=2) 

    def forward(self, data):
        # HeteroConv expects node features and edge_index as a dictionary
        x_dict = data.x_dict
        edge_index_dict = data.edge_index_dict

        # Apply heterogeneous message passing
        jumping_literals = torch.zeros(x_dict['literals'].shape).to(x_dict['literals'].device) 
        jumping_clauses = torch.zeros(x_dict['clauses'].shape).to(x_dict['clauses'].device) 
        for conv, literal_linear, clause_linear in zip(self.convs, self.literal_linears, self.clause_linears):
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()} 
            # Apply transformations for each node type
            x_dict['literals'] = literal_linear(x_dict['literals'])
            x_dict['clauses'] = clause_linear(x_dict['clauses'])

            jumping_literals += x_dict['literals'] 
            jumping_clauses += x_dict['clauses'] 


        # Readout operation: Aggregate the node features of both 'literals' and 'clauses'
        # We'll aggregate over all nodes of the graph
        if config['jumping']:
            x_graph = torch.cat([self.readout(jumping_literals, data.batch['literals']), self.readout(jumping_clauses, data.batch['clauses'])], 1)
        else:
            x_graph = torch.cat([self.readout(x_dict['literals'], data.batch['literals']), self.readout(x_dict['clauses'], data.batch['clauses'])], 1)
        
        if config['target'] == "sat":
            if config['dropout']:
                x_graph = F.dropout(x_graph, p=0.5, training=self.training) 
            res = torch.sigmoid(self.classifier(x_graph))
        else:            
            # Apply a final regressor
            if config['dropout']:
                x_graph = F.dropout(x_graph, p=0.5, training=self.training) # Dropout to avoid overfit
            res = self.regressor(x_graph) 

        return res 

# Training

In [None]:
test_split = len(dataset)*8//10 
# Create a DataLoader for batching
train_loader = DataLoader(dataset[:test_split], batch_size=config['batch'], shuffle=False) 
test_loader = DataLoader(dataset[test_split:], batch_size=config['batch'], shuffle=False) 

print(len(train_loader), len(test_loader))

In [None]:
def compute_batch_assignment(batch, node_type):
    """
    Given a batch of graphs in a HeteroDataBatch, extract the '.ptr' and compute the 'batch' tensor that assigns each node to its graph based on the '.ptr' values.
    
    Args:
        batch (HeteroDataBatch): A batch containing multiple graphs.
    
    Returns:
        batch_tensor (Tensor): A tensor of size (total number of nodes across all graphs, ) containing the graph index for each node.
    """
    # Initialize an empty list to collect the batch indices for each graph
    batch_tensor = []
    
    # Access the 'ptr' for each graph type in the batch (should be present in the batch)
    # Get the prefix sum (ptr) tensor
    ptr = batch[node_type].ptr
    
    # Loop over each graph in the batch (using ptr values)
    for i in range(len(ptr) - 1):
        # Assign the same batch index for each node in the current graph
        batch_tensor.extend([i] * (ptr[i + 1] - ptr[i]))
    
    # Convert the list to a tensor
    return torch.tensor(batch_tensor, dtype=torch.long)



In [None]:
def train():
    model.train()

    for batch in train_loader:  # Iterate in batches over the training dataset.
        batch.batch = {'literals': compute_batch_assignment(batch, 'literals'), 'clauses': compute_batch_assignment(batch, 'clauses')} 
        batch = batch.to(DEVICE) 
        out = model(batch).squeeze()  # Perform a single forward pass. 
        loss = criterion(out, batch.y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        # Perform gradient clipping by value
        torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=5.) 
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.



def test(loader):
    model.eval()

    total_out = [] 
    total_y = [] 
    with torch.no_grad():
        for batch in loader:
            batch.batch = {'literals': compute_batch_assignment(batch, 'literals'), 'clauses': compute_batch_assignment(batch, 'clauses')} 
            batch = batch.to(DEVICE) 
            out = model(batch).squeeze() 
            total_out.append(out) 
            total_y.append(batch.y) 

    total_out = torch.cat(total_out, 0)
    total_y = torch.cat(total_y, 0)

    return total_out, total_y 

def compute_main_metric(out, y):
    total_loss = criterion(out, y) 
    return total_loss 

def compute_metrics(out, y):
    metrics = {} 
    if config['target'] == "sat": 
        metrics['Entropy'] = torch.nn.functional.binary_cross_entropy(out, y) 
        metrics['Accuracy'] = (torch.round(out) == y).sum() / len(out) # Count correct predictions 
    else: 
        metrics['MSE'] = torch.nn.functional.mse_loss(out, y) 
        metrics['L1'] = torch.nn.functional.l1_loss(out, y) 
        metrics['R2'] = torchmetrics.functional.r2_score(out, y) 
        metrics['MAPE'] = torchmetrics.functional.mean_absolute_percentage_error(out, y) 
        metrics['Spearman'] = torchmetrics.functional.spearman_corrcoef(out, y) 

    return metrics 

def plot_predictions(out, y, title='Test data', file=None):
    if config['target'] == "sat":
        # Confusion matrix
        disp = ConfusionMatrixDisplay.from_predictions(y.cpu(), out.cpu() > 0.5, normalize='true', display_labels=['UNSAT', 'SAT'], cmap=plt.cm.Blues)
        plt.title(title)
        if file:
            plt.savefig(file)
        else:
            plt.show() 
        
    else:
        # Regression task 
        fig, ax = plt.subplots(figsize=(5, 5)) 
        preds = out.cpu() 
        reals = y.cpu() 

        ax.set_xlim((0, 500.)) 
        ax.set_ylim((0, 500.))
        ax.set_aspect("equal")

        ax.set_title(title)
        ax.set_ylabel("Predictions")
        ax.set_xlabel("Real times")
        ax.scatter(reals, preds, s=2, color='blue', label='Model outputs') 
        ax.axline((0, 0), slope=1., color='red', label='Perfect') 
        ax.grid()
        ax.legend() 
        if file:
            plt.savefig(file)
        else:
            plt.show() 

In [None]:
model = HeteroGNN(hidden_dim=config['embedding']).to(DEVICE) 
for batch in train_loader:
    batch.batch = {'literals': compute_batch_assignment(batch, 'literals'), 'clauses': compute_batch_assignment(batch, 'clauses')} 
    batch = batch.to(DEVICE)
    print(summary(model, batch, max_depth=6)) 
    break 

### Main loop

In [None]:
torch.manual_seed(12345)
random.seed(12345)
numpy.random.seed(12345)

# Initialize model 
model = HeteroGNN(hidden_dim=config['embedding']).to(DEVICE) 
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 
if config['target'] == "sat":
    criterion = torch.nn.BCELoss() # only 2 classes 
else:
    criterion = torch.nn.MSELoss() 


# Save model configuration 
timestamp = int(time.time()) 
os.mkdir(f"logs/{timestamp}")
with open(f"logs/{timestamp}/config.txt", 'w') as f:
    for key in config:
        f.write(f'{key}={config[key]}\n') 

# Before training (evaluate initial random model)
train_out, train_y = test(train_loader) 
test_out, test_y = test(test_loader) 
# Return to linear-space 
if config['log']:
    train_out = torch.exp(train_out)
    train_y = torch.exp(train_y)
    test_out = torch.exp(test_out)
    test_y = torch.exp(test_y) 

train_loss = compute_main_metric(train_out, train_y) 
metrics = compute_metrics(test_out, test_y) 
for metric in metrics:
    print(f'Test {metric} = {metrics[metric]}') 

# Store in files 
with open(f"logs/{timestamp}/init-metrics.txt", 'w') as f:
    for metric in metrics:
        f.write(f'Test {metric} = {metrics[metric]}\n') 
    if config['target'] == 'seconds':
        # For regression
        f.write(f"{train_loss} & {metrics['MSE']} & {metrics['L1']} & {metrics['R2']} & {metrics['MAPE']} & {metrics['Spearman']}\n")
plot_predictions(train_out, train_y, title='Train data', file=f'logs/{timestamp}/init-train.png') 
plot_predictions(test_out, test_y, title='Test data', file=f'logs/{timestamp}/init-test.png') 




# Actual training 
best_loss = 1e10 
best_epoch = 0 
best_model = deepcopy(model) 
history = {}
for epoch in tqdm(range(1, config['epochs']+1), desc='Epochs'): 
    train() 
    train_out, train_y = test(train_loader) 
    train_loss = compute_main_metric(train_out, train_y) 

    # Save best model so far 
    if train_loss < best_loss:
        best_loss = train_loss 
        best_model = deepcopy(model) 
        best_epoch = epoch 

        print(f'New best: {epoch:03d}, Train Loss: {train_loss:.6f}') 

        test_out, test_y = test(test_loader) 
        # Return to linear-space 
        if config['log']:
            train_out = torch.exp(train_out)
            train_y = torch.exp(train_y)
            test_out = torch.exp(test_out)
            test_y = torch.exp(test_y) 

        metrics = compute_metrics(test_out, test_y) 
        for metric in metrics:
            print(f'Test {metric} = {metrics[metric]}') 
        
        plot_predictions(train_out, train_y, title='Train data') 
        plot_predictions(test_out, test_y, title='Test data') 
        print() # Empty line 

        history[epoch] = {
            "train_loss": train_loss.cpu().item()
        }
        history[epoch].update({ k: metrics[k].cpu().item() for k in metrics }) 
        

    if epoch % 10 == 0:
        test_out, test_y = test(test_loader) 
        metrics = compute_metrics(test_out, test_y) 
        print(f'Train Loss: {train_loss:.6f}') 
        for metric in metrics:
            print(f'Test {metric} = {metrics[metric]}') 
        print("-"*80)

print(f"Best result is {best_loss} at {best_epoch}") 



# After training
model = best_model # Evaluate the best model 
train_out, train_y = test(train_loader) 
test_out, test_y = test(test_loader) 
# Return to linear-space 
if config['log']:
    train_out = torch.exp(train_out)
    train_y = torch.exp(train_y)
    test_out = torch.exp(test_out)
    test_y = torch.exp(test_y) 

train_loss = compute_main_metric(train_out, train_y) 
metrics = compute_metrics(test_out, test_y) 
for metric in metrics:
    print(f'Test {metric} = {metrics[metric]}') 

# Store in files 
with open(f"logs/{timestamp}/metrics.txt", 'w') as f:
    f.write(f"Train loss: {train_loss} at {best_epoch}\n")
    for metric in metrics:
        f.write(f'Test {metric} = {metrics[metric]}\n') 
    if config['target'] == 'seconds':
        # For regression 
        f.write(f"{best_epoch} & {train_loss} & {metrics['MSE']} & {metrics['L1']} & {metrics['R2']} & {metrics['MAPE']} & {metrics['Spearman']}\n")
plot_predictions(train_out, train_y, title='Train data', file=f'logs/{timestamp}/train.png') 
plot_predictions(test_out, test_y, title='Test data', file=f'logs/{timestamp}/test.png') 

with open(f"logs/{timestamp}/architecture.txt", 'w') as f:
    for batch in train_loader:
        batch.batch = {'literals': compute_batch_assignment(batch, 'literals'), 'clauses': compute_batch_assignment(batch, 'clauses')} 
        batch = batch.to(DEVICE)
        f.write(summary(model, batch, max_depth=5)) 
        break 
# Store best model 
torch.save(best_model, f"logs/{timestamp}/model.pt") 


t_end = int(time.time()) 
with open(f"logs/{timestamp}/time.txt", 'w') as f:
    f.write(f"{t_end - timestamp}") # in seconds

df_history = pd.DataFrame([history[k] for k in sorted(history.keys())], index=sorted(history.keys())) 
df_history.to_csv(f"logs/{timestamp}/history.csv", index_label="Epoch") 

print(f"Best result is {best_loss} at {best_epoch}")