In [1]:
from ds_util import get_dataset
from model import MLPModel
import numpy as np
import torch
import h5py
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from scipy.stats import spearmanr, pearsonr
from matplotlib import pyplot as plt

In [2]:
f = h5py.File('/clusterfs/nilah/oberon/datasets/basenji/embeddings/embeddings.h5','r')



In [3]:
dset = f['embeddings']

train_inds = np.zeros(len(dset), dtype=bool)
val_inds = np.zeros(len(dset), dtype=bool)
test_inds = np.zeros(len(dset), dtype=bool)

In [6]:
train_inds[:34021]=True
val_inds[34021:36234]=True
test_inds[36234:]=True

In [7]:
labels_full = h5py.File('/clusterfs/nilah/oberon/datasets/cs282a/dataset_14-lmnb1_4-cpg.h5')['single_bin']

In [8]:
def pearson_correlation(prediction, target):
    # Flatten the tensors to 1D
    prediction_flat = prediction.view(-1).cpu().detach().numpy()
    target_flat = target.view(-1).cpu().detach().numpy()

    # Calculate Pearson's correlation
    corr, _ = pearsonr(prediction_flat, target_flat)
    return corr

def spearman_correlation(prediction,target):
    # Flatten the tensors to 1D
    prediction_flat = prediction.view(-1).cpu().detach().numpy()
    target_flat = target.view(-1).cpu().detach().numpy()

    # Calculate Spearman's correlation
    corr, _ = spearmanr(prediction_flat, target_flat)
    return corr    

In [9]:
start_load = datetime.now()
print('starting at ',start_load)
val_tensor_dset = torch.utils.data.TensorDataset(torch.Tensor(dset[val_inds]), torch.Tensor(labels_full[val_inds]))
print(datetime.now()-start_load)
test_tensor_dset = torch.utils.data.TensorDataset(torch.Tensor(dset[test_inds]), torch.Tensor(labels_full[test_inds]))
print(datetime.now()-start_load)
train_tensor_dset = torch.utils.data.TensorDataset(torch.Tensor(dset[train_inds]), torch.Tensor(labels_full[train_inds]))
print(datetime.now()-start_load)
training_loader = torch.utils.data.DataLoader(train_tensor_dset, batch_size=4, shuffle=True)
print(datetime.now()-start_load)
validation_loader = torch.utils.data.DataLoader(val_tensor_dset, batch_size=4, shuffle=False)
print(datetime.now()-start_load)
test_loader = torch.utils.data.DataLoader(test_tensor_dset, batch_size=4, shuffle=False)
print(datetime.now()-start_load)

starting at  2023-11-28 06:54:34.427071
0:01:34.659593
0:02:54.480318
0:27:20.080791
0:27:20.097129
0:27:20.102772
0:27:20.102925


In [10]:
curr_model = MLPModel()
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(curr_model.parameters(), lr=0.000005, momentum=0.8)

In [11]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.
    
    all_predictions = []
    all_labels = []
    
    for i, data in enumerate(training_loader):
        inputs, labels = data

        optimizer.zero_grad()
        outputs = curr_model(inputs.transpose(1,2))

        loss = loss_fn(outputs, labels.squeeze(1))
        loss.backward()
        # Apply torch.nan_to_num to gradients
        for param in curr_model.parameters():
            if param.grad is not None:
                param.grad = torch.nan_to_num(param.grad)
                
        optimizer.step()
        running_loss += loss.item()
        
        all_predictions.append(torch.nan_to_num(outputs).detach())
        all_labels.append(torch.nan_to_num(labels).detach())


        
        if i % 250 == 249:
            last_loss = running_loss / 250
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.
            predictions_flat = torch.cat(all_predictions).view(-1)
            labels_flat = torch.cat(all_labels).view(-1)
            pcorr = pearson_correlation(predictions_flat, labels_flat)
            scorr = spearman_correlation(predictions_flat,labels_flat)
            print("  batch {} Pearson correlation: {}, Spearman correlation: {}".format(i + 1, pcorr.item(),scorr.item()))
            
            all_predictions = []
            all_labels = []
            
            

    return last_loss, pcorr, scorr

In [12]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/mlp_model_{}'.format(timestamp))

epoch_number = 0
EPOCHS = 5
best_vloss = 1_000_000.

epoch_loss, epoch_coeff = [], []

