# Load data for final version of thymus ageing atlas

In [None]:
import os
import sys
import session_info
from datetime import datetime
today = datetime.today().strftime('%Y-%m-%d')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import scanpy as sc
import anndata as ad
import hdf5plugin

# Add repo path to sys path (allows to access scripts and metadata from repo)
#repo_path,_ = os.path.split(os.path.split(os.getcwd())[0])
repo_path = '/lustre/scratch126/cellgen/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis'
sys.path.insert(1, repo_path) 
sys.path.insert(2, '/lustre/scratch126/cellgen/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis/scripts')

# Add R libs path
#os.environ['LD_LIBRARY_PATH'] = '' # Uncomment on jhub
#os.environ['R_HOME'] = '/nfs/team205/lm25/condaEnvs/thymusAgeing/lib/R' # Uncomment on jhub
os.environ['R_LIBS_USER'] = '/nfs/team205/lm25/condaEnvs/thymusAgeing/lib/R/library'

%load_ext rpy2.ipython
%load_ext autoreload
%autoreload 2

In [None]:
%%capture
%%R

library(tidyverse)
library(patchwork)
library(magrittr)

source('/nfs/team205/lm25/customScripts/visualisation/customTheme.R')

options(max.print=150)

In [None]:
# Define paths
plots_path = f'{repo_path}/plots/preprocessing'
data_path = f'{repo_path}/data'
general_data_path = f'{repo_path}/data'

# Inspect metadata

In [None]:
# Load latest metadata
from utils import get_latest_version,update_obs

latest_meta_path = get_latest_version(dir = f'{general_data_path}/metadata', file_prefix='Thymus_ageing_metadata')
latest_meta = pd.read_excel(latest_meta_path)

latest_meta.head()

## Demultiplex Notarangelo2024

In [None]:
notarangelo_meta = latest_meta[latest_meta['study'] == 'Notarangelo2024']

notarangelo_meta.head()

In [None]:
# Identify libraries for demuxing
demux_lib = notarangelo_meta[notarangelo_meta['library'].str.count('_') == 2]['library'].unique()

demux_lib

In [None]:
from scipy.io import mmread
from scipy.sparse import csr_matrix
from typing import List


def import_hto_matrix(path):
    
    """
    Imports an HTO (Hashtag Oligo) matrix from the specified directory.
    This function reads a sparse matrix from a Matrix Market file, along with
    corresponding features and barcodes from TSV files, and returns an AnnData
    object containing the data.
    Parameters:
    path (str): The directory path where the matrix.mtx.gz, features.tsv.gz, 
                and barcodes.tsv.gz files are located.
    Returns:
    ad.AnnData: An AnnData object containing the imported HTO matrix with 
                barcodes as observations and features as variables.
    """
    
    # Read the matrix file as a sparse matrix
    sparse_matrix = mmread(f'{path}/matrix.mtx.gz').tocsr()
    features = pd.read_csv(f'{path}features.tsv.gz', sep='\t', header=None)[0]
    barcodes = pd.read_csv(f'{path}/barcodes.tsv.gz', sep='\t', header=None)[0]

    adata = ad.AnnData(X=sparse_matrix.T)
    adata.obs_names = barcodes
    adata.var_names = features

    return adata

