In [3]:
import NeuroSAT
import torch
import pickle
import explainer_NeuroSAT

from pysat.solvers import Solver

In [4]:
opts = {
        'out_dir': '/Users/trist/Documents/Bachelor-Thesis/NeuroSAT/test/files/data/dataset_train_10',
        'logging': '/Users/trist/Documents/Bachelor-Thesis/NeuroSAT/test/files/log/dataset_train_10.log',
        'n_pairs': 100,  # Anzahl der zu generierenden Paare
        'min_n': 8,
        'max_n': 8,
        'p_k_2': 0.3,
        'p_geo': 0.4,
        'max_nodes_per_batch': 4000,
        'one_pair': False,
        'emb_dim': 128,
        'iterations': 26,
    }

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
with open(opts['out_dir'], 'rb') as file:
    data = pickle.load(file)

In [None]:
from pysat.formula import WCNF
from pysat.formula import CNF
from pysat.examples.musx import MUSX
from pysat.examples.optux import OptUx
import utils

problemBatch = data[0]

current_batch_num = 5

clauses_problem_0 = problemBatch.get_clauses_for_problem(current_batch_num)

print(f"Length of sub_problem 1: {len(clauses_problem_0)}")
# Sub problem 1 Contains literals 11-20, ...
print(f"Sub_problem 1 clauses: {clauses_problem_0}")

solver = Solver(name='m22')

#for clause in clauses_problem_0:
#    solver.add_clause(clause)

# TODO: Change this according to length of literals
offset = problemBatch.n_literals + 1

# TODO: Assumptions must be unique! + 100 will not work in batch as there exists clause 100...
assumptions = [i + offset for i in range(len(clauses_problem_0))]

# Add the clauses with selector literals
for i, clause in enumerate(clauses_problem_0):
    solver.add_clause(clause + [-assumptions[i]])  # Each clause gets a unique assumption
    


    
is_sat = solver.solve(assumptions=assumptions)


# TODO: THIS UNSAT CORE REQUIRES AN ASSUMPTION!!! THEREFORE NOT SUITED FOR PGEXPLAINER?! -> Bypass with additional assumption per clause?
unsat_core = solver.get_core()

reversed_core = torch.tensor(unsat_core[::-1])

# TODO: looks like unsat_core is in reverse order, flip?
# Last element of unsat core is first clause in core_clauses

stats = solver.accum_stats()
    
# Map back to original clauses
core_clauses = [clauses_problem_0[i] for i in range(len(clauses_problem_0)) if assumptions[i] in unsat_core]
    
    
cnf = CNF()
for i, clause in enumerate(clauses_problem_0):
    print(clause)
    cnf.append(clause)

print("Computed MUS:")
#print(cnf.to_dimacs())
# Compute Minimally Unsatisfiable Subformulas instead of UNSAT cores? No assumptions needed?
musx = MUSX(cnf)
mus = musx.compute()
print("------------------------------------")
print(f"lenght of computed MUS: {len(mus)}")

"""with OptUx(cnf) as optux:
    for mus in optux.enumerate():
        print('mus {0} has cost {1}'.format(mus, optux.cost))"""



# TODO: USE MUS INSTEAD OF UNSAT CORE
# Computed mus contains the relative clause numbers for the current problem -> Need to be mapped to edges of original clause

eval_problem = data[0]
sub_problem_start_clause = 0

current_batch_num = 12

clauses_problem_i = problemBatch.get_clauses_for_problem(current_batch_num)

cnf = CNF()
for i, clause in enumerate(clauses_problem_i):
    #print(clause)
    cnf.append(clause)

musx = MUSX(cnf)
# mus contains list of relative clause numbers for the current problem e.g. [4, 7, ..., 53], starting at 1
mus = musx.compute()

# Map mus list to 
eval_batch_mask = utils.get_batch_mask(torch.tensor(eval_problem.batch_edges), batch_idx=current_batch_num, batch_size=opts['min_n'], n_variables=eval_problem.n_variables)

