In [None]:
from torch.optim import lr_scheduler
import os, pickle, time
import random
import numpy as np
import torch.optim as optim
import torch
from torch.utils.data import DataLoader,SubsetRandomSampler,Dataset
from scipy.stats import pearsonr,spearmanr
from scipy.sparse import load_npz,csr_matrix,save_npz
import subprocess
from torch.nn import MultiheadAttention
from einops.layers.torch import Rearrange


'''
_________________________
Necessary shit Zhuohan didn't include in the file
_________________________
'''

from torch import nn
from typing import Optional
from torch import Tensor

def pad_signal_matrix(matrix, pad_len=300):
    paddings = np.zeros(pad_len).astype('float32')
    dmatrix = np.vstack((paddings, matrix[:, -pad_len:]))[:-1, :]
    umatrix = np.vstack((matrix[:, :pad_len], paddings))[1:, :]
    return np.hstack((dmatrix, matrix, umatrix))

'''
_________________________
Prepare the data
_________________________
'''

def pad_seq_matrix(matrix, pad_len=300):
    paddings = np.zeros((1, 4, pad_len)).astype('int8')
    dmatrix = np.concatenate((paddings, matrix[:, :, -pad_len:]), axis=0)[:-1, :, :]
    umatrix = np.concatenate((matrix[:, :, :pad_len], paddings), axis=0)[1:, :, :] # WTF
    return np.concatenate((dmatrix, matrix, umatrix), axis=2)

def load_ref_genome(chr):
    #ref_path = 'Your Path'
    #ref_file = os.path.join(ref_path, 'Your File')
    #ref_file = '/home/gridsan/gschuette/binz_group_shared/zlao/data/hg38/chr1.fa'
    ref_file = '/home/gridsan/gschuette/binz_group_shared/zlao/data/hg38/chr1.npz'
    ref_gen_data = load_npz(ref_file).toarray().reshape(4, -1, 1000).swapaxes(0, 1)
    return torch.tensor(pad_seq_matrix(ref_gen_data))

def load_dnase(dnase_seq):
    dnase_seq = np.expand_dims(pad_signal_matrix(dnase_seq.reshape(-1, 1000)), axis=1)
    return torch.tensor(dnase_seq)

def prepare_train_data(cl,chrs):
    dnase_data={}
    ref_data={}
    #dnase_path = 'Your Path'
    #with open(dnase_path + 'Your File', 'rb') as f:
    with open('/home/gridsan/gschuette/binz_group_shared/zlao/data/GM12878_dnase.pickle','rb') as f:
        dnase = pickle.load(f)
    dnase_seq = dnase[chrs]

    dnase_data[chrs]=load_dnase(dnase_seq.toarray())
    ref_data[chrs]=load_ref_genome(chrs)
    return dnase_data, ref_data

def load_data(data,ref_data,dnase_data):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    args = get_args()
    data=data.numpy().astype(int)
    input=[]
    for i in range(data.shape[0]):
        chr = args.curr_chr
        s,e=data[i][1],data[i][2]
        input.append(torch.cat((ref_data[chr][s:e],dnase_data[chr][s:e]),dim=1))

    input= torch.stack(input)

    return input.float().to(device)

'''
_______________________
Model details (backbone + transformer encode + dimension reduction)
_______________________
'''

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        conv_kernel_size1 = 10
        conv_kernel_size2 = 8
        pool_kernel_size1 = 5
        pool_kernel_size2 = 4
        self.conv_net = nn.Sequential(
            nn.Conv1d(5, 256, kernel_size=conv_kernel_size1),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv1d(256, 256, kernel_size=conv_kernel_size1),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=pool_kernel_size1, stride=pool_kernel_size1),
            nn.BatchNorm1d(256),
            nn.Dropout(p=0.1),
            nn.Conv1d(256, 360, kernel_size=conv_kernel_size2),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv1d(360, 360, kernel_size=conv_kernel_size2),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=pool_kernel_size2, stride=pool_kernel_size2),
            nn.BatchNorm1d(360),
            nn.Dropout(p=0.1),
            nn.Conv1d(360, 512, kernel_size=conv_kernel_size2),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2),
            nn.Conv1d(512, 512, kernel_size=conv_kernel_size2),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(512),
            nn.Dropout(p=0.2))
        self.num_channels = 512
    def forward(self, x):
        out = self.conv_net(x)
        return out

