1. preprocess:按照rank的方式进行排序（参考geneformer）
2. tokenizer构建，直接基于前面的预料库构建即可
3. BERT架构构建（或者其他的）,MLM
4. 预训练

## preprocessing

In [1]:
# TODO: eid信息缺失了，并且部分蛋白是用mean填充的，后续可以根本不需要填充，因为可以传入的时候没有他们
import pandas as pd

train_data = pd.read_pickle("result/part1/train_data.pkl").set_index("eid")
test_data = pd.read_pickle("result/part1/test_data.pkl").set_index("eid")


protein_cols = test_data.columns[test_data.columns.tolist().index("C3") :].tolist()

mean_protein: C3        0.037608
KLK7      0.008941
GCHFR     0.050778
NHLRC3   -0.003544
APOD      0.028029
            ...   
VWF       0.001637
NOTCH3    0.003569
CNTN1     0.003767
ENG       0.002524
ICAM2    -0.026433
Length: 2911, dtype: float64, std_protein: C3        0.487526
KLK7      0.410312
GCHFR     0.546738
NHLRC3    0.346298
APOD      0.451526
            ...   
VWF       0.840448
NOTCH3    0.411491
CNTN1     0.331694
ENG       0.195196
ICAM2     0.357197
Length: 2911, dtype: float64


## build tokenizer

In [6]:
# get the token vocab from geneformers
TOKEN_DICTIONARY_FILE = "/home/xutingfeng/ukb/project/ppp_prediction/ppp_prediction/geneformer/token_dictionary.pkl"
GENE_NAME_ID_DICTIONARY_FILE = "/home/xutingfeng/ukb/project/ppp_prediction/ppp_prediction/geneformer/gene_name_id_dict.pkl"
import pickle

with open(TOKEN_DICTIONARY_FILE, "rb") as f:
    gene_token_dict = pickle.load(f)

with open(GENE_NAME_ID_DICTIONARY_FILE, "rb") as f:
    gene_name_id_dict = pickle.load(f)


from collections import OrderedDict

PAD = "<pad>"
MASK = "<mask>"
gene_name_token_dict = OrderedDict(
    {
        PAD: 0,
        MASK: 1,
        **{
            k: token
            for k, v in gene_name_id_dict.items()
            if (token := gene_token_dict.get(v)) != None
        },
    }
)


gene_name_token_dict = {
    i[0]: i[1] for i in sorted(gene_name_token_dict.items(), key=lambda x: x[1])
}
gene_name_token_dict

