In [1]:
import os
import sys

import itertools

import numpy as np
import pandas as pd
import scipy.stats as ss
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm.notebook import tqdm

import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl

In [2]:
import pygad

In [3]:
sys.path.append("../../predictor/regression_multiple")
import dataset_regression as utrdata
from lit_regressor import RNARegressor

In [4]:
class FitnessModel:
    def __init__(
        self,
        model_path: str,
        features: tuple,
        construct_type: str,
    ):
        self.dataset_params = dict(
            features=features,
            construct_type=construct_type,
            augment=False,
            augment_test_time=False,
            augment_kws=dict(
                extend_left=0,
                extend_right=0,
                shift_left=0,
                shift_right=0,
                revcomp=False,
            ),
        )
        if construct_type.lower() == "utr3":
            self.celltype_codes = list(utrdata.CELLTYPE_CODES_UTR3.keys())
        elif construct_type.lower() == "utr5":
            self.celltype_codes = list(utrdata.CELLTYPE_CODES_UTR5.keys())
        else:
            raise ValueError("'' must be from ['utr3', 'utr5']")

        self.load_model(model_path)

    def load_model(self, ckpt_path):
        self.model = RNARegressor.load_from_checkpoint(ckpt_path)
        progressbar_callback = pl.callbacks.TQDMProgressBar(refresh_rate=2.0)
        self.trainer = pl.Trainer(
            callbacks=[progressbar_callback],
            logger=False,
            accelerator="gpu",
            devices=1,
            deterministic=True,
        )

    def fit(self, *args):
        raise NotImplementedError()

    def predict(self, seqs, celltype_codes=None):
        df = self.create_seq_df(seqs, celltype_codes=celltype_codes)
        ds = utrdata.UTRData(
            df=df,
            **self.dataset_params
        )
        dl = DataLoader(
            ds,
            batch_size=1024,
            num_workers=10,
            shuffle=False,
            drop_last=False
        )
        prediction = self.trainer.predict(model=self.model, dataloaders=dl)

        val_pred, _ = zip(*prediction)
        val_pred = torch.concat(val_pred)
        val_pred = val_pred.numpy()

        df.drop(["mass_center", "diff"], axis=1, inplace=True)
        df["pred_mass_center"] = val_pred[:, 1]
        return df

    def predict_pairwise(self, seqs, celltype_codes=None):
        assert len(celltype_codes) == 2
        df = self.predict(seqs, celltype_codes=celltype_codes)
        df_pivot = df.pivot(columns="cell_type", index=["num", "seq"], values="pred_mass_center")
        diff = df_pivot[celltype_codes[0]] - df_pivot[celltype_codes[1]]
        return diff.to_numpy()

    def create_seq_df(self, seqs, celltype_codes=None):
        if celltype_codes is None:
            celltype_codes = self.celltype_codes
        df = pd.DataFrame({"seq": seqs} | {ct: 0.0 for ct in celltype_codes}).reset_index(names="num")
        df_long = df.melt(id_vars=["num", "seq"], value_vars=celltype_codes, var_name="cell_type", value_name="mass_center")
        df_long["diff"] = 0.0
        return df_long

In [5]:
ckpt_path = "../../regression_multiple/model_validation/model-utr5-deltas-epoch=9-step=840.ckpt"

