In [None]:
# import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib
from tqdm import tqdm
import pickle
import os
from Bio import Seq, SeqIO
from tensorflow import keras
from tensorflow.keras import backend as K
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping
from adabelief_tf import AdaBeliefOptimizer
from tqdm.keras import TqdmCallback
import socket
import glob
import math


os.environ["CUDA_VISIBLE_DEVICES"]="1" # pick which gpu to use - run nvidia-smi to see which ones are in use
physical_devices = tf.config.list_physical_devices('GPU')
# try:
#     tf.config.experimental.set_memory_growth(physical_devices[0], True)
try:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
    tf.config.experimental.set_memory_growth(physical_devices[1], True)
    tf.config.experimental.set_memory_growth(physical_devices[2], True)
    tf.config.experimental.set_memory_growth(physical_devices[3], True)
except:
    # Invalid device or cannot modify virtual devices once initialized.
    pass

if 'biochem1' in socket.gethostname():
    dataPBase = '/avicenna/vramani/analyses/pacbio/'
    figPBase = '/avicenna/cmcnally/pbanalysis/'
if 'assembler4' in socket.gethostname():
    dataPBase = '/data/users/goodarzilab/mostrowski/pacbio/data/'
if 'titan4' in socket.gethostname():
    dataPBase = '/data/users/goodarzilab/mostrowski/pacbio/data/'
if 'wynton' in socket.gethostname():
    dataPBase = '/wynton/group/goodarzilab/ramanilab/results/pacbio/'
if 'rumi' in socket.gethostname():
    raise Exception('no pacbio results folder on rumi')
    
    
sampleRef = pd.concat([pd.read_csv(dataPBase + '210520_NA_K562Brdu_repeat/210520_NA_K562Brdu_repeat.sampleReference.wynton.csv',index_col=0),
                       pd.read_csv(dataPBase + '210930_MO_E14_K562_BrdU/210930_MO_E14_K562_BrdU.sampleReference.wynton.csv',index_col=0),
                       pd.read_csv(dataPBase + '211014_MO_BrdU_invivo/211014_MO_BrdU_invivo.sampleReference.wynton.csv',index_col=0),
                       pd.read_csv(dataPBase + '220722_BrdUTP_methcontrols/220722_BrdUTP_methcontrols.sampleReference.wynton.csv',index_col=0),
                       pd.read_csv(dataPBase + '220722_BrdU_K562_invivopulse/220722_BrdU_K562_invivopulse.sampleReference.wynton.csv',index_col=0),
                       pd.read_csv(dataPBase + '211203_MO_BrdU_invo_1013/211203_MO_BrdU_invo_1013.sampleReference.wynton.csv',index_col=0),
                       pd.read_csv(dataPBase + '220128_MO_BrdU_shear_CM_spike-in/220128_MO_BrdU_shear_CM_spike-in.sampleReference_MO.wynton.csv',index_col=0),
                       pd.read_csv(dataPBase + '220920_MO_BrdU_invivo/220920_MO_BrdU_invivo.sampleReference.wynton.csv',index_col=0),
                       pd.read_csv(dataPBase + '221109_MO_BrdU_invivo/221109_MO_BrdU_invivo.sampleReference.wynton.csv',index_col=0),
                       pd.read_csv(dataPBase + '221017_MO_CTCFdegron_RASAM/221017_MO_CTCFdegron_RASAM.sampleReference.wynton.csv',index_col=0),
                       pd.read_csv(dataPBase + '230419_MO_K562_RASAM_1/230419_MO_K562_RASAM_1.sampleReference.wynton.csv',index_col=0),
                       pd.read_csv(dataPBase + '230419_MO_K562_RASAM_2/230419_MO_K562_RASAM_2.sampleReference.wynton.csv',index_col=0),
                       pd.read_csv(dataPBase + '230419_MO_CTCFdegron_RASAM/230419_MO_CTCFdegron_RASAM.sampleReference.wynton.csv',index_col=0),
                       pd.read_csv(dataPBase + '230419_MO_NIPBLdegron_RASAM/230419_MO_NIPBLdegron_RASAM.sampleReference.wynton.csv',index_col=0),
                       pd.read_csv(dataPBase + '230425_e14_RASAM_Cell1/230425_e14_RASAM_Cell1.sampleReference.wynton.csv',index_col=0),
                       pd.read_csv(dataPBase + '230425_e14_RASAM_Cell2/230425_e14_RASAM_Cell2.sampleReference.wynton.csv',index_col=0),
                       pd.read_csv(dataPBase + '230425_e14_RASAM_Cell3/230425_e14_RASAM_Cell3.sampleReference.wynton.csv',index_col=0),
                       pd.read_csv(dataPBase + '230425_e14_RASAM_Cell4/230425_e14_RASAM_Cell4.sampleReference.wynton.csv',index_col=0)],
                      ignore_index=True)

