In [1]:
import os

os.environ["OMP_NUM_THREADS"] = "20" 

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib

from tqdm.auto import tqdm as tqdm_auto
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import importlib
from scipy import stats 

import torch
import torch.cuda
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F 
import pytorch_lightning as pl

import matplotlib.colors as clr

In [3]:
import dataset_regressiondiffusion as dataset
importlib.reload(dataset)


from lit_regressor import RNARegressor
import legnet_difgenerator
import legnet_classifier

In [4]:
def seed_everything(seed):
    # random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [5]:
batch_size = 1024
num_workers = 16
batch_per_epoch = 1
device= torch.device('cuda:1')
CELL_TYPE_FILTER = 'c17'
seq_len = 240 # UTR3
epoch = 267
num = 1
seed_everything(42)

In [6]:
PATH_FROM = '../../../data/UTR3_zinb_norm_singleref_2023-05-23.csv'
df = pd.read_csv(PATH_FROM)

df = df[df.cell_type == CELL_TYPE_FILTER].reset_index(drop=True)
scores = (df['1']*1+df['2']*2+df['3']*3+df['4']*4) / df[['1', '2', '3', '4']].sum(axis=1)
df['mass_center'] = scores

In [7]:
df

Unnamed: 0,seq,cell_type,replicate,1,2,3,4,fold,mass_center
0,TGCAGTTTTGACCTCCCAGGCTCAAGCGATCCTCCTGCCTCAGCCT...,c17,1,21.945857,36.076924,62.723318,25.114635,val,2.623929
1,TGCAGTTTTGACCTCCCAGGCTCAAGCGATCCTCCTGCCTCAGCCT...,c17,2,19.848040,34.344834,35.704787,42.941895,val,2.765890
2,ATCAAAAAGCAGGCCAGATTCTAATCAAAATCAGGTAAATTTTAAT...,c17,1,24.996422,38.172301,39.674662,59.162970,train,2.820981
3,ATCAAAAAGCAGGCCAGATTCTAATCAAAATCAGGTAAATTTTAAT...,c17,2,28.542939,24.460446,42.991478,53.016374,train,2.808538
4,ATTTTAGTTTGCCCAAATAATATCTTGAAAATGCTCTGAATTTTAC...,c17,1,7.785771,4.782926,12.253139,28.715463,val,3.156171
...,...,...,...,...,...,...,...,...,...
56847,TGTGCTTCCTAAGAGTACAAACCTGAGCATATGTCCAGGCTTGCAA...,c17,2,23.535406,25.189248,31.332772,45.950563,train,2.791208
56848,TAGGTGGTGATCTTAAATGGGTGAGATGGAACGAGAGCACACATTA...,c17,1,19.259539,30.975137,20.634468,16.226516,train,2.388400
56849,TAGGTGGTGATCTTAAATGGGTGAGATGGAACGAGAGCACACATTA...,c17,2,18.755488,31.065221,28.372554,14.724238,train,2.420433
56850,AGGAGGCAACTGTGGCATTGCTTCCTTAACCAGCTCATGGTGTGTG...,c17,1,12.384384,33.298272,34.982940,33.136733,train,2.780933


In [8]:
df['A'] = df["seq"].str.count('A')/seq_len
df['C'] = df["seq"].str.count('C')/seq_len
df['G'] = df["seq"].str.count('G')/seq_len
df['T'] = df["seq"].str.count('T')/seq_len

subdf = df.sort_values(by='mass_center', ascending = True).reset_index(drop=True)
subdf = subdf.drop(['1','2', '3','4'], axis=1)
subdf

Unnamed: 0,seq,cell_type,replicate,fold,mass_center,A,C,G,T
0,GGCTCCACACAGACTAACGTAGGCACTATAAGGACCAGCCCAACCC...,c17,1,train,1.277848,0.254167,0.295833,0.254167,0.195833
1,CACCTTCACTAGAAATGTCCCATCATCGTGGGAGGGGAGCAGGGCA...,c17,1,train,1.287805,0.262500,0.233333,0.329167,0.175000
2,CGGGGTGTCGGTAGCGTCTTAGCCAAGAGTCCAATTAAAGAACGAA...,c17,2,train,1.296866,0.233333,0.283333,0.337500,0.145833
3,AGAGACTTCTTTTCATTGAGGCTTCGTAAAGTTTTCCATTTTGATT...,c17,2,train,1.334953,0.341667,0.137500,0.145833,0.375000
4,CCAGCTACCATGGGAACGCAAGGCAGCAACTCTCTAATTAACCAGG...,c17,2,train,1.343799,0.350000,0.229167,0.112500,0.308333
...,...,...,...,...,...,...,...,...,...
56847,AGAAAATCATACAACTCAGCATCCAGTTGGCTTTTTAAGAATTCTG...,c17,1,train,3.634526,0.312500,0.133333,0.162500,0.391667
56848,CTCGGCCTCCCAAAGTGCTGGGATTACAGGCGTGAGCCACCGCGTC...,c17,1,train,3.647392,0.375000,0.158333,0.170833,0.295833
56849,AGAAAATCATACAACTCAGCATCCAGTTGGCTTTTTAAGAATTCTG...,c17,2,train,3.664347,0.312500,0.133333,0.162500,0.391667
56850,GGCACAGGGTTTATGTTTAGGATGTTGAAAAAGTTCTGCAGATAAA...,c17,1,train,3.666589,0.350000,0.145833,0.175000,0.329167


