# Deep Learning in 3D Genome (Tutorial)

## Explore model

In [None]:
import json
import subprocess
# os.environ["CUDA_VISIBLE_DEVICES"] = '-1' ### run on CPU

from cooltools.lib.numutils import set_diag
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import pysam
import tensorflow as tf
from basenji import dataset, dna_io, seqnn

In [None]:
import tensorflow as tf
print(tf.__version__)

### Understanding the Model Architecture
Akita (and other basenji-derived models) stores the information needed to specify its architecture in a json file. Achitectures are specified in terms of blocks and layers, where each block can have multiple layers.

The weights for Akita, which were learned via training, are stored separately in a standard hdf5 file generated by tensorflow.

Before diving into the Keras model summary, let's examine the structure/architecture of a model and understand how the shapes of tensors change throughout the model. Below we print the part of the json file that specifies parameters of the first few blocks.

Note the first block is a convolution followed by pooling. The second block is a tower of repeated convolutions followed by pooling. Together this pooling takes us from an input sequence of 1310720 to a set of latent profiles binned at $2^{11}$ bp resolution. This is also the resolution of the target maps, as no more pooling is performed in this network.


In [None]:
import json

# Open and read the JSON file
model_dir = './tutorial_materials/akita_v2/'
params_file = model_dir+'params.json' # architecture
with open(params_file, 'r') as file:
    params = json.load(file)
    model_architecture = params['model'] # Retrieve model's architecture from params.json

model_architecture['trunk'][:2]

Below we initialize the model architecture and print the parital keras summary of the first seven layers.

These incude: \
1. **Input Layer**: The input layer receives sequences in a specific shape (e.g., `(batch_size, sequence_length, channels)`).
2. **First Convolutional Layer**: Applies convolution with specified filters and kernel size, resulting in shape transformation.
3. **Pooling Layer**: Reduces the spatial dimensions.
4. **ReLu Layer**: Activation layer, providing non-linearity to the network; returns zero for negative inputs, and the same value for positive inputs.

