# Graph Neural Network for Simulated X-Ray Transient Detection
The present work aims to train a GNN to label a particular sort of X-Ray transient using simulated events overlayed onto real data from XMM-Newton observations. We will experiment with Graph Convolutional Networks (GCNs). We will therefore  have to trandsform our point-cloud data into a "k nearest neighbors"-type graph. Data stored in the `raw` folder at the current working directory is taken from icaro.iusspavia.it `/mnt/data/PPS_ICARO_SIM2`. Observations store data for each photon detected, with no filter applied, in FITS files ending in `EVLI0000.FTZ` for the original observations and `EVLF0000.FTZ` for the observation and simulation combined. We will refer to the former data as "genuine" and to the latter as "faked" for brevity. 

IMPORTANT! Some of the "faked" data files downloaded are empty, so sun the `prune.py` script present in the current repository to remove these artifacts.

In [1]:
import numpy as np

from astropy.table import Table, setdiff
from astropy.table.operations import _join

import torch
import pyg_lib #new in torch_geometric 2.2.0!
from torch_geometric.data import Data
from torch_geometric.data import Dataset
import torch_geometric.transforms as ttr
from torch_geometric.loader import DataLoader, NeighborLoader
import torch_directml

import os
import os.path as osp
import sys
from glob import glob
from icecream import ic
from tqdm import tqdm, trange
import gc

In [2]:
from sklearn.preprocessing import StandardScaler

I define a `log` function for future use.

In [3]:
def log(logfile, forcemode=None, **loggings):
    if not forcemode is None:
        assert forcemode in ["w", "a"], f"Error: `forcemode` is '{forcemode}'. Must be either 'w' or 'a'"
    print(*(f"{key}: {value}" for key, value in loggings.items()), sep="\n\t", file=sys.stderr)
    mode = "w+"
    if osp.exists(logfile) and forcemode is None:
        usrinpt=""
        while not usrinpt in ["O","E","C"]:
            usrinpt = input(f"Do you want to overwrite [O] or extend [E] already existing log file {logfile}? (C to cancel) [O,E,C] ")
        if usrinpt == "C":
            return
        elif usrinpt == "E":
            mode = "a"
    elif not forcemode is None:
        mode = forcemode
    with open(logfile, mode) as lf:
        print(*(f"{key}: {value}" for key, value in loggings.items()), sep="\n\t", file=lf)

