In [None]:
# https://github.com/brandonwz/ml4gfinal

In [None]:
# %%writefile train.py
from hmnet import ConvTransNet, BetterConvNet, BetterConvPoolNet, SimpleConvNet, TransformerNoConv
from data_reader import HisModDataset

import torch
import torch.nn as nn

import scipy
from sklearn import metrics

import numpy as np

import matplotlib.pyplot as plt

import os, sys

from timeit import default_timer

#Use a trial name that hasn't been used before -- models will get saved here 
TRIAL_NAME = "11betterConv2geneTest"
TRIAL_DIR = "./checkpoints/" + TRIAL_NAME

print("====== GPU Info ======")

print("cuda available:", torch.cuda.is_available())

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", DEVICE)

print("======================")

MAIN_DIR = "./ProcessedData/"

def train(
	hmnet, 
	train_loader, 
	val_loader, 
	checkpoint_name = "", 
	epoch = 10
	):

	optim = torch.optim.Adam(hmnet.parameters(), amsgrad=True)

	checkpoint = "./checkpoints/" + checkpoint_name

	min_val_mse = float('inf')

	torch.set_grad_enabled(True)
	hmnet.train()

	print("train...")
	c = 0
	for i in range(epoch):
		start_time = default_timer()
		for x1, x2, y in train_loader:

			x1 = x1.to(DEVICE)
			x2 = x2.to(DEVICE)
			y = y.to(DEVICE)

			#Take diff of gene seq feature tensors 
			input_mat = x1 - x2
			input_mat = input_mat.float()

			pred = hmnet(input_mat)
			pred = pred.squeeze()

			loss = nn.MSELoss()

			y = y.float()
			y = y.squeeze()

			out = loss(pred, y)

			out.backward()

			optim.step()
			optim.zero_grad()

		mse, r2, _ = eval(val_loader, hmnet)

		epoch_time = default_timer() - start_time
		print("epoch:", (i+1))
		print("time taken:", epoch_time)
		print("val pcc:", r2)

		#Save model with best performance on validation set in terms of MSE
		if(mse < min_val_mse):
			min_val_mse = mse
			torch.save(hmnet.state_dict(), checkpoint)

		torch.set_grad_enabled(True)
		hmnet.train()	

	print("finished training")

	#Load and return best model across epochs
	hmnet.load_state_dict(torch.load(checkpoint))

	return hmnet

#Evaluates model on dataset
def eval(test_data, model):
	torch.set_grad_enabled(False)
	model.eval()
	pred_list = []; label_list = []
	for x1, x2, y in test_data:

		x1 = x1.to(DEVICE)
		x2 = x2.to(DEVICE)
		y = y.to(DEVICE)

		input_mat = x1 - x2
		input_mat = input_mat.float()

		pred = model(input_mat)
		y = y.float()

		pred_list.append(pred.squeeze().item())
		label_list.append(y.squeeze().item())

	pred_list = np.array(pred_list)
	label_list = np.array(label_list)

	#print(pred_list[0:10])

	#From https://github.com/QData/DeepDiffChrome
	R2,p=scipy.stats.pearsonr(label_list, pred_list)
	MSE=metrics.mean_squared_error(label_list, pred_list)
	return MSE, R2, p

def graph_results(pccs):
	labels = ["E123-E003", 
	"E116-E003", 
	"E123-E116", 
	"E003-E005", 
	"E003-E006",
	"E006-E007",
	"E005-E006",
	"E003-E004",
	"E004-E006",
	"E037-E038"]

	ticks = np.arange(10) + 1

	plt.figure()
	plt.xticks(ticks, labels, rotation = 45)
	plt.plot(ticks, pccs, linestyle='dashed', marker='o')
	plt.ylabel("Pearson Correlation Coefficient")
	plt.xlabel("Cell Pairs")
	plt.ylim(0, 1)
	plt.show()


