In [None]:
!pip install scanpy
!pip install catboost
!pip install lightgbm



In [None]:
import scanpy as sc
import numpy as np
import pandas as pd

from catboost import CatBoostClassifier
from sklearn.metrics import f1_score, roc_auc_score, roc_curve, precision_score, recall_score
from sklearn.metrics import precision_recall_curve, accuracy_score
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.utils import shuffle
from sklearn.compose import ColumnTransformer
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from tqdm import tqdm

In [None]:
! wget https://download.brainimagelibrary.org/cf/1c/cf1c1a431ef8d021/processed_data/counts.h5ad

cell_labels = pd.read_csv('https://download.brainimagelibrary.org/cf/1c/cf1c1a431ef8d021/processed_data/cell_labels.csv')
adata = sc.read_h5ad("counts.h5ad")

--2023-10-08 18:53:18--  https://download.brainimagelibrary.org/cf/1c/cf1c1a431ef8d021/processed_data/counts.h5ad
Resolving download.brainimagelibrary.org (download.brainimagelibrary.org)... 192.231.243.61, 192.231.243.62
Connecting to download.brainimagelibrary.org (download.brainimagelibrary.org)|192.231.243.61|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 304862128 (291M) [application/octet-stream]
Saving to: ‘counts.h5ad.1’


2023-10-08 18:53:31 (25.1 MB/s) - ‘counts.h5ad.1’ saved [304862128/304862128]



In [None]:
# переименуем столбец 'Unnamed: 0'в 'ID'
cell_labels = cell_labels.rename(columns={'Unnamed: 0': 'ID'})

## EDA

In [None]:
# подсчёт доли пропусков
cell_labels.isna().mean()

ID             0.0
sample_id      0.0
slice_id       0.0
class_label    0.0
subclass       0.0
label          0.0
dtype: float64

Пропуски отсутствуют

In [None]:
# подсчет явных дубликатов
cell_labels.duplicated().sum()

0

Дубликаты в данных не обнаружены

In [None]:
# получение информации о типах данных
cell_labels.dtypes

ID             object
sample_id      object
slice_id       object
class_label    object
subclass       object
label          object
dtype: object

In [None]:
# подсчет количества значений целевого признака
cell_labels['class_label'].value_counts()

Glutamatergic    155831
Other            106055
GABAergic         18300
Name: class_label, dtype: int64

In [None]:
# подсчет доли каждого класса
print(cell_labels['class_label'].value_counts(normalize=True))

Glutamatergic    0.556170
Other            0.378516
GABAergic        0.065314
Name: class_label, dtype: float64


Имеется явный дисбаланс данных

In [None]:
# переводим adata в датафрейм
adata_df = adata.to_df()

In [None]:
# подсчёт доли пропусков
adata_df.isna().mean()

index
1700022I11Rik    0.0
1810046K07Rik    0.0
5031425F14Rik    0.0
5730522E02Rik    0.0
Acta2            0.0
                ... 
Sst              0.0
Rab3b            0.0
Slc17a7          0.0
Penk             0.0
Gad1             0.0
Length: 254, dtype: float64

Пропуски отсутствуют

In [None]:
# подсчет явных дубликатов
adata_df.duplicated().sum()

0

Дубликаты в данных не обнаружены

In [None]:
non_zero = (adata_df != 0).sum()
non_zero

index
1700022I11Rik      9574
1810046K07Rik     32859
5031425F14Rik     12726
5730522E02Rik     21690
Acta2            119588
                  ...  
Sst              214621
Rab3b            279420
Slc17a7          280148
Penk             264217
Gad1             187781
Length: 254, dtype: int64

In [None]:
sorted_counts = non_zero.sort_values(ascending=True)
sorted_counts

index
Dnase1l3           7232
Clrn1              8251
Mrgprx2            8875
1700022I11Rik      9574
C1qtnf7           11463
                  ...  
Tac2             267056
Vip              278972
Rab3b            279420
Gad2             279691
Slc17a7          280148
Length: 254, dtype: int64

Характеристики с нулевыми значениями отсутствуют

In [None]:
# Объединим данные из датафрейма cell_labels с данными из adata по индексу
cell_labels_merged = cell_labels.merge(adata_df, left_on='ID', right_on='index')
cell_labels_merged