def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu"):
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = _get_activation_fn(activation)
    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos
    def forward(self, src,
                    src_mask: Optional[Tensor] = None,
                    src_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None):
        src2 = self.norm1(src)
        q = k = self.with_pos_embed(src2, pos)
        src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)
        return src

def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
    def forward(self, src,
                mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        output = src
        for layer in self.layers:
            output = layer(output, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask, pos=pos)
        if self.norm is not None:
            output = self.norm(output)
        return output

class Transformer(nn.Module):
    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu"
                 ):
        super().__init__()
        self.num_encoder_layers = num_encoder_layers
        if num_decoder_layers > 0:
            encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                    dropout, activation)
            encoder_norm = nn.LayerNorm(d_model)
            self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        self._reset_parameters()
        self.d_model = d_model
        self.nhead = nhead

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, pos_embed=None, mask=None):
        src = src.permute(2, 0, 1)
        print('transformer_input', src.shape)
        if mask is not None:
            mask = mask.flatten(1)
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
        return memory.transpose(0,1)

class AttentionPool(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.pool_fn = Rearrange('b (n p) d-> b n p d', n=1)
        self.to_attn_logits = nn.Parameter(torch.eye(dim))

    def forward(self, x):
        attn_logits = einsum('b n d, d e -> b n e', x, self.to_attn_logits)
        x = self.pool_fn(x)
        logits = self.pool_fn(attn_logits)

        attn = logits.softmax(dim = -2)
        return (x * attn).sum(dim = -2).squeeze()

class Tranmodel(nn.Module):
    def __init__(self, backbone, transformer, bins=200, max_bin=10, in_dim=64):
        super().__init__()
        self.backbone = backbone
        self.transformer = transformer
        hidden_dim = transformer.d_model
        self.input_proj = nn.Conv1d(backbone.num_channels, hidden_dim, kernel_size=1)
        self.bins=bins
        self.max_bin=max_bin
        self.attention_pool = AttentionPool(hidden_dim)
        self.project=nn.Sequential(
            Rearrange('(b n) c -> b c n', n=bins*5),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=15, padding=7,groups=hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Conv1d(hidden_dim, embed_dim, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2)
        )

        self.cnn=nn.Sequential(
            nn.Conv1d(embed_dim, embed_dim, kernel_size=15, padding=7),
            nn.BatchNorm1d(embed_dim),
            nn.MaxPool1d(kernel_size=5, stride=5),
            nn.ReLU(inplace=True),
            nn.Conv1d(embed_dim, embed_dim, kernel_size=1),
            nn.Dropout(0.2),
            Rearrange('b c n -> b n c')
        )
    def forward(self, input):
        input=rearrange(input,'b n c l -> (b n) c l')
        src = self.backbone(input)
        src = self.input_proj(src)
        src = self.transformer(src)
        src = self.attention_pool(src)
        src = self.project(src)
        src = self.cnn(src)
        return src

def build_backbone():
    model = CNN()
    return model

def build_transformer(hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers):
    return Transformer(
        d_model=hidden_dim,
        dropout=dropout,
        nhead=nheads,
        dim_feedforward=dim_feedforward,
        num_encoder_layers=enc_layers,
        num_decoder_layers=dec_layers
    )

'''
_______________________
Create the model with the current params
_______________________
'''

def build_pretrain_model_hic(device, bins=200, nheads=4, hidden_dim=512, embed_dim=256, dim_feedforward=1024, enc_layers=1, dec_layers=2, dropout=0.2):
    backbone = build_backbone()
    transformer = build_transformer(hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers)
    pretrain_model = Tranmodel(
            backbone=backbone,
            transformer=transformer
        )

    model_dict = pretrain_model.state_dict()
    pretrain_dict = torch.load("Parameter File", map_location='cpu')
    pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict}
    model_dict.update(pretrain_dict)
    pretrain_model.load_state_dict(model_dict)
    return pretrain_model