if __name__ == '__main__':
	cell_pairs = [
	["E123", "E003"],
	["E116", "E003"],
	["E123", "E116"],
	["E003", "E005"],
	["E003", "E006"],
	["E006", "E007"],
	["E005", "E006"],
	["E003", "E004"],
	["E004", "E006"],
	["E037", "E038"]
	]

	#Make sure trial name is unique
	if(os.path.exists(TRIAL_DIR)):
		print("Error: trial already exists. Please choose a different name.")
		sys.exit()
	else:
		os.makedirs(TRIAL_DIR)

	for cell_pair in cell_pairs:
		TRIAL_PAIR_NAME = TRIAL_NAME + "_" + cell_pair[0] + "_" + cell_pair[1]
		TRIAL_SAVE = TRIAL_NAME + "/" + TRIAL_PAIR_NAME
		
		print("=======CELL PAIR: " + str(cell_pair) + "========")
		
		cellA_expr_file = cell_pair[0] + ".expr.csv"
		cellA_file = cell_pair[0] + ".train.csv"
		cellB_file = cell_pair[1] + ".train.csv"
		cellB_expr_file = cell_pair[1] + ".expr.csv"

		cellA_val = cell_pair[0] + ".valid.csv"
		cellB_val = cell_pair[1] + ".valid.csv"

		hmnet = ConvTransNet() #BetterConvNet() #Change me to one of the available models
		hmnet = hmnet.to(DEVICE)

		print("loading data...")

		#Load train data
		dataset = HisModDataset(
			cellA_file, 
			cellA_expr_file, 
			cellB_file, 
			cellB_expr_file, 
			MAIN_DIR
		)
		
		#Load val data
		val_data = HisModDataset(
			cellA_val, 
			cellA_expr_file, 
			cellB_val, 
			cellB_expr_file, 
			MAIN_DIR
		)

		print("data loaded!")

		dataloader = torch.utils.data.DataLoader(dataset)
		val_loader = torch.utils.data.DataLoader(val_data)

		hmnet = train(hmnet, dataloader, val_loader = val_loader, checkpoint_name = TRIAL_SAVE)

		MSE, R2, p = eval(dataloader, hmnet)
		print("eval on train set:", R2, p)

		cellA_expr_file = cell_pair[0] + ".expr.csv"
		cellA_file = cell_pair[0] + ".test.csv"
		cellB_file = cell_pair[1] + ".test.csv"
		cellB_expr_file = cell_pair[1] + ".expr.csv"

		#Load test data 
		dataset = HisModDataset(
			cellA_file, 
			cellA_expr_file, 
			cellB_file, 
			cellB_expr_file, 
			MAIN_DIR, 
			ignore_B = False
		)

		dataloader = torch.utils.data.DataLoader(dataset)


		MSE, R2, p = eval(dataloader, hmnet)
		print("eval on test set: ", R2, p)

	


