# Preprocessing routine

In [None]:
#Numeric
import numpy as np
import pandas as pd
#DL
import keras
import tensorflow as tf
#Sytem
from pymongo import MongoClient
#Tokenizers
import sentencepiece as spm

## 📁 Define Paths and Database Parameters

We define variables for:

- MongoDB database and collection names.
- BPE tokenizer model.
- A csv with the required train, tune, test partition IDs.

In [None]:
db_name = "------"
collection_name = "------"
tokenizer_model_path = "------"
partitions_path = "------"

## Data Generator

This section shows:

- Implementing a custom `DataGenerator` class that inherits from `keras.utils.Sequence` to efficiently load and preprocess data in batches for model training.
- Instantiating the data generator for the training set and iterating through batches to generate input-output pairs.
- Saving the `DataGenerator` class as a utility Python file for reuse.

This approach enables scalable and reproducible data preprocessing for downstream machine learning tasks.}

DataGenerator class to yield training batches for embedding models with the following features:

- **`db_name`**: Name of the MongoDB database holding genotype data.  
- **`collection_name`**: Name of the collection with gene sequence records.  
- **`organism_IDs`**: List of organism IDs sampled randomly per batch to pair with gene IDs.  
- **`tokenizer_path`**: Path prefix to the SentencePiece tokenizer model (without `.model` extension).  
- **`shuffle`**: Boolean flag to shuffle gene IDs at the end of each epoch to improve training randomness.  
- **`batch_size`**: Number of gene samples processed per batch.  
- **`context_size`**: Size of the sliding window used to extract context words around an anchor word in sequences.  
- **`negative_samples`**: Number of negative context samples generated per positive anchor-context pair.  
- **`vocab_size`**: Total vocabulary size used to sample negative examples uniformly at random.  
- **`max_pair`**: Maximum number of anchor-context pairs randomly selected from each gene's sequence per batch.

Key methods and behaviors:

- **`__len__`**: Calculates how many batches fit in one epoch given dataset size and batch size.  
- **`__getitem__`**: Retrieves a batch of data by fetching gene sequences from MongoDB, tokenizing, and generating training pairs.  
- **`on_epoch_end`**: Shuffles gene IDs if enabled, to ensure different ordering across epochs.  
- **`__data_generation`**:  
  - Randomly pairs each gene with an organism ID for batch diversity.  
  - Queries the database for haplotype sequences of these pairs.  
  - Tokenizes haplotype sequences using BPE previously model.  
  - Splits sequences into fixed-length context windows.  
  - Extracts anchor words (middle token of each context window) and corresponding context words.  
  - Randomly samples a subset of these pairs up to `max_pair`.  
  - Generates negative samples by sampling random tokens from the vocabulary.  
  - Combines positive and negative samples to form the final input (`X`) and label (`Y`) batches.

This class is tailored to efficiently produce training data for GeneticPieces2Vec Model.


In [None]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self, db_name, collection_name, organism_IDs, tokenizer_path, shuffle, batch_size, context_size, negative_samples, vocab_size, max_pair, **kwargs):
        super().__init__(**kwargs)
        self.organism_IDs = organism_IDs
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.context_size = context_size
        self.negative_samples = negative_samples
        self.vocab_size = vocab_size
        self.max_pair = max_pair

        self.tokenizer = spm.SentencePieceProcessor()
        self.tokenizer.load(f'{tokenizer_path}.model')

        self.client = MongoClient("mongodb://localhost:27017/")
        self.db = self.client[db_name]
        self.collection = self.db[collection_name]
        self.genes_IDs = self.collection.distinct('gene_ID')
        
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.genes_IDs) / self.batch_size))

    def __getitem__(self, index):
        # Generate indexes of the batch
        batch_genes_IDS = self.genes_IDs[index*self.batch_size:(index+1)*self.batch_size]
        # Generate data
        X, Y = self.__data_generation(batch_genes_IDS)

        return X, Y

    def on_epoch_end(self):
        if self.shuffle == True:
            np.random.shuffle(self.genes_IDs)

    def __data_generation(self, IDs):
        organisms = list(map(str,np.random.choice(self.organism_IDs, size = self.batch_size)))
        query = {"$or": [
                    {'gene_ID': gene, 'organism_ID': organism}
                    for gene,organism in zip(IDs, organisms)
                   ]}
        genes = self.collection.find(query)
        
        total_pairs = np.zeros([1,2 + self.negative_samples])
        for i,gene in enumerate(genes):
            haplotype_1, haplotype_2 = gene['haplotype_1'].upper(), gene['haplotype_2'].upper()
            seq_1, seq_2 = np.array(self.tokenizer.encode_as_ids(haplotype_1)), np.array(self.tokenizer.encode_as_ids(haplotype_2))

            seq_len = np.min([len(seq_1), len(seq_2)])
            n_contexts = np.floor(seq_len/self.context_size)
            seq_1, seq_2 = seq_1[:int(n_contexts*self.context_size)], seq_2[:int(n_contexts*self.context_size)]
            seq_1, seq_2 = seq_1.reshape([int(n_contexts), self.context_size]), seq_2.reshape([int(n_contexts), self.context_size])
            seq = np.concatenate((seq_1, seq_2), axis=0)

            anchor_words = seq[:, int(self.context_size/2)]
            anchor_words = np.repeat(anchor_words, self.context_size-1)
            context_words = np.delete(seq, int(self.context_size/2), axis=1).flatten()
            pairs = np.column_stack((anchor_words, context_words))
            
            reduced_pairs = pairs[np.random.randint(0, high=len(pairs), size=(self.max_pair)),:]
            
            negative_context = np.random.randint(1, high=self.vocab_size, size=(len(reduced_pairs), self.negative_samples))
            gene_pairs = np.column_stack((reduced_pairs, negative_context))
            
            total_pairs = np.concatenate((total_pairs,gene_pairs), axis=0)
            
        total_pairs = total_pairs[1:]
        X = (total_pairs[:,0], total_pairs[:,1:])
        Y = np.zeros([len(X[0]), 1 + self.negative_samples])
        Y[:,0] = 1
        return X, Y