{'<pad>': 0,
 '<mask>': 1,
 'TSPAN6': 2,
 'TNMD': 3,
 'DPM1': 4,
 'SCYL3': 5,
 'C1orf112': 6,
 'FGR': 7,
 'CFH': 8,
 'FUCA2': 9,
 'GCLC': 10,
 'NFYA': 11,
 'STPG1': 12,
 'NIPAL3': 13,
 'LAS1L': 14,
 'ENPP4': 15,
 'SEMA3F': 16,
 'CFTR': 17,
 'ANKIB1': 18,
 'CYP51A1': 19,
 'KRIT1': 20,
 'RAD52': 21,
 'BAD': 22,
 'LAP3': 23,
 'CD99': 24,
 'HS3ST1': 25,
 'AOC1': 26,
 'WNT16': 27,
 'HECW1': 28,
 'MAD1L1': 29,
 'LASP1': 30,
 'SNX11': 31,
 'TMEM176A': 32,
 'M6PR': 33,
 'KLHL13': 34,
 'CYP26B1': 35,
 'ICA1': 36,
 'DBNDD1': 37,
 'ALS2': 38,
 'CASP10': 39,
 'CFLAR': 40,
 'TFPI': 41,
 'NDUFAF7': 42,
 'RBM5': 43,
 'MTMR7': 44,
 'SLC7A2': 45,
 'ARF5': 46,
 'SARM1': 47,
 'POLDIP2': 48,
 'PLXND1': 49,
 'AK2': 50,
 'CD38': 51,
 'FKBP4': 52,
 'KDM1A': 53,
 'RBM6': 54,
 'CAMKK1': 55,
 'RECQL': 56,
 'VPS50': 57,
 'HSPB6': 58,
 'ARHGAP33': 59,
 'NDUFAB1': 60,
 'PDK4': 61,
 'SLC22A16': 62,
 'ZMYND10': 63,
 'ABCB5': 64,
 'ARX': 65,
 'SLC25A13': 66,
 'ST7': 67,
 'CDC27': 68,
 'SLC4A1': 69,
 'CALCR': 70,
 'HC

In [7]:
from tokenizers import (
    Tokenizer,
    decoders,
    models,
    normalizers,
    pre_tokenizers,
    trainers,
    processors,
)
from transformers import BertTokenizer, BertTokenizerFast, PreTrainedTokenizerFast

# build transformer tokenizer
tokenizer = Tokenizer(models.WordLevel(vocab=gene_name_token_dict, unk_token=None))

tokenizer.normalizer = normalizers.BertNormalizer(
    lowercase=False,
    clean_text=False,
    handle_chinese_chars=False,
    strip_accents=False,
)
tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()


assert tokenizer.get_vocab() == gene_name_token_dict


bert_tokenizer = BertTokenizerFast(
    tokenizer_object=tokenizer,
    unk_token=None,
    pad_token=PAD,
    cls_token=None,
    sep_token=None,
    mask_token=MASK,
    lowercase=False,
    tokenize_chinese_chars=False,
    do_lower_case=False,
)
bert_tokenizer.model_max_length = 2048
tokenizer.get_vocab()

The OrderedVocab you are attempting to save contains holes for indices [112, 123, 124, 125, 375, 378, 766, 1054, 1084, 1195, 1219, 1402, 1758, 1903, 2186, 2617, 2620, 3420, 3421, 3495, 3496, 3497, 3498, 3499, 3500, 3501, 3571, 3943, 3978, 3992, 4060, 4630, 4638, 4682, 5052, 5147, 5186, 5256, 5288, 5517, 5529, 5532, 5547, 5667, 6041, 6096, 6097, 6098, 6194, 6245, 6292, 6379, 6464, 6502, 6503, 6505, 6565, 6694, 6757, 6785, 6990, 7121, 7269, 7399, 7551, 7598, 7599, 7600, 8030, 8031, 8032, 8126, 8151, 8207, 8281, 8365, 8575, 8613, 8735, 8909, 8912, 8914, 8918, 8991, 9174, 9234, 9236, 9258, 9431, 9433, 9525, 9620, 9621, 9663, 9710, 9726, 9754, 9794, 9805, 9836, 9878, 10040, 10090, 10106, 10244, 10270, 10481, 10537, 10551, 10552, 10553, 10554, 10641, 10829, 10843, 10857, 10928, 10957, 11380, 11381, 11383, 11515, 11517, 11582, 11669, 11703, 11849, 11930, 11975, 11977, 12005, 12108, 12116, 12127, 12153, 12160, 12225, 12228, 12231, 12238, 12317, 12353, 12666, 12767, 12776, 12785, 12802, 12815, 

{'MIR151B': 22340,
 'OXT': 2513,
 'CXXC1': 9761,
 'SEMA6D': 7519,
 'CRNN': 8381,
 'PGAM2': 11354,
 'GRIN1': 13986,
 'FBXL14': 13049,
 'THUMPD3': 6795,
 'ZNF766': 16465,
 'CKAP2': 7175,
 'LDHAL6B': 13089,
 'MIR8077': 24465,
 'MIR4477A': 24342,
 'ATP6V1G2-DDX39B': 20848,
 'RPL17': 22562,
 'DPM2': 7330,
 'MIR211': 18156,
 'PWWP2A': 12669,
 'RAX2': 13496,
 'MIR8085': 25054,
 'DNAJC5': 2445,
 'POGK': 8287,
 'SSU72P7': 25129,
 'MIR2355': 20742,
 'DRGX': 11566,
 'EIF2B3': 1134,
 'TAX1BP1': 3241,
 'PPP1R10': 17693,
 'TFF1': 10374,
 'C2CD4D': 19683,
 'PIGB': 1089,
 'PRPF38B': 6809,
 'GAMT': 6167,
 'CTLA4': 10998,
 'EPHB3': 14919,
 'EFTUD2': 3600,
 'GOLGA6D': 7935,
 'ATP12A': 1322,
 'TJP3': 3063,
 'NAA40': 3781,
 'PLPPR5': 4729,
 'UBE3D': 4799,
 'PBX4': 3171,
 'FASTKD3': 5484,
 'UBE2Q1': 10461,
 'TEX55': 10937,
 'MT1E': 12562,
 'SPC25': 9464,
 'FOXC2': 13957,
 'EPB41L1': 1774,
 'KRT3': 15795,
 'IBA57-DT': 17435,
 'TIGD1': 19459,
 'MIR3657': 22840,
 'MIR3910-1': 23032,
 'MIR6784': 24441,
 'TAGLN3

, 18012, 18013, 18014, 18016, 18017, 18018, 18023, 18024, 18025, 18026, 18027, 18028, 18029, 18031, 18032, 18033, 18035, 18036, 18037, 18038, 18042, 18046, 18060, 18062, 18063, 18071, 18094, 18095, 18096, 18125, 18131, 18136, 18146, 18162, 18167, 18171, 18181, 18193, 18196, 18198, 18208, 18209, 18214, 18236, 18246, 18255, 18256, 18257, 18258, 18263, 18270, 18271, 18272, 18278, 18285, 18287, 18293, 18297, 18308, 18319, 18350, 18361, 18362, 18374, 18375, 18377, 18378, 18379, 18380, 18383, 18384, 18386, 18387, 18389, 18390, 18392, 18394, 18395, 18396, 18399, 18400, 18403, 18411, 18414, 18415, 18417, 18418, 18420, 18424, 18427, 18428, 18429, 18431, 18433, 18434, 18436, 18437, 18438, 18439, 18440, 18441, 18442, 18443, 18444, 18445, 18446, 18447, 18448, 18449, 18450, 18451, 18452, 18453, 18455, 18457, 18458, 18459, 18470, 18471, 18472, 18478, 18481, 18484, 18485, 18491, 18495, 18497, 18498, 18499, 18500, 18506, 18512, 18517, 18519, 18529, 18536, 18563, 18600, 18603, 18614, 18623, 18648, 1865

In [8]:
pd.Series(tokenizer.get_vocab()).sort_values()

<pad>               0
<mask>              1
TSPAN6              2
TNMD                3
DPM1                4
                ...  
SMIM40          25381
H3C3            25382
GUCA1ANB        25383
SMIM42          25395
DUS4L-BCAP29    25400
Length: 21247, dtype: int64

In [9]:
len(set(list(tokenizer.get_vocab().keys())))

21247

In [10]:
test_data

Unnamed: 0_level_0,C3,KLK7,GCHFR,NHLRC3,APOD,GAPDH,TP53I3,CPA4,ANXA2,GRSF1,...,EGFR,TGFBR3,CRTAC1,IGFBP7,SELE,VWF,NOTCH3,CNTN1,ENG,ICAM2
eid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1000127,-0.1648,0.41915,0.0758,-0.1043,0.6142,-0.28950,-0.2959,0.10865,-0.44730,-0.6219,...,0.12050,-0.3048,-0.18615,0.2473,0.43020,0.48260,-0.20860,-0.0618,0.03980,0.2881
1000258,-0.1921,0.66140,-0.1820,-0.6058,0.0311,-0.07430,0.4035,-0.83670,0.63780,0.8637,...,0.26405,-0.4350,-0.10835,-0.2557,-0.73230,0.95105,-0.66800,0.0159,-0.13375,-0.4492
1000634,-0.1844,-0.28910,-0.0165,0.2138,-0.1784,0.07015,0.1277,-0.11510,0.00035,0.5811,...,0.26030,-0.1450,-0.01185,0.1794,0.55690,-0.02970,-0.03360,-0.0161,0.08860,0.1529
1000822,,,,,,,,,,,...,0.10915,0.1599,0.39385,0.0362,-0.20270,0.91280,0.22265,0.4424,0.08500,0.3129
1001060,-0.2934,0.18505,-0.6861,-0.5999,0.0605,-0.23240,-1.1029,-0.12005,-0.49870,-0.4336,...,0.24110,-0.2590,-0.20755,-0.1704,-0.50870,-1.07130,0.03240,0.0044,0.13590,-0.5621
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6020950,0.5542,0.08000,-0.2687,0.2264,-0.0384,-0.03845,-0.0643,0.10820,0.62835,0.1903,...,0.29150,0.1256,-0.02835,-0.0443,0.00000,0.99500,0.18220,0.5549,0.27490,0.2183
6022134,-0.4530,,,,,,,,,0.1759,...,-0.45625,-0.1405,-0.11695,0.1953,-0.52805,0.61300,0.07330,-0.0778,0.05950,0.1643
6022493,0.4204,0.45260,0.4319,0.0025,-0.0471,-0.13020,0.5091,0.37500,0.51440,-0.7438,...,0.00655,0.0000,0.43925,0.4017,0.67530,-0.29545,0.10710,-0.2575,0.06085,-0.1425
6023226,-0.2089,0.29860,-0.6432,-0.2969,0.1563,-0.07735,0.0353,-0.14020,0.49395,0.3366,...,-0.02870,-0.1172,0.56175,-0.1310,-1.14050,-0.51630,0.36120,0.6218,-0.23470,-0.1722


In [11]:
bert_tokenizer(["C3 KLK7"])

{'input_ids': [[5670, 12414]], 'token_type_ids': [[0, 0]], 'attention_mask': [[1, 1]]}

In [12]:
bert_tokenizer.save_pretrained(f"result/dl/ProteomicsGeneFormer/tokenizer")

('result/dl/ProteomicsGeneFormer/tokenizer/tokenizer_config.json',
 'result/dl/ProteomicsGeneFormer/tokenizer/special_tokens_map.json',
 'result/dl/ProteomicsGeneFormer/tokenizer/vocab.json',
 'result/dl/ProteomicsGeneFormer/tokenizer/added_tokens.json',
 'result/dl/ProteomicsGeneFormer/tokenizer/tokenizer.json')

The OrderedVocab you are attempting to save contains holes for indices [112, 123, 124, 125, 375, 378, 766, 1054, 1084, 1195, 1219, 1402, 1758, 1903, 2186, 2617, 2620, 3420, 3421, 3495, 3496, 3497, 3498, 3499, 3500, 3501, 3571, 3943, 3978, 3992, 4060, 4630, 4638, 4682, 5052, 5147, 5186, 5256, 5288, 5517, 5529, 5532, 5547, 5667, 6041, 6096, 6097, 6098, 6194, 6245, 6292, 6379, 6464, 6502, 6503, 6505, 6565, 6694, 6757, 6785, 6990, 7121, 7269, 7399, 7551, 7598, 7599, 7600, 8030, 8031, 8032, 8126, 8151, 8207, 8281, 8365, 8575, 8613, 8735, 8909, 8912, 8914, 8918, 8991, 9174, 9234, 9236, 9258, 9431, 9433, 9525, 9620, 9621, 9663, 9710, 9726, 9754, 9794, 9805, 9836, 9878, 10040, 10090, 10106, 10244, 10270, 10481, 10537, 10551, 10552, 10553, 10554, 10641, 10829, 10843, 10857, 10928, 10957, 11380, 11381, 11383, 11515, 11517, 11582, 11669, 11703, 11849, 11930, 11975, 11977, 12005, 12108, 12116, 12127, 12153, 12160, 12225, 12228, 12231, 12238, 12317, 12353, 12666, 12767, 12776, 12785, 12802, 12815, 

## build dataset

In [13]:
tokenizer_dir = "result/dl/ProteomicsGeneFormer/tokenizer"
from transformers import AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)

In [14]:
# filter data and drop proteins not in tokenizer

tokens = list(tokenizer.vocab.keys())
not_in_list = []
for col in protein_cols:
    if col not in tokens:
        not_in_list.append(col)
print(
    f"Total {len(not_in_list)} proteins not in geneformer tokens will drop them , part of them are : {not_in_list[:10]}"
)

test_data.drop(columns=not_in_list, inplace=True)
train_data.drop(columns=not_in_list, inplace=True)

Total 41 proteins not in geneformer tokens will drop them , part of them are : ['DEFB103A_DEFB103B', 'BTNL10', 'ARNTL', 'MAP1LC3B2', 'SKIV2L', 'HCG22', 'DEFB104A_DEFB104B', 'AKR7L', 'GBA', 'MYLPF']


In [16]:
# to_cal = pd.concat([train_data, test_data], axis=0)
# mean_protein = to_cal.mean()
# std_protein = to_cal.std()

# print(f"mean_protein: {mean_protein}, std_protein: {std_protein}")

mean_protein: C3        0.037608
KLK7      0.008941
GCHFR     0.050778
NHLRC3   -0.003544
APOD      0.028029
            ...   
VWF       0.001637
NOTCH3    0.003569
CNTN1     0.003767
ENG       0.002524
ICAM2    -0.026433
Length: 2870, dtype: float64, std_protein: C3        0.487526
KLK7      0.410312
GCHFR     0.546738
NHLRC3    0.346298
APOD      0.451526
            ...   
VWF       0.840448
NOTCH3    0.411491
CNTN1     0.331694
ENG       0.195196
ICAM2     0.357197
Length: 2870, dtype: float64


In [22]:
from datasets import Dataset
from collections import defaultdict

res = defaultdict(list)

for idx, row in test_data.iterrows():
    # row = (row - mean_protein) / std_protein
    ranked_row = row.sort_values(ascending=False).dropna()

    res["eid"].append(ranked_row.name)
    res["proteins"].append(" ".join(ranked_row.index.tolist()))
    res["raw_proteins"].append(ranked_row.index.tolist())
    res["values"].append(ranked_row.values.tolist())
    res["length"].append(len(ranked_row))

test_dataset = Dataset.from_dict(res)
test_dataset

Dataset({
    features: ['eid', 'proteins', 'raw_proteins', 'values', 'length'],
    num_rows: 15432
})

In [24]:
test = pd.DataFrame(test_dataset["raw_proteins"])
test

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,2860,2861,2862,2863,2864,2865,2866,2867,2868,2869
0,TRIM26,CD3D,PGLYRP1,S100A14,ARTN,GLYR1,MRPL24,ERC2,IL20RA,IL17F,...,,,,,,,,,,
1,CA8,RAC3,CALCB,ECE1,IL1A,DNLZ,NRN1,CD28,SLIT2,MCEE,...,BST1,,,,,,,,,
2,TBR1,CD4,NEK7,REXO2,B2M,LAMA1,TCL1B,CLEC6A,SUMF1,YY1,...,,,,,,,,,,
3,IFNLR1,PRKAB1,LAP3,CDH17,NINJ1,CD200R1,ODAM,ADAMTS8,SIGLEC9,EIF5A,...,,,,,,,,,,
4,ATP5PO,TRAF3IP2,SIRT5,PGLYRP4,SOWAHA,ANKRA2,LIPF,ANXA11,CASP4,EDAR,...,TAGLN3,MAGEA3,FCGR3B,C1QA,JCHAIN,LILRA2,SIGLEC1,SRP14,EDA2R,C2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15427,RTKN2,SDK2,CRNN,SOD3,MAGED1,LRTM1,CCDC134,AIDA,ARHGAP5,LYPLA2,...,GIPC2,CD300LF,RBKS,AKR1B10,MAP2K1,CRHR1,MFAP5,FCN2,,
15428,IL10RA,CRYBB2,NPTN,DUSP3,MAEA,IFNL1,MMP10,ICAM5,HLA-DRA,CXCL10,...,,,,,,,,,,
15429,UPK3BL1,LMNB1,PADI4,RAB44,LPCAT2,SCGB2A2,FMNL1,ZHX2,SNRPB2,MFAP3L,...,SAFB2,SPTBN2,VGF,MYO6,CSNK2A1,RANBP2,MAP1LC3A,TUBB3,IL13RA2,
15430,GPR158,DDX25,SCN2B,INSL4,USP28,WASL,NUDT10,PCDH1,GLP1R,C1QTNF6,...,,,,,,,,,,


In [28]:
test[3].value_counts()

SOD3       53
CAPS       53
HRAS       42
LRPAP1     41
RGS8       37
           ..
CLEC4C      1
MVK         1
SIGLEC6     1
PPM1F       1
FCRL3       1
Name: 3, Length: 2106, dtype: int64

In [10]:
def group_texts(examples, max_length=2048):

    tokenized_inputs = tokenizer(
        examples["proteins"],
        return_special_tokens_mask=True,
        add_special_tokens=True,
        padding="max_length",
        max_length=max_length,
        truncation=True,
        truncation_strategy="only_last",
    )

    return tokenized_inputs


# preprocess dataset
test_dataset = test_dataset.map(
    group_texts,
    batched=True,
    remove_columns=["proteins"],
    num_proc=4,
    batch_size = 256
)

The OrderedVocab you are attempting to save contains holes for indices [112, 123, 124, 125, 375, 378, 766, 1054, 1084, 1195, 1219, 1402, 1758, 1903, 2186, 2617, 2620, 3420, 3421, 3495, 3496, 3497, 3498, 3499, 3500, 3501, 3571, 3943, 3978, 3992, 4060, 4630, 4638, 4682, 5052, 5147, 5186, 5256, 5288, 5517, 5529, 5532, 5547, 5667, 6041, 6096, 6097, 6098, 6194, 6245, 6292, 6379, 6464, 6502, 6503, 6505, 6565, 6694, 6757, 6785, 6990, 7121, 7269, 7399, 7551, 7598, 7599, 7600, 8030, 8031, 8032, 8126, 8151, 8207, 8281, 8365, 8575, 8613, 8735, 8909, 8912, 8914, 8918, 8991, 9174, 9234, 9236, 9258, 9431, 9433, 9525, 9620, 9621, 9663, 9710, 9726, 9754, 9794, 9805, 9836, 9878, 10040, 10090, 10106, 10244, 10270, 10481, 10537, 10551, 10552, 10553, 10554, 10641, 10829, 10843, 10857, 10928, 10957, 11380, 11381, 11383, 11515, 11517, 11582, 11669, 11703, 11849, 11930, 11975, 11977, 12005, 12108, 12116, 12127, 12153, 12160, 12225, 12228, 12231, 12238, 12317, 12353, 12666, 12767, 12776, 12785, 12802, 12815, 

Map (num_proc=4):   0%|          | 0/15432 [00:00<?, ? examples/s]

In [12]:
rank_df = pd.DataFrame(test_dataset["input_ids"])

In [23]:
rank_df[200].value_counts()

7216     22
2053     22
11441    22
5738     21
11268    21
         ..
1524      1
5119      1
1000      1
758       1
7835      1
Name: 200, Length: 2718, dtype: int64

In [161]:
from transformers import (
    DataCollatorForLanguageModeling,
    BertForMaskedLM,
    Trainer,
    TrainingArguments,
)


data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm_probability=0.15
)

In [162]:
model = BertForMaskedLM.from_pretrained(
    "/home/xutingfeng/github_code/others/Geneformer"
)

In [163]:
# set training parameters
# number gpus

# batch size for training and eval
geneformer_batch_size = 4
# max learning rate
max_lr = 1e-3
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 10_000
# number of epochs
epochs = 3
# optimizer
optimizer = "adamw"
# weight_decay
weight_decay = 0.001

# define the training arguments
training_args = {
    "learning_rate": max_lr,
    "do_train": True,
    "do_eval": False,
    "group_by_length": True,
    "length_column_name": "length",
    "disable_tqdm": False,
    "lr_scheduler_type": lr_schedule_fn,
    "warmup_steps": warmup_steps,
    "weight_decay": weight_decay,
    "per_device_train_batch_size": geneformer_batch_size,
    "num_train_epochs": epochs,
    "save_strategy": "steps",
    "logging_steps": 1000,
    "output_dir": "test/",
    "logging_dir": "test/",
}
training_args = TrainingArguments(**training_args)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=test_dataset,
    eval_dataset=test_dataset,
    data_collator=data_collator,
)

print("Starting training.")

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Starting training.


In [164]:
for batch in trainer.get_train_dataloader():

    break

In [166]:
o = model(**batch)

In [181]:
tokenizer.decode([13462, 13462, 18746, 11900, 18746, 18746, 11900, 11900, 234, 11900])

'CNP CNP NPIPA7 GREM1 NPIPA7 NPIPA7 GREM1 GREM1 CSDE1 GREM1'

In [177]:
import torch

torch.argmax(o.logits, dim=-1)[:, :10]

tensor([[13462, 13462, 18746, 11900, 18746, 18746, 11900, 11900,   234, 11900],
        [ 6852,  3170,  6852, 18746,  1868, 11900,  6852,  2848, 18746, 11900],
        [11900, 11900, 11900, 11900, 11900, 11900, 11900, 11900,  6391, 11900],
        [ 9681,  2951, 11900, 11900, 18746, 11900, 18746,  3170, 11900, 11900]],
       device='cuda:0')

In [176]:
batch["input_ids"][:, :10]

tensor([[15349,  2351,  1471,     1,  1643,   611,   392,   389,   234, 14396],
        [ 9542,  2529, 16031, 12137,  3709,  4454,  3124,  2848,  5840, 11697],
        [12479,  2513, 13809,     1,  2529, 17275,  8827,   637,  6391,  6721],
        [ 2529,  2951,     1,  6975, 20520,     1,  8637,   659,     1, 10788]],
       device='cuda:0')

In [174]:
batch["labels"][:, :10]

tensor([[ -100,  -100,  -100,  3719,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100, 15349,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  6154,  -100,  -100,  2492,  -100,  -100,  2181,  -100]],
       device='cuda:0')

In [76]:
from ppp_prediction.geneformer.pretrainer import GeneformerPreCollator
from transformers import (
    DataCollatorForLanguageModeling,
    BertConfig,
    BertForMaskedLM,
    TrainingArguments,
    Trainer,
)

token_dictionary = (
    "/home/xutingfeng/github_code/others/Geneformer/geneformer/token_dictionary.pkl"
)
precollator = GeneformerPreCollator(token_dictionary=token_dictionary)

# # Data Collator Functions
data_collator = DataCollatorForLanguageModeling(
    tokenizer=precollator, mlm=True, mlm_probability=0.15
)
with open(token_dictionary, "rb") as f:
    tokenizer = pickle.load(f)
tokenizer

{'<pad>': 0,
 '<mask>': 1,
 'ENSG00000000003': 2,
 'ENSG00000000005': 3,
 'ENSG00000000419': 4,
 'ENSG00000000457': 5,
 'ENSG00000000460': 6,
 'ENSG00000000938': 7,
 'ENSG00000000971': 8,
 'ENSG00000001036': 9,
 'ENSG00000001084': 10,
 'ENSG00000001167': 11,
 'ENSG00000001460': 12,
 'ENSG00000001461': 13,
 'ENSG00000001497': 14,
 'ENSG00000001561': 15,
 'ENSG00000001617': 16,
 'ENSG00000001626': 17,
 'ENSG00000001629': 18,
 'ENSG00000001630': 19,
 'ENSG00000001631': 20,
 'ENSG00000002016': 21,
 'ENSG00000002330': 22,
 'ENSG00000002549': 23,
 'ENSG00000002586': 24,
 'ENSG00000002587': 25,
 'ENSG00000002726': 26,
 'ENSG00000002745': 27,
 'ENSG00000002746': 28,
 'ENSG00000002822': 29,
 'ENSG00000002834': 30,
 'ENSG00000002919': 31,
 'ENSG00000002933': 32,
 'ENSG00000003056': 33,
 'ENSG00000003096': 34,
 'ENSG00000003137': 35,
 'ENSG00000003147': 36,
 'ENSG00000003249': 37,
 'ENSG00000003393': 38,
 'ENSG00000003400': 39,
 'ENSG00000003402': 40,
 'ENSG00000003436': 41,
 'ENSG00000003509': 4

In [77]:
# set model parameters
# model type
model_type = "bert"
# max input size
max_input_size = 2**11  # 2048
# number of layers
num_layers = 6
# number of attention heads
num_attn_heads = 4
# number of embedding dimensions
num_embed_dim = 256
# intermediate size
intermed_size = num_embed_dim * 2
# activation function
activ_fn = "relu"
# initializer range, layer norm, dropout
initializer_range = 0.02
layer_norm_eps = 1e-12
attention_probs_dropout_prob = 0.02
hidden_dropout_prob = 0.02


# set training parameters
# total number of examples in Genecorpus-30M after QC filtering:
num_examples = 27_406_208
# number gpus
# num_gpus = 12
# batch size for training and eval
geneformer_batch_size = 12
# max learning rate
max_lr = 1e-3
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 10_000
# number of epochs
epochs = 3
# optimizer
optimizer = "adamw"
# weight_decay
weight_decay = 0.001

config = {
    "hidden_size": num_embed_dim,
    "num_hidden_layers": num_layers,
    "initializer_range": initializer_range,
    "layer_norm_eps": layer_norm_eps,
    "attention_probs_dropout_prob": attention_probs_dropout_prob,
    "hidden_dropout_prob": hidden_dropout_prob,
    "intermediate_size": intermed_size,
    "hidden_act": activ_fn,
    "max_position_embeddings": max_input_size,
    "model_type": model_type,
    "num_attention_heads": num_attn_heads,
    "pad_token_id": tokenizer.get("<pad>"),
    "vocab_size": len(tokenizer),  # genes+2 for <mask> and <pad> tokens
}

config = BertConfig(**config)
model = BertForMaskedLM(config).from_pretrained(
    "/home/xutingfeng/github_code/others/Geneformer"
)

In [78]:
# define the training arguments
training_args = {
    "learning_rate": max_lr,
    "do_train": True,
    "do_eval": False,
    "group_by_length": True,
    "length_column_name": "length",
    "disable_tqdm": False,
    "lr_scheduler_type": lr_schedule_fn,
    "warmup_steps": warmup_steps,
    "weight_decay": weight_decay,
    "per_device_train_batch_size": geneformer_batch_size,
    "num_train_epochs": epochs,
    "save_strategy": "steps",
    "save_steps": np.floor(
        num_examples / geneformer_batch_size / 8
    ),  # 8 saves per epoch
    "logging_steps": 1000,
    "output_dir": "test/",
    "logging_dir": "test/",
}
training_args = TrainingArguments(**training_args)

print("Starting training.")

Starting training.


In [79]:
trainer = Trainer(
    model=model,
    args=training_args,
    # pretraining corpus (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048.dataset)
    train_dataset=test_dataset,
    # file of lengths of each example cell (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/blob/main/genecorpus_30M_2048_lengths.pkl)
    # example_lengths_file="genecorpus_30M_2048_lengths.pkl",
    # token_dictionary=token_dictionary,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [80]:
for batch in trainer.get_train_dataloader():
    print(batch)
    break

ValueError: expected sequence of length 2870 at dim 1 (got 2869)

In [84]:
test_dataset[:2]["input_ids"]

[[20005,
  2529,
  3995,
  4346,
  16031,
  16399,
  10572,
  7964,
  222,
  10450,
  12059,
  20328,
  2492,
  170,
  19436,
  11987,
  8291,
  4696,
  13189,
  10904,
  4185,
  370,
  767,
  20520,
  6947,
  3358,
  4262,
  1285,
  8202,
  4959,
  13207,
  8826,
  5658,
  134,
  8319,
  5497,
  12671,
  3554,
  6740,
  10788,
  12504,
  6286,
  3567,
  15746,
  4563,
  1096,
  2905,
  6611,
  4345,
  12072,
  3897,
  9999,
  4106,
  15088,
  20449,
  1712,
  9735,
  328,
  9895,
  2990,
  4361,
  5577,
  13669,
  12511,
  15262,
  21088,
  6275,
  11746,
  21465,
  16459,
  7685,
  14416,
  12063,
  13732,
  2254,
  16598,
  9419,
  6373,
  5875,
  15917,
  2513,
  11755,
  11002,
  5219,
  14725,
  10898,
  12012,
  16834,
  145,
  1943,
  12870,
  13510,
  6870,
  16599,
  1797,
  6443,
  13402,
  10482,
  4817,
  7695,
  3190,
  5222,
  13959,
  8069,
  7005,
  2313,
  3667,
  12678,
  17031,
  8583,
  13509,
  1523,
  825,
  5887,
  5440,
  3766,
  3113,
  3663,
  12702,
  17304,

In [3]:
# tokenizer 采用geneformer
from geneformer.tokenizer import TOKEN_DICTIONARY_FILE

from __future__ import annotations

import logging
import pickle
import warnings
from pathlib import Path
from typing import Literal

import anndata as ad
import numpy as np
import scipy.sparse as sp
from datasets import Dataset

warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")  # noqa
import loompy as lp  # noqa

logger = logging.getLogger(__name__)

import loompy as lp
import numpy as np


def rank_genes(gene_vector, gene_tokens):
    """
    Rank gene expression vector.
    """
    # sort by median-scaled gene values
    sorted_indices = np.argsort(-gene_vector)
    return gene_tokens[sorted_indices]


def tokenize_ind(gene_vector, gene_tokens):
    """
    Convert normalized gene expression vector to tokenized rank value encoding.
    """
    # create array of gene vector with token indices
    # mask undetected genes
    nonzero_mask = np.nonzero(gene_vector)[0]
    # rank by median-scaled gene values
    return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])


class ProteomicsTokenizer:
    def __init__(
        self,
        custom_attr_name_dict=None,
        nproc=1,
        chunk_size=512,
        model_input_size=2048,
        special_token=False,
        # gene_median_file=GENE_MEDIAN_FILE,
        token_dictionary_file=TOKEN_DICTIONARY_FILE,
    ):
        """
        Initialize tokenizer.

        **Parameters:**

        custom_attr_name_dict : None, dict
            | Dictionary of custom attributes to be added to the dataset.
            | Keys are the names of the attributes in the loom file.
            | Values are the names of the attributes in the dataset.
        nproc : int
            | Number of processes to use for dataset mapping.
        chunk_size : int = 512
            | Chunk size for anndata tokenizer.
        model_input_size : int = 2048
            | Max input size of model to truncate input to.
        special_token : bool = False
            | Adds CLS token before and SEP token after rank value encoding.
        # gene_median_file : Path
        #     | Path to pickle file containing dictionary of non-zero median
        #     | gene expression values across Genecorpus-30M.
        token_dictionary_file : Path
            | Path to pickle file containing token dictionary (Ensembl IDs:token).

        """
        # dictionary of custom attributes {output dataset column name: input .loom column name}
        self.custom_attr_name_dict = custom_attr_name_dict

        # number of processes for dataset mapping
        self.nproc = nproc

        # chunk size for anndata tokenizer
        self.chunk_size = chunk_size

        # input size for tokenization
        self.model_input_size = model_input_size

        # add CLS and SEP tokens
        self.special_token = special_token

        # load dictionary of gene normalization factors
        # (non-zero median value of expression across Genecorpus-30M)
        # with open(gene_median_file, "rb") as f:
        #     self.gene_median_dict = pickle.load(f)

        # load token dictionary (Ensembl IDs:token)
        with open(token_dictionary_file, "rb") as f:
            self.gene_token_dict = pickle.load(f)

        # gene keys for full vocabulary
        self.gene_keys = list(self.gene_token_dict.keys())

        # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
        self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))

    def tokenize_loom(self, loom_file_path, target_sum=10_000):
        if self.custom_attr_name_dict is not None:
            file_ind_metadata = {
                attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
            }

        with lp.connect(str(loom_file_path)) as data:
            # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors

            coding_miRNA_loc = np.where(
                [self.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]]
            )[0]

            # norm_factor_vector = np.array(
            #     [
            #         self.gene_median_dict[i]
            #         for i in data.ra["ensembl_id"][coding_miRNA_loc]
            #     ]
            # )
            coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc]

            not_in_gene_ids = set(data.ra["ensembl_id"]) - set(self.gene_keys)
            print(
                f"{len(not_in_gene_ids)} genes not in gene token dictionary, skipping them, some are: {list(not_in_gene_ids)[:5]}"
            )

            coding_miRNA_tokens = np.array(
                [self.gene_token_dict[i] for i in coding_miRNA_ids]
            )

            # define coordinates of individual passing filters for inclusion (e.g. QC)
            try:
                data.ca["filter_pass"]
            except AttributeError:
                var_exists = False
            else:
                var_exists = True

            if var_exists:
                filter_pass_loc = np.where([i == 1 for i in data.ca["filter_pass"]])[0]
            elif not var_exists:
                print(
                    f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all inds."
                )
                filter_pass_loc = np.array([i for i in range(data.shape[1])])

            # scan through .loom files and tokenize inds
            tokenized_ind = []
            for _ix, _selection, view in data.scan(
                items=filter_pass_loc, axis=1, batch_size=self.chunk_size
            ):
                # select subview with protein-coding and miRNA genes
                subview = view.view[coding_miRNA_loc, :]
                # Currently do not norm ,as the values is NPX by UKB

                # tokenize subview gene vectors
                tokenized_ind += [
                    tokenize_ind(subview[:, i], coding_miRNA_tokens)
                    for i in range(subview.shape[1])
                ]

                # add custom attributes for subview to dict
                if self.custom_attr_name_dict is not None:
                    for k in file_ind_metadata.keys():
                        file_ind_metadata[k] += subview.ca[k].tolist()
                else:
                    file_ind_metadata = None

        return tokenized_ind, file_ind_metadata

    def create_dataset(
        self,
        tokenized_inds,
        ind_metadata,
        use_generator=False,
        keep_uncropped_input_ids=False,
    ):
        print("Creating dataset.")
        # create dict for dataset creation
        dataset_dict = {"input_ids": tokenized_inds}
        if self.custom_attr_name_dict is not None:
            dataset_dict.update(ind_metadata)

        # create dataset
        if use_generator:

            def dict_generator():
                for i in range(len(tokenized_inds)):
                    yield {k: dataset_dict[k][i] for k in dataset_dict.keys()}

            output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
        else:
            output_dataset = Dataset.from_dict(dataset_dict)

        def format_ind_features(example):
            # Store original uncropped input_ids in separate feature
            if keep_uncropped_input_ids:
                example["input_ids_uncropped"] = example["input_ids"]
                example["length_uncropped"] = len(example["input_ids"])

            # Truncate/Crop input_ids to input size
            if self.special_token:
                example["input_ids"] = example["input_ids"][
                    0 : self.model_input_size - 2
                ]  # truncate to leave space for CLS and SEP token
                example["input_ids"] = np.insert(
                    example["input_ids"], 0, self.gene_token_dict.get("<cls>")
                )
                example["input_ids"] = np.insert(
                    example["input_ids"],
                    len(example["input_ids"]),
                    self.gene_token_dict.get("<sep>"),
                )
            else:
                # Truncate/Crop input_ids to input size
                example["input_ids"] = example["input_ids"][0 : self.model_input_size]
            example["length"] = len(example["input_ids"])

            return example

        output_dataset_truncated = output_dataset.map(
            format_ind_features, num_proc=self.nproc
        )
        return output_dataset_truncated