In [1]:
import os
import sys
import torch
class HisModDataset(torch.utils.data.Dataset):
    def __init__(
                self, 
                cellA_file, 
                cellA_expr_file, 
                cellB_file, 
                cellB_expr_file, 
                main_dir, 
                ignore_B=False,
                shuffle_cols = False #experimental
                ):
        
        print('in constructor')

        cell_cols = ["A", "B", "C", "D", "E", "F"] #dummy cols
        expr_cols = ["A", "B"] #dummy cols 

        #read in data into dataframes
        cellA_df = pd.read_csv(main_dir + cellA_file, names=cell_cols)
        cellB_df = pd.read_csv(main_dir + cellB_file, names=cell_cols)
        cellA_expr_df = pd.read_csv(main_dir + cellA_expr_file, names=expr_cols)
        cellB_expr_df = pd.read_csv(main_dir + cellB_expr_file, names=expr_cols)

        #200 entries per gene
        self.offset = 200
        self.length = len(cellA_df)//self.offset

        #all cols except one with gene info
        hm_cols = ["B", "C", "D", "E", "F"]

        #convert to tensor
        self.cellA_tensor = torch.tensor(cellA_df[hm_cols].values)
        self.cellB_tensor = torch.tensor(cellB_df[hm_cols].values)

        #using expr file, maps gene id -> expr value
        self.gene_to_valA = dict(zip(cellA_expr_df.A, cellA_expr_df.B))
        self.gene_to_valB = dict(zip(cellB_expr_df.A, cellB_expr_df.B))

        #extract gene info col from the train/test/val data
        self.geneA_names = cellA_df["A"]
        self.geneB_names = cellB_df["A"]

        self.ignore_B = ignore_B
        self.shuffle_cols = shuffle_cols

    def __getitem__(self, idx):
        print('in get items',idx)
        idx = idx*self.offset

        #find the relevant slice of the tensor based on the id
        tensorA = self.cellA_tensor[idx:idx+self.offset]
        tensorB = self.cellB_tensor[idx:idx+self.offset]

        #transpose b/c data is initially 200x5
        tensorA = torch.transpose(tensorA, 0, 1)
        tensorB = torch.transpose(tensorB, 0, 1)

        if(self.shuffle_cols): #experimental
            tensorA = tensorA[:, torch.randperm(tensorA.size()[1])]
            tensorB = tensorB[:, torch.randperm(tensorB.size()[1])]

        #just takes the first gene window id corresponding to the 
        #group of 200 and uses it to find the gene id
        geneA = self.geneA_names[idx].split("_")[0]
        geneB = self.geneB_names[idx].split("_")[0]

        #print(geneA); print(geneB) #<-- should be the same

        #Get the expression values for cell A and cell B
        cA = self.gene_to_valA[geneA]
        cB = self.gene_to_valB[geneB]

        if(self.ignore_B):
            return tensorA, tensorB, cA

		#Find the log-fold change
        label = self.getlabel(cA, cB) 

        return tensorA, tensorB, label[0]

    def __len__(self):
        return self.length

	#From https://github.com/QData/DeepDiffChrome/blob/master/data.py
    def getlabel(self, c1,c2):
        print('in get label')
        # get log fold change of expression

        label1=math.log((float(c1)+1.0),2)
        label2=math.log((float(c2)+1.0),2)
        label=[]
        label.append(label1)
        label.append(label2)

        fold_change=(float(c2)+1.0)/(float(c1)+1.0)
        log_fold_change=math.log((fold_change),2)
        return (log_fold_change, label)

In [2]:
import os
import sys
import pandas as pd
# from data_reader import HisModDataset
TRIAL_NAME = "11awawqswa21aas"
TRIAL_DIR = "./checkpoints/" + TRIAL_NAME
MAIN_DIR = "./ProcessedData/"
cell_pairs = [
                ["E123", "E003"]
#                 ["E116", "E003"],
#                 ["E123", "E116"],
#                 ["E003", "E005"],
#                 ["E003", "E006"],
#                 ["E006", "E007"],
#                 ["E005", "E006"],
#                 ["E003", "E004"],
#                 ["E004", "E006"],
#                 ["E037", "E038"]
            ]

#Make sure trial name is unique
if(os.path.exists(TRIAL_DIR)):
    print("Error: trial already exists. Please choose a different name.")
    sys.exit()
else:
    os.makedirs(TRIAL_DIR)

for cell_pair in cell_pairs:
    TRIAL_PAIR_NAME = TRIAL_NAME + "_" + cell_pair[0] + "_" + cell_pair[1]
    TRIAL_SAVE = TRIAL_NAME + "/" + TRIAL_PAIR_NAME

    print("=======CELL PAIR: " + str(cell_pair) + "========")

    cellA_expr_file = cell_pair[0] + ".expr.csv"
    cellA_file = cell_pair[0] + ".train.csv"
    cellB_file = cell_pair[1] + ".train.csv"
    cellB_expr_file = cell_pair[1] + ".expr.csv"

    cellA_val = cell_pair[0] + ".valid.csv"
    cellB_val = cell_pair[1] + ".valid.csv"

#     hmnet = ConvTransNet() #BetterConvNet() #Change me to one of the available models
#     hmnet = hmnet.to(DEVICE)

    print("loading data...")

    #Load train data
    dataset = HisModDataset(
        cellA_file, 
        cellA_expr_file, 
        cellB_file, 
        cellB_expr_file, 
        MAIN_DIR
    )