Unnamed: 0,ID,sample_id,slice_id,class_label,subclass,label,1700022I11Rik,1810046K07Rik,5031425F14Rik,5730522E02Rik,...,Gad2,Tac2,Lamp5,Cnr1,Pvalb,Sst,Rab3b,Slc17a7,Penk,Gad1
0,10000143038275111136124942858811168393,mouse2_sample4,mouse2_slice31,Other,Astro,Astro_1,0.0,0.000000,0.0,0.000000,...,0.261892,0.262109,0.037601,0.010633,0.192950,0.079685,0.063194,0.613834,0.038716,0.000000
1,100001798412490480358118871918100400402,mouse2_sample5,mouse2_slice160,Other,Endo,Endo,0.0,0.000000,0.0,0.000000,...,0.512115,0.225832,0.093828,0.000000,0.160687,0.171461,0.208027,0.618113,0.051407,0.000000
2,100006878605830627922364612565348097824,mouse2_sample6,mouse2_slice109,Other,SMC,SMC,0.0,0.000000,0.0,0.000000,...,0.199059,0.114653,0.000000,0.070231,0.144294,0.051351,0.148232,0.448829,0.041903,0.000000
3,100007228202835962319771548915451072492,mouse1_sample2,mouse1_slice71,Other,Endo,Endo,0.0,0.000000,0.0,0.000000,...,0.240500,0.093413,0.033100,0.151837,0.197471,0.011399,0.110675,1.117577,0.016462,0.467554
4,100009332472089331948140672873134747603,mouse2_sample5,mouse2_slice219,Glutamatergic,L2/3 IT,L23_IT_3,0.0,0.491629,0.0,0.983257,...,1.726676,0.533404,3.691514,0.000000,0.296567,0.748264,1.423427,11.386119,0.387408,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
280181,99987465505639073211021560543065098772,mouse1_sample5,mouse1_slice251,GABAergic,Pvalb,Pvalb_1,0.0,0.000000,0.0,0.557620,...,4.711612,1.145445,0.000000,0.362137,1.549598,0.195306,4.132728,1.586415,0.152228,8.515803
280182,99989592830367590092304100078674096866,mouse2_sample3,mouse2_slice261,Glutamatergic,L5 ET,L5_ET_5,0.0,0.000000,0.0,0.000000,...,0.354238,0.160562,0.630489,0.030634,0.073509,0.000000,0.210928,4.954309,0.140572,0.107833
280183,99991756591196613545069880666241120777,mouse1_sample3,mouse1_slice112,Glutamatergic,L4/5 IT,L45_IT_3,0.0,0.000000,0.0,0.000000,...,0.944207,0.221704,0.139209,0.854429,0.296027,0.110999,0.432686,5.654802,0.042431,0.004518
280184,99997421766159526763299676887100858104,mouse2_sample3,mouse2_slice261,Glutamatergic,L5 ET,L5_ET_3,0.0,0.000000,0.0,0.000000,...,0.581056,0.782009,1.167211,0.008588,0.000000,0.000000,1.607264,11.598111,0.279262,0.072066


In [None]:
# Удалим столбцы 'ID', 'sample_id', 'slice_id', 'subclass', 'label' не представляющие ценности для обучения моделей:
cell_labels_merged = cell_labels_merged.drop(columns=['ID', 'sample_id', 'slice_id', 'subclass', 'label'])


In [None]:
cell_labels_merged.dtypes

class_label       object
1700022I11Rik    float32
1810046K07Rik    float32
5031425F14Rik    float32
5730522E02Rik    float32
                  ...   
Sst              float32
Rab3b            float32
Slc17a7          float32
Penk             float32
Gad1             float32
Length: 255, dtype: object

In [None]:
cell_labels_merged