The following function definition is a copy-paste of the original `setdiff` function from [astropy sourcecode](https://docs.astropy.org/en/stable/_modules/astropy/table/operations.html), modified to return the indices of elemnts prensent in `table1` but not in `table2`. This will be used to mark simulated data overlayed onto the real observations.

In [4]:
def setdiff_idx(table1, table2, keys=None):
    if keys is None:
        keys = table1.colnames

    # Check that all keys are in table1 and table2
    for tbl, tbl_str in ((table1, 'table1'), (table2, 'table2')):
        diff_keys = np.setdiff1d(keys, tbl.colnames)
        if len(diff_keys) != 0:
            raise ValueError("The {} columns are missing from {}, cannot take "
                             "a set difference.".format(diff_keys, tbl_str))

    # Make a light internal copy of both tables
    t1 = table1.copy(copy_data=False)
    t1.meta = {}
    t1.keep_columns(keys)
    t1['__index1__'] = np.arange(len(table1))  # Keep track of rows indices

    # Make a light internal copy to avoid touching table2
    t2 = table2.copy(copy_data=False)
    t2.meta = {}
    t2.keep_columns(keys)
    # Dummy column to recover rows after join
    t2['__index2__'] = np.zeros(len(t2), dtype=np.uint8)  # dummy column

    t12 = _join(t1, t2, join_type='left', keys=keys,
                metadata_conflicts='silent')

    # If t12 index2 is masked then that means some rows were in table1 but not table2.
    if hasattr(t12['__index2__'], 'mask'):
        # Define bool mask of table1 rows not in table2
        diff = t12['__index2__'].mask
        # Get the row indices of table1 for those rows
        idx = t12['__index1__'][diff]
    else:
        idx = []

    return idx

TODO: check why the following does not prevent warnings from happening

In [5]:
from astropy import units as u
newunits = [u.def_unit("PIXELS", u.pixel),
            u.def_unit("CHAN", u.chan),
            u.def_unit("CHANNEL", u.chan),
            u.def_unit("0.05 arcsec", 0.05*u.arcsec)
           ]

Let's define a function that reads from a XMM observation FITS file and returns a table with the relevent event attributes and a flag `ISFAKE` which is `True` for simulated events and `False` for genuine events. The function takes two arguments: the path to the genuine file and the path to the faked file. A column with name `ISFAKE` will be added where `True` values will label simulated events. The function will return the faked observations table's `TIME`, `X`, `Y`, `PI`, `FLAG`, and `ISFAKE` columns.

In [6]:
def read_events(genuine, simulated):
    with u.add_enabled_units(newunits):
        I_dat = Table.read(genuine, hdu=1)
        F_dat = Table.read(simulated, hdu=1)

    # assert all(I_dat['X'].mask), f"{genuine}"

    I_dat = I_dat[np.logical_not(I_dat['X'].mask)]
    F_dat = F_dat[np.logical_not(F_dat['X'].mask)]
    
    D_dat_idx = setdiff_idx(F_dat, I_dat)
    
    dat = F_dat
    dat["ISFAKE"] = np.zeros(len(dat), dtype=bool)
    dat["ISFAKE"][D_dat_idx] = True
    return dat["TIME", "X", "Y", "PI", "FLAG", "ISFAKE"]

Define search patterns for event files, both genuine and faked, to be searched within the `raw` directory.

In [7]:
genuine_pattern = "*EVLI0000.FTZ"
faked_pattern   = "*EVLF0000.FTZ"

We will now set up `Data` and `Dataset` specialized classes for our observation data precessing and handling.

First we will define `IcaroData` as a data type in which the `pos` attribute is overridden by a `@property`. This new `pos` gets and sets values from the last three features of each row of data.

In [8]:
class IcaroData(Data):    
    @property
    def pos(self):
        return self.x[:, -3:]
    
    @pos.setter
    def pos(self, replace):
        assert replace.shape == self.pos.shape
        self.x[:, -3:] = replace

The following dataset structure is quite standard. Notice how we use as feature (`x` attribute) values from the `PI`, `FLAG`, `TIME`, `X`, and `Y` columns, where the last three will be used as `pos` for the data. Notice then that this `pos` is then transformed through the use of a `Standard Scaler` and saved into the `processed` folder in the current working directory. As for the target (`y` attribute) we will use the `ISFAKE` column. Notice that, since we need to transform bools into numerical values for computation on CUDA, simulated data is now labeled with `1`, while genuine data with `0`.

In [9]:
from torch.multiprocessing import Manager
import concurrent.futures

In [10]:
class IcaroDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
        self.device = device
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return list(sorted(list(glob(osp.join(self.raw_dir, genuine_pattern))) +
                           list(glob(osp.join(self.raw_dir, faked_pattern)))))

    @property
    def processed_file_names(self):
        return list(map(lambda name: osp.join(self.processed_dir, osp.basename(name)+".pt"), 
                        glob(osp.join(self.raw_dir, faked_pattern))))
    
    @property
    def num_classes(self):
        return 2

    def _hidden_process(self, raw_path):
        # Read data from `raw_path`.
        raw_path = raw_path.split('|', 2)
        # try:
        #     dat = read_events(*raw_path)
        # except Exception as e:
        #     print("error at ", *raw_path)
        #     raise e

        dat = read_events(*raw_path)
        
        data = IcaroData(x  =torch.from_numpy(np.array([dat["PI"], dat["FLAG"], dat["TIME"], dat["X"], dat["Y"]]).T).float(), 
                            y  =torch.from_numpy(np.array(dat["ISFAKE"])).long())
        
        ss2 = StandardScaler()
        ss2.fit(data.pos)
        new_pos = ss2.transform(data.pos)
        data.pos = torch.tensor(new_pos)
        data.to(self.device)

        if self.pre_filter is not None and not self.pre_filter(data):
            return

        if self.pre_transform is not None:
            data = self.pre_transform(data)

        torch.save(data, osp.join(self.processed_dir, osp.basename(raw_path[-1])+".pt"))
        # del data

    def process(self):
        fnames = list(zip(sorted(glob(osp.join(self.raw_dir, genuine_pattern))), 
                          sorted(glob(osp.join(self.raw_dir, faked_pattern))))
                     )
        already_processed = list(map(lambda name: osp.basename(name), glob(osp.join(self.processed_dir, "*"))))
        fnames = np.array([gname+'|'+fname for gname, fname in fnames if not osp.basename(fname)+".pt" in already_processed])
        # manager = Manager()
        # with concurrent.futures.ThreadPoolExecutor() as executor:
        #     executor.map(self._hidden_process, tqdm(fnames))

        # manager = Manager()
        # with concurrent.futures.ProcessPoolExecutor() as executor:
        #     executor.map(self._hidden_process, tqdm(fnames))

        # hidden_process = np.vectorize(self._hidden_process)
        # hidden_process(fnames)

        for fname in tqdm(fnames):
            self._hidden_process(fname)

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, self.processed_file_names[idx]))
        return data

