In [None]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Image
from IPython.core.debugger import set_trace

# HoneyDew
Here we present Honeydew, a network for predicting genomic structural features using epigenetic data.  In keeping with the naming convention of "muppet networks" honeydew is a tool which incorporates self-attention mechanisms to effectively predict higher order genomic conformational structures such as AB compartments and TADs.

# Experiments

## Data Assembly
First we gather some data for our experimetns.  We first pull epigenetic data from Roswnwald et al "A machine learning framework for the prediction of chromatin folding in drosophila using epigenetic features" https://peerj.com/articles/cs-307/

This data contains a collection of 18 epigentic track features along three different cell lines and corresponding amaratus gamma



In [None]:
!mkdir Data
!mkdir Data/Drosophilla
!wget -P Data/Drosophilla https://raw.githubusercontent.com/MichalRozenwald/Hi-ChIP-ML/master/data/epigenetics/s2_kc_bg_scaled_18_features_2901.csv
!wget -P Data/Drosophilla https://raw.githubusercontent.com/MichalRozenwald/Hi-ChIP-ML/master/data/target/s2_kc_bg_clean_gamma_2901.csv


We can visualize this data briefly using pandas

In [None]:
import pandas as pd
df_feat = pd.read_csv("Data/Drosophilla/s2_kc_bg_scaled_18_features_2901.csv",
                sep=',')
print("Label:", str(df_feat.columns))
df_label  = pd.read_csv("Data/Drosophilla/s2_kc_bg_clean_gamma_2901.csv", 
                 sep=',')
print("Features:", str(df_label.columns))




We notice that fetures dataframe currently has a column called 'gamma',
this is the gamma value described inside of the armatus paper.
*https://almob.biomedcentral.com/articles/10.1186/1748-7188-9-14*

While we could work with just gamma we are also interested in checking if epigenetic data can be used to predict other meaningfull downstream metrics such as insulation score and ab compartment vector.  Thus before running experiments we would like to extend the dataset to encompass some of these datasets.  To do this will require Identification of labels of these metrics using Hi-C Data.  Thus we need to download the raw Hi-C Data

In [None]:
!mkdir Data/Drosophilla/HiC_Maps
!wget -P Data/Drosophilla/HiC_Maps ftp://ftp.ncbi.nlm.nih.gov/geo/series/GSE69nnn/GSE69013/suppl/GSE69013_BG3_merged_IC-heatmap-20K.txt.gz
!wget -P Data/Drosophilla/HiC_Maps ftp://ftp.ncbi.nlm.nih.gov/geo/series/GSE69nnn/GSE69013/suppl/GSE69013_KC_merged_IC-heatmap-20K.txt.gz
!wget -P Data/Drosophilla/HiC_Maps ftp://ftp.ncbi.nlm.nih.gov/geo/series/GSE69nnn/GSE69013/suppl/GSE69013_S2_merged_IC-heatmap-20K.txt.gz
!gunzip Data/Drosophilla/HiC_Maps/GSE69013_BG3_merged_IC-heatmap-20K.txt.gz
!gunzip Data/Drosophilla/HiC_Maps/GSE69013_KC_merged_IC-heatmap-20K.txt.gz
!gunzip Data/Drosophilla/HiC_Maps/GSE69013_S2_merged_IC-heatmap-20K.txt.gz

cell_line_hics = [
    "Data/Drosophilla/HiC_Maps/GSE69013_S2_merged_IC-heatmap-20K.txt",
    "Data/Drosophilla/HiC_Maps/GSE69013_KC_merged_IC-heatmap-20K.txt",
    "Data/Drosophilla/HiC_Maps/GSE69013_BG3_merged_IC-heatmap-20K.txt",
]

cell_lines = ['S2','KC','BG']

In [None]:
#!mkdir Other_Tools/Juicer
#!git clone https://github.com/aidenlab/juicer.git Other_Tools/Juicer
#!mkdir Data/Drosophilla/Fastq

