# Distilling multiple models into a unified framework

Here we will train a distilled model on the predictions, or mean predictions from an ensemble of models trained on different folds.
The following steps will be used:
1. Create Training data

    a. Make predictions for original sequences

    b. Generate new sequences from the genome to make predictions

    c. Generate new sequences with variants

2. Take the mean of created data
3. Retrain the model, use artificial sequences for validation set

## 1. Create Training data from 10 fold models

### a. Make predictions for original sequences with all 10 models

In [None]:
import numpy as np
import os
import sys

drgclis = os.path.expanduser('~/Git/DRG/scripts/')

def generate_ensemble_predictions(modeldict, input_file, outpath, device='cpu', drgclis_path=None):
    """
    Generate predictions from multiple models and compute their mean.
    
    Parameters:
    -----------
    modeldict : dict
        Dictionary mapping fold names to model paths
    input_file : str
        Path to the input file for predictions
    outpath : str
        Output directory for saving predictions
    device : str, optional
        Device to use for computations ('cpu' or 'cuda'). Default is 'cpu'
    drgclis_path : str, optional
        Path to DRG scripts. If None, uses the global drgclis variable
        
    Returns:
    --------
    tuple
        (mean_values, columns, names) - averaged predictions and metadata
    """
    if drgclis_path is None:
        drgclis_path = drgclis
    
    # Generate predictions for each model
    for fold, model in modeldict.items():
        print(f'Processing {input_file} with model {model}')
        # Run model
        # keep track name files in that order because it's the same as during training.
        os.system(f'python {drgclis_path}train_models/run_cnn_model.py {input_file} None --predictnew --cnn {model}_model_params.dat device={device} --save_predictions --outname {outpath}')

    # Read in predictions from individual models and create new training set out of mean
    values_list = []
    prevcolumns = None
    prevnames = None

    for fold, model in modeldict.items():
        model_basename = os.path.basename(model)
        pred_file = f'{outpath}/from{model_basename}_predictions.npz'
        if os.path.exists(pred_file):
            with np.load(pred_file) as data:
                values = data['values']
                columns = data['columns']
                names = data['names']
                if len(values_list) == 0:
                    prevcolumns = columns
                    prevnames = names
                if np.array_equal(columns, prevcolumns) and np.array_equal(names, prevnames):
                    values_list.append(values)
                else:
                    print(f'Incompatible columns or names in {pred_file}. Skipping.')
        else:
            print(f'Prediction file {pred_file} not found.')

    if len(values_list) > 0:
        mean_values = np.mean(values_list, axis=0)
        return mean_values, prevcolumns, prevnames
    else:
        raise ValueError("No valid prediction files found.")

# Set up model dictionary
modeldict = {}
modelpath = os.path.abspath('./models/')
modelname = 'CTCFaH3K27acaH3K36me3aH33aH3K27me3aH3K4me1aATAConseq2krcomp_mh'
modelsuffix = '-cv10-1_Cormsek512l19TfEXPGELUmax10rcTvlCota_tc2dNoned1s1r1l7ma5nfc3s1024cbnoTfdo0.1tr1e-05SGD0.9bs64-F'

for f in range(10):
    modeldict['fold'+str(f)] = f'{modelpath}/{modelname}{f}{modelsuffix}'

# Set up paths and parameters
input_path = os.path.expanduser('./')
input_file = f'{input_path}seq2k.npz'
device = 'cpu'
outpath = os.path.expanduser(f'./output/')

# Generate ensemble predictions
mean_values, columns, names = generate_ensemble_predictions(
    modeldict, input_file, outpath, device=device, drgclis_path=drgclis
)

# Save the averaged predictions
np.savez_compressed(f'{outpath}{os.path.splitext(os.path.split(input_file)[1])[0]}.model.mean_predictions.npz', 
                   counts=mean_values, 
                   celltypes=columns, 
                   names=names)

### b. Generate new sequences from the genome and make predictions

In [None]:
# readin bed original bed file
data_path = '/home/sasse/UW/CutandRun/'
bed_file = f'{data_path}ImmGen_ATACpeak.final.bed6' # windows for signal

# Read bed and create novel 250bp region that are outside the 250 regions in the bed file
def determine_regions_between_peaks(bed_file, signal_window = None):
    with open(bed_file, 'r') as f:
        peaks = [line.strip().split('\t') for line in f.readlines()]
    peaks = np.array(peaks)
    if signal_window is None:
        signal_window = int(peaks[0, 2]) - int(peaks[0, 1])  # Assuming uniform window size from the first peak
    
    # Convert to intervals for each chromosome separately
    inbetween_peaks = []
    for chr in np.unique(peaks[:, 0]):
        chr_peaks = peaks[peaks[:, 0] == chr]
        between_peaks = np.stack([chr_peaks[:-1, 2], chr_peaks[1:, 1]], axis=1)
        between_peak_names = np.array([f"{chr}:{start}-{end}" for start, end in between_peaks])
        between_peak_chrs = np.array([chr] * len(between_peak_names))
        # check if between peaks > signal_window
        valid_intervals = between_peaks[:, 1] - between_peaks[:, 0] > signal_window
        valid_between_peaks = np.stack([between_peak_chrs[valid_intervals], between_peaks[valid_intervals]], axis=1)
        inbetween_peaks.append(np.stack([valid_between_peaks, between_peak_names[valid_intervals]], axis=1))
    inbetween_peaks = np.concatenate(inbetween_peaks, axis=0)
    return inbetween_peaks

