In [None]:
import pandas as pd
import numpy as np
import os, sys

import torch
import transformers
import csv

from tqdm import tqdm 

from transformers import AutoTokenizer, AutoModel

In [None]:
torch.cuda.is_available()

In [None]:
from sklearn.metrics import f1_score, precision_score, recall_score

def symbol_wize(y_true, y_pred):
    y_true1, y_pred1 = y_true.split(), y_pred.split()
    y_true, y_pred = set(), set()
    
    eps=1e-7
    for i in y_true1:
        y_true.update(set(range(int(i.split(':')[0]), int(i.split(':')[1]))))
    for i in y_pred1:
        y_pred.update(set(range(int(i.split(':')[0]), int(i.split(':')[1]))))
    
    true_pos = y_true.intersection(y_pred)
    false_neg = y_true.difference(y_pred)
    false_pos = y_pred.difference(y_true)
    
    precision = (len(true_pos)+eps)/(len(true_pos) + len(false_pos)+eps)
    recall = (len(true_pos)+eps)/(len(true_pos) + len(false_neg)+eps)
    
    f1_score = 2*(precision*recall)/(precision + recall + 1e-7)
    return f1_score

def get_rank(true_class, pred_class, true_span, pred_span):
    if true_class == pred_class == 1:
        if type(true_class) == str and type(pred_class) == str:
            return symbol_wize(true_span, pred_span)
        else:
            return 0
    elif true_class == pred_class == 0:
        return 1
    else:
        return 0

def gapping_metrics(true_data, pred_data, only_class=False):
    f1_class = f1_score(true_data['class'].values, pred_data['class'].values)
    
    f1_symbolwise_score = 0
    if not only_class:
        f1_symbolwise_scores = []
        for tag in ['cV', 'cR1', 'cR2', 'R1', 'R2']:
            f1_symbolwise_scores += [get_rank(true_data.iloc[i]['class'], pred_data.iloc[i]['class'], true_data.iloc[i][tag], pred_data.iloc[i][tag]) for i in range(len(true_data))]
        f1_symbolwise_score = np.mean(f1_symbolwise_scores)
    
    return {'f1_score': f1_class, 'f1_symbolwise_score': f1_symbolwise_score}


# BERT

# Data

In [None]:
train_df = pd.read_csv(os.path.join('data', 'train', 'train.csv'), sep="\t", quoting=csv.QUOTE_NONE)
train_dict = train_df.to_dict("records")

In [None]:
test_df = pd.read_csv(os.path.join('data', 'test', 'test_gold_standard.csv'), sep="\t", quoting=csv.QUOTE_NONE)
test_dict = test_df.to_dict("records")

In [None]:
NAME_MODEL = 'sberbank-ai/ruRoberta-large'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(NAME_MODEL,
                                                    truncation=True,
                                                    padding=True)

In [None]:
TAGS = ["[NONE]", "cV", "cR1", "cR2", "R1", "R2"]
TAG2ID = {v: k for k, v in enumerate(TAGS)}

GAPS = ["[NONE]", "V"]
GAP2ID = {v: k for k, v in enumerate(GAPS)}

In [None]:
def make_text_data(train_dict):
    text_data = []

    for sample in tqdm(train_dict):
        text = sample['text']
        text = text.replace("—", "-")

        tokenizer_out = tokenizer(text, padding="max_length", truncation=True, max_length=128)
        word_ids = tokenizer_out.word_ids()
        tokens_ids = tokenizer_out.input_ids
        tokens = tokenizer_out.tokens()
        #print(tokens)

        tokens_borders = []
        for i in range(0, len(tokens_ids)):
            #print(tokens[i], tokens_ids[i])
            if tokens_ids[i] == 0 or tokens[i] == '[SEP]' or tokens[i] == '[CLS]' or tokens[i] == '<s>' or tokens[i] == '</s>':
                tokens_borders.append([-1, -1])
            else:
                token_border = tokenizer_out.token_to_chars(i)
                tokens_borders.append([token_border.start, token_border.end])

        tags_borders = []
        for tag in TAGS[1:]:
            if not pd.isna(sample[tag]):
                for border in sample[tag].split(" "):
                    left, right = list(map(int, border.split(":")))
                    tags_borders.append((tag, left, right))

        tags = []
        for token_left, token_right in tokens_borders:

            if token_left == -1 and token_right == -1:
                tags.append(TAGS[0])
                continue

            flag = False
            for tag, tag_left, tag_right in tags_borders:
                if tag_left <= token_left and token_right <= tag_right:
                    tags.append(tag)
                    flag = True

            if not flag:
                tags.append(TAGS[0])   

        gap_index = []
        if not pd.isna(sample["V"]):
            for borders in sample["V"].split(" "):
                left, right = list(map(int, borders.split(":")))
                gap_index.append(left)

        gaps = [TAGS[0]] * len(tokens_ids)
        for tag_left in gap_index:
            flag = False
            for i, (left, right) in enumerate(tokens_borders):
                if tag_left == left and not flag:
                    gaps[i] = "V"
                    flag = True

        #tags = [TAGS[0]] + tags + [TAGS[0]]
        #gaps = [TAGS[0]] + gaps + [TAGS[0]]
        #tokens_borders = [[-1, -1]] + tokens_borders + [[-1, -1]]
        tags_ids = [TAG2ID[tag] for tag in tags]
        gaps_ids = [GAP2ID[gap] for gap in gaps]
        attention_mask = tokenizer_out['attention_mask']
        label = int(sample['class'])

        text_data.append({
            'tokens_ids': tokens_ids,
            'attention_mask': attention_mask,
            'tags_ids': tags_ids,
            'gaps_ids': gaps_ids,
            'label': label,
            'tokens_borders': tokens_borders
        })
    
    return text_data