#!prefetch SRR2032292
#!fastq-dump SRR2032292 --split-files -O Data/Drosophilla/Fastq
#!wget -P Data/Drosophilla http://igenomes.illumina.com.s3-website-us-east-1.amazonaws.com/Drosophila_melanogaster/UCSC/dm3/Drosophila_melanogaster_UCSC_dm3.tar.gz 
#!tar -zxvf Data/Drosophilla/Drosophila_melanogaster_UCSC_dm3.tar.gz -C Data/Drosophilla
!bwa mem -SP5M -t 8 Data/Drosophilla/Drosophila_melanogaster/UCSC/dm3/Sequence/BWAIndex/genome.fa Data/Drosophilla/Fastq/SRR2032292_1.fastq  Data/Drosophilla/Fastq/SRR2032292_2.fastq  > test.sam

# TODO Remove chimeric reads
# TODO Sort
# TODO remove dups
# TODO juicer pre
# extract eigen





In [None]:
!ls Data/Drosophilla/Fastq


# TAD 

Using the downloaded Hi-C data we will extract chromosomal contact matrices from which TAD metrics can be computed

In [None]:
cell_line_hics = [
    "Data/Drosophilla/HiC_Maps/GSE69013_S2_merged_IC-heatmap-20K.txt",
    "Data/Drosophilla/HiC_Maps/GSE69013_KC_merged_IC-heatmap-20K.txt",
    "Data/Drosophilla/HiC_Maps/GSE69013_BG3_merged_IC-heatmap-20K.txt",
]

cell_lines = ['S2','KC','BG']
chro_mats = {}

for cell_line, fn in zip(cell_lines, cell_line_hics):
    hic  = np.loadtxt(fn,
          dtype=str,
             skiprows=1)
    mat   = hic[:,1:].astype(float)
    bins  = hic[:,0]
    np.save("bins.npy", bins)
    def getChro(x):
        return x.split(":")[0]
    chro_at_bins = np.array(list(map(getChro, bins)))
    chros     = np.unique(chro_at_bins)
    chros     = np.delete(chros, 4)# the epigentic data is not used for chro 4
    for chro in chros:
        d1_mask = chro_at_bins != chro
        chro_mats[cell_line, chro] = np.delete(np.delete(mat,
                            d1_mask,
                           axis=0), d1_mask,
                           axis=1)  


Insulation Scores
We determine insulation scores using the Hi-C data, to do this we firt build  helping funcition comupte Insualtion

In [None]:
import torch
class computeInsulation(torch.nn.Module):
    def __init__(self, window_radius=10, deriv_size=10):
        super(computeInsulation, self).__init__()
        self.window_radius = window_radius
        self.deriv_size  = deriv_size
        self.di_pool     = torch.nn.AvgPool2d(kernel_size=(2*window_radius+1), stride=1) #51
        self.top_pool    = torch.nn.AvgPool1d(kernel_size=deriv_size, stride=1)
        self.bottom_pool = torch.nn.AvgPool1d(kernel_size=deriv_size, stride=1)

    def forward(self, x):
        iv     = self.di_pool(x)
        iv     = torch.diagonal(iv, dim1=2, dim2=3)
        iv     = torch.log2(iv/torch.mean(iv))
        top    = self.top_pool(iv[:,:,self.deriv_size:])
        bottom = self.bottom_pool(iv[:,:,:-self.deriv_size])
        dv     = (top-bottom)
        left   = torch.cat([torch.zeros(dv.shape[0], dv.shape[1],2), dv], dim=2)
        right  = torch.cat([dv, torch.zeros(dv.shape[0], dv.shape[1],2)], dim=2)
        band   = ((left<0) == torch.ones_like(left)) * ((right>0) == torch.ones_like(right))
        band   = band[:,:,2:-2]
        boundaries = []
        for i in range(0, band.shape[0]):
            cur_bound = torch.where(band[i,0])[0]+self.window_radius+self.deriv_size
            boundaries.append(cur_bound)
        return iv, dv, boundaries




Now we will apply this insulation using a few different diagonal radius to capture different ranges of tad sizes