region_size = 250
sequence_length = 2000

inbetween_peaks = determine_regions_between_peaks(bed_file, signal_window=region_size)

# create function that splits the regions between peaks into regions of region_size, with and without overlap defined by step_size
def split_regions_into_chunks(regions, region_size, step_size=None, can_overlap = False):
    """
    Split regions into chunks of a given size with optional overlap.
    """
    if step_size is None:
        step_size = region_size

    chunks = []
    for region in regions:
        start = region[1]
        end = region[2]
        ct = 0
        while start < end:
            chunk_end = start + region_size
            if not can_overlap and chunk_end > end:
                break
            chunks.append([region[0], start, chunk_end, f'{region[3]}_{ct}'])
            start += step_size
            ct += 1
            if chunk_end > end:
                break
    return np.array(chunks)

inbetween_chunks = split_regions_into_chunks(inbetween_peaks, region_size, step_size=None, can_overlap=False)
print(f'Created {len(inbetween_chunks)} chunks')

if len(inbetween_chunks) > 1000000:
    # Randomly downsample to 1000000 regions
    np.random.seed(42)  # For reproducibility
    inbetween_chunks = inbetween_chunks[np.random.choice(inbetween_chunks.shape[0], 1000000, replace=False)]

# extend the chunks to sequence_length
extension = (sequence_length - region_size) // 2

# Create sequences for the chunks with the genome chromosome files
genome_path = '/home/sasse/data/genomes/mm10/'
import drg_tools as drg
from drg_tools import io_utils as utils
from drg_tools import sequence_utils as sutils

seqnames, seqs = utils.extract_sequences_from_bed(bed_file, genome_path, extend_before = extension, extend_after = extension)

onehot = []
for s, seq in enumerate(seqs):
    onehot.append(sutils.seq_onehot(seq))

# Convert the list of one-hot encoded sequences to a numpy array
onehot = np.array(onehot)
seqnames = np.array(seqnames)
# Print the shape of the one-hot encoded sequences
print(onehot.shape)
# Save the one-hot encoded sequences to a numpy file
output_path = os.path.splitext(bed_file)[0] + 'inbetween.oh.npz'
# Save the one-hot encoded sequences and names to a numpy file
# call the arrays seqfeatures and genenames for compatibility with the rest of the code
np.savez(output_path, seqfeatures = onehot, genenames = seqnames)
print(f"One-hot encoded sequences saved to {output_path}")

In [None]:
# Create mean prediction training data for new sequences 
# Set up paths and parameters
input_path = os.path.expanduser('./')
input_file = output_path  # Use the one-hot encoded sequences file
device = 'cpu'
outpath = os.path.expanduser(f'./output/')

# Generate ensemble predictions
mean_values, columns, names = generate_ensemble_predictions(
    modeldict, input_file, outpath, device=device, drgclis_path=drgclis
)

# Save the averaged predictions
np.savez_compressed(f'{outpath}{os.path.splitext(os.path.split(input_file)[1])[0]}.model.mean_predictions.npz', 
                   counts=mean_values, 
                   celltypes=columns, 
                   names=names)

### c. Generate new sequences with variants in original sequences

In [None]:
# Set up paths and parameters
input_path = os.path.expanduser('./')
input_file = f'{input_path}seq2k.npz'
# Load one-hot encoded sequences
onehot = np.load(input_file)
# Randomly insert variants with a certain probability
variant_probability = 0.1
for i in range(onehot.shape[0]):
   # Insert a variant of frequency variant_probability
   positions = np.random.choice(0, onehot.shape[1], size=int(variant_probability * onehot.shape[1]), replace=False)
   oh = onehot[i]
   oh[positions, :] = 0
   oh[positions, np.random.choice(0, onehot.shape[2], size=len(positions), replace=True)] = 1
   onehot[i] = oh

output_path = os.path.splitext(input_file)[0] + f'mutated{variant_probability}.oh.npz'
# Save the one-hot encoded sequences and names to a numpy file
# call the arrays seqfeatures and genenames for compatibility with the rest of the code
np.savez(output_path, seqfeatures = onehot, genenames = seqnames)
print(f"One-hot encoded sequences saved to {output_path}")

## 2. Train distillation model on generated mean predictions

### a. Train on different combinations of generated data

### b. Assess the performance on real data 