1. preprocess:按照rank的方式进行排序（参考geneformer）
2. tokenizer构建，直接基于前面的预料库构建即可
3. BERT架构构建（或者其他的）,MLM
4. 预训练

## preprocessing

In [1]:
# TODO: eid信息缺失了，并且部分蛋白是用mean填充的，后续可以根本不需要填充，因为可以传入的时候没有他们
import pandas as pd

train_data = pd.read_pickle("result/part1/train_data.pkl").set_index("eid")
test_data = pd.read_pickle("result/part1/test_data.pkl").set_index("eid").head(300)


protein_cols = test_data.columns[test_data.columns.tolist().index("C3") :].tolist()

In [2]:
from datasets import Dataset
from collections import defaultdict

res = defaultdict(list)

for idx, row in test_data.iterrows():
    ranked_row = row.sort_values(ascending=False).dropna()

    res["eid"].append(ranked_row.name)
    res["proteins"].append(" ".join(ranked_row.index.tolist()))
    res["values"].append(ranked_row.values.tolist())

test_dataset = Dataset.from_dict(res)
test_dataset

Dataset({
    features: ['eid', 'proteins', 'values'],
    num_rows: 300
})

In [3]:
from transformers import (
    BertTokenizer,
    BertTokenizerFast,
    PreTrainedTokenizerFast,
    AutoTokenizer,
)


tokenizer = AutoTokenizer.from_pretrained("transtab/tokenizer")

In [4]:
from transformers import AutoTokenizer
import multiprocessing
from transformers import BertTokenizer, BertTokenizerFast, PreTrainedTokenizerFast


def group_texts(examples, max_length=2048):

    tokenized_inputs = tokenizer(
        examples["proteins"],
        return_special_tokens_mask=True,
        add_special_tokens=True,
        padding="max_length",
        max_length=max_length,
        truncation=True,
        truncation_strategy="only_last",
    )

    return tokenized_inputs


# preprocess dataset
tokenized_datasets = test_dataset.map(
    group_texts,
    batched=True,
    remove_columns=["proteins"],
    num_proc=8,
)

Map (num_proc=8):   0%|          | 0/300 [00:00<?, ? examples/s]

In [5]:
tokenized_datasets[0]["input_ids"]

[2775,
 2537,
 1374,
 391,
 957,
 2323,
 996,
 1177,
 2028,
 1499,
 449,
 1542,
 2632,
 1242,
 512,
 1651,
 447,
 1198,
 191,
 1435,
 105,
 1413,
 1392,
 2084,
 1376,
 1433,
 26,
 2178,
 1032,
 2796,
 1516,
 940,
 78,
 1340,
 2173,
 2713,
 1753,
 1372,
 2829,
 102,
 2422,
 2077,
 749,
 199,
 390,
 2472,
 2233,
 1232,
 814,
 1020,
 1382,
 2161,
 1326,
 1156,
 1699,
 1128,
 1498,
 1782,
 737,
 1656,
 242,
 815,
 2093,
 748,
 332,
 2544,
 621,
 2528,
 1248,
 2492,
 381,
 727,
 421,
 1803,
 1200,
 1500,
 1621,
 1998,
 1727,
 399,
 2640,
 2788,
 2083,
 1949,
 2754,
 425,
 2892,
 884,
 106,
 1601,
 1997,
 397,
 653,
 1924,
 794,
 728,
 877,
 63,
 666,
 333,
 74,
 1021,
 406,
 2826,
 2648,
 1497,
 1954,
 256,
 1700,
 1306,
 1237,
 1219,
 1169,
 64,
 2035,
 444,
 70,
 13,
 1321,
 2367,
 1831,
 2441,
 658,
 711,
 2740,
 1019,
 526,
 162,
 850,
 633,
 2148,
 1332,
 591,
 1152,
 405,
 383,
 382,
 2020,
 1860,
 120,
 2141,
 1085,
 1314,
 1579,
 782,
 2032,
 971,
 602,
 123,
 200,
 473,
 1397,
 100

In [6]:
from transformers import DataCollatorForLanguageModeling


data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm_probability=0.15
)

## albert

In [9]:
from transformers import AlbertConfig, AlbertForMaskedLM

albertconfig = AlbertConfig()

albertconfig

AlbertConfig {
  "attention_probs_dropout_prob": 0,
  "bos_token_id": 2,
  "classifier_dropout_prob": 0.1,
  "embedding_size": 128,
  "eos_token_id": 3,
  "hidden_act": "gelu_new",
  "hidden_dropout_prob": 0,
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "inner_group_num": 1,
  "intermediate_size": 16384,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "albert",
  "num_attention_heads": 64,
  "num_hidden_groups": 1,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.39.3",
  "type_vocab_size": 2,
  "vocab_size": 30000
}

