In [1]:
# Model.py 
%load_ext autoreload
%autoreload 2

In [2]:
# model jupyter code

# internal
from preprocess import *
import config

# util
import networkx
from tqdm import tqdm_notebook as tqdm

# function
import torch
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

In [3]:
train_set = MNLI_Dataset(config.TRAIN_FILE, "train")

In [None]:
# how to create batch for graph???

"""
DataLoader for minibatch
in each batch
- tokens_tensors  : (batch_size, max_seq_len_in_batch)
- segments_tensors: (batch_size, max_seq_len_in_batch)
- masks_tensors   : (batch_size, max_seq_len_in_batch)
- label_ids       : (batch_size)
"""
def create_mini_batch(samples):
    tokens_tensors = [s[0] for s in samples]
    segments_tensors = [s[1] for s in samples]
    
    # use(have) label or not
    if samples[0][2] is not None:
        label_ids = torch.stack([s[2] for s in samples])
    else:
        label_ids = None
    
    # zero pad to same length
    tokens_tensors = pad_sequence(tokens_tensors,  batch_first=True)
    segments_tensors = pad_sequence(segments_tensors,  batch_first=True)
    
    # attention masks, set none-padding part to 1 for LM to attend
    masks_tensors = torch.zeros(tokens_tensors.shape, dtype=torch.long)
    masks_tensors = masks_tensors.masked_fill( tokens_tensors != 0, 1)
    
    return tokens_tensors, segments_tensors, masks_tensors, label_ids

In [4]:
len(train_set)

392702

In [5]:
train_set[0]

(tensor([  101, 17158,  2135,  6949,  8301, 25057,  2038,  2048,  3937,  9646,
          1011,  4031,  1998, 10505,  1012,   100]),
 tensor([  101,  4031,  1998, 10505,  2024,  2054,  2191,  6949,  8301, 25057,
          2147,  1012,   100]),
 tensor(1))

In [6]:
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3-5: Start propagating messages.
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_j, edge_index, size):
        # x_j has shape [E, out_channels]

        # Step 3: Normalize node features.
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels]

        # Step 5: Return new node embeddings.
        return aggr_out

In [7]:
con1 = GCNConv(500, 500)