## Preprocess the data

In [None]:
# import torch
# torch.__version__
# torch.version.cuda

In [None]:
# ! pip install torch==1.7.0+cu110 torchvision==0.8.1+cu110 torchaudio===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
# ! pip install --upgrade --force-reinstall torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cu110.html
# ! pip install --upgrade --force-reinstall torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+cu110.html
# ! pip install --upgrade --force-reinstall torch-cluster -f https://pytorch-geometric.com/whl/torch-1.7.0+cu110.html
# ! pip install --upgrade --force-reinstall torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.7.0+cu110.html
# ! pip install torch-geometric
# ! pip install numpy==1.18.0

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

In [None]:
import time
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
from torch.nn import Sequential as Seq, Linear, ReLU

from torch.utils.tensorboard import SummaryWriter

from collections import Counter
from tqdm.notebook import tqdm

from sklearn import metrics
from sklearn.metrics import classification_report, f1_score

In [None]:
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import GraphConv, TopKPooling, GatedGraphConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp

from torch_geometric.data import DataLoader, InMemoryDataset, Data

from torch_geometric.utils import remove_self_loops, add_self_loops

In [None]:
device = 'cuda'

# Dataset

In [None]:
conll_data = pickle.load(open('conll_graph_all.pickle', 'rb'))
vocabulary = pickle.load(open('vocabulary_all.pickle', 'rb'))
voc2id = {key:{l: i for i, l in enumerate(vocabulary[key])} for key in vocabulary}
labels = pickle.load(open('labels.pickle', 'rb'))
label2id = {l: i for i, l in enumerate(labels)}

In [None]:
class CoNLLDataset(InMemoryDataset):
    def __init__(self, data, split, voc2id, labels, window_size=3, root='.', transform=None, pre_transform=None):
        self.dataset = data[split]
        self.voc2id = voc2id
        self.labels = labels
        self.label2id = {l: i for i, l in enumerate(labels)}
        
        self.window_size = window_size
        self.split = split
        
        super(CoNLLDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

        
    @property
    def raw_file_names(self):
        return []
    
    @property
    def processed_file_names(self):
        return ['/data/graphner_embeddings/node_selfloop_simple_features/'+self.split]

    def download(self):
        pass
    
    def process(self):
        data_list = []        
    
        for doc_i, doc in tqdm(enumerate(self.dataset), total=len(self.dataset)):
            features = ['word', 'pos', 'chunk', 'extra'] # to keep the same order

            word_id  = [self.voc2id['word'][doc['word'][0]]]
            len_so_far = len(self.voc2id['word'])
            pos_id   = [len_so_far + self.voc2id['pos'][doc['pos'][0]]]
            len_so_far += len(self.voc2id['pos']) 
            chunk_id = [len_so_far + self.voc2id['chunk'][doc['chunk'][0]]]
            len_so_far += len(self.voc2id['chunk'])

            extra_ids = []
            for v in doc['extra']:
                extra_ids.append(len_so_far + self.voc2id['extra'][v])
            len_so_far += len(self.voc2id['extra'])

            window_size = max(len(doc['left_context']), len(doc['right_context'])) if self.window_size == 'all' \
                        else self.window_size

            ids = word_id + pos_id + chunk_id + extra_ids
            edges = [(i+1, 0) for i, n in enumerate(ids[1:])]

            context_word_ids = []
            for window  in [doc['left_context'][-window_size:], doc['right_context'][:window_size]]:
                for i, w in enumerate(window):

                    # All nodes relate to the central node
                    # edges.append((len(ids), 0)) 

                    # Nodes linked by order in the sentence
                    edges.append((len(ids), 0 if i == 0 else (len(ids) - 1)))
                    ids.append(self.voc2id['word'][w])
            
            # add self loop
            edges.append((0, 0))

            x = torch.LongTensor(ids).unsqueeze(1)
            y = torch.tensor([self.label2id[doc['label']]])
            edge_index = torch.tensor(list(zip(*edges)))

            try:
                data = Data(x=x, y=y, edge_index=edge_index)
                data_list.append(data)
            except:
                if len(edge_index) == 0:
                    continue
                print(x, y, edge_index)
                break

        data, slices = self.collate(data_list)
        torch.save((data, slices), '/data/graphner_embeddings/node_selfloop_simple_features/'+self.split)

In [None]:
train_dataset = CoNLLDataset(data=conll_data, split='train', voc2id=voc2id, labels=labels, window_size='all')
val_dataset = CoNLLDataset(data=conll_data, split='validation', voc2id=voc2id, labels=labels, window_size='all')
test_dataset = CoNLLDataset(data=conll_data, split='test', voc2id=voc2id, labels=labels, window_size=('all'))

# Model

In [None]:
class SAGEConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(SAGEConv, self).__init__(aggr='max') #  "Max" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.act = torch.nn.ReLU()
        self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False)
        self.update_act = torch.nn.ReLU()
        
    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        
        
        edge_index, _ = remove_self_loops(edge_index)
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_j):
        # x_j has shape [E, in_channels]

        x_j = self.lin(x_j)
        x_j = self.act(x_j)
        
        return x_j

    def update(self, aggr_out, x):
        # aggr_out has shape [N, out_channels]

        new_embedding = torch.cat([aggr_out, x], dim=1)
        
        new_embedding = self.update_lin(new_embedding)
        new_embedding = self.update_act(new_embedding)
        
        return new_embedding