Unnamed: 0,class_label,1700022I11Rik,1810046K07Rik,5031425F14Rik,5730522E02Rik,Acta2,Adam2,Adamts2,Adamts4,Adra1b,...,Gad2,Tac2,Lamp5,Cnr1,Pvalb,Sst,Rab3b,Slc17a7,Penk,Gad1
0,Other,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.659448,...,0.261892,0.262109,0.037601,0.010633,0.192950,0.079685,0.063194,0.613834,0.038716,0.000000
1,Other,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,...,0.512115,0.225832,0.093828,0.000000,0.160687,0.171461,0.208027,0.618113,0.051407,0.000000
2,Other,0.0,0.000000,0.0,0.000000,25.567039,0.0,0.000000,0.000000,0.000000,...,0.199059,0.114653,0.000000,0.070231,0.144294,0.051351,0.148232,0.448829,0.041903,0.000000
3,Other,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,...,0.240500,0.093413,0.033100,0.151837,0.197471,0.011399,0.110675,1.117577,0.016462,0.467554
4,Glutamatergic,0.0,0.491629,0.0,0.983257,0.491629,0.0,0.000000,0.000000,0.983257,...,1.726676,0.533404,3.691514,0.000000,0.296567,0.748264,1.423427,11.386119,0.387408,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
280181,GABAergic,0.0,0.000000,0.0,0.557620,0.000000,0.0,0.000000,0.000000,0.557620,...,4.711612,1.145445,0.000000,0.362137,1.549598,0.195306,4.132728,1.586415,0.152228,8.515803
280182,Glutamatergic,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.000000,1.829112,14.632893,...,0.354238,0.160562,0.630489,0.030634,0.073509,0.000000,0.210928,4.954309,0.140572,0.107833
280183,Glutamatergic,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,...,0.944207,0.221704,0.139209,0.854429,0.296027,0.110999,0.432686,5.654802,0.042431,0.004518
280184,Glutamatergic,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,2.795264,...,0.581056,0.782009,1.167211,0.008588,0.000000,0.000000,1.607264,11.598111,0.279262,0.072066


## Подготовка данных для классификации:
Выделяем признаки (гены) и целевую переменную (метки классов).
Разделим данные на обучающую (60%), валидационную (20%) и тестовую (20%) выборки:

#### Разделение данных на выборки

In [None]:
# разделение данных на обучающую, валидационную и тестовую выборки
cell_labels_train, cell_labels_test = train_test_split(cell_labels_merged, train_size=0.6, test_size=0.4, random_state=12345)
cell_labels_test, cell_labels_valid = train_test_split(cell_labels_test, test_size=0.5, random_state=12345)

In [None]:
print("Процент данных в обучающей выборке:", '{:.0%}'.format(len(cell_labels_train)/len(cell_labels_merged)))
print("Процент данных в валидационной выборке", '{:.0%}'.format(len(cell_labels_valid)/len(cell_labels_merged)))
print("Процент данных в тестовой выборке", '{:.0%}'.format(len(cell_labels_test)/len(cell_labels_merged)))

Процент данных в обучающей выборке: 60%
Процент данных в валидационной выборке 20%
Процент данных в тестовой выборке 20%


In [None]:
# определение функции для создания переменных для признаков и целевого признака
def splitting_data(data, target_column):
    return data.drop(columns=[target_column], axis=1), data[target_column]

In [None]:
# cоздание переменных для признаков и целевого признака
features_train, target_train = splitting_data(cell_labels_train, 'class_label')
features_valid, target_valid = splitting_data(cell_labels_valid, 'class_label')
features_test, target_test = splitting_data(cell_labels_test, 'class_label')

In [None]:
# проверка суммы значений
print("Сумма значений новых выборок:", len(features_train) + len(features_valid) + len(features_test))
print("Количество значений исходной выборки:", len(cell_labels_merged))

Сумма значений новых выборок: 280186
Количество значений исходной выборки: 280186


## Борьба с дисбалансом

In [None]:
target_train

250680    Glutamatergic
141167    Glutamatergic
125704        GABAergic
197527            Other
8111              Other
              ...      
158838    Glutamatergic
47873             Other
86398     Glutamatergic
77285             Other
217570            Other
Name: class_label, Length: 168111, dtype: object

In [None]:
len(features_train)

168111

In [None]:
# определение upsampling-функции
def upsample(features, target, repeat):

    features_other = features[target == "Other"]
    features_glutamatergic = features[target == "Glutamatergic"]
    features_GABAergic = features[target == "GABAergic"]

    target_other = target[target == "Other"]
    target_glutamatergic = target[target == "Glutamatergic"]
    target_GABAergic = target[target == "GABAergic"]

    features_upsampled = pd.concat([features_other] + [features_glutamatergic] + [features_GABAergic] * repeat)
    target_upsampled = pd.concat([target_other] + [target_glutamatergic] + [target_GABAergic] * repeat)

    features_upsampled, target_upsampled = shuffle(features_upsampled, target_upsampled, random_state=12345)

    return features_upsampled, target_upsampled

In [None]:
# получение увеличенных выборок
features_train, target_train = upsample(features_train, target_train, 3)

