Extract consensus sequence patterns from stage 1 models

In [3]:
import numpy as np
from copy import deepcopy

import torch

from torch import nn
from torch.autograd import grad
from torch_fftconv import fft_conv1d, FFTConv1d

class SimpleNet(nn.Module):
    def __init__(self):

        super(SimpleNet, self).__init__()
        self.conv = nn.Conv1d(4,  40, kernel_size=51, padding=25)
        self.conv_inr = nn.Conv1d(4,  10, kernel_size=15, padding=7)
        self.activation = nn.Sigmoid()
        
        self.deconv = FFTConv1d(80,  10, kernel_size=601, padding=300)
        self.deconv_inr = nn.ConvTranspose1d(20,  10, kernel_size=15, padding=7)
        self.softplus = nn.Softplus()

    def forward(self, x):
        y = torch.cat([self.conv(x), self.conv(x.flip([1,2])).flip([2])], 1)
        y_inr = self.conv_inr(x)
        yact = self.activation(y)
        y_inr_act = self.activation(y_inr)
        y_pred = self.softplus(self.deconv(yact*y)+self.deconv_inr(y_inr_act*y_inr))
        return y_pred

class SimpleNetModified(nn.Module):
    def __init__(self, input_channels=4, output_channels=4):
        super(SimpleNetModified, self).__init__()

        self.conv = nn.Conv1d(input_channels, 40, kernel_size=51, padding=25)
        self.activation = nn.Sigmoid()

        # Separate deconv layers for labels (3 channels) and SSE (1 channel)
        self.deconv_labels = FFTConv1d(40, output_channels-1, kernel_size=601, padding=300)  # 3 channels
        self.deconv_SSE = FFTConv1d(40, 1, kernel_size=601, padding=300)  # 1 channel

    def forward(self, x):
        y = self.conv(x)  # Shape: (batch_size, 80, 5000)
        yact = self.activation(y) * y  # Shape: (batch_size, 80, 5000)
        
        # Separate predictions for labels and SSE
        y_pred_label = F.softmax(self.deconv_labels(yact), dim=1)  # Shape: (batch_size, 3, 5000)
        y_pred_SSE = torch.sigmoid(self.deconv_SSE(yact))  # Shape: (batch_size, 1, 5000)
        
        # Concatenate the outputs along channel dimension
        y_pred = torch.cat([y_pred_label, y_pred_SSE], dim=1)  # Shape: (batch_size, 4, 5000)
        
        # Crop the output to match label shape: 5000 -> 4000 by removing 500 from each side
        return y_pred[:, :, 500:-500]  # Final shape: (batch_size, 4, 4000)
    
import glob
nets = []
filenames = []
for i in range(12):
    f = f'./models/model.40000.rep'+str(i)+'.pth'
    net = SimpleNetModified()
    net.load_state_dict(torch.load(f, map_location='cpu'))  # Added map_location here
    net.cpu()
    nets.append(net)
    filenames.append(f)
          

In [4]:
mats = [nets[i].conv.weight.detach().numpy() for i in range(len(nets))]
mats_norm = [mat - mat.mean(axis=1, keepdims=True) for mat in mats]
demats = [nets[i].deconv_labels.weight.detach().numpy() for i in range(len(nets))]

In [5]:
from scipy.stats import pearsonr
from scipy.signal import correlate2d
from matplotlib import pyplot as plt
%matplotlib inline

from numba import njit
import random

@njit( )
def cross_corr(x, y):
    cors = []
    i=0
    for j in range(y.shape[1]-5):
        minlen = np.fmin(x.shape[1]-i, y.shape[1]-j)
        cors.append(np.fmax(np.corrcoef(x.flatten(), 
                                 np.concatenate((y[:,j:],y[:,:j]), axis=1).flatten())[0,1],
                           np.corrcoef(x.flatten(), 
                                 np.concatenate((y[:,j:],y[:,:j]), axis=1)[::-1,::-1].flatten())[0,1]))
    return np.array(cors)

def comparemats(mats):
    crossmats = {}
    validmats = {}
    for ii in range(len(mats)):
        for jj in range(ii+1, len(mats)):
            cors = []
            for i in range(40):
                cors_row = []
                for j in range(40):
                    cors_row.append(cross_corr(mats_norm[ii][i], mats_norm[jj][j]).max())
                cors.append(cors_row)
            cors = np.array(cors)
            crossmats[(ii,jj)]=cors
            validmats[(ii,jj)]= (np.abs(mats_norm[ii]).max(axis=2).max(axis=1)[:,None]>0.1) & (np.abs(mats_norm[jj]).max(axis=2).max(axis=1)[None,:]>0.1)

    return crossmats, validmats

In [6]:
crossmats, validmats = comparemats(mats)

In [7]:
matchlist = []
matchscores = []
for i in range(len(nets)):
    for j in range(i+1, len(nets)):
        mat = crossmats[(i, j)].copy()
        mat[mat < mat.max(axis=1, keepdims=True)] = 0
        mat[mat < mat.max(axis=0, keepdims=True)] = 0
        mat[~validmats[(i, j)]] = 0
        # Changed np.object to object here
        matchlist.append(np.argwhere(mat > 0.95).astype(str) + np.array(["_"+str(i), "_"+str(j)], dtype=object)[None, :])
        matchscores.append(mat[mat > 0.95])