In [None]:
cell_line_hics = [
    "Data/Drosophilla/HiC_Maps/GSE69013_S2_merged_IC-heatmap-20K.txt",
    "Data/Drosophilla/HiC_Maps/GSE69013_KC_merged_IC-heatmap-20K.txt",
    "Data/Drosophilla/HiC_Maps/GSE69013_BG3_merged_IC-heatmap-20K.txt",
]

cell_lines = ['S2','KC','BG']
win_radii  = [3,5,10]
insul_vecs = {}
dv_vecs    = {}

for win_radius in win_radii:
    insulationComputer = computeInsulation(window_radius=win_radius)
    for cell_line, fn in zip(cell_lines, cell_line_hics):
        for chro in chros:
            mat_torch = torch.unsqueeze(torch.unsqueeze(torch.from_numpy(chro_mats[cell_line, chro]),
                                                    dim=0), dim=1)
            iv, dv, boundaries = insulationComputer(mat_torch.to(dtype=torch.float32))
            insul_vecs[cell_line, win_radius, chro] = iv[0,0,:].numpy()
            dv_vecs[cell_line, win_radius, chro]    = dv[0,0,:].numpy() #*2+4


In [None]:
print(dv_vecs['S2',10, '2L'].shape)
print(insul_vecs['S2', 10, '2L'].shape)
dv_vecs['S2',5,'2L'][0:10]


now that we have computed insulation scores for the chromosomes we add them to our label df

In [None]:
for win_radius in win_radii:
    insulation_column = []
    dv_column         = []
    for cl in cell_lines:
        for chro in chros:
            window_buff = np.repeat(np.nan, (win_radius-1))
            insulation_column.extend(np.repeat(np.nan, (win_radius)))
            insulation_column.extend(insul_vecs[cl, win_radius, chro])
            insulation_column.extend(np.repeat(np.nan, (win_radius-1)))
            
            dv_column.extend(np.repeat(np.nan, (win_radius+10-1)))
            dv_column.extend(dv_vecs[cl, win_radius, chro])
            dv_column.extend(np.repeat(np.nan, (win_radius+10-1)))
            
    #df_label['insulation_'+str(win_radius)] = insulation_column 
    #df_label['difference_'+str(win_radius)] = dv_column
    df_label['insulation_'+str(win_radius)] = dv_column
print(df_label.columns)


# Directionality Index
firs we build a helper class for computing directionality index as described in Dixon et al 2012 https://www.nature.com/articles/nature11082

In [None]:
import torch
from torch.nn import functional as F
class computeDirectionality(torch.nn.Module):
    def __init__(self,
                 radius=2
                ):
        self.up   = torch.zeros((2*radius+1, 2*radius+1))
        self.down = torch.zeros((2*radius+1, 2*radius+1))
        self.down[radius+1:,radius]    = 1
        self.up[:radius, radius]       = 1
        self.up   = torch.unsqueeze(self.up, 0)
        self.up   = torch.unsqueeze(self.up, 0)
        self.down = torch.unsqueeze(self.down, 0)
        self.down = torch.unsqueeze(self.down, 0)
        
    def forward(self, x):
        a       = F.conv2d(x, self.up)
        b       = F.conv2d(x, self.down)
        e       = (a+b)/2
        sign    =  torch.sign(b-a)
        term    = ((((a-e)**2)/e)+(((b-e)**2)/e))
        di      = sign * term
        di      = di.squeeze()
        di_vec  = torch.diagonal(di)
        return di_vec


Now we apply this class to the different cell lines using a few different possible parameters for window size

In [None]:
direction_vecs = {}
for win_radius in win_radii:
    directionalityComputer = computeDirectionality(radius=win_radius)
    for cell_line, fn in zip(cell_lines, cell_line_hics):
        for chro in chros:
            mat_torch = torch.unsqueeze(torch.unsqueeze(torch.from_numpy(chro_mats[cell_line, chro]),
                                                    dim=0), dim=1)
            di_vec = directionalityComputer.forward(mat_torch.to(dtype=torch.float32))
            direction_vecs[cell_line, win_radius, chro] = di_vec.numpy() #/200+3
        

Now we add these direciton vecs to the label object

