# import

In [2]:
############################ import

import scanpy as sc
import pandas as pd
import numpy as np
from glob import glob
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
from torch_geometric.data import Data
from torch_geometric.data import DataLoader

import pickle
import sys
import requests
import json
import time
import copy
from pathlib import Path
from typing import Iterable, List, Tuple, Dict, Union, Optional
import warnings

from types import MethodType

import importlib
from scperturb import *
import anndata as ad

import argparse


from gears import PertData, GEARS
from gears.inference import compute_metrics, non_dropout_analysis

sys.path.insert(0, "/data1/lichen/code/single_cell_perturbation/others/scGPT/")
sys.path.append("/data1/lichen/code/single_cell_perturbation/scPerturb/Byte_Pert_Data/")

import scgpt as scg
# from scgpt.model import TransformerGenerator
from scgpt.loss import (
    masked_mse_loss,
    # criterion_neg_log_bernoulli,
    masked_relative_error,
)

import v1
from v1.utils import *
from v1.dataloader import *

# importlib.reload(v1)
# importlib.reload(v1.utils)  
# importlib.reload(v1.dataloader)


from config import prefix_list

from torch import nn
sys.path.append("/data1/lichen/code/single_cell_perturbation/others/scFoundation-main/model/") # path to this folder
from load import *

# function

In [3]:
############################ function

class scF_finetune_model(nn.Module):

    def __init__(self, ckpt_path,frozenmore=True):
        super().__init__()
        self.ckpt_path = ckpt_path
        self.frozenmore = frozenmore

    def build(self,
              pert_pad_id = 2):
        model,model_config = load_model_frommmf(self.ckpt_path)
        self.token_emb = model.token_emb
        self.pos_emb = model.pos_emb
        self.encoder = model.encoder

        self.pert_emb = nn.Embedding(3, model_config['encoder']['hidden_dim'], padding_idx=pert_pad_id)
        
        if self.frozenmore:
            for _,p in self.token_emb.named_parameters():
                p.requires_grad = False
            for _,p in self.pos_emb.named_parameters():
                p.requires_grad = False
            print('self.pos_emb and self.token_emb also frozen')
        
            for na, param in self.encoder.named_parameters():
                param.requires_grad = False
            for na, param in self.encoder.transformer_encoder[-2].named_parameters():
                print('self.encoder.transformer_encoder ',na,' have grad')
                param.requires_grad = True

        else:
            # - make all the layers able to train
            None

        self.fc = nn.Linear(model_config['encoder']['hidden_dim'], 1)
        self.fc1 = nn.Sequential(
        nn.Linear(model_config['encoder']['hidden_dim'], 256),
        nn.ReLU(),
        nn.Linear(256, 10)  # ['n_class']
        ) 
        self.norm = torch.nn.BatchNorm1d(model_config['encoder']['hidden_dim'], affine=False, eps=1e-6)
        self.model_config = model_config
        
    def forward(self, 
                ori_gene_values,
                pert_flags,
                position_gene_ids,
                *args, **kwargs):
        
        x = ori_gene_values
        value_labels = torch.ones_like(
            x, dtype=torch.bool
        )

        x_padding = x.eq(self.model_config['pad_token_id'])
        x = self.token_emb(torch.unsqueeze(x, 2).float(), output_weight = 0)

        # position_gene_ids = position_gene_ids.reshape(-1, 1).repeat(x.shape, 1)
        position_emb = self.pos_emb(position_gene_ids)
        x += position_emb

        x += self.pert_emb(pert_flags)

        logits = self.encoder(x, x_padding)

        logits = self.fc(logits)
        logits = logits.squeeze(-1)

        # # mlp
        # logits, _ = torch.max(logits, dim=1)  # b,dim

        # logits = self.norm(logits)
        # logits = self.fc1(logits)

        return logits

def pred_perturb_new(
    model,
    batch_data,
    # include_zero_gene="batch-wise",
    # gene_ids=None,
    amp=True,
):
    """
    Args:
        batch_data: a dictionary of input data with keys.

    Returns:
        output Tensor of shape [N, seq_len]
    """
    model.eval()
    device = next(model.parameters()).device
    batch_data.to(device)
    
    input_values = batch_data.x
    pert_flags = batch_data.pert_flags.long()
    mapped_input_gene_ids = batch_data.mapped_input_gene_ids


    with torch.cuda.amp.autocast(enabled=amp):
        output_values = model(
            input_values,
            pert_flags,
            mapped_input_gene_ids
        )
    return output_values

