In [1]:
#Import Library
import os
import numpy as np
import pickle
import re
import scispacy
import logging
from tqdm import tqdm
import torch
import random

In [3]:
torch.cuda.manual_seed_all(37)
random.seed(37)
w2v_file = "./data/embedding/wiki-news-300d-1M.vec"
dir_vocab = "./data/embedding/"
train_dir = "./data/PubMed_20k_RCT/train.txt"
test_dir = "./data/PubMed_20k_RCT/test.txt"
dev_dir = "./data/PubMed_20k_RCT/dev.txt"
language_model = "en_core_sci_sm"
vocab_dir = "./data/embedding/vocab_size_50000_min_-11.0164_max_2.3578.p"
explain_dir = "./data/explanations/explanations.jsonl"

In [9]:
import scispacy
import spacy
from gensim.models import KeyedVectors

#Prepare fasttext
VOCAB_SIZE = 50000

fasttext = KeyedVectors.load_word2vec_format(w2v_file, limit = VOCAB_SIZE)

word2vec = {}

lower_bound = float('inf')
upper_bound = float('-inf')

sum_of_vectors = None
for word in fasttext.index_to_key:
    word2vec[word] = np.reshape(fasttext[word], (1, - 1))
    
    min_coeff = np.min(word2vec[word])
    max_coeff = np.max(word2vec[word])
    lower_bound = min_coeff if min_coeff < lower_bound else lower_bound
    upper_bound = max_coeff if max_coeff > upper_bound else upper_bound

    if sum_of_vectors is not None:
        sum_of_vectors = sum_of_vectors +  word2vec[word]
    else:
        sum_of_vectors = word2vec[word]

sum_of_vectors /= VOCAB_SIZE
unk = "<###-unk-###>"
word2vec[unk] = sum_of_vectors

max_coeff = np.max(word2vec[unk])
upper_bound = max_coeff if max_coeff > upper_bound else upper_bound
min_coeff = np.min(word2vec[unk])
lower_bound = lower_bound if min_coeff < lower_bound else lower_bound

lower_bound = str(round(lower_bound, ndigits=5))
upper_bound = str(round(upper_bound, ndigits=5))

file_name = f'vocab_size_{VOCAB_SIZE}_min_{lower_bound}_max_{upper_bound}.p'
path = os.path.join(dir_vocab, file_name)
pickle.dump(word2vec, open(path, 'wb'))

In [12]:
from xgcn.xgraph import XNode, XGraph

def doc2graph(doc,to_lower):
    def add(graph, token, to_lower=to_lower):
        id = token.i + 1 
      
        if graph.contains_by_id(id):
            raise AssertionError('Node contained.')
        label = token.text 
        if to_lower:
            label = label.lower()
        n = XNode(id=id, label=label, type='TOKEN')
        graph.add_node(n)
        return n

    graph = XGraph()

    for idx in range(len(doc)):
        token = doc[idx]
        add(graph, token)

    for idx in range(len(doc)):
        parent_token = doc[idx]
        parent_id = parent_token.i + 1  
        parent_node = graph.get_node(parent_id)
        for child_token in parent_token.children:
            child_id = child_token.i + 1  
            child_node = graph.get_node(child_id)
            graph.add_edge(parent_node, child_node, t=child_token.dep_)

    return graph
    
def line_to_graph(line, nlp):
    label, sent = line.split('\t')[0], line.split('\t')[1]
    sent = sent.strip()
    doc = nlp(sent)
    print(sent)
    print(doc)
    g = doc2graph(doc=doc, to_lower=True)
    return label, g