print(features_train.shape)
print(target_train.shape)

(190079, 254)
(190079,)


In [None]:
# определение downsampling-функции
def downsample(features, target, fraction_other, fraction_glutamatergic):

    features_other = features[target == "Other"]
    features_glutamatergic = features[target == "Glutamatergic"]
    features_GABAergic = features[target == "GABAergic"]

    target_other = target[target == "Other"]
    target_glutamatergic = target[target == "Glutamatergic"]
    target_GABAergic = target[target == "GABAergic"]

    features_other = features_other.sample(frac=fraction_other, random_state=12345)
    target_other = target_other.sample(frac=fraction_other, random_state=12345)

    features_glutamatergic = features_glutamatergic.sample(frac=fraction_glutamatergic, random_state=12345)
    target_glutamatergic = target_glutamatergic.sample(frac=fraction_glutamatergic, random_state=12345)

    features_downsampled = pd.concat([features_other] + [features_glutamatergic] + [features_GABAergic])
    target_downsampled = pd.concat([target_other] + [target_glutamatergic] + [target_GABAergic])

    features_downsampled, target_downsampled = shuffle(features_downsampled, target_downsampled, random_state=12345)

    return features_downsampled, target_downsampled

In [None]:
# получение уменьшенных выборок
features_train, target_train = downsample(features_train, target_train, 0.5, 0.35)

print(features_train.shape)
print(target_train.shape)

(97485, 254)
(97485,)


In [None]:
# подсчет количества значений целевого признака
target_train.value_counts()

GABAergic        32952
Glutamatergic    32738
Other            31795
Name: class_label, dtype: int64

### Обучение моделей

In [None]:
model_CB = CatBoostClassifier(random_state=12345, verbose = False, task_type="GPU", devices='0:1')
parameters_CB = {'max_depth': list(range(1, 10)),
                 'n_estimators': list(range(5, 51, 5))
                 }

In [None]:
def grid_search_cv(model, features_train, target_train, features_valid, parameters):

    grid = GridSearchCV(estimator=model, param_grid=parameters, cv=5, scoring='f1_weighted',
                        n_jobs=-1, error_score='raise')
    grid.fit(features_train, target_train)
    predictions = grid.best_estimator_.predict(features_valid)

    f1 = f1_score(target_valid, predictions, average=None)

    print("F1:", f1)

    print("Параметры лучшей модели:", grid.best_params_)
    print()

In [None]:
%%time
#обучение модели
grid_search_cv(model_CB, features_train, target_train, features_valid, parameters_CB)

F1: [0.98663029 0.99603418 0.99405375]
Параметры лучшей модели: {'max_depth': 9, 'n_estimators': 50}

CPU times: user 17.1 s, sys: 3.09 s, total: 20.2 s
Wall time: 17min 47s


## Тестирование модели

In [None]:
features_test

Unnamed: 0,1700022I11Rik,1810046K07Rik,5031425F14Rik,5730522E02Rik,Acta2,Adam2,Adamts2,Adamts4,Adra1b,Alk,...,Gad2,Tac2,Lamp5,Cnr1,Pvalb,Sst,Rab3b,Slc17a7,Penk,Gad1
83768,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,2.810577,3.747436,0.000000,...,0.479865,0.570115,0.000000,0.151052,0.275068,0.114951,0.776496,8.645861,0.397368,0.000000
46645,0.000000,0.720141,0.0,0.000000,0.720141,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.562591,0.228866,0.141687,0.000000,0.225190,0.106453,0.334048,0.856092,0.160667,0.000000
261729,0.000000,0.000000,0.0,0.000000,1.966111,0.000000,0.000000,0.000000,4.587593,1.966111,...,3.523723,1.057427,0.000000,0.622802,5.430205,0.204675,4.639832,2.393476,0.204461,16.070328
27294,0.000000,0.000000,0.0,0.000000,0.000000,0.618067,0.000000,1.236134,8.652939,0.000000,...,1.295445,0.471979,0.025454,0.000000,0.324796,0.467346,1.173903,12.411381,0.408038,0.000000
174061,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.387153,0.307346,0.158279,0.000000,0.135557,0.312946,0.262788,0.291806,0.069098,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
140885,0.439498,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.550987,0.320670,2.167154,0.000000,0.052372,0.122104,0.387479,6.069474,0.304430,0.000000
188843,0.000000,0.000000,0.0,0.000000,8.594641,0.000000,0.000000,8.594641,0.000000,0.000000,...,0.937101,0.169633,0.069543,0.053683,0.000000,0.271513,0.550246,4.388434,0.153084,0.123911
121605,0.000000,0.000000,0.0,0.000000,0.640585,0.000000,0.640585,7.046440,0.000000,0.000000,...,0.792083,0.979309,0.000000,0.127742,0.433559,0.082353,0.282034,0.502107,0.187190,0.000000
111774,0.475141,0.000000,0.0,0.475141,2.850846,0.000000,0.475141,0.000000,4.751410,3.801129,...,1.798313,0.064934,2.860857,0.541593,0.618424,0.390358,1.989951,13.992662,0.281518,0.000000


