In [None]:
import numpy as np
from datasets import load_from_disk, Dataset
import torch
from torch import Tensor
import os
import sys
import pytz
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import datetime
import time
import pickle
import random
import subprocess
from collections import defaultdict
import gc

from transformers import BertConfig, BertForMaskedLM, TrainingArguments, TrainerCallback, Trainer, BertModel, BertPreTrainedModel
from transformers.modeling_outputs import MaskedLMOutput
from transformers.models.bert.modeling_bert import BertLMPredictionHead, BertOnlyMLMHead, BertPredictionHeadTransform
from transformers.activations import ACT2FN
from typing import List, Optional, Tuple, Union
import torch.nn.functional as F
from geneformer.pretrainer import token_dictionary
from geneformer import GeneformerPretrainer

'''Import Customized Model Structures'''
from GF_CAB import CustomBertForMaskedLM

os.environ["NCCL_DEBUG"] = "INFO"
os.environ["OMPI_MCA_opal_cuda_support"] = "true"
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"

seed_num = 42
random.seed(seed_num)
np.random.seed(seed_num)
seed_val = 42
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

# set local time/directories
timezone = pytz.timezone("Asia/Riyadh")
rootdir = os.getcwd() + "/Self_train"


corpus_dir = "Pretrain_data"
with open(corpus_dir + "/token_dictionary.pkl", "rb") as fp:
    token_dictionary = pickle.load(fp)

len_vocabulary = len(token_dictionary)


In [2]:
# set model parameters
# model type
model_type = "bert"
# max input size
max_input_size = 2**11  # 2048
# number of layers
num_layers = 6
# number of attention heads
num_attn_heads = 4
# number of embedding dimensions
num_embed_dim = 256
# intermediate size
intermed_size = num_embed_dim * 2
# activation function
activ_fn = "relu"
# initializer range, layer norm, dropout
initializer_range = 0.02
layer_norm_eps = 1e-12
attention_probs_dropout_prob = 0.02
hidden_dropout_prob = 0.02

# set training parameters
# total number of examples in Genecorpus-30M after QC filtering:
num_examples = 27_406_208
subset = 1_000_000
# number gpus
num_gpus = 8
# batch size for training and eval
geneformer_batch_size = 10
# max learning rate
max_lr = 1e-3
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 10_000
# number of epochs
epochs = 3
# optimizer
optimizer = "adamw"
# weight_decay
weight_decay = 0.001


# output directories
current_date = datetime.datetime.now(tz=timezone)
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}_{current_date.strftime('%X').replace(':','')}"
run_name = f"{datestamp}_geneformer_30M_L{num_layers}_emb{num_embed_dim}_SL{max_input_size}_E{epochs}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_O{optimizer}_DS{num_gpus}"
training_output_dir = f"{rootdir}/models/{run_name}/"
logging_dir = f"{rootdir}/runs/{run_name}/"
model_output_dir = os.path.join(training_output_dir, "models/")


model_output_file = os.path.join(model_output_dir, "pytorch_model.bin")
if os.path.isfile(model_output_file) is True:
    raise Exception("Model already saved to this directory.")

# make training and model output directories
os.makedirs(training_output_dir, exist_ok=True)
os.makedirs(model_output_dir, exist_ok=True)

In [None]:
training_path = 'model path'
cosine_val = str(5)
scale_val  = str(1)
model = CustomBertForMaskedLM.from_pretrained(training_path)

In [None]:
path_to_dataset = "Pretrain_data/subset_5M_genecorpus.dataset"
path_to_dataset = os.path.expanduser(path_to_dataset)

# Load the dataset from disk
genecorpus = load_from_disk(path_to_dataset)

sample_size = 50_000

test_dataset = genecorpus.shuffle(seed=666).select(list(range(sample_size)))
print(test_dataset)

batch_size = 100
batch_num = sample_size // batch_size

Dataset({
    features: ['input_ids'],
    num_rows: 50000
})