#     #Load val data
#     val_data = HisModDataset(
#         cellA_val, 
#         cellA_expr_file, 
#         cellB_val, 
#         cellB_expr_file, 
#         MAIN_DIR
#     )

    print("data loaded!")


loading data...
in constructor
data loaded!


In [3]:
dataloader1 = torch.utils.data.DataLoader(dataset)

In [4]:
import math

In [6]:
for x1, x2, y in dataloader1:
    print(y)

in get items 0
in get label
tensor([2.7388], dtype=torch.float64)
in get items 1
in get label
tensor([1.2595], dtype=torch.float64)
in get items 2
in get label
tensor([-7.9794], dtype=torch.float64)
in get items 3
in get label
tensor([1.5605], dtype=torch.float64)
in get items 4
in get label
tensor([3.0940], dtype=torch.float64)
in get items 5
in get label
tensor([-0.2579], dtype=torch.float64)
in get items 6
in get label
tensor([-0.6517], dtype=torch.float64)
in get items 7
in get label
tensor([5.9305], dtype=torch.float64)
in get items 8
in get label
tensor([-0.2507], dtype=torch.float64)
in get items 9
in get label
tensor([-1.7564], dtype=torch.float64)
in get items 10
in get label
tensor([-0.5049], dtype=torch.float64)
in get items 11
in get label
tensor([-0.6534], dtype=torch.float64)
in get items 12
in get label
tensor([-2.5250], dtype=torch.float64)
in get items 13
in get label
tensor([-0.7851], dtype=torch.float64)
in get items 14
in get label
tensor([0.0951], dtype=torch.float

tensor([0.5163], dtype=torch.float64)
in get items 385
in get label
tensor([10.6222], dtype=torch.float64)
in get items 386
in get label
tensor([1.7546], dtype=torch.float64)
in get items 387
in get label
tensor([1.6395], dtype=torch.float64)
in get items 388
in get label
tensor([5.1293], dtype=torch.float64)
in get items 389
in get label
tensor([-0.7075], dtype=torch.float64)
in get items 390
in get label
tensor([3.5063], dtype=torch.float64)
in get items 391
in get label
tensor([8.9513], dtype=torch.float64)
in get items 392
in get label
tensor([-0.0345], dtype=torch.float64)
in get items 393
in get label
tensor([1.2374], dtype=torch.float64)
in get items 394
in get label
tensor([-0.7759], dtype=torch.float64)
in get items 395
in get label
tensor([1.2153], dtype=torch.float64)
in get items 396
in get label
tensor([-1.9509], dtype=torch.float64)
in get items 397
in get label
tensor([-1.0150], dtype=torch.float64)
in get items 398
in get label
tensor([-3.7830], dtype=torch.float64)
in 

in get label
tensor([-4.0393], dtype=torch.float64)
in get items 782
in get label
tensor([6.4094], dtype=torch.float64)
in get items 783
in get label
tensor([-0.3657], dtype=torch.float64)
in get items 784
in get label
tensor([-3.1393], dtype=torch.float64)
in get items 785
in get label
tensor([0.2385], dtype=torch.float64)
in get items 786
in get label
tensor([-0.7500], dtype=torch.float64)
in get items 787
in get label
tensor([-2.7061], dtype=torch.float64)
in get items 788
in get label
tensor([0.0218], dtype=torch.float64)
in get items 789
in get label
tensor([3.], dtype=torch.float64)
in get items 790
in get label
tensor([-1.6924], dtype=torch.float64)
in get items 791
in get label
tensor([3.2740], dtype=torch.float64)
in get items 792
in get label
tensor([-0.4340], dtype=torch.float64)
in get items 793
in get label
tensor([-1.9386], dtype=torch.float64)
in get items 794
in get label
tensor([2.4002], dtype=torch.float64)
in get items 795
in get label
tensor([-1.7256], dtype=torch.f

tensor([9.2021], dtype=torch.float64)
in get items 1176
in get label
tensor([5.4263], dtype=torch.float64)
in get items 1177
in get label
tensor([-1.3785], dtype=torch.float64)
in get items 1178
in get label
tensor([4.1066], dtype=torch.float64)
in get items 1179
in get label
tensor([-0.8675], dtype=torch.float64)
in get items 1180
in get label
tensor([-0.2486], dtype=torch.float64)
in get items 1181
in get label
tensor([1.7341], dtype=torch.float64)
in get items 1182
in get label
tensor([1.8213], dtype=torch.float64)
in get items 1183
in get label
tensor([2.5220], dtype=torch.float64)
in get items 1184
in get label
tensor([-1.1233], dtype=torch.float64)
in get items 1185
in get label
tensor([-2.7503], dtype=torch.float64)
in get items 1186
in get label
tensor([6.6724], dtype=torch.float64)
in get items 1187
in get label
tensor([-2.1179], dtype=torch.float64)
in get items 1188
in get label
tensor([0.], dtype=torch.float64)
in get items 1189
in get label
tensor([2.4699], dtype=torch.flo

in get label
tensor([-1.9321], dtype=torch.float64)
in get items 1592
in get label
tensor([-1.2173], dtype=torch.float64)
in get items 1593
in get label
tensor([4.3219], dtype=torch.float64)
in get items 1594
in get label
tensor([-0.1895], dtype=torch.float64)
in get items 1595
in get label
tensor([8.9298], dtype=torch.float64)
in get items 1596
in get label
tensor([-1.4022], dtype=torch.float64)
in get items 1597
in get label
tensor([3.2090], dtype=torch.float64)
in get items 1598
in get label
tensor([2.4659], dtype=torch.float64)
in get items 1599
in get label
tensor([-0.8950], dtype=torch.float64)
in get items 1600
in get label
tensor([10.9687], dtype=torch.float64)
in get items 1601
in get label
tensor([9.1015], dtype=torch.float64)
in get items 1602
in get label
tensor([-0.3754], dtype=torch.float64)
in get items 1603
in get label
tensor([0.9097], dtype=torch.float64)
in get items 1604
in get label
tensor([-2.8028], dtype=torch.float64)
in get items 1605
in get label
tensor([-0.53

tensor([-0.2307], dtype=torch.float64)
in get items 1976
in get label
tensor([-0.3674], dtype=torch.float64)
in get items 1977
in get label
tensor([-2.5212], dtype=torch.float64)
in get items 1978
in get label
tensor([0.6444], dtype=torch.float64)
in get items 1979
in get label
tensor([0.1795], dtype=torch.float64)
in get items 1980
in get label
tensor([-0.1465], dtype=torch.float64)
in get items 1981
in get label
tensor([-0.4667], dtype=torch.float64)
in get items 1982
in get label
tensor([-2.0803], dtype=torch.float64)
in get items 1983
in get label
tensor([0.4621], dtype=torch.float64)
in get items 1984
in get label
tensor([0.4804], dtype=torch.float64)
in get items 1985
in get label
tensor([3.0810], dtype=torch.float64)
in get items 1986
in get label
tensor([-1.6743], dtype=torch.float64)
in get items 1987
in get label
tensor([0.3932], dtype=torch.float64)
in get items 1988
in get label
tensor([2.4010], dtype=torch.float64)
in get items 1989
in get label
tensor([-1.1677], dtype=tor

tensor([1.7741], dtype=torch.float64)
in get items 2413
in get label
tensor([-1.3295], dtype=torch.float64)
in get items 2414
in get label
tensor([-2.7912], dtype=torch.float64)
in get items 2415
in get label
tensor([-2.0822], dtype=torch.float64)
in get items 2416
in get label
tensor([1.3257], dtype=torch.float64)
in get items 2417
in get label
tensor([-1.0117], dtype=torch.float64)
in get items 2418
in get label
tensor([2.2563], dtype=torch.float64)
in get items 2419
in get label
tensor([0.7224], dtype=torch.float64)
in get items 2420
in get label
tensor([-1.0679], dtype=torch.float64)
in get items 2421
in get label
tensor([2.], dtype=torch.float64)
in get items 2422
in get label
tensor([-0.7195], dtype=torch.float64)
in get items 2423
in get label
tensor([8.0945], dtype=torch.float64)
in get items 2424
in get label
tensor([4.0875], dtype=torch.float64)
in get items 2425
in get label
tensor([-3.3977], dtype=torch.float64)
in get items 2426
in get label
tensor([-0.1199], dtype=torch.f

tensor([0.], dtype=torch.float64)
in get items 2833
in get label
tensor([0.], dtype=torch.float64)
in get items 2834
in get label
tensor([6.2479], dtype=torch.float64)
in get items 2835
in get label
tensor([-1.3414], dtype=torch.float64)
in get items 2836
in get label
tensor([-4.7549], dtype=torch.float64)
in get items 2837
in get label
tensor([-0.4731], dtype=torch.float64)
in get items 2838
in get label
tensor([-1.3219], dtype=torch.float64)
in get items 2839
in get label
tensor([7.8473], dtype=torch.float64)
in get items 2840
in get label
tensor([0.3350], dtype=torch.float64)
in get items 2841
in get label
tensor([0.1252], dtype=torch.float64)
in get items 2842
in get label
tensor([-0.6509], dtype=torch.float64)
in get items 2843
in get label
tensor([0.2764], dtype=torch.float64)
in get items 2844
in get label
tensor([-2.0211], dtype=torch.float64)
in get items 2845
in get label
tensor([3.5508], dtype=torch.float64)
in get items 2846
in get label
tensor([0.6632], dtype=torch.float64

in get label
tensor([0.], dtype=torch.float64)
in get items 3244
in get label
tensor([2.0128], dtype=torch.float64)
in get items 3245
in get label
tensor([-1.1699], dtype=torch.float64)
in get items 3246
in get label
tensor([10.2992], dtype=torch.float64)
in get items 3247
in get label
tensor([12.5433], dtype=torch.float64)
in get items 3248
in get label
tensor([1.5254], dtype=torch.float64)
in get items 3249
in get label
tensor([8.9513], dtype=torch.float64)
in get items 3250
in get label
tensor([0.4711], dtype=torch.float64)
in get items 3251
in get label
tensor([12.9726], dtype=torch.float64)
in get items 3252
in get label
tensor([16.2651], dtype=torch.float64)
in get items 3253
in get label
tensor([6.8549], dtype=torch.float64)
in get items 3254
in get label
tensor([2.8557], dtype=torch.float64)
in get items 3255
in get label
tensor([10.0269], dtype=torch.float64)
in get items 3256
in get label
tensor([0.6956], dtype=torch.float64)
in get items 3257
in get label
tensor([-1.6503], d

tensor([-0.2957], dtype=torch.float64)
in get items 3657
in get label
tensor([-0.5570], dtype=torch.float64)
in get items 3658
in get label
tensor([6.5236], dtype=torch.float64)
in get items 3659
in get label
tensor([0.8075], dtype=torch.float64)
in get items 3660
in get label
tensor([7.3249], dtype=torch.float64)
in get items 3661
in get label
tensor([-0.1255], dtype=torch.float64)
in get items 3662
in get label
tensor([-1.7946], dtype=torch.float64)
in get items 3663
in get label
tensor([1.6507], dtype=torch.float64)
in get items 3664
in get label
tensor([-3.9827], dtype=torch.float64)
in get items 3665
in get label
tensor([-1.6060], dtype=torch.float64)
in get items 3666
in get label
tensor([1.], dtype=torch.float64)
in get items 3667
in get label
tensor([6.4026], dtype=torch.float64)
in get items 3668
in get label
tensor([0.7734], dtype=torch.float64)
in get items 3669
in get label
tensor([0.2257], dtype=torch.float64)
in get items 3670
in get label
tensor([-1.0385], dtype=torch.fl

KeyboardInterrupt: 