matplotlib.rcParams['font.sans-serif'] = "Arial"
matplotlib.rcParams['font.family'] = "sans-serif"

%matplotlib inline
matplotlib.rcParams['figure.dpi']= 120

sampleRef

## Making files formatted for reading into the model

In [None]:
# Merge block files if necessary

usesamples = range(0,10)

for samp in usesamples:
    import glob

    hmmPieces = glob.glob('{0}{1}/processed/full/{1}_{2}_block*_full.pickle'.format(dataPBase,sampleRef['cell'][samp],
                                                                                    sampleRef['sampleName'][samp]))
    hmmPieces = sorted(hmmPieces)

    hmmAll = {}

    for piece in tqdm(hmmPieces, position=0):
        with open(piece,'rb') as fopen:
            hmmPart = pickle.load(fopen)
        hmmAll.update(hmmPart)

    with open('{0}{1}/processed/full/{1}_{2}_full.pickle'.format(dataPBase,sampleRef['cell'][samp],
                                                                 sampleRef['sampleName'][samp]), 'wb') as fout:
        pickle.dump(hmmAll, fout)

In [None]:
def makeCNNinput(samp):
    baseDic = {'A':0, 'C':1, 'G':2, 'T':3}
    compBase = {'A':'T', 'C':'G', 'G':'C', 'T':'A'}

    Nbases = 500

    inff = os.path.join(dataPBase, sampleRef['cell'][samp], 'processed', 'full',
                        sampleRef['cell'][samp] + '_' + sampleRef['sampleName'][samp] + '_full_zmwinfo.pickle')
    zmwinfo = pd.read_pickle(inff)

    finff = os.path.join(dataPBase, sampleRef['cell'][samp], 'processed', 'full',
                         sampleRef['cell'][samp] + '_' + sampleRef['sampleName'][samp] + '_full.pickle')

    with open(finff,'rb') as fopen:
        ipdfull = pickle.load(fopen)

    nchunk = 0
    for ccl in zmwinfo['cclen']:
        if ccl >= Nbases:
            nchunk += int(np.ceil(ccl / Nbases))

    forwardIPD = np.full((nchunk, Nbases, 1), -1, dtype=np.float32)
    reverseIPD = np.full((nchunk, Nbases, 1), -1, dtype=np.float32)
    forwardNsub = np.full((nchunk, Nbases, 1), -1, dtype=np.int16)
    reverseNsub = np.full((nchunk, Nbases, 1), -1, dtype=np.int16)
    forwardSeq = np.full((nchunk, Nbases, 4), 0, dtype=np.bool)
    reverseSeq = np.full((nchunk, Nbases, 4), 0, dtype=np.bool)
    zmwNum = np.full((nchunk, 1), -1, dtype=np.int64)
    zmwStartPos = np.full((nchunk, 1), -1, dtype=np.int32)

    imol = 0
    for zmw in tqdm(zmwinfo['zmw'], position=0, mininterval=1,
                    desc='{0}__{1}'.format(sampleRef['cell'][samp], sampleRef['sampleName'][samp])):
        ccl = len(ipdfull[zmw]['read'])
        if ccl < Nbases:
            continue
        npiece = int(np.ceil(ccl / Nbases))
        interstarts = np.rint(np.interp(np.arange(1,npiece-1), [0, npiece-1], [0, ccl-Nbases])).astype('int')
        startp = np.concatenate([[0], interstarts, [ccl-Nbases]])

        ipdz = ipdfull[zmw]
        seq = ipdz['read']
        for ich in range(npiece):
            usebases = np.arange(startp[ich], startp[ich]+Nbases)
            zmwNum[imol,0] = zmw
            zmwStartPos[imol,0] = startp[ich]
            forwardIPD[imol,:,0] = ipdz['forwardM'][usebases]
            reverseIPD[imol,:,0] = ipdz['reverseM'][usebases]
            forwardNsub[imol,:,0] = ipdz['forwardNSub'][usebases]
            reverseNsub[imol,:,0] = ipdz['reverseNsub'][usebases]
            for ib, b in enumerate(usebases):
                forwardSeq[imol,ib,baseDic[seq[b]]] = True
                reverseSeq[imol,ib,baseDic[compBase[seq[b]]]] = True
            imol += 1

    forwardIPD[np.isnan(forwardIPD)] = -1
    reverseIPD[np.isnan(reverseIPD)] = -1

    if not os.path.exists(dataPBase + '%s/processed/forNN' % (sampleRef['cell'][samp])):
        os.makedirs(dataPBase + '%s/processed/forNN' % (sampleRef['cell'][samp]))

    np.savez(os.path.join(dataPBase, sampleRef['cell'][samp],'processed','forNN',
                          '{0}_{1}_forCNN_preds.npz'.format(sampleRef['cell'][samp],
                                                            sampleRef['sampleName'][samp])), 
             forwardIPD = forwardIPD,
             reverseIPD = reverseIPD,
             forwardNsub = forwardNsub,
             reverseNsub = reverseNsub,
             forwardSeq = forwardSeq,
             reverseSeq = reverseSeq,
             zmwNum = zmwNum,
             zmwStartPos = zmwStartPos)