fitness_model = FitnessModel(
    model_path=ckpt_path,
    features=("sequence", "positional", "conditions"),
    construct_type="utr5",
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
INVERSE_CODES = {v: k for k, v in utrdata.CODES.items()}

In [7]:
def encode_seq_onehot(seq):
    seq_enc = np.array([utrdata.CODES[i] for i in seq])
    return seq_enc


def decode_seq_onehot(arr):
    seq = "".join([INVERSE_CODES[i] for i in arr])
    return seq

In [8]:
CELL_TYPE_CODES = None  # pygad disallows use of itertools.partial for some reason

def fitness_func(ga, batch, batch_idx):
    seqs = pd.Series([decode_seq_onehot(arr) for arr in batch])
    eorfs = -10 * seqs.str.contains("ATG").to_numpy()
    pred = fitness_model.predict_pairwise(seqs, celltype_codes=CELL_TYPE_CODES)
    return pred + eorfs

In [9]:
src_data = pd.read_csv("../../data/UTR5_zinb_norm_singleref_2023-05-23.csv")["seq"].drop_duplicates().reset_index(drop=True)
src_data

0        ATTGCTGCAGACGCTCACCCCAGACACTCACTGCACCGGAGTGAGC...
1        TGGAAGGGCCGTGTTCGTGTTGGCAAAGAAGGTCGGCTGCTGAGCC...
2        ACTTCCGTTGAGTTCCGCCTCGCCGTTTGTCCCTTGCGGTACCCGT...
3        TTTGTCCCTTGCGGTACCCGTCCGCATACGAATCTAGCCCGGGAAC...
4        ATACGAATCTAGCCCGGGAACCGAGTTGCGGGAGTGCGGTCTGTGC...
                               ...                        
21604    CTCCGGCTCGACGCCGGCTCTCTTTTTGACGCCCCGCCGCCGGGGT...
21605    CGGCTGCGGCTGCGGCTGCGGCTGCTACTGCTACGCTCCTAGCTTG...
21606    CCTGGAGCCTCCGCGCCGGCTCAGCCTGGGGGCGGGCTCCGGTCCG...
21607    GCAGAGTCTGCGGACCCGGCGCCGAGGCGGCCACCCGAGACGCGGC...
21608    CCGTCGTCTCCTCCGCGTCCCCGCCCGCCAGCTGCTGTCGGAGGTT...
Name: seq, Length: 21609, dtype: object

In [10]:
from shuffle_dinucl import shuffle_seq_dinucl

In [11]:
def generate_initial_population(n: int, seqs=src_data):
    if n <= seqs.shape[0]:
        seqs = seqs.sample(n, replace=False)
    else:
        seqs = seqs.sample(n, replace=True)
    seqs = seqs.apply(shuffle_seq_dinucl)
    enc_seqs = np.stack([encode_seq_onehot(s) for s in seqs])
    return enc_seqs

In [None]:
def on_gen(ga_instance):
    print("Finished generation:", ga_instance.generations_completed)
    # print("Best solution:", *ga_instance.best_solution())


n_obj = 10000

generated = dict()
num_genes = 50
for seed, cell_types in itertools.product(range(1, 11), itertools.permutations(fitness_model.celltype_codes, 2)):
    CELL_TYPE_CODES = cell_types
    cell_type_tag = '-'.join(cell_types)
    file_name = os.path.join("saved_utr5_v4", f"utr5-seed={seed}-genes={num_genes}-{cell_type_tag}.csv")

    np.random.seed(seed)
    initial_population = generate_initial_population(n=n_obj)

    ga_instance = pygad.GA(
        # Basic parameters
        random_seed=seed,
        num_generations=25,
        stop_criteria="saturate_10",
        initial_population=initial_population,
        sol_per_pop=n_obj,
        fitness_batch_size=n_obj,
        on_generation=on_gen,
        # Genes and fitness
        fitness_func=fitness_func,
        gene_space=np.arange(0, 4),
        num_genes=num_genes,
        # Mutation process
        mutation_type="adaptive",
        mutation_probability=(0.2, 0.05),
        num_parents_mating=n_obj // 2,
        parent_selection_type="sss",
        # Crossover process
        crossover_type="two_points",  # "single_point", "two_points"
        crossover_probability=0.1,
        # Selection process
        keep_parents=100,
        keep_elitism=1,
    )

    ga_instance.run()
    ga_instance.plot_fitness()
    population = ga_instance.population
    seqs = pd.Series([decode_seq_onehot(arr) for arr in population])

    gen_df = fitness_model.predict(seqs)
    gen_df.to_csv(file_name, index=False)
    generated[cell_type_tag] = gen_df

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: 0it [00:00, ?it/s]

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: 0it [00:00, ?it/s]

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: 0it [00:00, ?it/s]

Finished generation: 1


You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: 0it [00:00, ?it/s]

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: 0it [00:00, ?it/s]

Finished generation: 2


You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: 0it [00:00, ?it/s]

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: 0it [00:00, ?it/s]

Finished generation: 3


You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: 0it [00:00, ?it/s]

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: 0it [00:00, ?it/s]

Finished generation: 4


In [None]:
generated.keys()