In [9]:
CODES = {
    0:"A",
    1:"C",
    2:"G",
    3:"T",
    4:"N",
}

def id2n(n):
    return CODES[n]


In [10]:
def gen_random_seqs(num, lengh, cell_type):
    np.random.seed(7)
    seqs = [''.join(i) for i in np.random.choice(['A', 'C', 'G', 'T'], size=(num, lengh)).tolist()]
    cell_type = np.full((num,), cell_type).tolist()
    fold = np.full((num, ), 'val').tolist()
    count = np.full((num, ), 1).tolist()
    replicates = np.full((num, ), 1).tolist()
    score = np.random.uniform(1.5,3.5,num)
    random_df = pd.DataFrame({'seq': seqs, 'cell_type': cell_type, 'fold' :fold,'replicate' :replicates, '1':count, '2':count, '3':count, '4':count, 'mass_center': score})
    return random_df


In [11]:
random_df = gen_random_seqs(batch_size*100, seq_len, CELL_TYPE_FILTER)
random_df = dataset.PromotersData(random_df)

In [16]:
legnet_generator_path = f'../saved_models/generator/3utr/ut3_{CELL_TYPE_FILTER}_KL1_model_{epoch}.pth'
difussion_model = legnet_difgenerator.LegNet_diffusion(240,
                ks=7,
                block_sizes=[256, 128, 128, 64, 64, 64, 64],
                final_ch=4).to(device)

difussion_model.load_state_dict(torch.load(legnet_generator_path, map_location=device)['model_state_dict'])
difussion_model.requires_grad_ = False

In [None]:
PATH = '../saved_models/predictor/model-utr3-deltas-epoch=9-step=1330.ckpt'
predictor = RNARegressor.load_from_checkpoint(PATH, map_location='cpu')

In [None]:
all_size = random_df.data.shape[0]
test_size = batch_size*1000
train_size = all_size-test_size
train_set, val_set = torch.utils.data.random_split(random_df, [train_size, test_size])

dl_test = DataLoader(val_set,
                     batch_size=batch_size,
                     num_workers=num_workers,
                     shuffle=False,
                     worker_init_fn = lambda id: np.random.seed(id)
                    )

In [None]:
# Draw raw correlation plot
def correlation_plot(pred_df, mode_comment, set_seq, epoch):
    plt.figure(frameon=False, figsize=(8,6))
    mse = ((pred_df.mass_center - pred_df.pred_score) ** 2).mean()
    rmse = ((pred_df.mass_center - pred_df.pred_score) ** 2).mean() ** 0.5
    g = sns.jointplot(data=pred_df, x="pred_score", y="mass_center", kind="hex", xlim = (1.5,3.5))
    g.fig.suptitle(f'{CELL_TYPE_FILTER}, 3UTR')

    # Draw a line of x=y 
    x0, x1 = g.ax_joint.get_xlim()
    y0, y1 = g.ax_joint.get_ylim()
    lims = [max(x0, y0), min(x1, y1)]
    g.ax_joint.plot(lims, lims, '-k')   
    
    Sp_cor = round(stats.spearmanr(pred_df.mass_center, pred_df.pred_score)[0], 3)
    P_cor= round(stats.pearsonr(pred_df.mass_center, pred_df.pred_score)[0], 3)
    g.fig.text(0.2,0.6,
            f"$\\rho$ = {Sp_cor:.04f}\n" +
            f"r = {P_cor:.04f}\n" 
            f"MSE = {mse:.04f}\n" 
            # # f"RMSE = {rmse:.04f}\n"
            )
    g.fig.text(1, 0.5,
            f"steps = {steps:d}\n" +
            mode_comment + f'\nuniq seq: {set_seq} from {len(pred_df.pred_score)}')    
    g.ax_joint.set_xlabel('Pred')
    g.ax_joint.set_ylabel('Target')
    uniq = set_seq / len(pred_df.pred_score)
    if P_cor >= 0.5 and uniq > 0.98 : 
        plt.savefig(f'./tuning/utr3/{CELL_TYPE_FILTER}/P_{P_cor}_epoch{epoch}.svg', bbox_inches="tight")
    plt.tight_layout()
    plt.show()
    
    return Sp_cor, P_cor, uniq


In [None]:
from itertools import permutations
ALLPERM = torch.tensor(list(permutations((0, 1, 2, 3))))
ALLPERM

# diffusion-like sampling