In [None]:
for win_radius in win_radii:
    direction_column = []
    for cl in cell_lines:
        for chro in chros:
            window_buff = np.repeat(np.nan, (win_radius-1))
            direction_column.extend(np.repeat(np.nan, (win_radius)))
            direction_column.extend(direction_vecs[cl, win_radius, chro])
            direction_column.extend(np.repeat(np.nan, (win_radius-1)))
    df_label['directionality_'+str(win_radius)] = direction_column 
print(df_label.columns)

## Clean the TAD Labels

In [None]:
#rows_to_drop = df_label.isnull().any(axis=1)
#clean_feature_df = df_feat[~rows_to_drop.to_numpy()]
#clean_labels_df  = df_label[~rows_to_drop.to_numpy()]
#print(clean_labels_df.columns)
#print(clean_feature_df.columns)

clean_features_df = df_feat[ ~df_label.isin([np.nan, np.inf, -np.inf]).any(1)]
clean_labels_df   = df_label[~df_label.isin([np.nan, np.inf, -np.inf]).any(1)]

print(clean_features_df.shape)
print(clean_labels_df.shape)
clean_labels_df.to_csv("Data/Drosophilla/clean_labels.csv")
clean_features_df.to_csv("Data/Drosophilla/clean_features.csv")

#alt_clean_labels = df_label[~df_label.isin([np.nan, np.inf, -np.inf]).any(1)]
#print(alt_clean_labels)
#print(clean_labels_df)

# Dataset Loader
Now that we have labels assembled for a variety of TAD identifying metrics we use pytorch lightning datamodule to create a dataloader object


In [None]:
from pytorch_lightning import LightningDataModule
from torch.utils.data import Dataset, DataLoader
from sklearn import preprocessing

#TODO CHECK THIS
INSULATION_30 = 9