'''
_______________________
Create the embeddings
_______________________
'''

def create_embedding(start_pos, end_pos, curr_chr='Your interested chromosome', bins=200, nheads=4, hidden_dim=512, embed_dim=256, dim_feedforward=1024, enc_layers=1, dec_layers=2, dropout=0.2): ### curr_chr is the input chromosome index because you would fetch the genomic info from this chromosome. e.g. '15'

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    cl='GM12878'
    dnase_data, ref_data = prepare_train_data(cl,curr_chr)

    ### Create the input
    input_x = []
    input_x.append(torch.cat((ref_data[curr_chr][start_pos:end_pos],dnase_data[curr_chr][start_pos:end_pos]),dim=1))
    input_x = torch.stack(input_x)

    input_x = input_x.float().to(device)

    model= build_pretrain_model_hic(device)
    model.cuda()

    output = model(input_x)
    torch.save(output, "chr_%d_%d.pt"%(start_pos, end_pos))

    return output


In [1]:
import os
import numpy as np
import torch
import pandas as pd
from scipy.sparse import load_npz

##########
# Load functions for the genome. Adapted from what Zhuohan sent me 
def pad_seq_matrix(matrix, pad_len=300):
    paddings = np.zeros((1, 4, pad_len)).astype('int8')
    dmatrix = np.concatenate((paddings, matrix[:, :, -pad_len:]), axis=0)[:-1, :, :]
    umatrix = np.concatenate((matrix[:, :, :pad_len], paddings), axis=0)[1:, :, :]
    return np.concatenate((dmatrix, matrix, umatrix), axis=2)
    
def load_ref_genome(fp):
    #try:
    #    ref_gen_data = load_npz(fp).toarray().reshape(4, -1, 1000).swapaxes(0, 1)
    #except:
    #    print(fp) 
    #    return []
    ref_gen_data = load_npz(fp).toarray().reshape(4, -1, 1000).swapaxes(0, 1)
    return torch.tensor(pad_seq_matrix(ref_gen_data))

##########
# Load functions for DNase data. Adapted from what Zhuohan sent me 
def pad_signal_matrix(matrix, pad_len=300):
    paddings = np.zeros(pad_len).astype('float32')
    dmatrix = np.vstack((paddings, matrix[:, -pad_len:]))[:-1, :]
    umatrix = np.vstack((matrix[:, :pad_len], paddings))[1:, :]
    return np.hstack((dmatrix, matrix, umatrix))
    
def load_dnase(fp,chroms):
    chroms_ = [int(chrom[3:]) if chrom[3:].isnumeric() else chrom[3:] for chrom in chroms]
    dnase_seq_ = pd.read_pickle(fp)
    dnase_seq = {}
    for chrom in chroms: 
        chr = int(chrom[3:]) if chrom[3:].isnumeric() else chrom[3:]
        dnase_seq[chrom] = dnase_seq_[chr]
    del dnase_seq_
    
    for chr in dnase_seq: 
        dnase_seq[chr] = torch.tensor(np.expand_dims(pad_signal_matrix(dnase_seq[chr].toarray().reshape(-1, 1000)), axis=1))

    return dnase_seq

