In [None]:
import pandas as pd
from grelu.data.preprocess import filter_blacklist, filter_chromosomes
from grelu.variant import filter_variants
import os
from typing import Callable, List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
import scipy
from einops import rearrange
from torch import Tensor
from torch.utils.data import Dataset

from grelu.data.augment import Augmenter, _split_overall_idx
from grelu.data.utils import _check_multiclass, _create_task_data
from grelu.sequence.format import (
    INDEX_TO_BASE_HASH,
    check_intervals,
    convert_input_type,
    get_input_type,
    indices_to_one_hot,
    strings_to_indices,
)
from grelu.sequence.mutate import mutate
from grelu.sequence.utils import dinuc_shuffle, get_lengths, resize
from grelu.utils import get_aggfunc, get_transform_func

In [None]:
df = pd.read_csv("EUR.csv") # this file can be downloaded from UKBioBank website.

df['chrom'] =['chr' + str(i) for i in df['chrom']]
variants = df
variants = filter_chromosomes(variants, include='autosomesXY')
variants = filter_blacklist(variants, genome="hg19").reset_index(drop=True)

In [None]:
class VariantLenDataset(Dataset):
    """
    Dataset class to perform inference on sequence variants.

    Args:
        variants: pd.DataFrame with columns "chrom", "pos", "ref", "alt".
        seq_len: Uniform expected length (in base pairs) for output sequences
        genome: The name of the genome from which to read sequences.
        rc: If True, sequences will be augmented by reverse complementation. If
            False, they will not be reverse complemented.
        max_seq_shift: Maximum number of bases to shift the sequence for augmentation.
            This is normally a small value (< 10). If 0, sequences will not
            be augmented by shifting.
        frac_mutation: Fraction of bases to randomly mutate for data augmentation.
        protect: A list of positions to protect from mutation.
        n_mutated_seqs: Number of mutated sequences to generate from each input
            sequence for data augmentation.
    """

    def __init__(
        self,
        variants: pd.DataFrame,
        seq_len: int,
        genome: Optional[str] = None,
        rc: bool = False,
        max_seq_shift: int = 0,
        frac_mutation: float = 0.0,
        n_mutated_seqs: int = 1,
        protect: Optional[List[int]] = None,
        seed: Optional[int] = None,
        augment_mode: str = "serial",
    ) -> None:
        # Save params
        self.genome = genome
        self.seq_len = seq_len

        # Save augmentation params
        self.rc = rc
        self.max_seq_shift = max_seq_shift
        self.frac_mutated_bases = frac_mutation
        self.n_mutated_bases = int(self.frac_mutated_bases * self.seq_len)
        self.n_mutated_seqs = n_mutated_seqs

        # Ingest alleles
        self._load_alleles(variants)
        self.n_alleles = 2

        # Ingest sequences
        self._load_seqs(variants)
        self.n_seqs = self.seqs.shape[0]

        # Protect central positions for mutation
        if protect is None:
            self.protect = [seq_len // 2]
        else:
            self.protect = protect

        # Create augmenter
        self.augmenter = Augmenter(
            rc=self.rc,
            max_seq_shift=self.max_seq_shift,
            n_mutated_seqs=self.n_mutated_seqs,
            n_mutated_bases=self.n_mutated_bases,
            protect=self.protect,
            seq_len=self.seq_len,
            seed=seed,
            mode=augment_mode,
        )
        self.n_augmented = len(self.augmenter)

    def _load_alleles(self, variants: pd.DataFrame) -> None:
        try:
            self.ref = strings_to_indices(variants.ref.tolist())
            self.alt = strings_to_indices(list(variants.alt.values))
        except:
            self.ref = strings_to_indices([variants.ref])
            self.alt = strings_to_indices([variants.alt])

    def _load_seqs(self, variants: pd.DataFrame) -> None:
        from grelu.variant import variants_to_intervals

        self.padded_seq_len = self.seq_len + (2 * self.max_seq_shift)
        self.intervals = variants_to_intervals(variants, seq_len=self.padded_seq_len)
        self.seqs = convert_input_type(self.intervals, "indices", genome=self.genome)

    def __len__(self) -> int:
        return self.n_seqs * self.n_augmented * 2

    def __getitem__(self, idx: int) -> Tensor:
        # Get indices
        seq_idx, augment_idx, allele_idx = _split_overall_idx(
            idx, (self.n_seqs, self.n_augmented, self.n_alleles)
        )

        # Extract current sequence and alleles
        seq = self.seqs[seq_idx]

        # Insert the allele
        if allele_idx:
            alt = self.alt[seq_idx]
            seq = mutate(seq, alt, input_type="indices")
        else:
            ref = self.ref[seq_idx]
            seq = mutate(seq, ref, input_type="indices")

        # Augment current sequence
        seq = self.augmenter(seq=seq, idx=augment_idx)

        # One-hot encode
        return indices_to_one_hot(seq)

seq_all = []
for i in variants.index:
    df_new = pd.DataFrame()
    for item in variants.columns:
        df_new[item] = [variants.loc[i][item]]
    print(df_new)
    vr = grelu.variant.variant_to_seqs(df_new['chrom'].values[0],df_new['pos'].values[0], df_new['ref'].values[0], df_new['alt'].values[0], 'hg19', 512)
    seq_all.append(vr[1])

In [None]:
import pickle
with open('variant_list_part1.pickle', 'wb') as handle:
    pickle.dump(seq_all, handle, protocol=pickle.HIGHEST_PROTOCOL)

df = pd.DataFrame()
df['text'] = seq_all

df_train = df.iloc[0:int(len(df)*0.8)]
df_test = df.iloc[int(len(df)*0.8):]



df_train.to_csv("./ukbb_allinfo_13000000_train.csv", index=None)
df_test.to_csv("./ukbb_allinfo_13000000_test.csv", index=None)