In [9]:
from torch.optim import lr_scheduler
import os, pickle, time
import random
import numpy as np
import torch.optim as optim
import torch
from torch.utils.data import DataLoader,SubsetRandomSampler,Dataset
from scipy.stats import pearsonr,spearmanr
from scipy.sparse import load_npz,csr_matrix,save_npz
import subprocess
from torch.nn import MultiheadAttention
from einops.layers.torch import Rearrange


'''
_________________________
Necessary shit Zhuohan didn't include in the file
_________________________
'''

from torch import nn
from typing import Optional
from torch import Tensor

def pad_signal_matrix(matrix, pad_len=300):
    paddings = np.zeros(pad_len).astype('float32')
    dmatrix = np.vstack((paddings, matrix[:, -pad_len:]))[:-1, :]
    umatrix = np.vstack((matrix[:, :pad_len], paddings))[1:, :]
    return np.hstack((dmatrix, matrix, umatrix))

'''
_________________________
Prepare the data
_________________________
'''

def pad_seq_matrix(matrix, pad_len=300):
    paddings = np.zeros((1, 4, pad_len)).astype('int8')
    dmatrix = np.concatenate((paddings, matrix[:, :, -pad_len:]), axis=0)[:-1, :, :]
    umatrix = np.concatenate((matrix[:, :, :pad_len], paddings), axis=0)[1:, :, :] # WTF
    return np.concatenate((dmatrix, matrix, umatrix), axis=2)

def load_ref_genome(chr):
    #ref_path = 'Your Path'
    #ref_file = os.path.join(ref_path, 'Your File')
    #ref_file = '/home/gridsan/gschuette/binz_group_shared/zlao/data/hg38/chr1.fa'
    ref_file = '/home/gridsan/gschuette/binz_group_shared/zlao/data/hg38/chr1.npz'
    ref_gen_data = load_npz(ref_file).toarray().reshape(4, -1, 1000).swapaxes(0, 1)
    return torch.tensor(pad_seq_matrix(ref_gen_data))

def load_dnase(dnase_seq):
    dnase_seq = np.expand_dims(pad_signal_matrix(dnase_seq.reshape(-1, 1000)), axis=1)
    return torch.tensor(dnase_seq)

def prepare_train_data(cl,chrs):
    dnase_data={}
    ref_data={}
    #dnase_path = 'Your Path'
    #with open(dnase_path + 'Your File', 'rb') as f:
    with open('/home/gridsan/gschuette/binz_group_shared/zlao/data/GM12878_dnase.pickle','rb') as f:
        dnase = pickle.load(f)
    dnase_seq = dnase[chrs]

    dnase_data[chrs]=load_dnase(dnase_seq.toarray())
    ref_data[chrs]=load_ref_genome(chrs)
    return dnase_data, ref_data

def load_data(data,ref_data,dnase_data):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    args = get_args()
    data=data.numpy().astype(int)
    input=[]
    for i in range(data.shape[0]):
        chr = args.curr_chr
        s,e=data[i][1],data[i][2]
        input.append(torch.cat((ref_data[chr][s:e],dnase_data[chr][s:e]),dim=1))

    input= torch.stack(input)

    return input.float().to(device)

'''
_______________________
Model details (backbone + transformer encode + dimension reduction)
_______________________
'''

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        conv_kernel_size1 = 10
        conv_kernel_size2 = 8
        pool_kernel_size1 = 5
        pool_kernel_size2 = 4
        self.conv_net = nn.Sequential(
            nn.Conv1d(5, 256, kernel_size=conv_kernel_size1),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv1d(256, 256, kernel_size=conv_kernel_size1),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=pool_kernel_size1, stride=pool_kernel_size1),
            nn.BatchNorm1d(256),
            nn.Dropout(p=0.1),
            nn.Conv1d(256, 360, kernel_size=conv_kernel_size2),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv1d(360, 360, kernel_size=conv_kernel_size2),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=pool_kernel_size2, stride=pool_kernel_size2),
            nn.BatchNorm1d(360),
            nn.Dropout(p=0.1),
            nn.Conv1d(360, 512, kernel_size=conv_kernel_size2),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2),
            nn.Conv1d(512, 512, kernel_size=conv_kernel_size2),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(512),
            nn.Dropout(p=0.2))
        self.num_channels = 512
    def forward(self, x):
        out = self.conv_net(x)
        return out

def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu"):
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = _get_activation_fn(activation)
    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos
    def forward(self, src,
                    src_mask: Optional[Tensor] = None,
                    src_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None):
        src2 = self.norm1(src)
        q = k = self.with_pos_embed(src2, pos)
        src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)
        return src

def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
    def forward(self, src,
                mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        output = src
        for layer in self.layers:
            output = layer(output, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask, pos=pos)
        if self.norm is not None:
            output = self.norm(output)
        return output

class Transformer(nn.Module):
    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu"
                 ):
        super().__init__()
        self.num_encoder_layers = num_encoder_layers
        if num_decoder_layers > 0:
            encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                    dropout, activation)
            encoder_norm = nn.LayerNorm(d_model)
            self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        self._reset_parameters()
        self.d_model = d_model
        self.nhead = nhead

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, pos_embed=None, mask=None):
        src = src.permute(2, 0, 1)
        print('transformer_input', src.shape)
        if mask is not None:
            mask = mask.flatten(1)
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
        return memory.transpose(0,1)

