In [1]:
from ds_util import get_dataset
from perceptron_model import MLPModel
import numpy as np
import torch
import h5py
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from scipy.stats import spearmanr, pearsonr
from matplotlib import pyplot as plt

In [2]:
f = h5py.File('/clusterfs/nilah/oberon/datasets/basenji/embeddings/embeddings.h5','r')



In [3]:
dset = f['embeddings']

train_inds = np.zeros(len(dset), dtype=bool)
val_inds = np.zeros(len(dset), dtype=bool)
test_inds = np.zeros(len(dset), dtype=bool)

In [4]:
train_inds[:34021]=True
val_inds[34021:36234]=True
test_inds[36234:]=True

In [5]:
labels_full = h5py.File('/clusterfs/nilah/oberon/datasets/cs282a/dataset_14-lmnb1_4-cpg.h5')['single_bin']

In [6]:
def pearson_correlation(prediction, target):
    # Flatten the tensors to 1D
    prediction_flat = prediction.view(-1).cpu().detach().numpy()
    target_flat = target.view(-1).cpu().detach().numpy()

    # Calculate Pearson's correlation
    corr, _ = pearsonr(prediction_flat, target_flat)
    return corr

def spearman_correlation(prediction,target):
    # Flatten the tensors to 1D
    prediction_flat = prediction.view(-1).cpu().detach().numpy()
    target_flat = target.view(-1).cpu().detach().numpy()

    # Calculate Spearman's correlation
    corr, _ = spearmanr(prediction_flat, target_flat)
    return corr    

In [7]:
start_load = datetime.now()
print('starting at ',start_load)
val_tensor_dset = torch.utils.data.TensorDataset(torch.Tensor(dset[val_inds]), torch.Tensor(labels_full[val_inds]))
print(datetime.now()-start_load)
test_tensor_dset = torch.utils.data.TensorDataset(torch.Tensor(dset[test_inds]), torch.Tensor(labels_full[test_inds]))
print(datetime.now()-start_load)
train_tensor_dset = torch.utils.data.TensorDataset(torch.Tensor(dset[train_inds]), torch.Tensor(labels_full[train_inds]))
print(datetime.now()-start_load)
training_loader = torch.utils.data.DataLoader(train_tensor_dset, batch_size=4, shuffle=True)
print(datetime.now()-start_load)
validation_loader = torch.utils.data.DataLoader(val_tensor_dset, batch_size=4, shuffle=False)
print(datetime.now()-start_load)
test_loader = torch.utils.data.DataLoader(test_tensor_dset, batch_size=4, shuffle=False)
print(datetime.now()-start_load)

starting at  2023-11-27 19:37:59.952691
0:01:35.625081
0:02:52.873978
0:27:54.498324
0:27:54.503382
0:27:54.503570
0:27:54.504362


In [13]:
curr_model = MLPModel()
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(curr_model.parameters(), lr=0.000005, momentum=0.8)

In [14]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.
    
    all_predictions = []
    all_labels = []
    
    for i, data in enumerate(training_loader):
        inputs, labels = data

        optimizer.zero_grad()
        outputs = curr_model(inputs.transpose(1,2))

        loss = loss_fn(outputs, labels.squeeze(1))
        loss.backward()
        # Apply torch.nan_to_num to gradients
        for param in curr_model.parameters():
            if param.grad is not None:
                param.grad = torch.nan_to_num(param.grad)
                
        optimizer.step()
        running_loss += loss.item()
        
        all_predictions.append(torch.nan_to_num(outputs).detach())
        all_labels.append(torch.nan_to_num(labels).detach())


        
        if i % 250 == 249:
            last_loss = running_loss / 250
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.
            predictions_flat = torch.cat(all_predictions).view(-1)
            labels_flat = torch.cat(all_labels).view(-1)
            pcorr = pearson_correlation(predictions_flat, labels_flat)
            scorr = spearman_correlation(predictions_flat,labels_flat)
            print("  batch {} Pearson correlation: {}, Spearman correlation: {}".format(i + 1, pcorr.item(),scorr.item()))

            all_predictions = []
            all_labels = []
            
            

    return last_loss, pcorr, scorr

In [15]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/mlp_model_{}'.format(timestamp))

epoch_number = 0
EPOCHS = 5
best_vloss = 1_000_000.

epoch_loss, epoch_coeff = [], []