In [None]:
for samp in range(0,10):
    makeCNNinput(samp)

## Start setting up the model

In [None]:
# read in the data and get it set up for the model

usesamples = [0,1,2,3,8,9,10,11,12,13,14,15,16,17,19,20,27,28,29,30,31]


weight24h = 0.8
weightPCR50 = 0.5

sampBrdud = {0:0, 1:0, 2:0, 3:0, 8:weight24h, 9:weight24h, 10:weight24h, 11:weight24h, 12:0, 13:0, 14:0, 15:1, 
             16:0, 17:1, 19:0, 20:0, 27:weight24h, 28:weight24h, 29:weightPCR50, 30:weightPCR50, 31:1, 43:0, 
             44:0, 51:1, 52:1, 53:0, 56:1, 57:0, 60:1, 67:1, 71:1, 74:1}

## Load in data
forwardIPD = []
reverseIPD = []
forwardNsub = []
reverseNsub = []
forwardSeq = []
reverseSeq = []
sampB = []
sampW = []
sampO = []
# training samples take the first 18000 molecules, validation samples take up to the first 40000 

for samp in tqdm(usesamples, position=0):

    with np.load(os.path.join(dataPBase, sampleRef['cell'][samp],'processed','forNN',
                              '{0}_{1}_forCNN_preds.npz'.format(sampleRef['cell'][samp],
                                                                sampleRef['sampleName'][samp]))) as data:
        usemol = np.arange(data['forwardIPD'].shape[0])
        usemol = usemol[0:40000]
        forwardIPD.append(data['forwardIPD'][usemol,:,:] - 1)
        reverseIPD.append(data['reverseIPD'][usemol,:,:] - 1)
        forwardNsub.append((data['forwardNsub'][usemol,:,:] - 25) / 100)
        reverseNsub.append((data['reverseNsub'][usemol,:,:] - 25) / 100)
        forwardSeq.append(data['forwardSeq'][usemol,:,:])
        reverseSeq.append(data['reverseSeq'][usemol,:,:])
        sampB.append(np.full((len(usemol),1), sampBrdud[samp]))
        sampO.append(np.full((len(usemol),1), samp))

forwardIPD = np.concatenate(forwardIPD, axis=0)
reverseIPD = np.concatenate(reverseIPD, axis=0)
forwardNsub = np.concatenate(forwardNsub, axis=0)
reverseNsub = np.concatenate(reverseNsub, axis=0)
forwardSeq = np.concatenate(forwardSeq, axis=0)
reverseSeq = np.concatenate(reverseSeq, axis=0)
sampB = np.concatenate(sampB, axis=0).astype(np.float32)
sampO = np.concatenate(sampO, axis=0)

forward = np.concatenate([forwardIPD, forwardNsub, forwardSeq], axis=2)
reverse = np.concatenate([reverseIPD, reverseNsub, reverseSeq], axis=2)

In [None]:
# select training and validation data (s1)

trainMol = np.random.choice(forwardIPD.shape[0], int(forwardIPD.shape[0]*0.7), replace=False)
validMol = np.setdiff1d(np.arange(forwardIPD.shape[0]), trainMol)

trainMol = np.sort(trainMol)
validMol = np.sort(validMol)


In [None]:
# model with both max and average pooling (cnn2)

forwardInput = keras.layers.Input(forward.shape[1:])
reverseInput = keras.layers.Input(reverse.shape[1:])