In [None]:
embed_dim = 512
dim_input = sum(len(voc2id[k]) for k in ['word', 'pos', 'chunk', 'extra', 'classes'])
n_labels = len(labels)
dropout_rate = 0.5

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = SAGEConv(embed_dim, embed_dim)
        self.pool1 = TopKPooling(embed_dim, ratio=0.8)
        self.conv2 = SAGEConv(embed_dim, embed_dim)
        self.pool2 = TopKPooling(embed_dim, ratio=0.8)
        self.conv3 = SAGEConv(embed_dim, embed_dim)
        self.pool3 = TopKPooling(embed_dim, ratio=0.8)
        self.item_embedding = torch.nn.Embedding(num_embeddings=dim_input, embedding_dim=embed_dim)
        # self.lin0 = torch.nn.Linear(dim_input, embed_dim)
        self.lin1 = torch.nn.Linear(embed_dim*2, embed_dim)
        self.lin2 = torch.nn.Linear(embed_dim, embed_dim//2)
        self.lin3 = torch.nn.Linear(embed_dim//2, n_labels)
        self.bn1 = torch.nn.BatchNorm1d(embed_dim)
        self.bn2 = torch.nn.BatchNorm1d(embed_dim //2)
        self.act1 = torch.nn.ReLU()
        self.act2 = torch.nn.ReLU()        
  
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        # x = self.lin0(x.float())
        x = self.item_embedding(x)
        x = x.squeeze(1)

        x = F.relu(self.conv1(x, edge_index))
        
        z = self.pool1(x, edge_index, None, batch)
        x, edge_index, _, batch, _, _ = z
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv2(x, edge_index))
     
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv3(x, edge_index))

        x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = x1 + x2 + x3

        x = self.lin1(x)
        x = self.act1(x)
        x = self.lin2(x)
        x = self.act2(x)      
        x = F.dropout(x, p=dropout_rate, training=self.training)

        # x = torch.sigmoid(self.lin3(x))
        x = self.lin3(x).squeeze(1)
        # print(x)
        
        return x

In [None]:
model = Net().to(device)

writer = SummaryWriter(comment='gcn-justids-winall-wei2-lr1e3-mom0.9-wd5e4-embdim512-dr0.5-bs64-notall', log_dir=None,)
logs = []

In [None]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [None]:
optimizer_params = {'lr': 5e-4,
                    'momentum': 0.9, 
                    'weight_decay': 5e-4,
                   }

device = torch.device('cuda')
weights1 = [0.5530, 1.0000, 0.0317, 0.4590, 0.4120]
weights2 = [0.7436, 1.0000, 0.1780, 0.6775, 0.6419]

weights = torch.Tensor(weights2).to(device)

optimizer = torch.optim.SGD(model.parameters(),  **optimizer_params)
loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
# print()#weights)

In [None]:
torch.cuda.device_count()

In [None]:
def train(loader):
    model.train()

    loss_all = 0
    print_loss_every = int(len(loader) / 5)
    for ii, data in enumerate(tqdm(loader, total=len(loader))):
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        label = data.y.to(device)
        loss = loss_fn(output, label)
        loss.backward()
        loss_all += data.num_graphs * loss.item()
        optimizer.step()
        if ii % print_loss_every == 0:
            print('== Loss at', ii, ' : ', data.num_graphs * loss.item())

    return loss_all / len(train_dataset)

In [None]:
def evaluate(loader, split=''):
    print('==== Evaluation on the', split.upper(), ' set ====')
    model.eval()

    predictions = []
    labels = []

    with torch.no_grad():
        for data in tqdm(loader, total=len(loader)):

            data = data.to(device)
            pred = torch.nn.Softmax(dim=1)(model(data)).detach().cpu().numpy()
            label = data.y.detach().cpu().numpy()
            
            predictions.append(np.argmax(pred, axis=1))
            labels.append(label)

    predictions = np.hstack(predictions)
    labels = np.hstack(labels)
    
    print(classification_report(predictions, labels, digits=4))
    
    micro_F1 = metrics.f1_score(labels, predictions, average='micro')
    macro_F1 = metrics.f1_score(labels, predictions, average='macro')
    weighted_F1 = metrics.f1_score(labels, predictions, average='weighted')

        
    return (micro_F1, macro_F1, weighted_F1)

In [None]:
for epoch in range(5):
    loss = train(train_loader)
    val_acc = evaluate(val_loader, 'val')
    logs.append((loss, val_acc))
    
    writer.add_scalar("Learning_rate", optimizer_params['lr'], len(logs))
    writer.add_scalar("Loss/train", loss, len(logs))
    writer.add_scalar("micro_F1/dev", val_acc[0], len(logs))
    writer.add_scalar("macro_F1/dev", val_acc[1], len(logs))
    writer.add_scalar("weighted_F1/dev", val_acc[2], len(logs))

    print(f'Epoch: {len(logs):03d}, Loss: {loss:.5f}')
    # print('train_acc', train_acc)
    print('val_acc', val_acc)

In [None]:
test_acc = evaluate(test_loader, 'test')
print('test_acc', test_acc)