In [None]:
class ARRGDataset(torch.utils.data.Dataset):
    def __init__(self, text_data):
        self.text_data = text_data
        
    def __getitem__(self, idx):
        item = {key: torch.tensor(val, dtype=torch.long) for key, val in self.text_data[idx].items()}
        return item

    def __len__(self):
        return len(self.text_data)

In [None]:
train_data = make_text_data(train_dict)
arrg_dataset_train = ARRGDataset(train_data)

In [None]:
test_data = make_text_data(test_dict)
arrg_dataset_test = ARRGDataset(test_data)

# Model

In [None]:
from transformers import BertTokenizerFast, BertModel, BertConfig, AutoConfig, AutoModel, DistilBertConfig
from transformers import BertTokenizerFast, BertForTokenClassification, BertForSequenceClassification
from transformers import DistilBertTokenizerFast, DistilBertForTokenClassification, DistilBertForSequenceClassification
from transformers import Trainer, TrainingArguments

In [None]:
import torch.nn as nn

class BertAgrrModel(transformers.PreTrainedModel):

    def __init__(self, name='DeepPavlov/rubert-base-cased-sentence', out_size=768):
        super(BertAgrrModel, self).__init__(config=AutoConfig.from_pretrained(name, output_last_hidden_state=True))
        #config = BertConfig.from_pretrained("distilbert-base-uncased", output_last_hidden_state=True)
        self.bert = AutoModel.from_pretrained(name)
        
        self.dropout = nn.Dropout(0.1)
        self.sentence_classifier = nn.Linear(out_size, 2)
        self.full_annotation_classifier = nn.Linear(out_size, 6)

    def forward(self, input_ids, attention_mask):
        output = self.bert(input_ids, attention_mask)
        
        sequence_output = output.last_hidden_state
        sequence_output = self.dropout(sequence_output)
        pooled_output = sequence_output[:, 0, :]
        
        sentence_logits = self.sentence_classifier(pooled_output)
        full_annotation_logits = self.full_annotation_classifier(sequence_output)
        sentence_probs = torch.softmax(sentence_logits, dim=1)
        full_annotation_probs = torch.softmax(full_annotation_logits, dim=2)
        
        return {
            'sentence_logits': sentence_logits,
            'full_annotation_logits': full_annotation_logits,
            'sentence_probs': sentence_probs,
            'full_annotation_probs': full_annotation_probs
        }

# Train

In [None]:
from torch.optim import AdamW
from transformers import get_scheduler

In [None]:
EPOCH = 15
device = 'cuda:4'

In [None]:
model = BertAgrrModel(NAME_MODEL).to(device)

In [None]:
from torch.utils.data import Dataset, DataLoader
train_dataloader = DataLoader(arrg_dataset_train, shuffle=True, batch_size=24)
test_dataloader = DataLoader(arrg_dataset_test, shuffle=False, batch_size=4)

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-5)

In [None]:
f1_scores = []
all_loss = []