In [4]:
from transformers import AutoTokenizer
import multiprocessing
from transformers import BertTokenizer, BertTokenizerFast, PreTrainedTokenizerFast


def group_texts(examples, max_length=2048):

    tokenized_inputs = tokenizer(
        examples["proteins"],
        return_special_tokens_mask=True,
        add_special_tokens=True,
        padding="max_length",
        max_length=max_length,
        truncation=True,
        truncation_strategy="only_last",
    )

    return tokenized_inputs


# preprocess dataset
tokenized_datasets = test_dataset.map(
    group_texts,
    batched=True,
    remove_columns=["proteins"],
    num_proc=8,
)

Map (num_proc=8):   0%|          | 0/300 [00:00<?, ? examples/s]

In [5]:
tokenized_datasets[0]["input_ids"]

[2775,
 2537,
 1374,
 391,
 957,
 2323,
 996,
 1177,
 2028,
 1499,
 449,
 1542,
 2632,
 1242,
 512,
 1651,
 447,
 1198,
 191,
 1435,
 105,
 1413,
 1392,
 2084,
 1376,
 1433,
 26,
 2178,
 1032,
 2796,
 1516,
 940,
 78,
 1340,
 2173,
 2713,
 1753,
 1372,
 2829,
 102,
 2422,
 2077,
 749,
 199,
 390,
 2472,
 2233,
 1232,
 814,
 1020,
 1382,
 2161,
 1326,
 1156,
 1699,
 1128,
 1498,
 1782,
 737,
 1656,
 242,
 815,
 2093,
 748,
 332,
 2544,
 621,
 2528,
 1248,
 2492,
 381,
 727,
 421,
 1803,
 1200,
 1500,
 1621,
 1998,
 1727,
 399,
 2640,
 2788,
 2083,
 1949,
 2754,
 425,
 2892,
 884,
 106,
 1601,
 1997,
 397,
 653,
 1924,
 794,
 728,
 877,
 63,
 666,
 333,
 74,
 1021,
 406,
 2826,
 2648,
 1497,
 1954,
 256,
 1700,
 1306,
 1237,
 1219,
 1169,
 64,
 2035,
 444,
 70,
 13,
 1321,
 2367,
 1831,
 2441,
 658,
 711,
 2740,
 1019,
 526,
 162,
 850,
 633,
 2148,
 1332,
 591,
 1152,
 405,
 383,
 382,
 2020,
 1860,
 120,
 2141,
 1085,
 1314,
 1579,
 782,
 2032,
 971,
 602,
 123,
 200,
 473,
 1397,
 100

In [6]:
from transformers import DataCollatorForLanguageModeling


data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm_probability=0.15
)