In [None]:
partitions = pd.read_csv(partitions_path)
training_IDs = list(partitions.loc[partitions['partition']=='Train', 'ID'])

In [None]:
train_gen = DataGenerator(db_name, collection_name, organism_IDs = training_IDs, tokenizer_path = tokenizer_model_path, shuffle = False, batch_size = 32, context_size=5, negative_samples=5, vocab_size=9000, max_pair=1000)

In [None]:
for i in range(train_gen.__len__()):
    X, Y = train_gen.__getitem__(i)
    print(f'{i+1} of {train_gen.__len__()}', end='\r')

## Save data Generator

In [None]:
%%writefile /home/jmalagont/Documentos/GWord2Vec/algorithms/utils/DataGenerator.py

#Numeric
import numpy as np
import pandas as pd
#DL
import keras
import tensorflow as tf
#Sytem
from pymongo import MongoClient
#Tokenizers
import sentencepiece as spm


class DataGenerator(keras.utils.Sequence):
    def __init__(self, db_name, collection_name, organism_IDs, tokenizer_path, shuffle, batch_size, context_size, negative_samples, vocab_size, max_pair, **kwargs):
        super().__init__(**kwargs)
        self.organism_IDs = organism_IDs
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.context_size = context_size
        self.negative_samples = negative_samples
        self.vocab_size = vocab_size
        self.max_pair = max_pair

        self.tokenizer = spm.SentencePieceProcessor()
        self.tokenizer.load(f'{tokenizer_path}.model')

        self.client = MongoClient("mongodb://localhost:27017/")
        self.db = self.client[db_name]
        self.collection = self.db[collection_name]
        self.genes_IDs = self.collection.distinct('gene_ID')
        
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.genes_IDs) / self.batch_size))

    def __getitem__(self, index):
        # Generate indexes of the batch
        batch_genes_IDS = self.genes_IDs[index*self.batch_size:(index+1)*self.batch_size]
        # Generate data
        X, Y = self.__data_generation(batch_genes_IDS)

        return X, Y

    def on_epoch_end(self):
        if self.shuffle == True:
            np.random.shuffle(self.genes_IDs)

    def __data_generation(self, IDs):
        organisms = list(map(str,np.random.choice(self.organism_IDs, size = self.batch_size)))
        query = {"$or": [
                    {'gene_ID': gene, 'organism_ID': organism}
                    for gene,organism in zip(IDs, organisms)
                   ]}
        genes = self.collection.find(query)
        
        total_pairs = np.zeros([1,2 + self.negative_samples])
        for i,gene in enumerate(genes):
            haplotype_1, haplotype_2 = gene['haplotype_1'].upper(), gene['haplotype_2'].upper()
            seq_1, seq_2 = np.array(self.tokenizer.encode_as_ids(haplotype_1)), np.array(self.tokenizer.encode_as_ids(haplotype_2))

            seq_len = np.min([len(seq_1), len(seq_2)])
            n_contexts = np.floor(seq_len/self.context_size)
            seq_1, seq_2 = seq_1[:int(n_contexts*self.context_size)], seq_2[:int(n_contexts*self.context_size)]
            seq_1, seq_2 = seq_1.reshape([int(n_contexts), self.context_size]), seq_2.reshape([int(n_contexts), self.context_size])
            seq = np.concatenate((seq_1, seq_2), axis=0)

            anchor_words = seq[:, int(self.context_size/2)]
            anchor_words = np.repeat(anchor_words, self.context_size-1)
            context_words = np.delete(seq, int(self.context_size/2), axis=1).flatten()
            pairs = np.column_stack((anchor_words, context_words))
            
            reduced_pairs = pairs[np.random.randint(0, high=len(pairs), size=(self.max_pair)),:]
            
            negative_context = np.random.randint(1, high=self.vocab_size, size=(len(reduced_pairs), self.negative_samples))
            gene_pairs = np.column_stack((reduced_pairs, negative_context))
            
            total_pairs = np.concatenate((total_pairs,gene_pairs), axis=0)
            
        total_pairs = total_pairs[1:]
        X = (total_pairs[:,0], total_pairs[:,1:])
        Y = np.zeros([len(X[0]), 1 + self.negative_samples])
        Y[:,0] = 1
        return X, Y