Note that the first dimension is always None, as models can process a flexibile number of sequences (determined by the batch size) both when training or making predictions. See [tensorflow docs](https://www.tensorflow.org/js/guide/models_and_layers#model_summary) for more information.

The stochastic shift and stochastic reverse compliment layers are used by Akita (and Basenji) for data augmentation during training.


In [None]:
from akita_utils import print_partial_model_summary

human_model = seqnn.SeqNN(model_architecture)
print_partial_model_summary(human_model.model, num_layers=7)

- ðŸŒŸ Once we initialize the architecture of Akita model, we will need to restore the pretrained model weights. Since weights correspond to specific layers, successfully restoring weights requires that they are used with the same architecture as used when the model was originally trained. Here we use the architecture specified in the params.json file.

In [None]:
head = 0 # index 0 indicates human model
weights_file  = model_dir+f'model{head}_best.h5' # model_weights
human_model.restore(weights_file)
print('successfully loaded')

### Make a prediction from sequence (for human)


- ðŸŒŸ A FASTA file is a text-based format for representing nucleotide or peptide sequences. Each sequence in a FASTA file is preceded by a single-line description starting with a `>` character, followed by lines of sequence data. \
\
Example:

> \>sequence1 \
ATGCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTA \
\>sequence2 \
CGTAGCTAGCTAGCTAGCTAACGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGC


- ðŸŒŸ We use pysam to read sequence from fasta file, and encode the four nucleotides into one-hot encoding format, which is a binary matrix. \

<center>

| Nucleotide | One-Hot Encoding  |
|------------|-------------------|
| A          | [1, 0, 0, 0]      |
| C          | [0, 1, 0, 0]      |
| G          | [0, 0, 1, 0]      |
| T          | [0, 0, 0, 1]      |
<center/>

In [None]:
fasta_file = pysam.FastaFile('/content/akita_tutorial/tutorial_materials/data/chr11_99928064_101238784.fasta')
seq = fasta_file.fetch(fasta_file.references[0])
seq_1hot = dna_io.dna_1hot(seq)

# expand input dimensions, as model accepts arrays of size [#regions,2^20bp, 4]
test_pred_from_seq = human_model.model.predict(np.expand_dims(seq_1hot,0))

- ðŸŒŸ The last step is plotting the HiC map! Since the model has multiple celltype outputs, for plotting an example, we just visualize the first cell type, which is indicates by `target_index` variable \
\
Note: Since model outputs are flattened to be upper triangular, additional information on the amount of cropping and diagonal offset is required to convert this output to a square matrix used for visualization.

In [None]:
# read data parameters
data_stats_file = './tutorial_materials/akita_v2/data_params.json'
with open(data_stats_file) as data_stats_open:
    data_stats = json.load(data_stats_open)
seq_length = data_stats['seq_length']
target_length = data_stats['target_length']
hic_diags =  data_stats['diagonal_offset']
target_crop = data_stats['crop_bp'] // data_stats['pool_width']
target_length1 = data_stats['seq_length'] // data_stats['pool_width']

target_length1_cropped = target_length1 - 2*target_crop
def from_upper_triu(vector_repr, matrix_len, num_diags):
    z = np.zeros((matrix_len,matrix_len))
    triu_tup = np.triu_indices(matrix_len,num_diags)
    z[triu_tup] = vector_repr
    for i in range(-num_diags+1,num_diags):
        set_diag(z, np.nan, i)
    return z + z.T

#transform from flattened representation to symmetric matrix representation
target_index = 0
mat = from_upper_triu(test_pred_from_seq[:,:,target_index], target_length1_cropped, hic_diags)

In [None]:
plt.figure(figsize=(8,4))
vmin=-2; vmax=2

plt.subplot(121)
im = plt.matshow(mat, fignum=False, cmap= 'RdBu_r', vmax=vmax, vmin=vmin)
plt.colorbar(im, fraction=.04, pad = 0.05, ticks=[-2,-1, 0, 1,2]);
plt.title('Akita prediction',y=1.15 )
plt.ylabel('chr11:99928064-101238784')

# plot target
test_target = np.load('/content/akita_tutorial/tutorial_materials/data/chr11_99928064_101238784_target.npy')
plt.subplot(122)
mat = from_upper_triu(test_target[:,:,target_index], target_length1_cropped, hic_diags)
im = plt.matshow(mat, fignum=False, cmap= 'RdBu_r', vmax=vmax, vmin=vmin)
plt.colorbar(im, fraction=.04, pad = 0.05, ticks=[-2,-1, 0, 1,2]);
plt.title('target',y=1.15)

plt.tight_layout()
plt.show()

### Predictions for distrupted genomic sequences (for mouse)

- ðŸŒŸ Next, let's delve into more uses of the Akita model. With Akita, we can simulate the effects of sequence perturbations on genomic 3D organization. To start, we'll first download the reference genome of the mouse.

In [None]:
if not os.path.isfile('./tutorial_materials/data/mm10.ml.fa'):
    print('downloading mm10.ml.fa')
    subprocess.call('curl -o ./tutorial_materials/data/mm10.ml.fa.gz ftp://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_mouse/release_M22/GRCm38.p6.genome.fa.gz', shell=True)
    subprocess.call('gunzip ./tutorial_materials/data/mm10.ml.fa.gz', shell=True)

fasta_open = pysam.Fastafile('./tutorial_materials/data/mm10.ml.fa')

- ðŸŒŸ Load a table that stores genomic locations that contain CTCF motifs and a table records chromosome sizes of mouse genome

In [None]:
# model arguments
CTCT_table = "./tutorial_materials/data/insertion_disruption_tsvs/disruption_examples.tsv"
chrom_sizes = "./tutorial_materials/data/mm10.fa.sizes"
chrom_sizes_table = pd.read_csv(chrom_sizes, sep="\t", names=["chrom", "size"])

- ðŸŒŸ This time we will load the pretrained weights for mouse instead of human, but the model architecture is same for both of mouse and human model

In [None]:
head = 1 # index 1 indicates mouse model
shifts = "0"
shifts = [int(shift) for shift in shifts.split(",")]
rc = False
weights_file  = model_dir+f'model{head}_best.h5' # model_weights

mouse_model = seqnn.SeqNN(model_architecture)
mouse_model.restore(weights_file, head_i=head)
mouse_model.build_ensemble(rc, shifts)

- ðŸŒŸ Next, we need a the function `central_permutation_seqs_gen` to permutate sequences that contains CTCF motif

In [None]:
from basenji import stream
from akita_utils import central_permutation_seqs_gen, ut_dense

batch_size=8
seq_coords_df = pd.read_csv(CTCT_table, sep="\t")

preds_stream = stream.PredStreamGen(
        mouse_model,
        central_permutation_seqs_gen(seq_coords_df, fasta_open, chrom_sizes_table),
        batch_size,
    )

- ðŸŒŸ Now, we will see how permutation on sequence will affect genome 3D organization from strong effect to weak effect

In [None]:
from akita_utils import central_permutation_seqs_gen, ut_dense
def central_permutation_seqs_gen(
    seq_coords_df,
    genome_open,
    chrom_sizes_table,
    permutation_window_shift=0,
    revcomp=False,
    seq_length=1310720,
):
    """
    Generates sequences for a set of genomic coordinates, applying central permutations and optionally
    operating on reverse complements, with an additional option for shifting the permutation window.

    This generator function takes a DataFrame `seq_coords_df` containing genomic coordinates
    (chromosome, start, end, strand), a genome file handler `genome_open` to fetch sequences, and
    a table of chromosome sizes `chrom_sizes_table`. It yields sequences with central permutations
    around the coordinates specified in `seq_coords_df`, considering an optional shift for the
    permutation window. If `rc` is True, the reverse complement of these sequences is generated.

    Parameters:
    - seq_coords_df (pandas.DataFrame): DataFrame with columns 'chrom', 'start', 'end', 'strand',
                                        representing genomic coordinates of interest.
    - genome_open (GenomeFileHandler): A file handler for the genome to fetch sequences.
    - chrom_sizes_table (pandas.DataFrame): DataFrame with columns 'chrom' and 'size', representing
                                            the sizes of chromosomes in the genome.
    - permutation_window_shift (int, optional): The number of base pairs to shift the center of the
                                                 permutation window. Default is 0.
    - rc (bool, optional): If True, operates on reverse complement of the sequences. Default is False.
    - seq_length (int, optional): The total length of the sequence to be generated. Default is 1310720.

    Yields:
    numpy.ndarray: One-hot encoded DNA sequences. Each sequence is either the original or its central
                   permutation, with or without reverse complement as specified by `rc`.

    Raises:
    Exception: If the prediction window for a given span cannot be centered within the chromosome.
    """

    for s in seq_coords_df.itertuples():
        list_1hot = []

        chrom, window_start, window_end = expand_and_check_window(
            s, chrom_sizes_table, shift=permutation_window_shift
        )
        permutation_start, permutation_end = get_relative_window_coordinates(
            s, shift=permutation_window_shift
        )

        wt_seq_1hot = dna_1hot(
            genome_open.fetch(chrom, window_start, window_end).upper()
        )
        if revcomp:
            rc_wt_seq_1hot = hot1_rc(wt_seq_1hot)
            list_1hot.append(rc_wt_seq_1hot.copy())
        else:
            list_1hot.append(wt_seq_1hot.copy())

        ### MODIFY HERE ###
        alt_seq_1hot = wt_seq_1hot.copy()
        permuted_span = permute_seq_k(
            alt_seq_1hot[permutation_start:permutation_end], k=1
        )
        alt_seq_1hot[permutation_start:permutation_end] = permuted_span
        ### MODIFY HERE ###

        if revcomp:
            rc_alt_seq_1hot = hot1_rc(alt_seq_1hot.copy())
            list_1hot.append(rc_alt_seq_1hot)
        else:
            list_1hot.append(alt_seq_1hot)

        # yielding first the reference, then the permuted sequence
        for sequence in list_1hot:
            yield sequence

In [None]:
target_index = 1 # cell type
num_experiments = len(seq_coords_df)

for ref_index in range(0, num_experiments*2, 2):

    ref_preds_matrix = preds_stream[ref_index]
    permut_index = ref_index + 1
    permuted_preds_matrix = preds_stream[permut_index]
    exp_index = ref_index//2

    ref_maps = ut_dense(ref_preds_matrix)
    perm_maps = ut_dense(permuted_preds_matrix)

    fig, axs = plt.subplots(1, 3, figsize=(10, 5))
    sns.heatmap(
        ref_maps[:,:,target_index],
        vmin=-0.6,
        vmax=0.6,
        cbar=False,
        cmap="RdBu_r",
        square=True,
        xticklabels=False,
        yticklabels=False,
        ax=axs[0]
    )
    axs[0].set_title('Reference')

    sns.heatmap(
        perm_maps[:,:,target_index],
        vmin=-0.6,
        vmax=0.6,
        cbar=False,
        cmap="RdBu_r",
        square=True,
        xticklabels=False,
        yticklabels=False,
        ax=axs[1]
    )
    axs[1].set_title('Permuted')

    sns.heatmap(
        perm_maps[:,:,target_index]-ref_maps[:,:,target_index],
        vmin=-0.6,
        vmax=0.6,
        cbar=True,
        cmap="PiYG_r",
        square=True,
        xticklabels=False,
        yticklabels=False,
        ax=axs[2]
    )
    axs[2].set_title('Difference')

    plt.tight_layout()
    plt.show()

## Predictions for inserted CTCF sequences (for mouse)

In [None]:
from akita_utils import dna_1hot, hot1_rc, _insert_casette

def symmertic_insertion_seqs_gen(seq_coords_df, background_seqs, genome_open, nproc=1, map=map):
    """
    Generate sequences with symmetric insertions for a given set of coordinates.

    This generator function takes a DataFrame `seq_coords_df` containing genomic
    coordinates, a list of background sequences `background_seqs`, and a genome file
    handler `genome_open`. It yields one-hot encoded DNA sequences with symmetric
    insertions based on the specified coordinates.

    Parameters:
    - seq_coords_df (pandas.DataFrame): DataFrame with columns 'chrom', 'start', 'end',
                                         'strand', 'flank_bp', 'spacer_bp', 'orientation'.
                                         Represents genomic coordinates and insertion parameters.
    - background_seqs (List[numpy.ndarray]): List of background sequences to be modified.
    - genome_open (GenomeFileHandler): A file handler for the genome to fetch sequences.

    Yields:
    numpy.ndarray: One-hot encoded DNA sequence with symmetric insertions.
    """

    for s in seq_coords_df.itertuples():

        flank_bp = s.flank_bp
        spacer_bp = s.spacer_bp
        orientation_string = s.orientation

        seq_1hot_insertion = dna_1hot(
            genome_open.fetch(
                s.chrom, s.start - flank_bp, s.end + flank_bp
            ).upper()
        )

        if s.strand == "-":
            seq_1hot_insertion = hot1_rc(seq_1hot_insertion)
            # now, all motifs are standarized to this orientation ">"

        seq_1hot = background_seqs[s.background_index].copy()

        ### MODIFY HERE ###
        seq_1hot = _insert_casette(
            seq_1hot, seq_1hot_insertion, spacer_bp, orientation_string
        )
        ### MODIFY HERE ###
        # Parameters of _insert_casette:
        # - seq_1hot (numpy.ndarray): One-hot encoded DNA sequence to be modified.
        # - seq_1hot_insertion (numpy.ndarray): One-hot encoded DNA sequence to be inserted.
        # - spacer_bp (int): Number of base pairs for intert-insert spacers.
        # - orientation_string (str): String specifying the orientation and number of insertions.
        #                         '>' denotes forward orientation, and '<' denotes reverse.

        yield seq_1hot

In [None]:
CTCT_table = "./tutorial_materials/data/insertion_disruption_tsvs/insertion_examples.tsv"
background_file = "./tutorial_materials/data/background_sequences_model_0.fa"

In [None]:
seq_coords_df = pd.read_csv(CTCT_table, sep="\t")

In [None]:
from akita_utils import dna_1hot

background_seqs = []

with open(background_file, "r") as f:
    for line in f.readlines():
        if ">" in line:
            continue
        background_seqs.append(dna_1hot(line.strip()))

In [None]:
# predictions for references
backgrounds_predictions = mouse_model.predict(np.array(background_seqs), batch_size=batch_size)

In [None]:
from akita_utils import symmertic_insertion_seqs_gen

preds_stream = stream.PredStreamGen(
        mouse_model,
        symmertic_insertion_seqs_gen(seq_coords_df, background_seqs, fasta_open),
        batch_size,
    )

In [None]:
target_index = 1 #cell type

for exp_index in range(num_experiments):

    bg_index = seq_coords_df.iloc[exp_index].background_index

    prediction_matrix = preds_stream[exp_index]
    reference_prediction_matrix = backgrounds_predictions[bg_index, :, :]

    ref_maps = ut_dense(reference_prediction_matrix)
    alt_maps = ut_dense(prediction_matrix)

    fig, axs = plt.subplots(1, 3, figsize=(10, 4))
    sns.heatmap(
        ref_maps[:,:,target_index],
        vmin=-0.6,
        vmax=0.6,
        cbar=False,
        cmap="RdBu_r",
        square=True,
        xticklabels=False,
        yticklabels=False,
        ax=axs[0]
    )
    axs[0].set_title('Reference')

    sns.heatmap(
        alt_maps[:,:,target_index],
        vmin=-0.6,
        vmax=0.6,
        cbar=False,
        cmap="RdBu_r",
        square=True,
        xticklabels=False,
        yticklabels=False,
        ax=axs[1]
    )
    axs[1].set_title('Inserted')

    sns.heatmap(
        alt_maps[:,:,target_index]-ref_maps[:,:,target_index],
        vmin=-0.6,
        vmax=0.6,
        cbar=True,
        cmap="PiYG_r",
        square=True,
        xticklabels=False,
        yticklabels=False,
        ax=axs[2]
    )
    axs[2].set_title('Difference')

    plt.tight_layout()
    plt.show()

# Deep Learning in 3D Genome (Training Part Tutorial)

## Understanding the procedure of model training



In [None]:
from akita_utils import show_targets, show_prediction

- ðŸŒŸ In the interest of time, we use synthetic data to describe how we go about training an Akita model.\
\
For synthetic data we have freedom to both choose both the size of input sequences as well as the rules used to generate target outputs. Because of this we can design datasets where training is much faster than training on real genomic sequences and experimental Hi-C data by using shorter sequences (here ~32 kb instead of ~1.4 Mb) with simpler features to learn.\
\
More generallh, using synthetic data is a useful strategy for model development, including debugging data preprocessing, or a model's loss function and architecture. \
\
To create this synthetic data, we:
- (i) generate random DNA sequences with size of 32768 bp.
- (ii) for each sequences we add 4-8 motifs randomly.
- (iii) create a corresponding output map by placing boundaries based on where motifs are inserted, which means we create squares between inserted CTCF motifs. \
\
Let's generate the training data with `generate_training_data.py`, which is under `/tutorial_materials/training_materials/`. \

In [None]:
%cd /content/akita_tutorial
!/usr/local/bin/python3 ./tutorial_materials/training_materials/generate_training_data.py

In [None]:
data_dir = './tutorial_materials/training_materials'

# plotting utility to show three synthetic target maps
show_targets(data_dir, split_label='train')

- ðŸŒŸ Then, let's take a look at prediction of a naive model (without training). We can tell predictions are far away from targets

In [None]:
model_dir = './tutorial_materials/training_materials/'
show_prediction(data_dir, model_dir, split_label='train')

- ðŸŒŸ Let's take a look at layers of this model, you will find that combining pooling layers takes us from an input sequence of 32,768 to a set of latent profiles binned at a 64 bp resolution, which is also the resolution of the target maps.

In [None]:
params_file = model_dir+'params.json' # architecture
with open(params_file, 'r') as file:
    params = json.load(file)
    model_architecture = params['model']
seqnn_model = seqnn.SeqNN(model_architecture)
seqnn_model.model.summary()


- ðŸŒŸ Before training the model, let's briefly go through the params.json to understand some parameters for training procedure. The params.json is composed of two parts "train", which includes parameters for instructing training procedure, and "model", which includes parameters for creating model architecture. Here, we will focus on "train" part

In [None]:
model_dir = './tutorial_materials/training_materials/'
params_file = model_dir+'params.json' # architecture
with open(params_file, 'r') as file:
    params = json.load(file)
    train_params = params['train'] # Retrieve model's architecture from params.json

train_params

- ðŸŒŸ By tuning these parameters, we can control the training of a model to achieve better performance and avoid common issues like overfitting and exploding gradients:
> 1. **batch_size**:
  The batch size defines the number of samples that will be passed into the network at one time. In this case, we use a batch size of 8. A smaller batch size can lead to more noisy estimates of the gradient but can require less memory. For long input sequences as we use for Akita, total GPU memory poses an important limitation on batch size.
  2. **optimizer**:
  The optimizer used for training the model. In this case, we are using the Adam optimizer, which is an adaptive learning rate optimization algorithm that's been designed specifically for training deep neural networks.
  3. **learning_rate**:
  Learning rate is a hyperparameter that controls how much to change the model in response to the estimated error each time the model weights are updated. A higher learning rate can lead to faster convergence but might overshoot the minimum, while a lower learning rate might result in slower convergence.
  4. **momentum**:
  Momentum is a hyperparameter that controls how heavily to weight previous updates when making the current update, and can lead to faster convergence to better local optima.
  5. **loss**:
  The loss function used to measure the performance of the model. Here, we are using Mean Squared Error (MSE), which calculates the average of the squares of the errors between predicted and actual values. MSE is commonly used for regression tasks.
  6. **patience**:  
  Patience is a parameter used in early stopping, which stops the training process if the validation loss for a model does not improve after a certain number of epochs (i.e. full passes through the training data set). Here, the model will stop training if it does not improve for 8 consecutive epochs.
  7. **clip_norm**:
  Gradient clipping is a technique to prevent exploding gradients in very deep networks. The clip norm parameter specifies the maximum norm for the gradients. If the gradients exceed this norm, they will be scaled down to the maximum norm of 10.0.


- ðŸŒŸ Now let's run the training script! \
\
This script iterates over all of the `train.tfr` TfRecords of synthetic data contained in a folder in `tutorial_materials/training_materials/`. \
\
Model architecture and training parameters are specified with the `params.json`. Updates to model parameters after each batch are determined based on the loss and the training parameters. The model will stop training after the validation loss saturates (on the `valid.tfr` files) and the patience is exceeded. \
\
Model checkpoints, including the weights, will be stored in the folder `train_out/`. \
\
We use a customized plotting function to visualize the improvement in model performance over time on both the training and validation sets.


In [None]:
%run bin/akita_train.py -k -o ./train_out/  tutorial_materials/training_materials/params.json tutorial_materials/training_materials/

- ðŸŒŸ Finally, let's examine the model's performance after training

In [None]:
show_targets(data_dir, split_label='test')

In [None]:
model_dir = './train_out/'
show_prediction(data_dir, model_dir, restore_weights=True, split_label='test')