##########
# Class by Greg
class SequencesDataset:

    def __init__(
        self,
        cell_type = 'GM12878',
        alignment = 'hg19',
        data_dir = '../../data/',
        resolution=20_000,
        region_length=1_300_000,
        batch_size=64,
        chroms = None
    ):
        self.dnase_fp = data_dir + f'outside/{cell_type}_{alignment}.pkl' 
        self.res = resolution // 1000
        self.length = region_length // 1000
        self.batch_size = batch_size

        if chroms is None: 
            chroms = [f'chr{k}' for k in [*range(1,23),'X']]
        elif type(chroms) == str:
            chroms = [chroms] 
        self.chroms = chroms
        
        # Ensure all files exist 
        self.genome_fps = {}
        for chrom in chroms: 
            self.genome_fps[chrom] = data_dir + f'outside/{alignment}/{chrom}.npz'
            assert os.path.exists(self.genome_fps[chrom]), self.genome_fps[chrom]

        assert os.path.exists(self.dnase_fp), self.dnase_fp

        # Load the data
        print(f'Loading sequencing data')
        self.load_genome()
        self.load_dnase()
        print(f'Sequencing data loading complete')

        # Prepare to iterate
        self.curr_chrom = 0
        self.curr_idx = 0 
        self.inner_idx = 0 

    ##################################################
    # Loading functions
    def load_genome(self):
        self.genome = {}
        for chrom,fp in self.genome_fps.items():
            self.genome[chrom] = load_ref_genome(fp)

    def load_dnase(self):
        self.dnase = load_dnase(self.dnase_fp,self.chroms)

    ##################################################
    # Data loading functions 
    def _update_sample_idx_(self):
        
        self.curr_idx+= self.res

        c = self.chroms[self.curr_chrom]
        if self.curr_idx + self.length >= seq_ds.genome[c].shape[0]:
            self.curr_idx = 0
            self.curr_chrom = (self.curr_chrom+1) % len(self.chroms)

    def _get_sample_(self,idx=None):
        if idx is None:
            c,i,l = self.chroms[self.curr_chrom],self.curr_idx,self.length
        else:
            c,i,l = idx # uses kb resolution for i,l (see self.fetch for comparison) 

        ref_seq = self.genome[c][i:i+l,...]
        dnase_seq = self.dnase[c][i:i+l,...]

        return ref_seq, dnase_seq 
    
    def _region_is_valid_(self):

        ref_seq, dnase_seq = self._get_sample_()

        return ~(ref_seq==0).all(1).any(), ref_seq, dnase_seq
        
    def __iter__(self):
        self.curr_chrom = 0
        self.curr_idx = 0
        self.inner_idx = 0
        return self
    
    def __next__(self):

        if (self.curr_chrom == 0) and (self.curr_idx == 0) and self.inner_idx > 0:
            self.inner_idx == 0 
            raise StopIteration # Back to the start!

        batch_ref_seq = []
        batch_dnase_seq = []
        while len(batch_ref_seq) < self.batch_size:

            is_valid, ref_seq, dnase_seq = self._region_is_valid_()
            if is_valid:
                batch_ref_seq.append(ref_seq)
                batch_dnase_seq.append(dnase_seq)
                self.inner_idx+= 1

            self._update_sample_idx_()

            if (self.curr_chrom == 0) and (self.curr_idx == 0):
                # Went back to the start. This epoch is over! 
                break

        if len(batch_ref_seq) == 0:
            # We must have broken out of the while loop before finding any valid samples
            self.inner_idx == 0 
            raise StopIteration

        # Stack the subobjects as desired.
        batch_ref_seq = torch.stack(batch_ref_seq,dim=0) 
        batch_dnase_seq = torch.stack(batch_dnase_seq,dim=0) 
        
        return batch_ref_seq, batch_dnase_seq

    # For easier interfacing with my other classes. 
    def _fetch_one_(self,idx):
        '''
        idx = tuple(chrom,start_idx (in bp), region length (in bp))
        '''
        c,s,l = idx

        assert s%1000 == 0 
        s//= 1000
        assert l%1000 == 0
        l//= 1000
        if c not in self.chroms:
            c = f'chr{c}'
        assert c in self.chroms
        
        ref_seq, dnase_seq = self._get_sample_(idx=(c,s,l))
        return ref_seq, dnase_seq
        
    def is_valid(self,idx):
        '''
        idx = same as described in self._fetch_one_
        '''
        ref_seq, dnase_seq = self._fetch_one_(idx)
        
        return ~(ref_seq==0).all(1).any(), ref_seq, dnase_seq
    
    def fetch(self,idxs,return_invalid=False):
        '''
        idxs is a list of tuples with the shape desribed for idx in the self._fetch_one_ function above.
        '''
        ref_seqs = []
        dnase_seqs = []
        idxs_return = []
        for idx in idxs: 
            is_valid, ref_seq, dnase_seq = self.is_valid(idx) 
            if return_invalid or is_valid: 
                idxs_return.append(idx)
                ref_seqs.append(ref_seq)
                dnase_seqs.append(dnase_seq)

        ref_seqs = torch.stack(ref_seqs,dim=0)
        dnase_seqs = torch.stack(dnase_seqs,dim=0) 
        return idxs_return, ref_seqs, dnase_seqs
        