# init

In [22]:
################################ init para
# init para

# - init dataloader para
data_dir = '/nfs/public/lichen/data/single_cell/perturb_data/scPerturb/raw/scPerturb_rna/statistic_20240520'
pert_cell_filter = 100 # this is used to filter perts, cell number less than this will be filtered
seed = 2024 # this is the random seed
split_type = 1 # 1 for unseen perts; 0 for unseen celltypes
split_ratio = [0.7, 0.2, 0.1] # train:test:val; val is used to choose data, test is for final validation
var_num = 5000 # selecting hvg number
num_de_genes = 20 # number of de genes
# bs_train = 2 # batch size of trainloader
bs_test =  2 # batch size of testloader
lr = 1e-4

# - multi gpu para
n_gpu = 1
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
# device_ids = list(range(n_gpu))
device_ids = [3]

# - training epoch
epochs = 20
max_seq_len = 6000
early_stop = 20
save_flag = True

# - intial the loss para
amp = True
schedule_interval = 1
lr = 1e-4
include_zero_gene = "all"
log_interval = 100
mask_value = -1


gene_mode = 'whole'

In [21]:
dataset_pert_dict = {
    'CAR_T': ['PDCD1'],
    'blood': ['GATA1', 'SPI1'],
    'OSKM': [['SOX2',
         'POU5F1',
         'KLF4',
         'MYC']],
    'ADM': ['PTF1A']
}
dataset_celltype_dict = {
    'CAR_T': 'Tex', 
    'blood': 'LMPP',
    'OSKM': 'Fibroblast-like',
    'ADM': 'Acinar'
}


dataset_dire_dict = {
    'CAR_T': 'down', 
    'blood': 'down',
    'OSKM': 'up',
    'ADM': 'down'
}


datasets = list(dataset_pert_dict.keys())

# dataset = datasets[0]

# main

In [25]:
parser = argparse.ArgumentParser(description="scF apply L1000")
parser.add_argument('--cell_line_bulk', type=str, default=None)
parser.add_argument('--model_mode', type=str, default=None) # pretrain, init
args = parser.parse_args([])

args.model_mode = 'pretrain'
model_mode = args.model_mode

adata_mode = 'non_minus'
# minus: save adata as the minus delta
# non_minus: save adata, add the minus delta on the original gene exp