def assign_hto(adata: ad.AnnData, run_id : int, lib : List, hto_map: pd.DataFrame) -> ad.AnnData:
    """
    Assigns HTO (Hashtag Oligonucleotide) labels to the AnnData object based on the provided HTO mapping.
    Parameters:
    -----------
    adata : ad.AnnData
        Annotated data matrix.
    run_id : int
        Identifier for the sequencing run.
    lib : List
        List of libraries to consider for HTO assignment.
    hto_map : pd.DataFrame
        DataFrame containing the HTO mapping information. It should have columns 'Sequencing_run_name', 
        'Library', 'Sample1', 'Sample2', 'Sample1.HTO', and 'Sample2.HTO'.
    Returns:
    --------
    ad.AnnData
        The AnnData object with updated HTO assignments in the observation (obs) metadata.
    """
    
    max_counts = np.array(adata.X.argmax(axis=1)).flatten()

    adata.obs['barcode'] = adata.obs_names
    adata.obs['hto_assignment_orig'] = [adata.var_names[i] for i in max_counts]

    hto_dict = hto_map.loc[(hto_map['Sequencing_run_name'] == run_id) & (hto_map['Library'].isin(lib))][['Sample1', 'Sample2', 'Sample1.HTO', 'Sample2.HTO']].set_index('Sample1.HTO')['Sample1'].to_dict()
    hto_dict.update(hto_map.loc[(hto_map['Sequencing_run_name'] == run_id) & (hto_map['Library'].isin(lib))][['Sample1', 'Sample2', 'Sample1.HTO', 'Sample2.HTO']].set_index('Sample2.HTO')['Sample2'].to_dict())

    print(hto_dict)
    print(adata.obs['hto_assignment_orig'].str.split('-').str[0].unique())
    adata.obs['index'] = [hto_dict[s.split('-')[0]] if s.split('-')[0] in hto_dict.keys() else pd.NA for s in adata.obs['hto_assignment_orig']]
    adata.obs['hto_assignment'] = [f'{i}-{b}' if not pd.isna(i) else pd.NA for i,b in zip(adata.obs['index'], adata.obs_names)]

    adata.obs_names = adata.obs['hto_assignment']
    
    return adata

def hto_demux(path: str, hto_map: pd.DataFrame, run_id: int, lib: List) -> ad.AnnData:
    """
    Demultiplexes the HTO (Hashtag Oligonucleotide) data based on the provided HTO mapping.
    Parameters:
    -----------
    path : str
        Path to the HTO matrix file.
    hto_map : pd.DataFrame
        DataFrame containing the HTO mapping information. It should have columns 'Sequencing_run_name', 
        'Library', 'Sample1', 'Sample2', 'Sample1.HTO', and 'Sample2.HTO'.
    run_id : int
        Identifier for the sequencing run.
    lib : List
        List of libraries to consider for HTO assignment.
    Returns:
    --------
    ad.AnnData
        The AnnData object with demultiplexed HTO data.
    """
    
    adata = import_hto_matrix(path)
    adata = assign_hto(adata, run_id, lib, hto_map)
    
    return adata

In [None]:
# Load HTO mapping info
hto_map = pd.read_excel('/lustre/scratch126/cellgen/team205/lm25/raw_data/Notarangelo2024/Notarangelo2024_meta.xlsx', sheet_name='HTO.sample.assignments')
hto_map['Sample1.HTO'] = hto_map['Sample1.HTO'].apply(lambda x: f'HTO_{x[-1]}')
hto_map['Sample2.HTO'] = hto_map['Sample2.HTO'].apply(lambda x: f'HTO_{x[-1]}')

hto_map

In [None]:
all_runs = hto_map[['Sequencing_run_name', 'Library', 'HTO.name']].drop_duplicates().groupby(['Sequencing_run_name', 'HTO.name']).apply(lambda x: x['Library'].tolist()).to_frame(name='Library').reset_index()

barcode_assignments = []
# Fix: Use iterrows() to iterate over DataFrame rows
for _, row in all_runs.iterrows():
    run_id = row['Sequencing_run_name']
    lib = row['Library']
    hto_name = row['HTO.name']
    
    hto_path = f'/lustre/scratch126/cellgen/team205/lm25/raw_data/Notarangelo2024/HTO_CITEseq_count_outputs/{run_id}/HTO_counts_{run_id}_{hto_name}/read_count/'
    adata = hto_demux(hto_path, hto_map, run_id, lib)
    
    barcode_assignments.append(adata.obs)

In [None]:
barcode_assignments = pd.concat(barcode_assignments)

barcode_assignments.head()

In [None]:
barcode_assignments['index'].value_counts()

In [None]:
barcode_assignments.to_csv(f'/lustre/scratch126/cellgen/team205/lm25/raw_data/Notarangelo2024/HTO_CITEseq_count_outputs/Notarangelo2024_HTO_barcode_assignments.csv')

# Assemble cell h5ad object

## Select libraries

In [None]:
cells_meta = latest_meta.loc[(latest_meta['health_status'] == 'healthy') & (latest_meta['type'] == 'cells') & (latest_meta['age_group'].isin(['infant', 'paed', 'adult'])) & (latest_meta['study'] != 'Notarangelo2024')]