In [None]:
seq_ds = SequencesDataset()

Loading sequencing data


In [None]:
for i,(ref_seq,dnase_seq) in enumerate(seq_ds): 
    if (seq_ds.curr_chrom == 0) and (seq_ds.curr_idx == 0) and seq_ds.inner_idx > 0:
        print(i) 
        continue
    
    assert ref_seq.shape == torch.Size([64,1300,4,1600])
    assert dnase_seq.shape == torch.Size([64,1300,1,1600])
i

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import MultiheadAttention
from torch import Tensor
from einops.layers.torch import Rearrange
from typing import Optional
import copy

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        conv_kernel_size1 = 10
        conv_kernel_size2 = 8
        pool_kernel_size1 = 5
        pool_kernel_size2 = 4
        self.conv_net = nn.Sequential(
            nn.Conv1d(5, 256, kernel_size=conv_kernel_size1),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv1d(256, 256, kernel_size=conv_kernel_size1),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=pool_kernel_size1, stride=pool_kernel_size1),
            nn.BatchNorm1d(256),
            nn.Dropout(p=0.1),
            nn.Conv1d(256, 360, kernel_size=conv_kernel_size2),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv1d(360, 360, kernel_size=conv_kernel_size2),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=pool_kernel_size2, stride=pool_kernel_size2),
            nn.BatchNorm1d(360),
            nn.Dropout(p=0.1),
            nn.Conv1d(360, 512, kernel_size=conv_kernel_size2),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2),
            nn.Conv1d(512, 512, kernel_size=conv_kernel_size2),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(512),
            nn.Dropout(p=0.2))
        self.num_channels = 512
    def forward(self, x):
        out = self.conv_net(x)
        return out

def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu"):
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = _get_activation_fn(activation)
    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos
    def forward(self, src,
                    src_mask: Optional[Tensor] = None,
                    src_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None):
        src2 = self.norm1(src)
        q = k = self.with_pos_embed(src2, pos)
        src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)
        return src

def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
    def forward(self, src,
                mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        output = src
        for layer in self.layers:
            output = layer(output, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask, pos=pos)
        if self.norm is not None:
            output = self.norm(output)
        return output

class Transformer(nn.Module):
    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu"
                 ):
        super().__init__()
        self.num_encoder_layers = num_encoder_layers
        if num_decoder_layers > 0:
            encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                    dropout, activation)
            encoder_norm = nn.LayerNorm(d_model)
            self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        self._reset_parameters()
        self.d_model = d_model
        self.nhead = nhead

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, pos_embed=None, mask=None):
        src = src.permute(2, 0, 1)
        if mask is not None:
            mask = mask.flatten(1)
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
        return memory.transpose(0,1)

