In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# General import, names of train, test, val
import numpy as np
import pysam
from tqdm.notebook import tqdm
import h5py
import pandas as pd
import matplotlib.pyplot as plt
import time
rng = np.random.default_rng(seed=0)

basedir = '/data/leslie/shared/ASA/'
aligndir = f'{basedir}pseudodiploid/atac/'
ctype = 'cd8'
ident = '_vi_chrom'

datadir = f'{basedir}mouseASA/{ctype}/cast/data/'
chroms = list(range(1,20))

reps = ['r1','r2','r3','r4','r5']
seqlen = 2114                         # region around summit for sequence
outlen = 1000                        # region around summit for coverage
save = True                          # failsafe to prevent unwanted overwriting

# Preprocessing of Model Input

In [3]:
# Get replicate counts
from utils import get_shifts, one_hot, get_neg_summits

N = []
for rep in reps:
    bamfile = aligndir+ctype+'/'+rep+'.combined.rmDup.Aligned.sortedByCoord.out.bam'
    bamf = pysam.AlignmentFile(bamfile, "rb")
    N.append( sum([bamf.get_index_statistics()[i][1] for i in range(len(chroms))]) )
    bamf.close()

In [4]:
# Get peak and uneg summits
if ident[:3]=='_vi':
    summits = pd.read_csv(aligndir+'cd8_old/yi_cd8_peaks_33143.csv',sep=',',index_col=0)
    summits['start'] += seqlen//2
    summits = summits.iloc[:,1:3]
    summits.columns = range(2)

In [5]:
# Augment the list of summits in train and val by randomly adding seqlen//3 shifts (100bp)
augment = True      # for augmentation of data by shifting sequence window

if augment:
    ident += '_aug'
    frac = 1.0
    # Randomly shift all the summits by + or - seqlen//3
    summits_aug = summits.copy()
    shifts = seqlen//3 * rng.choice([-1,1], len(summits_aug), replace=True)
    summits_aug[1] += shifts

In [6]:
# Generate x, y and p (profile)
import gzip
from Bio import SeqIO
from bisect import bisect

bamfile = aligndir+ctype+'/'+reps[0]+'.combined.rmDup.Aligned.sortedByCoord.out.bam'
bamf = pysam.AlignmentFile(bamfile, "rb")

