In [1]:
import time

import torch

import torch.nn.functional as F
from torch.nn import ModuleList, Embedding
from torch.nn import Sequential, ReLU, Linear
from torch.nn import CrossEntropyLoss, MSELoss, L1Loss
from torch.optim.lr_scheduler import ReduceLROnPlateau

from torch_geometric.utils import degree
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, PNAConv, BatchNorm, global_add_pool

from phylognn_model import G2Dist_GCNConv_Small

from gene_graph_dataset import GeneGraphDataset

from torch.utils.tensorboard import SummaryWriter

In [2]:
train_p, test_p = 0.7, 0.2

In [3]:
dataset = GeneGraphDataset('dataset', 20, 20, graph_num = 1000)
data_size = len(dataset)
train_size, test_size = (int)(data_size * train_p), (int)(data_size * test_p)

In [4]:
data_size

20000

In [5]:
dataset = dataset.shuffle()
train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:(train_size + test_size)]
val_dataset = dataset[(train_size + test_size):]

In [6]:
# len(train_dataset), len(test_dataset), len(val_dataset)

In [7]:
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256)
val_loader = DataLoader(val_dataset, batch_size=1)

In [8]:
# len(train_loader), len(test_loader), len(val_loader)

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = G2Dist_GCNConv_Small().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay = 0.0001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10,
                              min_lr=0.00001)

In [10]:
# loss_fn = MSELoss()
# l1_fn = L1Loss()

loss_fn = CrossEntropyLoss()

def train(train_loader):
    model.train()

    total_loss, counter = 0, 0
    size = len(train_loader)
    for batch, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        #loss = (out.squeeze() - data.y).abs().sum()
        pred, y = out.softmax(axis = 1).argmax(axis = 1), data.y
        counter += (pred == y).sum().item()
        
        loss = loss_fn(out, data.y)
        
        loss.backward()
        total_loss += loss.item()
        optimizer.step()
        
    return total_loss / len(train_loader), counter

In [11]:
@torch.no_grad()
def test(loader):
    model.eval()

    total_error, counter = 0, 0
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        
        pred, y = out.softmax(axis = 1).argmax(axis = 1), data.y
        counter += (pred == y).sum().item()
        
        # total_error += (out.squeeze() - data.y).abs().sum().item()
        
        total_error += loss_fn(out, data.y).item()
        
    return total_error / len(loader), counter

In [12]:
writer = SummaryWriter(log_dir='runs_g2d_10/g2dist_0020_0020_20000-small-run1')

In [None]:
import numpy as np
for epoch in range(1, 1001):
    loss, train_counter = train(train_loader)
    test_mae, test_counter = test(test_loader)
    val_mae, _ = test(val_loader)
    
    # scheduler.step(loss)
    
    writer.add_scalar('Loss/train', loss, epoch)
    writer.add_scalar('Loss/test', test_mae, epoch)
    writer.add_scalar('Loss/val', val_mae, epoch)
    writer.add_scalar('Counter/train', train_counter/len(train_loader.dataset), epoch)
    writer.add_scalar('Counter/test', test_counter/len(test_loader.dataset), epoch)
    
    print(f'{time.ctime()}\t'
          f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
          f'Test: {test_mae:.4f}')
    
    print(f'\t\t -- train_counter: {train_counter}, test_counter:{test_counter}')

Sat Jan  1 23:26:34 2022	Epoch: 001, Loss: 6.5481, Val: 69.6604, Test: 70.7975
		 -- train_counter: 882, test_counter:191
Sat Jan  1 23:26:50 2022	Epoch: 002, Loss: 2.5768, Val: 85.8999, Test: 86.4222
		 -- train_counter: 2197, test_counter:191
Sat Jan  1 23:27:06 2022	Epoch: 003, Loss: 2.3407, Val: 18.4953, Test: 18.5329
		 -- train_counter: 2618, test_counter:211
Sat Jan  1 23:27:23 2022	Epoch: 004, Loss: 2.0855, Val: 4.5157, Test: 4.5748
		 -- train_counter: 3222, test_counter:222
Sat Jan  1 23:27:39 2022	Epoch: 005, Loss: 1.9791, Val: 8.7308, Test: 8.6836
		 -- train_counter: 3431, test_counter:279
Sat Jan  1 23:27:56 2022	Epoch: 006, Loss: 1.8337, Val: 11.1299, Test: 11.0914
		 -- train_counter: 3923, test_counter:168
Sat Jan  1 23:28:13 2022	Epoch: 007, Loss: 1.7120, Val: 15.1981, Test: 15.1588
		 -- train_counter: 4590, test_counter:52
Sat Jan  1 23:28:29 2022	Epoch: 008, Loss: 1.6798, Val: 14.5409, Test: 14.5897
		 -- train_counter: 4671, test_counter:269
Sat Jan  1 23:28:46 20

In [None]:
model.eval()

In [None]:
tld0 = list(train_loader)[0].to(device)
tld1 = list(test_loader)[0].to(device)

In [None]:
res0 = model(tld0.x, tld0.edge_index, tld0.batch)

In [None]:
res0

In [None]:
res0.argmax(axis = 1)

In [None]:
tld0.y

In [None]:
loss_fn(res0, tld0.y)

In [None]:
L1Loss()(res0.argmax(axis = 1).to(torch.float), tld0.y.to(torch.float))

In [None]:
(res0.argmax(axis = 1) - tld0.y).abs().sum().item()/len(tld0.y)

In [None]:
res1 = model(tld1.x, tld1.edge_index, tld1.batch)

In [None]:
res1.argmax(axis = 1)

In [None]:
tld1.y

In [None]:
loss_fn(res1, tld1.y)

In [None]:
L1Loss()(res1.argmax(axis = 1).to(torch.float), tld1.y.to(torch.float))

In [None]:
train_y = [d.y.item() for d in train_dataset]

In [None]:
np.unique(train_y)

In [None]:
test_y = [d.y.item() for d in test_dataset]

In [None]:
np.unique(test_y)

In [None]:
np.unique([d.y.item() for d in val_dataset])