In [None]:
import os
import re
import tarfile
import urllib.request
from pathlib import Path
import numpy as np
import pandas as pd
import h5py
import anndata
from scipy.sparse import csr_matrix

# Constants
GEO_ACCESSION = "GSE206107"
DOWNLOAD_URL = "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE206107&format=file"
GUIDE_NAMES = ['sgNC1', 'sgNC2', 'sgPrmt1', 'sgRipk1', 'sgAxl']
GENE_TARGET_MAP = {
    'sgNC1': 'Non-Targeting',
    'sgNC2': 'Non-Targeting',
    'sgPrmt1': 'Prmt1',
    'sgRipk1': 'Ripk1',
    'sgAxl': 'Axl'
}

def download_and_extract(data_dir):
    """Download and extract the dataset if not already present."""
    data_dir = Path(data_dir)
    tar_path = data_dir / f"{GEO_ACCESSION}_RAW.tar"
    
    os.makedirs(data_dir, exist_ok=True)
    
    if not tar_path.exists():
        print(f"Downloading {GEO_ACCESSION} dataset...")
        urllib.request.urlretrieve(DOWNLOAD_URL, tar_path)
    
    h5_files = list(data_dir.glob("*.h5"))
    if not h5_files:
        print(f"Extracting {GEO_ACCESSION}_RAW.tar...")
        with tarfile.open(tar_path, 'r') as tar:
            tar.extractall(path=data_dir)
        h5_files = list(data_dir.glob("*.h5"))
    
    return h5_files

def parse_filename(filename):
    """Parse the filename to extract metadata."""
    basename = os.path.basename(filename)
    gsm_id = basename.split('_')[0]
    
    match = re.match(r'GSM\d+_(\d+)-(\w+)-(\w+)(?:-v(\d+))?-(\w+)\.h5', basename)
    if match:
        version_num, seq_type, condition, version, treatment = match.groups()
        if version is None:
            version = "1"
        return {
            'gsm_id': gsm_id,
            'version_num': version_num,
            'seq_type': seq_type,
            'condition': condition,
            'version': f"v{version}",
            'treatment': treatment
        }
    else:
        raise ValueError(f"Could not parse filename: {basename}")

def read_h5_file(file_path):
    """Read an h5 file and return the gene expression matrix and metadata."""
    with h5py.File(file_path, 'r') as f:
        if 'matrix' in f:
            matrix_group = f['matrix']
            data = matrix_group['data'][:]
            indices = matrix_group['indices'][:]
            indptr = matrix_group['indptr'][:]
            shape = matrix_group['shape'][:]
            
            if 'barcodes' in matrix_group:
                barcodes = [bc.decode('utf-8') for bc in matrix_group['barcodes'][:]]
            else:
                barcodes = [f"cell_{i}" for i in range(shape[0])]
            
            if len(indptr) != len(barcodes) + 1:
                n_cells = len(barcodes)
                new_indptr = np.zeros(n_cells + 1, dtype=np.int32)
                new_indptr[:min(len(indptr), n_cells + 1)] = indptr[:min(len(indptr), n_cells + 1)]
                if len(indptr) < n_cells + 1:
                    new_indptr[len(indptr):] = indptr[-1]
                indptr = new_indptr
            
            if 'features' in matrix_group:
                features_group = matrix_group['features']
                feature_ids = [id.decode('utf-8') for id in features_group['id'][:]]
                feature_names = [name.decode('utf-8') for name in features_group['name'][:]]
                
                if 'feature_type' in features_group:
                    feature_types = [ft.decode('utf-8') for ft in features_group['feature_type'][:]]
                    gene_indices = [i for i, ft in enumerate(feature_types) if ft == 'Gene Expression']
                    guide_indices = [i for i, ft in enumerate(feature_types) if ft == 'CRISPR Guide Capture']
                    
                    gene_ids = [feature_ids[i] for i in gene_indices]
                    gene_names = [feature_names[i] for i in gene_indices]
                    
                    guide_names = [feature_names[i] for i in guide_indices] if guide_indices else None
                    
                    full_matrix = csr_matrix(
                        (data, indices, indptr),
                        shape=(len(barcodes), len(feature_names))
                    )
                    
                    gene_expression = full_matrix[:, gene_indices]
                    
                    guide_counts = None
                    if guide_indices:
                        guide_counts = full_matrix[:, guide_indices].toarray()
                else:
                    gene_ids = feature_ids
                    gene_names = feature_names
                    gene_expression = csr_matrix(
                        (data, indices, indptr),
                        shape=(len(barcodes), len(feature_names))
                    )
                    guide_counts = None
                    guide_names = None
            else:
                gene_ids = [f"gene_{i}" for i in range(shape[1])]
                gene_names = gene_ids
                gene_expression = csr_matrix(
                    (data, indices, indptr),
                    shape=(len(barcodes), len(gene_ids))
                )
                guide_counts = None
                guide_names = None
            
            return {
                'gene_expression': gene_expression,
                'gene_ids': gene_ids,
                'gene_names': gene_names,
                'barcodes': barcodes,
                'guide_counts': guide_counts,
                'guide_names': guide_names
            }
        else:
            raise ValueError(f"Could not find matrix in {file_path}")