In [11]:
albertconfig = AlbertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=256,
    num_attention_heads=8,
    intermediate_size=512,
    max_position_embeddings=tokenizer.model_max_length,
    num_hidden_layers=6,
)


AlbertForMaskedLM(albertconfig)

AlbertForMaskedLM(
  (albert): AlbertModel(
    (embeddings): AlbertEmbeddings(
      (word_embeddings): Embedding(2916, 128, padding_idx=0)
      (position_embeddings): Embedding(2911, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0, inplace=False)
    )
    (encoder): AlbertTransformer(
      (embedding_hidden_mapping_in): Linear(in_features=128, out_features=256, bias=True)
      (albert_layer_groups): ModuleList(
        (0): AlbertLayerGroup(
          (albert_layers): ModuleList(
            (0): AlbertLayer(
              (full_layer_layer_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
              (attention): AlbertAttention(
                (query): Linear(in_features=256, out_features=256, bias=True)
                (key): Linear(in_features=256, out_features=256, bias=True)
                (value): Linear(in_features=256, out_features=256, bias=True)
  

## bert

In [9]:
from transformers import BertConfig

config = BertConfig()

config

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.39.3",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [10]:
tokenizer.vocab_size

2916

In [11]:
bertconfig = BertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=256,
    num_attention_heads=8,
    intermediate_size=512,
    max_position_embeddings=tokenizer.model_max_length,
    num_hidden_layers=6,
)
from transformers import AutoModelForMaskedLM, BertForMaskedLM

BertForMaskedLM(bertconfig)

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(2916, 256, padding_idx=0)
      (position_embeddings): Embedding(2911, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_aff

In [21]:
# from transformers import AutoModelForMaskedLM, BertForMaskedLM

# model = BertForMaskedLM(bertconfig)
from transformers import AutoModelForMaskedLM, BertForMaskedLM, AutoModelForPreTraining

model = BertForMaskedLM.from_pretrained("result/dl/ProteomicsBERT/checkpoint-500")

In [22]:
# from torch.
model.bert.embeddings.word_embeddings

Embedding(2916, 256, padding_idx=0)

In [23]:
from transformers import Trainer, TrainingArguments

In [24]:
training_args = TrainingArguments(
    output_dir="test",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    weight_decay=0.01,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets,
    eval_dataset=tokenized_datasets,
    data_collator=data_collator,
)

# trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [25]:
trainer.train_dataset

Dataset({
    features: ['eid', 'values', 'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'],
    num_rows: 300
})

In [26]:
for batch in trainer.get_train_dataloader():

    break

In [27]:
import torch


# with torch.no_grad():
o = model(**batch)

CrossEntropyLoss()


In [28]:
o.loss

tensor(8.0208, device='cuda:0', grad_fn=<NllLossBackward0>)

In [20]:
model

BertForPreTraining(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(2916, 256, padding_idx=0)
      (position_embeddings): Embedding(2916, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_

In [37]:
torch.argmax(o.prediction_logits, dim=-1)

tensor([[1643, 2178,  574,  ...,  178, 2199, 1496],
        [ 105, 1243, 2477,  ..., 2012, 1496, 2007]], device='cuda:0')

In [38]:
batch["input_ids"]

tensor([[2775, 2537, 1374,  ..., 2359, 2214,  948],
        [1153,  835, 2225,  ..., 1497,    3, 1304]], device='cuda:0')

In [39]:
batch["labels"]

tensor([[-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, 1606, -100]], device='cuda:0')

In [40]:
torch.nn.CrossEntropyLoss()(
    o.prediction_logits.view(-1, tokenizer.vocab_size), batch["labels"].view(-1)
)

tensor(8.0299, device='cuda:0', grad_fn=<NllLossBackward0>)

In [42]:
 o.prediction_logits.view(-1, tokenizer.vocab_size).shape

torch.Size([4096, 2916])

In [42]:
batch["token_type_ids"]

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='cuda:0')

In [43]:
batch["input_ids"]

tensor([[2775, 2537, 1374,  ..., 2359, 2214,  948],
        [1153,  835, 2225,  ..., 1497, 1453, 1304],
        [1837, 1500, 2624,  ..., 1265,    3, 1918],
        ...,
        [ 371, 2825,  943,  ...,  524, 2886, 1549],
        [2537, 2829,    3,  ..., 1380, 1278,  739],
        [ 389, 2537, 2488,  ..., 1663, 1429, 1442]], device='cuda:0')

In [48]:
batch.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

In [46]:
tokenizer.decode(batch["input_ids"][0][:10])

'trim26 spint3 il17f ccl20 erc2 s100a14 fam171a2 [MASK] pglyrp1 kir2ds4'

In [44]:
batch["input_ids"][0][:10]

tensor([2775, 2537, 1374,  391,  957, 2323,  996,    3, 2028, 1499],
       device='cuda:0')

In [45]:
batch["labels"][0][:10]

tensor([-100, -100, -100, -100, -100, -100, -100, 1177, -100, -100],
       device='cuda:0')

In [None]:
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))


In [28]:
o.loss

In [20]:
model = model.from_pretrained("result/dl/ProteomicsBERT/checkpoint-99500")

In [21]:
batch

{'input_ids': tensor([[2778, 2540, 1377,  ...,  374,  856, 1287],
        [1159,  839, 2228,  ..., 1552,  565, 1183],
        [1840, 1503, 2627,  ..., 1319,    3,  731],
        ...,
        [ 375, 2828,  949,  ..., 2871, 1697,  394],
        [2540, 2832,    3,  ..., 1667,   56, 1649],
        [ 393, 2540, 2491,  ...,  121,  984,  182]], device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]], device='cuda:0'), 'labels': tensor([[-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, 2641, -100],
        [-100, -100,

In [25]:
model.to("cpu")
o = model(batch)
o

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [24]:
# model(batch)
batch

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
from transformers import AutoConfig


config = AutoConfig.from_pretrained(
    "distilbert/distilroberta-base",
    vocab_size=len(tokenizer),
    n_ctx=context_length,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

In [None]:
from transformers import AutoModelForMaskedLM

model = AutoModelForMaskedLM.from_pretrained("distilbert/distilroberta-base")

In [None]:
data_collator

In [None]:
from transformers import BertTokenizer, BertTokenizerFast, PreTrainedTokenizerFast


protein_tokenizer = BertTokenizerFast.from_pretrained("transtab/tokenizer")


def tokenize_function(examples):
    return protein_tokenizer(examples["proteins"])


tokenized_datasets = test_dataset.map(
    tokenize_function, batched=True, num_proc=4, batch_size=32
)

In [None]:
protein_tokenizer(test_dataset[0]["proteins"])

In [None]:
# res_list = {}


# for idx, row in test_data.set_index("eid").iterrows():
#     ranked_row = row.sort_values(ascending=False).dropna()
#     res = {}
#     res["eid"] = ranked_row.name
#     res["proteins"] = ranked_row.index.tolist()
#     res["values"] = ranked_row.values.tolist()
#     res_list.append(res)


# def dict_generator():
#     for i in range(len(res_list)):
#         yield res_list[i]

# Dataset.from_generator(dict_generator, num_proc=4)

In [None]:
Dataset.from_generator(res_list)

In [None]:
import numpy as np


def rank_genes(gene_vector, gene_tokens):
    """
    Rank gene expression vector.
    """
    # sort by median-scaled gene values
    sorted_indices = np.argsort(-gene_vector)
    return gene_tokens[sorted_indices], gene_vector[sorted_indices]


def rank_sorted(examples, protein_cols=None):
    """
    Rank protein expression vectors.
    """

    protein_expression_vectors = np.array([examples[col] for col in protein_cols])
    protein_tokens = protein_cols
    ranked_protein_tokens, ranked_protein_expression_vectors = rank_genes(
        protein_expression_vectors, protein_tokens
    )

    other_cols = [col for col in examples.columns if col not in protein_cols]
    return_res = {}
    for i, col in enumerate(other_cols):
        return_res[col] = examples[col].values
    return_res["protein_tokens"] = ranked_protein_tokens
    return_res["protein_expression_vectors"] = ranked_protein_expression_vectors
    return return_res


test_dataset.map(lambda x: rank_sorted(x, protein_cols=protein_cols))

In [None]:
# normalize = True 
# if normalize:
#     # pd.concat([train_data, test_data])[protein_cols]

In [3]:
import pandas as pd

metabolism = pd.read_pickle(
    "/home/xutingfeng/ukb/ukbData/omics/metabolomics/parsed/2024/init_visit.pkl"
)

metabolism.describe()

Unnamed: 0,eid,3-Hydroxybutyrate,Acetate,Acetoacetate,Acetone,Alanine,Albumin,Apolipoprotein A1,Apolipoprotein B,Apolipoprotein B to Apolipoprotein A1 ratio,...,Triglycerides to Total Lipids in Medium VLDL percentage,Triglycerides to Total Lipids in Small HDL percentage,Triglycerides to Total Lipids in Small LDL percentage,Triglycerides to Total Lipids in Small VLDL percentage,Triglycerides to Total Lipids in Very Large HDL percentage,Triglycerides to Total Lipids in Very Large VLDL percentage,Triglycerides to Total Lipids in Very Small VLDL percentage,Tyrosine,VLDL Cholesterol,Valine
count,274298.0,269219.0,274092.0,274286.0,274294.0,274188.0,274252.0,274296.0,274296.0,274296.0,...,274295.0,274295.0,274295.0,274295.0,274185.0,271667.0,274295.0,273971.0,274295.0,274070.0
mean,3506315.0,0.060765,0.017903,0.012999,0.014232,0.296304,39.347088,1.462377,0.849427,0.595225,...,47.854559,4.546763,5.516848,38.422912,5.307593,52.255282,19.635848,0.063042,0.720872,0.210312
std,1449998.0,0.062247,0.033286,0.012315,0.005623,0.078481,3.387807,0.24633,0.201663,0.166096,...,8.524622,1.329243,1.829603,6.697692,3.906198,8.177621,4.196483,0.01456,0.248205,0.043535
min,1000025.0,0.0,0.0,0.0,0.002328,0.061282,0.0,0.30399,0.15432,0.064523,...,0.000158,0.10023,1.6825,4.3184,0.001715,0.000184,3.6654,0.005511,0.063553,0.067384
25%,2251038.0,0.029318,0.011571,0.006238,0.011095,0.23917,37.263,1.288,0.70599,0.47421,...,41.65,3.62655,4.2882,33.813,3.0168,48.215,16.731,0.052981,0.53982,0.18006
50%,3504206.0,0.043116,0.014985,0.009833,0.012905,0.28739,39.344,1.4393,0.833995,0.57833,...,47.173,4.454,5.1234,38.109,4.3861,53.365,19.002,0.061228,0.69758,0.20537
75%,4760605.0,0.068369,0.01896,0.015646,0.015644,0.344213,41.427,1.6105,0.97493,0.698723,...,53.271,5.3524,6.3048,42.7045,6.472,57.635,21.91,0.071066,0.875325,0.2348
max,6024110.0,3.8434,1.8006,0.7582,0.42117,1.2726,72.014,3.3998,2.4621,3.5867,...,99.301,35.914,77.593,81.04,99.989,97.041,48.023,0.38941,2.7676,0.84766


In [2]:
pd.set

Unnamed: 0,eid,age,sex,PC1,PC2,PC3,PC4,PC5,PC6,PC7,PC8,PC9,PC10
0,1000017,56.0,1.0,-11.3690,3.56718,-1.975530,0.213937,-12.43420,-1.698380,-0.090687,-3.498190,4.762600,3.153210
1,1000025,62.0,1.0,-12.1620,2.77470,0.175048,2.554930,8.75958,-0.044124,-1.497300,0.052680,0.276735,2.118800
2,1000038,60.0,1.0,-12.8698,6.41566,-5.106100,-1.296310,-6.34291,-2.935870,1.690630,-1.932100,3.712410,-0.063338
3,1000042,60.0,1.0,72.9437,-109.21600,74.692200,17.863400,-1.44577,-0.571180,-2.228180,1.646810,1.608430,5.003350
4,1000056,65.0,0.0,-10.7174,5.77507,0.620341,0.505251,-2.49160,1.052860,0.290698,1.672830,-1.928450,-0.712658
...,...,...,...,...,...,...,...,...,...,...,...,...,...
502404,6024086,66.0,0.0,-11.1845,4.08367,-0.006942,-0.325017,-5.32889,2.483810,-1.063800,-3.733520,3.016760,-0.309265
502405,6024098,68.0,1.0,-13.3426,2.56658,-0.076882,6.048100,11.09400,1.417840,2.647870,1.042270,-2.291610,-1.093340
502406,6024103,61.0,1.0,-12.2113,4.22902,-2.629170,4.489250,-2.29320,0.573617,1.350590,-1.911610,1.115080,0.535197
502407,6024110,66.0,1.0,-10.5527,6.84118,-2.149580,-0.825010,-2.83187,-1.727010,1.742680,0.109792,-0.305446,-0.589371