def preprocess(path, limit = 100):
    pattern = "###[0-9]+$"
    pattern = re.compile(pattern)
    
    path_out = path.replace('.txt', '.p')
    
    f_in = open(path, 'r')
    lines = f_in.readlines()
    graphs = []
    
    nlp = spacy.load(language_model, disable = ['tagger',
                                              'ner',
                                              'textcat',
                                              'entity_ruler',
                                              'sentenizer',
                                              'merge_noun_chunks',
                                              'merge_entities',
                                              'merge_subtokens'])
    
    written = 0
    discarded = 0
    lines = random.sample(lines, int(limit / 100 * len(lines)))
    for line in tqdm(lines):
        line = line.strip()
        if len(line) == 0 or pattern.match(line.strip()):
            discarded = discarded + 1
            continue
        label, graph = line_to_graph(line.strip(), nlp)
        graphs.append((label,graph))
        written = written + 1
        if (written % 1000 == 999):
            print('Processed {} lines and discarded {} lines'.format(written + 1, discarded + 1))
    
    f_in.close()

    pickle.dump(graphs, open(path_out, 'wb'))
    return path_out

In [13]:
#Preprocessing in PubMed Dataset
preprocess(path = train_dir, limit = 40)
preprocess(path = test_dir)
preprocess(path = dev_dir)

  0%|          | 0/84016 [00:00<?, ?it/s]


Blood samples were drawn at the beginning ( t0 ) and end ( t1 ) of the operation and after 24h ( t2 ) .
Blood samples were drawn at the beginning ( t0 ) and end ( t1 ) of the operation and after 24h ( t2 ) .


  0%|          | 0/35135 [00:00<?, ?it/s]


The cumulative number of new gadolinium-enhancing T1 lesions was reduced by 67.9 % compared to placebo ( p = 0.002 ) .
The cumulative number of new gadolinium-enhancing T1 lesions was reduced by 67.9 % compared to placebo ( p = 0.002 ) .


  0%|          | 0/35212 [00:00<?, ?it/s]

No significant changes were noted in the other two groups .
No significant changes were noted in the other two groups .





'./data/PubMed_20k_RCT/dev.p'

In [5]:
from sklearn.metrics import f1_score
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from tqdm import tqdm
import torch.nn.functional as F

from xgcn.xgcn import XGCN
from xgcn.xgraph import XSample, Pad, ToTensor, LabelToOneHot

In [6]:
#DataLoader

class PubMedLoader(Dataset):
    
    def __init__(self, path_pickle, w2v_path, pad = 150, crop =-1):
        self.path_word2vec = w2v_path
        self.label2vec = pickle.load(open(w2v_path, 'rb'))
        self.path_pickle = path_pickle
        self.label_graph_tuples = pickle.load(open(self.path_pickle, 'rb'))
        self.label2onehot = LabelToOneHot(classes = PubMedLoader.classes())
        self.crop = crop
        self.pad = pad
        self.ops = [Pad(self.pad), ToTensor(), self.label2onehot]
        self.transforms = transforms.Compose(self.ops)
    
    @staticmethod
    def classes():
        return ["METHODS", "RESULTS", "CONCLUSIONS", "BACKGROUND", "OBJECTIVE"]
    
    def __len__(self):
        if self.crop > 0:
            return self.crop
        return len(self.label_graph_tuples)

    def __getitem__(self, index):
        label, graph = self.label_graph_tuples[index]
        embedding = graph.E(label2vec = self.label2vec)
        adjacency = graph.A_tilde()
        
        xsample = XSample(embedding, adjacency, label)
        xsample = self.transforms(xsample)
        return xsample.EMBEDDING, xsample.ADJACENCY, xsample.LABEL

In [13]:
def load(train_path, dev_path, test_path, w2v_path, pad, crop_train, crop_dev, crop_test, batch_size, num_workers):
    pin_memory = torch.cuda.is_available()
    
    train_dataset = PubMedLoader(path_pickle = train_dir.replace(".txt", ".p"), w2v_path = w2v_path, pad = pad, crop = crop_train)
    dev_dataset = PubMedLoader(path_pickle = dev_dir.replace(".txt",".p"), w2v_path = w2v_path, pad = pad, crop = crop_dev)
    test_dataset = PubMedLoader(path_pickle = test_dir.replace(".txt",".p"), w2v_path = w2v_path, pad = pad, crop = crop_test)

    train_loader = DataLoader(train_dataset, batch_size = batch_size ,pin_memory=pin_memory, num_workers=num_workers)
    dev_loader = DataLoader(dev_dataset, batch_size = batch_size ,pin_memory=pin_memory, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size = batch_size ,pin_memory=pin_memory, num_workers=num_workers)

    return train_loader, dev_loader, test_loader