Let's load data from the current working directory and pre-transform it into a `KNNGraph`. This might take some time the first load, as the processed files are built, but subsequent runs will be speedy.

TODO: lots of warnings from astropy units when first processing. Gotta see what we can do about it

In [11]:
import warnings
with warnings.catch_warnings(): # to avoid useless astropy units warnings (gotta check how to solve it)
    warnings.simplefilter('ignore')
    ds = IcaroDataset(os.getcwd(), pre_transform = ttr.KNNGraph(k=20), device='cpu')

Processing...
  0%|          | 0/4885 [00:00<?, ?it/s]

reading
preprocessing
preprocessing done
data transformed


  0%|          | 1/4885 [02:04<168:37:58, 124.30s/it]

data saved
reading
preprocessing
preprocessing done
data transformed


  0%|          | 2/4885 [02:50<106:03:06, 78.19s/it] 

data saved
reading
preprocessing
preprocessing done


We now define a `Net` model, with parametrable number of GCN layers, in channels, hidden channels, and out channels. Each layer but the last has a user-given activation function (`relu` is the default) and a `softmax` output activation function.

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import MLP, GINConv, global_add_pool

class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, activation_function=F.relu):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(in_channels, hidden_channels))
            in_channels = hidden_channels

        self.last_conv = GCNConv(hidden_channels, out_channels)
        
        self.activation_function = activation_function

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv in self.convs:
            x = conv(x, edge_index)
            x = self.activation_function(x)
        x = F.dropout(x, training=self.training)
        x = self.last_conv(x, edge_index)
        
        return F.log_softmax(x, dim=1)

Let us now set up the parameters.

In [None]:
dataset    = ds
lr         = 0.01
device = torch_directml.device()
# device     = torch.device('cpu') # cries in low GPU memory space YoY
# device     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 1
epochs     = 100
hidden_channels = 5
num_layers = 2
activation_function = F.relu #torch.sigmoid

Then we'll split and load the dataset.

In [None]:
torch.cuda.empty_cache()
#print(torch.cuda.memory_summary(device=None, abbreviated=False))

model = Net(dataset.num_node_features, hidden_channels, dataset.num_classes, num_layers, activation_function).to(device)
#print(torch.cuda.memory_summary(device=None, abbreviated=False))

