# Building Graph Neural Network for NLP task

In [None]:
# Install required packages.
import os
import torch
import json
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

# Helper function for visualization.
%matplotlib inline
import networkx as nx
import matplotlib.pyplot as plt

from konlpy.tag import Okt
from gensim.models import Word2Vec
Tokenizer = Okt()

path = os.path.dirname(os.getcwd())
print(path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
base_params = {

## model ["GNN", "RNN", "LSTM", "BERT", "FAST"]
    "NAME": "GNN",
    
## Path params
    "DATA": os.path.join("NLP_data", "nsmc_test.json"),
    "MODEL_NAME": "ko_word2vec.model",
    "VOCAB": "ko_word2vec.model",

## Data params
    "INPUT_DIM": 128,
    "N_CLASSES": 2,
    "MAX_LENGTH": 64,
    "PADDING": False,

## Model params
    "HIDDEN": 128,
    "EMBEDDING_DIM": 128, #RNN
    "N_LAYER": 2, #LSTM
    "BIDIRECT": True, #LSTM
    "DROPOUT": 0.5,
   

## Training params
    "BATCH_FIRST": True,
    "LR": 0.0001,
    "WARMUP_RATIO": 0.2,
    "MAX_EPOCHES": 100,
    "BATCH_SZ": 128
}

In [None]:
import json
import os

class Vocab:
    def __init__(self, vocab_path, name="Vocab"):
        self.name = name

        model = Word2Vec.load(vocab_path)
        vocab = [word for word in model.wv.index_to_key]
        self.vocab = ["[PAD]", "[SOS]", "[EOS]", "[CLS]", "[UNK]"] + vocab

        self.word2index = dict([(word, i) for i, word in enumerate(self.vocab)])
        self.index2word = dict([(i, word) for i, word in enumerate(self.vocab)])
        self.n_words = len(self.vocab) # Count default tokens
    def handle_unknown_word():
        pass

kor_vocab = Vocab("ko_word2vec.model")
# kor_sign_vocab.index2word[1]

PAD_IDX = kor_vocab.word2index["[PAD]"]
EOS_IDX = kor_vocab.word2index["[EOS]"]
UNK_IDX = kor_vocab.word2index["[UNK]"]

base_params["N_VOCAB"] = kor_vocab.n_words
print(kor_vocab.n_words)


## Build dataset

### Bert dataset

In [None]:
# from kobert import get_tokenizer
# from kobert import get_pytorch_kobert_model

# bertmodel, vocab = get_pytorch_kobert_model(cachedir=".cache")

In [None]:
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np

class BERTDataset(Dataset):
    def __init__(self, dataset, bert_tokenizer, max_len):

        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=True, pair=False)

        self.sentences = [transform([i[0]]) for i in dataset]
        self.labels = [np.int32(i[2]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

In [None]:
import pytorch_lightning as pl
from sklearn.model_selection import train_test_split
import random

class BERTData_Module(pl.LightningDataModule):

    def __init__(self, args):
        super().__init__()
        self.data_dir = args["DATA"]
        self.batch_size = args["BATCH_SZ"]
        self.vocab_dir = args["VOCAB"]
        self.args = args
        self.prepare_data()

    def prepare_data(self):

        self.sentences = []
        with open(self.data_dir, 'r') as f:
            sentences = json.load(f)
            for sentence in sentences:
                if(len(sentence[1]) > 0):
                    self.sentences.append(sentence)
        # self.sentences = self.sentences[:len(self.sentences)//10]
        self.sentences = random.sample(self.sentences, len(self.sentences)//5)
        
      
        tokenizer = get_tokenizer()
        self.tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

        print('Vocab size: ', len(vocab))
        print("load %d samples"%(len(self.sentences)))
       
        
    def setup(self, stage: str):
        
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            self._train, self._val = train_test_split(self.sentences, test_size=0.3)

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self._test = self.sentences
        
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
                BERTDataset(self._train, bert_tokenizer=self.tok, max_len=self.args["MAX_LENGTH"]),
                batch_size=self.batch_size,
                num_workers= 8,
                shuffle=True)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(
                BERTDataset(self._val, bert_tokenizer=self.tok, max_len=self.args["MAX_LENGTH"]),
                batch_size=self.batch_size,
                num_workers= 8,
                shuffle=False)
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(
                BERTDataset(self._test, bert_tokenizer=self.tok, max_len=self.args["MAX_LENGTH"]),
                batch_size=self.batch_size,
                num_workers= 8,
                shuffle=False)

In [None]:
datamodule = BERTData_Module(args=base_params)
datamodule.setup(stage="fit")

for batch in datamodule.train_dataloader():
    # print(batch.size())
    print(batch)
    break

### Graph dataset

In [None]:
from torch_geometric.data import Data, Dataset
import itertools

class EdgeEncoder(object):
    def __init__(self, model_name='ko_word2vec.model', dtype=None):
        self.dtype = dtype
        self.model = Word2Vec.load(model_name)

    def __call__(self, node1, node2):

        try:
            edge_weight = self.model.wv.similarity(node1, node2)
            edge_weight = torch.tensor(edge_weight).to(self.dtype)
        except:
            edge_weight = torch.tensor(0.).to(self.dtype)
        return edge_weight

class SequenceEncoder(object):
    def __init__(self, model_name='ko_word2vec.model', dtype=None, size=128):
        
        self.model = Word2Vec.load(model_name)
        self.dtype = dtype
        self.size = size

    def __call__(self, word):

        try:
            x = self.model.wv[word]
            x = torch.from_numpy(x).view(self.size).to(self.dtype)
        except:
            x = torch.zeros(self.size).to(self.dtype)
        return x


class NSMC_Graph(Dataset):
    def __init__(self, sentences, args):
        super().__init__()

        self.sentences = sentences
        
        self.size = args["INPUT_DIM"]
        self.node_encoder = SequenceEncoder(args["MODEL_NAME"], dtype=torch.float, size=self.size)
        self.edge_encoder = EdgeEncoder(args["MODEL_NAME"], dtype=torch.float)

    def len(self):
        return len(self.sentences)

    def get(self, idx, return_sample=False):
        data = self.sentences[idx]
       
        tokens = data[1]
      
        xs = [self.node_encoder(token) for token in tokens]
        edge_index = list(itertools.product(range(len(tokens)), range(len(tokens))))
        edge_attr = [self.edge_encoder(tokens[i], tokens[j]) for (i, j) in edge_index]
       
        edge_index = torch.tensor(edge_index)
        edge_attr = torch.stack(edge_attr)
        x = torch.stack(xs)
        label = torch.tensor(data[2]).to(torch.long)
        if(return_sample == True): 
            return Data(x=x, edge_index=edge_index.T, edge_weight=edge_attr, y=label), tokens
        else:
            return Data(x=x, edge_index=edge_index.T, edge_weight=edge_attr, y=label)



In [None]:
with open(os.path.join("NLP_data", "nsmc_test.json"), 'r') as f:
    nscm_data = json.load(f)
   

dataset = NSMC_Graph(nscm_data[:200], base_params)
data = dataset[5]
print(data)

In [None]:
from torch_geometric.utils import to_networkx

def draw_graph(g, labels, edge_mask=None, draw_edge_labels=False):
    # g = g.copy().to_undirected()
    g = to_networkx(g, edge_attrs=["edge_weight"], to_undirected=True)
    node_labels = {}
    for i in range(len(labels)):
        node_labels[i] = labels[i]
        
    # pos = nx.planar_layout(g)
    # pos = nx.spring_layout(g, pos=pos)
    if edge_mask is None:
        edge_color = 'black'
        widths = None
    else:
        edge_color = [data["edge_weight"] for u, v, data in g.edges(data=True)]
        widths = [x * 5 for x in edge_color]

    # fontprop = fm.FontProperties(fname='NanumGothic.otf', size=18)
    nx.draw(g, labels=node_labels, width=widths,
            edge_color=edge_color, edge_cmap=plt.cm.Blues,
            node_color='azure', font_family="NanumBarunGothic")
    
    if draw_edge_labels and edge_mask is not None:
        edge_labels = {k: ('%.2f' % v) for k, v in edge_mask.items()}    
        nx.draw_networkx_edge_labels(g, edge_labels=edge_labels,
                                    font_color='red')
    plt.show()

In [None]:
import random

i = random.choice(range(dataset.len()))
data = dataset.get(i, return_sample=True)
print(data)
plt.figure(figsize=(10, 5))

draw_graph(g = data[0], labels=data[1], edge_mask=True)

In [None]:
import pytorch_lightning as pl
from sklearn.model_selection import train_test_split
import random
# from torch_geometric.loader import DataLoader
import torch_geometric

class GraphDataModule(pl.LightningDataModule):

    def __init__(self, args):
        super().__init__()
        self.data_dir = args["DATA"]
        self.batch_size = args["BATCH_SZ"]
        self.args = args
        self.prepare_data()

    def prepare_data(self):

        self.sentences = []
        with open(self.data_dir, 'r') as f:
            sentences = json.load(f)
            for sentence in sentences:
                if(len(sentence[1]) > 0):
                    self.sentences.append(sentence)
        # self.sentences = self.sentences[:len(self.sentences)//10]
        self.sentences = random.sample(self.sentences, len(self.sentences)//5)
        print("load %d samples"%(len(self.sentences)))
       
        
    def setup(self, stage: str):
        
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            self._train, self._val = train_test_split(self.sentences, test_size=0.3)

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self._test = self.sentences
        
    def train_dataloader(self):
        return torch_geometric.loader.DataLoader(
                NSMC_Graph(self._train, args=self.args),
                batch_size=self.batch_size,
                shuffle=True)
    
    def val_dataloader(self):
        return torch_geometric.loader.DataLoader(
                NSMC_Graph(self._val, args=self.args),
                batch_size=self.batch_size,
                shuffle=False)
    
    def test_dataloader(self):
        return torch_geometric.loader.DataLoader(
                NSMC_Graph(self._test, args=self.args),
                batch_size=self.batch_size,
                shuffle=False)

In [None]:
datamodule = GraphDataModule(args=base_params)
datamodule.setup(stage="fit")

for batch in datamodule.train_dataloader():
    print(batch)
    break

### Sequence dataset

In [None]:
from torch.utils.data import DataLoader, Dataset

class NSMC_seq(Dataset):
    def __init__(self, data, vocab, args):
        self.data = data
        self.vocab = vocab
        self.padding = args["PADDING"]
        self.max_length = args["MAX_LENGTH"]
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        
        _, item, label = self.data[idx]
        
        # adding PAD, EOS, SOS and truncated
        tokens = [self.vocab.word2index["[SOS]"]] + [self.vocab.word2index[token] if(token in self.vocab.word2index) else self.vocab.word2index["[UNK]"] 
                        for token in item][:self.max_length-2] + [self.vocab.word2index["[EOS]"]]
       
        if(self.padding == True):
            tokens += [self.vocab.word2index["[PAD]"] for i in range(self.max_length - len(tokens))]
          
        return (tokens, label)

def my_collate(batch):
    max_len_input = max([len(item[0]) for item in batch])
    label = [item[1] for item in batch]
    len_input = [min(len(item[0]), max_len_input) for item in batch]

    tokens = [item[0] + [PAD_IDX for i in range(max_len_input - len(item[0]))] for item in batch]
   
    return {
        "input_ids": torch.tensor(tokens, dtype=int),
        "src_length": torch.tensor(len_input, dtype=int),
        "labels": torch.tensor(label, dtype=int)
    }

In [None]:
import pytorch_lightning as pl
from sklearn.model_selection import train_test_split
import random

class SeqDataModule(pl.LightningDataModule):

    def __init__(self, args):
        super().__init__()
        self.data_dir = args["DATA"]
        self.batch_size = args["BATCH_SZ"]
        self.vocab_dir = args["VOCAB"]
        self.args = args
        self.prepare_data()

    def prepare_data(self):

        self.sentences = []
        with open(self.data_dir, 'r') as f:
            sentences = json.load(f)
            for sentence in sentences:
                if(len(sentence[1]) > 0):
                    self.sentences.append(sentence)
        # self.sentences = self.sentences[:len(self.sentences)//10]
        self.sentences = random.sample(self.sentences, len(self.sentences)//5)
        
        self.vocab = Vocab(self.vocab_dir)
        print('Vocab size: ', self.vocab.n_words)
        print("load %d samples"%(len(self.sentences)))
       
        
    def setup(self, stage: str):
        
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            self._train, self._val = train_test_split(self.sentences, test_size=0.3)

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self._test = self.sentences
        
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
                NSMC_seq(self._train, self.vocab, args=self.args),
                batch_size=self.batch_size,
                num_workers= 8,
                shuffle=True,
                collate_fn=my_collate)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(
                NSMC_seq(self._val, self.vocab, args=self.args),
                batch_size=self.batch_size,
                num_workers= 8,
                shuffle=False,
                collate_fn=my_collate)
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(
                NSMC_seq(self._test, self.vocab, args=self.args),
                batch_size=self.batch_size,
                num_workers= 8,
                shuffle=False,
                collate_fn=my_collate)

In [None]:
datamodule = SeqDataModule(args=base_params)
datamodule.setup(stage="fit")

for batch in datamodule.train_dataloader():
    print(batch["input_ids"].size())
    print(batch)
    break

## Define model

In [None]:
def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        torch.nn.init.xavier_uniform_(m.weight.data)

### Graph model

In [None]:
import torch
from torch.nn import Linear
import torch.nn.functional as F

from torch_geometric.nn import GraphConv, global_add_pool

class Net(torch.nn.Module):
    def __init__(self, dim, num_features, num_classes):
        super(Net, self).__init__()

        self.conv1 = GraphConv(num_features, dim)
        self.conv2 = GraphConv(dim, dim)
        # self.conv3 = GraphConv(dim, dim)
        # self.conv4 = GraphConv(dim, dim)
        # self.conv5 = GraphConv(dim, dim)

        self.lin1 = Linear(dim, dim)
        self.lin2 = Linear(dim, num_classes)

    def forward(self, x, edge_index, batch, edge_weight=None):
        
        x = self.conv1(x, edge_index, edge_weight).relu()
        x = self.conv2(x, edge_index, edge_weight).relu()
        # x = self.conv3(x, edge_index, edge_weight).relu()
        # x = self.conv4(x, edge_index, edge_weight).relu()
        # x = self.conv5(x, edge_index, edge_weight).relu()
        x = global_add_pool(x, batch)
        x = self.lin1(x).relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)



### RNN

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class RNN(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
        
        super().__init__()
        
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        
        self.rnn = nn.RNN(embedding_dim, hidden_dim)
        
        self.fc = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, text):

        #text = [sent len, batch size]
        
        embedded = self.embedding(text)
        
        #embedded = [sent len, batch size, emb dim]
        
        output, hidden = self.rnn(embedded)
        
        #output = [sent len, batch size, hid dim]
        #hidden = [1, batch size, hid dim]
        
        assert torch.equal(output[-1,:,:], hidden.squeeze(0))
        hidden = self.fc(hidden.squeeze(0))
       
        return F.log_softmax(hidden, dim=-1)
       

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class LSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, 
                 bidirectional, dropout, pad_idx):
        
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = pad_idx)
        
        self.rnn = nn.LSTM(embedding_dim, 
                           hidden_dim, 
                           num_layers=n_layers, 
                           bidirectional=bidirectional, 
                           dropout=dropout)
        
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, text, text_lengths):
        
        #text = [sent len, batch size]
        
        embedded = self.dropout(self.embedding(text))
        
        #embedded = [sent len, batch size, emb dim]
        
        #pack sequence
        # lengths need to be on CPU!
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths.to('cpu'), enforce_sorted=False)
        
        packed_output, (hidden, cell) = self.rnn(packed_embedded)
        
        #unpack sequence
        output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)

        #output = [sent len, batch size, hid dim * num directions]
        #output over padding tokens are zero tensors
        
        #hidden = [num layers * num directions, batch size, hid dim]
        #cell = [num layers * num directions, batch size, hid dim]
        
        #concat the final forward (hidden[-2,:,:]) and backward (hidden[-1,:,:]) hidden layers
        #and apply dropout
        
        hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))
                
        #hidden = [batch size, hid dim * num directions]
        hidden = self.fc(hidden)

            
        return F.log_softmax(hidden, dim=-1)