class FlyDataModule(LightningDataModule):
    class FlyDataset(Dataset):
        def _get_cell_line_idx(self, strin):
            if strin == "S2":
                return 2
            if strin =="KC":
                return 3
            if string == "BG":
                return 4
        
        def _get_label_col(self, label_type, label_val):
            if label_type == "gamma":
                return 7
            if label_type == "insulation":
                if label_val==10:
                    return 8
                if label_val==20:
                    return 9
                if label_val==30:
                    return 10
            if label_type =="directionality":
                if label_val ==10:
                    return 11
                if label_val ==20:
                    return 12
                if label_val ==30:
                    return 13
                
            
        def __init__(self,
                    cell_line,
                    tvt,
                    data_win_radius,
                    label_type,
                    label_val):
            self.cell_line       = cell_line
            self.data_win_radius = data_win_radius
            self.tvt             = tvt
            self.label_type      = label_type
            self.label_val       = label_val
            self.feature_vecs = []
            self.label_vecs   = []
        
            FEATURE_STRING = "Data/Drosophilla/clean_features.csv"
            LABEL_STRING   = "Data/Drosophilla/clean_labels.csv"
            
            
            cell_line_idx = self._get_cell_line_idx(cell_line)
        
            features       = np.loadtxt(FEATURE_STRING,
                                   delimiter=',',
                                   dtype=str)
        
            labels         = np.loadtxt(LABEL_STRING,
                                   delimiter=',',
                                   dtype=str)
            
            self.feature_head = features[0]
            self.label_head   = labels[0]
                        
            features          = features[1:]
            labels            = labels[1:]
        
            features          = features[features[:, cell_line_idx].astype(int)==1]
            labels            = labels[labels[:, cell_line_idx].astype(int)==1]
            features          = features[:,7:].astype(float)
            
            
            label_idx         = self._get_label_col(self.label_type, 
                                                    self.label_val)
            self.labels       = labels[:, label_idx].astype(float)
            
            
            #if label_type == "insulation" and label_val==30:
            #    self.labels = labels[:,INSULATION_30].astype(float)
            
            
            self.features = preprocessing.scale(features,
                                           axis=0,
                                           with_mean=True,
                                           with_std=True)
            
            feature_vecs = []
            label_vecs   = []
            
            for i in range(0, self.features.shape[0] - self.data_win_radius):
                start = i - self.data_win_radius
                end   = i + self.data_win_radius + 1
                if start < 0:
                    continue
                feature_vec = self.features[start:end]
                label_vec   = self.labels[start:end]
                feature_vecs.append(feature_vec)
                label_vecs.append(label_vec)
                
            self.feature_vecs = np.array(feature_vecs)
            self.label_vecs   = np.expand_dims(np.array(label_vecs), axis=2)
            
            if self.tvt == 'mini':
                self.feature_vecs = self.feature_vecs[0:100]
                self.label_vecs   = self.label_vecs[0:100]
                
            if self.tvt == 'train':
                cutoff=int(.7*self.feature_vecs.shape[0])
                self.feature_vecs = self.feature_vecs[:cutoff]
                self.label_vecs   = self.label_vecs[:cutoff]
                
            if self.tvt == 'val':
                cutoff1=int(.7*self.feature_vecs.shape[0])
                cutoff2=int(.85*self.feature_vecs.shape[0])
                self.feature_vecs = self.feature_vecs[cutoff1:cutoff2]
                self.label_vecs   = self.label_vecs[cutoff1:cutoff2]
                
            if self.tvt == 'test':
                cutoff=int(.85*self.feature_vecs.shape[0])
                self.feature_vecs = self.feature_vecs[cutoff:]
                self.label_vecs   = self.label_vecs[cutoff:]
        
        
        
        def __len__(self):
            return self.feature_vecs.shape[0]
        
        def __getitem__(self, idx):
            return self.feature_vecs[idx], self.label_vecs[idx]
    
    
    def __init__(self, 
                 cell_line,
                 data_win_radius,
                 batch_size,
                 label_type,
                 label_val='na'):
        super().__init__()
        self.batch_size      = batch_size
        self.cell_line       = cell_line
        self.data_win_radius = data_win_radius
        self.label_type      = label_type
        self.label_val       = label_val
    
    def setup(self):
        self.train = self.FlyDataset(cell_line=self.cell_line,
                        tvt="train",
                        data_win_radius=self.data_win_radius,
                        label_type=self.label_type,
                        label_val=self.label_val)
        
        self.val   = self.FlyDataset(cell_line=self.cell_line,
                        tvt="val",
                        data_win_radius=self.data_win_radius,
                        label_type=self.label_type,
                        label_val=self.label_val)
        
        self.test  = self.FlyDataset(cell_line=self.cell_line,
                        tvt="test",
                        data_win_radius=self.data_win_radius,
                        label_type=self.label_type,
                        label_val=self.label_val)
                                
        
        print("Everything set")
        
    def train_dataloader(self):
        return DataLoader(self.train, 
                          batch_size=self.batch_size,
                         num_workers=8)
    def val_dataloader(self):
        return DataLoader(self.val,
                          batch_size=self.batch_size,
                         num_workers=8)
    def test_dataloader(self):
        return DataLoader(self.test,
                        batch_size=self.batch_size,
                         num_workers=8)
    
    
        
            

In [None]:
dm = FlyDataModule(cell_line="S2",
                  data_win_radius=10,
                  batch_size=4,
                  label_type="insulation",
                  label_val=10)
dm.setup()

In [None]:
for b, batch in enumerate(dm.train_dataloader()):
    feature, label = batch
    feature = feature.float()
    label   = label.float()
    fig, ax = plt.subplots(2)
    print(feature.shape)
    ax[0].imshow(feature[0], vmin=0, vmax=5)
    ax[1].plot(label[0])
    ax[1].set_ylim(-.05,0.5)
    plt.show()

# Models
Now we show a few of the models that we are interested in.  The paper from which we derived our epigenetic dataset
https://peerj.com/articles/cs-307/ used a bidirection LSTMModel as well as regression and gradient boosting.  We expand the list of models to test to include:
    Gated Recurrent Units,
    Elman Recurrent Neural Netorks.
    
We also develop honeydew, a transformer based network.


In [None]:
import sys
sys.path.append("../")
sys.path.append(".")
import pytorch_lightning as pl
import torch
import pdb
from torch import nn
from torch.nn import RNN
from torch.nn import functional as F