for epoch in range(EPOCHS):
    print('EPOCH {}\n-------------------------------'.format(epoch_number + 1))

    curr_model.train(True)
    avg_loss,pcorr,scorr = train_one_epoch(epoch_number, writer)
    running_vloss = 0.0
    curr_model.eval()
    
    all_vpredictions = []
    all_vlabels = []

    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = curr_model(vinputs.transpose(1,2))
            vloss = loss_fn(voutputs, vlabels.squeeze(1))
            running_vloss += vloss
            
            all_vpredictions.append(voutputs)
            all_vlabels.append(vlabels)

    avg_vloss = running_vloss / (i + 1)
    
    predictions_flat = torch.cat(all_vpredictions).view(-1)
    labels_flat = torch.cat(all_vlabels).view(-1)
    pval_corr = pearson_correlation(predictions_flat, labels_flat)
    sval_corr = spearman_correlation(predictions_flat,labels_flat)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
    print('Validation Pearson Correlation: {}, Spearman Correlation: {}'.format(pval_corr.item(),sval_corr.item()))
    
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(curr_model.state_dict(), model_path)

    epoch_loss.append((avg_loss, avg_vloss))
    epoch_coeff.append((scorr, sval_corr))
    
    epoch_number += 1

EPOCH 1
-------------------------------
  batch 250 loss: 6258.500887939453
  batch 250 Pearson correlation: 0.4885842766722295, Spearman correlation: 0.41591395644651225
  batch 500 loss: 1528.721437133789
  batch 500 Pearson correlation: 0.8308102605076149, Spearman correlation: 0.7209943193161956
  batch 750 loss: 1267.768804321289
  batch 750 Pearson correlation: 0.8596376582232973, Spearman correlation: 0.7696116168773964
  batch 1000 loss: 1102.4890606689453
  batch 1000 Pearson correlation: 0.8790601997473338, Spearman correlation: 0.808449254492517
  batch 1250 loss: 971.8797003173828
  batch 1250 Pearson correlation: 0.892697735618405, Spearman correlation: 0.8254352372660414
  batch 1500 loss: 927.9501573486328
  batch 1500 Pearson correlation: 0.8977073741979487, Spearman correlation: 0.8257128052160804
  batch 1750 loss: 920.7919936523438
  batch 1750 Pearson correlation: 0.8972298209505021, Spearman correlation: 0.8289878942558431
  batch 2000 loss: 851.7142744140625
  bat

  batch 7000 loss: 535.1036639404297
  batch 7000 Pearson correlation: 0.9425185073643791, Spearman correlation: 0.9122064434266535
  batch 7250 loss: 504.6178358154297
  batch 7250 Pearson correlation: 0.9455133195706107, Spearman correlation: 0.9166542642065639
  batch 7500 loss: 506.56117993164065
  batch 7500 Pearson correlation: 0.9457576360238681, Spearman correlation: 0.913790573669175
  batch 7750 loss: 559.8061134643555
  batch 7750 Pearson correlation: 0.9408207443896559, Spearman correlation: 0.9116657101295798
  batch 8000 loss: 503.13255047607424
  batch 8000 Pearson correlation: 0.9463194630345492, Spearman correlation: 0.9179569730048922
  batch 8250 loss: 519.1504570007324
  batch 8250 Pearson correlation: 0.9436020770489084, Spearman correlation: 0.9127364729756144
  batch 8500 loss: 519.7018712768555
  batch 8500 Pearson correlation: 0.9436869552669198, Spearman correlation: 0.9147040410992655
LOSS train 519.7018712768555 valid 636.3772583007812
Validation Pearson Cor

  batch 5000 loss: 486.3881654052734
  batch 5000 Pearson correlation: 0.9476562173682305, Spearman correlation: 0.9183542274551276
  batch 5250 loss: 495.434389251709
  batch 5250 Pearson correlation: 0.9469473385461082, Spearman correlation: 0.915395268735309
  batch 5500 loss: 469.77723889160154
  batch 5500 Pearson correlation: 0.9495135318380995, Spearman correlation: 0.9247631637810705
  batch 5750 loss: 486.8671820678711
  batch 5750 Pearson correlation: 0.9479753012068003, Spearman correlation: 0.9182951485855863
  batch 6000 loss: 479.1629215698242
  batch 6000 Pearson correlation: 0.9485871576052153, Spearman correlation: 0.9192041927362243
  batch 6250 loss: 468.47275384521487
  batch 6250 Pearson correlation: 0.9496198345955711, Spearman correlation: 0.9214520818407796
  batch 6500 loss: 471.9086069335938
  batch 6500 Pearson correlation: 0.9486781681167544, Spearman correlation: 0.9210683678867916
  batch 6750 loss: 475.11847619628907
  batch 6750 Pearson correlation: 0.94

In [13]:
print(epoch_coeff)

[(0.9008030417153318, 0.8945205164632815), (0.9147040410992655, 0.9039814078263012), (0.9177554508763093, 0.9091419272684197), (0.9192885341816951, 0.912128601300309), (0.9289113618717284, 0.9127458289255737)]