reverseFlip1 = keras.layers.Lambda(lambda x: K.reverse(x,axes=1))(reverseInput)

convLayerM = keras.layers.Conv1D(200, 21, kernel_initializer="he_uniform", padding="valid")
forwardConvM = convLayerM(forwardInput)
reverseConvM = convLayerM(reverseFlip1)

convLayerMB1 = keras.layers.BatchNormalization()
forwardConvMB1 = convLayerMB1(forwardConvM)
reverseConvMB1 = convLayerMB1(reverseConvM)

convLayerMA1 = keras.layers.Activation(activation="relu")
forwardConvMA1 = convLayerMA1(forwardConvMB1)
reverseConvMA1 = convLayerMA1(reverseConvMB1)

convLayerM2 = keras.layers.Conv1D(200, 1, kernel_initializer="he_uniform", padding="valid")
forwardConvM2 = convLayerM2(forwardConvMA1)
reverseConvM2 = convLayerM2(reverseConvMA1)

convLayerMB2 = keras.layers.BatchNormalization()
forwardConvMB2 = convLayerMB2(forwardConvM2)
reverseConvMB2 = convLayerMB2(reverseConvM2)

convLayerMA2 = keras.layers.Activation(activation="relu")
forwardConvMA2 = convLayerMA2(forwardConvMB2)
reverseConvMA2 = convLayerMA2(reverseConvMB2)

reverseFlip2M = keras.layers.Lambda(lambda x: K.reverse(x,axes=1))(reverseConvMA2)

forrevConcatM = keras.layers.Concatenate(axis=1)([forwardConvMA2, reverseFlip2M])

poolLayerM = keras.layers.GlobalMaxPool1D()(forrevConcatM)

convLayerA = keras.layers.Conv1D(200, 21, kernel_initializer="he_uniform", padding="valid")
forwardConvA = convLayerM(forwardInput)
reverseConvA = convLayerM(reverseFlip1)

convLayerAB1 = keras.layers.BatchNormalization()
forwardConvAB1 = convLayerAB1(forwardConvA)
reverseConvAB1 = convLayerAB1(reverseConvA)

convLayerAA1 = keras.layers.Activation(activation="relu")
forwardConvAA1 = convLayerAA1(forwardConvAB1)
reverseConvAA1 = convLayerAA1(reverseConvAB1)

convLayerA2 = keras.layers.Conv1D(200, 1, kernel_initializer="he_uniform", padding="valid")
forwardConvA2 = convLayerM2(forwardConvAA1)
reverseConvA2 = convLayerM2(reverseConvAA1)

convLayerAB2 = keras.layers.BatchNormalization()
forwardConvAB2 = convLayerAB2(forwardConvA2)
reverseConvAB2 = convLayerAB2(reverseConvA2)

convLayerAA2 = keras.layers.Activation(activation="relu")
forwardConvAA2 = convLayerMA2(forwardConvAB2)
reverseConvAA2 = convLayerMA2(reverseConvAB2)

reverseFlip2A = keras.layers.Lambda(lambda x: K.reverse(x,axes=1))(reverseConvAA2)

forrevConcatA = keras.layers.Concatenate(axis=1)([forwardConvAA2, reverseFlip2A])

poolLayerA = keras.layers.GlobalAveragePooling1D()(forrevConcatA)

combineLayer = keras.layers.Concatenate()([poolLayerM, poolLayerA])
drop1 = keras.layers.Dropout(0.5)(combineLayer)
midLayer = keras.layers.Dense(400, activation="relu", kernel_initializer="he_uniform")(drop1)
drop2 = keras.layers.Dropout(0.5)(midLayer)
midLayer2 = keras.layers.Dense(400, activation="relu", kernel_initializer="he_uniform")(drop2)
drop3 = keras.layers.Dropout(0.5)(midLayer2)
midLayer3 = keras.layers.Dense(400, activation="relu", kernel_initializer="he_uniform")(drop3)
drop4 = keras.layers.Dropout(0.5)(midLayer3)
outLayer = keras.layers.Dense(1, activation="sigmoid")(drop4)

model = keras.models.Model(inputs=[forwardInput, reverseInput],
                           outputs=[outLayer])

optimizer = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999)
model.compile(loss="binary_crossentropy", optimizer=optimizer)

In [None]:
# Training

earlystop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

opt = tf.keras.optimizers.Adam() 
model.compile(optimizer=opt, loss=tf.keras.losses.binary_crossentropy, metrics=['accuracy'])

