In [1]:
import os
import pandas as pd
import numpy as np
import torch

from eugene.dataload._io import read_numpy, read

# Autoreload extension
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

# Basic import
import eugene as eu

eu.settings.dataset_dir = "/cellar/users/aklie/data/eugene/killoran17"

eu.__version__

Global seed set to 13


GPU is available: True
Number of GPUs: 2
Current GPU: 0
GPUs: Quadro RTX 5000


'0.0.4'

In [2]:
sdata = eu.dl.read_fasta(os.path.join(eu.settings.dataset_dir, "chr1.fa"))

In [3]:
def downsample_sdata(sdata, n=None, frac=None, copy=False):
    sdata = sdata.copy() if copy else sdata
    if n is None and frac is None:
        raise ValueError("Must specify either n or frac")
    if n is not None and frac is not None:
        raise ValueError("Must specify either n or frac, not both")
    num_seqs = sdata.n_obs
    if n is not None:
        if n > num_seqs:
            raise ValueError("n must be less than or equal to the number of sequences")
        rand_idx = np.random.choice(num_seqs, n, replace=False)
        sdata = sdata[rand_idx]
    elif frac is not None:
        if frac > 1:
            raise ValueError("frac must be less than or equal to 1")
        rand_idx = np.random.choice(num_seqs, int(num_seqs * frac), replace=False)
        sdata = sdata[rand_idx]
    return sdata
      
def remove_only_N_seqs(seqs):
    return [seq for seq in seqs if not all([x == "N" for x in seq])]  

def remove_only_N_seqs_sdata(sdata, copy=False):
    sdata = sdata.copy() if copy else sdata
    N_only_mask = np.array([all([x == "N" for x in seq]) for seq in sdata.seqs])
    sdata = sdata[~N_only_mask]
    return sdata

def seq_len_sdata(sdata, copy=False):
    sdata = sdata.copy() if copy else sdata
    sdata.seqs_annot["seq_len"] = [len(seq) for seq in sdata.seqs]
    return sdata

In [4]:
sdata_downsampled = downsample_sdata(sdata, frac=0.1, copy=True)

In [5]:
remove_only_N_seqs_sdata(sdata_downsampled, copy=False)
eu.pp.sanitize_seqs_sdata(sdata_downsampled, copy=False)

SeqData object modified:
	seqs: ['ATTAGCATACTATATACTAATAGAATTAGCATACTATATACTAATAGAAT'
 'tttcactggcctagagagctcccctctggaggaccctacaactgcagggt'
 'ccagttccagaacagttaagctgaaacctgaaaagatgactaggattagc' ...
 'GTAAAGAAGTATCCCCTCCCAGAAACATTTACTTCAAGTGAGTTAGTCAA'
 'CACTCTGCTCAGAGTCAGATACAGATAGAGCTGTTTTTGTTTTTATTTTT'
 'ggttctctgagtatatacgtaaattttagtatccagaccttttattttga'] -> 248956 seqs added


In [6]:
sdata_downsampled.seqs[:5]

array(['ATTAGCATACTATATACTAATAGAATTAGCATACTATATACTAATAGAAT',
       'TTTCACTGGCCTAGAGAGCTCCCCTCTGGAGGACCCTACAACTGCAGGGT',
       'CCAGTTCCAGAACAGTTAAGCTGAAACCTGAAAAGATGACTAGGATTAGC',
       'GTCTGTTTCTGATAAAAAGAGATACATGGAATGTTTTTATTTTAGCATTT',
       'GTTTTCGCGGTAAGAGTTGTTTAGGTAGTACAGATGCAATTTTCTTTTAT'], dtype='<U50')

In [7]:
seq_len_sdata(sdata_downsampled, copy=False)

SeqData object with = 248956 seqs
seqs = (248956,)
names = (248956,)
rev_seqs = None
ohe_seqs = None
ohe_rev_seqs = None
seqs_annot: 'seq_len'
pos_annot: None
seqsm: None
uns: None

In [8]:
sdata_downsampled["seq_len"].value_counts()

50    248956
Name: seq_len, dtype: int64

In [9]:
eu.pp.ohe_seqs_sdata(sdata_downsampled, copy=False)

One-hot encoding sequences:   0%|          | 0/248956 [00:00<?, ?it/s]

SeqData object modified:
	ohe_seqs: None -> 248956 ohe_seqs added


In [10]:
eu.pp.train_test_split_sdata(sdata_downsampled)

SeqData object modified:
    seqs_annot:
        + train_val


In [11]:
seq_len = 50
latent_dim = seq_len * 4

In [12]:
generator = torch.nn.Sequential(
    eu.models.base.BasicFullyConnectedModule(
        input_dim=500,
        output_dim=latent_dim,
        hidden_dims=[300]
    ),
    torch.nn.Tanh()
)

In [13]:
disc_batchnorm = True
disc_conv_dropout = 0
disc_fc_dropout = 0
discriminator = torch.nn.Sequential(
    eu.models.DeepBind(
        input_len=latent_dim,
        output_dim=1,
        pool_width=8,
        conv_kwargs=dict(channels=[4, 128], conv_kernels=[16], dropout_rates=disc_conv_dropout, batchnorm=disc_batchnorm),
        fc_kwargs=dict(hidden_dims=[256, 64], dropout_rate=disc_fc_dropout, batchnorm=disc_batchnorm),
        mode="dna",
    ),
    torch.nn.Sigmoid()
)

In [14]:
model = eu.models.GAN(
    latent_dim=latent_dim, 
    generator=generator, 
    discriminator=discriminator,
    grad_clip=0.01,
    gen_lr=1e-3,
    disc_lr=1e-3,
)

In [21]:
!export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/opt/miniconda3/lib/
eu.settings.batch_size = 256
eu.settings.dl_num_workers = 3
eu.settings.dl_pin_memory_gpu_training = True
eu.train.fit(
    model=model,
    sdata=sdata_downsampled,
    model_checkpoint_callback=False,
    early_stopping_callback=False,
    gpus=1
)

Global seed set to 13


No transforms given, assuming just need to tensorize.
No transforms given, assuming just need to tensorize.


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Set SLURM handle signals.

  | Name          | Type       | Params
---------------------------------------------
0 | generator     | Sequential | 210 K 
1 | discriminator | Sequential | 157 K 
---------------------------------------------
367 K     Trainable params
0         Non-trainable params
367 K     Total params
1.470     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 13


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

MisconfigurationException: ReduceLROnPlateau conditioned on metric val_discriminator_loss which is not available. Available metrics are: ['train_generator_loss', 'train_discriminator_loss', 'val_generator_loss', 'val_generator_loss_epoch']. Condition can be set using `monitor` key in lr scheduler dict

---