In [4]:
import lightning.pytorch as pl
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler
from torch.utils.data import DataLoader, Dataset

In [24]:
if torch.backends.mps.is_available():
    # Set default device for new tensors
    torch.set_default_device("cpu")

In [None]:
def preprocess_data(
    data: str | pd.DataFrame,
    drop_columns: list = None,
    ordinal_encode_columns: list = None,
    one_hot_encode_columns: list = None,
    drop_na: bool = True,
):
    if isinstance(data, str):
        df = pd.read_csv(data)
    elif isinstance(data, pd.DataFrame):
        df = data
    else:
        raise ValueError("Input data must be a file path (str) or a pandas DataFrame.")
    ordinal_mappings = {}

    if drop_columns:
        df = df.drop(columns=drop_columns)
    if ordinal_encode_columns:
        ordinal_encoder = OrdinalEncoder()
        df[ordinal_encode_columns] = ordinal_encoder.fit_transform(df[ordinal_encode_columns])
        for i, col in enumerate(ordinal_encode_columns):
            ordinal_mappings[col] = {category: index for index, category in enumerate(ordinal_encoder.categories_[i])}
    if one_hot_encode_columns:
        one_hot_encoder = OneHotEncoder(sparse=False)
        one_hot_encoded = one_hot_encoder.fit_transform(df[one_hot_encode_columns])
        one_hot_encoded_df = pd.DataFrame(
            one_hot_encoded, columns=one_hot_encoder.get_feature_names_out(one_hot_encode_columns)
        )
        df = pd.concat([df.drop(columns=one_hot_encode_columns), one_hot_encoded_df], axis=1)

    if drop_na:
        df = df.dropna()

    return df, ordinal_mappings

In [None]:
class SCClass(Dataset):
    def __init__(self, df, scaler=None, scale_method_cls=StandardScaler):
        super().__init__()
        self.df = df
        data = self.df.astype(np.float32)
        if scaler is None:
            # Usually for training, we want to fit the scaler
            self.scaler = scale_method_cls()
            scaled_data = self.scaler.fit_transform(data)
        else:
            # For validation/test
            self.scaler = scaler
            scaled_data = self.scaler.transform(data)
        self.data_tensor = torch.tensor(scaled_data, dtype=torch.float32)

    def __len__(self):
        return len(self.data_tensor)

    def __getitem__(self, idx):
        return self.data_tensor[idx]

    def get_scaler(self):
        return self.scaler


class SCDataLoader(pl.LightningDataModule):
    def __init__(
        self,
        data: str,
        batch_size: int = 32,
        num_workers: int = 0,
        val_split: float = 0.15,
        test_split: float = 0.15,
        drop_columns: list = None,
        ordinal_encode_columns: list = None,
        one_hot_encode_columns: list = None,
        drop_na: bool = True,
        scale_method_cls=StandardScaler,
        random_seed: int = 42,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.data: pd.DataFrame = None
        self.processed_data: pd.DataFrame = None
        self.ordinal_mappings: dict = None
        self.train_dataset: SCClass = None
        self.val_dataset: SCClass = None
        self.test_dataset: SCClass = None
        self.scaler: StandardScaler = None

    def setup(self, stage=None):
        if self.train_dataset is not None and self.val_dataset is not None:
            # Check for test dataset existence only if test split is expected
            if self.test_dataset is not None:
                return
        data = self.hparams.data
        if isinstance(data, str):
            self.data = pd.read_csv(data)
        elif isinstance(data, pd.DataFrame):
            self.data = data
        else:
            raise ValueError("Input data must be a file path (str) or a pandas DataFrame.")

        if self.hparams.test_split == 0:
            raise ValueError("test_split must be greater than 0.")
        self.processed_data, self.ordinal_mappings = preprocess_data(
            self.data,
            drop_columns=self.hparams.drop_columns,
            ordinal_encode_columns=self.hparams.ordinal_encode_columns,
            one_hot_encode_columns=self.hparams.one_hot_encode_columns,
            drop_na=self.hparams.drop_na,
        )

        # Splitting and dataset initiations
        train_df, test_df = train_test_split(
            self.processed_data, test_size=self.hparams.test_split, random_state=self.hparams.random_seed
        )
        train_df, val_df = train_test_split(
            train_df,
            test_size=self.hparams.val_split / (1 - self.hparams.test_split),
            random_state=self.hparams.random_seed,
        )
        self.train_dataset = SCClass(train_df, scale_method_cls=self.hparams.scale_method_cls)
        self.scaler = self.train_dataset.get_scaler()
        self.val_dataset = SCClass(val_df, scaler=self.scaler)
        self.test_dataset = SCClass(test_df, scaler=self.scaler)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            persistent_workers=True if self.hparams.num_workers > 0 else False,
            pin_memory=True,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            persistent_workers=True if self.hparams.num_workers > 0 else False,
            pin_memory=True,
            shuffle=False,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            persistent_workers=True if self.hparams.num_workers > 0 else False,
            pin_memory=True,
            shuffle=False,
        )

In [None]:
# adata = ad.read_h5ad('/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/results/standard/adatas/cells_final.h5ad')
# adata.layers['raw'] = adata.raw.X
# adata.X = adata.layers['arcsinh']
# adata.X = adata.X.astype('float32')
# adata.obs = adata.obs.drop(columns=['Phenotype4','disease2', 'disease3', 'image_ID', 'disease', 'cellcharter_CN', 'HistoneH3', 'ROI'])
# dataloader = AnnLoader(adata,
#    batch_size=32,
#    shuffle=True,
#    )