In [171]:
import os
import pandas as pd
import numpy as np
import torch

from eugene.dataload._io import read_numpy, read

# Autoreload extension
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

# Basic import
import eugene as eu

eu.settings.dataset_dir = "/cellar/users/aklie/data/eugene/"

eu.__version__

'0.0.4'

In [186]:
import gzip
import shutil
from eugene.datasets._utils import try_download_urls
def killoran17(dataset="chr1", return_sdata=True, **kwargs):
    urls_list = [
        "https://hgdownload.soe.ucsc.edu/goldenPath/hg38/chromosomes/chr1.fa.gz"
    ]
    if dataset == "chr1":
        dataset = [0]
    paths = try_download_urls(dataset, urls_list, "killoran17")
    if dataset == [0]:
        paths = paths[0]
        print("Unzipping...")
        with gzip.open(paths, 'rb') as f_in:
            with open(paths[:-3], 'wb') as f_out:
                shutil.copyfileobj(f_in, f_out)
                paths = paths[:-3]
    if return_sdata:
        return eu.dl.read_fasta(paths, **kwargs)
    else:
        return paths

In [187]:
sdata = killoran17()

Dataset killoran17 chr1.fa.gz has already been downloaded.
Unzipping...


In [188]:
sdata

SeqData object with = 2489565 seqs
seqs = (2489565,)
names = (2489565,)
rev_seqs = None
ohe_seqs = None
ohe_rev_seqs = None
pos_annot: None
seqsm: None
uns: None

In [189]:
def downsample_sdata(sdata, n=None, frac=None, copy=False):
    sdata = sdata.copy() if copy else sdata
    if n is None and frac is None:
        raise ValueError("Must specify either n or frac")
    if n is not None and frac is not None:
        raise ValueError("Must specify either n or frac, not both")
    num_seqs = sdata.n_obs
    if n is not None:
        if n > num_seqs:
            raise ValueError("n must be less than or equal to the number of sequences")
        rand_idx = np.random.choice(num_seqs, n, replace=False)
        sdata = sdata[rand_idx]
    elif frac is not None:
        if frac > 1:
            raise ValueError("frac must be less than or equal to 1")
        rand_idx = np.random.choice(num_seqs, int(num_seqs * frac), replace=False)
        sdata = sdata[rand_idx]
    return sdata
      
def remove_only_N_seqs(seqs):
    return [seq for seq in seqs if not all([x == "N" for x in seq])]  

def remove_only_N_seqs_sdata(sdata, copy=False):
    sdata = sdata.copy() if copy else sdata
    N_only_mask = np.array([all([x == "N" for x in seq]) for seq in sdata.seqs])
    sdata = sdata[~N_only_mask]
    return sdata

def seq_len_sdata(sdata, copy=False):
    sdata = sdata.copy() if copy else sdata
    sdata.seqs_annot["seq_len"] = [len(seq) for seq in sdata.seqs]
    return sdata

In [190]:
sdata_downsampled = downsample_sdata(sdata, frac=0.01, copy=True)

In [191]:
remove_only_N_seqs_sdata(sdata_downsampled, copy=False)
eu.pp.sanitize_seqs_sdata(sdata_downsampled, copy=False)

SeqData object modified:
	seqs: ['ATTAGCATACTATATACTAATAGAATTAGCATACTATATACTAATAGAAT'
 'tttcactggcctagagagctcccctctggaggaccctacaactgcagggt'
 'ccagttccagaacagttaagctgaaacctgaaaagatgactaggattagc' ...
 'CCACGGGTCACATCCTGAGTTGTGCCGCATCCGCTTAGTGCAGCGTGTGC'
 'tgataataaatgagtttctctaggaatttttctttttgttcaggcactgt'
 'CAAGCGGCCTTAGTAAAAAAGAGAAGAAAAAATTTATAGAAAATGTTGCT'] -> 24895 seqs added


