In [1]:
from datetime import datetime
import os
import csv
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from multi_gene_inference import multi_data_loader, MultiGCNInferenceNetwork2, multi_train, validate

In [2]:
HOME = '/home/jupyter-dylan/'    # Dylan
DATE = str(datetime.today())
devices = torch.device('cuda')

## Load data

In [3]:
# load gene names
#gene_name = np.loadtxt('gene_name.txt', dtype=str)     # Jiayu
gene_name = np.concatenate(                            # Dylan
    [np.loadtxt('{}gene labels/input_genes.txt'.format(HOME), dtype=str),
    np.loadtxt('{}gene labels/output_genes-1.txt'.format(HOME), dtype=str)]
)

# gene_edge = np.loadtxt('genemania_edge.txt', dtype=str)   # Jiayu
with open('../genemania_edges.csv', newline='') as f:         # Dylan
    reader = csv.reader(f)
    gene_edge = list(reader)

In [4]:
# create mapping table from genes to nodes number
node_map = {}
for i, j in enumerate(gene_name):
    node_map.update({j:i})

In [5]:
gene_set = set(["App", "Apoe", "Gusb","Lamp5","Mbp","Pvalb","S100b","Slc30a3","Snca","Mapt"])

In [6]:
# load training data
#train_loader = data_loader('input_train_cat.csv', gene='gusb', node_map=node_map,    # Jiayu
#                           edge_list=edge_list, multiplier=1e-5)
train_loader = multi_data_loader('{}new_data/input_train.csv'.format(HOME), genes=gene_set, node_map=node_map,    # Dylan
                           gene_edge=gene_edge, cat='{}new_data/output_train-1.csv'.format(HOME))

# load validation data
#validate_loader = data_loader('input_test_cat.csv', gene='gusb', node_map=node_map,  # Jiayu
#                              edge_list=edge_list, multiplier=1e-5)
validate_loader = multi_data_loader('{}new_data/input_test.csv'.format(HOME), genes=gene_set, node_map=node_map,  # Dylan
                              gene_edge=gene_edge, cat='{}new_data/output_test-1.csv'.format(HOME))

In [7]:
print(len(train_loader))
print(len(validate_loader))

185420
46360


In [8]:
print(train_loader[0].x.size())
print(train_loader[0].y.size())
print(train_loader[0].edge_index.size())

torch.Size([1431, 1])
torch.Size([1, 1])
torch.Size([2, 488166])


## Train GCN

In [None]:
model = MultiGCNInferenceNetwork2().to(torch.device('cuda'))
train_loss_lis = []
validate_loss_lis = []
train_batch_size = 64
validate_batch_size = 64
epochs = 100

folder = 'checkpoints_{}'.format(DATE)
os.mkdir(folder)

for i in range(epochs): 
    # train
    train_loss = multi_train(model, train_loader, train_batch_size)
    train_loss_lis.append(train_loss)
    
    # get validation loss
    val_loss = validate(model, validate_loader, validate_batch_size)
    validate_loss_lis.append(val_loss)
    print('Validation loss: {}'.format(val_loss))
    
    # checkpoint model
    torch.save(model.state_dict(), os.path.join(folder, 'epoch{}.pkl'.format(i)))
    with open(os.path.join(folder, 'checkpoints.csv'), 'a', newline='') as f:
        writer = csv.writer(f)
        if i == 0:
            writer.writerow(['epoch', 'train_loss', 'val_loss'])
        writer.writerow([i, train_loss, val_loss])
    
    # print ('epochs: %d || train_losses: %f || train_mse: %f || val_loss: %f || val_mse: %f' \
          # %(i, train_loss, train_mse, val_loss, val_mse))
    # print ('epochs: %d || train_losses: %f' %(i, train_loss))

100%|██████████| 2897/2897 [36:06<00:00,  1.34it/s, loss=0.00062] 
  0%|          | 0/2897 [00:00<?, ?it/s]

Validation loss: 5.5234519263096685e-06


100%|██████████| 2897/2897 [36:09<00:00,  1.34it/s, loss=1.29e-6]
  0%|          | 0/2897 [00:00<?, ?it/s]

Validation loss: 1.20420661954273e-06


100%|██████████| 2897/2897 [36:16<00:00,  1.33it/s, loss=1.23e-6]
  0%|          | 0/2897 [00:00<?, ?it/s]

Validation loss: 1.2023390545327864e-06


100%|██████████| 2897/2897 [36:15<00:00,  1.33it/s, loss=1.23e-6]
  0%|          | 0/2897 [00:00<?, ?it/s]

Validation loss: 1.203420621270737e-06


100%|██████████| 2897/2897 [36:24<00:00,  1.33it/s, loss=1.23e-6]
  0%|          | 0/2897 [00:00<?, ?it/s]

Validation loss: 1.2012008545553012e-06


100%|██████████| 2897/2897 [36:26<00:00,  1.33it/s, loss=1.23e-6]
  0%|          | 0/2897 [00:00<?, ?it/s]

Validation loss: 1.2007800270326869e-06


100%|██████████| 2897/2897 [36:34<00:00,  1.32it/s, loss=1.23e-6]
  0%|          | 0/2897 [00:00<?, ?it/s]

Validation loss: 1.2007734394378833e-06


100%|██████████| 2897/2897 [36:27<00:00,  1.32it/s, loss=1.22e-6]
  0%|          | 0/2897 [00:00<?, ?it/s]

Validation loss: 1.1982486896731963e-06


100%|██████████| 2897/2897 [36:47<00:00,  1.31it/s, loss=1.22e-6]
  0%|          | 0/2897 [00:00<?, ?it/s]

Validation loss: 1.1949761870605327e-06


100%|██████████| 2897/2897 [36:47<00:00,  1.31it/s, loss=1.22e-6]
  0%|          | 0/2897 [00:00<?, ?it/s]

Validation loss: 1.198888842002132e-06


100%|██████████| 2897/2897 [36:48<00:00,  1.31it/s, loss=1.22e-6]
  0%|          | 0/2897 [00:00<?, ?it/s]

Validation loss: 1.195950263021001e-06


100%|██████████| 2897/2897 [36:28<00:00,  1.32it/s, loss=1.22e-6]
  0%|          | 0/2897 [00:00<?, ?it/s]

Validation loss: 1.195666987857086e-06


100%|██████████| 2897/2897 [36:49<00:00,  1.31it/s, loss=1.22e-6]
  0%|          | 0/2897 [00:00<?, ?it/s]

Validation loss: 1.1938522851755304e-06


100%|██████████| 2897/2897 [36:53<00:00,  1.31it/s, loss=1.22e-6]
  0%|          | 0/2897 [00:00<?, ?it/s]

Validation loss: 1.2062426675334635e-06


100%|██████████| 2897/2897 [36:51<00:00,  1.31it/s, loss=1.22e-6]
  0%|          | 0/2897 [00:00<?, ?it/s]

Validation loss: 1.1924532797802673e-06


 54%|█████▍    | 1568/2897 [20:10<16:54,  1.31it/s, loss=1.24e-6]

In [None]:
# ep = [i for i in range(20)]
plt.plot(range(len(validate_loss_lis)), validate_loss_lis)
plt.plot(range(len(train_loss_lis)), train_loss_lis)