class AttentionPool(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.pool_fn = Rearrange('b (n p) d-> b n p d', n=1)
        self.to_attn_logits = nn.Parameter(torch.eye(dim))

    def forward(self, x):
        attn_logits = einsum('b n d, d e -> b n e', x, self.to_attn_logits)
        x = self.pool_fn(x)
        logits = self.pool_fn(attn_logits)

        attn = logits.softmax(dim = -2)
        return (x * attn).sum(dim = -2).squeeze()

class Tranmodel(nn.Module):
    def __init__(self, backbone, transformer, bins=200, max_bin=10, in_dim=64,embed_dim=256):
        super().__init__()
        self.backbone = backbone
        self.transformer = transformer
        '''
        if backbone is None:
            self.backbone = CNN()
        else:
            self.backbone = backbone
        if transformer is None: 
            self.transformer = **kwargs
        self.transformer = transformer
        '''
        hidden_dim = transformer.d_model
        self.input_proj = nn.Conv1d(backbone.num_channels, hidden_dim, kernel_size=1)
        self.bins=bins
        self.max_bin=max_bin
        self.attention_pool = AttentionPool(hidden_dim)
        self.project=nn.Sequential(
            Rearrange('(b n) c -> b c n', n=bins*5),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=15, padding=7,groups=hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Conv1d(hidden_dim, embed_dim, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2)
        )

        self.cnn=nn.Sequential(
            nn.Conv1d(embed_dim, embed_dim, kernel_size=15, padding=7),
            nn.BatchNorm1d(embed_dim),
            nn.MaxPool1d(kernel_size=5, stride=5),
            nn.ReLU(inplace=True),
            nn.Conv1d(embed_dim, embed_dim, kernel_size=1),
            nn.Dropout(0.2),
            Rearrange('b c n -> b n c')
        )
        
    def forward(self, input):
        input=rearrange(input,'b n c l -> (b n) c l')
        src = self.backbone(input)
        src = self.input_proj(src)
        src = self.transformer(src)
        src = self.attention_pool(src)
        src = self.project(src)
        src = self.cnn(src)
        return src

    ###########################
    # For ease, I'm not generalizing this for now. Stick with the defaults used to this point in the CNN
    def get_pretrained_model(
        bins=200,
        nheads=4,
        hidden_dim=512,
        embed_dim=256,
        dim_feedforward=1024,
        enc_layers=1,
        dec_layers=2,
        dropout=0.2,
        param_filepath="../../data/models/hic_GM12878_transformer.pt"
    ):

        backbone = CNN()
        transformer = Transformer(
            d_model=hidden_dim,
            dropout=dropout,
            nhead=nheads,
            dim_feedforward=dim_feedforward,
            num_encoder_layers=enc_layers,
            num_decoder_layers=dec_layers
        )

        model = Tranmodel(
            backbone=backbone,
            transformer=transformer,
            embed_dim=embed_dim
        )

        model.load_state_dict(torch.load(param_filepath,map_location='cpu'))
        
        return model

        


In [3]:
model = Tranmodel.get_pretrained_model()#param_filepath='./test.pt')

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import MultiheadAttention
from torch import Tensor
from einops.layers.torch import Rearrange
from typing import Optional
import copy

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        conv_kernel_size1 = 10
        conv_kernel_size2 = 8
        pool_kernel_size1 = 5
        pool_kernel_size2 = 4
        self.conv_net = nn.Sequential(
            nn.Conv1d(5, 256, kernel_size=conv_kernel_size1),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv1d(256, 256, kernel_size=conv_kernel_size1),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=pool_kernel_size1, stride=pool_kernel_size1),
            nn.BatchNorm1d(256),
            nn.Dropout(p=0.1),
            nn.Conv1d(256, 360, kernel_size=conv_kernel_size2),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Conv1d(360, 360, kernel_size=conv_kernel_size2),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=pool_kernel_size2, stride=pool_kernel_size2),
            nn.BatchNorm1d(360),
            nn.Dropout(p=0.1),
            nn.Conv1d(360, 512, kernel_size=conv_kernel_size2),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2),
            nn.Conv1d(512, 512, kernel_size=conv_kernel_size2),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(512),
            nn.Dropout(p=0.2))
        self.num_channels = 512
    def forward(self, x):
        out = self.conv_net(x)
        return out