## Fastext

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class FastText(nn.Module):
    def __init__(self, vocab_size, embedding_dim, output_dim, pad_idx):
        
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        
        self.fc = nn.Linear(embedding_dim, output_dim)
        
    def forward(self, text):
        
        #text = [sent len, batch size]
        
        embedded = self.embedding(text)
                
        #embedded = [sent len, batch size, emb dim]
        
        embedded = embedded.permute(1, 0, 2)
        
        #embedded = [batch size, sent len, emb dim]
        
        pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1) 
        
        #pooled = [batch size, embedding_dim]
        hidden = self.fc(pooled)
                
        return F.log_softmax(hidden, dim=-1)


def initialize_weights_FAST(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        torch.nn.init.xavier_uniform_(m.weight.data)
    w2v = Word2Vec.load('ko_word2vec.model')
    pretrained_embeddings = torch.tensor(w2v.syn1neg)
    pretrained_embeddings = torch.concat([torch.zeros(5, base_params["EMBEDDING_DIM"]), pretrained_embeddings]) #extend special characters
    m.embedding.weight.data.copy_(pretrained_embeddings)
    # m.embedding.weight.data[UNK_IDX] = torch.zeros(base_params["EMBEDDING_DIM"])
    # m.embedding.weight.data[PAD_IDX] = torch.zeros(base_params["EMBEDDING_DIM"])

## BERT

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=2,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        else:
            out = pooler
        
        out = self.classifier(out)
        return  F.log_softmax(out, dim=-1)

## Trainer

In [None]:
import pytorch_lightning as pl
from transformers.optimization import AdamW, get_cosine_schedule_with_warmup
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Module(pl.LightningModule):

    def __init__(self, kwargs):
        super(Module, self).__init__()

        if(kwargs["NAME"] == "GNN"):

            self.model = Net(dim = kwargs["INPUT_DIM"], num_features=kwargs["HIDDEN"], num_classes=kwargs["N_CLASSES"])
            initialize_weights(self.model)

        elif(kwargs["NAME"] == "RNN"):

            self.model = RNN(kwargs["N_VOCAB"], kwargs["EMBEDDING_DIM"], kwargs["HIDDEN"], kwargs["N_CLASSES"])
            initialize_weights(self.model)

        elif(kwargs["NAME"] == "LSTM"):

            self.model = LSTM(vocab_size=kwargs["N_VOCAB"], embedding_dim=kwargs["EMBEDDING_DIM"], hidden_dim=kwargs["HIDDEN"], output_dim=kwargs["N_CLASSES"],
                                n_layers=kwargs["N_LAYER"], bidirectional=kwargs["BIDIRECT"], dropout=kwargs["DROPOUT"], pad_idx=PAD_IDX)
            initialize_weights(self.model)

        elif(kwargs["NAME"] == "FAST"):
            self.model = FastText(vocab_size=kwargs["N_VOCAB"], embedding_dim=kwargs["EMBEDDING_DIM"], output_dim=kwargs["N_CLASSES"], pad_idx=PAD_IDX)
            initialize_weights_FAST(self.model)

        elif(kwargs["NAME"] == "BERT"):

            self.model = BERTClassifier(bert=bertmodel, num_classes=kwargs["N_CLASSES"], dr_rate=kwargs["DROPOUT"])
        
        self.loss_function = torch.nn.NLLLoss()
        self.kwargs = kwargs
        
    def configure_optimizers(self):
        # Prepare optimizer
        
        optimizer = torch.optim.AdamW(self.parameters(), lr= self.kwargs["LR"])
        # warm up lr
        num_train_steps = len(self.trainer._data_connector._train_dataloader_source.dataloader()) * self.kwargs["MAX_EPOCHES"] ##because of lighting problem from 1.5 version
        num_warmup_steps = int(num_train_steps * self.kwargs["WARMUP_RATIO"])
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps)
        lr_scheduler = {'scheduler': scheduler, 'name': 'cosine_schedule_with_warmup',
                        'monitor': 'loss', 'interval': 'step',
                        'frequency': 1}
        return [optimizer], [lr_scheduler]

    def forward(self, batch):
        
        if(self.kwargs["NAME"] == "GNN"):
            batch = batch.to(device)
            output = self.model(x=batch.x, edge_index=batch.edge_index, batch=batch.batch, edge_weight=batch.edge_weight)
           
        elif(self.kwargs["NAME"] in ["RNN", "FAST"]): 
            x = torch.swapaxes(batch["input_ids"], 0, 1)
            output = self.model(x)

        elif(self.kwargs["NAME"] == "LSTM"): 
            x = torch.swapaxes(batch["input_ids"], 0, 1)
            output = self.model(x, batch["src_length"])
        
        elif(self.kwargs["NAME"] == "BERT"):
            token_ids, valid_length, segment_ids, _ =  batch
            token_ids = token_ids.long()
            segment_ids = segment_ids.long()
            valid_length= valid_length
            output = self.model(token_ids, valid_length, segment_ids)

        return output

    def cal_loss(self, outputs, trg):
        
        return self.loss_function(outputs, trg)

    def training_step(self, batch, batch_idx):
      
        outputs = self(batch)

        if(self.kwargs["NAME"] == "GNN"): target = batch.y
        elif(self.kwargs["NAME"] in ["RNN", "LSTM", "FAST"]): target = batch["labels"]
        elif(self.kwargs["NAME"] == "BERT"): target = batch[3].long()
           
        loss = self.cal_loss(outputs, target)
        
        self.log('train_loss', loss, on_epoch=True, prog_bar=True, logger=True, batch_size=self.kwargs["BATCH_SZ"])
        return loss

    def validation_step(self, batch, batch_idx):
             
        outputs = self(batch)

        if(self.kwargs["NAME"] == "GNN"): target = batch.y
        elif(self.kwargs["NAME"] in ["RNN", "LSTM", "FAST"]): target = batch["labels"]
        elif(self.kwargs["NAME"] == "BERT"): target = batch[3].long()
           
        loss = self.cal_loss(outputs, target)
        self.log('val_loss', loss, batch_size=self.kwargs["BATCH_SZ"])
        
    def test_step(self, batch, batch_idx):
      
        outputs = self(batch)

        if(self.kwargs["NAME"] == "GNN"): target = batch.y
        elif(self.kwargs["NAME"] in ["RNN", "LSTM", "FAST"]): target = batch["labels"]
        elif(self.kwargs["NAME"] == "BERT"): target = batch[3].long()
           
        loss = self.cal_loss(outputs, target)
        self.log('test_loss', loss, batch_size=self.kwargs["BATCH_SZ"])
        
        _, outs = torch.max(outputs, dim=-1)
        return {'output': outs, 'label': target} #be careful, it gathers data on each GPUs so the data will be split

    def test_epoch_end(self, test_step_outputs):
        acc = []
        
        for out in test_step_outputs:

            labels = out["label"]
            outputs = out["output"]
            
            acc += [outputs.eq(labels).sum().item() / self.kwargs["BATCH_SZ"]]
        print("Accuracy: ", (np.array(acc)).mean())

