In [1]:
import NeuroSAT
import torch
import pickle
import explainer_NeuroSAT

from pysat.solvers import Solver

In [2]:
opts = {
        'out_dir': '/Users/trist/Documents/Bachelor-Thesis/NeuroSAT/test/files/data/dataset_train_10',
        'logging': '/Users/trist/Documents/Bachelor-Thesis/NeuroSAT/test/files/log/dataset_train_10.log',
        'n_pairs': 100,  # Anzahl der zu generierenden Paare
        'min_n': 8,
        'max_n': 8,
        'p_k_2': 0.3,
        'p_geo': 0.4,
        'max_nodes_per_batch': 4000,
        'one_pair': False,
        'emb_dim': 128,
        'iterations': 26,
    }

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
with open(opts['out_dir'], 'rb') as file:
    data = pickle.load(file)

In [37]:
from pysat.formula import WCNF
from pysat.formula import CNF
from pysat.examples.musx import MUSX
from pysat.examples.optux import OptUx
import utils

problemBatch = data[0]

current_batch_num = 1

clauses_problem_0 = problemBatch.get_clauses_for_problem(current_batch_num)

print(f"Length of sub_problem 1: {len(clauses_problem_0)}")
# Sub problem 1 Contains literals 11-20, ...
print(f"Sub_problem 1 clauses: {clauses_problem_0}")

solver = Solver(name='m22')

#for clause in clauses_problem_0:
#    solver.add_clause(clause)

# TODO: Change this according to length of literals
offset = problemBatch.n_literals + 1

# TODO: Assumptions must be unique! + 100 will not work in batch as there exists clause 100...
assumptions = [i + offset for i in range(len(clauses_problem_0))]

# Add the clauses with selector literals
for i, clause in enumerate(clauses_problem_0):
    solver.add_clause(clause + [-assumptions[i]])  # Each clause gets a unique assumption
    


    
is_sat = solver.solve(assumptions=assumptions)


# TODO: THIS UNSAT CORE REQUIRES AN ASSUMPTION!!! THEREFORE NOT SUITED FOR PGEXPLAINER?! -> Bypass with additional assumption per clause?
unsat_core = solver.get_core()

reversed_core = torch.tensor(unsat_core[::-1])

# TODO: looks like unsat_core is in reverse order, flip?
# Last element of unsat core is first clause in core_clauses

stats = solver.accum_stats()
    
# Map back to original clauses
core_clauses = [clauses_problem_0[i] for i in range(len(clauses_problem_0)) if assumptions[i] in unsat_core]
    
    
cnf = CNF()
for i, clause in enumerate(clauses_problem_0):
    print(clause)
    cnf.append(clause)  # Each clause gets a unique assumption

print("Computed MUS:")
#print(cnf.to_dimacs())
# Compute Minimally Unsatisfiable Subformulas instead of UNSAT cores? No assumptions needed?
musx = MUSX(cnf)
musx.compute()
print("------------------------------------")

"""with OptUx(cnf) as optux:
    for mus in optux.enumerate():
        print('mus {0} has cost {1}'.format(mus, optux.cost))"""



solver.delete()

print(is_sat)
#print(stats)
print(f"Computed unsat core: {unsat_core}")
print("------------------------------------")
print(core_clauses)


utils.visualize_cnf_interactive(clauses_problem_0)



# This is shifted, literals are indexed 10-19 instead of 11-20!!!!!
# positive literals match! First clause (71/70 in batch_edges) is [17, 12, -14, 20]. In batched batch_edges: 16-70, 11-70, 19-70 matches 17, 12, 20
# negative ones do not match!!!!!
print(data[0].batch_edges)

mask = utils.get_batch_mask(torch.tensor(data[0].batch_edges), batch_idx=current_batch_num, batch_size=10, n_variables=problemBatch.n_variables)

masked_batch_edges = data[0].batch_edges[mask]
print(masked_batch_edges)

print(problemBatch.n_variables)
print(problemBatch.n_literals)




# TODO: Create gt in shape of mask with 0 if edge pair not in core, 1 if edge pair in core!
# -> First clause in sub_problem = 70: 17,12,-14,20 in indexes: 16,11,1413,19 has edges 16-17, 11-70, 1413-70 and 19-70
# -> First clause in gt/core = 100 (subtract 100 because of assumption -> 0, first clause(70)) with content [17, 12, -14, 20]. Convert to edge indexes and map!

unsat_core_edges = reversed_core - offset
# connect first element in unsat_core_edges to literals of first core_clauses

gt = []
for i, idx in enumerate(unsat_core_edges):
    literals = core_clauses[i]
    
    # TODO: Swap clauses_problem_0  with current problem?
    # This only works if all sub_problems have the same amount of clauses!! WRONG!
    clause = len(clauses_problem_0)*current_batch_num + unsat_core_edges[i]
    
    # Sum of problemBatch.n_clauses_per_batch[] before current_batch_num? Or count while calculating gt for data?
    #clause = problemBatch.n_clauses_per_batch[current_batch_num] + unsat_core_edges[i]
    
    for value in literals:
        value = value -1 if value >= 1 else problemBatch.n_variables - (value + 1)
        gt.append([value, clause])
        
motif_size = len(gt)

print(torch.tensor(gt))

print(problemBatch.n_clauses_per_batch)

empty_gt = torch.zeros(len(masked_batch_edges))

print(empty_gt.shape)

test = torch.isin(torch.tensor(masked_batch_edges), torch.tensor(gt))

# We only need the right column of the isin tensor
clausesTest = test[:,1].int()
print(clausesTest)


print(problemBatch.n_variables)


"""gt_mask = []
for i, content in enumerate(masked_batch_edges):
    if masked_batch_edges[i] in unsat_core_edes:
        gt_mask.append(1)
    else:
        gt_mask.append(0)"""
        
