In [1]:
import time

import torch

import torch.nn.functional as F
from torch.nn import ModuleList, Embedding
from torch.nn import Sequential, ReLU, Linear, MSELoss
from torch.optim.lr_scheduler import ReduceLROnPlateau

from torch_geometric.utils import degree
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, PNAConv, BatchNorm, global_add_pool

from phylognn_model import G2Dist_GCNConv

from gene_graph_dataset import GeneGraphDataset

from torch.utils.tensorboard import SummaryWriter

In [2]:
train_p, test_p = 0.7, 0.2

In [3]:
dataset = GeneGraphDataset('dataset', 100, 20, graph_num = 50)
data_size = len(dataset)
train_size, test_size = (int)(data_size * train_p), (int)(data_size * test_p)

In [4]:
data_size

1000

In [5]:
dataset = dataset.shuffle()
train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:(train_size + test_size)]
val_dataset = dataset[(train_size + test_size):]

In [6]:
t = dataset[0]

In [7]:
t

Data(x=[200], edge_index=[2, 398], edge_attr=[398, 2], dtype=torch.int64, y=[1], num_nodes=200)

In [8]:
t.y

tensor([4])

In [9]:
# len(train_dataset), len(test_dataset), len(val_dataset)

In [10]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1)
val_loader = DataLoader(val_dataset, batch_size=1)

In [11]:
# len(train_loader), len(test_loader), len(val_loader)

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = G2Dist_GCNConv().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
                              min_lr=0.00001)

In [15]:
mse_loss = MSELoss()
def train(train_loader):
    model.train()

    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        #loss = (out.squeeze() - data.y).abs().sum()
        loss = mse_loss(out.squeeze(), data.y.to(torch.float))
        
        loss.backward()
        total_loss += loss.item() * data.num_graphs
        optimizer.step()
    return total_loss / len(train_loader.dataset)

In [16]:
@torch.no_grad()
def test(loader):
    model.eval()

    total_error = 0
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        total_error += (out.squeeze() - data.y).abs().sum().item()
    return total_error / len(loader.dataset)

In [17]:
writer = SummaryWriter(log_dir='runs_g2d_10/g2distance_0100_0020_01000-gcn-run4')

In [18]:
result = torch.zeros(1000, 3)

for epoch in range(1, 101):
    loss = train(train_loader)
    test_mae = test(test_loader)
    val_mae = test(val_loader)
    
    # scheduler.step(val_mae)
    result[epoch - 1] = torch.tensor([loss, val_mae, test_mae])
    
    writer.add_scalar('Loss/train', loss, epoch)
    writer.add_scalar('Loss/test', test_mae, epoch)
    writer.add_scalar('Loss/val', val_mae, epoch)
    
    print(f'{time.ctime()}\t'
          f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
          f'Test: {test_mae:.4f}')

Wed Dec 29 23:31:13 2021	Epoch: 001, Loss: 131.6797, Val: 10.6940, Test: 9.1010
Wed Dec 29 23:31:15 2021	Epoch: 002, Loss: 101.3037, Val: 9.9220, Test: 8.3069
Wed Dec 29 23:31:17 2021	Epoch: 003, Loss: 70.7683, Val: 9.2629, Test: 7.6386
Wed Dec 29 23:31:19 2021	Epoch: 004, Loss: 45.9278, Val: 9.0215, Test: 7.3829
Wed Dec 29 23:31:21 2021	Epoch: 005, Loss: 32.2752, Val: 9.3279, Test: 7.6283
Wed Dec 29 23:31:23 2021	Epoch: 006, Loss: 29.0565, Val: 8.3758, Test: 6.6687
Wed Dec 29 23:31:25 2021	Epoch: 007, Loss: 27.2096, Val: 7.2313, Test: 5.6692
Wed Dec 29 23:31:27 2021	Epoch: 008, Loss: 25.7433, Val: 6.3189, Test: 5.0125
Wed Dec 29 23:31:29 2021	Epoch: 009, Loss: 24.6015, Val: 5.7001, Test: 4.5915
Wed Dec 29 23:31:31 2021	Epoch: 010, Loss: 23.4974, Val: 5.0510, Test: 4.3246
Wed Dec 29 23:31:33 2021	Epoch: 011, Loss: 22.4077, Val: 4.8185, Test: 4.2774
Wed Dec 29 23:31:35 2021	Epoch: 012, Loss: 21.7950, Val: 4.6875, Test: 4.2714
Wed Dec 29 23:31:37 2021	Epoch: 013, Loss: 21.1261, Val: 4.70

In [19]:
writer.close()

In [20]:
len(test_dataset)

200

In [21]:
model.eval()

for td in test_loader:
    td = td.to(device)
    
    out = model(td.x, td.edge_index, td.batch)
    print(f'{out.squeeze().item():.4f}, {td.y.item():.2f}')

15.8690, 6.00
9.8693, 13.00
9.5195, 16.00
12.6847, 10.00
11.0214, 8.00
12.1926, 1.00
12.4089, 14.00
12.6962, 6.00
8.4674, 5.00
13.1534, 3.00
10.9987, 15.00
9.1817, 10.00
10.4410, 10.00
13.2640, 9.00
16.0047, 17.00
13.3262, 15.00
12.9492, 5.00
9.8725, 9.00
17.0109, 1.00
4.4897, 1.00
11.6419, 14.00
10.8838, 10.00
17.4414, 3.00
14.9151, 18.00
20.1962, 16.00
13.2005, 3.00
8.4702, 6.00
9.7967, 14.00
10.4376, 2.00
16.4852, 12.00
15.2414, 17.00
13.8360, 12.00
8.5370, 19.00
15.0889, 18.00
19.0239, 14.00
14.0629, 2.00
16.2450, 18.00
5.5416, 7.00
13.0683, 14.00
8.7653, 4.00
14.0036, 17.00
7.9410, 7.00
9.5773, 9.00
7.5399, 5.00
8.7327, 4.00
14.5951, 11.00
15.1789, 2.00
18.2782, 12.00
15.0488, 12.00
9.6260, 8.00
9.9108, 3.00
7.6402, 1.00
12.1432, 18.00
9.5015, 13.00
10.7635, 18.00
12.0166, 12.00
14.7583, 3.00
9.7486, 11.00
14.2026, 15.00
7.7697, 9.00
7.8131, 5.00
8.3330, 2.00
12.8150, 8.00
10.9716, 20.00
11.9946, 14.00
15.3974, 5.00
12.2330, 20.00
11.7888, 5.00
15.6698, 7.00
13.7966, 9.00
10.4494,

In [37]:
len(td.dataset)

AttributeError: 'GlobalStorage' object has no attribute 'dataset'

In [38]:
test_loader.dataset

GeneGraphDataset(200)