In [None]:
training_args = {
    "learning_rate": max_lr,
    "do_train": False,
    "do_eval": False,
    "group_by_length": True,
    "length_column_name": "length",
    "disable_tqdm": False,
    "lr_scheduler_type": lr_schedule_fn,
    "warmup_steps": warmup_steps,
    "weight_decay": weight_decay,
    "per_device_train_batch_size": geneformer_batch_size,
    "num_train_epochs": epochs,
    "save_strategy": "steps",
    "save_steps": np.floor(subset / geneformer_batch_size / 8),  # 8 saves per epoch
    "logging_steps": 1000,
    "output_dir": training_output_dir,
    "logging_dir": logging_dir,
}

training_args = TrainingArguments(**training_args)

# define the trainer
trainer = GeneformerPretrainer(
    model=model,
    args=training_args,
    train_dataset=None,
    example_lengths_file="Pretrain_data/sub_5M_genecorpus_30M_2048_lengths.pkl",
    token_dictionary=token_dictionary,
)

In [None]:
# --- Avoid memory blow-up ---
all_labels_pre = []
all_labels_gt = []

for i in tqdm(range(batch_num), colour="purple", desc="Batch Prediction"):
    start_index = i * batch_size
    end_index = start_index + batch_size
    mini_batch = test_dataset.select(range(start_index, end_index))

    # Predict without keeping trainer caches
    mini_predictions = trainer.predict(mini_batch)

    predictions_batch = np.argmax(mini_predictions.predictions, axis=-1).astype("int32")
    all_labels_gt.append(mini_predictions.label_ids)
    all_labels_pre.append(predictions_batch)
    # Free memory
    del mini_predictions
    gc.collect()

In [7]:
def build_final_arrays(all_labels_pre, all_labels_gt, test_dataset):
    """
    Build final concatenated arrays for input_ids, labels_pre, and all_labels_gt
    Works directly on precomputed predictions (no logits kept).
    """
    print('Building final concatenated arrays...')
    
    max_dim2 = max(arr.shape[1] for arr in all_labels_pre)
    total_samples = sum(arr.shape[0] for arr in all_labels_pre)
    
    print(f'Max sequence length: {max_dim2}')
    print(f'Total samples: {total_samples}')
    print(f'Number of batches: {len(all_labels_pre)}')
    
    # Pre-allocate
    labels_pre_final = np.full((total_samples, max_dim2), -100, dtype=np.int32)
    all_labels_gt_final = np.full((total_samples, max_dim2), -100, dtype=all_labels_gt[0].dtype)
    
    # Fill arrays
    current_idx = 0
    for batch_idx in tqdm(range(len(all_labels_pre)), desc="Processing batches"):
        preds = all_labels_pre[batch_idx]
        labels = all_labels_gt[batch_idx]
        
        batch_size, seq_len = preds.shape
        labels_pre_final[current_idx:current_idx+batch_size, :seq_len] = preds
        all_labels_gt_final[current_idx:current_idx+batch_size, :seq_len] = labels
        current_idx += batch_size
        
        del preds, labels
        gc.collect()
    
    # Process input_ids
    input_ids_list = [np.array(seq) for seq in test_dataset["input_ids"]]
    max_input_dim = max(arr.shape[0] for arr in input_ids_list)
    input_ids_final = np.zeros((len(input_ids_list), max_input_dim), dtype=np.int32)
    for i, arr in enumerate(tqdm(input_ids_list, desc="Processing input_ids")):
        input_ids_final[i, :arr.shape[0]] = arr
    
    del input_ids_list
    gc.collect()
    
    print(f"\nFinal array shapes:")
    print(f"input_ids: {input_ids_final.shape}")
    print(f"labels_pre: {labels_pre_final.shape}")
    print(f"all_labels_gt: {all_labels_gt_final.shape}")
    
    return input_ids_final, labels_pre_final, all_labels_gt_final


input_ids, labels_pre, all_labels_gt = build_final_arrays(all_labels_pre, all_labels_gt, test_dataset)

Building final concatenated arrays...
Max sequence length: 2048
Total samples: 50000
Number of batches: 500


Processing batches:   0%|          | 0/500 [00:00<?, ?it/s]

Processing input_ids:   0%|          | 0/50000 [00:00<?, ?it/s]


Final array shapes:
input_ids: (50000, 2048)
labels_pre: (50000, 2048)
all_labels_gt: (50000, 2048)