history = model.fit([forward[trainMol,:,:], reverse[trainMol,:,:]], sampB[trainMol,:],
                    batch_size=32, epochs=40, shuffle=True, verbose=0,
                    validation_data=([forward[validMol,:,:], reverse[validMol,:,:]],
                                     sampB[validMol,:]),
                    callbacks=[earlystop, TqdmCallback(verbose=1, miniters=40, mininterval=0.4)])

In [None]:
# Save training and validation molecules

np.save(dataPBase + '%s/processed/NNmodels/brduModel_230619_s1_cnn2_t2_forROC_trainMol_idx' % 
        (sampleRef['cell'][samp]), trainMol)
np.save(dataPBase + '%s/processed/NNmodels/brduModel_230619_s1_cnn2_t2_forROC_trainMol_forward' % 
        (sampleRef['cell'][samp]), forward[trainMol,:,:])
np.save(dataPBase + '%s/processed/NNmodels/brduModel_230619_s1_cnn2_t2_forROC_trainMol_reverse' % 
        (sampleRef['cell'][samp]), reverse[trainMol,:,:])
np.save(dataPBase + '%s/processed/NNmodels/brduModel_230619_s1_cnn2_t2_forROC_trainMol_sampB' % 
        (sampleRef['cell'][samp]), sampB[trainMol,:])

np.save(dataPBase + '%s/processed/NNmodels/brduModel_230619_s1_cnn2_t2_forROC_validMol_idx' % 
        (sampleRef['cell'][samp]), validMol)
np.save(dataPBase + '%s/processed/NNmodels/brduModel_230619_s1_cnn2_t2_forROC_validMol_forward' % 
        (sampleRef['cell'][samp]), forward[validMol,:,:])
np.save(dataPBase + '%s/processed/NNmodels/brduModel_230619_s1_cnn2_t2_forROC_validMol_reverse' % 
        (sampleRef['cell'][samp]), reverse[validMol,:,:])
np.save(dataPBase + '%s/processed/NNmodels/brduModel_230619_s1_cnn2_t2_forROC_validMol_sampB' % 
        (sampleRef['cell'][samp]), sampB[validMol,:])


In [None]:
samp = 14

model.save(dataPBase + '%s/processed/NNmodels/brduModel_230619_s1_cnn2_t2_forROC' % (sampleRef['cell'][samp]))
np.save((dataPBase + '%s/processed/NNmodels/brduModel_230619_s1_cnn2_t2_forROC_history' % 
         (sampleRef['cell'][samp])),history.history)

In [None]:
print(history.history.keys())

# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

In [None]:
print(history.history.keys())

# summarize history for accuracy
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

In [None]:
v_preds = model.predict([forward[validMol,:,:], reverse[validMol,:,:]], batch_size=1024, verbose=1)

In [None]:
np.save(dataPBase + '%s/processed/NNmodels/brduModel_230619_s1_cnn2_t2_forROC_validMol_preds' % 
        (sampleRef['cell'][samp]), v_preds)


In [None]:
# roc curve and auc for validation data 
from sklearn import metrics

fpr, tpr, _ = metrics.roc_curve(sampB[validMol,:].astype(int),  v_preds)
auc = metrics.roc_auc_score(sampB[validMol,:].astype(int),  v_preds)

#create ROC curve
plt.plot(fpr,tpr,label="AUC="+str(auc))
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.legend(loc=4)
plt.show()

In [None]:
bins = np.linspace(0,1,21)
binc = bins[0:-1] + (bins[1]-bins[0])/2
samphd = {}
for samp in usesamples:
    hist, bine = np.histogram(preds[sampO == samp], bins=bins, density=True)
    samphd[samp] = hist
    

plt.figure(figsize=(10,8))
for isamp, samp in enumerate(usesamples):
    if isamp < 9:
        plt.plot(binc, samphd[samp], label=sampleRef['sampleName'][samp])
    else:
        plt.plot(binc, samphd[samp], label=sampleRef['sampleName'][samp], ls='--')
plt.legend()

# Make predictions using trained model

In [None]:
# Load model, set chunk size, and define function

savesamp = 14
brdumodel = keras.models.load_model(dataPBase + '%s/processed/NNmodels/brduModel_220919_s1_cnn2_t2' % 
                                    (sampleRef['cell'][savesamp]))

chunksize = 150000