class AttentionPool(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.pool_fn = Rearrange('b (n p) d-> b n p d', n=1)
        self.to_attn_logits = nn.Parameter(torch.eye(dim))

    def forward(self, x):
        attn_logits = einsum('b n d, d e -> b n e', x, self.to_attn_logits)
        x = self.pool_fn(x)
        logits = self.pool_fn(attn_logits)

        attn = logits.softmax(dim = -2)
        return (x * attn).sum(dim = -2).squeeze()

class Tranmodel(nn.Module):
    def __init__(self, backbone, transfomer, bins=200, max_bin=10, in_dim=64):
        super().__init__()
        self.backbone = backbone
        self.transformer = transfomer
        hidden_dim = transfomer.d_model
        self.input_proj = nn.Conv1d(backbone.num_channels, hidden_dim, kernel_size=1)
        self.bins=bins
        self.max_bin=max_bin
        self.attention_pool = AttentionPool(hidden_dim)
        self.project=nn.Sequential(
            Rearrange('(b n) c -> b c n', n=bins*5),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=15, padding=7,groups=hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Conv1d(hidden_dim, embed_dim, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2)
        )

        self.cnn=nn.Sequential(
            nn.Conv1d(embed_dim, embed_dim, kernel_size=15, padding=7),
            nn.BatchNorm1d(embed_dim),
            nn.MaxPool1d(kernel_size=5, stride=5),
            nn.ReLU(inplace=True),
            nn.Conv1d(embed_dim, embed_dim, kernel_size=1),
            nn.Dropout(0.2),
            Rearrange('b c n -> b n c')
        )
    def forward(self, input):
        input=rearrange(input,'b n c l -> (b n) c l')
        src = self.backbone(input)
        src = self.input_proj(src)
        src = self.transformer(src)
        src = self.attention_pool(src)
        src = self.project(src)
        src = self.cnn(src)
        return src

def build_backbone():
    model = CNN()
    return model

def build_transformer(hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers):
    return Transformer(
        d_model=hidden_dim,
        dropout=dropout,
        nhead=nheads,
        dim_feedforward=dim_feedforward,
        num_encoder_layers=enc_layers,
        num_decoder_layers=dec_layers
    )

'''
_______________________
Create the model with the current params
_______________________
'''

def build_pretrain_model_hic(device, bins=200, nheads=4, hidden_dim=512, embed_dim=256, dim_feedforward=1024, enc_layers=1, dec_layers=2, dropout=0.2):
    backbone = build_backbone()
    transformer = build_transformer(hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers)
    pretrain_model = Tranmodel(
            backbone=backbone,
            transfomer=transformer
        )

    model_dict = pretrain_model.state_dict()
    pretrain_dict = torch.load("Parameter File", map_location='cpu')
    pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict}
    model_dict.update(pretrain_dict)
    pretrain_model.load_state_dict(model_dict)
    return pretrain_model

'''
_______________________
Create the embeddings
_______________________
'''

def create_embedding(start_pos, end_pos, curr_chr='Your interested chromosome', bins=200, nheads=4, hidden_dim=512, embed_dim=256, dim_feedforward=1024, enc_layers=1, dec_layers=2, dropout=0.2): ### curr_chr is the input chromosome index because you would fetch the genomic info from this chromosome. e.g. '15'

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    cl='GM12878'
    dnase_data, ref_data = prepare_train_data(cl,curr_chr)

    ### Create the input
    input_x = []
    input_x.append(torch.cat((ref_data[curr_chr][start_pos:end_pos],dnase_data[curr_chr][start_pos:end_pos]),dim=1))
    input_x = torch.stack(input_x)

    input_x = input_x.float().to(device)

    model= build_pretrain_model_hic(device)
    model.cuda()

    output = model(input_x)
    torch.save(output, "chr_%d_%d.pt"%(start_pos, end_pos))

    return output


In [12]:
cell_type = 'GM12878'
dnase_data, ref_data = prepare_train_data(cell_type,'1')

  dnase = pickle.load(f)


In [17]:
(dnase_data['1']!=0).any(-1).squeeze()

tensor([False, False, False,  ..., False, False, False])

In [18]:
np.where((dnase_data['1']!=0).any(-1).squeeze())

(array([     9,     10,     15, ..., 248944, 248945, 248946]),)

In [23]:
dnase_data['1'][784]

tensor([[0., 0., 0.,  ..., 0., 0., 0.]])

In [25]:
import pandas as pd
a = pd.read_pickle("../../data/raw_embeddings/my_dict.pickle")

In [30]:
a['1'][0]

array([ 780, 2080])

In [32]:
ref_data['1'].shape

torch.Size([248956, 4, 1600])

In [33]:
ref_data['1'][0,:,0]

tensor([0, 0, 0, 0], dtype=torch.int8)

In [36]:
ref_data['1'][780,:,300]

tensor([0, 0, 0, 1], dtype=torch.int8)

In [42]:
ref_data['1'].shape

torch.Size([248956, 4, 1600])

In [43]:
dnase_data['1'].shape

torch.Size([248956, 1, 1600])