In [8]:
# ------------------------------------------
# Vectorized repeat ratio computation
# ------------------------------------------
def compute_repeat_ratios(input_ids, labels_pre, all_labels_gt, return_gene_lists=False):
    """
    Compute repeat ratios with vectorized operations.
    Optionally also return lists of repeating genes and counts (slower).
    """

    num_samples, seq_len = labels_pre.shape
    vocab_size = input_ids.max() + 1  # assume contiguous ids

    # Masks
    mask = all_labels_gt != -100       # masked positions
    unmask = all_labels_gt == -100     # unmasked positions

    # Predictions only at masked positions
    seq_pre = np.where(mask, labels_pre, -100)
    seq_unmask = np.where(unmask, input_ids, -100)

    # Number of predictions per sequence
    masked_counts = mask.sum(axis=1)

    # --- Count predictions per sequence (dense [num_samples, vocab_size]) ---
    pred_counts = np.zeros((num_samples, vocab_size), dtype=np.int32)
    for i in range(num_samples):
        vals, counts = np.unique(seq_pre[i][seq_pre[i] != -100], return_counts=True)
        pred_counts[i, vals] = counts

    # --- Presence of unmasked genes ---
    unmask_presence = np.zeros((num_samples, vocab_size), dtype=bool)
    for i in range(num_samples):
        vals = np.unique(seq_unmask[i][seq_unmask[i] != -100])
        unmask_presence[i, vals] = True

    # --- Compute repeats ---
    # Repeats with unmasked
    repeat_with_unmasked_count = (pred_counts * unmask_presence).sum(axis=1)

    # Repeats within masked (counts > 1, subtract 1 per gene)
    repeat_within_masked_count = np.where(pred_counts > 1, pred_counts - 1, 0).sum(axis=1)

    # Total
    total_repeat_count = repeat_with_unmasked_count + repeat_within_masked_count

    # --- Ratios ---
    repeat_ratio_unmasked = np.divide(
        repeat_with_unmasked_count, masked_counts, out=np.zeros_like(repeat_with_unmasked_count, dtype=float), where=masked_counts > 0
    ) * 100

    repeat_ratio_masked = np.divide(
        repeat_within_masked_count, masked_counts, out=np.zeros_like(repeat_within_masked_count, dtype=float), where=masked_counts > 0
    ) * 100

    repeat_ratio_overall = np.divide(
        total_repeat_count, masked_counts, out=np.zeros_like(total_repeat_count, dtype=float), where=masked_counts > 0
    ) * 100

    if not return_gene_lists:
        return repeat_ratio_unmasked, repeat_ratio_masked, repeat_ratio_overall

    # ------------------------------------------
    # Slower part: actual repeating genes & counts
    # (requires Python loop because of variable lengths)
    # ------------------------------------------
    repeating_gene_list = []
    repeating_count_list = []

    for i in tqdm(range(num_samples), desc="Collecting gene lists"):
        vals, counts = np.nonzero(pred_counts[i])[0], pred_counts[i][pred_counts[i] > 0]

        # Genes repeated with unmasked
        genes_repeat_with_unmasked = vals[unmask_presence[i, vals]]
        repeat_with_unmasked_count = counts[unmask_presence[i, vals]]

        # Genes repeated within masked
        mask_repeats = (counts > 1) & (~unmask_presence[i, vals])
        genes_repeat_within_masked = vals[mask_repeats]
        repeat_within_masked_count = counts[mask_repeats] - 1

        # Concatenate
        seq_total_repeated_gene = np.concatenate((genes_repeat_with_unmasked, genes_repeat_within_masked), axis=0)
        seq_total_repeated_count = np.concatenate((repeat_with_unmasked_count, repeat_within_masked_count), axis=0)

        repeating_gene_list.append(seq_total_repeated_gene)
        repeating_count_list.append(seq_total_repeated_count)

    return (repeat_ratio_unmasked, repeat_ratio_masked, repeat_ratio_overall,
            repeating_gene_list, repeating_count_list)


# ------------------------------------------
# Example usage
# ------------------------------------------
# Suppose you already have numpy arrays:
# input_ids, labels_pre, all_labels_gt
# all with shape [num_samples, seq_len]

print("Computing repeat ratios (vectorized)...")
repeat_ratio_unmasked, repeat_ratio_masked, repeat_ratio_overall = compute_repeat_ratios(
    input_ids, labels_pre, all_labels_gt, return_gene_lists=False
)