In [8]:
import seaborn as sns
sns.set(rc={"figure.dpi":300, 'savefig.dpi':300})
sns.set_style("white")

import logomaker
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
%matplotlib inline
prop_cycle = plt.rcParams['axes.prop_cycle']
itercolor =  prop_cycle()
def plotfun(motifpwm, title=None, ax=None):
    motifpwm = pd.DataFrame(motifpwm,columns=['A','C','G','T'])
    crp_logo = logomaker.Logo(motifpwm,
                              shade_below=.5,
                              fade_below=.5,
                              font_name='Arial Rounded MT Bold',
                             ax=ax)

    # style using Logo methods
    crp_logo.style_spines(visible=False)
    crp_logo.style_spines(spines=['left', 'bottom'], visible=True)
    crp_logo.style_xticks(rotation=90, fmt='%d', anchor=0)
    if title is not None:
        crp_logo.ax.set_title(title)
    # style using Axes methods
    crp_logo.ax.set_ylabel("", labelpad=-1)
    crp_logo.ax.xaxis.set_ticks_position('none')
    crp_logo.ax.xaxis.set_tick_params(pad=-1)
    return crp_logo



In [9]:
import networkx as nx
from collections import defaultdict
g = nx.Graph()
g.add_weighted_edges_from(np.hstack([np.concatenate(matchlist, axis=0),np.concatenate(matchscores)[:,None]]))


In [10]:
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'DejaVu Sans'  # or e.g., 'Arial', 'DejaVu Sans', 'Helvetica'

In [11]:
from matplotlib.backends.backend_pdf import PdfPages
bestmats = []
bestscores = []
filescores = defaultdict(list)
#with PdfPages('../figures/motif_replicates.pdf') as pdf:

for cc in list(nx.connected_components(g)):
    bestscore = 0
    matinds = []
    motifinds = []
    scores = []

    #_, axes = plt.subplots( figsize=(10,len(cc)*1), nrows=len(cc),ncols=2, dpi=300)
    for i in cc:
        motifind, matind = list(map(int, i.split('_')))
        score = np.sum([crossmats[i,matind][:,motifind].max() if i < matind else crossmats[matind,i][motifind,:].max() for i in np.setdiff1d(range(4),[matind])  ])

        if score > bestscore:
            bestmat = mats[matind][motifind]

        matinds.append(matind)
        motifinds.append(motifind)
        scores.append(score)

        filescores[filenames[matind]].append(score)

    #for ii, i in enumerate(np.argsort(-np.array(scores))):
    #    plotfun(mats[matinds[i]][motifinds[i]].T, ax=axes[ii][0])
    #    axes[ii][0].set_xticks([])
    #    axes[ii][1].plot(np.arange(-300,301), demats[matinds[i]][0,motifinds[i]])
        #axes[ii][1].plot(np.arange(-300,301), demats[matinds[i]][0,40+motifinds[i],::-1])
    #    sns.despine()
    #pdf.savefig()
    bestscores.append(bestscores)
    bestmats.append(bestmat)



In [12]:
selectmats = []
selectdemats = []
selectdemats_rc = []

for cc in list(nx.connected_components(g)):
    if len(cc)>=7:
        for i in cc:
            motifind, matind = list(map(int, i.split('_')))

            if matind == 7:
                selectmats.append(mats[matind][motifind])
                #selectdemats.append(demats[matind][:,[motifind],:])
                #selectdemats_rc.append(demats[matind][:,[motifind+40],:])


In [13]:
len(selectmats) 

24

In [14]:
from matplotlib.backends.backend_pdf import PdfPages
with PdfPages('../figures/motif_replicates_selected.pdf') as pdf:
    for ii, mat in enumerate(selectmats):
        fig, axes = plt.subplots( figsize=(10,3), nrows=1,ncols=1, dpi=300)
        plotfun(mat.T, ax=axes)
        pdf.savefig()
        plt.close()


findfont: Font family 'Arial Rounded MT Bold' not found.
findfont: Font family 'Arial Rounded MT Bold' not found.
findfont: Font family 'Arial Rounded MT Bold' not found.
findfont: Font family 'Arial Rounded MT Bold' not found.
findfont: Font family 'Arial Rounded MT Bold' not found.
findfont: Font family 'Arial Rounded MT Bold' not found.
findfont: Font family 'Arial Rounded MT Bold' not found.
findfont: Font family 'Arial Rounded MT Bold' not found.
findfont: Font family 'Arial Rounded MT Bold' not found.
findfont: Font family 'Arial Rounded MT Bold' not found.
findfont: Font family 'Arial Rounded MT Bold' not found.
findfont: Font family 'Arial Rounded MT Bold' not found.
findfont: Font family 'Arial Rounded MT Bold' not found.
findfont: Font family 'Arial Rounded MT Bold' not found.
findfont: Font family 'Arial Rounded MT Bold' not found.
findfont: Font family 'Arial Rounded MT Bold' not found.
findfont: Font family 'Arial Rounded MT Bold' not found.
findfont: Font family 'Arial Ro