In [30]:
def report(epoch, split, scores):
    print("Epoch: {} Split: {} F-micro: {:.3f} F-macro: {:.3f} F-weighted: {:.3f}"
        .format(epoch, split, scores['micro'], scores['macro'], scores['weighted']))

In [31]:
def validate(xgcn, dataloader,device):
    xgcn.eval()
    xgcn.to(device)
    outputs = None
    targets = None
    for (embeddings, adjacencies, labels) in tqdm(dataloader):
        embeddings = embeddings.to(device)
        adjacencies = adjacencies.to(device)
        labels = labels.to(device)
        if targets is None:
            targets = labels
        else:
            targets = torch.cat((targets, labels))

        output = xgcn(embeddings, adjacencies)
        output = torch.argmax(output, dim=1)

        if outputs is None:
            outputs = output
        else:
            outputs = torch.cat((outputs, output))

    outputs = outputs.tolist()
    targets = targets.tolist()

    outputs, targets = zip(*((output, target) for output, target in zip(outputs, targets))) # todo what does this do?
    outputs = list(outputs)
    targets = list(targets)

    f_score_micro = f1_score(y_pred=outputs, y_true=targets, average='micro')
    f_score_macro = f1_score(y_pred=outputs, y_true=targets, average='macro')
    f_score_weighted = f1_score(y_pred=outputs, y_true=targets, average='weighted')
    logging.info('...done validating.')

    return {'micro': f_score_micro,
            'macro': f_score_macro,
            'weighted': f_score_weighted}

In [16]:
def train(train_loader,dev_loader, path_model,epochs, batch_size, pad, nfeat, nhid, patience, metric, random_seed,nclasses):
    print(train_loader.__len__())
    device = torch.device('mps' if torch.cuda.is_available() else 'cpu')
    print(device)
    xgcn = XGCN(nfeat = nfeat, nhid = nhid, nclass = nclasses, pad= pad, bias = None)
    
    xgcn.to(device)
    optimizer = Adam(params = xgcn.parameters())
    
    #xgcn.load_state_dict(torch.load("model.weights"))
    scores = validate(xgcn = xgcn, dataloader = dev_loader, device = device)
    
    report(epoch = 0, split = "Dev", scores = scores)
    torch.save(xgcn.state_dict(), path_model)
    print("Saved initial model to {}.".format(path_model))
    
    wait = 0
    score_last = float('-inf')
    running_loss = 0.0
    for epoch in range(epochs):
        xgcn.train()
        for batch_idx, (embeddings, adjacencies, labels) in enumerate(train_loader):
            embeddings = embeddings.to(device)
            adjacencies = adjacencies.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            preds = xgcn(embeddings,adjacencies)
            loss = F.nll_loss(preds, labels)
            loss.backward()
            optimizer.step()
            xgcn.xfc.weight.data.clamp_(0)
            
            running_loss += loss.item()
            if batch_idx % 10 == 9:
                print('[%d, %5d, %5d] loss: %.3f' %
                    (epoch + 1, batch_idx + 1, (batch_idx + 1) * batch_size, running_loss / 10))
                running_loss = 0.0
        
        scores = validate(xgcn = xgcn, dataloader = dev_loader, device = device)
        report(epoch = epoch + 1, split = "Dev", scores = scores)
        
        score_current = scores[metric]
        
        if score_current > score_last:
            torch.save(xgcn.state_dict(), path_model)
            print("{} score improved from {:.3f} to {:.3f}. Saved model to {}."
                .format(metric, score_last, score_current, path_model))
            score_last = score_current
            wait = 0
        else:
            wait = wait + 1
            if wait >= patience:
                print("Terminating training after {} epochs w/o improvement.".format(wait))
                return xgcn