# Select specific samples for Notarangelo2024 study
notarangelo_meta = latest_meta.loc[latest_meta['study'] == 'Notarangelo2024']
notarangelo_samples = notarangelo_meta.loc[(notarangelo_meta['age_group'].isin(['adult', 'paed'])) & (notarangelo_meta['health_status'] == 'healthy')]['index'].tolist()
notarangelo_samples.extend(notarangelo_meta.loc[(notarangelo_meta['age_group'].isin(['infant'])) & (notarangelo_meta['health_status'] == 'healthy') & (notarangelo_meta['age'].isin(['7d', '4m', '11m']))]['index'].tolist())  

cells_meta = pd.concat([cells_meta, latest_meta.loc[latest_meta['index'].isin(notarangelo_samples)]])

cells_meta

In [None]:
# Check if any library is missing cellbender paths
any(pd.isna(cells_meta['path_cellbender_gex']))

## Check which sorts TabSap datasets are

In [None]:
tabsap_meta = cells_meta.loc[cells_meta['study'] == 'TabulaSapiens2022']
tabsap_meta['sample'] = tabsap_meta['index']

In [None]:
tabsap_meta[['sample', 'sort']] # Sorts assigned by Veronika (metadata v9)

In [None]:
tabsap_adata = cellbender_to_anndata(tabsap_meta)

In [None]:
for s in tabsap_adata.obs['sample'].unique():
    expr = tabsap_adata[tabsap_adata.obs['sample'] == s, 'CD3E'].X.sum(0)/tabsap_adata[tabsap_adata.obs['sample'] == s].shape[0]
    assign = tabsap_adata[tabsap_adata.obs['sample'] == s].obs['sort'].unique()[0]
    print(f'{s}:{expr} -> {assign}')

NOTE: I reassigned sorts based on whether CD3E > 0 (CD3P). Strangely, that means that both 5' samples of TSP2 are CD3P...

## Create object for demultiplexing

In [None]:
from scripts.utils import cellbender_to_anndata, add_cell_metadata

In [None]:
demux_meta = cells_meta.loc[(cells_meta['study'] == 'Notarangelo2024') & (cells_meta['library'].str.count('_') == 2)]

demux_meta

In [None]:
%%capture output
demux_adata = cellbender_to_anndata(demux_meta, col_library='library', col_prefix='library', add_meta=False)

demux_adata

In [None]:
# Remove sample from obs (needs to be added through demuxing)
demux_adata.obs.drop(columns='sample', inplace=True)

demux_adata.obs.head()

In [None]:
barcode_assignments = pd.read_csv('/lustre/scratch126/cellgen/team205/lm25/raw_data/Notarangelo2024/HTO_CITEseq_count_outputs/Notarangelo2024_HTO_barcode_assignments.csv')
barcode_assignments = barcode_assignments.merge(demux_meta[['index', 'library']].drop_duplicates(), on = 'index')

barcode_assignments.head()

In [None]:
demux_adata.obs = demux_adata.obs.merge(barcode_assignments, on = ['barcode', 'library'], how = 'left')
demux_adata.obs_names = demux_adata.obs['index'] + '-' + demux_adata.obs['barcode']

demux_adata.obs

In [None]:
# Filter adata to only contain barcodes from samples of interest and add metadata
demux_adata = demux_adata[~demux_adata.obs['index'].isna()]
demux_adata.obs.drop(columns=['hto_assignment', 'hto_assignment_orig', 'hto_assignment.1'], inplace=True)

demux_adata.obs = pd.merge(left = demux_adata.obs.reset_index(names = 'names'), right = demux_meta, how = "left", on=['index', 'library']).set_index('names')

demux_adata.obs

## Load non-demux data

In [None]:
non_demux_meta = cells_meta.loc[~cells_meta['index'].isin(demux_meta['index'])]

non_demux_meta

In [None]:
# Check if any library is missing cellbender paths
cells_meta.path_cellbender_gex.isna().sum()

In [None]:
# Check how many samples per study
non_demux_meta.study.value_counts()

In [None]:
%%capture output
adata = cellbender_to_anndata(non_demux_meta)

adata

In [None]:
adata['A16_TH_TOT_1-AAACCTGTCGAGAGCA', 'TCF7'].X.todense()

In [None]:
adata.obs.study.value_counts()

In [None]:
var = adata.var.copy()