print("\nRepeat ratio stats:")
print(f"Unmasked overlap mean: {repeat_ratio_unmasked.mean():.2f}%")
print(f"Masked repetition mean: {repeat_ratio_masked.mean():.2f}%")
print(f"Overall repetition mean: {repeat_ratio_overall.mean():.2f}%")

# If you also want the actual gene lists per sequence:
# repeat_ratio_unmasked, repeat_ratio_masked, repeat_ratio_overall, \
#     repeating_gene_list, repeating_count_list = compute_repeat_ratios(
#         input_ids, labels_pre, all_labels_gt, return_gene_lists=True
#     )


Computing repeat ratios (vectorized)...

Repeat ratio stats:
Unmasked overlap mean: 4.49%
Masked repetition mean: 6.09%
Overall repetition mean: 10.58%


In [9]:
uniqueness_list = []
for idx in tqdm(range(len(labels_pre))):
    masked_indices = (all_labels_gt[idx] != -100).nonzero()
    if len(masked_indices[0]) != 0:
        masked_pres = labels_pre[idx][masked_indices]
        unique_genes, counts = np.unique(masked_pres, return_counts=True)
        unique_pre_counts = np.sum(counts == 1)
        uniqueness = unique_pre_counts / counts.sum()
    else:
        uniqueness = 0
    uniqueness_list.append(uniqueness)


  0%|          | 0/50000 [00:00<?, ?it/s]

In [10]:
def compute_sequence_accuracy(input_ids, labels_pre, all_labels_gt):
    """
    Vectorized computation of per-sequence accuracy (% correct at masked positions).
    """
    # Masked positions
    mask = all_labels_gt != -100  # True where labels are valid
    
    # True labels at masked positions
    seq_mask = np.where(mask, all_labels_gt, -100)
    
    # Predictions at masked positions
    seq_pre = np.where(mask, labels_pre, -100)
    
    # Correct predictions (boolean array)
    correct = (seq_pre == seq_mask) & mask
    
    # Per-sequence counts
    correct_counts = correct.sum(axis=1)
    total_counts = mask.sum(axis=1)
    
    # Accuracy: correct / total (handle division by zero → 1.0 when no masked positions)
    seq_acc = np.divide(
        correct_counts, total_counts,
        out=np.ones_like(correct_counts, dtype=float), where=total_counts > 0
    ) * 100
    
    return seq_acc

print("Computing per-sequence accuracy (vectorized)...")
seq_acc_list = compute_sequence_accuracy(input_ids, labels_pre, all_labels_gt)
print(f"Mean accuracy: {seq_acc_list.mean():.2f}%")
print(f"Std. accuracy: {seq_acc_list.std():.2f}%")

Computing per-sequence accuracy (vectorized)...
Mean accuracy: 24.95%
Std. accuracy: 15.26%


In [11]:
df = pd.DataFrame([seq_acc_list, repeat_ratio_overall, repeat_ratio_masked, repeat_ratio_unmasked, uniqueness_list], index = ['Accuracy', 'Repeat Ratio', 'Repeat Mask', 'Repeat Unmask', 'Unique']).T
file_path = "/ibex/user/chenj0i/Geneformer/" + "GF_Sim_Cosine_" + cosine_val + "_" + scale_val + "Mmodel_50K.csv"
df.to_csv(file_path)
(df)

Unnamed: 0,Accuracy,Repeat Ratio,Repeat Mask,Repeat Unmask,Unique
0,29.568106,20.265781,7.641196,12.624585,0.860465
1,20.000000,0.000000,0.000000,0.000000,1.000000
2,10.000000,10.000000,10.000000,0.000000,0.800000
3,46.666667,13.333333,13.333333,0.000000,0.800000
4,14.285714,0.000000,0.000000,0.000000,1.000000
...,...,...,...,...,...
49995,23.076923,10.256410,10.256410,0.000000,0.820513
49996,17.525773,21.649485,8.247423,13.402062,0.835052
49997,28.571429,0.000000,0.000000,0.000000,1.000000
49998,50.000000,0.000000,0.000000,0.000000,1.000000


In [12]:
df.mean()

Accuracy         24.945316
Repeat Ratio     10.582862
Repeat Mask       6.091604
Repeat Unmask     4.491259
Unique            0.883583
dtype: float64