In [1]:
!pip install pandas numpy pyarrow requests tqdm pyranges --quiet

# Download GDC client
!wget https://gdc.cancer.gov/files/public/file/gdc-client_v1.6.1_Ubuntu_x64.zip -O gdc.zip -q
!unzip -o gdc.zip > /dev/null
!chmod +x gdc-client

print("Installed Python deps and downloaded gdc-client.")

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━[0m [32m1.2/1.6 MB[0m [31m38.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m30.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m78.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m102.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for sorted_nearest (pyproject.toml) ... [?25l[?25hdone
Installed Python deps and downloaded gdc-client.


In [2]:
from google.colab import drive
drive.mount('/content/drive')

print(" Google Drive mounted at /content/drive")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
 Google Drive mounted at /content/drive


In [3]:
from pathlib import Path
import os
import pandas as pd
from transformers import (
    BertConfig,
    BertForMaskedLM,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
    TrainerCallback
)

# Main project root on Google Drive
PROJECT_ROOT = Path("/content/drive/MyDrive/bdh_challenge_2025_data")
PROJECT_ROOT.mkdir(exist_ok=True)

# Data directories

# Generic data directory
DATA_DIR = PROJECT_ROOT / "data"
DATA_DIR.mkdir(exist_ok=True)

# Directory where raw TCGA STAR count files will be stored
RNA_DIR = PROJECT_ROOT / "tcga_rna"
RNA_DIR.mkdir(exist_ok=True)

# Directory where processed matrices, tokenized data, embeddings
PROCESSED_DIR = PROJECT_ROOT / "processed"
PROCESSED_DIR.mkdir(exist_ok=True)

print("Project root :", PROJECT_ROOT)
print("DATA_DIR     :", DATA_DIR)
print("RNA_DIR      :", RNA_DIR)
print("PROCESSED_DIR:", PROCESSED_DIR)


Project root : /content/drive/MyDrive/bdh_challenge_2025_data
DATA_DIR     : /content/drive/MyDrive/bdh_challenge_2025_data/data
RNA_DIR      : /content/drive/MyDrive/bdh_challenge_2025_data/tcga_rna
PROCESSED_DIR: /content/drive/MyDrive/bdh_challenge_2025_data/processed


In [4]:
ALIGNED_PATH = PROCESSED_DIR / "tcga_tpm_like_log10_bulkrnabert_aligned.parquet"
expr_aligned = pd.read_parquet(ALIGNED_PATH)

In [5]:
expr_aligned.head()

gene_id,ENSG00000000003,ENSG00000000005,ENSG00000000419,ENSG00000000457,ENSG00000000460,ENSG00000000938,ENSG00000000971,ENSG00000001036,ENSG00000001084,ENSG00000001167,...,ENSG00000284519,ENSG00000284532,ENSG00000284535,ENSG00000284543,ENSG00000284557,ENSG00000284564,ENSG00000284574,ENSG00000284587,ENSG00000284595,ENSG00000284596
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
TCGA-P5-A5EY,1.681527,0.063817,1.141174,0.877901,0.408766,0.778577,0.989684,1.165323,1.534292,1.396228,...,0.0,0.0,0.0,0.008509,0.0,0.0,0.0,0.0,0.0,0.0
TCGA-AR-A1AW,1.605939,0.088174,1.567791,1.419358,1.132621,1.467036,2.035625,1.47768,1.519647,1.80023,...,0.0,0.0,0.0,0.254434,0.0,0.0,0.0,0.0,0.0,0.0
TCGA-E9-A1N5,1.848499,0.367775,1.398195,1.432945,0.981104,1.023865,1.629135,1.540959,1.53839,1.612078,...,0.0,0.0,0.0,0.189368,0.0,0.0,0.0,0.0,0.0,0.0
TCGA-97-7941,2.254924,0.079794,1.397967,1.311801,0.662706,1.375285,1.92909,1.767373,1.938894,1.489701,...,0.0,0.0,0.0,0.393385,0.0,0.0,0.0,0.0,0.0,0.0
TCGA-93-7347,1.395531,0.035987,1.26112,1.125617,0.640414,1.382043,1.897127,1.634034,1.067374,1.466133,...,0.0,0.0,0.0,0.241292,0.0,0.0,0.0,0.0,0.0,0.0


In [6]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import BertConfig, BertForMaskedLM, PreTrainedModel, BertModel


CONFIG = {
    "N_genes": 19062,
    "n_embedding": 256,
    "n_layers": 4,
    "n_heads": 8,
    "n_bins": 64,
    "vocab_size": 70,
    "mask_prob": 0.15,
    "batch_size_eff": 3e6,
    "batch_size": 16,
    "hv_genes_count": 12403,
    "grad_accum": 4,
    "lr": 1e-4,
    "max_steps": 10000,
}

In [7]:
class RNAPreprocessor:
    def __init__(self, num_bins=64):
        self.num_bins = num_bins
        self.global_max = None
        self.bin_edges = None

    def fit_transform(self, data):

        transformed = data # data is already log10(1+x) transformed

        # Max Normalization
        self.global_max = np.max(transformed)
        normalized = transformed / self.global_max

        # Binning
        self.bin_edges = np.linspace(0, 1, self.num_bins + 1)
        binned = np.digitize(normalized, self.bin_edges) - 1
        binned = np.clip(binned, 0, self.num_bins - 1)

        # Shift +5 for special tokens (0-4 reserved)
        return binned + 5

class BulkRNADataset(Dataset):
    def __init__(self, tokenized_data, survival_time=None, event_indicator=None):
        self.data = torch.tensor(tokenized_data, dtype=torch.long)
        self.survival_time = torch.tensor(survival_time, dtype=torch.float) if survival_time is not None else None
        self.event = torch.tensor(event_indicator, dtype=torch.float) if event_indicator is not None else None

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = {"input_ids": self.data[idx]}
        if self.survival_time is not None:
            item["time"] = self.survival_time[idx]
            item["event"] = self.event[idx]
        return item

In [8]:

class MockTokenizer:
    """
    A dummy tokenizer that satisfies all DataCollator requirements:
    1. Provides special token attributes.
    2. Implements padding (stacking tensors).
    3. Identifies special tokens so they aren't masked.
    """
    def __init__(self, vocab_size, mask_token_id=4, pad_token_id=0):
        self.vocab_size = vocab_size
        self.mask_token_id = mask_token_id
        self.pad_token_id = pad_token_id

        # Attributes required by DataCollator checks
        self.mask_token = "[MASK]"
        self.pad_token = "[PAD]"

        # Prevents truncation warnings
        self.model_max_length = 100_000

    def __len__(self):
        return self.vocab_size

    def pad(self, encoded_inputs, return_tensors="pt", **kwargs):
        """
        Called by DataCollator to stack a list of samples into a batch.
        Since our gene expression data is fixed-length (N_genes), we just stack them.
        """
        import torch

        # encoded_inputs is a list of dicts: [{'input_ids': tensor}, {'input_ids': tensor}]
        batch = {}

        # We take the first sample to know the keys (usually just 'input_ids')
        for key in encoded_inputs[0].keys():
            # Extract values for this key from all samples
            values = [d[key] for d in encoded_inputs]

            # If they are already tensors, stack them. If lists, convert to tensor.
            if isinstance(values[0], torch.Tensor):
                batch[key] = torch.stack(values)
            else:
                batch[key] = torch.tensor(values, dtype=torch.long)

        return batch

    def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
        """
        Called by DataCollator to decide what NOT to mask.
        We treat IDs 0 to 4 as special (PAD, UNK, CLS, SEP, MASK).
        Returns: list of 0s (can mask) and 1s (cannot mask).
        """
        # token_ids_0 is a list of integers
        return [1 if token < 5 else 0 for token in token_ids_0]

    def convert_tokens_to_ids(self, token):
        return self.mask_token_id if token == self.mask_token else 0

    def save_pretrained(self, save_directory):
        pass

def train_pretraining_model(train_dataset, eval_dataset):
    """
    Sets up the BERT MLM training loop using Hugging Face Trainer.
    Matches the "Pre-training" section of the paper.
    """
    tokenizer = MockTokenizer(
        vocab_size=CONFIG['vocab_size'],
        mask_token_id=4,
        pad_token_id=0
    )
    # 1. Architecture
    config = BertConfig(
        vocab_size=CONFIG['vocab_size'],
        hidden_size=CONFIG['n_embedding'],
        num_hidden_layers=CONFIG['n_layers'],
        num_attention_heads=CONFIG['n_heads'],
        max_position_embeddings=CONFIG['N_genes'] + 512, # Buffer
        pad_token_id=0,
        mask_token_id=4
    )

    model = BertForMaskedLM(config)
    print(f"Model Parameters: {model.num_parameters()}")

    # 2. Collator (Handles the 15% masking automatically)
    # Paper: "decided not to consider only non-zero... but all of them"
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, # Not needed for custom IDs
        mlm=True,
        mlm_probability=CONFIG['mask_prob']
    )

    # 3. Training Arguments
    # Paper uses huge batch size (3e6 tokens). We simulate via accumulation.
    training_args = TrainingArguments(
        output_dir="./results_pretrain",
        overwrite_output_dir=True,
        num_train_epochs=15,
        per_device_train_batch_size=CONFIG['batch_size'],
        gradient_accumulation_steps=CONFIG['grad_accum'],
        learning_rate=CONFIG['lr'],
        logging_steps=50,
        save_steps=500,
        eval_strategy="steps",
        eval_steps=100,
        fp16=torch.cuda.is_available(), # Use mixed precision if GPU available
        report_to="none"
    )

    # 4. Metrics (Reconstruction Accuracy)
    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        # Masked tokens are those where label != -100
        mask = labels != -100
        accuracy = (predictions[mask] == labels[mask]).mean()
        return {"accuracy": accuracy}

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        compute_metrics=compute_metrics
    )

    return trainer

In [9]:
from sklearn.model_selection import train_test_split


print("\nPreprocessing Data...")
preprocessor = RNAPreprocessor(num_bins=CONFIG['n_bins'])
tokenized_data = preprocessor.fit_transform(expr_aligned)

# Split Train/Test (95/5 split as per paper)
train_data, test_data = train_test_split(tokenized_data, test_size=0.05)

train_dataset = BulkRNADataset(train_data)
eval_dataset = BulkRNADataset(test_data)

print(f"Train size: {len(train_dataset)}, Eval size: {len(eval_dataset)}")

# --- PART C: Pre-train (MLM) ---
print("\nStarting Pre-training...")
trainer = train_pretraining_model(train_dataset, eval_dataset)

# Run Training
trainer.train() # Uncomment to execute
print("Training setup complete. Call trainer.train() to start.")


Preprocessing Data...
Train size: 3215, Eval size: 170

Starting Pre-training...
Model Parameters: 12457798


Step,Training Loss,Validation Loss,Accuracy
100,2.4812,2.243885,0.356784
200,2.1111,2.038057,0.370813
300,2.0464,1.996385,0.373528
400,2.0232,1.977154,0.375163
500,2.0073,1.969604,0.374517
600,2.0014,1.95897,0.376751
700,1.9975,1.959643,0.376523


We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Training setup complete. Call trainer.train() to start.


In [11]:
from transformers import BertModel, BertConfig


checkpoint_path = "./results_pretrain/checkpoint-765"


model_mlm = BertForMaskedLM.from_pretrained(checkpoint_path)


bert_encoder = BertModel.from_pretrained(checkpoint_path)

print("Encoder loaded successfully. Ready for fine-tuning.")

Some weights of BertModel were not initialized from the model checkpoint at ./results_pretrain/checkpoint-765 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Encoder loaded successfully. Ready for fine-tuning.


In [12]:
bert_encoder

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(70, 256, padding_idx=0)
    (position_embeddings): Embedding(19574, 256)
    (token_type_embeddings): Embedding(2, 256)
    (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-3): 4 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=256, out_features=256, bias=True)
            (key): Linear(in_features=256, out_features=256, bias=True)
            (value): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=256, out_features=256, bias=True)
            (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
 

In [10]:

source_folder = "./results_pretrain/checkpoint-765"

destination_folder = PROJECT_ROOT / "BulkRNABert_Models"
import subprocess


os.makedirs(destination_folder, exist_ok=True)

print(f"Copying files to {destination_folder}...")

!cp -r -v "$source_folder" "$destination_folder"

print("Copy complete.")

Copying files to /content/drive/MyDrive/bdh_challenge_2025_data/BulkRNABert_Models...
'./results_pretrain/checkpoint-765/config.json' -> '/content/drive/MyDrive/bdh_challenge_2025_data/BulkRNABert_Models/checkpoint-765/config.json'
'./results_pretrain/checkpoint-765/model.safetensors' -> '/content/drive/MyDrive/bdh_challenge_2025_data/BulkRNABert_Models/checkpoint-765/model.safetensors'
'./results_pretrain/checkpoint-765/training_args.bin' -> '/content/drive/MyDrive/bdh_challenge_2025_data/BulkRNABert_Models/checkpoint-765/training_args.bin'
'./results_pretrain/checkpoint-765/optimizer.pt' -> '/content/drive/MyDrive/bdh_challenge_2025_data/BulkRNABert_Models/checkpoint-765/optimizer.pt'
'./results_pretrain/checkpoint-765/scheduler.pt' -> '/content/drive/MyDrive/bdh_challenge_2025_data/BulkRNABert_Models/checkpoint-765/scheduler.pt'
'./results_pretrain/checkpoint-765/scaler.pt' -> '/content/drive/MyDrive/bdh_challenge_2025_data/BulkRNABert_Models/checkpoint-765/scaler.pt'
'./results_pre