# Preprocessing Code

In [None]:
from imblearn.over_sampling import RandomOverSampler, SMOTE
from imblearn.under_sampling import RandomUnderSampler, TomekLinks
import scvi
import scanpy as sc

## Oversampling code

In [None]:
adata = sc.read("data/ms_train.h5ad")
type_to_index = dict(zip(adata.obs['celltype'].unique(), range(len(adata.obs['celltype'].unique()))))
index_to_type = dict(zip(range(len(adata.obs['celltype'].unique())), adata.obs['celltype'].unique()))

In [None]:
X = adata.X
vr = adata.var
y = [type_to_index[i] for i in list(adata.obs['celltype'])]

sampler = RandomOverSampler(random_state=42)
X_resampled, y_resampled = sampler.fit_resample(X, y)

adata_resampled = AnnData(X_resampled, obs={'celltype': y_resampled}, var=vr)

## Undersampling code

In [None]:
adata = sc.read("data/ms_train.h5ad")
type_to_index = dict(zip(adata.obs['celltype'].unique(), range(len(adata.obs['celltype'].unique()))))
index_to_type = dict(zip(range(len(adata.obs['celltype'].unique())), adata.obs['celltype'].unique()))

In [None]:
def undersample_with_threshold(adata, threshold):
    celltypes_dict = dict(adata.obs['celltype'].value_counts())
    to_undersample_dict = {}
    to_keep_dict = {}

    for celltype, samples in celltypes_dict.items():
        if samples >= threshold:
            to_undersample_dict[celltype] = samples
        else:
            to_keep_dict[celltype] = samples
            
    min_class = min(to_undersample_dict, key=celltypes_dict.get)
    threshold_class = adata[adata[adata.obs['celltype'] == min_class].obs.sample(threshold).index]
    undersample_adata = adata[adata.obs['celltype'] != min_class]
    undersample_adata = undersample_adata.concatenate(threshold_class)
    
    keep_data = adata[adata.obs['celltype'].isin(to_keep_dict.keys())]
    undersample_data = undersample_adata[undersample_adata.obs['celltype'].isin(to_undersample_dict.keys())]
    
    X = undersample_data.X
    vr = undersample_data.var
    # Assuming `adata` contains `.obs` attribute with cell labels
    y = [type_to_index[i] for i in list(undersample_data.obs['celltype'])]

    # Create a SMOTE object
    under_sampler = RandomUnderSampler(random_state=42)

    # Resample the data
    X_resampled, y_resampled = under_sampler.fit_resample(X, y)

    # Create a new AnnData object with the resampled data
    adata_resampled = AnnData(X_resampled, obs={'celltype': y_resampled, 'str_batch': 0}, var=vr)
    keep_data.obs['celltype'] = keep_data.obs['celltype'].apply(lambda x: type_to_index[x])
    adata_resampled = adata_resampled.concatenate(keep_data)
    adata_resampled.obs['str_batch'] = 0

    return adata_resampled

## Imputation code

In [None]:
new_adata = scvi.data.read_h5ad('data/ms_train.h5ad') 
counts_dict = dict(adata.obs['celltype'].value_counts())
threshold = # pick a threshold
adata = adata[adata.obs['celltype'].map(counts_dict).astype(int) <= threshold]
adata = adata.copy()

In [None]:
is_continue = True

while is_continue:
    # filtering class that have counts under the specified threshold (to oversample)
    counts_dict = dict(new_adata.obs['celltype'].value_counts())
    print(counts_dict)
    adata = adata[adata.obs['celltype'].map(counts_dict).astype(int) <= threshold]
    adata = adata.copy()
    scvi.model.SCVI.setup_anndata(adata)
    model = scvi.model.SCVI(adata)
    model.train()
    imputed_adata = model.get_normalized_expression()
    imputed_adata = AnnData(imputed_adata, obs=adata.obs, var=adata.var, uns=adata.uns, obsm=adata.obsm, varm=adata.varm)
    
    add_adata = imputed_adata[imputed_adata.obs['celltype'].map(counts_dict).astype(int) <= threshold]
    add_dict = {key: value for key, value in counts_dict.items() if value < threshold}
    
    if add_dict == {}:
        is_continue = False
    
    for celltype, count in add_dict.items():
        concat_adata = imputed_adata[imputed_adata.obs['celltype'] == celltype]
        
        concat_adata_size = concat_adata.shape[0]
        new_adata_size = counts_dict[celltype]
        
        if new_adata_size + concat_adata_size > threshold:
            n_obs = threshold - new_adata_size
            sp.pp.subsample(concat_adata, n_obs=n_obs)
        new_adata = new_adata.concatenate(concat_adata)
    try:
        new_adata.write_h5ad('ms_train.h5ad')
    except:
        new_adata.write('ms_train.h5ad')