class RNNModule(pl.LightningModule):

    def __init__(self,
            lr=1e-5,
            input_size=29,
            hidden_size=1,
            num_layers=1,
            dropout=0,
            bidirectional=False,
            optimi=torch.optim.Adam
            ):
        super().__init__()
        self.lr = lr
        self.input_size=input_size
        self.hidden_size=hidden_size
        self.num_layers=num_layers
        self.dropout=dropout
        self.bidirectional=bidirectional
        self.optimi=optimi
        self.rnn = nn.RNN(input_size=self.input_size,
                        hidden_size=self.hidden_size,
                        num_layers=self.num_layers,
                        nonlinearity='relu',
                        bias=True,
                        batch_first=True,
                        dropout=self.dropout,
                        bidirectional=self.bidirectional)
        self.save_hyperparameters()

    def forward(self, x):
        output, hn = self.rnn(x)
        return output

    def training_step(self, batch, batch_idx):
        feature, label = batch
        feature        = feature.float()
        label          = label.float()
        output, hn     = self.rnn(feature)
        loss           = F.mse_loss(output, label)
        self.log("mse_loss", loss)
        return loss

    def configure_optimizers(self):
        #optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        if self.optimi == "Adam":
            optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        if self.optimi == "SGD":
            optimizer = torch.optim.SGD(self.parameters(), lr=self.lr)
        if self.optimi == "RMSprop":
            optimizer = torch.optim.RMSprop(self.parameters(), lr=self.lr)
        return optimizer


In [None]:
# TODO Transformer
from torch import nn

class  TransformerModule(pl.LightningModule):
    def __init__(self):
        print("init")
        self.model = nn.Transformer(nhead=16,
                                    num_encoder_layers=12)
    def forward(self, x):
        print("forward")
    
    def training_step(self, batch, batch_idx):
        feature, label = batch
        feature        = feature.float()
        label          = label.float()
        out            = transformer_model()
    

# TODO this is debuging adjusy

In [None]:
model   = RNNModule(
    lr=.0001)

trainer = pl.Trainer(gpus=1,
                     max_epochs=50)

dm      = FlyDataModule(cell_line="S2",
                        batch_size=4,
                        data_win_radius=5,
                        label_type='gamma',
                        label_val=10)
dm.setup()
trainer.fit(model, dm)
                        

# Experiments
## Evaluation
Using all features to predict gamma


In [None]:
import glob
import yaml

dm = FlyDataModule(cell_line="S2",
                 data_win_radius=5,
                 batch_size=1,
                 label_type='gamma',
                 label_val=10)
dm.setup()
dataloader = dm.test_dataloader()
VERSION    = 0
PATH       = glob.glob("lightning_logs/version_"+str(VERSION)+"/checkpoints/*")[0]
op         = open("lightning_logs/version_"+str(VERSION)+"/hparams.yaml")
param      = yaml.load(op)
model      = RNNModule()
pretrained_model = model.load_from_checkpoint(PATH)
pretrained_model.freeze()
for b, batch in enumerate(dm.test_dataloader()):
    feature, label = batch
    feature = feature.float()
    label   = label.float()
    output   = pretrained_model(feature)
    if b == 50:
        fig, ax = plt.subplots(1)
        ax.plot(label[0,:,0].detach().numpy(),
                c="blue",
               label="label")
        ax.plot(output[0,:,0].detach().numpy(),
               c="green",
               label="prediction")
        plt.show()

## First we need a custom logging mechanis,

In [None]:
#pulled from https://github.com/optuna/optuna/issues/1186

import yaml
import os
from pytorch_lightning.loggers import LightningLoggerBase