def mutagenesisv2_(seqs, maxmut):
    batchsize = seqs.shape[0]
    seqlen = seqs.shape[2]
    muts = torch.full((batchsize,), maxmut)
    indexx = torch.arange(batchsize)
    mut_positions = torch.zeros(batchsize, seqlen, dtype=bool)
    for i in range(maxmut):
        single_positions = torch.randint(high=seqlen, size=(batchsize,))
        mut_positions[indexx, single_positions] |= muts > i

    mut_positions = mut_positions[:,None,:].broadcast_to(seqs.shape)
    x = seqs.permute(2, 0, 1)[mut_positions.permute(2, 0, 1)]
    mut_number = x.shape[0] // 4
    
    myperm = torch.randint(high=ALLPERM.shape[0], size=(mut_number,))
    myperm = (ALLPERM[myperm] + torch.arange(mut_number)[:,None] * 4).ravel()
    
    seqs.permute(2, 0, 1)[mut_positions.permute(2, 0, 1)] = x[myperm]
    


In [None]:
### generate with random scores

def predict_float(dl_test, mut_interval, intensities, start, end, nucl_df):    
    seqs_batches = []
    scores_batches = []
    with torch.no_grad():
        difussion_model.eval()
        for data in tqdm(dl_test):
            target_nucl = []
            seq_batch = data.float().to(device)
            score_chanels = seq_batch[:,4:5,:].clone().to(device)
            seq_batch = seq_batch[:,:4,:]
            target_score = torch.FloatTensor(seq_batch.shape[0], 1, 1).uniform_(start, end).to(device)
            for score in target_score:
                index = nucl_df.mass_center.searchsorted(score.cpu().numpy())[0][0]-1
                target_nucl.append(nucl_df[['A', 'C', 'G', 'T']].iloc[index])
            target_nucl = torch.broadcast_to(torch.tensor(target_nucl)[:,:,None], [target_score.shape[0], 4, seq_len]).to(device)
            target_score = target_score.to(device)
            for intens, muts, in zip(intensities, mut_interval):
                tmp = torch.broadcast_to(target_score, (target_score.shape[0], 1, seq_len))
                seq_batch = torch.concat((seq_batch.to(device), tmp.to(device), torch.full_like(score_chanels, intens).to(device), target_nucl.to(device)), dim=1)
                seq_batch = difussion_model(seq_batch.float())
                seq_batch = torch.softmax(seq_batch, dim=1) 
                mutagenesisv2_(seq_batch, muts)
            seqs_batches.append(seq_batch.cpu().numpy()) 
            scores_batches.append((target_score.squeeze()).cpu().numpy())
        return seqs_batches, scores_batches



In [None]:
start = 1.8
end = 3.2

steps = 149 # (empirically chosen)
m_i_start, m_i_end = 149, 0 # number of introduced mutations (empirically chosen)
in_start, in_end = 150, 1 # number of mutations in channel (empirically chosen)

In [None]:
mut_interval = np.linspace(m_i_start, m_i_end, steps).round().astype("int")
intensities = np.linspace(in_start, in_end, steps).round().astype("int")

seqs, scores_batches = predict_float(dl_test, mut_interval, intensities, start, end, subdf)



In [None]:
given_expression = np.concatenate(scores_batches)
decoded_seq  = torch.tensor(np.concatenate(seqs)).argmax(axis=1).cpu().numpy()

encoded_seqs = [''.join([id2n(n) for n in seq]) for seq in decoded_seq]
cell_type = np.full((len(encoded_seqs), ), CELL_TYPE_FILTER).tolist()
fold = np.full((len(encoded_seqs), ), 'val').tolist()
count = np.full((len(encoded_seqs), ), 1).tolist()
diff = np.full((len(encoded_seqs), ), None).tolist()
replicates = np.full((len(encoded_seqs), ), 1).tolist()

pred_df = pd.DataFrame({'seq': encoded_seqs, 'cell_type': cell_type, 'fold' :fold,'replicate' :replicates, 'mass_center': given_expression, 'diff':diff})

In [None]:
pred_df_val = dataset.UTRData(
    df=pred_df,
    augment=False,
    augment_test_time=False,
    augment_kws=dict(
        extend_left=0,
        extend_right=0,
        shift_left=0,
        shift_right=0,
        revcomp=False,
                    ),
    features=("sequence", "positional", "conditions"),  # ("sequence", "conditions", "positional", "revcomp")
    construct_type="utr3"
            )
 

pred_dl = DataLoader(
        pred_df_val,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False,
        drop_last=False
                            )

trainer = pl.Trainer(
        accelerator="gpu",
        devices=[1],)



pred_score = trainer.predict(model=predictor, dataloaders=pred_dl)


In [None]:
pred_df['pred_score'] = np.concatenate(pred_score)[:,1]
pred_df['pred_diff'] = np.concatenate(pred_score)[:,0]
pred_df = pred_df.drop(['fold', 'replicate','diff'], axis=1)

In [None]:
PATH = f'./generated/cell_{CELL_TYPE_FILTER}_epoch_{epoch}_3UTR.csv'
pred_df.to_csv(PATH, index=False)

----------------