for epoch in range(EPOCHS):
    print('EPOCH {}\n-------------------------------'.format(epoch_number + 1))

    curr_model.train(True)
    avg_loss,pcorr,scorr = train_one_epoch(epoch_number, writer)
    running_vloss = 0.0
    curr_model.eval()
    
    all_vpredictions = []
    all_vlabels = []

    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = curr_model(vinputs.transpose(1,2))
            vloss = loss_fn(voutputs, vlabels.squeeze(1))
            running_vloss += vloss
            
            all_vpredictions.append(voutputs)
            all_vlabels.append(vlabels)

    avg_vloss = running_vloss / (i + 1)
    
    predictions_flat = torch.cat(all_vpredictions).view(-1)
    labels_flat = torch.cat(all_vlabels).view(-1)
    pval_corr = pearson_correlation(predictions_flat, labels_flat)
    sval_corr = spearman_correlation(predictions_flat,labels_flat)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
    print('Validation Pearson Correlation: {}, Spearman Correlation: {}'.format(pval_corr.item(),sval_corr.item()))
    
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(curr_model.state_dict(), model_path)

    epoch_loss.append((avg_loss, avg_vloss))
    epoch_coeff.append((scorr, sval_corr))
    
    epoch_number += 1

EPOCH 1
-------------------------------
  batch 250 loss: 3279.786724243164
  batch 250 Pearson correlation: 0.6894281268379249, Spearman correlation: 0.6523257050599491
  batch 500 loss: 538.5323609008789
  batch 500 Pearson correlation: 0.9412819787657906, Spearman correlation: 0.9079783060343478
  batch 750 loss: 514.6707052001954
  batch 750 Pearson correlation: 0.9443298605854294, Spearman correlation: 0.9147097326673045
  batch 1000 loss: nan
  batch 1000 Pearson correlation: 0.948649805484788, Spearman correlation: 0.9212098150777837
  batch 1250 loss: 488.90921905517575
  batch 1250 Pearson correlation: 0.9476468753485154, Spearman correlation: 0.9185681405327344
  batch 1500 loss: 465.7970531005859
  batch 1500 Pearson correlation: 0.949775832025745, Spearman correlation: 0.9211608049084798
  batch 1750 loss: 464.2482590332031
  batch 1750 Pearson correlation: 0.9505877611041958, Spearman correlation: 0.9236150204008549
  batch 2000 loss: 473.7239056396484
  batch 2000 Pearson

  batch 7000 loss: 392.3660195007324
  batch 7000 Pearson correlation: 0.9573841366447077, Spearman correlation: 0.9333758817968569
  batch 7250 loss: 407.49762774658205
  batch 7250 Pearson correlation: 0.9555485332182743, Spearman correlation: 0.9313954613067474
  batch 7500 loss: 414.966072265625
  batch 7500 Pearson correlation: 0.9554544053193875, Spearman correlation: 0.9311305866807961
  batch 7750 loss: 429.2652268066406
  batch 7750 Pearson correlation: 0.9540675841169267, Spearman correlation: 0.9314227650907891
  batch 8000 loss: nan
  batch 8000 Pearson correlation: 0.9571080840176451, Spearman correlation: 0.9319939117522714
  batch 8250 loss: 389.50419952392576
  batch 8250 Pearson correlation: 0.9583958631622511, Spearman correlation: 0.9386249382587175
  batch 8500 loss: 406.2132159423828
  batch 8500 Pearson correlation: 0.9559802432578623, Spearman correlation: 0.9327588409721504
LOSS train 406.2132159423828 valid 506.2987365722656
Validation Pearson Correlation: 0.94

  batch 5000 loss: 384.2023699645996
  batch 5000 Pearson correlation: 0.9586742654038378, Spearman correlation: 0.9385369600721981
  batch 5250 loss: 374.3661752319336
  batch 5250 Pearson correlation: 0.9602778612916012, Spearman correlation: 0.9407236252469822
  batch 5500 loss: 368.11929946899414
  batch 5500 Pearson correlation: 0.9595254005474174, Spearman correlation: 0.9398692500952738
  batch 5750 loss: 396.49502029418943
  batch 5750 Pearson correlation: 0.9576599945380139, Spearman correlation: 0.9365927927540821
  batch 6000 loss: 385.8930966796875
  batch 6000 Pearson correlation: 0.9590336619880777, Spearman correlation: 0.9370725052974254
  batch 6250 loss: 385.945269317627
  batch 6250 Pearson correlation: 0.9582135046574213, Spearman correlation: 0.9344298509529265
  batch 6500 loss: 372.2580586395264
  batch 6500 Pearson correlation: 0.9596703167882401, Spearman correlation: 0.9405581461749425
  batch 6750 loss: 386.50347763061524
  batch 6750 Pearson correlation: 0.9

In [31]:
print(epoch_coeff)

[(0.8023441110268037, 0.8198872713965447), (0.8579737297804155, 0.8433671005928176), (0.8603200804955299, 0.8590490294150954), (0.8825058004687943, 0.8827406891320867), (0.8728390414949895, 0.8743633067935532)]
