In [1]:
import torch.nn.functional as F
import torch.optim as optim
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from gcn_utils import collate
from utils import get_data
from tqdm import tqdm
from model import GCN


Using backend: pytorch


In [2]:
train_data = get_data('gdb_9_clean.tsv', dataset='gdb-9', device='cuda:0')
data_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate)
model = GCN(27, 16, 32, 11, 2, F.relu, 0.1)
loss_func = nn.MSELoss(reduction='none')
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
device = torch.device('cuda:0')
model.to(device)

Reading raw data ... : : 133885it [01:21, 1637.87it/s]
Converting to DGL graphs ...: 100%|███████████████████████████████████████████| 133885/133885 [03:15<00:00, 685.07it/s]


GCN(
  (element_emebdding): Embedding(27, 16)
  (layers): ModuleList(
    (0): GCNLayer()
    (1): GCNLayer(
      (dropout): Dropout(p=0.1)
    )
  )
  (out): Linear(in_features=32, out_features=11, bias=True)
)

In [None]:
for epoch in range(80):
    epoch_loss = 0
    batch = tqdm(data_loader)
    for bg, label in batch:

        prediction = model(bg)
        loss = torch.sum(loss_func(prediction, label) * 1/label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
        batch.set_description(f'epoch {epoch} loss {epoch_loss}')
    print(label.cpu())
    print(label.cpu().size())
    if epoch > 10:
        print(prediction.detach().cpu())
        print(label.cpu())


  0%|          | 0/4184 [00:00<?, ?it/s]epoch 0 loss -6545.298828125:   0%|          | 0/4184 [00:00<?, ?it/s]epoch 0 loss -6545.298828125:   0%|          | 1/4184 [00:00<22:55,  3.04it/s]epoch 0 loss -19888.859375:   0%|          | 1/4184 [00:00<22:55,  3.04it/s]  epoch 0 loss -24431.306640625:   0%|          | 1/4184 [00:00<22:55,  3.04it/s]epoch 0 loss -37418.2978515625:   0%|          | 1/4184 [00:00<22:55,  3.04it/s]epoch 0 loss -48366.5654296875:   0%|          | 1/4184 [00:00<22:55,  3.04it/s]epoch 0 loss -59333.1298828125:   0%|          | 1/4184 [00:00<22:55,  3.04it/s]epoch 0 loss -59333.1298828125:   0%|          | 6/4184 [00:00<16:30,  4.22it/s]epoch 0 loss -71987.357421875:   0%|          | 6/4184 [00:00<16:30,  4.22it/s] epoch 0 loss -81730.93359375:   0%|          | 6/4184 [00:00<16:30,  4.22it/s] epoch 0 loss -92307.908203125:   0%|          | 6/4184 [00:00<16:30,  4.22it/s]epoch 0 loss -102483.8828125:   0%|          | 6/4184 [00:00<16:30,  4.22it/s] epoc

tensor([[ 8.4440e+01, -1.9650e-01,  4.9500e-02,  2.4600e-01,  1.6032e+03,
          1.7052e-01, -4.0317e+02, -4.0316e+02, -4.0316e+02, -4.0321e+02,
          3.4146e+01],
        [ 8.0620e+01, -2.6610e-01,  1.0900e-02,  2.7700e-01,  1.4576e+03,
          1.3396e-01, -3.8083e+02, -3.8082e+02, -3.8082e+02, -3.8086e+02,
          3.4437e+01],
        [ 7.8230e+01, -2.1490e-01,  8.2900e-02,  2.9790e-01,  1.0731e+03,
          1.7051e-01, -4.0306e+02, -4.0305e+02, -4.0305e+02, -4.0309e+02,
          3.2550e+01],
        [ 8.5370e+01, -2.3140e-01,  4.4700e-02,  2.7610e-01,  1.1692e+03,
          1.6839e-01, -3.6594e+02, -3.6593e+02, -3.6593e+02, -3.6597e+02,
          3.6203e+01],
        [ 9.5190e+01, -2.3500e-01,  2.7000e-02,  2.6200e-01,  1.5382e+03,
          2.2823e-01, -3.5238e+02, -3.5237e+02, -3.5237e+02, -3.5242e+02,
          3.9231e+01],
        [ 7.0400e+01, -2.5630e-01, -1.9700e-02,  2.3660e-01,  1.0561e+03,
          1.2555e-01, -4.3787e+02, -4.3787e+02, -4.3786e+02, -4.3790e+0