In [None]:
adata = ad.concat([adata, demux_adata], merge = 'same', index_unique = None)

In [None]:
adata['A16_TH_TOT_1-AAACCTGTCGAGAGCA', 'TCF7'].X.todense()

In [None]:
adata.obs.groupby(['chemistry_simple', 'type', 'study']).agg(n_donors = ('donor', 'nunique'),
                                                             n_cells = ('age', 'count'))

In [None]:
adata.obs.groupby(['study']).agg(n_donors = ('donor', 'nunique'),
                                                             n_cells = ('age', 'count'))

In [None]:
adata.obs.groupby(['study', 'age_group'], observed = True).agg(n_donors = ('donor', 'nunique'),
                                                             n_cells = ('age', 'count'))

In [None]:
adata.obs.groupby(['study', 'sort'], observed = True).agg(n_donors = ('donor', 'nunique'),
                                                             n_cells = ('age', 'count'))

In [None]:
import hdf5plugin

object_version = f'v3_{today}'

# Convert columns of type object
for col in adata.obs.columns:
    if adata.obs[col].dtypes == 'object':
        if isinstance(adata.obs[col].iloc[0], (bool)):
            adata.obs[col] = adata.obs[col].astype(bool)
        else:
            adata.obs[col] = adata.obs[col].astype(str)

adata.write_h5ad(
            f'{general_data_path}/objects/rna/thyAgeing_all_unfiltered_{object_version}.zarr',
            compression=hdf5plugin.FILTERS["zstd"],
            compression_opts=hdf5plugin.Zstd(clevel=5).filter_options
            )

## QC & filtering

### Mito, gene and read counts

In [None]:
# Remove empty cells
print('Removing {} empty cells'.format(sum(adata.X.sum(1) == 0)[0,0]))
print('Removing {} non-expressed genes'.format(sum(adata.X.sum(0) == 0)[0,0]))

adata = adata[(adata.X.sum(1) > 0), (adata.X.sum(0) > 0)]

In [None]:
# Add metadata
add_cell_metadata(adata, velocyto = False, cellbender = False)

In [None]:
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

sc.pl.violin(adata, ['n_genes', 'n_counts', 'percent_mito','percent_ribo', 'percent_hb'],
             jitter=0.1, multi_panel=True, size = 0.1)

In [None]:
plt.hist(adata.obs['n_counts'], range = (0, 1000), bins = 50)

In [None]:
plt.hist(adata.obs['n_genes'], range = (0, 1000), bins = 50)

In [None]:
adata.obs[['n_genes','n_counts','percent_mito', 'percent_ribo', 'percent_hb']].describe()

In [None]:
adata['A16_TH_TOT_1-AAACCTGTCGAGAGCA', 'TCF7'].X.todense()

In [None]:
adata.obs.groupby('study')[['n_genes', 'n_counts', 'percent_mito', 'percent_ribo', 'percent_hb']].describe().to_csv(f'{data_path}/analysis/preprocessing/qc_summary_stats_by_study.csv')

In [None]:
# Filter raw cells according to identified QC thresholds:
print('Total number of cells: {:d}'.format(adata.n_obs))

sc.pp.filter_cells(adata, min_counts = 300)
print('Number of cells after min count filter: {:d}'.format(adata.n_obs))

sc.pp.filter_cells(adata, min_genes = 400)
print('Number of cells after min genes filter: {:d}'.format(adata.n_obs))

sc.pp.filter_cells(adata, max_genes = 10000)
print('Number of cells after max genes filter: {:d}'.format(adata.n_obs))

adata = adata[adata.obs['percent_mito'] < 0.15]
print('Number of cells after mito filter: {:d}'.format(adata.n_obs))

adata = adata[adata.obs['percent_ribo'] > 0.05]
print('Number of cells after ribo filter: {:d}'.format(adata.n_obs))

In [None]:
# Convert columns of type object
for col in adata.obs.columns:
    if adata.obs[col].dtypes == 'object':
        if isinstance(adata.obs[col].iloc[0], (bool)):
            adata.obs[col] = adata.obs[col].astype(bool)
        else:
            adata.obs[col] = adata.obs[col].astype(str)

In [None]:
import hdf5plugin

object_version = f'v3_{today}'