## albert

In [9]:
from transformers import AlbertConfig, AlbertForMaskedLM

albertconfig = AlbertConfig()

albertconfig

AlbertConfig {
  "attention_probs_dropout_prob": 0,
  "bos_token_id": 2,
  "classifier_dropout_prob": 0.1,
  "embedding_size": 128,
  "eos_token_id": 3,
  "hidden_act": "gelu_new",
  "hidden_dropout_prob": 0,
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "inner_group_num": 1,
  "intermediate_size": 16384,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "albert",
  "num_attention_heads": 64,
  "num_hidden_groups": 1,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.39.3",
  "type_vocab_size": 2,
  "vocab_size": 30000
}

In [11]:
albertconfig = AlbertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=256,
    num_attention_heads=8,
    intermediate_size=512,
    max_position_embeddings=tokenizer.model_max_length,
    num_hidden_layers=6,
)


AlbertForMaskedLM(albertconfig)

AlbertForMaskedLM(
  (albert): AlbertModel(
    (embeddings): AlbertEmbeddings(
      (word_embeddings): Embedding(2916, 128, padding_idx=0)
      (position_embeddings): Embedding(2911, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0, inplace=False)
    )
    (encoder): AlbertTransformer(
      (embedding_hidden_mapping_in): Linear(in_features=128, out_features=256, bias=True)
      (albert_layer_groups): ModuleList(
        (0): AlbertLayerGroup(
          (albert_layers): ModuleList(
            (0): AlbertLayer(
              (full_layer_layer_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
              (attention): AlbertAttention(
                (query): Linear(in_features=256, out_features=256, bias=True)
                (key): Linear(in_features=256, out_features=256, bias=True)
                (value): Linear(in_features=256, out_features=256, bias=True)
  

## bert

In [9]:
from transformers import BertConfig

config = BertConfig()

config

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.39.3",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [10]:
tokenizer.vocab_size

2916

In [11]:
bertconfig = BertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=256,
    num_attention_heads=8,
    intermediate_size=512,
    max_position_embeddings=tokenizer.model_max_length,
    num_hidden_layers=6,
)
from transformers import AutoModelForMaskedLM, BertForMaskedLM

BertForMaskedLM(bertconfig)

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(2916, 256, padding_idx=0)
      (position_embeddings): Embedding(2911, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_aff

In [21]:
# from transformers import AutoModelForMaskedLM, BertForMaskedLM

# model = BertForMaskedLM(bertconfig)
from transformers import AutoModelForMaskedLM, BertForMaskedLM, AutoModelForPreTraining

model = BertForMaskedLM.from_pretrained("result/dl/ProteomicsBERT/checkpoint-500")

In [22]:
# from torch.
model.bert.embeddings.word_embeddings

Embedding(2916, 256, padding_idx=0)

In [23]:
from transformers import Trainer, TrainingArguments

In [24]:
training_args = TrainingArguments(
    output_dir="test",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    weight_decay=0.01,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets,
    eval_dataset=tokenized_datasets,
    data_collator=data_collator,
)

# trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [25]:
trainer.train_dataset

Dataset({
    features: ['eid', 'values', 'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'],
    num_rows: 300
})

In [26]:
for batch in trainer.get_train_dataloader():

    break

In [27]:
import torch


# with torch.no_grad():
o = model(**batch)

CrossEntropyLoss()


In [28]:
o.loss

tensor(8.0208, device='cuda:0', grad_fn=<NllLossBackward0>)

In [20]:
model

BertForPreTraining(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(2916, 256, padding_idx=0)
      (position_embeddings): Embedding(2916, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_

In [37]:
torch.argmax(o.prediction_logits, dim=-1)

tensor([[1643, 2178,  574,  ...,  178, 2199, 1496],
        [ 105, 1243, 2477,  ..., 2012, 1496, 2007]], device='cuda:0')

In [38]:
batch["input_ids"]

tensor([[2775, 2537, 1374,  ..., 2359, 2214,  948],
        [1153,  835, 2225,  ..., 1497,    3, 1304]], device='cuda:0')

In [39]:
batch["labels"]

tensor([[-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, 1606, -100]], device='cuda:0')

In [40]:
torch.nn.CrossEntropyLoss()(
    o.prediction_logits.view(-1, tokenizer.vocab_size), batch["labels"].view(-1)
)

tensor(8.0299, device='cuda:0', grad_fn=<NllLossBackward0>)

In [42]:
 o.prediction_logits.view(-1, tokenizer.vocab_size).shape

torch.Size([4096, 2916])

In [42]:
batch["token_type_ids"]

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='cuda:0')

In [43]:
batch["input_ids"]

tensor([[2775, 2537, 1374,  ..., 2359, 2214,  948],
        [1153,  835, 2225,  ..., 1497, 1453, 1304],
        [1837, 1500, 2624,  ..., 1265,    3, 1918],
        ...,
        [ 371, 2825,  943,  ...,  524, 2886, 1549],
        [2537, 2829,    3,  ..., 1380, 1278,  739],
        [ 389, 2537, 2488,  ..., 1663, 1429, 1442]], device='cuda:0')

In [48]:
batch.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

In [46]:
tokenizer.decode(batch["input_ids"][0][:10])

'trim26 spint3 il17f ccl20 erc2 s100a14 fam171a2 [MASK] pglyrp1 kir2ds4'

In [44]:
batch["input_ids"][0][:10]

tensor([2775, 2537, 1374,  391,  957, 2323,  996,    3, 2028, 1499],
       device='cuda:0')

In [45]:
batch["labels"][0][:10]

tensor([-100, -100, -100, -100, -100, -100, -100, 1177, -100, -100],
       device='cuda:0')

In [None]:
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))


In [28]:
o.loss

In [20]:
model = model.from_pretrained("result/dl/ProteomicsBERT/checkpoint-99500")

In [21]:
batch

{'input_ids': tensor([[2778, 2540, 1377,  ...,  374,  856, 1287],
        [1159,  839, 2228,  ..., 1552,  565, 1183],
        [1840, 1503, 2627,  ..., 1319,    3,  731],
        ...,
        [ 375, 2828,  949,  ..., 2871, 1697,  394],
        [2540, 2832,    3,  ..., 1667,   56, 1649],
        [ 393, 2540, 2491,  ...,  121,  984,  182]], device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]], device='cuda:0'), 'labels': tensor([[-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, 2641, -100],
        [-100, -100,

In [25]:
model.to("cpu")
o = model(batch)
o

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [24]:
# model(batch)
batch

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
from transformers import AutoConfig


config = AutoConfig.from_pretrained(
    "distilbert/distilroberta-base",
    vocab_size=len(tokenizer),
    n_ctx=context_length,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

In [None]:
from transformers import AutoModelForMaskedLM

model = AutoModelForMaskedLM.from_pretrained("distilbert/distilroberta-base")

In [None]:
data_collator

In [None]:
from transformers import BertTokenizer, BertTokenizerFast, PreTrainedTokenizerFast


protein_tokenizer = BertTokenizerFast.from_pretrained("transtab/tokenizer")


def tokenize_function(examples):
    return protein_tokenizer(examples["proteins"])


tokenized_datasets = test_dataset.map(
    tokenize_function, batched=True, num_proc=4, batch_size=32
)

In [None]:
protein_tokenizer(test_dataset[0]["proteins"])

In [None]:
# res_list = {}


# for idx, row in test_data.set_index("eid").iterrows():
#     ranked_row = row.sort_values(ascending=False).dropna()
#     res = {}
#     res["eid"] = ranked_row.name
#     res["proteins"] = ranked_row.index.tolist()
#     res["values"] = ranked_row.values.tolist()
#     res_list.append(res)


# def dict_generator():
#     for i in range(len(res_list)):
#         yield res_list[i]

# Dataset.from_generator(dict_generator, num_proc=4)

In [None]:
Dataset.from_generator(res_list)

In [None]:
import numpy as np


def rank_genes(gene_vector, gene_tokens):
    """
    Rank gene expression vector.
    """
    # sort by median-scaled gene values
    sorted_indices = np.argsort(-gene_vector)
    return gene_tokens[sorted_indices], gene_vector[sorted_indices]


def rank_sorted(examples, protein_cols=None):
    """
    Rank protein expression vectors.
    """

    protein_expression_vectors = np.array([examples[col] for col in protein_cols])
    protein_tokens = protein_cols
    ranked_protein_tokens, ranked_protein_expression_vectors = rank_genes(
        protein_expression_vectors, protein_tokens
    )

    other_cols = [col for col in examples.columns if col not in protein_cols]
    return_res = {}
    for i, col in enumerate(other_cols):
        return_res[col] = examples[col].values
    return_res["protein_tokens"] = ranked_protein_tokens
    return_res["protein_expression_vectors"] = ranked_protein_expression_vectors
    return return_res


test_dataset.map(lambda x: rank_sorted(x, protein_cols=protein_cols))

In [None]:
# normalize = True 
# if normalize:
#     # pd.concat([train_data, test_data])[protein_cols]