In [None]:
import torch
from sklearn.metrics import f1_score
from graph_qa_dataset import GraphQADataset
from torch_geometric.loader import DataListLoader
from model.graph_qa_model import GAT, HGT

In [None]:
train_dataset = GraphQADataset(split="train", data_size=20)
val_dataset = GraphQADataset(split="validation", data_size=2)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
model = HGT(hidden_channels=64, out_channels=2, num_heads=4, num_layers=2, metadata=train_dataset.metadata)
model = model.to(device)

In [None]:
loss_op = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)

In [None]:
train_dataloader = DataListLoader(train_dataset, batch_size=1, shuffle=True)
val_dataloader = DataListLoader(val_dataset, batch_size=1, shuffle=False)

In [None]:
def train(dataloader):
    model.train()
    
    total_loss = 0
    for batch in dataloader:
        data = batch[0]
        optimizer.zero_grad()
        out = model(data.x_dict, data.edge_index_dict)

        loss = loss_op(out, data["context"].y)
        print(loss)
        total_loss += loss.item()
        loss.backward()

train(train_dataloader)

In [None]:
ll = list(range(100))

In [None]:
import stanza

In [2]:
nlp = stanza.Pipeline(lang='en', processors='tokenize,pos,constituency', use_gpu=True)

2022-10-31 18:13:42 INFO: Loading: constituency
2022-10-31 18:13:43 INFO: Done loading processors!


In [3]:
nlp("""No! Just as in regular PyTorch, you do not have to use datasets, e.g., when you want to create synthetic data on the fly without saving them explicitly to disk. In this case, simply pass a regular python list holding torch_geometric.data.Data objects and pass them to torch_geometric.loader.DataLoader:""")

[
  [
    {
      "id": 1,
      "text": "No",
      "upos": "INTJ",
      "xpos": "UH",
      "feats": "Polarity=Neg",
      "start_char": 0,
      "end_char": 2
    },
    {
      "id": 2,
      "text": "!",
      "upos": "PUNCT",
      "xpos": ".",
      "start_char": 2,
      "end_char": 3
    }
  ],
  [
    {
      "id": 1,
      "text": "Just",
      "upos": "ADV",
      "xpos": "RB",
      "start_char": 4,
      "end_char": 8
    },
    {
      "id": 2,
      "text": "as",
      "upos": "SCONJ",
      "xpos": "IN",
      "start_char": 9,
      "end_char": 11
    },
    {
      "id": 3,
      "text": "in",
      "upos": "ADP",
      "xpos": "IN",
      "start_char": 12,
      "end_char": 14
    },
    {
      "id": 4,
      "text": "regular",
      "upos": "ADJ",
      "xpos": "JJ",
      "feats": "Degree=Pos",
      "start_char": 15,
      "end_char": 22
    },
    {
      "id": 5,
      "text": "PyTorch",
      "upos": "NOUN",
      "xpos": "NN",
      "feats": "Number=Sing",
 