for dataset in datasets[:]:
    # dataset = 'OSKM'
    # dataset = 'blood'
    # dataset = 'CAR_T'
    print('='*20, dataset, '='*20)

    ckpt_path = '/data1/lichen/code/single_cell_perturbation/others/scFoundation-main/model/models/models.ckpt'
    model,model_config = load_model_frommmf(ckpt_path)

    save_dir = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/real_case/data'
    adata_rna = sc.read(os.path.join(save_dir, dataset, 'adata_ctrl_v2.h5ad'))
    if not isinstance(adata_rna.X, np.ndarray):
        adata_rna.X = adata_rna.X.toarray()

    #====================get scF common, and mapped gene ids
    adata = adata_rna.copy()
    X_df= pd.DataFrame(adata.X, index=adata.obs.index.tolist(), columns=adata.var.index.tolist()) # read from csv file
    gene_list_df = pd.read_csv('/data1/lichen/code/single_cell_perturbation/others/scFoundation-main/OS_scRNA_gene_index.19264.tsv', header=0, delimiter='\t')
    scF_gene_list = list(gene_list_df['gene_name'])
    print('common genes are: ', len(np.intersect1d(scF_gene_list, adata.var_names)))

    # - transform our adata to scF adata
    from scRNA_workflow import *
    X_df, to_fill_columns, var = main_gene_selection(X_df, scF_gene_list)
    adata_uni = sc.AnnData(X_df)
    adata_uni.obs = adata.obs
    adata_uni.uns = adata.uns

    if gene_mode == 'common':
        # - get common genes
        common_genes = np.intersect1d(adata_rna.var_names, adata_uni.var_names)
        print('common genes to scF gene list is: ', f'{len(common_genes)}/{len(adata_rna.var_names)}')
        # - get the gene_id [common_gene_ids are the input to scF, gene positions]
        common_gene_ids = np.array([list(adata_uni.var_names).index(gene) for gene in common_genes])
        common_gene_ids = torch.tensor(common_gene_ids)

        adata_ctrl = adata_rna[:, common_genes].copy()
    elif gene_mode == 'whole':
        common_genes = np.array(adata_rna.var_names)
        # - get the gene_id [common_gene_ids are the input to scF, gene positions]
        common_gene_ids = np.array([list(adata_uni.var_names).index(gene) if gene in adata_uni.var_names else model_config['pad_token_id'] for gene in common_genes])
        common_gene_ids = torch.tensor(common_gene_ids)
        adata_ctrl = adata_rna.copy()
    else:
        raise ValueError()


    #====================initial the model
    # - initial model
    scF_model = scF_finetune_model(ckpt_path=ckpt_path,
                                    frozenmore = False)
    scF_model.build()

    if model_mode == 'pretrain':
        # -- load out weight
        # model_file = '/hpc-cache-pfs/home/lichen/result/single_cell_perturbation/scFoundation_pretrain/down/model_best.pt'

        if dataset_dire_dict[dataset] == 'down':
            model_file = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/volc_result/scFoundation_pretrain_v5/down/model_best.pt'
        if dataset_dire_dict[dataset] == 'up':
            model_file = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/volc_result/scFoundation_pretrain_v5/up/model_best.pt'
        # 加载保存的模型权重
        saved_state_dict = torch.load(model_file)
        from collections import OrderedDict
        # 创建一个新的字典来存储修改后的权重
        new_state_dict = OrderedDict()
        # 修改键名称
        for key, value in saved_state_dict.items():
            new_key = key.replace('module.', '')  # 移除'module.'前缀
            new_state_dict[new_key] = value
        # 加载修改后的权重
        scF_model.load_state_dict(new_state_dict)

    elif model_mode == 'init':
        None
    else:
        raise ValueError()

    # - add the parallel
    model = torch.nn.DataParallel(scF_model, device_ids=device_ids)
    # - put model on device
    model.to(device)
    best_model = copy.deepcopy(model)

    for pert in tqdm(dataset_pert_dict[dataset]):

        # pert = 'SPI1'
        print('*'*20, pert)

        if isinstance(pert, str):
            pert_combo = [pert]
        else:
            pert_combo = pert

        # - prepare the test_loader
        Xs = adata_ctrl.X
        var_names = list(adata_ctrl.var_names)

        if not isinstance(Xs, np.ndarray):
            Xs = Xs.toarray()

        #================================= get ctrl output
        cell_graphs = []
        for X in Xs:
            # - pert_flags for multi perts
            pert_flags = torch.zeros(Xs.shape[1])
            for tmp_pert in pert_combo:
                if tmp_pert not in var_names:
                    raise ValueError(f'{pert} not in var_names')
                else:
                    pert_flags[var_names.index(tmp_pert)] = 0
            tmp_Data = Data(x = torch.Tensor(X.reshape(1, -1)),
                        pert_flags = pert_flags.reshape(1, -1),
                        mapped_input_gene_ids = common_gene_ids.reshape(1, -1))
            cell_graphs.append(tmp_Data)

        test_loader = DataLoader(cell_graphs,
                                batch_size=bs_test * n_gpu, 
                                shuffle=False)

        # - infer the data
        pred = []
        for itr, batch in enumerate(test_loader):
            batch.to(device)
            with torch.no_grad():
                p = pred_perturb_new(best_model, batch)
                pred.extend(p.cpu())
        pred = torch.stack(pred)
        pred = pred.detach().cpu().numpy()
        pred_ctrl = pred.copy()

        #================================= get pert output
        cell_graphs = []
        for X in Xs:
            # - pert_flags for multi perts
            pert_flags = torch.zeros(Xs.shape[1])
            for tmp_pert in pert_combo:
                if tmp_pert not in var_names:
                    raise ValueError(f'{pert} not in var_names')
                else:
                    pert_flags[var_names.index(tmp_pert)] = 1
            tmp_Data = Data(x = torch.Tensor(X.reshape(1, -1)),
                        pert_flags = pert_flags.reshape(1, -1),
                        mapped_input_gene_ids = common_gene_ids.reshape(1, -1))
            cell_graphs.append(tmp_Data)

        test_loader = DataLoader(cell_graphs,
                                batch_size=bs_test * n_gpu, 
                                shuffle=False)

        # - infer the data
        pred = []
        for itr, batch in enumerate(test_loader):
            batch.to(device)
            with torch.no_grad():
                p = pred_perturb_new(best_model, batch)
                pred.extend(p.cpu())
        pred = torch.stack(pred)
        pred = pred.detach().cpu().numpy()
        pred_pert = pred.copy()

        pert_prefix = '_'.join(pert_combo)
        tmp_dir = f'/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark_202410/real_case/result/{dataset}'
        save_prefix = f'scFoundation/{pert_prefix}' # use result of K562 to do the direct transfer
        os.makedirs(os.path.join(tmp_dir, save_prefix), exist_ok=True)

        adata_pert = adata_rna.copy()

        if adata_mode == 'minus':
            adata_pert.X = pred_pert - pred_ctrl
            adata_pert.obs_names = [i+f'_{pert_prefix}' for i in adata_pert.obs_names]
            adata_pert.write(os.path.join(tmp_dir, save_prefix, 'adata_pert_minus.h5ad'))

        elif adata_mode == 'non_minus':
            adata_pert.X += pred_pert - pred_ctrl
            adata_pert.obs_names = [i+f'_{pert_prefix}' for i in adata_pert.obs_names]
            adata_pert.write(os.path.join(tmp_dir, save_prefix, 'adata_pert.h5ad'))
        else:
            raise ValueError()

    #     break
    # break