model.train()
for epoch in range(EPOCH):
    num_batches = 0
    losses = []
    
    model.eval()
    max_score = -1.0
    print('EPOCH: {} Starting eval...'.format(epoch+1))
    pred_labeles = []
    pred_cV = []
    pred_cR1 = []
    pred_cR2 = []
    pred_R1 = []
    pred_R2 = []
    
    logits = []
    for batch in tqdm(test_dataloader):
        token_ids = batch['tokens_ids']
        mask = batch['attention_mask']
        labels = batch['label']
        tags = batch['tags_ids']
        tokens_borders = batch['tokens_borders'].detach().numpy()
        
        out = model(token_ids.to(device), mask.to(device))
        
        class_probs = out['sentence_probs'].detach().to('cpu').numpy()[:, 1]
        logits += list(class_probs)
        class_probs[class_probs >= 0.5] = 1
        class_probs[class_probs < 0.5] = 0
        pred_labeles += list(class_probs)
        
        full_annotation_probs = out['full_annotation_probs'].detach().to('cpu').numpy()
        full_annotation_class = np.argmax(full_annotation_probs, axis=2)
        
        pred_tags = {
            'cV': '0:0',
            'cR1': '0:0',
            'cR2': '0:0',
            'R1': '0:0',
            'R2': '0:0'
        }
        for border, annot in zip(tokens_borders, full_annotation_class):
            for tag_id in [1, 2, 3, 4, 5]:
                border_tag = border[annot == tag_id]
                
                if border_tag.shape[0] != 0:
                    left = 0
                    right = 0
                    for b in border_tag:
                        l, r = b[0], b[1]
                        
                        if l == -1 and r == -1:
                            continue
                        
                        if l <= left:
                            left = l
                        if r >= right:
                            right = r
                
                pred_tags[TAGS[tag_id]] = str(left) + ':' + str(right)
        
            pred_cV.append(pred_tags['cV'])
            pred_cR1.append(pred_tags['cR1'])
            pred_cR2.append(pred_tags['cR2'])
            pred_R1.append(pred_tags['R1'])
            pred_R2.append(pred_tags['R2'])
    
    pred_df = pd.DataFrame({
        'class': pred_labeles,
        'cV': pred_cV,
        'cR1': pred_cR1,
        'cR2': pred_cR2,
        'R1': pred_R1,
        'R2': pred_R2
    })

    metrics = gapping_metrics(test_df, pred_df)
    
    if metrics['f1_score'] > max_score:
        print('Saving model...')
        max_score = metrics['f1_score']
        torch.save(model.state_dict(), os.path.join('checkpoints', NAME_MODEL + '_' + str(max_score) + '.model'))
    
    print('Eval metrics:', metrics)
    f1_scores.append(metrics)

    model.train()
    
    for i, batch in enumerate(train_dataloader):
        token_ids = batch['tokens_ids']
        mask = batch['attention_mask']
        labels = batch['label']
        tags = batch['tags_ids']
        
        labels = labels.type(torch.long).to(device)
        out = model(token_ids.to(device), mask.to(device))
        
        sentence_loss = criterion(out['sentence_probs'], labels)
        
        loss = sentence_loss
        
        loss.backward()
        losses.append(loss.item())

        optimizer.step()
        optimizer.zero_grad()
        
        if (i+1) % 64 == 0:
            print('Step: {}, Loss: {}'.format(i+1, np.mean(losses)))
            all_loss.append(np.mean(losses))
            losses = []


# UD

In [None]:
import sys

from deeppavlov import build_model, configs
from deeppavlov.models.morpho_tagger.common import call_model

from ufal.udpipe import Model as udModel, Pipeline

from Gapping.read_write import read_data, parse_ud_output

In [None]:
ud_model_path = "models/russian-syntagrus-ud-2.3-181115.udpipe"
train_path = "Gapping/data/train.csv"
outfile = "results/train.out"
tokenize, parse = False, True
tokenized_outfile = None


HYPHENS = "-—–"
QUOTES = "«“”„»``''"

def cannot_be_before_hyphen(x):
    return not (x.isalpha() or x.isdigit() or x in HYPHENS)       

def fix_quotes(x):
    answer = []
    lines = x.split("\n")
    for line in lines:
        line = line.strip()
        if line == "":
            continue
        splitted = line.split("\t")
        if splitted[1] in QUOTES:
            splitted[1] = splitted[2] = '"'
            splitted[3] = "PUNCT"
            splitted[5] = "_"
        answer.append("\t".join(splitted))
    answer = "\n".join(answer)
    return answer
    
