In [12]:
import torch
import numpy as np
import anndata as ad
from sklearn.metrics import confusion_matrix, f1_score, precision_score, log_loss
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import joblib

In [None]:
adata_main = ad.read_h5ad('Data/GSE155249_main.h5ad')

### We are going to use only highly variable genes to speed up computations

In [14]:
adata = adata_main[:, adata_main.var['highly_variable'] == True]
print(f"Adata shape: {adata.shape}")
print(f"Adata main shape: {adata_main.shape}")

Adata shape: (77146, 4488)
Adata main shape: (77146, 21819)


In [15]:
def prepare_data(adata):
    X = adata.X.toarray() if not isinstance(adata.X, np.ndarray) else adata.X
    y = adata.obs['Cluster'].values
    le = LabelEncoder()
    y_encoded = le.fit_transform(y)

    # First split: train (70%) and temp (30%)
    X_train, X_temp, y_train, y_temp = train_test_split(
        X, y_encoded, test_size=0.3, random_state=42, stratify=y_encoded
    )

    # Second split: val (20%) and test (10%) from temp (30%)
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=1/3, random_state=42, stratify=y_temp
    )

    return X_train, X_val, X_test, y_train, y_val, y_test, le

In [16]:
X_train, X_val, X_test, y_train, y_val, y_test, le = prepare_data(adata)

In [17]:
print(f"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
print(f"X_val shape: {X_val.shape}, y_val shape: {y_val.shape}")
print(f"X_test shape: {X_test.shape}, y_test shape: {y_test.shape}")

X_train shape: (54002, 4488), y_train shape: (54002,)
X_val shape: (15429, 4488), y_val shape: (15429,)
X_test shape: (7715, 4488), y_test shape: (7715,)


### Save the data for future reference

In [None]:
np.savez_compressed("DATA/train_data.npz", X=X_train, y=y_train)
np.savez_compressed("DATA/val_data.npz", X=X_val, y=y_val)
np.savez_compressed("DATA/test_data.npz", X=X_test, y=y_test)
joblib.dump(le, 'DATA/label_encoder.pkl')

['/raid/brunopsz/Moddeling/DATA/label_encoder.pkl']