## Test model

In [None]:
# data_module = BERTData_Module(args=base_params)
# data_module.setup(stage="fit")

In [None]:
# model = Module(kwargs=base_params)
# loss = torch.nn.NLLLoss()
# batch = next(iter(data_module.train_dataloader()))

# outputs = model(batch)

# print(outputs.type())
# print(outputs.size())
# l = loss(outputs, batch[3].long())
# print(l)

## Training

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

data_module = GraphDataModule(args=base_params)
# data_module = BERTData_Module(args=base_params)
data_module.setup(stage="fit")

checkpoint_callback = ModelCheckpoint(
    # dirpath='model_chp',
    filename='{epoch:02d}',
    verbose=True,
    monitor='val_loss',
    mode='min'
)

model = Module(kwargs=base_params)
trainer = Trainer(
    callbacks=checkpoint_callback,
    max_epochs=base_params["MAX_EPOCHES"],
    gradient_clip_val=1.0,
    accelerator="gpu",
    devices=1,
    strategy="dp",
    )

trainer.fit(model, data_module)
print('best model path {}'.format(checkpoint_callback.best_model_path))

In [None]:
# from tkinter import SE
# from pytorch_lightning import Trainer

# data_module = BERTData_Module(args=base_params)
# data_module.setup(stage="test")
# model = Module(kwargs=base_params)
# trainer = Trainer(accelerator="gpu", strategy="dp", devices=1)
# trainer.test(model=model, datamodule=data_module, ckpt_path="lightning_logs/BERT/checkpoints/epoch=01.ckpt")