In [7]:
#Start Train Process
pad = 350
crop_train = -1
crop_dev = -1
crop_test = -1
batch_size = 8
num_workers = 64
epochs = 100
nclasses = 5
path_model = "./data/model/model.weights"
nfeat = 300
nhid = 300
patience = 3
metric = "weighted"
random_seed = 37

In [None]:
train_loader, dev_loader, test_loader = load(train_path=train_dir,dev_path =  dev_dir,test_path = test_dir, w2v_path = vocab_dir, pad = pad, crop_train = crop_train, crop_dev = crop_dev, crop_test = crop_test, batch_size = batch_size, num_workers = num_workers)

print("Start dump loader")
#pickle.dump(train_loader, open("train_loader.p", 'wb'))
#pickle.dump(dev_loader, open("dev_loader.p", 'wb'))
#pickle.dump(test_loader, open("test_loader.p", 'wb'))

In [14]:
xgcn = train(train_loader, dev_loader,path_model, epochs, batch_size, pad, nfeat, nhid,patience, metric, random_seed, nclasses = nclasses)

8982
cpu


  0%|          | 0/3777 [00:00<?, ?it/s]Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/homebrew/Caskroom/miniforge/base/envs/tf/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/tf/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'PubMedLoader' on <module '__main__' (built-in)>
  0%|          | 0/3777 [00:16<?, ?it/s]


KeyboardInterrupt: 

In [None]:
#Testing
xgcn = XGCN(nfeat = nfeat, nhid = nhid, nclass = nclasses, pad= pad, bias = None)
xgcn.load_state_dict(torch.load("model.weights"))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
scores = validate(xgcn = xgcn, dataloader = test_loader, device = device)
report(0, split = "Test", scores = scores)

In [9]:
#Explain
from explain import *
explain(nfeat=nfeat,
            nhid=nhid,
            path_model=path_model,
            padding=pad,
            path_text=test_dir,
            path_out="./data/explanations/explanations.jsonl",
            path_label2vec=vocab_dir,
            lower_bound=float(lower_bound),
            upper_bound=float(upper_bound),
            to_lower=True,
            language_model=language_model,
            crop=-1,
            do_occlude= True,
            drop=1.0,
            step=0.1,                           
            verbose=False)

  global_normalized_relevance_matrix = (relevance_matrices[0] + relevance_matrices[1]) / np.sum(
100%|██████████| 35135/35135 [2:04:30<00:00,  4.70it/s]


True

In [13]:
#postprocess
from postprocess import *

print('Summarizing occlusion experiments...')
top, bottom = read_explanations(explain_dir)
res_top, percentages = occlusion_predictions(top)
res_bottom, percentages = occlusion_predictions(bottom)
f1_top = [f1_score(t[0], t[1], average='weighted') for t in res_top]

f1_top = list(zip(percentages, f1_top))
f1_top = [f'{tup[0]},{tup[1]}' for tup in f1_top]
f1_top = '\n'.join(f1_top)
f1_bottom = [f1_score(b[0], b[1], average='weighted') for b in res_bottom]
f1_bottom = list(zip(percentages, f1_bottom))
f1_bottom = [f'{tup[0]},{tup[1]}' for tup in f1_bottom]
f1_bottom = '\n'.join(f1_bottom)
with open("./data/top_masked_predictions.csv", 'w+') as fout:
    fout.write(f1_top)
    fout.close()
with open("./data/bottom_masked_predictions.csv", 'w+') as fout:
    fout.write(f1_bottom)
    fout.close()

Summarizing occlusion experiments...


30135it [13:35, 36.94it/s]


Converting to latex...


FileNotFoundError: [Errno 2] No such file or directory: './data/explanations/explanations.jsonl'

In [14]:
print('Converting to latex...')
to_latex(path_in=explain_dir,
            path_out="./data/explanations/explanations.tex",
            max_seq_len=10,
            crop=250,
            weight=15,
            base=0.5)
print('...done converting to latex.')

Converting to latex...


1110it [00:25, 43.42it/s]

...done converting to latex.