def create_anndata(data):
    """Create an AnnData object from the data."""
    gene_names = data['gene_names']
    if len(set(gene_names)) < len(gene_names):
        name_count = {}
        unique_gene_names = []
        for name in gene_names:
            if name in name_count:
                name_count[name] += 1
                unique_gene_names.append(f"{name}_{name_count[name]}")
            else:
                name_count[name] = 0
                unique_gene_names.append(name)
        var_index = pd.Index(unique_gene_names)
    else:
        var_index = pd.Index(gene_names)
    
    adata = anndata.AnnData(
        X=data['gene_expression'],
        obs=pd.DataFrame(index=data['barcodes']),
        var=pd.DataFrame(index=var_index)
    )
    
    adata.var['ensembl_id'] = data['gene_ids']
    
    if data['guide_counts'] is not None and data['guide_names'] is not None:
        adata.obsm['guide_counts'] = data['guide_counts']
        adata.uns['guide_names'] = np.array(data['guide_names'])
    
    adata.raw = adata.copy()
    
    return adata

def identify_perturbations(adata):
    """Identify perturbations for each cell based on guide counts."""
    if 'guide_counts' not in adata.obsm or 'guide_names' not in adata.uns:
        print("Warning: Guide counts or names not found. Cannot identify perturbations.")
        adata.obs['perturbation_name'] = pd.Categorical(['Unknown'] * adata.n_obs)
        adata.obs['perturbation_type'] = pd.Categorical(['Unknown'] * adata.n_obs)
        return adata
    
    guide_counts = adata.obsm['guide_counts']
    guide_names = adata.uns['guide_names']
    
    perturbations = []
    for i in range(guide_counts.shape[0]):
        cell_guides = guide_counts[i, :]
        if np.sum(cell_guides) == 0:
            perturbation = "Unknown"
        else:
            threshold = 1
            guide_indices = np.where(cell_guides > threshold)[0]
            if len(guide_indices) == 0:
                perturbation = "Unknown"
            else:
                targets = []
                for idx in guide_indices:
                    if idx < len(guide_names):
                        guide = guide_names[idx]
                        if guide in GENE_TARGET_MAP:
                            targets.append(GENE_TARGET_MAP[guide])
                        else:
                            targets.append(guide)
                perturbation = " + ".join(targets)
        perturbations.append(perturbation)
    
    adata.obs['perturbation_name'] = pd.Categorical(perturbations)
    
    perturbation_types = ['Non-Targeting' if 'Non-Targeting' in x or x == 'Unknown' else 'Targeting' 
                          for x in perturbations]
    adata.obs['perturbation_type'] = pd.Categorical(perturbation_types)
    
    return adata

def add_harmonized_metadata(adata, file_info):
    """Add harmonized metadata to the AnnData object."""
    adata.obs['organism'] = pd.Categorical(['Mus musculus'] * adata.n_obs)
    adata.obs['cell_type'] = pd.Categorical(['MC38 colon cancer cells'] * adata.n_obs)
    adata.obs['crispr_type'] = pd.Categorical(['CRISPR KO'] * adata.n_obs)
    adata.obs['cancer_type'] = pd.Categorical(['Colon Cancer'] * adata.n_obs)
    # The original condition is based on file_info['treatment'],
    # but this will be overridden later based on perturbation.
    adata.obs['condition'] = pd.Categorical([file_info['treatment']] * adata.n_obs)
    adata.obs['seq_type'] = pd.Categorical([file_info['seq_type']] * adata.n_obs)
    adata.obs['environment'] = pd.Categorical([file_info['condition']] * adata.n_obs)
    adata.obs['version'] = pd.Categorical([file_info['version']] * adata.n_obs)
    adata.obs['sample_id'] = pd.Categorical([file_info['gsm_id']] * adata.n_obs)
    
    return adata

def process_file(file_path):
    """Process a single h5 file and return an AnnData object with harmonized metadata."""
    file_info = parse_filename(file_path)
    data = read_h5_file(file_path)
    adata = create_anndata(data)
    adata = identify_perturbations(adata)
    adata = add_harmonized_metadata(adata, file_info)
    return adata

def main(data_dir):
    """Process all files, combine them into one AnnData, filter and adjust metadata, then save."""
    h5_files = download_and_extract(data_dir)
    print(f"Found {len(h5_files)} h5 files to process")
    
    adata_list = []
    for file_path in h5_files:
        print(f"Processing {os.path.basename(file_path)}...")
        try:
            adata = process_file(file_path)
            adata_list.append(adata)
        except Exception as e:
            print(f"Error processing {file_path}: {e}")
    
    if not adata_list:
        print("No valid AnnData objects processed.")
        return
    
    # Combine all AnnData objects
    combined_adata = anndata.concat(adata_list, join="outer", merge="same")
    
    # Exclude cells with "Unknown" perturbation
    combined_adata = combined_adata[combined_adata.obs['perturbation_name'] != "Unknown", :].copy()
    
    # Update condition:
    # If perturbation_name is "Non-targeting" (case insensitive), set condition to "Control",
    # otherwise set to "Test".
    combined_adata.obs['condition'] = combined_adata.obs['perturbation_name'].apply(
        lambda x: "Control" if str(x).lower() == "non-targeting" else "Test"
    )
    
    output_path = os.path.join(data_dir, "combined_harmonized.h5ad")
    combined_adata.write_h5ad(output_path)
    print(f"Combined harmonized data saved to {output_path}")

# In Jupyter, set the data directory and call main()
data_dir = "/content/GSE206107"
main(data_dir)