def neg_summit_generator(c, chromsummits, seqlen, bamf):
    neg_summits = np.empty(0, dtype=np.int64)
    seed=0
    while len(neg_summits)<len(chromsummits):     # get neg summits and only keep low coverage ones (<5)
        temp = get_neg_summits(chromsummits, len(chromsummits)-len(neg_summits), len(seq_b6), seed)
        idx = np.where(np.array([bamf.count(str(c),i-seqlen//2,i+seqlen//2) for i in temp]) < 20)[0]
        temp = temp[idx]
        neg_summits = np.concatenate((neg_summits, temp))
        seed+=1
    neg_summits = np.sort(neg_summits)
    return neg_summits

def parse_bed(path):
    df = pd.read_csv(path, sep='\t', header=None)
    df = df.loc[np.where(df[0].isin(list(range(1,20))))[0],:].reset_index(drop=True)
    df[0] = df[0].astype(int)
    temp = np.empty(len(df), dtype=int)
    idx = np.where(df[5]=='+')[0]
    temp[idx] = df[1][idx]
    idx = np.where(df[5]=='-')[0]
    temp[idx] = df[2][idx] - 1        # Correct for 1-based indexing in third BED column
    df[1]=temp
    df = df.drop([2,3,4,5], axis=1)
    return df

gen_b6 = SeqIO.index(f'{basedir}pseudodiploid/gen/b6.fa', 'fasta')
gen_cast = SeqIO.index(f'{basedir}pseudodiploid/gen/cast.fa', 'fasta')

modfile = f'{basedir}pseudodiploid/gen/cast.mod'
with gzip.open(modfile,'rt') as f:
    mods = f.read().split('\n')
    mods = [x for x in mods if not (x.startswith('s') or x.startswith('#'))][:-1]

In [13]:
x = dict()
p = dict()
neg_summits = dict()

for cnt,rep in enumerate(tqdm(reps)):
    temp1 = parse_bed(f'{datadir}bed/{rep}_b6.bed.gz')
    temp2 = parse_bed(f'{datadir}bed/{rep}_cast.bed.gz')
    temp3 = parse_bed(f'{datadir}bed/{rep}_both.bed.gz')

    for c in tqdm(chroms):         # for each chr (in first rep get all the x seqs)
        for allele in ['b6','ca']:
            for label in ['','_unegs','_aug']:
                p[rep+'_chr'+str(c)+'_'+allele+label] = []
                if cnt==0:
                    x['chr'+str(c)+'_'+allele+label] = []
        
        
        chromsummits = summits.iloc[np.where(summits[0]==c)[0],1]     # slice out the relevant chromosome summits
        augsummits = summits_aug.iloc[np.where(summits_aug[0]==c)[0],1]
        # For sequence (x)
        if cnt==0:
            seq_b6 = ''.join(gen_b6.get_raw(str(c)).decode().split('\n')[1:])
            seq_cast = ''.join(gen_cast.get_raw(str(c)).decode().split('\n')[1:])
            # get relevant b6 & ca genomic seqs
            cast_shifts = get_shifts(chromsummits, mods, c)
            x['chr'+str(c)+'_b6'] += [seq_b6[i-seqlen//2:i+seqlen//2] for i in chromsummits]
            x['chr'+str(c)+'_ca'] += [seq_cast[i+j-seqlen//2:i+j+seqlen//2] for i,j in zip(chromsummits,cast_shifts)]
            # get relevant uneg genomic seqs
            neg_summits[c] = neg_summit_generator(c, chromsummits, seqlen, bamf)
            cast_shifts = get_shifts(neg_summits[c], mods, c)
            x['chr'+str(c)+'_b6_unegs'] += [seq_b6[i-seqlen//2:i+seqlen//2] for i in neg_summits[c]]
            x['chr'+str(c)+'_ca_unegs'] += [seq_cast[i+j-seqlen//2:i+j+seqlen//2] for i,j in zip(neg_summits[c],cast_shifts)]
            # get relevant aug genomic seqs
            cast_shifts = get_shifts(augsummits, mods, c)
            x['chr'+str(c)+'_b6_aug'] += [seq_b6[i-seqlen//2:i+seqlen//2] for i in augsummits]
            x['chr'+str(c)+'_ca_aug'] += [seq_cast[i+j-seqlen//2:i+j+seqlen//2] for i,j in zip(augsummits,cast_shifts)]
            for allele in ['b6','ca']:
                for label in ['','_unegs','_aug']:
                    x['chr'+str(c)+'_'+allele+label] = one_hot(x['chr'+str(c)+'_'+allele+label])         # convert string of nucleotides to one-hot representation

        
        # For profile (p)
        temp1_chrom = np.array(temp1.iloc[np.where(temp1[0]==c)][1])         # slice each BED dataframe of insertion sites
        temp2_chrom = np.array(temp2.iloc[np.where(temp2[0]==c)][1])
        temp3_chrom = np.array(temp3.iloc[np.where(temp3[0]==c)][1])

        p_b6 = []
        p_cast = []
        for i in chromsummits:
            p1 = np.bincount(temp1_chrom[np.where(np.logical_and(temp1_chrom >= i-outlen//2, temp1_chrom < i+outlen//2))[0]] - (i-outlen//2))
            p1 = np.pad(p1,[0,outlen-len(p1)])
            p2 = np.bincount(temp2_chrom[np.where(np.logical_and(temp2_chrom >= i-outlen//2, temp2_chrom < i+outlen//2))[0]] - (i-outlen//2))
            p2 = np.pad(p2,[0,outlen-len(p2)])
            p3 = np.bincount(temp3_chrom[np.where(np.logical_and(temp3_chrom >= i-outlen//2, temp3_chrom < i+outlen//2))[0]] - (i-outlen//2))
            p3 = np.pad(p3,[0,outlen-len(p3)])
            p_b6.append(p1+p3/2)
            p_cast.append(p2+p3/2)
        p[rep+'_chr'+str(c)+'_b6'] += p_b6
        p[rep+'_chr'+str(c)+'_ca'] += p_cast

        p_b6 = []
        p_cast = []
        for i in neg_summits[c]:    # neg_summits already assigned from seq computation
            p1 = np.bincount(temp1_chrom[np.where(np.logical_and(temp1_chrom >= i-outlen//2, temp1_chrom < i+outlen//2))[0]] - (i-outlen//2))
            p1 = np.pad(p1,[0,outlen-len(p1)])
            p2 = np.bincount(temp2_chrom[np.where(np.logical_and(temp2_chrom >= i-outlen//2, temp2_chrom < i+outlen//2))[0]] - (i-outlen//2))
            p2 = np.pad(p2,[0,outlen-len(p2)])
            p3 = np.bincount(temp3_chrom[np.where(np.logical_and(temp3_chrom >= i-outlen//2, temp3_chrom < i+outlen//2))[0]] - (i-outlen//2))
            p3 = np.pad(p3,[0,outlen-len(p3)])
            p_b6.append(p1+p3/2)
            p_cast.append(p2+p3/2)
        p[rep+'_chr'+str(c)+'_b6_unegs'] += p_b6
        p[rep+'_chr'+str(c)+'_ca_unegs'] += p_cast

        p_b6 = []
        p_cast = []
        for i in augsummits:    # neg_summits already assigned from seq computation
            p1 = np.bincount(temp1_chrom[np.where(np.logical_and(temp1_chrom >= i-outlen//2, temp1_chrom < i+outlen//2))[0]] - (i-outlen//2))
            p1 = np.pad(p1,[0,outlen-len(p1)])
            p2 = np.bincount(temp2_chrom[np.where(np.logical_and(temp2_chrom >= i-outlen//2, temp2_chrom < i+outlen//2))[0]] - (i-outlen//2))
            p2 = np.pad(p2,[0,outlen-len(p2)])
            p3 = np.bincount(temp3_chrom[np.where(np.logical_and(temp3_chrom >= i-outlen//2, temp3_chrom < i+outlen//2))[0]] - (i-outlen//2))
            p3 = np.pad(p3,[0,outlen-len(p3)])
            p_b6.append(p1+p3/2)
            p_cast.append(p2+p3/2)
        p[rep+'_chr'+str(c)+'_b6_aug'] += p_b6
        p[rep+'_chr'+str(c)+'_ca_aug'] += p_cast
    
    if cnt==0:
        bamf.close()
        gen_b6.close()
        gen_cast.close()
    
    del p_b6, p_cast, p1, p2, p3, chromsummits, augsummits, temp1, temp2, temp3, cast_shifts, seq_b6, seq_cast

  0%|          | 0/5 [00:00<?, ?it/s]

  df = pd.read_csv(path, sep='\t', header=None)
  df = pd.read_csv(path, sep='\t', header=None)
  df = pd.read_csv(path, sep='\t', header=None)


  0%|          | 0/19 [00:00<?, ?it/s]

  df = pd.read_csv(path, sep='\t', header=None)
  df = pd.read_csv(path, sep='\t', header=None)
  df = pd.read_csv(path, sep='\t', header=None)


  0%|          | 0/19 [00:00<?, ?it/s]

  df = pd.read_csv(path, sep='\t', header=None)
  df = pd.read_csv(path, sep='\t', header=None)
  df = pd.read_csv(path, sep='\t', header=None)


  0%|          | 0/19 [00:00<?, ?it/s]

  df = pd.read_csv(path, sep='\t', header=None)
  df = pd.read_csv(path, sep='\t', header=None)
  df = pd.read_csv(path, sep='\t', header=None)


  0%|          | 0/19 [00:00<?, ?it/s]

  df = pd.read_csv(path, sep='\t', header=None)
  df = pd.read_csv(path, sep='\t', header=None)
  df = pd.read_csv(path, sep='\t', header=None)


  0%|          | 0/19 [00:01<?, ?it/s]

In [22]:
# x's are ready
# p's need to be pooled

# pool across reps
for c in chroms:
    for allele in ['b6','ca']:
            for label in ['','_unegs','_aug']:
                p['chr'+str(c)+'_'+allele+label] = np.zeros_like(p[rep+'_chr'+str(c)+'_'+allele+label])
                for i,rep in enumerate(reps):
                    p['chr'+str(c)+'_'+allele+label] += np.array(p[rep+'_chr'+str(c)+'_'+allele+label])
                    del p[rep+'_chr'+str(c)+'_'+allele+label]

In [27]:
# Finally, save everything
if save:
    with h5py.File(datadir+'data'+ident+'.h5','w') as f:
        for key in x.keys():
            f.create_dataset('x_'+key, data=x[key])
        for key in p.keys():
            f.create_dataset('p_'+key, data=p[key])    