class DictLogger(LightningLoggerBase):
    """PyTorch Lightning `dict` logger."""

    def __init__(self, version, root_dir):
        super(DictLogger, self).__init__()
        self.metrics = []
        self._version = version
        self.root_dir = root_dir

    def log_metrics(self, metrics, step=None):
        self.metrics.append(metrics)

    @property
    def version(self):
        return self._version

    @property
    def experiment(self):
        """Return the experiment object associated with this logger."""

    def log_hyperparams(self, params):
        """
        Record hyperparameters.
        Args:
            params: :class:`~argparse.Namespace` containing the hyperparameters
        """
        if not os.path.isdir(self.root_dir):
            os.mkdir(self.root_dir)
        if not os.path.isdir(self.root_dir+"/optuna"):
            os.mkdir(self.root_dir+"/optuna")
        dirr = self.root_dir+'/optuna/version_'+str(self._version)
        if not os.path.isdir(dirr):
            os.mkdir(self.root_dir+'/optuna/version_'+str(self._version))
        with open(dirr+"/hparams.yaml", 'w') as outfile:
            print("logging them hyperparams:"+str(self.root_dir))
            yaml.dump(params, outfile)

    @property
    def name(self):
        """Return the experiment name."""
        return 'optuna'    

# Experiment 1
Train each model architecture to predict gamma.
### todo for gamma prediction we follow the suggestion of https://peerj.com/articles/cs-307/ and use a loss function 

In [None]:
import glob
import yaml
import os
import optuna
import pytorch_lightning as pl


dm = FlyDataModule(cell_line="S2",
                 data_win_radius=5,
                 batch_size=1,
                 label_type='gamma')
dm.setup()

model_classes = [RNNModule]
for model_class in model_classes:
    def objective(trial):
        root_dir = "Experiments/Experiment_1_Gamma"
        if not os.path.isdir("Experiments"):
            os.mkdir("Experiments")
        if not os.path.isdir(root_dir):
            os.mkdir(root_dir)
        logger  = DictLogger(trial.number,
                            root_dir)
        trainer = pl.Trainer(
            logger=logger,
            gpus=1,
            max_epochs=10,
            default_root_dir=root_dir
        )
        
        #hyperparameters
        input_size=29
        lr         = trial.suggest_categorical("lr", [1e-5, 1e-4, 1e-3, 1e-2, 1e-1])
        num_layers = trial.suggest_categorical("num_layers", [1, 4, 8, 16, 32, 64, 128, 256, 512])
        optimi     = trial.suggest_categorical("optimi", ["Adam", "SGD"])
        model = model_class(lr=lr,
                           input_size=input_size,
                           num_layers=num_layers,
                           optimi=optimi)
        trainer.fit(model, dm)
        return logger.metrics[-1]['mse_loss']

        
    study = optuna.create_study(direction="minimize")
    study.optimize(objective, n_trials=2)
    print("Number of finished trials: {}".format(len(study.trials)))
    print("Best trial")
    trial = study.best_trial
    print(trial)
    print("value:{}".format(trial.value))
    print(" params:")
    for key, value in trial.params.items():
        print("  {}: {}".format(key, value))
        

We visualize results of Experiment 1:

In [None]:
DATA_WIN_RADIUS = 5


dm = FlyDataModule(cell_line="S2",
                 data_win_radius=DATA_WIN_RADIUS,
                 batch_size=1,
                 label_type='gamma',
                 label_val=10)
dm.setup()
dataloader = dm.test_dataloader()

fig, ax = plt.subplots(1, figsize=(20,20))

for VERSION in list(range(0,2)):
    PATH       = glob.glob("Experiments/Experiment_1_Gamma/optuna/version_"+str(VERSION)+"/checkpoints/*")[0]
    op         = open("Experiments/Experiment_1_Gamma/optuna/version_"+str(VERSION)+"/hparams.yaml")
    param      = yaml.load(op)
    model      = RNNModule()
    pretrained_model = model.load_from_checkpoint(PATH)
    pretrained_model.freeze()
    
    
    full_output = []
    full_label  = []
    for b, batch in enumerate(dm.test_dataloader()):
        feature, label = batch
        feature  = feature.float()
        label    = label.float()
        output   = pretrained_model(feature)
        if b %(2*DATA_WIN_RADIUS+1)==0:
            full_output.extend(output[0,:,0].detach().numpy())
            full_label.extend(label[0,:,0].detach().numpy())
            
    if b > 1:
        ax.plot(full_output,
            label="prediction"+str(VERSION),
            linewidth=5)
    
    if VERSION == 0:
       ax.plot(full_label,
             label="label",
             linewidth=5)
plt.legend()
plt.show()
        