def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu"):
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = _get_activation_fn(activation)
    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos
    def forward(self, src,
                    src_mask: Optional[Tensor] = None,
                    src_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None):
        src2 = self.norm1(src)
        q = k = self.with_pos_embed(src2, pos)
        src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)
        return src

def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
    def forward(self, src,
                mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        output = src
        for layer in self.layers:
            output = layer(output, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask, pos=pos)
        if self.norm is not None:
            output = self.norm(output)
        return output

class Transformer(nn.Module):
    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu"
                 ):
        super().__init__()
        self.num_encoder_layers = num_encoder_layers
        if num_decoder_layers > 0:
            encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                    dropout, activation)
            encoder_norm = nn.LayerNorm(d_model)
            self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        self._reset_parameters()
        self.d_model = d_model
        self.nhead = nhead

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, pos_embed=None, mask=None):
        src = src.permute(2, 0, 1)
        if mask is not None:
            mask = mask.flatten(1)
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
        return memory.transpose(0,1)

class AttentionPool(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.pool_fn = Rearrange('b (n p) d-> b n p d', n=1)
        self.to_attn_logits = nn.Parameter(torch.eye(dim))

    def forward(self, x):
        attn_logits = einsum('b n d, d e -> b n e', x, self.to_attn_logits)
        x = self.pool_fn(x)
        logits = self.pool_fn(attn_logits)

        attn = logits.softmax(dim = -2)
        return (x * attn).sum(dim = -2).squeeze()

class Tranmodel(nn.Module):
    def __init__(self, backbone, transformer, bins=200, max_bin=10, in_dim=64,embed_dim=256):
        super().__init__()
        self.backbone = backbone
        self.transformer = transformer
        '''
        if backbone is None:
            self.backbone = CNN()
        else:
            self.backbone = backbone
        if transformer is None: 
            self.transformer = **kwargs
        self.transformer = transformer
        '''
        hidden_dim = transformer.d_model
        self.input_proj = nn.Conv1d(backbone.num_channels, hidden_dim, kernel_size=1)
        self.bins=bins
        self.max_bin=max_bin
        self.attention_pool = AttentionPool(hidden_dim)
        self.project=nn.Sequential(
            Rearrange('(b n) c -> b c n', n=bins*5),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=15, padding=7,groups=hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Conv1d(hidden_dim, embed_dim, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2)
        )

        self.cnn=nn.Sequential(
            nn.Conv1d(embed_dim, embed_dim, kernel_size=15, padding=7),
            nn.BatchNorm1d(embed_dim),
            nn.MaxPool1d(kernel_size=5, stride=5),
            nn.ReLU(inplace=True),
            nn.Conv1d(embed_dim, embed_dim, kernel_size=1),
            nn.Dropout(0.2),
            Rearrange('b c n -> b n c')
        )
        
    def forward(self, input):
        input=rearrange(input,'b n c l -> (b n) c l')
        src = self.backbone(input)
        src = self.input_proj(src)
        src = self.transformer(src)
        src = self.attention_pool(src)
        src = self.project(src)
        src = self.cnn(src)
        return src

    ###########################
    # For ease, I'm not generalizing this for now. Stick with the defaults used to this point in the CNN
    def get_pretrained_model(
        bins=200,
        nheads=4,
        hidden_dim=512,
        embed_dim=256,
        dim_feedforward=1024,
        enc_layers=1,
        dec_layers=2,
        dropout=0.2,
        param_filepath="../../data/models/hic_GM12878_transformer.pt"
    ):

        backbone = CNN()
        transformer = Transformer(
            d_model=hidden_dim,
            dropout=dropout,
            nhead=nheads,
            dim_feedforward=dim_feedforward,
            num_encoder_layers=enc_layers,
            num_decoder_layers=dec_layers
        )

        model = Tranmodel(
            backbone=backbone,
            transformer=transformer,
            embed_dim=embed_dim
        )

        model.load_state_dict(torch.load(param_filepath,map_location='cpu'))
        
        return model

def build_backbone():
    model = CNN()
    return model

def build_transformer(hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers):
    return Transformer(
        d_model=hidden_dim,
        dropout=dropout,
        nhead=nheads,
        dim_feedforward=dim_feedforward,
        num_encoder_layers=enc_layers,
        num_decoder_layers=dec_layers
    )
    
def build_pretrain_model_hic(device, bins=200, nheads=4, hidden_dim=512, embed_dim=256, dim_feedforward=1024, enc_layers=1, dec_layers=2, dropout=0.2):
    backbone = build_backbone()
    transformer = build_transformer(hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers)
    pretrain_model = Tranmodel(
            backbone=backbone,
            transformer=transformer,
            embed_dim=embed_dim
        )

    model_dict = pretrain_model.state_dict()
    pretrain_dict = torch.load("../../data/models/hic_GM12878_transformer_og.pt", map_location='cpu')
    pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict}
    model_dict.update(pretrain_dict)
    pretrain_model.load_state_dict(model_dict)
    return pretrain_model