## Geneformer's converting the gene expression matrix into ordinal gene tokens function (oversampling and undersampling embedded)

In [None]:
def load_h5ad_dataset(path, saving_name='data', is_save=False, is_oversample=False, is_undersample=False):
    inp_list = []
    len_list = []
    
    def add_input_ids(row):
        inp_list.append(row.sort_values(ascending=False).index.tolist()[:sum(row.gt(0))])
    
    def add_len(row):
        len_list.append(len(row['input_ids']))
        
    def convert_naming(name, name_dict):
        return name_dict[name]
        
    dataset_base = anndata.read_h5ad(path)
    unique_nums = len(dataset_base.obs['celltype'].unique())

    index_to_type = dict(zip(range(unique_nums), list(dataset_base.obs['celltype'].unique())))
    type_to_index = dict(zip(list(dataset_base.obs['celltype'].unique()), range(unique_nums)))
    
    # dataset_base.obs['celltype'] =  dataset_base.obs['celltype'].apply(lambda x: convert_naming(x, type_to_index))
    
    if is_oversample:
        sampler = RandomOverSampler(random_state=42)
        X_resampled, y_resampled = sampler.fit_resample(dataset_base.X, dataset_base.obs['celltype'])
        dataset_base = AnnData(X_resampled, obs={'celltype': y_resampled.values})
        
    if is_undersample:
        sampler = RandomUnderSampler(random_state=42)
        X_resampled, y_resampled = sampler.fit_resample(dataset_base.X, dataset_base.obs['celltype'])
        dataset_base = AnnData(X_resampled, obs={'celltype': y_resampled.values})
        
    dataset = dataset_base.to_df()
    dataset.rename(columns={x:y for x,y in zip(dataset.columns,range(0,len(dataset.columns)))}, inplace=True)
    
    for row in range(len(dataset)):
        add_input_ids(dataset.iloc[row])
    dataset['input_ids'] = inp_list
    
    dataset.drop(labels=range(3000), axis=1, inplace=True)
    
    for row in range(len(dataset)):
        add_len(dataset.iloc[row])
    dataset['length'] = len_list
    
    dataset['type'] = dataset_base.obs['celltype']
    dataset['gene_name'] = range(dataset.shape[0])
    
    dataset.reset_index(inplace=True)
    dataset.drop(labels=['index', 'gene_name'], axis=1, inplace=True)
    dataset.rename_axis(None, axis=1, inplace=True)
    
    if is_oversample == False and is_undersample == False:
        dataset['type'] = dataset['type'].apply(lambda x: convert_naming(x, type_to_index))
    dataset.rename(columns={'type':'label'}, inplace=True)
    
    dataset = dataset[['label', 'input_ids', 'length']]
    
    if is_save:
        dataset.to_csv(saving_name)
        with open(f'dictionaries/index_to_type_{saving_name}.pkl', 'wb') as fp:
            pickle.dump(index_to_type, fp)
        with open(f'dictionaries/type_to_index_{saving_name}.pkl', 'wb') as fp:
            pickle.dump(type_to_index, fp)
    
    return dataset, index_to_type, type_to_index
    

In [None]:
def load_csv_dataset(path):
    def convert_list_values_to_int(x):
        lst = [int(i) for i in x]
        return lst
    
    dataset = pd.read_csv(path)
    
    dataset.drop(labels='Unnamed: 0', axis=1, inplace=True)
    dataset['label'] = dataset['label'].astype(int)
    dataset['input_ids'] = dataset['input_ids'].str.strip('[]').str.split(', ')
    dataset['input_ids'] = dataset['input_ids'].apply(convert_list_values_to_int)
    dataset['length'] = dataset['length'].astype(int)
    
    return dataset