adata.write_h5ad(
            f'{general_data_path}/objects/rna/thyAgeing_all_filtered_{object_version}.zarr',
            compression=hdf5plugin.FILTERS["zstd"],
            compression_opts=hdf5plugin.Zstd(clevel=5).filter_options
            )

### Doublet removal

In [None]:
import scrublet as scr
import anndata as ad
import pandas as pd
import scanpy as sc
from concurrent.futures import ThreadPoolExecutor

def doublet_detection(sample_object, sample_col='sample', n_cpu=4):
    sc.settings.verbosity = 1  # verbosity: errors (0), warnings (1), info (2), hints (3)
    scrdf = pd.DataFrame()
    def process_sample(s):
        # Import data
        print('Doublet detection for sample {}'.format(s))
        adata_sample = sample_object[sample_object.obs[sample_col] == s, :].copy()
        if adata_sample.shape[0] > 100:
            # Set up and run Scrublet
            scrub = scr.Scrublet(adata_sample.X)
            doublet_scores, predicted_doublets = scrub.scrub_doublets(verbose=False)
            adata_sample.obs['scrublet_score'] = doublet_scores  # 1 scrublet score
            adata_sample.obs['predicted_doublet'] = predicted_doublets
            return adata_sample.obs[['scrublet_score', 'predicted_doublet']]
        return pd.DataFrame()

    with ThreadPoolExecutor(max_workers=n_cpu) as executor:
        results = list(executor.map(process_sample, sample_object.obs[sample_col].unique()))

    scrdf = pd.concat(results)
    return scrdf

In [None]:
# Remove any sample with less than 100 cells
samples_to_remove = adata.obs['sample'].value_counts().sort_values(ascending=True).to_frame(name='n_cells').reset_index(names='sample').query('n_cells < 100')['sample']

adata = adata[~adata.obs['sample'].isin(samples_to_remove)]

In [None]:
doublet_scores = doublet_detection(adata, sample_col = 'sample', n_cpu = 4)

In [None]:
doublet_scores

In [None]:
doublet_scores.to_csv(f'{data_path}/analysis/preprocessing/thyAgeing_all_filtered_{object_version}_doubletScores.csv')

In [None]:
# Identify duplicates barcodes (mostly from Notarangelo2024 multiplexed samples)
doublet_scores.index.duplicated().sum() == adata.obs_names.duplicated().sum()

In [None]:
dup_barcodes = adata.obs[adata.obs_names.duplicated()].index

adata[adata.obs_names.isin(dup_barcodes)].obs

In [None]:
# Remove duplicated barcodes from adata and scrublet_scores
adata = adata[~adata.obs_names.duplicated()]
scrublet_scores = doublet_scores[~doublet_scores.index.duplicated()]

adata.shape[0] == scrublet_scores.shape[0]

In [None]:
# Add scrublet scores to adata
adata.obs = adata.obs.join(scrublet_scores, how = 'left')

In [None]:
adata['A16_TH_TOT_1-AAACCTGTCGAGAGCA', 'TCF7'].X.todense()

In [None]:
adata.obs.head()

In [None]:
adata.obs['predicted_doublet'].value_counts(normalize=True)

In [None]:
(adata.obs['scrublet_score'] < .3).value_counts(normalize=True)

In [None]:
# Inspect doublet scores by sample
adata.obs.groupby('sample')['predicted_doublet'].value_counts(normalize=True)

In [None]:
# Remove cells predicted to be doublets
adata = adata[adata.obs['predicted_doublet'] != True,:]
print('Number of cells after doublet filter: {:d}'.format(adata.n_obs))

In [None]:
adata['A16_TH_TOT_1-AAACCTGTCGAGAGCA', 'TCF7'].X.todense()

In [None]:
import hdf5plugin

object_version = 'v3_2024-11-04'

# Convert columns of type object
for col in adata.obs.columns:
    if adata.obs[col].dtypes == 'object':
        if isinstance(adata.obs[col].iloc[0], (bool)):
            adata.obs[col] = adata.obs[col].astype(bool)
        else:
            adata.obs[col] = adata.obs[col].astype(str)

adata.write_h5ad(
            f'{general_data_path}/objects/rna/thyAgeing_all_filtered_{object_version}.zarr',
            compression=hdf5plugin.FILTERS["zstd"],
            compression_opts=hdf5plugin.Zstd(clevel=5).filter_options
            )