model = build_pretrain_model_hic('cpu') 
torch.save(model.state_dict(),'../../data/models/hic_GM12878_transformer.pt')


In [44]:
def build_backbone():
    model = CNN()
    return model

def build_transformer(hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers):
    return Transformer(
        d_model=hidden_dim,
        dropout=dropout,
        nhead=nheads,
        dim_feedforward=dim_feedforward,
        num_encoder_layers=enc_layers,
        num_decoder_layers=dec_layers
    )
    
def build_pretrain_model_hic(device, bins=200, nheads=4, hidden_dim=512, embed_dim=256, dim_feedforward=1024, enc_layers=1, dec_layers=2, dropout=0.2):
    backbone = build_backbone()
    transformer = build_transformer(hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers)
    pretrain_model = Tranmodel(
            backbone=backbone,
            transformer=transformer
        )

    model_dict = pretrain_model.state_dict()
    pretrain_dict = torch.load("../../data/models/hic_GM12878_transformer.pt", map_location='cpu')
    pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict}
    model_dict.update(pretrain_dict)
    pretrain_model.load_state_dict(model_dict)
    return pretrain_model

In [45]:
model = build_pretrain_model_hic('cpu') 

In [49]:
torch.save(model.state_dict(),'test.pt')

In [1]:
import torch

In [3]:
a = torch.load("../../data/models/hic_GM12878_transformer.pt",map_location='cpu')

In [6]:
a

OrderedDict([('pretrain_model.feature_pos_encoding',
              tensor([[[-0.6822, -1.6249,  1.1352,  ...,  1.3716, -0.7261,  0.2279],
                       [ 1.5814,  0.9985,  0.1154,  ...,  0.5484, -0.1557,  0.3486],
                       [ 0.0825, -0.9254,  0.3727,  ...,  0.3910,  1.0148,  0.8607],
                       ...,
                       [ 0.9018, -0.4964, -1.2406,  ...,  0.7593,  0.2301, -0.9925],
                       [-0.8628,  0.3690, -0.8889,  ...,  0.3373,  0.9038,  0.9924],
                       [ 0.6486,  1.2263,  0.5698,  ...,  0.9522, -0.4027,  1.6026]]])),
             ('pretrain_model.backbone.conv_net.0.weight',
              tensor([[[ 0.8392, -4.0774, -4.8250,  ..., -1.9907,  1.2839, -5.0225],
                       [-4.0130, -2.2366, -0.3131,  ...,  0.6559, -2.7559,  0.1112],
                       [ 0.1499, -4.2560,  0.5277,  ..., -0.2887, -2.6812, -4.2274],
                       [-5.4168,  1.2491, -1.5418,  ..., -4.4101, -5.6384,  0.7563],
      