In [192]:
sdata_downsampled.seqs[:5]

array(['ATTAGCATACTATATACTAATAGAATTAGCATACTATATACTAATAGAAT',
       'TTTCACTGGCCTAGAGAGCTCCCCTCTGGAGGACCCTACAACTGCAGGGT',
       'CCAGTTCCAGAACAGTTAAGCTGAAACCTGAAAAGATGACTAGGATTAGC',
       'GTCTGTTTCTGATAAAAAGAGATACATGGAATGTTTTTATTTTAGCATTT',
       'GTTTTCGCGGTAAGAGTTGTTTAGGTAGTACAGATGCAATTTTCTTTTAT'], dtype='<U50')

In [193]:
seq_len_sdata(sdata_downsampled, copy=False)

SeqData object with = 24895 seqs
seqs = (24895,)
names = (24895,)
rev_seqs = None
ohe_seqs = None
ohe_rev_seqs = None
seqs_annot: 'seq_len'
pos_annot: None
seqsm: None
uns: None

In [194]:
sdata_downsampled["seq_len"].value_counts()

50    24895
Name: seq_len, dtype: int64

In [195]:
eu.pp.ohe_seqs_sdata(sdata_downsampled, copy=False)

One-hot encoding sequences:   0%|          | 0/24895 [00:00<?, ?it/s]

SeqData object modified:
	ohe_seqs: None -> 24895 ohe_seqs added


In [196]:
eu.pp.train_test_split_sdata(sdata_downsampled)

SeqData object modified:
    seqs_annot:
        + train_val


In [197]:
seq_len = 50
latent_dim = 128

In [198]:
import torch.nn as nn

In [199]:
class Flatten(torch.nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)
        
class View(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape

    def __repr__(self):
        return f'View{self.shape}'

    def forward(self, input):
        '''
        Reshapes the input according to the shape saved in the view data structure.
        '''
        batch_size = input.size(0)
        shape = (batch_size, *self.shape)
        out = input.view(shape)
        return out

In [200]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_multiplier=1):
        super(ResidualBlock, self).__init__()
        self.residual_multiplier = residual_multiplier
        self.relu = nn.ReLU()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
        
    def forward(self, x):
        residual = x
        out = self.relu(x)
        out = self.conv(out)
        out = out*self.residual_multiplier + residual
        return out

In [201]:
class Generator(nn.Module):
    def __init__(self, latent_dim, seq_len):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.seq_len = seq_len
        self.linear = nn.Linear(latent_dim, seq_len*100)
        self.elu = nn.ELU()
        self.view = View((100, seq_len))
        self.res_blocks = nn.Sequential(*[ResidualBlock(100, 100, 3, 1, 1)]*5)
        self.conv = nn.Conv1d(100, 4, 1)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        out = self.linear(x) 
        out = self.elu(out) 
        out = self.view(out) 
        out = self.res_blocks(out) 
        out = self.conv(out)
        out = self.softmax(out)
        return out

In [202]:
gen = Generator(latent_dim, seq_len)

In [203]:
z = torch.Tensor(np.random.normal(0, 1, (10, latent_dim)))
z.shape

torch.Size([10, 128])

In [204]:
fake = gen(z)

In [205]:
fake_tokens = np.argmax(fake.detach().numpy(), axis=1).reshape(10, 50)
eu.pp.decode_seq(eu.pp._utils._token2one_hot(fake_tokens[0]))

'ATTTTCCATATAGGTCGTACCCGTTTTCTCCTATTTCATGTCCCTTTCTC'

