In [31]:
import numpy as np

import torch
from torch_geometric.data import Dataset, Data

from dig.xgraph.evaluation import XCollector

import import_ipynb
from influenceDataset import get_dataloader, influenceDataset
from simulation import simul_fromData

In [39]:
def greedy(data, sparsity):
    n = len(data.x)
    m = len(data.edge_attr)
    hard_mask = [1]*n
    edge_live = set(range(m))
    zero = data.x.sum()
    origin = simul_fromData(data, hard_mask).sum()
    
    for _ in range(int(m*sparsity)):
        pred_min = origin
        idx_min = 0
        for i in edge_live:
            hard_mask[i]=0
            pred = sum(simul_fromData(data, hard_mask))
            hard_mask[i]=1
            
            if pred<pred_min:
                pred_min = pred
                idx_min = i
        hard_mask[idx_min]=0
        edge_live.remove(idx_min)
    
    hard_mask_out = [1-x for x in hard_mask]
    related_preds = {'zero':zero, 'masked':simul_fromData(data,hard_mask).sum(), 'maskout':simul_fromData(data,hard_mask_out).sum(), 'origin':origin}
    return None, hard_mask, related_preds
        

            
def pipeline():    
    dataset = influenceDataset('..','../graphs')

    x_collector = XCollector()
    for i, data in enumerate(dataset):
        print(i)
        _, hard_mask, related_preds = greedy(data, sparsity=0.0000001)
        x_collector.collect_data(hard_mask, related_preds)
        break

    print(f'Fidelity: {x_collector.fidelity:.4f}\n'
          f'Fidelity_inv: {x_collector.fidelity_inv: .4f}\n'
          f'Sparsity: {x_collector.sparsity:.4f}')

In [40]:
pipeline() # 4:34

0


KeyboardInterrupt: 

In [27]:
"""
x = torch.tensor([[1], [0], [1], [1], [0]], dtype=torch.float)
edge_index = torch.tensor([[0, 1, 2, 3, 3], [1, 2, 1, 1, 4]], dtype=torch.long)
edge_attr1 = torch.tensor([[0], [0.8], [1], [0.5], [0.1]], dtype=torch.float)
edge_attr2 = torch.tensor([[1], [0.2], [1], [0.8], [1]], dtype=torch.float)
data11 = Data(x=x, edge_index=edge_index, edge_attr=edge_attr1)
data12 = Data(x=x, edge_index=edge_index, edge_attr=edge_attr1)
data21 = Data(x=x, edge_index=edge_index, edge_attr=edge_attr2)

for i, data in enumerate([data11,data12,data21]):
    _, hard_mask, related_preds = greedy(data, sparsity=0.5)
    
    print(i)
    print(related_preds)
    print(hard_mask)
    print()
"""

[1, 1, 1, 1, 1]
[1, 1, 0, 1, 1]
0
{'zero': tensor(3.), 'masked': 3.0997, 'maskout': 4.0, 'origin': 4.1007}
[1, 1, 0, 0, 1]

[1, 1, 1, 1, 1]
[1, 1, 0, 1, 1]
1
{'zero': tensor(3.), 'masked': 3.0933, 'maskout': 4.0, 'origin': 4.0969}
[1, 1, 0, 0, 1]

[1, 1, 1, 1, 1]
[1, 1, 1, 1, 0]
2
{'zero': tensor(3.), 'masked': 4.0, 'maskout': 5.0, 'origin': 5.0}
[0, 1, 1, 1, 0]