def sanitize(sent):
    sent = "".join(a if a not in QUOTES else '"' for a in sent)
    answer = ""
    indexes = [0] + [i for i, a in enumerate(sent) if a in HYPHENS] + [len(sent)]
    start = 0
    for i, hyphen_index in enumerate(indexes[1:-1], 1):
        answer += sent[start:hyphen_index]
        if hyphen_index > 0 and hyphen_index < len(sent) - 1 and cannot_be_before_hyphen(sent[hyphen_index+1]) and sent[hyphen_index-1].isalpha():
            answer += " " + sent[hyphen_index]
            if sent[hyphen_index+1] != " ":
                answer += " "
        elif hyphen_index > 0 and hyphen_index < len(sent) - 1 and cannot_be_before_hyphen(sent[hyphen_index-1]) and sent[hyphen_index+1].isalpha():
            if sent[hyphen_index-1] != " ":
                answer += " "
            answer += sent[hyphen_index] + " "
        else:
            answer += sent[hyphen_index]
        start = hyphen_index + 1
    answer += sent[start:]
    return answer

if __name__ == "__main__":
    model = build_model(configs.morpho_tagger.UD2_0.morpho_ru_syntagrus_pymorphy_lemmatize, download=True)
    ud_model = udModel.load(ud_model_path)
    sents, answers = read_data(train_path)
    symbols = sorted(set(a for sent in sents for a in sent))
    sents = [sanitize(sent) for sent in sents]
    
    if tokenize:
        tokenized_data, for_tagging = "", []
        tokenizer = Pipeline(ud_model, "tokenize", Pipeline.NONE, Pipeline.NONE, "conllu")
        for start in range(0, len(sents), 40):
            if start % 400 == 0:
                print("{} sents processed".format(start))
            end = min(start + 40, len(sents))
            curr_output = tokenizer.process("\n\n".join(sents[start:end]))
            tokenized_data += curr_output + "\n"
            curr_output = parse_ud_output(curr_output)
            for_tagging.extend([[elem[1] for elem in sent] for sent in curr_output])
        if tokenized_outfile is not None:
            with open(tokenized_outfile, "w", encoding="utf8") as fout:
                fout.write(tokenized_data)
    else:
        for_tagging = sents
    if parse:
        print(len(for_tagging))
        print("Tagging...")
        tagged_data = call_model(model, for_tagging, batch_size=64)
        tagged_data = [fix_quotes(elem) for elem in tagged_data]
        print("Tagging completed...")
        parser = Pipeline(ud_model, "conllu", Pipeline.NONE, Pipeline.DEFAULT, "conllu")
        with open(outfile, "w", encoding="utf8") as fout:
            for start in range(0, len(tagged_data), 16):
                if start % 400 == 0:
                    print("{} sents processed".format(start))
                end = min(start + 16, len(tagged_data))
                parsed_data = parser.process("\n\n".join(tagged_data[start:end]))
                fout.write(parsed_data)

# GNN

In [None]:
import torch
import os

os.environ['TORCH_VERSION'] = torch.__version__
!echo $TORCH_VERSION

In [None]:
!pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_VERSION}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-${TORCH_VERSION}.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-${TORCH_VERSION}.html
!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-${TORCH_VERSION}.html
!pip install torch-geometric -f https://data.pyg.org/whl/torch-${TORCH_VERSION}.html

In [None]:
import torch
import numpy as np
import pandas as pd
import networkx as nx

from tqdm import tqdm 

import matplotlib.pyplot as plt

from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx

import youtokentome as yttm