def predictBrdU(samp):
    with np.load(os.path.join(dataPBase, sampleRef['cell'][samp],'processed','forNN','{0}_{1}_forCNN_preds.npz'.format(sampleRef['cell'][samp],sampleRef['sampleName'][samp]))) as data:

        print('dimension of npz file is: ' + str(data['forwardIPD'].shape))
        print('number of 500 bp chunks is: ' + str(data['zmwNum'].shape[0]))
        top = math.ceil(data['zmwNum'].shape[0]/chunksize) # number of iterations of 10,000 500 bp chunks
        print('number of loops is: ' + str(top))

        rem = data['zmwNum'].shape[0] - ((top-1) * chunksize) # number of chunks on the last iteration
        print('remainder is: ' + str(rem))

        end = chunksize
        start = 0
        cycle = 0

        print('starting loop')
        
        for i in range (0, top):
            
            if cycle == top - 1:
                end = start + rem
            
            print('starting index: ' + str(start))
            print('ending index: ' + str(end))
            
            forward = np.concatenate([data['forwardIPD'][start:end] - 1,
                                      (data['forwardNsub'][start:end] - 25) / 100,
                                      data['forwardSeq'][start:end]], axis=2)
            print('forward concatenated' + str(len(forward)))
            reverse = np.concatenate([data['reverseIPD'][start:end] - 1,
                                      (data['reverseNsub'][start:end] - 25) / 100,
                                      data['reverseSeq'][start:end]], axis=2)
            print('reverse concatenated')
            
            zmwN = data['zmwNum'][start:end]
            print('num')
            zmwP = data['zmwStartPos'][start:end]
            print('start')
            
            # run predictions
            prs = brdumodel.predict([forward, reverse], batch_size=1024, verbose=1)

            # save predictions
            # turn predictions into a format to save

            zmwIx = {}
            lastzmw = None
            for i in range(zmwN.shape[0]):
                thiszmw = zmwN[i,0]
                if thiszmw != lastzmw:
                    zmwIx[thiszmw] = [i]
                else:
                    zmwIx[thiszmw].append(i)
                lastzmw = thiszmw

            zmws = list(zmwIx.keys())
            zmws = sorted(zmws)

            brduPD = {}
            for zmw in zmws:
                startp = [zmwP[ix,0] for ix in zmwIx[zmw]]
                maxstartp = np.max(startp)
                brduRawp = np.full((len(startp), maxstartp+500), np.nan, dtype=np.float32)
                for eix, ix in enumerate(zmwIx[zmw]):
                    brduRawp[eix,zmwP[ix,0]:(zmwP[ix,0]+500)] = prs[ix]
                brduPD[zmw] = np.nanmean(brduRawp, axis=0)

            if not os.path.exists(dataPBase + '%s/processed/brduPrediction' % (sampleRef['cell'][samp])):
                os.makedirs(dataPBase + '%s/processed/brduPrediction' % (sampleRef['cell'][samp]))
                
            piece = "{:0>3d}".format(cycle)

            outf = '{0}{1}/processed/brduPrediction/{1}_{2}_brdu_220919_s1_cnn2_t2_adj_piece{3}.pickle'.format(dataPBase,
                                                                                     sampleRef['cell'][samp],
                                                                                     sampleRef['sampleName'][samp],piece)
            print(outf)
            with open(outf, 'wb') as fout:
                pickle.dump(brduPD, fout, protocol=4)
                
            
            # iterate through chunks
            start += chunksize
            end += chunksize
            
            cycle += 1
        

In [None]:
# Run function that actually makes the predictions

usesamples = range(0,10)

for samp in usesamples:
    predictBrdU(samp)

In [None]:
# Unite all piece files into one prediction file

for samp in range(0,10):
    predDict={}
    binmol = {}
    
    predPieces = glob.glob('{0}{1}/processed/brduPrediction/{1}_{2}_brdu_220919_s1_cnn2_t2_adj_piece*.pickle'.format(dataPBase,
                                                                                        sampleRef['cell'][samp],
                                                                                        sampleRef['sampleName'][samp]))
    predPieces = sorted(predPieces)
    
    for piece in tqdm(predPieces, position=0):
        with open(piece,'rb') as fopen:
            predPart = pickle.load(fopen)
        predDict.update(predPart)
        

    # save the output as a file
    with open(dataPBase + '{0}/processed/brduPrediction/{0}_{1}_brdu_220919_s1_cnn2_t2_adj.pickle'.format(sampleRef['cell'][samp],
                                                                                          sampleRef['sampleName'][samp]), 'wb') as fout:
        pickle.dump(predDict, fout, protocol=4)

