In [1]:
import time

import torch

import torch.nn.functional as F
from torch.nn import ModuleList, Embedding
from torch.nn import Sequential, ReLU, Linear
from torch.nn import CrossEntropyLoss, MSELoss, L1Loss
from torch.optim.lr_scheduler import ReduceLROnPlateau

from torch_geometric.utils import degree
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, PNAConv, BatchNorm, global_add_pool

from phylognn_model import G2Dist_PNAConv

from gene_graph_dataset import GeneGraphDataset

from torch.utils.tensorboard import SummaryWriter

In [2]:
train_p, test_p = 0.7, 0.2

In [3]:
dataset = GeneGraphDataset('dataset_adj1', 20, 20, graph_num = 100)
data_size = len(dataset)
train_size, test_size = (int)(data_size * train_p), (int)(data_size * test_p)

In [4]:
data_size

2000

In [5]:
dataset = dataset.shuffle()
train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:(train_size + test_size)]
val_dataset = dataset[(train_size + test_size):]

In [6]:
# len(train_dataset), len(test_dataset), len(val_dataset)

In [7]:
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)
val_loader = DataLoader(val_dataset, batch_size=1)

In [8]:
# len(train_loader), len(test_loader), len(val_loader)

In [9]:
deg = torch.zeros(5, dtype=torch.long)
for data in dataset:
    d = degree(data.edge_index[1].type(torch.int64), 
               num_nodes=data.num_nodes, dtype=torch.long)
    deg += torch.bincount(d, minlength=deg.numel())

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = G2Dist_PNAConv(deg).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay = 0.0001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10,
                              min_lr=0.00001)

In [11]:
# loss_fn = MSELoss()
# l1_fn = L1Loss()

loss_fn = CrossEntropyLoss()

def train(train_loader):
    model.train()

    total_loss, counter = 0, 0
    size = len(train_loader)
    for batch, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        #loss = (out.squeeze() - data.y).abs().sum()
        pred, y = out.softmax(axis = 1).argmax(axis = 1), data.y
        counter += (pred == y).sum().item()
        
        loss = loss_fn(out, data.y)
        
        loss.backward()
        total_loss += loss.item()
        optimizer.step()
        
    return total_loss / len(train_loader), counter

In [12]:
@torch.no_grad()
def test(loader):
    model.eval()

    total_error, counter = 0, 0
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        
        pred, y = out.softmax(axis = 1).argmax(axis = 1), data.y
        counter += (pred == y).sum().item()
        
        # total_error += (out.squeeze() - data.y).abs().sum().item()
        
        total_error += loss_fn(out, data.y).item()
        
    return total_error / len(loader), counter

In [13]:
writer = SummaryWriter(log_dir='runs_g2d_10/g2dist_adjone_02000-pna-global-run1')

In [14]:
import numpy as np
for epoch in range(1, 1001):
    loss, train_counter = train(train_loader)
    test_mae, test_counter = test(test_loader)
    val_mae, _ = test(val_loader)
    
    # scheduler.step(loss)
    
    writer.add_scalar('Loss/train', loss, epoch)
    writer.add_scalar('Loss/test', test_mae, epoch)
    writer.add_scalar('Loss/val', val_mae, epoch)
    writer.add_scalar('Counter/train', train_counter/len(train_loader.dataset), epoch)
    writer.add_scalar('Counter/test', test_counter/len(test_loader.dataset), epoch)
    
    print(f'{time.ctime()}\t'
          f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
          f'Test: {test_mae:.4f}')
    
    print(f'\t\t -- train_counter: {train_counter}, test_counter:{test_counter}')

Sun Jan  2 14:15:39 2022	Epoch: 001, Loss: 3.1139, Val: 2.9204, Test: 2.9613
		 -- train_counter: 117, test_counter:28
Sun Jan  2 14:15:42 2022	Epoch: 002, Loss: 2.7273, Val: 2.8902, Test: 2.9204
		 -- train_counter: 181, test_counter:41
Sun Jan  2 14:15:46 2022	Epoch: 003, Loss: 2.6004, Val: 2.7530, Test: 2.8115
		 -- train_counter: 263, test_counter:47
Sun Jan  2 14:15:50 2022	Epoch: 004, Loss: 2.4950, Val: 2.6956, Test: 2.7768
		 -- train_counter: 293, test_counter:53
Sun Jan  2 14:15:53 2022	Epoch: 005, Loss: 2.4119, Val: 2.7598, Test: 2.8190
		 -- train_counter: 356, test_counter:34
Sun Jan  2 14:15:57 2022	Epoch: 006, Loss: 2.3259, Val: 2.8666, Test: 2.7650
		 -- train_counter: 429, test_counter:61
Sun Jan  2 14:16:01 2022	Epoch: 007, Loss: 2.2537, Val: 2.8493, Test: 2.7790
		 -- train_counter: 489, test_counter:59
Sun Jan  2 14:16:05 2022	Epoch: 008, Loss: 2.1544, Val: 3.3997, Test: 3.3631
		 -- train_counter: 568, test_counter:33
Sun Jan  2 14:16:08 2022	Epoch: 009, Loss: 2.049

KeyboardInterrupt: 

In [None]:
model.eval()

In [None]:
model.train()

In [None]:
tld0 = list(train_loader)[0].to(device)
tld1 = list(test_loader)[0].to(device)

In [None]:
res0 = model(tld0.x, tld0.edge_index, tld0.batch)

In [None]:
res0.argmax(axis = 1)

In [None]:
tld0.y

In [None]:
loss_fn(res0, tld0.y)

In [None]:
L1Loss()(res0.argmax(axis = 1).to(torch.float), tld0.y.to(torch.float))

In [None]:
(res0.argmax(axis = 1) == tld0.y).abs().sum().item()/len(tld0.y)

In [None]:
res1 = model(tld1.x, tld1.edge_index, tld1.batch)

In [None]:
res1.argmax(axis = 1)

In [None]:
tld1.y

In [None]:
loss_fn(res1, tld1.y)

In [None]:
L1Loss()(res1.argmax(axis = 1).to(torch.float), tld1.y.to(torch.float))

In [None]:
train_y = [d.y.item() for d in train_dataset]

In [None]:
np.unique(train_y)

In [None]:
test_y = [d.y.item() for d in test_dataset]

In [None]:
np.unique(test_y)

In [None]:
np.unique([d.y.item() for d in val_dataset])

In [None]:
data = train_dataset[0]

In [None]:
data = data.to(device)

In [None]:
data.x

In [None]:
model.node_emb

In [None]:
x = model.node_emb(data.x.squeeze()).view(-1, 80)

In [None]:
x

In [None]:
x = model.convs[0](x, data.edge_index)

In [None]:
x