In [None]:
model = CatBoostClassifier(max_depth=9, n_estimators=50, random_state=12345)

print("Время обучения:")
%time model.fit(features_train, target_train)

print()
print("Время предсказания:")
%time predictions = model.predict(features_test)

print()
print(predictions)

Время обучения:
Learning rate set to 0.5
0:	learn: 0.4271078	total: 3.59s	remaining: 2m 56s
1:	learn: 0.2518020	total: 6.64s	remaining: 2m 39s
2:	learn: 0.1641483	total: 8.83s	remaining: 2m 18s
3:	learn: 0.1124926	total: 10.8s	remaining: 2m 3s
4:	learn: 0.0808365	total: 12.7s	remaining: 1m 54s
5:	learn: 0.0622058	total: 14.6s	remaining: 1m 47s
6:	learn: 0.0481802	total: 16.6s	remaining: 1m 41s
7:	learn: 0.0404568	total: 19.4s	remaining: 1m 42s
8:	learn: 0.0331603	total: 21.8s	remaining: 1m 39s
9:	learn: 0.0290909	total: 23.7s	remaining: 1m 34s
10:	learn: 0.0261327	total: 25.6s	remaining: 1m 30s
11:	learn: 0.0250797	total: 27.5s	remaining: 1m 27s
12:	learn: 0.0238068	total: 29.4s	remaining: 1m 23s
13:	learn: 0.0227766	total: 32.2s	remaining: 1m 22s
14:	learn: 0.0212573	total: 34.9s	remaining: 1m 21s
15:	learn: 0.0204874	total: 36.8s	remaining: 1m 18s
16:	learn: 0.0195024	total: 38.7s	remaining: 1m 15s
17:	learn: 0.0189670	total: 40.7s	remaining: 1m 12s
18:	learn: 0.0182709	total: 42.6s	

In [None]:
results = pd.concat([target_test.reset_index(drop=True), pd.DataFrame(predictions, columns=["predictions"])], axis=1)
results

Unnamed: 0,class_label,predictions
0,Glutamatergic,Glutamatergic
1,Other,Other
2,GABAergic,GABAergic
3,Glutamatergic,Glutamatergic
4,Other,Other
...,...,...
56032,Glutamatergic,Glutamatergic
56033,Other,Other
56034,Other,Other
56035,Glutamatergic,Glutamatergic


In [None]:
report = classification_report(target_test, predictions, target_names=['Other', 'Glutamatergic', 'GABAergic'])
print(report)

               precision    recall  f1-score   support

        Other       0.98      1.00      0.99      3680
Glutamatergic       1.00      1.00      1.00     31241
    GABAergic       0.99      1.00      0.99     21116

     accuracy                           1.00     56037
    macro avg       0.99      1.00      0.99     56037
 weighted avg       1.00      1.00      1.00     56037



## Выводы

В ходе работы было выполнено:

* Загружены, изучены и подготовлены данные. Выполнена проверка наличия нулевых и пропущенных значений, удалены лишние колонки;
* Был выявлен дисбаланс данных (Glutamatergic 155831, Other 106055, GABAergic 18300). В качестве борьбы с дисбалансом была увеличена выборка для класса "GABAergic", а также уменьшены выборки
для классов "Glutamatergic" и "Other";
* Была обучена модель с градиентным бустингом (CatBoost). Данные были разделены на train, validation и test выборки;
* В ходе обучения были получены параметры лучшей модели: {'max_depth': 9, 'n_estimators': 50}:
* В финальной проверке на тестовой выборке модель CatBoost показала precision на уровнях: 0.98 для класса "Other",
 0.99 для класса "GABAergic",
 1.00 для класса "Glutamatergic"