{'mask_gene_name': False, 'gene_num': 19266, 'seq_len': 19266, 'encoder': {'hidden_dim': 768, 'depth': 12, 'heads': 12, 'dim_head': 64, 'seq_len': 19266, 'module_type': 'transformer', 'norm_first': False}, 'decoder': {'hidden_dim': 512, 'depth': 6, 'heads': 8, 'dim_head': 64, 'module_type': 'performer', 'seq_len': 19266, 'norm_first': False}, 'n_class': 104, 'pad_token_id': 103, 'mask_token_id': 102, 'bin_num': 100, 'bin_alpha': 1.0, 'rawcount': True, 'model': 'mae_autobin', 'test_valid_train_idx_dict': '/nfs_beijing/minsheng/data/os10000w-new/global_shuffle/meta.csv.train_set_idx_dict.pt', 'valid_data_path': '/nfs_beijing/minsheng/data/valid_count_10w.npz', 'num_tokens': 13, 'train_data_path': None, 'isPanA': False, 'isPlanA1': False, 'max_files_to_load': 5, 'bin_type': 'auto_bin', 'value_mask_prob': 0.3, 'zero_mask_prob': 0.03, 'replace_prob': 0.8, 'random_token_prob': 0.1, 'mask_ignore_token_ids': [0], 'decoder_add_zero': True, 'mae_encoder_max_seq_len': 15000, 'isPlanA': False, 'ma



common genes are:  2376
{'mask_gene_name': False, 'gene_num': 19266, 'seq_len': 19266, 'encoder': {'hidden_dim': 768, 'depth': 12, 'heads': 12, 'dim_head': 64, 'seq_len': 19266, 'module_type': 'transformer', 'norm_first': False}, 'decoder': {'hidden_dim': 512, 'depth': 6, 'heads': 8, 'dim_head': 64, 'module_type': 'performer', 'seq_len': 19266, 'norm_first': False}, 'n_class': 104, 'pad_token_id': 103, 'mask_token_id': 102, 'bin_num': 100, 'bin_alpha': 1.0, 'rawcount': True, 'model': 'mae_autobin', 'test_valid_train_idx_dict': '/nfs_beijing/minsheng/data/os10000w-new/global_shuffle/meta.csv.train_set_idx_dict.pt', 'valid_data_path': '/nfs_beijing/minsheng/data/valid_count_10w.npz', 'num_tokens': 13, 'train_data_path': None, 'isPanA': False, 'isPlanA1': False, 'max_files_to_load': 5, 'bin_type': 'auto_bin', 'value_mask_prob': 0.3, 'zero_mask_prob': 0.03, 'replace_prob': 0.8, 'random_token_prob': 0.1, 'mask_ignore_token_ids': [0], 'decoder_add_zero': True, 'mae_encoder_max_seq_len': 1500

  0%|          | 0/1 [00:00<?, ?it/s]

******************** PDCD1


100%|██████████| 1/1 [01:38<00:00, 98.44s/it]


{'mask_gene_name': False, 'gene_num': 19266, 'seq_len': 19266, 'encoder': {'hidden_dim': 768, 'depth': 12, 'heads': 12, 'dim_head': 64, 'seq_len': 19266, 'module_type': 'transformer', 'norm_first': False}, 'decoder': {'hidden_dim': 512, 'depth': 6, 'heads': 8, 'dim_head': 64, 'module_type': 'performer', 'seq_len': 19266, 'norm_first': False}, 'n_class': 104, 'pad_token_id': 103, 'mask_token_id': 102, 'bin_num': 100, 'bin_alpha': 1.0, 'rawcount': True, 'model': 'mae_autobin', 'test_valid_train_idx_dict': '/nfs_beijing/minsheng/data/os10000w-new/global_shuffle/meta.csv.train_set_idx_dict.pt', 'valid_data_path': '/nfs_beijing/minsheng/data/valid_count_10w.npz', 'num_tokens': 13, 'train_data_path': None, 'isPanA': False, 'isPlanA1': False, 'max_files_to_load': 5, 'bin_type': 'auto_bin', 'value_mask_prob': 0.3, 'zero_mask_prob': 0.03, 'replace_prob': 0.8, 'random_token_prob': 0.1, 'mask_ignore_token_ids': [0], 'decoder_add_zero': True, 'mae_encoder_max_seq_len': 15000, 'isPlanA': False, 'ma



common genes are:  1143
{'mask_gene_name': False, 'gene_num': 19266, 'seq_len': 19266, 'encoder': {'hidden_dim': 768, 'depth': 12, 'heads': 12, 'dim_head': 64, 'seq_len': 19266, 'module_type': 'transformer', 'norm_first': False}, 'decoder': {'hidden_dim': 512, 'depth': 6, 'heads': 8, 'dim_head': 64, 'module_type': 'performer', 'seq_len': 19266, 'norm_first': False}, 'n_class': 104, 'pad_token_id': 103, 'mask_token_id': 102, 'bin_num': 100, 'bin_alpha': 1.0, 'rawcount': True, 'model': 'mae_autobin', 'test_valid_train_idx_dict': '/nfs_beijing/minsheng/data/os10000w-new/global_shuffle/meta.csv.train_set_idx_dict.pt', 'valid_data_path': '/nfs_beijing/minsheng/data/valid_count_10w.npz', 'num_tokens': 13, 'train_data_path': None, 'isPanA': False, 'isPlanA1': False, 'max_files_to_load': 5, 'bin_type': 'auto_bin', 'value_mask_prob': 0.3, 'zero_mask_prob': 0.03, 'replace_prob': 0.8, 'random_token_prob': 0.1, 'mask_ignore_token_ids': [0], 'decoder_add_zero': True, 'mae_encoder_max_seq_len': 1500

  0%|          | 0/2 [00:00<?, ?it/s]

******************** GATA1


 50%|█████     | 1/2 [02:29<02:29, 149.59s/it]

******************** SPI1


100%|██████████| 2/2 [05:01<00:00, 150.93s/it]


{'mask_gene_name': False, 'gene_num': 19266, 'seq_len': 19266, 'encoder': {'hidden_dim': 768, 'depth': 12, 'heads': 12, 'dim_head': 64, 'seq_len': 19266, 'module_type': 'transformer', 'norm_first': False}, 'decoder': {'hidden_dim': 512, 'depth': 6, 'heads': 8, 'dim_head': 64, 'module_type': 'performer', 'seq_len': 19266, 'norm_first': False}, 'n_class': 104, 'pad_token_id': 103, 'mask_token_id': 102, 'bin_num': 100, 'bin_alpha': 1.0, 'rawcount': True, 'model': 'mae_autobin', 'test_valid_train_idx_dict': '/nfs_beijing/minsheng/data/os10000w-new/global_shuffle/meta.csv.train_set_idx_dict.pt', 'valid_data_path': '/nfs_beijing/minsheng/data/valid_count_10w.npz', 'num_tokens': 13, 'train_data_path': None, 'isPanA': False, 'isPlanA1': False, 'max_files_to_load': 5, 'bin_type': 'auto_bin', 'value_mask_prob': 0.3, 'zero_mask_prob': 0.03, 'replace_prob': 0.8, 'random_token_prob': 0.1, 'mask_ignore_token_ids': [0], 'decoder_add_zero': True, 'mae_encoder_max_seq_len': 15000, 'isPlanA': False, 'ma



common genes are:  2812
{'mask_gene_name': False, 'gene_num': 19266, 'seq_len': 19266, 'encoder': {'hidden_dim': 768, 'depth': 12, 'heads': 12, 'dim_head': 64, 'seq_len': 19266, 'module_type': 'transformer', 'norm_first': False}, 'decoder': {'hidden_dim': 512, 'depth': 6, 'heads': 8, 'dim_head': 64, 'module_type': 'performer', 'seq_len': 19266, 'norm_first': False}, 'n_class': 104, 'pad_token_id': 103, 'mask_token_id': 102, 'bin_num': 100, 'bin_alpha': 1.0, 'rawcount': True, 'model': 'mae_autobin', 'test_valid_train_idx_dict': '/nfs_beijing/minsheng/data/os10000w-new/global_shuffle/meta.csv.train_set_idx_dict.pt', 'valid_data_path': '/nfs_beijing/minsheng/data/valid_count_10w.npz', 'num_tokens': 13, 'train_data_path': None, 'isPanA': False, 'isPlanA1': False, 'max_files_to_load': 5, 'bin_type': 'auto_bin', 'value_mask_prob': 0.3, 'zero_mask_prob': 0.03, 'replace_prob': 0.8, 'random_token_prob': 0.1, 'mask_ignore_token_ids': [0], 'decoder_add_zero': True, 'mae_encoder_max_seq_len': 1500

  0%|          | 0/1 [00:00<?, ?it/s]

******************** ['SOX2', 'POU5F1', 'KLF4', 'MYC']


100%|██████████| 1/1 [13:27<00:00, 807.37s/it]


{'mask_gene_name': False, 'gene_num': 19266, 'seq_len': 19266, 'encoder': {'hidden_dim': 768, 'depth': 12, 'heads': 12, 'dim_head': 64, 'seq_len': 19266, 'module_type': 'transformer', 'norm_first': False}, 'decoder': {'hidden_dim': 512, 'depth': 6, 'heads': 8, 'dim_head': 64, 'module_type': 'performer', 'seq_len': 19266, 'norm_first': False}, 'n_class': 104, 'pad_token_id': 103, 'mask_token_id': 102, 'bin_num': 100, 'bin_alpha': 1.0, 'rawcount': True, 'model': 'mae_autobin', 'test_valid_train_idx_dict': '/nfs_beijing/minsheng/data/os10000w-new/global_shuffle/meta.csv.train_set_idx_dict.pt', 'valid_data_path': '/nfs_beijing/minsheng/data/valid_count_10w.npz', 'num_tokens': 13, 'train_data_path': None, 'isPanA': False, 'isPlanA1': False, 'max_files_to_load': 5, 'bin_type': 'auto_bin', 'value_mask_prob': 0.3, 'zero_mask_prob': 0.03, 'replace_prob': 0.8, 'random_token_prob': 0.1, 'mask_ignore_token_ids': [0], 'decoder_add_zero': True, 'mae_encoder_max_seq_len': 15000, 'isPlanA': False, 'ma



common genes are:  4323
{'mask_gene_name': False, 'gene_num': 19266, 'seq_len': 19266, 'encoder': {'hidden_dim': 768, 'depth': 12, 'heads': 12, 'dim_head': 64, 'seq_len': 19266, 'module_type': 'transformer', 'norm_first': False}, 'decoder': {'hidden_dim': 512, 'depth': 6, 'heads': 8, 'dim_head': 64, 'module_type': 'performer', 'seq_len': 19266, 'norm_first': False}, 'n_class': 104, 'pad_token_id': 103, 'mask_token_id': 102, 'bin_num': 100, 'bin_alpha': 1.0, 'rawcount': True, 'model': 'mae_autobin', 'test_valid_train_idx_dict': '/nfs_beijing/minsheng/data/os10000w-new/global_shuffle/meta.csv.train_set_idx_dict.pt', 'valid_data_path': '/nfs_beijing/minsheng/data/valid_count_10w.npz', 'num_tokens': 13, 'train_data_path': None, 'isPanA': False, 'isPlanA1': False, 'max_files_to_load': 5, 'bin_type': 'auto_bin', 'value_mask_prob': 0.3, 'zero_mask_prob': 0.03, 'replace_prob': 0.8, 'random_token_prob': 0.1, 'mask_ignore_token_ids': [0], 'decoder_add_zero': True, 'mae_encoder_max_seq_len': 1500

  0%|          | 0/1 [00:00<?, ?it/s]

******************** PTF1A


100%|██████████| 1/1 [02:28<00:00, 148.47s/it]


In [11]:
dataset

'CAR_T'