masked_batch_edges = eval_problem.batch_edges[eval_batch_mask]


print(mus)
print(len(torch.unique(torch.tensor(masked_batch_edges[:, 1]))))





solver.delete()
"""
print(is_sat)
#print(stats)
print(f"Computed unsat core length: {len(unsat_core)}")
print(f"Computed unsat core: {unsat_core}")
print("------------------------------------")
print(core_clauses)


utils.visualize_cnf_interactive(clauses_problem_0)



# This is shifted, literals are indexed 10-19 instead of 11-20!!!!!
# positive literals match! First clause (71/70 in batch_edges) is [17, 12, -14, 20]. In batched batch_edges: 16-70, 11-70, 19-70 matches 17, 12, 20
# negative ones do not match!!!!!
print(data[0].batch_edges)

mask = utils.get_batch_mask(torch.tensor(data[0].batch_edges), batch_idx=current_batch_num, batch_size=10, n_variables=problemBatch.n_variables)

masked_batch_edges = data[0].batch_edges[mask]
print(masked_batch_edges)

print(problemBatch.n_variables)
print(problemBatch.n_literals)




# TODO: Create gt in shape of mask with 0 if edge pair not in core, 1 if edge pair in core!
# -> First clause in sub_problem = 70: 17,12,-14,20 in indexes: 16,11,1413,19 has edges 16-17, 11-70, 1413-70 and 19-70
# -> First clause in gt/core = 100 (subtract 100 because of assumption -> 0, first clause(70)) with content [17, 12, -14, 20]. Convert to edge indexes and map!

unsat_core_edges = reversed_core - offset
# connect first element in unsat_core_edges to literals of first core_clauses

gt = []
for i, idx in enumerate(unsat_core_edges):
    literals = core_clauses[i]
    
    # TODO: Swap clauses_problem_0  with current problem?
    # This only works if all sub_problems have the same amount of clauses!! WRONG!
    clause = len(clauses_problem_0)*current_batch_num + unsat_core_edges[i]
    
    # Sum of problemBatch.n_clauses_per_batch[] before current_batch_num? Or count while calculating gt for data?
    #clause = problemBatch.n_clauses_per_batch[current_batch_num] + unsat_core_edges[i]
    
    for value in literals:
        value = value -1 if value >= 1 else problemBatch.n_variables - (value + 1)
        gt.append([value, clause])
        
motif_size = len(gt)

print(torch.tensor(gt))

print(problemBatch.n_clauses_per_batch)

empty_gt = torch.zeros(len(masked_batch_edges))

print(empty_gt.shape)

test = torch.isin(torch.tensor(masked_batch_edges), torch.tensor(gt))

# We only need the right column of the isin tensor
clausesTest = test[:,1].int()
print(clausesTest)


print(problemBatch.n_variables)


gt_mask = []
for i, content in enumerate(masked_batch_edges):
    if masked_batch_edges[i] in unsat_core_edes:
        gt_mask.append(1)
    else:
        gt_mask.append(0)
        
test2 = torch.isin(torch.tensor(data[0].batch_edges), torch.tensor(gt))

# We only need the right column of the isin tensor
allClausesTestgt = test2[:,1].int()
        
        
pos = utils.visualize_edge_index_interactive(masked_batch_edges, clausesTest)
pos2 = utils.visualize_edge_index_interactive(masked_batch_edges, clausesTest, "test", pos, motif_size)"""