In [None]:
class UDDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(UDDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return ['train.out', 'train.csv']

    @property
    def processed_file_names(self):
        return 'no.pt'
    
    def download(self):
        pass
    
    def process(self):
        self.PROP = ['ADJ', 'ADP', 'ADV', 'AUX', 'CCONJ', 'DET', 'INTJ', 'NOUN', 'NUM',
                     'PAD', 'PART', 'PRON', 'PROPN', 'PUNCT', 'SCONJ', 'SYM', 'VERB', 'X']
        
        self.parsed_sen = []
        with open(self.raw_paths[0], 'r', encoding="utf8") as parsed_file:
            one_parsed_sen = []
            for i, line in enumerate(parsed_file):
                line = line.strip()

                if line == '':
                    self.parsed_sen.append(one_parsed_sen)
                    one_parsed_sen = []
                    continue
                
                one_parsed_sen.append(line.split('\t'))
        
        
        idx = 0
        for ud_graph, label in tqdm(zip(self.parsed_sen,
                                        list(pd.read_csv(self.raw_paths[1], sep='\t', encoding='utf-8')['class'].values))):
            edge_index = []
            x = []
            for term in ud_graph:
                if int(term[-4]) == 0:
                    feat_prop = [0] * len(self.PROP)
                    feat_prop[self.PROP.index(term[3])] = 1
                    x.append(feat_prop)
                    continue

                feat_prop = [0] * len(self.PROP)
                feat_prop[self.PROP.index(term[3])] = 1
                x.append(feat_prop)

                edge_index.append([int(term[0])-1, int(term[-4])-1])

            x = torch.tensor(np.array(x))

            edge_index = np.array(edge_index)
            edge_index = edge_index.T
            edge_index = torch.tensor(edge_index, dtype=torch.long)
            #print(edge_index)
            
            label = torch.tensor([label])
            torch.save(Data(x=x, edge_index=edge_index, y=label),
                       os.path.join(self.processed_dir, f'UD_{idx}.pt'))
            idx += 1
        
    def len(self):
        return len(self.parsed_sen)
    
    def get(self, idx):
        return torch.load(os.path.join(self.processed_dir, f'UD_{idx}.pt'))

In [None]:
ud_dataset = UDDataset('/content/drive/MyDrive/Диплом/')

In [None]:
batch_size = 8
loader = DataLoader(ud_dataset, batch_size=batch_size)

In [None]:
train_loader = DataLoader(ud_dataset[:int(len(ud_dataset) * 0.8)], batch_size=batch_size)
val_loader = DataLoader(ud_dataset[int(len(ud_dataset) * 0.8):int(len(ud_dataset) * 0.9)], batch_size=batch_size)
test_loader = DataLoader(ud_dataset[int(len(ud_dataset) * 0.9):], batch_size=batch_size)

# GCN

In [None]:
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

In [None]:
from torch import nn

class GCN(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCN, self).__init__(aggr='add')
        self.lin = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        x = self.lin(x)

        norm = 1 / torch.sqrt(
            degree(edge_index[0])[edge_index[0]] * 
            degree(edge_index[1])[edge_index[1]]
        )

        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

In [None]:
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn


class GNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.0):
        super(GNN, self).__init__()

        self.convs = pyg_nn.Sequential(
            'x, edge_index', [
                (GCN(input_dim, hidden_dim), 'x, edge_index -> x'),
                nn.ReLU(),
                (GCN(hidden_dim, hidden_dim), 'x, edge_index -> x'),
                nn.ReLU(),
                (GCN(hidden_dim, hidden_dim), 'x, edge_index -> x'),
                nn.ReLU(),
            ]
        )

        self.post_mp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, data):
        x, edge_index, batch = data.x.type(torch.FloatTensor), data.edge_index, data.batch
        
        x = self.convs(x, edge_index)
        x = pyg_nn.global_max_pool(x, batch)
        x = self.post_mp(x)

        return x

In [None]:
def cross_entropy_loss(x, labels):
    return F.cross_entropy(x, labels)


def train(model, optimizer, train_loader, val_loader, epochs):
    train_loss = []
    val_accuracy = []

    for epoch in tqdm(range(epochs)):
        batch_train_loss = []
        batch_val_accuracy = []

        model.train()
        for batch in train_loader:
            logits = model(batch)
            labels = batch.y
            loss = cross_entropy_loss(logits, labels)
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            batch_train_loss.append(float(loss.data.numpy()))

        train_loss.append(np.mean(batch_train_loss))

        model.eval()
        for batch in val_loader:
            pred = torch.argmax(model(batch), dim=1)
            labels = batch.y
            batch_val_accuracy.append(np.mean((labels == pred).numpy()))

        val_accuracy.append(np.mean(batch_val_accuracy))
        
    return model, train_loss, val_accuracy


def plot_progress(train_loss, val_accuracy):
    fig, (train_ax, val_ax) = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))

    train_ax.plot(train_loss)
    train_ax.set_title('Train loss')
    train_ax.set_xlabel('Epoch')

    val_ax.plot(val_accuracy)
    val_ax.set_title('Val F1')
    val_ax.set_xlabel('Epoch')

    plt.show()
    
def evaluate(model, loader):
    model.eval()

    predictions = np.array([])
    labels = np.array([])

    for batch in loader:
        pred = torch.argmax(model(batch), dim=1)
        true = batch.y

        predictions = np.append(predictions, pred)
        labels = np.append(labels, true)

    return np.mean(predictions == labels)

