In [1]:
# %load_ext autoreload
# %autoreload 2

import sys
sys.path.append("../src")

import os
os.environ['CUDA_VISIBLE_DEVICES'] = "1"

In [2]:
import pickle
import tqdm
import numpy as np
import pandas as pd
import networkx as nx
from gensim.models import KeyedVectors as word2vec

import torch
from torch_geometric.utils import from_networkx
from torch_geometric.data import DataLoader

import torch.nn.functional as F
from torch_geometric.data import DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import GATConv
from torch_geometric.nn import GraphConv, TopKPooling
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp


from code_parser import *
from dataset import CloneDatasetPair

In [3]:
batch_size = 4
num_workers = 2

In [4]:
transform = T.NormalizeFeatures()
train_dataset = CloneDatasetPair(root="../data/train_pair", functions_path="../data/functions/", pairs_path="../data/train.npz")
valid_dataset = CloneDatasetPair(root="../data/valid_pair", functions_path="../data/functions/", pairs_path="../data/valid.npz")
test_dataset = CloneDatasetPair(root="../data/test_pair", functions_path="../data/functions/", pairs_path="../data/test.npz")

In [5]:
test_loader = DataLoader(test_dataset, batch_size=batch_size,  follow_batch=['x_s', 'x_t'])
val_loader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=num_workers,  follow_batch=['x_s', 'x_t'])
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers,  follow_batch=['x_s', 'x_t'])

In [6]:
from dgmc.models import SplineCNN, RelCNN, DGMC
device = torch.device('cpu')

psi_1 = RelCNN(384, 128, 1, cat = False).to(device)
psi_2 = RelCNN(384, 128, 1, cat = False).to(device)

model = DGMC(psi_1, psi_2, num_steps=1, k=-1, detach=True).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)

In [None]:
def train(epoch):
    model.train()

    loss_all = 0
    steps = 0
    total_loss = total_examples = total_correct = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        S_0, S_L = model(x_s=data.x_s,
                         edge_index_s=data.edge_index_s,
                         edge_attr_s=None,
                         batch_s=None,
                         x_t=data.x_t,
                         edge_index_t=data.edge_index_t,
                         edge_attr_t=None,
                         batch_t=None,
                         y=data.y)
        
        loss = model.loss(S_0, data.y)
        loss = model.loss(S_L, data.y) + loss
        loss.backward()
        optimizer.step()
    
        loss_all += loss.item()
        steps += 1
        total_correct += model.acc(S_L, data.y, reduction='sum')
        total_examples += data.y.size()[0]
        
        print(f"loss = {loss.item()}; acc: {total_correct/total_examples}", end="\r")
        
        del loss
        del S_0
        del S_L
        del data
        
        torch.cuda.empty_cache() 

    return loss_all / len(step)
train(1)

loss = 13.425775527954102; acc: 0.010479041916167664

In [None]:
for data in train_loader:
    break
dir(data)

In [None]:
data.y.size()[0]

In [None]:
def train(epoch):
    model.train()

    loss_all = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, data.y)
        loss.backward()
        print(f"loss = {loss.item()}", end="\r")
        loss_all += data.num_graphs * loss.item()
        optimizer.step()
    return loss_all / len(train_dataset)


@torch.no_grad()
def test(loader):
    model.eval()

    correct = 0
    for data in loader:
        data = data.to(device)
        pred = model(data).max(dim=1)[1]
        correct += pred.eq(data.y).sum().item()
    return correct / len(loader.dataset)


In [None]:
best_val_acc = 0
for epoch in range(1, 201):
    loss = train(epoch)
    train_acc = test(train_loader)
    val_acc = test(val_loader)
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "../data/pg_play_2.pt")
        
    
    print('Epoch: {:03d}, Loss: {:.5f}, Train Acc: {:.5f}, Val Acc: {:.5f}, Best: {:.5f}'.
          format(epoch, loss, train_acc, val_acc, best_val_acc))