train_dataset = dataset[len(dataset) // 5:]
train_loader  = DataLoader(train_dataset, batch_size, shuffle=True)
valid_dataset = dataset[len(dataset) // 10:len(dataset) // 5]
valid_loader  = DataLoader(valid_dataset, batch_size+1)
test_dataset  = dataset[:len(dataset) // 10]
test_loader   = DataLoader(test_dataset, batch_size+1)

optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)

FileNotFoundError: [Errno 2] No such file or directory: '/home/xaco/tesi/Transient-Data-Analysis/processed/P0745170301M1S001MIEVLF0000.FTZ.pt'

In [None]:
import subprocess as sp
import time

In [None]:
def total_len(dataset):
    """Returns the number of target rows of the dataset"""
    return np.sum([len(data.y) for data in dataset])

def total_positives(dataset):
    """Returns the number of target value '1' of the dataset (only if the other class is '0')"""
    return np.sum([data.y.sum().item() for data in dataset])

def train():
    model.train()

    total_loss = 0
    for data in train_loader:
        loader = train_loader
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data).to(device)
        pred = out.argmax(dim=-1)
        totpos = total_positives(loader.dataset)
        totlen = total_len(loader.dataset)
        true_positives = torch.logical_and(pred == 1, pred == data.y).sum().int()/totpos
        true_negatives = torch.logical_and(pred == 0, pred == data.y).sum().int()/(totlen-totpos)
        frac, rev_frac = data.y.sum().item()/len(data.y), (len(data.y) - data.y.sum().item())/len(data.y)
        assert not np.isnan(frac) and not np.isnan(rev_frac)
        if frac == 0: # in this case placeholder parameters must be enforced to avoid unwanted behavior
            frac = rev_frac = 0.5
            true_positives = 1.
        addloss = (true_positives*true_negatives)**(-0.5) - 1 # scares the model out of giving a constant answer
        loss = F.cross_entropy(out, data.y, weight=torch.tensor([frac, rev_frac]).to(device)) + addloss
        assert not torch.isnan(loss.detach()), f"out: {out}\ndata.y: {data.y}\nLoss: {total_loss}\nWeight: {frac}"
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
        #print(torch.cuda.memory_summary(device=None, abbreviated=False))
        del data
        torch.cuda.empty_cache()
        gc.collect
        #print(torch.cuda.memory_summary(device=None, abbreviated=False))
    return total_loss / total_len(train_loader.dataset)

@torch.no_grad()
def test(loader):
    model.eval()

    total_correct         = 0
    total_true_positives  = 0
    total_false_positives = 0
    for data in loader:
        data = data.to(device)
        pred = model(data).argmax(dim=-1)
        total_correct += int((pred == data.y).sum())
        total_true_positives += int(np.logical_and(pred == 1, pred == data.y).sum())
        total_false_positives += int(np.logical_and(pred == 1, pred != data.y).sum())
        del data
        torch.cuda.empty_cache()
        gc.collect
    totlen = total_len(loader.dataset)
    totpos = total_positives(loader.dataset)
    return (total_correct/totlen, 
            total_true_positives/totpos, 
            total_false_positives/(totlen-totpos)
           )

In [None]:
for epoch in range(1, epochs + 1):
    loss = train()
    train_acc, train_tp, train_fp = test(train_loader)
    test_acc, test_tp, test_fp = test(valid_loader)
    log(Epoch=epoch, 
        AbsLogLoss=np.log(loss), 
        Train_accuracy=train_acc,
        Train_true_positives=train_tp,
        Train_false_positives=train_fp,
        Test_accuracy=test_acc,
        Test_true_positives=test_tp,
        Test_false_positives=test_fp,
        logfile="logs.log",
        forcemode="w"
       )
    # if not epoch % 10: # for acoustic feedback
    #     sp.run(["spd-say", "'Epoch! Epoch! Epoch!'"])

In [None]:
# while True: # warns that the process is finished
#     sp.run(["spd-say", "'Your process is done'"])
#     time.sleep(5)