In [206]:
class Discriminator(nn.Module):
    def __init__(self, seq_len):
        super(Discriminator, self).__init__()
        self.seq_len = seq_len
        self.conv = nn.Conv1d(4, 100, 1)
        self.res_blocks = nn.Sequential(*[ResidualBlock(100, 100, 3, 1, 1)]*5)
        self.view = View((100*seq_len,))
        self.linear = nn.Linear(100*seq_len, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.conv(x)
        out = self.res_blocks(out)
        out = self.view(out)
        out = self.linear(out)
        out = self.sigmoid(out)
        return out

In [207]:
disc = Discriminator(seq_len) 

In [208]:
disc(fake)

tensor([[0.2476],
        [0.2639],
        [0.2641],
        [0.2184],
        [0.2509],
        [0.2672],
        [0.2206],
        [0.2475],
        [0.2486],
        [0.2420]], grad_fn=<SigmoidBackward0>)

In [164]:
model = eu.models.GAN(
    seq_len=50,
    latent_dim=128, 
    generator=gen, 
    discriminator=disc,
    mode="gan",
    grad_clip=0.01,
    gen_lr=1e-3,
    disc_lr=1e-3
)

In [209]:
z = torch.Tensor(np.random.normal(0, 1, (10, latent_dim)))
z.shape

torch.Size([10, 128])

In [210]:
model.discriminator(model(z))

tensor([[0.5013],
        [0.5014],
        [0.5012],
        [0.5019],
        [0.5018],
        [0.5015],
        [0.5016],
        [0.5011],
        [0.5016],
        [0.5018]])

In [211]:
eu.settings.batch_size = 128
eu.settings.dl_num_workers = 0
eu.settings.dl_pin_memory_gpu_training = True
eu.train.fit(
    model=model,
    sdata=sdata_downsampled,
    mode="wgan"
    epochs=5,
    gpus=1,
    model_checkpoint_callback=False,
    early_stopping_callback=False,
    log_dir="/cellar/users/aklie/projects/EUGENe/tests/notebooks/models/GAN",
    name="test"
)

Global seed set to 13
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


No transforms given, assuming just need to tensorize.
No transforms given, assuming just need to tensorize.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Set SLURM handle signals.

  | Name          | Type          | Params
------------------------------------------------
0 | generator     | Generator     | 675 K 
1 | discriminator | Discriminator | 35.6 K
------------------------------------------------
0         Trainable params
711 K     Non-trainable params
711 K     Total params
2.844     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
Global seed set to 13
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Training: 0it [00:00, ?it/s]

  f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [214]:
from eugene.plot._utils import tflog2pandas

In [216]:
logs = tflog2pandas("/cellar/users/aklie/projects/EUGENe/tests/notebooks/models/GAN/test/version_1")

In [220]:
logs[logs["metric"] == "train_discriminator_loss"]

Unnamed: 0,metric,value,step
55,train_discriminator_loss,0.695262,49.0
56,train_discriminator_loss,0.695445,99.0
57,train_discriminator_loss,0.695528,149.0
58,train_discriminator_loss,0.695282,199.0
59,train_discriminator_loss,0.695296,249.0
60,train_discriminator_loss,0.695336,299.0
61,train_discriminator_loss,0.695307,349.0
62,train_discriminator_loss,0.695351,399.0
63,train_discriminator_loss,0.695502,449.0
64,train_discriminator_loss,0.695396,499.0


In [212]:
def generate_seqs(model, num_seqs):
    z = torch.Tensor(np.random.normal(0, 1, (num_seqs, model.latent_dim)))
    fake = model(z)
    print(fake.shape)
    fake_tokens = np.argmax(fake.detach().numpy(), axis=1).reshape(num_seqs, 50)
    print(fake_tokens.shape)
    return np.array([eu.pp.decode_seq(eu.pp._utils._token2one_hot(tokens)) for tokens in fake_tokens])

In [213]:
generate_seqs(model, 10)

torch.Size([10, 4, 50])
(10, 50)


array(['TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT',
       'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT',
       'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTATTTTTTTTTTTTTTTTTT',
       'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT',
       'TTTTATTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT',
       'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT',
       'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT',
       'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT',
       'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT',
       'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTATTTTTTTTTTTTT'], dtype='<U50')

---