In [None]:
import torch.optim as optim

hidden_dim = 5
model = GNN(18, hidden_dim, 2, 0.2)

optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
epochs = 30
model, train_loss, val_accuracy = train(model, optimizer, train_loader, val_loader, epochs)
plot_progress(train_loss, val_accuracy)

# GAT

In [None]:
import torch_geometric.utils as pyg_utils

class GAT(MessagePassing):
    def __init__(self, in_channels, out_channels, num_heads):
        super(GAT, self).__init__(aggr='add')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_heads = num_heads
        
        self.lin = nn.Linear(self.in_channels, self.num_heads * self.out_channels, bias=False)
        self.att = nn.Parameter(torch.Tensor(1, self.num_heads, 2 * out_channels))
        
        nn.init.xavier_uniform_(self.att)

    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        x = self.lin(x)

        return self.propagate(edge_index=edge_index, x=x)

    def message(self, x_i, x_j, index):
        x_i = x_i.view(-1, self.num_heads, self.out_channels)
        x_j = x_j.view(-1, self.num_heads, self.out_channels)
        
        concatenated_features = torch.cat([x_i, x_j], dim=-1)

        alpha = (concatenated_features * self.att).sum(dim=-1).unsqueeze(-1)
        alpha = F.leaky_relu(alpha)
        alpha = pyg_utils.softmax(alpha, index=index)
        
        return (alpha * x_j).mean(dim=1)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn


class GNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_heads, dropout=0.0):
        super(GNN, self).__init__()
        
        self.num_layers = num_layers
        self.dropout = dropout

        self.convs = self.convs = pyg_nn.Sequential(
            'x, edge_index', [
                (GAT(input_dim, hidden_dim, num_heads), 'x, edge_index -> x'),
                nn.ReLU(),
                (GAT(hidden_dim, hidden_dim, num_heads), 'x, edge_index -> x'),
                nn.ReLU(),
                (GAT(hidden_dim, hidden_dim, num_heads), 'x, edge_index -> x'),
                nn.ReLU(),
            ]
        )

        self.post_mp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        x = self.convs(x, edge_index)
        x = pyg_nn.global_max_pool(x, batch)
        
        x = self.post_mp(x)

        return x

In [None]:
import torch.optim as optim

hidden_dim = 5
model = GNN(input_dim=18, hidden_dim, output_dim=2, num_heads=3, 0.2)

optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
epochs = 30
model, train_loss, val_accuracy = train(model, optimizer, train_loader, val_loader, epochs)
plot_progress(train_loss, val_accuracy)

# Full annotation

In [None]:
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn


class GNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_heads=1, dropout=0.0):
        super(GNN, self).__init__()
        
        self.num_layers = num_layers
        self.dropout = dropout

        self.convs = nn.ModuleList()
        self.convs.append(GCN(input_dim, hidden_dim))
        for l in range(self.num_layers - 1):
            self.convs.append(GCN(hidden_dim, hidden_dim))

        self.post_mp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(), 
            nn.Dropout(self.dropout), 
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        
        x = self.post_mp(x)

        return x

In [None]:
def train_for_node_classification(model, optimizer, graph, epochs):
    train_loss = []
    val_accuracy = []

    for epoch in tqdm(range(epochs)):
        model.train()

        logits = model(graph)
        
        train_logits = logits[graph.train_mask]
        train_labels = graph.y[graph.train_mask]

        loss = cross_entropy_loss(train_logits, train_labels)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        train_loss.append(loss.detach().numpy())

        model.eval()

        pred = model(graph).max(dim=1)[1]
        
        val_pred = pred[graph.val_mask]
        val_labels = graph.y[graph.val_mask]
        
        val_accuracy.append(np.mean((val_labels == val_pred).numpy()))
        
    return model, train_loss, val_accuracy


def evaluate_node_classification(model, graph):
    model.eval()
    mask = graph.test_mask

    predictions = model(graph).max(dim=1)[1].numpy()[graph.test_mask]
    labels = graph.y.numpy()[graph.test_mask]

    return np.mean(predictions == labels)

In [None]:
num_layers = 3
dropout = 0.2
hidden_dim = 16
lr = 0.001

model = GNN(dataset.num_node_features, hidden_dim, dataset.num_classes, num_layers, dropout)
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
epochs = 200
model, train_loss, val_accuracy = train_for_node_classification(model, optimizer, dataset[0], epochs)
plot_progress(train_loss, val_accuracy)