Length of sub_problem 1: 53
Sub_problem 1 clauses: [[-42, 47, -48, 46], [42, -47, 46, 43, -45], [-44, 42, -47, -41], [41, 42, 45, 48], [47, 43, -44, 42, 41], [48, -44, 46, 41], [-48, 42, 47], [-44, 41, -43], [42, 44, -48, -46, 43, 41, 47], [-47, -44, 45], [46, -42, 48, 43], [46, -43, 42, -48], [-44, -46, 41, 43, -42], [42, -44, -41, -47, -45], [46, -45, -41, -48], [-42, -46, 47], [-48, -45], [-46, -45, 44], [46, -48, -42, -43, 45, -44, 47, 41], [44, 48, 41], [47, 41, -46, -43, 42], [44, 46, 43, 47], [-46, -47, -45, 44, 42], [-41, -46, 48, -43], [48, -45, 44, -41, -47], [43, 44, -46, -42], [-47, -43, -46, -42, -44], [44, 46, 47], [-41, -42, 44, 45, 46, 43], [44, -41, -43, -47], [47, 42, -41], [-41, 48, 44], [-42, -46], [-44, 45, -43], [-44, -41], [-43, -44], [41, 42, 48, 43, 45, 47, -46, 44], [46, -42, -44, -41, 47, 48], [48, 44, -42, -45, -46, -41, -43], [-45, 41, 46], [43, 42, -41, 47], [-44, 46], [-43, 48, -45, 44, -47], [43, -45, 47, -46, -42], [47, 45, 41, 43, 48, -44, 42, 46], [46

'\nprint(is_sat)\n#print(stats)\nprint(f"Computed unsat core length: {len(unsat_core)}")\nprint(f"Computed unsat core: {unsat_core}")\nprint("------------------------------------")\nprint(core_clauses)\n\n\nutils.visualize_cnf_interactive(clauses_problem_0)\n\n\n\n# This is shifted, literals are indexed 10-19 instead of 11-20!!!!!\n# positive literals match! First clause (71/70 in batch_edges) is [17, 12, -14, 20]. In batched batch_edges: 16-70, 11-70, 19-70 matches 17, 12, 20\n# negative ones do not match!!!!!\nprint(data[0].batch_edges)\n\nmask = utils.get_batch_mask(torch.tensor(data[0].batch_edges), batch_idx=current_batch_num, batch_size=10, n_variables=problemBatch.n_variables)\n\nmasked_batch_edges = data[0].batch_edges[mask]\nprint(masked_batch_edges)\n\nprint(problemBatch.n_variables)\nprint(problemBatch.n_literals)\n\n\n\n\n# TODO: Create gt in shape of mask with 0 if edge pair not in core, 1 if edge pair in core!\n# -> First clause in sub_problem = 70: 17,12,-14,20 in indexe

In [16]:
print(data[0].batch_edges.shape)

print(data[0].n_variables)

batch = []
currentBatch = 0




problem = data[0]
n_variables = problem.n_variables

batch_edges = torch.tensor(problem.batch_edges)           # Shape: (num_edges, 2)

batch_literals = torch.cat([
    torch.arange(0, 40),                   # Positive literals
    torch.arange(n_variables, n_variables + 40)  # Negative literals
])

batch_mask = torch.isin(batch_edges[:, 0], batch_literals)

# Apply the mask
batch_edges_filtered = batch_edges[batch_mask]


print(batch_mask)



"""for i in data[0].batch_edges:
    if i[0] in []"""

(36846, 2)
1440
tensor([ True,  True,  True,  ..., False, False, False])


'for i in data[0].batch_edges:\n    if i[0] in []'

In [32]:
# Example usage:
batch_idx = 2  # Get mask for the 3rd batch (0-based index)
batch_mask = get_batch_mask(torch.tensor(problem.batch_edges), batch_idx)

# Apply the mask to filter edges
batch_edges_filtered = problem.batch_edges[batch_mask]

print(problem.batch_edges[:,0].shape)
print(batch_mask.shape)
#print(torch.tensor(batch_edges_filtered))

(36846,)
torch.Size([36846])


In [None]:
print(len(data[0].is_sat))
print(data[0].clauses)

36
[[-1, -14, 31], [-35, 19, -3, 37], [19, 22, -6, -2, 38, -31, -37, -32], [7, 28, 32], [2, -29], [-14, 4, 2, -31], [19, 17, 13, 10, 15], [-21, 10], [-35, -27, 17, 12, -6], [-35, 26], [27, -30, -15], [26, 10, -15, 18, 30, 13], [-30, -7, -1, 3, 28], [5, 1, 39, -14, 40], [35, 18, -6], [-31, -28, 9, -2, -20, 35, -14, 10, 5], [4, -40, 27], [19, 10], [35, 20, -10, 14, -18, 39, -30, 1, -27], [-19, 38, 31, -20], [20, -22, -23, -7], [-14, 4, 22], [14, 19, 17, 15, 21], [-30, -26, 38], [21, 24, -23], [-22, 8, 20, -3, -31, 36, -29, -34, 19], [-30, 5, -37, 13, -12, -11, 40, 33, 19, 10], [-34, -40], [33, 29, -40, 28, -10, -14, 35, -32], [20, 12, 38, -13, 31], [19, 7, -12, 13, -34, 8], [1, 19], [7, 5, -39, 35, 23, 11, -27], [5, 8, 31, -17, -39, -25], [34, -32, -21, 4, -24, 35, 19, -5, 8, -28], [27, 15, -25, -23, -16, -38, -13], [37, -33, 15, -18, 28, -38], [-4, 34, -7, -40], [14, -33], [-6, 2, -19, -16], [32, -34, -33, -9], [-20, -31, 16, 18, 34], [9, 20, 4], [-18, -9, -30], [11, 24, 34], [23, 3, -2

In [4]:
vote_mean, iteration_votes_sorted, all_literal_embeddings, all_clause_embeddings = downstreamTask.forward(data[0])

print(vote_mean)

print(all_literal_embeddings[-1])

NameError: name 'downstreamTask' is not defined

In [3]:
datasetName="NeuroSAT"

In [6]:
from torch_geometric import seed
import utils
import wandb
import torch.nn.functional as fn
from sklearn.metrics import roc_auc_score
import numpy as np


def trainExplainer (datasetName, opts, save_model=False, wandb_project="Explainer-NeuroSAT", runSeed=None) :
    if runSeed is not None: seed.seed_everything(runSeed)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Check valid dataset name
    configOG = utils.loadConfig(datasetName)
    if configOG == -1:
        return
    
    params = configOG['params']
    graph_task = params['graph_task']
    epochs = params['epochs']
    t0 = params['t0']
    tT = params['tT']
    sampled_graphs = params['sampled_graphs']
    coefficient_size_reg = params['coefficient_size_reg']
    coefficient_entropy_reg = params['coefficient_entropy_reg']
    coefficient_L2_reg = params['coefficient_L2_reg']
    coefficient_consistency = params['coefficient_consistency']
    num_explanation_edges = params['num_explanation_edges']
    lr_mlp = params['lr_mlp']

    wandb.init(project=wandb_project, config=params)

    hidden_dim = 64 # Make loading possible
    clip_grad_norm = 2 # Make loading possible
    min_clip_value = -2
    
    
    with open(opts['out_dir'], 'rb') as file:
        data = pickle.load(file)
    
    # TODO: Split data into train and test
    # TODO: !!! each data consist of multiple problems !!! -> Extract singular problems for calculating the loss!
    dataset = data
    
    eval_problem = data[-1]
    
    # gt only needs to be calced once, not per epoch!
    countClauses = 0
    reals = []
    gt_edges_per_problem = []
    for current_batch_num in range(len(eval_problem.is_sat)):
            # This can be repeated for each sub_problem
            clauses_current_problem = eval_problem.get_clauses_for_problem(current_batch_num)

            #print(f"Length of sub_problem 0: {len(clauses_current_problem)}")
            # Sub problem 1 Contains literals 11-20, ...
            #print(f"Sub_problem 0 clauses: {clauses_current_problem}")
            
            # Next part is only for calculating unsat_core -> Move to creation of data and save?
            solver = Solver(name='m22')

            offset = eval_problem.n_literals + 1

            # Assumptions must be unique, therefore eval_problem.n_literals + 1
            assumptions = [i + offset for i in range(len(clauses_current_problem))]

            # Add the clauses with selector literals
            for i, clause in enumerate(clauses_current_problem):
                solver.add_clause(clause + [-assumptions[i]])  # Each clause gets a unique assumption
                
            # is_sat not needed right now as all problems should be unsat
            is_sat = solver.solve(assumptions=assumptions)
            
            unsat_core = solver.get_core()

            # Core contains clauses in reverse order, so we reverse it back
            reversed_core = torch.tensor(unsat_core[::-1])
            # Subtract offset from clauses in core to get original clauses
            unsat_core_clauses = reversed_core - offset
                
            # Map back to original clauses
            core_clause_literals = [clauses_current_problem[i] for i in range(len(clauses_current_problem)) if assumptions[i] in unsat_core]

            solver.delete()
            
            # Calculate mask for current sub_problem in eval_problem
            eval_batch_mask = utils.get_batch_mask(torch.tensor(eval_problem.batch_edges), batch_idx=current_batch_num, batch_size=opts['min_n'], n_variables=eval_problem.n_variables)
            
            
            gt_mask = []
            # TODO: THIS LOOKS WRONG
            for i, idx in enumerate(unsat_core_clauses):
                literals = core_clause_literals[i]
                
                # TODO: This only works if all sub_problems have the same amount of clauses!! WRONG!
                # clause = len(clauses_current_problem)*current_batch_num + unsat_core_clauses[i]
                clause = countClauses + unsat_core_clauses[i]
                
                # Sum of problemBatch.n_clauses_per_batch[] before current_batch_num? Or count while calculating gt for data?
                #clause = problemBatch.n_clauses_per_batch[current_batch_num] + unsat_core_clauses[i]
                
                for value in literals:
                    value = value -1 if value >= 1 else eval_problem.n_variables - (value + 1)
                    gt_mask.append([value, clause])
            
            countClauses = countClauses + len(clauses_current_problem)
                    
            sub_problem_edges = eval_problem.batch_edges[eval_batch_mask]
            
            # TODO: VALIDATE THIS!!!
            gt = torch.isin(torch.tensor(sub_problem_edges), torch.tensor(gt_mask))
            # We only need the right column of the isin tensor
            gt = gt[:,1].int()
            
            reals.append(gt.flatten().numpy())
            gt_edges_per_problem.append(gt_mask)
            #reals_unflattened.append(gt)
    
    allReals = np.concatenate(reals)  # Flatten the list of arrays
    
        

    downstreamTask = NeuroSAT.NeuroSAT(opts=opts,device=device)
    checkpoint = torch.load(f"models/neurosat_sr10to40_ep1024_nr26_d128_last.pth.tar", weights_only=True, map_location=device)
    downstreamTask.load_state_dict(checkpoint['state_dict'])

    mlp = explainer_NeuroSAT.MLP(GraphTask=graph_task).to(device)
    wandb.watch(mlp, log= "all", log_freq=2, log_graph=False)

    mlp_optimizer = torch.optim.Adam(params = mlp.parameters(), lr = lr_mlp)

    downstreamTask.eval()
    for param in downstreamTask.parameters():
        param.requires_grad = False


    training_iterator = dataset
    
    for epoch in range(0, epochs) :
        mlp.train()
        mlp_optimizer.zero_grad()

        temperature = t0*((tT/t0) ** ((epoch+1)/epochs))
        
        #sampledEdges = 0.0
        #sumSampledEdges = 0.0
        
        #samplePredSum = 0

        for index, content in enumerate(training_iterator):
            # stop training before last batch, used for evaluation
            if index == len(training_iterator)-2: break
            node_to_predict = None
            if graph_task: 
                # !! current_problem is really a batch of problems !!
                current_problem = content

            # MLP forward
            # TODO: Implement embeddingCalculation for SAT
            w_ij = mlp.forward(downstreamTask, current_problem, nodeToPred=node_to_predict)

            sampleLoss = torch.FloatTensor([0]).to(device)
            loss = torch.FloatTensor([0]).to(device)
            
            pOriginal, _, _, _ = downstreamTask.forward(current_problem)
            pOriginal = fn.softmax(pOriginal, dim=0)
            
            for k in range(0, sampled_graphs):
                edge_ij = mlp.sampleGraph(w_ij, temperature)
                
                #sampledEdges += torch.sum(edge_ij)
            
                # TODO: softmax needed for loss, beacuse negative values do not work witg log! Need for normalization?
                pSample, _, _, _ = downstreamTask.forward(current_problem, edge_weights=edge_ij)
                pSample = fn.softmax(pSample, dim=0)

                #samplePredSum += torch.sum(torch.argmax(pSample, dim=1))
                
                if graph_task:
                    for sub_problem_idx in range(len(current_problem.is_sat)):
                        # batch_mask needed to differentiate sub_problems in batch of problems for loss
                        # batch_edges cannot simply be divided since sub_problems have different number of edges/clauses
                        # IMPORTANT: batch_size and n variables dependant on data!!
                        batch_mask = utils.get_batch_mask(torch.tensor(current_problem.batch_edges), sub_problem_idx, opts['min_n'], current_problem.n_variables)
                        currLoss = mlp.loss(pOriginal[sub_problem_idx], pSample[sub_problem_idx], edge_ij[batch_mask], current_problem.batch_edges[batch_mask], coefficient_size_reg, coefficient_entropy_reg, coefficient_consistency)
                        sampleLoss.add_(currLoss)
                    

            loss += sampleLoss / sampled_graphs
            
            #sumSampledEdges += sampledEdges / sampled_graphs

        #print(samplePredSum)
        
        loss = loss / len(training_iterator)
        loss.backward()
        
        mlp_optimizer.step()

        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

        torch.nn.utils.clip_grad_norm_(mlp.parameters(), max_norm=clip_grad_norm)
        
        """for param in mlp.parameters():
            if param.grad is not None:
                param.grad.data = torch.max(param.grad.data, min_clip_value * torch.ones_like(param.grad.data))"""

        mlp.eval()
        
        """if graph_task:
            #TODO: Evaluation for SAT! Needs gt
            meanAuc = evaluation.evaluateNeuroSATAUC(mlp, downstreamTask, data)"""
            
        # Get one sub problem for evaluation:
        preds = []
        highest_auc = 0
        lowest_auc = 1
        
        # FIXME: THE EMBEDDINGS CALCULATED IN THIS FORWARD PATH ARE QUITE SMALL, BUT FIXED (GOOD)
        
        # Calculate weights and prediction for all sub_problems in eval_problem
        w_ij_eval = mlp.forward(downstreamTask, eval_problem, nodeToPred=node_to_predict)
        edge_ij_eval = mlp.sampleGraph(w_ij_eval, temperature).detach()
        #pSample_eval, _, _, _ = downstreamTask.forward(eval_problem, edge_weights=edge_ij_eval)
        #pOriginal_eval, _, _, _ = downstreamTask.forward(eval_problem)
        
        for current_batch_num in range(len(eval_problem.is_sat)):
            # This can be repeated for each sub_problem
            clauses_current_problem = eval_problem.get_clauses_for_problem(current_batch_num)

            #print(f"Length of sub_problem 0: {len(clauses_current_problem)}")
            # Sub problem 1 Contains literals 11-20, ...
            #print(f"Sub_problem 0 clauses: {clauses_current_problem}")
            
            # Calculate mask for current sub_problem in eval_problem
            eval_batch_mask = utils.get_batch_mask(torch.tensor(eval_problem.batch_edges), batch_idx=current_batch_num, batch_size=opts['min_n'], n_variables=eval_problem.n_variables)
            
            # Edge probabilites for current sub problem in eval_problems
            edge_ij_eval_masked = edge_ij_eval[eval_batch_mask]
            
            sub_problem_edges = eval_problem.batch_edges[eval_batch_mask]
            
            motif_size = len(gt_mask)
            
            preds.append(edge_ij_eval_masked.cpu().flatten().numpy())
            
            if epoch == epochs-1:
                curr_roc_auc = roc_auc_score(reals[current_batch_num], preds[-1])
                if curr_roc_auc > highest_auc:
                    highest_auc = curr_roc_auc
                    highest_edge_ij = edge_ij_eval_masked
                    highest_gt = reals[current_batch_num]
                    sub_problem_edges_highest = sub_problem_edges
                    highest_index = current_batch_num
                if curr_roc_auc < lowest_auc:
                    lowest_auc = curr_roc_auc
                    lowest_edge_ij = edge_ij_eval_masked
                    lowest_gt = reals[current_batch_num]
                    sub_problem_edges_lowest = sub_problem_edges
                    lowest_index = current_batch_num
                
        
        
        all_preds = np.concatenate(preds)  # Flatten the list of arrays
    
        roc_auc = roc_auc_score(allReals, all_preds)
        
        weights_eval_masked = w_ij_eval[eval_batch_mask]
        print(f"Edge weights for last sub_problem in last problem batch: {weights_eval_masked}")
        
        print(f"Edge probabilites for last sub_problem in last problem batch: {edge_ij_eval_masked}")
        
        #print(f"Prediction for sub_problem: {pOriginal_eval[0]}")
        #print(f"Prediction for sampled sub_problem: {pSample_eval[0]}")
        
        print(f"roc_auc score: {roc_auc}")
    

        # TODO: VISUALIZE TOPK!
        # print sub_problem 0 with calculated edge weights
        if (epoch+1) % 5 == 0:
            pos = utils.visualize_edge_index_interactive(sub_problem_edges, edge_ij_eval_masked, f"results/replication/seed{runSeed}_vis_edge_ij_{epoch+1}", topK=len(gt_edges_per_problem[-1]))
        
        
        #sumSampledEdges = sumSampledEdges / len(training_iterator)
        #, "val/mean_AUC": meanAuc
        wandb.log({"train/Loss": loss, "val/temperature": temperature, "val/roc_auc": roc_auc})

        """for name, param in mlp.named_parameters():
            if param.requires_grad:
                print(f"{name}: {param.grad}")"""
        
    # print sub_problem 0 with calculated edge weights
    pos = utils.visualize_edge_index_interactive(sub_problem_edges, edge_ij_eval_masked, f"results/replication/seed{runSeed}_vis_edge_ij_{epoch+1}", topK=len(gt_edges_per_problem[-1]))
    # print sub_problem 0 with gt
    pos = utils.visualize_edge_index_interactive(sub_problem_edges, reals[-1], f"results/replication/seed{runSeed}_vis_gt", pos)
    
    # TODO: Show difference between topK edges and gt edges -> logical and on topK and gt, visualize
    
    sorted_weights, topk_indices_highest = torch.topk(highest_edge_ij, len(gt_edges_per_problem[highest_index]))
    mask_topK_highest = torch.zeros_like(highest_edge_ij, dtype=torch.float32)
    mask_topK_highest[topk_indices_highest] = 1
    
    common_edges_highest = np.logical_and(highest_gt, mask_topK_highest.flatten().numpy())
    print(f"Highest individual auc: {highest_auc}")
    pos = utils.visualize_edge_index_interactive(sub_problem_edges_highest, highest_edge_ij, f"results/replication/seed{runSeed}_highestAUC", topK=len(gt_edges_per_problem[highest_index]))
    pos = utils.visualize_edge_index_interactive(sub_problem_edges_highest, highest_gt, f"results/replication/seed{runSeed}_highestAUC_gt")
    pos = utils.visualize_edge_index_interactive(sub_problem_edges_highest, common_edges_highest, f"results/replication/seed{runSeed}_highestAUC_commonEdges")
    
    # TO EVALUATE SATISFIABILITY OF EXPLANATION:
    print(f"Highest AUC example masked edges: {sub_problem_edges_highest[mask_topK_highest.bool()]}")
    
    
    sorted_weights, topk_indices_lowest = torch.topk(lowest_edge_ij, len(gt_edges_per_problem[lowest_index]))
    mask_topK_lowest = torch.zeros_like(lowest_edge_ij, dtype=torch.float32)
    mask_topK_lowest[topk_indices_lowest] = 1
    
    common_edges_lowest = np.logical_and(lowest_gt, mask_topK_lowest.flatten().numpy())
    print(f"Lowest individual auc: {lowest_auc}")
    pos = utils.visualize_edge_index_interactive(sub_problem_edges_lowest, lowest_edge_ij, f"results/replication/seed{runSeed}_lowestAUC", topK=len(gt_edges_per_problem[lowest_index]))
    pos = utils.visualize_edge_index_interactive(sub_problem_edges_lowest, lowest_gt, f"results/replication/seed{runSeed}_lowestAUC_gt")
    pos = utils.visualize_edge_index_interactive(sub_problem_edges_lowest, common_edges_lowest, f"results/replication/seed{runSeed}_lowestAUC_commonEdges")
    
        
    #if save_model == "True":
    #    torch.save(mlp.state_dict(), f"models/explainer_{dataset}_{meanAuc}_{wandb.run.name}")

    wandb.finish()
    
    return mlp, downstreamTask

In [7]:
for i in range(1):
    mlp, downstreamTask = trainExplainer(datasetName,opts, wandb_project="NeuroSAT-seeded-train_val-threeEmbeddings-consistencyReg-lowerLR", runSeed=i)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtristan-schulz2001[0m ([33mtristan-schulz2001-tu-dortmund[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Max input embedding for MLP: 0.982603907585144
Min input embedding for MLP: -0.9755895137786865
------------------------------------------
Max input embedding for MLP: 0.9890250563621521
Min input embedding for MLP: -0.9954797029495239
------------------------------------------
Epoch 1, Loss: 3.2832558155059814
Max input embedding for MLP: 0.9771515727043152
Min input embedding for MLP: -0.9932307600975037
------------------------------------------
Edge weights for last sub_problem in last problem batch: tensor([-0.4973, -0.5713, -0.4564, -0.2737, -0.6204, -0.7220, -0.2868, -0.3077,
        -0.3309, -0.7305, -0.3926, -0.5093, -0.2476, -0.1903, -0.4388, -0.4245,
        -0.3969, -0.6623, -0.2749, -0.3379, -0.3019, -0.5271, -0.5577, -0.6581,
        -0.5361, -0.2593, -0.3520, -0.3576, -0.6619, -0.5547, -0.3339, -0.3722,
        -0.3000, -0.4140, -0.2353, -0.2757, -0.7112, -0.3016, -0.6002, -0.3064,
        -0.2776, -0.3068, -0.6697, -0.3859, -0.4220, -0.5506, -0.2283, -0.3198,
        -0

0,1
train/Loss,███▇▇▇▇▇▆▆▆▅▅▅▄▄▄▄▃▃▂▂▂▂▁▁▁▁▁▁
val/roc_auc,▁▃▅▆▇▇██████████████████████▇▇
val/temperature,██▇▇▆▆▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁

0,1
train/Loss,0.93089
val/roc_auc,0.70173
val/temperature,1.0


## Sweeping config

In [None]:
sweep_config = {
    "method": 'grid',                    # random, grid or Bayesian search
    "metric": {"goal": "maximize", "name": "val/mean_AUC"},
    "parameters": {
        'epochs': {
            'values': [30]
            },
        'tT': {
            'values': [1.0, 5.0]
            },
        'size_reg': {
            'values': [1.0, 0.1, 0.01]
            },
        'entropy_reg': {
            'values': [0.1, 1.0, 10.0]
            },
        'lr_mlp':{
            'values': [0.003, 0.001, 0.0003]
            },
        'seed':{
            'values': [74, 75]
            },
    },
}   

In [None]:
sweep_id = wandb.sweep(sweep_config, project="Sweep-NeuroSAT")

In [None]:
import sweepExplainerNeuroSAT

wandb.agent(sweep_id, sweepExplainerNeuroSAT.trainExplainer)

NameError: name 'wandb' is not defined