test2 = torch.isin(torch.tensor(data[0].batch_edges), torch.tensor(gt))

# We only need the right column of the isin tensor
allClausesTestgt = test2[:,1].int()
        
        
pos = utils.visualize_edge_index_interactive(masked_batch_edges, clausesTest)
pos2 = utils.visualize_edge_index_interactive(masked_batch_edges, clausesTest, "test", pos, motif_size)

Length of sub_problem 1: 73
Sub_problem 1 clauses: [[-12, -15, -16, 13], [14, -19, 16, -18, -13], [-12, 14], [19, 16, -18], [17, -11, -12, -15, -13, -14, -18, -20], [13, -20, -19], [-13, -17, 11, 14, -15], [-18, 17, -20, 11, 16, 19], [-18, 20, 11, 12, -16, -19, 15, -13, -17], [15, -12, -13, -17], [19, 16, -15, -17, -18, 14], [-11, -15, -18, 16], [17, 14, -18, 15, 20], [13, 19, -12, -18], [12, 20, 19], [-12, 14, -15, -20, 13, 19, 16], [-19, 17, -15, 11, 18, -20, -14, 16], [14, 17, 19, 16, -15], [16, 19, -20], [-14, 13], [-20, -16, -13], [-19, 16, -14, -20], [-11, -19], [-14, -16, -20], [-12, 20, 15, -16], [-20, 15, 11], [-17, 11, 13, -12], [-18, 11, -20, -14, -17], [-20, -17, -12], [13, -11, 17, -19, -16, -15], [-11, 16, 14, 12], [-18, -11], [16, -12, 20, -17, -19], [-12, -11, 18], [20, 16, -19, -17], [-20, 15, 13], [15, -12, 17, -11, 14], [-11, 17, -19], [-13, 16, -17], [12, 19, -17], [-15, -11, -16, 18, -20], [16, -18], [16, -13, 14], [-19, 14, -12], [15, 17, -18], [-16, 19, -12], [-1

In [16]:
print(data[0].batch_edges.shape)

print(data[0].n_variables)

batch = []
currentBatch = 0




problem = data[0]
n_variables = problem.n_variables

batch_edges = torch.tensor(problem.batch_edges)           # Shape: (num_edges, 2)

batch_literals = torch.cat([
    torch.arange(0, 40),                   # Positive literals
    torch.arange(n_variables, n_variables + 40)  # Negative literals
])

batch_mask = torch.isin(batch_edges[:, 0], batch_literals)

# Apply the mask
batch_edges_filtered = batch_edges[batch_mask]


print(batch_mask)



"""for i in data[0].batch_edges:
    if i[0] in []"""

(36846, 2)
1440
tensor([ True,  True,  True,  ..., False, False, False])


'for i in data[0].batch_edges:\n    if i[0] in []'

In [32]:
# Example usage:
batch_idx = 2  # Get mask for the 3rd batch (0-based index)
batch_mask = get_batch_mask(torch.tensor(problem.batch_edges), batch_idx)

# Apply the mask to filter edges
batch_edges_filtered = problem.batch_edges[batch_mask]

print(problem.batch_edges[:,0].shape)
print(batch_mask.shape)
#print(torch.tensor(batch_edges_filtered))

(36846,)
torch.Size([36846])


In [None]:
print(len(data[0].is_sat))
print(data[0].clauses)

36
[[-1, -14, 31], [-35, 19, -3, 37], [19, 22, -6, -2, 38, -31, -37, -32], [7, 28, 32], [2, -29], [-14, 4, 2, -31], [19, 17, 13, 10, 15], [-21, 10], [-35, -27, 17, 12, -6], [-35, 26], [27, -30, -15], [26, 10, -15, 18, 30, 13], [-30, -7, -1, 3, 28], [5, 1, 39, -14, 40], [35, 18, -6], [-31, -28, 9, -2, -20, 35, -14, 10, 5], [4, -40, 27], [19, 10], [35, 20, -10, 14, -18, 39, -30, 1, -27], [-19, 38, 31, -20], [20, -22, -23, -7], [-14, 4, 22], [14, 19, 17, 15, 21], [-30, -26, 38], [21, 24, -23], [-22, 8, 20, -3, -31, 36, -29, -34, 19], [-30, 5, -37, 13, -12, -11, 40, 33, 19, 10], [-34, -40], [33, 29, -40, 28, -10, -14, 35, -32], [20, 12, 38, -13, 31], [19, 7, -12, 13, -34, 8], [1, 19], [7, 5, -39, 35, 23, 11, -27], [5, 8, 31, -17, -39, -25], [34, -32, -21, 4, -24, 35, 19, -5, 8, -28], [27, 15, -25, -23, -16, -38, -13], [37, -33, 15, -18, 28, -38], [-4, 34, -7, -40], [14, -33], [-6, 2, -19, -16], [32, -34, -33, -9], [-20, -31, 16, 18, 34], [9, 20, 4], [-18, -9, -30], [11, 24, 34], [23, 3, -2

In [4]:
vote_mean, iteration_votes_sorted, all_literal_embeddings, all_clause_embeddings = downstreamTask.forward(data[0])

print(vote_mean)

print(all_literal_embeddings[-1])

NameError: name 'downstreamTask' is not defined

In [3]:
datasetName="NeuroSAT"

In [4]:
from torch_geometric import seed
import utils
import wandb
import torch.nn.functional as fn
from sklearn.metrics import roc_auc_score
import numpy as np


def trainExplainer (datasetName, opts, save_model=False, wandb_project="Explainer-NeuroSAT", runSeed=None) :
    if runSeed is not None: seed.seed_everything(runSeed)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Check valid dataset name
    configOG = utils.loadConfig(datasetName)
    if configOG == -1:
        return
    
    params = configOG['params']
    graph_task = params['graph_task']
    epochs = params['epochs']
    t0 = params['t0']
    tT = params['tT']
    sampled_graphs = params['sampled_graphs']
    coefficient_size_reg = params['coefficient_size_reg']
    coefficient_entropy_reg = params['coefficient_entropy_reg']
    coefficient_L2_reg = params['coefficient_L2_reg']
    num_explanation_edges = params['num_explanation_edges']
    lr_mlp = params['lr_mlp']

    wandb.init(project=wandb_project, config=params)

    hidden_dim = 64 # Make loading possible
    clip_grad_norm = 2 # Make loading possible
    min_clip_value = -2
    
    
    with open(opts['out_dir'], 'rb') as file:
        data = pickle.load(file)
    
    # TODO: Split data into train and test
    # TODO: !!! each data consist of multiple problems !!! -> Extract singular problems for calculating the loss!
    dataset = data
    
    eval_problem = data[-1]
    
    # gt only needs to be calced once, not per epoch!
    countClauses = 0
    reals = []
    gt_edges_per_problem = []
    for current_batch_num in range(len(eval_problem.is_sat)):
            # This can be repeated for each sub_problem
            clauses_current_problem = eval_problem.get_clauses_for_problem(current_batch_num)

            #print(f"Length of sub_problem 0: {len(clauses_current_problem)}")
            # Sub problem 1 Contains literals 11-20, ...
            #print(f"Sub_problem 0 clauses: {clauses_current_problem}")
            
            # Next part is only for calculating unsat_core -> Move to creation of data and save?
            solver = Solver(name='m22')

            offset = eval_problem.n_literals + 1

            # Assumptions must be unique, therefore eval_problem.n_literals + 1
            assumptions = [i + offset for i in range(len(clauses_current_problem))]

            # Add the clauses with selector literals
            for i, clause in enumerate(clauses_current_problem):
                solver.add_clause(clause + [-assumptions[i]])  # Each clause gets a unique assumption
                
            # is_sat not needed right now as all problems should be unsat
            is_sat = solver.solve(assumptions=assumptions)
            
            unsat_core = solver.get_core()

            # Core contains clauses in reverse order, so we reverse it back
            reversed_core = torch.tensor(unsat_core[::-1])
            # Subtract offset from clauses in core to get original clauses
            unsat_core_clauses = reversed_core - offset
                
            # Map back to original clauses
            core_clause_literals = [clauses_current_problem[i] for i in range(len(clauses_current_problem)) if assumptions[i] in unsat_core]

            solver.delete()
            
            # Calculate mask for current sub_problem in eval_problem
            eval_batch_mask = utils.get_batch_mask(torch.tensor(eval_problem.batch_edges), batch_idx=current_batch_num, batch_size=opts['min_n'], n_variables=eval_problem.n_variables)
            
            
            gt_mask = []
            # TODO: THIS LOOKS WRONG
            for i, idx in enumerate(unsat_core_clauses):
                literals = core_clause_literals[i]
                
                # TODO: This only works if all sub_problems have the same amount of clauses!! WRONG!
                # clause = len(clauses_current_problem)*current_batch_num + unsat_core_clauses[i]
                clause = countClauses + unsat_core_clauses[i]
                
                # Sum of problemBatch.n_clauses_per_batch[] before current_batch_num? Or count while calculating gt for data?
                #clause = problemBatch.n_clauses_per_batch[current_batch_num] + unsat_core_clauses[i]
                
                for value in literals:
                    value = value -1 if value >= 1 else eval_problem.n_variables - (value + 1)
                    gt_mask.append([value, clause])
            
            countClauses = countClauses + len(clauses_current_problem)
                    
            sub_problem_edges = eval_problem.batch_edges[eval_batch_mask]
            
            # TODO: VALIDATE THIS!!!
            gt = torch.isin(torch.tensor(sub_problem_edges), torch.tensor(gt_mask))
            # We only need the right column of the isin tensor
            gt = gt[:,1].int()
            
            reals.append(gt.flatten().numpy())
            gt_edges_per_problem.append(gt_mask)
            #reals_unflattened.append(gt)
    
    allReals = np.concatenate(reals)  # Flatten the list of arrays
    
        

    downstreamTask = NeuroSAT.NeuroSAT(opts=opts,device=device)
    checkpoint = torch.load(f"models/neurosat_sr10to40_ep1024_nr26_d128_last.pth.tar", weights_only=True, map_location=device)
    downstreamTask.load_state_dict(checkpoint['state_dict'])

    mlp = explainer_NeuroSAT.MLP(GraphTask=graph_task).to(device)
    wandb.watch(mlp, log= "all", log_freq=2, log_graph=False)

    mlp_optimizer = torch.optim.Adam(params = mlp.parameters(), lr = lr_mlp)

    downstreamTask.eval()
    for param in downstreamTask.parameters():
        param.requires_grad = False


    training_iterator = dataset
    
    for epoch in range(0, epochs) :
        mlp.train()
        mlp_optimizer.zero_grad()

        temperature = t0*((tT/t0) ** ((epoch+1)/epochs))
        
        #sampledEdges = 0.0
        #sumSampledEdges = 0.0
        
        #samplePredSum = 0

        for index, content in enumerate(training_iterator):
            # stop training before last batch, used for evaluation
            if index == len(training_iterator)-2: break
            node_to_predict = None
            if graph_task: 
                # !! current_problem is really a batch of problems !!
                current_problem = content

            # MLP forward
            # TODO: Implement embeddingCalculation for SAT
            w_ij = mlp.forward(downstreamTask, current_problem, nodeToPred=node_to_predict)

            sampleLoss = torch.FloatTensor([0]).to(device)
            loss = torch.FloatTensor([0]).to(device)
            
            pOriginal, _, _, _ = downstreamTask.forward(current_problem)
            pOriginal = fn.softmax(pOriginal, dim=0)
            
            for k in range(0, sampled_graphs):
                edge_ij = mlp.sampleGraph(w_ij, temperature)
                
                #sampledEdges += torch.sum(edge_ij)
            
                # TODO: softmax needed for loss, beacuse negative values do not work witg log! Need for normalization?
                pSample, _, _, _ = downstreamTask.forward(current_problem, edge_weights=edge_ij)
                pSample = fn.softmax(pSample, dim=0)

                #samplePredSum += torch.sum(torch.argmax(pSample, dim=1))
                
                if graph_task:
                    for sub_problem_idx in range(len(current_problem.is_sat)):
                        # batch_mask needed to differentiate sub_problems in batch of problems for loss
                        # batch_edges cannot simply be divided since sub_problems have different number of edges/clauses
                        # IMPORTANT: batch_size and n variables dependant on data!!
                        batch_mask = utils.get_batch_mask(torch.tensor(current_problem.batch_edges), sub_problem_idx, opts['min_n'], current_problem.n_variables)
                        currLoss = mlp.loss(pOriginal[sub_problem_idx], pSample[sub_problem_idx], edge_ij[batch_mask], coefficient_size_reg, coefficient_entropy_reg)
                        sampleLoss.add_(currLoss)
                    

            loss += sampleLoss / sampled_graphs
            
            #sumSampledEdges += sampledEdges / sampled_graphs

        #print(samplePredSum)
        
        loss = loss / len(training_iterator)
        loss.backward()
        
        mlp_optimizer.step()

        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

        torch.nn.utils.clip_grad_norm_(mlp.parameters(), max_norm=clip_grad_norm)
        
        """for param in mlp.parameters():
            if param.grad is not None:
                param.grad.data = torch.max(param.grad.data, min_clip_value * torch.ones_like(param.grad.data))"""

        mlp.eval()
        
        """if graph_task:
            #TODO: Evaluation for SAT! Needs gt
            meanAuc = evaluation.evaluateNeuroSATAUC(mlp, downstreamTask, data)"""
            
        # Get one sub problem for evaluation:
        preds = []
        highest_auc = 0
        lowest_auc = 1
        
        # Calculate weights and prediction for all sub_problems in eval_problem
        w_ij_eval = mlp.forward(downstreamTask, eval_problem, nodeToPred=node_to_predict)
        edge_ij_eval = mlp.sampleGraph(w_ij_eval, temperature).detach()
        #pSample_eval, _, _, _ = downstreamTask.forward(eval_problem, edge_weights=edge_ij_eval)
        #pOriginal_eval, _, _, _ = downstreamTask.forward(eval_problem)
        
        for current_batch_num in range(len(eval_problem.is_sat)):
            # This can be repeated for each sub_problem
            clauses_current_problem = eval_problem.get_clauses_for_problem(current_batch_num)

            #print(f"Length of sub_problem 0: {len(clauses_current_problem)}")
            # Sub problem 1 Contains literals 11-20, ...
            #print(f"Sub_problem 0 clauses: {clauses_current_problem}")
            
            # Calculate mask for current sub_problem in eval_problem
            eval_batch_mask = utils.get_batch_mask(torch.tensor(eval_problem.batch_edges), batch_idx=current_batch_num, batch_size=opts['min_n'], n_variables=eval_problem.n_variables)
            
            # Edge probabilites for current sub problem in eval_problems
            edge_ij_eval_masked = edge_ij_eval[eval_batch_mask]
                    
            sub_problem_edges = eval_problem.batch_edges[eval_batch_mask]
            
            motif_size = len(gt_mask)
            
            preds.append(edge_ij_eval_masked.cpu().flatten().numpy())
            
            if epoch == epochs-1:
                curr_roc_auc = roc_auc_score(reals[current_batch_num], preds[-1])
                if curr_roc_auc > highest_auc:
                    highest_auc = curr_roc_auc
                    highest_edge_ij = edge_ij_eval_masked
                    highest_gt = reals[current_batch_num]
                    sub_problem_edges_highest = sub_problem_edges
                    highest_index = current_batch_num
                if curr_roc_auc < lowest_auc:
                    lowest_auc = curr_roc_auc
                    lowest_edge_ij = edge_ij_eval_masked
                    lowest_gt = reals[current_batch_num]
                    sub_problem_edges_lowest = sub_problem_edges
                    lowest_index = current_batch_num
                
        
        
        all_preds = np.concatenate(preds)  # Flatten the list of arrays
    
        roc_auc = roc_auc_score(allReals, all_preds)
        
        print(f"Edge probabilites for last sub_problem in last problem batch: {edge_ij_eval_masked}")
        
        #print(f"Prediction for sub_problem: {pOriginal_eval[0]}")
        #print(f"Prediction for sampled sub_problem: {pSample_eval[0]}")
        
        print(f"roc_auc score: {roc_auc}")
    

        # TODO: VISUALIZE TOPK!
        # print sub_problem 0 with calculated edge weights
        if (epoch+1) % 5 == 0:
            pos = utils.visualize_edge_index_interactive(sub_problem_edges, edge_ij_eval_masked, f"results/replication/seed{runSeed}_vis_edge_ij_{epoch+1}", topK=len(gt_edges_per_problem[-1]))
        
        
        #sumSampledEdges = sumSampledEdges / len(training_iterator)
        #, "val/mean_AUC": meanAuc
        wandb.log({"train/Loss": loss, "val/temperature": temperature, "val/roc_auc": roc_auc})

        """for name, param in mlp.named_parameters():
            if param.requires_grad:
                print(f"{name}: {param.grad}")"""
        
    # print sub_problem 0 with calculated edge weights
    pos = utils.visualize_edge_index_interactive(sub_problem_edges, edge_ij_eval_masked, f"results/replication/seed{runSeed}_vis_edge_ij_{epoch+1}", topK=len(gt_edges_per_problem[-1]))
    # print sub_problem 0 with gt
    pos = utils.visualize_edge_index_interactive(sub_problem_edges, reals[-1], f"results/replication/seed{runSeed}_vis_gt", pos)
    
    # TODO: Show difference between topK edges and gt edges -> logical and on topK and gt, visualize
    
    sorted_weights, topk_indices_highest = torch.topk(highest_edge_ij, len(gt_edges_per_problem[highest_index]))
    mask_topK_highest = torch.zeros_like(highest_edge_ij, dtype=torch.float32)
    mask_topK_highest[topk_indices_highest] = 1
    
    common_edges_highest = np.logical_and(highest_gt, mask_topK_highest.flatten().numpy())
    print(f"Highest individual auc: {highest_auc}")
    pos = utils.visualize_edge_index_interactive(sub_problem_edges_highest, highest_edge_ij, f"results/replication/seed{runSeed}_highestAUC", topK=len(gt_edges_per_problem[highest_index]))
    pos = utils.visualize_edge_index_interactive(sub_problem_edges_highest, highest_gt, f"results/replication/seed{runSeed}_highestAUC_gt")
    pos = utils.visualize_edge_index_interactive(sub_problem_edges_highest, common_edges_highest, f"results/replication/seed{runSeed}_highestAUC_commonEdges")
    
    sorted_weights, topk_indices_lowest = torch.topk(lowest_edge_ij, len(gt_edges_per_problem[lowest_index]))
    mask_topK_lowest = torch.zeros_like(lowest_edge_ij, dtype=torch.float32)
    mask_topK_lowest[topk_indices_lowest] = 1
    
    common_edges_lowest = np.logical_and(lowest_gt, mask_topK_lowest.flatten().numpy())
    print(f"Lowest individual auc: {lowest_auc}")
    pos = utils.visualize_edge_index_interactive(sub_problem_edges_lowest, lowest_edge_ij, f"results/replication/seed{runSeed}_lowestAUC", topK=len(gt_edges_per_problem[lowest_index]))
    pos = utils.visualize_edge_index_interactive(sub_problem_edges_lowest, lowest_gt, f"results/replication/seed{runSeed}_lowestAUC_gt")
    pos = utils.visualize_edge_index_interactive(sub_problem_edges_lowest, common_edges_lowest, f"results/replication/seed{runSeed}_lowestAUC_commonEdges")
    
        
    #if save_model == "True":
    #    torch.save(mlp.state_dict(), f"models/explainer_{dataset}_{meanAuc}_{wandb.run.name}")

    wandb.finish()
    
    return mlp, downstreamTask

In [5]:
for i in range(10):
    mlp, downstreamTask = trainExplainer(datasetName,opts, wandb_project="NeuroSAT-seeded-train_val-fixed", runSeed=i)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtristan-schulz2001[0m ([33mtristan-schulz2001-tu-dortmund[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1, Loss: 3.2246246337890625
Edge probabilites for last sub_problem in last problem batch: tensor([0.2640, 0.1898, 0.2090, 0.2527, 0.1919, 0.2398, 0.3254, 0.2421, 0.2595,
        0.1729, 0.2225, 0.2632, 0.2516, 0.2751, 0.2869, 0.2066, 0.2461, 0.1952,
        0.3255, 0.2609, 0.2858, 0.2116, 0.2179, 0.1837, 0.1982, 0.3132, 0.2313,
        0.2369, 0.2493, 0.2519, 0.2026, 0.2100, 0.2300, 0.2195, 0.3396, 0.3098,
        0.2232, 0.2665, 0.2732, 0.2510, 0.3348, 0.2610, 0.2523, 0.2388, 0.2823,
        0.2513, 0.2745, 0.3487, 0.2192, 0.2912, 0.3220, 0.1874, 0.2391, 0.2509,
        0.2603, 0.2082, 0.2814, 0.2155, 0.3437, 0.2862, 0.2337, 0.2469, 0.2093,
        0.2166, 0.2706, 0.2364, 0.2727, 0.2362, 0.1805, 0.2625, 0.2529, 0.2408,
        0.2854, 0.2570, 0.2789, 0.2098, 0.3145, 0.2531, 0.2424, 0.2043, 0.2636,
        0.2763, 0.1903, 0.3822, 0.1870, 0.2537, 0.2393, 0.2579, 0.1626, 0.2374,
        0.3444, 0.2381, 0.2129, 0.3147, 0.2670, 0.2661, 0.2567, 0.2322, 0.3736,
        0.2078, 0.2634, 

0,1
train/Loss,█▇▇▆▆▅▅▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/roc_auc,▁▄▅▆▆▆▇███████████████████████
val/temperature,██▇▇▆▆▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁

0,1
train/Loss,0.824
val/roc_auc,0.64932
val/temperature,1.0


Epoch 1, Loss: 3.347249984741211
Edge probabilites for last sub_problem in last problem batch: tensor([0.3886, 0.3684, 0.3915, 0.4131, 0.3562, 0.3705, 0.3938, 0.3844, 0.3757,
        0.3188, 0.3542, 0.3793, 0.4033, 0.4097, 0.4072, 0.3832, 0.3842, 0.3536,
        0.4024, 0.4014, 0.4039, 0.3851, 0.3782, 0.3358, 0.3703, 0.3879, 0.3641,
        0.3665, 0.3776, 0.3748, 0.3372, 0.3392, 0.3689, 0.3574, 0.3948, 0.3819,
        0.3578, 0.3751, 0.3882, 0.3844, 0.3921, 0.3876, 0.3527, 0.3740, 0.4108,
        0.3999, 0.4078, 0.3929, 0.3756, 0.4023, 0.3976, 0.3517, 0.3752, 0.3986,
        0.3746, 0.3840, 0.4042, 0.3738, 0.3923, 0.3980, 0.3748, 0.3911, 0.3860,
        0.3762, 0.4022, 0.3707, 0.4047, 0.3690, 0.3468, 0.3981, 0.3922, 0.3788,
        0.3993, 0.3860, 0.4120, 0.3697, 0.3891, 0.3883, 0.3719, 0.3756, 0.3664,
        0.3916, 0.3404, 0.4193, 0.3658, 0.4051, 0.3778, 0.3824, 0.3161, 0.3471,
        0.4019, 0.3722, 0.3491, 0.3853, 0.3953, 0.3906, 0.3623, 0.3629, 0.4163,
        0.3651, 0.4038, 0

0,1
train/Loss,██▇▇▆▆▅▅▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/roc_auc,▁▅▇██████▇▇▇▆▆▅▅▅▅▄▄▄▄▄▄▄▄▄▄▄▄
val/temperature,██▇▇▆▆▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁

0,1
train/Loss,0.82402
val/roc_auc,0.67095
val/temperature,1.0


Epoch 1, Loss: 3.248383045196533
Edge probabilites for last sub_problem in last problem batch: tensor([0.3513, 0.2995, 0.2282, 0.2648, 0.2177, 0.2980, 0.3518, 0.3166, 0.2708,
        0.2613, 0.3177, 0.3486, 0.2595, 0.3164, 0.3614, 0.2250, 0.3473, 0.2835,
        0.3252, 0.2480, 0.2972, 0.2148, 0.2306, 0.2715, 0.2067, 0.3111, 0.2738,
        0.3338, 0.3075, 0.3211, 0.2473, 0.3042, 0.3053, 0.2242, 0.3368, 0.3340,
        0.2882, 0.2814, 0.3344, 0.3270, 0.3671, 0.2846, 0.3254, 0.2368, 0.3694,
        0.3409, 0.3224, 0.4148, 0.2360, 0.3173, 0.3336, 0.2863, 0.2969, 0.2517,
        0.3396, 0.2228, 0.3805, 0.2337, 0.3970, 0.3037, 0.3062, 0.3811, 0.2237,
        0.2345, 0.3537, 0.3044, 0.3087, 0.3592, 0.2801, 0.3488, 0.3332, 0.3144,
        0.3144, 0.3585, 0.3742, 0.2263, 0.3159, 0.2411, 0.3366, 0.2110, 0.3297,
        0.2891, 0.2770, 0.4216, 0.2990, 0.2605, 0.3225, 0.3539, 0.2578, 0.3195,
        0.3516, 0.3191, 0.3198, 0.3484, 0.3005, 0.3450, 0.3355, 0.2909, 0.4049,
        0.2304, 0.3545, 0

0,1
train/Loss,██▇▆▆▅▅▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/roc_auc,▁▃▄▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇████████████
val/temperature,██▇▇▆▆▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁

0,1
train/Loss,0.82401
val/roc_auc,0.64713
val/temperature,1.0


Epoch 1, Loss: 3.1977319717407227
Edge probabilites for last sub_problem in last problem batch: tensor([0.2761, 0.2255, 0.2372, 0.2657, 0.2204, 0.2615, 0.2650, 0.2303, 0.2328,
        0.1984, 0.2027, 0.2706, 0.2617, 0.2781, 0.2977, 0.2341, 0.2250, 0.2089,
        0.2869, 0.2452, 0.2488, 0.2227, 0.2288, 0.2032, 0.2196, 0.2785, 0.1963,
        0.2124, 0.2726, 0.2619, 0.1796, 0.1963, 0.2217, 0.2379, 0.2760, 0.2547,
        0.2549, 0.2406, 0.2800, 0.2372, 0.2727, 0.2593, 0.2445, 0.2446, 0.2955,
        0.2923, 0.2757, 0.2941, 0.2378, 0.2518, 0.2872, 0.2156, 0.2074, 0.2535,
        0.2628, 0.2327, 0.2658, 0.2378, 0.2893, 0.2461, 0.2060, 0.2278, 0.2325,
        0.2370, 0.2485, 0.2048, 0.2700, 0.2187, 0.2123, 0.2448, 0.2868, 0.2065,
        0.2482, 0.2774, 0.2985, 0.2274, 0.2839, 0.2427, 0.2186, 0.2201, 0.2502,
        0.2453, 0.2057, 0.3142, 0.2219, 0.2614, 0.2090, 0.2714, 0.2020, 0.2376,
        0.2858, 0.2313, 0.2049, 0.2628, 0.2654, 0.2427, 0.2551, 0.1998, 0.3030,
        0.2350, 0.2457, 

0,1
train/Loss,██▇▆▆▅▄▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/roc_auc,▁▆███▇▇▆▆▅▅▄▄▄▄▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂
val/temperature,██▇▇▆▆▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁

0,1
train/Loss,0.82397
val/roc_auc,0.66563
val/temperature,1.0


Epoch 1, Loss: 3.295910596847534
Edge probabilites for last sub_problem in last problem batch: tensor([0.3666, 0.3385, 0.3222, 0.3526, 0.3071, 0.3442, 0.4200, 0.4260, 0.3732,
        0.3104, 0.4164, 0.3648, 0.3503, 0.3917, 0.4020, 0.3199, 0.4345, 0.3222,
        0.4160, 0.3492, 0.3936, 0.3187, 0.3254, 0.3189, 0.3084, 0.4055, 0.3962,
        0.4284, 0.3563, 0.3502, 0.3697, 0.3980, 0.4105, 0.3170, 0.4165, 0.4099,
        0.3310, 0.3739, 0.3719, 0.4320, 0.4207, 0.3696, 0.3422, 0.3306, 0.4091,
        0.3811, 0.3960, 0.4396, 0.3296, 0.3961, 0.4097, 0.3262, 0.4000, 0.3440,
        0.3564, 0.3248, 0.4623, 0.3282, 0.4348, 0.3899, 0.4027, 0.4348, 0.3265,
        0.3296, 0.4637, 0.4083, 0.3934, 0.4286, 0.3294, 0.4552, 0.3773, 0.4094,
        0.3943, 0.3719, 0.4124, 0.3178, 0.4092, 0.3418, 0.4310, 0.3110, 0.3471,
        0.3870, 0.3176, 0.4528, 0.3351, 0.3524, 0.4094, 0.3709, 0.3095, 0.3340,
        0.4222, 0.4315, 0.4019, 0.4150, 0.3848, 0.4551, 0.3592, 0.4019, 0.4472,
        0.3239, 0.4598, 0

0,1
train/Loss,██▇▇▇▆▅▅▄▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/roc_auc,▁▄▆▇████████████▇▇▇▇▇▇▇▇▇▇▇▇▇▇
val/temperature,██▇▇▆▆▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁

0,1
train/Loss,0.8241
val/roc_auc,0.67508
val/temperature,1.0


Epoch 1, Loss: 3.3989856243133545
Edge probabilites for last sub_problem in last problem batch: tensor([0.4898, 0.4626, 0.4635, 0.5059, 0.4265, 0.4539, 0.3937, 0.4057, 0.4243,
        0.4129, 0.3910, 0.4827, 0.5005, 0.5029, 0.5003, 0.4606, 0.4169, 0.4365,
        0.4559, 0.4855, 0.4564, 0.4476, 0.4489, 0.4267, 0.4414, 0.4463, 0.4163,
        0.4048, 0.4691, 0.4464, 0.3885, 0.3805, 0.3942, 0.4530, 0.3763, 0.3807,
        0.4423, 0.4318, 0.4683, 0.4113, 0.3992, 0.4598, 0.4462, 0.4704, 0.5037,
        0.4987, 0.5149, 0.4397, 0.4565, 0.4739, 0.4667, 0.4453, 0.4339, 0.4950,
        0.4706, 0.4569, 0.4305, 0.4586, 0.4299, 0.4616, 0.4310, 0.4314, 0.4562,
        0.4571, 0.4421, 0.4340, 0.4946, 0.4157, 0.4396, 0.4395, 0.4888, 0.4419,
        0.4782, 0.4921, 0.5071, 0.4439, 0.4467, 0.4785, 0.4086, 0.4432, 0.4539,
        0.4465, 0.4310, 0.4450, 0.4522, 0.5007, 0.4497, 0.4821, 0.4139, 0.4362,
        0.3889, 0.3850, 0.3899, 0.3932, 0.4838, 0.4333, 0.4629, 0.4269, 0.4284,
        0.4492, 0.4407, 

0,1
train/Loss,██▇▇▆▆▅▄▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/roc_auc,▂██▆▅▄▄▃▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val/temperature,██▇▇▆▆▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁

0,1
train/Loss,0.82402
val/roc_auc,0.66737
val/temperature,1.0


Epoch 1, Loss: 3.389061689376831
Edge probabilites for last sub_problem in last problem batch: tensor([0.3767, 0.3428, 0.3349, 0.3313, 0.3717, 0.3722, 0.3953, 0.4866, 0.4157,
        0.3351, 0.4229, 0.3749, 0.3317, 0.3836, 0.3967, 0.3352, 0.4261, 0.3361,
        0.4163, 0.3223, 0.4172, 0.3254, 0.3718, 0.3408, 0.3312, 0.4182, 0.4546,
        0.4297, 0.3757, 0.3963, 0.4528, 0.4191, 0.4808, 0.3149, 0.4427, 0.3984,
        0.3720, 0.4176, 0.3990, 0.4873, 0.3941, 0.3700, 0.3751, 0.3216, 0.3974,
        0.3681, 0.3855, 0.4001, 0.3744, 0.4249, 0.4266, 0.3477, 0.4595, 0.3323,
        0.3806, 0.3328, 0.4723, 0.3778, 0.4009, 0.4259, 0.4624, 0.4435, 0.3304,
        0.3757, 0.5045, 0.4641, 0.3829, 0.4412, 0.3478, 0.5032, 0.3713, 0.4616,
        0.4302, 0.3807, 0.4027, 0.3733, 0.4133, 0.3235, 0.4248, 0.3277, 0.3734,
        0.4171, 0.3376, 0.4538, 0.3428, 0.3272, 0.4675, 0.3834, 0.3383, 0.3809,
        0.4545, 0.4653, 0.4308, 0.4039, 0.3841, 0.5004, 0.3803, 0.4625, 0.4579,
        0.3797, 0.5040, 0

0,1
train/Loss,██▇▇▆▆▅▄▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/roc_auc,▁▃▅▆▇▇▇███████████████████████
val/temperature,██▇▇▆▆▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁

0,1
train/Loss,0.82399
val/roc_auc,0.65014
val/temperature,1.0


Epoch 1, Loss: 3.2331371307373047
Edge probabilites for last sub_problem in last problem batch: tensor([0.2319, 0.1957, 0.2050, 0.2652, 0.2324, 0.2628, 0.3369, 0.2996, 0.3324,
        0.2011, 0.2649, 0.2320, 0.2613, 0.2845, 0.2580, 0.2034, 0.2615, 0.1967,
        0.3719, 0.2623, 0.3365, 0.2084, 0.2423, 0.1983, 0.2067, 0.3657, 0.2327,
        0.2594, 0.2550, 0.2734, 0.2406, 0.2685, 0.3034, 0.2656, 0.3578, 0.3339,
        0.2677, 0.3316, 0.2650, 0.2963, 0.3359, 0.2888, 0.2377, 0.2606, 0.2539,
        0.2500, 0.2811, 0.3360, 0.2416, 0.3314, 0.3665, 0.1948, 0.2291, 0.2596,
        0.2320, 0.2063, 0.3267, 0.2409, 0.3383, 0.3286, 0.2292, 0.2593, 0.2058,
        0.2420, 0.2882, 0.2250, 0.2798, 0.2570, 0.1928, 0.2929, 0.2481, 0.2252,
        0.3257, 0.2305, 0.2538, 0.2400, 0.3700, 0.2622, 0.2614, 0.2087, 0.2367,
        0.3365, 0.2001, 0.3709, 0.1957, 0.2647, 0.2249, 0.2295, 0.2040, 0.2463,
        0.3643, 0.3356, 0.2690, 0.3403, 0.2805, 0.2901, 0.2313, 0.2270, 0.3707,
        0.2360, 0.2960, 

0,1
train/Loss,██▇▇▆▅▅▄▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/roc_auc,▁▃▄▄▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇██████████
val/temperature,██▇▇▆▆▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁

0,1
train/Loss,0.82413
val/roc_auc,0.63385
val/temperature,1.0


Epoch 1, Loss: 3.313356399536133
Edge probabilites for last sub_problem in last problem batch: tensor([0.3335, 0.2993, 0.3172, 0.3478, 0.2968, 0.2727, 0.3735, 0.3414, 0.3552,
        0.2531, 0.3503, 0.3228, 0.3364, 0.3654, 0.3388, 0.3081, 0.3835, 0.2762,
        0.4185, 0.3172, 0.3859, 0.2964, 0.3175, 0.2680, 0.2883, 0.4079, 0.3429,
        0.3727, 0.2868, 0.2855, 0.3016, 0.3299, 0.3217, 0.2780, 0.3527, 0.3518,
        0.2587, 0.3629, 0.3063, 0.3512, 0.3832, 0.3230, 0.2917, 0.2982, 0.3455,
        0.3124, 0.3776, 0.4279, 0.3380, 0.4175, 0.4346, 0.2815, 0.3624, 0.3246,
        0.3147, 0.3213, 0.4300, 0.3448, 0.4141, 0.4016, 0.3591, 0.4075, 0.3195,
        0.3426, 0.3810, 0.3641, 0.3593, 0.3931, 0.2854, 0.3777, 0.3086, 0.3736,
        0.4167, 0.3370, 0.3477, 0.3078, 0.4007, 0.3063, 0.3718, 0.2866, 0.3017,
        0.3721, 0.2667, 0.4279, 0.2947, 0.3401, 0.3800, 0.3318, 0.2512, 0.2870,
        0.3674, 0.3606, 0.3464, 0.3667, 0.3462, 0.3722, 0.3106, 0.3542, 0.4131,
        0.3286, 0.3766, 0

0,1
train/Loss,██▇▇▆▆▅▅▄▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/roc_auc,▂▁▂▂▃▄▄▄▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇██████
val/temperature,██▇▇▆▆▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁

0,1
train/Loss,0.82414
val/roc_auc,0.64111
val/temperature,1.0


Epoch 1, Loss: 3.3611881732940674
Edge probabilites for last sub_problem in last problem batch: tensor([0.3989, 0.3869, 0.4771, 0.5211, 0.4282, 0.3727, 0.3728, 0.3516, 0.4714,
        0.3732, 0.3446, 0.3927, 0.5166, 0.4941, 0.3827, 0.4717, 0.3742, 0.3915,
        0.5192, 0.5078, 0.4858, 0.4690, 0.4519, 0.3819, 0.4601, 0.5155, 0.3507,
        0.3599, 0.3813, 0.3602, 0.3269, 0.3379, 0.3430, 0.4833, 0.3762, 0.3629,
        0.3684, 0.4745, 0.3649, 0.3534, 0.3779, 0.4707, 0.3790, 0.4940, 0.3795,
        0.3780, 0.4934, 0.4196, 0.4650, 0.5028, 0.5202, 0.3841, 0.3588, 0.5069,
        0.3896, 0.4786, 0.3976, 0.4680, 0.4113, 0.4937, 0.3578, 0.3867, 0.4792,
        0.4687, 0.3762, 0.3572, 0.4873, 0.3697, 0.3799, 0.3763, 0.3840, 0.3633,
        0.4921, 0.3896, 0.3813, 0.4425, 0.5166, 0.5001, 0.3634, 0.4629, 0.3914,
        0.4810, 0.3864, 0.4265, 0.3851, 0.5154, 0.3627, 0.3932, 0.3744, 0.3800,
        0.3867, 0.3531, 0.3412, 0.3730, 0.4784, 0.3676, 0.3820, 0.3476, 0.4158,
        0.4529, 0.3767, 

0,1
train/Loss,██▇▇▇▆▆▅▅▄▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/roc_auc,▁▅▇████████▇▇▇▇▆▆▆▆▆▆▆▅▅▅▅▅▅▅▅
val/temperature,██▇▇▆▆▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁

0,1
train/Loss,0.82422
val/roc_auc,0.67429
val/temperature,1.0


## Sweeping config

In [None]:
sweep_config = {
    "method": 'grid',                    # random, grid or Bayesian search
    "metric": {"goal": "maximize", "name": "val/mean_AUC"},
    "parameters": {
        'epochs': {
            'values': [30]
            },
        'tT': {
            'values': [1.0, 5.0]
            },
        'size_reg': {
            'values': [1.0, 0.1, 0.01]
            },
        'entropy_reg': {
            'values': [0.1, 1.0, 10.0]
            },
        'lr_mlp':{
            'values': [0.003, 0.001, 0.0003]
            },
        'seed':{
            'values': [74, 75]
            },
    },
}   

In [None]:
sweep_id = wandb.sweep(sweep_config, project="Sweep-NeuroSAT")

In [None]:
import sweepExplainerNeuroSAT

wandb.agent(sweep_id, sweepExplainerNeuroSAT.trainExplainer)

NameError: name 'wandb' is not defined