In [None]:
!wget -q https://ftp.ncbi.nlm.nih.gov/geo/series/GSE289nnn/GSE289084/suppl/GSE289084_pbmc_barcodes.tsv.gz
!wget -q https://ftp.ncbi.nlm.nih.gov/geo/series/GSE289nnn/GSE289084/suppl/GSE289084_pbmc_features_adt.tsv.gz
!wget -q https://ftp.ncbi.nlm.nih.gov/geo/series/GSE289nnn/GSE289084/suppl/GSE289084_pbmc_matrix_adt.tsv.gz

In [None]:

ls -lh

In [None]:
import pandas as pd

barcodes = pd.read_csv(  #load cell barcodes
    "GSE289084_pbmc_barcodes.tsv.gz",
    header=None
)
barcodes.columns = ["cell_barcode"]

print("Number of cells:", barcodes.shape[0])
barcodes.head()

In [None]:
# Check for duplicate cell barcodes
# Purpose:
# - Ensure each cell is unique
# - Prevent silent alignment errors
# -------------------------------

duplicate_barcodes = barcodes["cell_barcode"].duplicated().sum()
print("Duplicate barcodes:", duplicate_barcodes)

In [None]:




# ----------------------------------------
# Load ADT (protein) count matrix as TSV
# Purpose:
# - Inspect dimensions
# - Verify protein × cell layout
# ----------------------------------------

import pandas as pd

adt_matrix = pd.read_csv(
    "GSE289084_pbmc_matrix_adt.tsv.gz",
    sep="\t",
    header=None
)

print("ADT matrix shape (rows × columns):", adt_matrix.shape)

In [None]:
adt_matrix.head(15)

In [None]:
# Clean ADT table into usable components
# Purpose:
# - Separate protein names
# - Extract numeric ADT count matrix
# - Prepare for downstream analysis
# ----------------------------------------

# Extract protein names (rows 1 onward, column 0)
adt_proteins = adt_matrix.iloc[1:, 0].values

# Extract cell barcodes from header row (columns 1 onward)
adt_cell_barcodes = adt_matrix.iloc[0, 1:].values

# Extract numeric ADT counts (rows 1 onward, columns 1 onward)
adt_counts = adt_matrix.iloc[1:, 1:].astype(int)

print("Proteins:", adt_proteins)
print("ADT count matrix shape:", adt_counts.shape)

In [None]:
# ----------------------------------------
# Confirm ADT barcodes match barcode file
# ----------------------------------------

if (adt_cell_barcodes == barcodes["cell_barcode"].values).all():
    print("✅ ADT cell barcodes match barcode file perfectly")
else:
    print("❌ Barcode mismatch detected")

In [None]:
adt_counts.head(15)

In [None]:
# Build a clean ADT DataFrame
# Purpose:
# - Attach protein names to rows
# - Attach cell barcodes to columns
# - Use integers (raw counts)
# ----------------------------------------

import pandas as pd

adt_df = pd.DataFrame(
    adt_counts.values,
    index=adt_proteins,
    columns=adt_cell_barcodes
)

print("ADT DataFrame shape (proteins × cells):", adt_df.shape)
adt_df.iloc[:5, :5]

In [None]:


# ----------------------------------------
# Per-protein summary statistics
# Purpose:
# - Check dynamic range
# - Spot dead or saturated proteins
# ----------------------------------------

adt_stats = pd.DataFrame({
    "mean": adt_df.mean(axis=1),
    "median": adt_df.median(axis=1),
    "max": adt_df.max(axis=1),
    "pct_zero": (adt_df == 0).mean(axis=1) * 100
})

adt_stats.sort_values("mean", ascending=False)

In [None]:
# Visualize sparsity per protein
# Purpose:
# - See which markers are rare vs ubiquitous
# ----------------------------------------

import matplotlib.pyplot as plt

plt.figure(figsize=(8, 4))
plt.barh(adt_stats.index, adt_stats["pct_zero"])
plt.xlabel("Percent of cells with zero counts")
plt.title("ADT sparsity per protein")
plt.tight_layout()
plt.show()

In [None]:


# ADT library size per cell
# Purpose:
# - Detect empty / overloaded droplets
# - Understand protein depth variation
# ----------------------------------------

import numpy as np

adt_libsize = adt_df.sum(axis=0)

print("ADT library size per cell")
print("Min:", int(adt_libsize.min()))
print("Median:", int(np.median(adt_libsize)))
print("Max:", int(adt_libsize.max()))

In [None]:








# Create a cell-level metadata table
# ----------------------------------------
# Purpose:
# - One row per cell
# - Extract information hidden in column names
# ----------------------------------------

import pandas as pd

# ----------------------------------------
# Step 1: Start from ADT column names
# These are the cell barcodes
# ----------------------------------------
cell_meta = pd.DataFrame({
    "cell_barcode": adt_cell_barcodes
})

# ----------------------------------------
# Step 2: Extract patient ID
# Example barcode:
# M13_S1-AAACCCACAGAACATA-1
# Patient = M13
# ----------------------------------------
cell_meta["patient"] = cell_meta["cell_barcode"].str.split("_").str[0]

# ----------------------------------------
# Step 3: Extract sample ID (S1 or S2)
# From: M13_S1-AAACCCACAGAACATA-1
# Sample = S1
# ----------------------------------------
cell_meta["sample"] = (
    cell_meta["cell_barcode"]
    .str.split("_").str[1]
    .str.split("-").str[0]
)

# ----------------------------------------
# Look at the first few rows
# ----------------------------------------
cell_meta.head(10)

In [None]:




# How many cells per sample?
cell_meta["sample"].value_counts()

In [None]:


# How many cells per patient and sample?
cell_meta.groupby(["patient", "sample"]).size()

In [None]:




# ----------------------------------------
# GSM-derived truth table (ADT only)
# Source: GEO GSE289084 sample descriptions
# ----------------------------------------

gsm_meta = pd.DataFrame({
    "patient": ["M13","M13","M16","M16","M17","M17","M25","M25",
                "M26","M26","M28","M28","M34","M34","M42","M42"],
    "week":    ["W0","W6","W0","W6","W0","W6","W0","W6",
                "W0","W6","W0","W6","W0","W6","W0","W6"]
})

In [None]:




cell_meta.groupby(["patient", "sample"]).size()

In [None]:
# Explicit sample-to-week mapping
# (constructed from GSM + observed samples)
# ----------------------------------------

sample_meta = pd.DataFrame({
    "patient": ["M13","M13","M16","M16","M17","M17","M25","M25",
                "M26","M26","M28","M28","M34","M34","M42","M42"],
    "sample":  ["S1","S6","S1","S6","S1","S6","S1","S6",
                "S1","S6","S1","S6","S1","S6","S1","S6"],
    "week":    ["W0","W6","W0","W6","W0","W6","W0","W6",
                "W0","W6","W0","W6","W0","W6","W0","W6"]
})

In [None]:
cell_meta = cell_meta.merge(
    sample_meta,
    on=["patient", "sample"],
    how="left"
)

In [None]:





# ----------------------------------------
# Convert ADT matrix to long format
# Purpose:
# - Make plotting easy
# ----------------------------------------

adt_long = (
    adt_df
    .T                              # cells × proteins
    .merge(cell_meta, left_index=True, right_on="cell_barcode")
    .melt(
        id_vars=["cell_barcode", "patient", "sample", "week"],
        var_name="protein",
        value_name="adt_count"
    )
)

adt_long.head()

In [None]:
# ----------------------------------------
# Distribution of total ADT per cell by week
# Purpose:
# - Check overall protein signal shift
# ----------------------------------------

import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(6,4))
sns.boxplot(
    data=cell_meta.assign(
        total_adt=adt_df.sum(axis=0).values
    ),
    x="week",
    y="total_adt"
)
plt.yscale("log")
plt.title("Total ADT counts per cell (Week 0 vs Week 6)")
plt.show()

In [None]:
# Violin plots for selected proteins
# ----------------------------------------

proteins_to_check = [
    "CD3", "CD4", "CD8",
    "CD14", "CD16",
    "HLA-DR", "CD19"
]

plt.figure(figsize=(12,6))
sns.violinplot(
    data=adt_long[adt_long["protein"].isin(proteins_to_check)],
    x="protein",
    y="adt_count",
    hue="week",
    split=True,
    scale="width",
    cut=0
)
plt.yscale("log")
plt.title("ADT distributions by Week (selected proteins)")
plt.legend(title="Week")
plt.show()

In [None]:
# Per-patient median ADT (example: CD8)
# ----------------------------------------

cd8_summary = (
    adt_long[adt_long["protein"] == "CD8"]
    .groupby(["patient", "week"])["adt_count"]
    .median()
    .reset_index()
)

plt.figure(figsize=(6,4))
sns.lineplot(
    data=cd8_summary,
    x="week",
    y="adt_count",
    hue="patient",
    marker="o"
)
plt.yscale("log")
plt.title("Per-patient median CD8 (Week 0 → Week 6)")
plt.show()

In [None]:
# Per-patient median ADT per protein
# Purpose:
# - Robust summary per patient per timepoint
# ----------------------------------------

patient_week_summary = (
    adt_long
    .groupby(["patient", "week", "protein"])["adt_count"]
    .median()
    .reset_index()
)

patient_week_summary.head(20)


In [None]:
# Pivot so W0 and W6 are side by side
# ----------------------------------------

paired = (
    patient_week_summary
    .pivot_table(
        index=["patient", "protein"],
        columns="week",
        values="adt_count"
    )
    .reset_index()
)

paired.head(20)

In [None]:
# Compute log2 fold-change (W6 / W0)
# Purpose:
# - Stabilize variance
# - Make changes symmetric
# ----------------------------------------

import numpy as np

paired["log2FC_W6_vs_W0"] = np.log2(
    (paired["W6"] + 1) / (paired["W0"] + 1)
)

paired.head(100)

In [None]:

# Per-protein log2FC across patients
# Purpose:
# - See patient heterogeneity per marker
# ----------------------------------------

import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(12,4))
sns.stripplot(
    data=paired,
    x="protein",
    y="log2FC_W6_vs_W0",
    jitter=True
)
plt.axhline(0, color="red", linestyle="--")
plt.ylabel("log2FC (Week 6 vs Week 0)")
plt.title("Patient-paired protein changes")
plt.show()

In [None]:

# Per-patient protein trajectories
# Purpose:
# - See if patients show coordinated shifts
# ----------------------------------------

plt.figure(figsize=(12,5))
sns.lineplot(
    data=paired,
    x="protein",
    y="log2FC_W6_vs_W0",
    hue="patient",
    marker="o"
)
plt.axhline(0, color="black", linestyle="--")
plt.ylabel("log2FC (W6 vs W0)")
plt.title("Protein-wise changes per patient")
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
plt.show()

In [None]:
# Heatmap of log2FC (patients × proteins)
# Purpose:
# - Compact overview of longitudinal changes
# ----------------------------------------

heatmap_df = paired.pivot(
    index="patient",
    columns="protein",
    values="log2FC_W6_vs_W0"
)

plt.figure(figsize=(10,4))
sns.heatmap(
    heatmap_df,
    center=0,
    cmap="coolwarm",
    linewidths=0.5
)
plt.title("Patient-paired log2FC heatmap (Week 6 vs Week 0)")
plt.show()

In [None]:
# Define HLA-DR^high cells
# Logic:
# - Use top 20% of HLA-DR ADT counts
# - Robust to batch and patient effects
# ----------------------------------------

# Get HLA-DR values per cell
hladr_values = adt_df.loc["HLA-DR"]

# Compute threshold (80th percentile)
hladr_thresh = hladr_values.quantile(0.8)

hladr_thresh

In [None]:
# Fix index alignment
# ----------------------------------------

cell_meta = cell_meta.set_index("cell_barcode")

In [None]:
# Add HLA-DR^high flag (aligned by barcode)
# ----------------------------------------

cell_meta["is_HLADR_high"] = hladr_values > hladr_thresh

In [None]:
print(cell_meta["is_HLADR_high"].value_counts())
print(cell_meta["is_HLADR_high"].value_counts(normalize=True))

In [None]:
# Subset to HLA-DR^high cells only
# ----------------------------------------

hladr_cells = cell_meta[cell_meta["is_HLADR_high"]].copy()

# Subset ADT matrix to these cells
adt_hladr = adt_df[hladr_cells.index]

# Sanity check
adt_hladr.shape

In [None]:
# Fix: make cell_barcode a column before merge
# ----------------------------------------

adt_hladr_long = (
    adt_hladr
    .T
    .reset_index()                 # <-- makes cell_barcode a column
    .rename(columns={"index": "cell_barcode"})
    .merge(
        hladr_cells.reset_index(),
        on="cell_barcode",
        how="inner"
    )
    .melt(
        id_vars=["cell_barcode", "patient", "week"],
        var_name="protein",
        value_name="adt_count"
    )
)

adt_hladr_long.head()

In [None]:
adt_hladr_long.shape


In [None]:
adt_hladr_long.dtypes

In [None]:
# Force ADT values to numeric
# ----------------------------------------

import pandas as pd

adt_hladr_long["adt_count"] = pd.to_numeric(
    adt_hladr_long["adt_count"],
    errors="coerce"   # non-numeric values → NaN
)

In [None]:
# Drop rows where ADT value is invalid
# ----------------------------------------

adt_hladr_long = adt_hladr_long.dropna(subset=["adt_count"])

In [None]:


adt_hladr_long.dtypes

In [None]:

# Per-patient, per-week median ADT
# ----------------------------------------

hladr_summary = (
    adt_hladr_long
    .groupby(["patient", "week", "protein"])["adt_count"]
    .median()
    .reset_index()
)

hladr_summary.head(20)

In [None]:
# Keep only known ADT proteins
valid_proteins = adt_df.index.tolist()

hladr_summary = hladr_summary[
    hladr_summary["protein"].isin(valid_proteins)
]

In [None]:
hladr_summary["protein"].unique()

In [None]:


hladr_summary.head(20)

In [None]:

# Pivot to paired format (HLA-DR^high cells)
# ----------------------------------------

hladr_paired = (
    hladr_summary
    .pivot_table(
        index=["patient", "protein"],
        columns="week",
        values="adt_count"
    )
    .dropna()
    .reset_index()
)

import numpy as np

hladr_paired["log2FC_W6_vs_W0"] = np.log2(
    (hladr_paired["W6"] + 1) / (hladr_paired["W0"] + 1)
)

hladr_paired.head(20)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

heatmap_hladr = hladr_paired.pivot(
    index="patient",
    columns="protein",
    values="log2FC_W6_vs_W0"
)

plt.figure(figsize=(10,4))
sns.heatmap(
    heatmap_hladr,
    center=0,
    cmap="coolwarm",
    linewidths=0.5
)
plt.title("HLA-DR^high cells: log2FC (Week 6 vs Week 0)")
plt.show()

In [None]:
plt.figure(figsize=(12,4))
sns.stripplot(
    data=hladr_paired,
    x="protein",
    y="log2FC_W6_vs_W0",
    jitter=True
)
plt.axhline(0, color="black", linestyle="--")
plt.ylabel("log2FC (W6 vs W0)")
plt.title("Protein changes in HLA-DR^high cells")
plt.show()

In [None]:
##Define CD8+ CELLS USING QUANTILE THRESHOLD
# ----------------------------------------

cd8_values = adt_df.loc["CD8"]

cd8_thresh = cd8_values.quantile(0.7)  # top 30%

cd8_thresh

In [None]:
# Define activated CD8+ T cells
# ----------------------------------------

cell_meta["is_CD8_high"] = cd8_values > cd8_thresh

cell_meta["is_HLADR_CD8"] = (
    cell_meta["is_HLADR_high"] &
    cell_meta["is_CD8_high"]
)

In [None]:
# How many cells?
cell_meta["is_HLADR_CD8"].value_counts()

In [None]:
# Proportion
cell_meta["is_HLADR_CD8"].value_counts(normalize=True)

In [None]:
# Subset metadata to HLA-DR^high CD8+ cells
# ----------------------------------------

hladr_cd8_cells = cell_meta[cell_meta["is_HLADR_CD8"]].copy()

hladr_cd8_cells.shape

In [None]:
# Subset ADT matrix (proteins × cells)
# ----------------------------------------

adt_hladr_cd8 = adt_df[hladr_cd8_cells.index]

adt_hladr_cd8.shape

In [None]:

# Long-format ADT for HLA-DR^high CD8+ cells
# ----------------------------------------

adt_hladr_cd8_long = (
    adt_hladr_cd8
    .T
    .reset_index()
    .rename(columns={"index": "cell_barcode"})
    .merge(
        hladr_cd8_cells.reset_index(),
        on="cell_barcode",
        how="inner"
    )
    .melt(
        id_vars=["cell_barcode", "patient", "week"],
        var_name="protein",
        value_name="adt_count"
    )
)

# Ensure numeric ADT
import pandas as pd
adt_hladr_cd8_long["adt_count"] = pd.to_numeric(
    adt_hladr_cd8_long["adt_count"], errors="coerce"
)

# Drop any accidental non-numeric rows
adt_hladr_cd8_long = adt_hladr_cd8_long.dropna(subset=["adt_count"])

adt_hladr_cd8_long.head(), adt_hladr_cd8_long.shape

In [None]:
# Per-patient median ADT (HLA-DR^high CD8+)
# ----------------------------------------

hladr_cd8_summary = (
    adt_hladr_cd8_long
    .groupby(["patient", "week", "protein"])["adt_count"]
    .median()
    .reset_index()
)

# Safety: keep only real proteins
valid_proteins = adt_df.index.tolist()
hladr_cd8_summary = hladr_cd8_summary[
    hladr_cd8_summary["protein"].isin(valid_proteins)
]

hladr_cd8_summary.head()

In [None]:
# Paired W0/W6 + log2FC (activated CD8+)
# ----------------------------------------

import numpy as np

hladr_cd8_paired = (
    hladr_cd8_summary
    .pivot_table(
        index=["patient", "protein"],
        columns="week",
        values="adt_count"
    )
    .dropna()
    .reset_index()
)

hladr_cd8_paired["log2FC_W6_vs_W0"] = np.log2(
    (hladr_cd8_paired["W6"] + 1) / (hladr_cd8_paired["W0"] + 1)
)

hladr_cd8_paired.head()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

heatmap_cd8 = hladr_cd8_paired.pivot(
    index="patient",
    columns="protein",
    values="log2FC_W6_vs_W0"
)

plt.figure(figsize=(10,4))
sns.heatmap(
    heatmap_cd8,
    center=0,
    cmap="coolwarm",
    linewidths=0.5
)
plt.title("HLA-DR^high CD8+ T cells: log2FC (Week 6 vs Week 0)")
plt.show()


In [None]:

plt.figure(figsize=(12,4))
sns.stripplot(
    data=hladr_cd8_paired,
    x="protein",
    y="log2FC_W6_vs_W0",
    jitter=True
)
plt.axhline(0, color="black", linestyle="--")
plt.ylabel("log2FC (W6 vs W0)")
plt.title("Protein changes in HLA-DR^high CD8+ T cells")
plt.show()

In [None]:
# Load RNA cell barcodes
# ----------------------------------------

import pandas as pd

rna_barcodes = pd.read_csv(
    "GSE289084_pbmc_barcodes.tsv.gz",
    header=None,
    names=["cell_barcode"]
)

rna_barcodes.shape, rna_barcodes.head()

In [None]:
!wget -q https://ftp.ncbi.nlm.nih.gov/geo/series/GSE289nnn/GSE289084/suppl/GSE289084_pbmc_features.tsv.gz

In [None]:
# Load gene names
# ----------------------------------------

rna_genes = pd.read_csv(
    "GSE289084_pbmc_features.tsv.gz",
    header=None,
    sep="\t"
)

rna_genes.head(), rna_genes.shape

In [None]:
!wget -q https://ftp.ncbi.nlm.nih.gov/geo/series/GSE289nnn/GSE289084/suppl/GSE289084_pbmc_matrix.tsv.gz

In [None]:
cell_meta["is_HLADR_CD8"].sum()

In [None]:
# Cells we care about (from ADT-derived metadata)
cd8_cells = set(
    cell_meta[cell_meta["is_HLADR_CD8"]].index
)

len(cd8_cells)

In [None]:
# Identify RNA column indices corresponding to CD8-activated cells
keep_cols = [
    i for i, cb in enumerate(rna_barcodes["cell_barcode"])
    if cb in cd8_cells
]

len(keep_cols)

In [None]:

out_file = "rna_cd8_subset.tsv"

In [None]:
# Define genes of interest
# Rationale:
# 1) Protein-anchored genes (measured in ADT)
# 2) Cytotoxic T-cell program
# 3) TNF / IFN signaling (study-relevant)
# 4) Activation / exhaustion context
# ----------------------------------------

genes_of_interest = [
    # Protein-anchored
    "CD3D", "CD3E",
    "CD4",
    "CD8A", "CD8B",
    "CD14",
    "FCGR3A",        # CD16
    "CD19",
    "CD27",
    "PTPRC",         # CD45
    "SELL",          # CD62L
    "HLA-DRA", "HLA-DRB1",
    "IGHM",

    # Cytotoxic program
    "GZMB",
    "PRF1",
    "NKG7",
    "GNLY",

    # TNF / IFN biology
    "TNF",
    "IFNG",
    "CXCL9",
    "CXCL10",

    # Activation / exhaustion
    "PDCD1",
    "LAG3",
    "TIGIT",
    "TOX"
]

gene_set = set(genes_of_interest)

In [None]:
genes_of_interest = [
    "CD19", "CD27",
    "CD3D", "CD3E",
    "CD8A", "CD8B",
    "GZMB", "PRF1", "NKG7", "GNLY",
    "TNF", "IFNG",
    "CXCL9", "CXCL10",
    "PDCD1", "LAG3", "TIGIT", "TOX",
    "HLA-DRA", "HLA-DRB1",
    "CD14", "FCGR3A", "SELL", "PTPRC", "IGHM"
]

with open("genes.txt", "w") as f:
    for g in genes_of_interest:
        f.write(g + "\n")

In [None]:
!ls
!lcat gene.txt

In [None]:
!zgrep -w -f genes.txt GSE289084_pbmc_matrix.tsv.gz > rna_genes_subset.tsv

In [None]:
import pandas as pd

rna_subset = pd.read_csv(
    "rna_genes_subset.tsv",
    sep="\t",
    header=None
)

rna_subset.shape

In [None]:
# Assign column names
rna_subset.columns = ["gene"] + rna_barcodes["cell_barcode"].tolist()

# Set gene names as index
rna_subset = rna_subset.set_index("gene")

# Keep only HLA-DR^high CD8+ cells
rna_cd8_small = rna_subset.loc[
    :,
    rna_subset.columns.isin(cd8_cells)
]

rna_cd8_small.shape

In [None]:
rna_cd8_small.loc["CD19"].sum()

In [None]:
rna_cd8_small.loc[["GZMB","PRF1","NKG7"]].mean(axis=1)

In [None]:
# Build RNA cell metadata
rna_meta = pd.DataFrame(index=rna_cd8_small.columns)
rna_meta["patient"] = rna_meta.index.str.extract(r"(M\d+)")
rna_meta["sample"]  = rna_meta.index.str.extract(r"_(S\d+)-")
rna_meta["week"]    = rna_meta["sample"].map({"S1":"W0", "S6":"W6"})


In [None]:
rna_long = (
    rna_cd8_small
    .T
    .merge(rna_meta, left_index=True, right_index=True)
    .melt(
        id_vars=["patient","week"],
        var_name="gene",
        value_name="rna_count"
    )
)

In [None]:
rna_summary = (
    rna_long
    .groupby(["patient","week","gene"])["rna_count"]
    .median()
    .reset_index()
)

In [None]:
print(rna_summary["week"].value_counts())

In [None]:
print(rna_cd8_small.columns[:5])

In [None]:
print(type(rna_cd8_small.columns))
print(rna_cd8_small.columns[:5])

In [None]:
# Build RNA cell metadata from column names
# ----------------------------------------

rna_meta = pd.DataFrame(index=rna_cd8_small.columns)

# Patient ID: M13, M25, etc.
rna_meta["patient"] = rna_meta.index.str.extract(r"(M\d+)")

# Sample ID: S1, S6
rna_meta["sample"] = rna_meta.index.str.extract(r"_(S\d+)-")

# Map sample → week
rna_meta["week"] = rna_meta["sample"].map({
    "S1": "W0",
    "S6": "W6"
})

# Sanity checks
print(rna_meta.head())
print(rna_meta["week"].value_counts())

In [None]:
# Convert index to a proper pandas Series of strings
barcode_series = pd.Series(rna_cd8_small.columns.astype(str))

rna_meta = pd.DataFrame(index=barcode_series)

# Patient ID: starts with M + digits
rna_meta["patient"] = barcode_series.str.extract(r"^(M\d+)")

# Sample ID: S1 or S6 between "_" and "-"
rna_meta["sample"] = barcode_series.str.extract(r"_(S\d+)-")

# Map sample → week
rna_meta["week"] = rna_meta["sample"].map({
    "S1": "W0",
    "S6": "W6"
})

# Set proper index
rna_meta.index = barcode_series

# Check
print(rna_meta.head())
print(rna_meta["week"].value_counts())

In [None]:
# Build RNA cell metadata from column names (correct regex)

rna_meta = pd.DataFrame(index=rna_cd8_small.columns)

# Patient: starts with M followed by digits
rna_meta["patient"] = rna_meta.index.str.extract(r"^(M\d+)")

# Sample: after underscore, before dash
rna_meta["sample"] = rna_meta.index.str.extract(r"_(S\d+)-")

# Map sample to week
rna_meta["week"] = rna_meta["sample"].map({
    "S1": "W0",
    "S6": "W6"
})

# Check
print(rna_meta.head())
print(rna_meta["week"].value_counts())

In [None]:
for x in barcode_series.head(5):
    print(repr(x))

In [None]:
# Use the raw values as index
barcode_series = pd.Series(rna_cd8_small.columns.astype(str))

rna_meta = pd.DataFrame(index=barcode_series.values)

# Extract metadata
rna_meta["patient"] = barcode_series.str.extract(r"^(M\d+)")
rna_meta["sample"]  = barcode_series.str.extract(r"_(S\d+)-")
rna_meta["week"]    = rna_meta["sample"].map({
    "S1": "W0",
    "S6": "W6"
})

# Check
print(rna_meta.head())
print(rna_meta["week"].value_counts())

In [None]:
# Build metadata directly from the column index
rna_meta = pd.DataFrame(index=rna_cd8_small.columns)

# Convert index to string explicitly (safety)
idx = rna_meta.index.astype(str)

# Extract patient and sample from index
rna_meta["patient"] = idx.str.extract(r"^(M\d+)")
rna_meta["sample"]  = idx.str.extract(r"_(S\d+)-")

# Map sample to week
rna_meta["week"] = rna_meta["sample"].map({
    "S1": "W0",
    "S6": "W6"
})

# Check
print(rna_meta.head())
print(rna_meta["week"].value_counts())

In [None]:
# Build metadata using pure string splitting (no regex)

rna_meta = pd.DataFrame(index=rna_cd8_small.columns)

def parse_barcode(bc):
    bc = str(bc)
    patient = bc.split("_")[0]          # M13
    sample  = bc.split("_")[1].split("-")[0]  # S1 or S6
    return patient, sample

parsed = [parse_barcode(bc) for bc in rna_meta.index]

rna_meta["patient"] = [p[0] for p in parsed]
rna_meta["sample"]  = [p[1] for p in parsed]

rna_meta["week"] = rna_meta["sample"].map({
    "S1": "W0",
    "S6": "W6"
})

# Check
print(rna_meta.head())
print(rna_meta["week"].value_counts())

In [None]:
# Build long-format RNA table
rna_long = (
    rna_cd8_small
    .T
    .merge(rna_meta, left_index=True, right_index=True)
    .melt(
        id_vars=["patient", "week"],
        var_name="gene",
        value_name="rna_count"
    )
)

print(rna_long.head())

In [None]:
print(rna_long.dtypes)

In [None]:
rna_long["rna_count"] = pd.to_numeric(rna_long["rna_count"], errors="coerce")

In [None]:
print(rna_long.dtypes)

In [None]:
rna_summary = (
    rna_long
    .groupby(["patient", "week", "gene"])["rna_count"]
    .median()
    .reset_index()
)

print(rna_summary.head())
print(rna_summary["week"].value_counts())

In [None]:
rna_wide = (
    rna_summary
    .pivot_table(
        index=["patient", "gene"],
        columns="week",
        values="rna_count"
    )
    .reset_index()
)

print(rna_wide.head())
print(rna_wide.columns)

In [None]:
import numpy as np

rna_wide["log2FC_W6_vs_W0"] = np.log2((rna_wide["W6"] + 1) / (rna_wide["W0"] + 1))
print(rna_wide.head())


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Build matrix: genes × patients
heatmap_data = rna_wide.pivot(
    index="gene",
    columns="patient",
    values="log2FC_W6_vs_W0"
)

plt.figure(figsize=(10, 9))
sns.heatmap(
    heatmap_data,
    cmap="coolwarm",
    center=0,
    linewidths=0.4,
    linecolor="gray",
    cbar_kws={"label": "log2 Fold Change (W6 vs W0)"}
)

plt.title("RNA activation signature in HLA-DR⁺ CD8 T cells")
plt.ylabel("Gene")
plt.xlabel("Patient")
plt.tight_layout()
plt.show()

In [None]:
genes_focus = ["GZMB", "PRF1", "NKG7", "CD27", "CD19"]

import seaborn as sns
import matplotlib.pyplot as plt

subset = rna_summary[rna_summary["gene"].isin(genes_focus)]

plt.figure(figsize=(12,5))
sns.pointplot(
    data=subset,
    x="gene",
    y="rna_count",
    hue="week",
    dodge=True,
    markers=["o","s"],
    capsize=0.1
)
plt.title("Key gene expression in HLA-DR⁺ CD8 T cells")
plt.ylabel("Median RNA count")
plt.xlabel("")
plt.show()

In [None]:
# Define CD8+ cells
# ----------------------------------------
cd8_values = adt_df.loc["CD8"]
cd8_thresh = cd8_values.quantile(0.8)   # top 20%, same logic as HLA-DR

cell_meta["is_CD8"] = cd8_values > cd8_thresh

In [None]:
cell_meta["is_CD8"].value_counts()

In [None]:
pd.crosstab(cell_meta["is_CD8"], cell_meta["is_HLADR_high"])

In [None]:
cell_meta["group"] = "Other"

cell_meta.loc[
    cell_meta["is_CD8"] & cell_meta["is_HLADR_high"],
    "group"
] = "HLADR+_CD8"

cell_meta.loc[
    cell_meta["is_CD8"] & (~cell_meta["is_HLADR_high"]),
    "group"
] = "HLADR-_CD8"

In [None]:
cell_meta["group"].value_counts()

In [None]:
hladr_pos_cells = cell_meta[cell_meta["group"]=="HLADR+_CD8"].index
hladr_neg_cells = cell_meta[cell_meta["group"]=="HLADR-_CD8"].index

rna_pos = rna_subset.loc[:, rna_subset.columns.isin(hladr_pos_cells)]
rna_neg = rna_subset.loc[:, rna_subset.columns.isin(hladr_neg_cells)]

print("HLADR+ CD8 RNA:", rna_pos.shape)
print("HLADR- CD8 RNA:", rna_neg.shape)

In [None]:
# RNA metadata: patient and week per cell
rna_meta = cell_meta.loc[:, ["patient", "week"]]

In [None]:
# HLADR+ CD8 RNA → long format
# ----------------------------------------
rna_pos_long = (
    rna_pos.T
    .merge(rna_meta, left_index=True, right_index=True)
    .melt(
        id_vars=["patient", "week"],
        var_name="gene",
        value_name="rna_count"
    )
)

print(rna_pos_long.head(20))

In [None]:
rna_pos_long.shape

In [None]:
rna_pos_summary = (
    rna_pos_long
    .groupby(["patient", "week", "gene"])["rna_count"]
    .median()
    .reset_index()
)

print(rna_pos_summary.head())

In [None]:
# Convert to wide format: W0 and W6 in columns
# ----------------------------------------
rna_pos_wide = (
    rna_pos_summary
    .pivot_table(
        index=["patient", "gene"],
        columns="week",
        values="rna_count"
    )
    .reset_index()
)

print(rna_pos_wide.head())

In [None]:
# Compute log2FC for HLADR+ CD8 cells
# ----------------------------------------
rna_pos_wide["log2FC"] = np.log2(
    (rna_pos_wide["W6"] + 1) / (rna_pos_wide["W0"] + 1)
)

print(rna_pos_wide.head())

In [None]:
rna_neg_long = (
    rna_neg.T
    .merge(rna_meta, left_index=True, right_index=True)
    .melt(
        id_vars=["patient", "week"],
        var_name="gene",
        value_name="rna_count"
    )
)

print(rna_neg_long.head())

In [None]:
# Summarize HLADR- CD8 RNA
rna_neg_summary = (
    rna_neg_long
    .groupby(["patient", "week", "gene"])["rna_count"]
    .median()
    .reset_index()
)

print(rna_neg_summary.head())

In [None]:
rna_neg_wide = (
    rna_neg_summary
    .pivot_table(
        index=["patient", "gene"],
        columns="week",
        values="rna_count"
    )
    .reset_index()
)

print(rna_neg_wide.head())

In [None]:
rna_neg_wide["log2FC"] = np.log2(
    (rna_neg_wide["W6"] + 1) / (rna_neg_wide["W0"] + 1)
)

print(rna_neg_wide.head())

In [None]:
rna_compare = rna_pos_wide.merge(
    rna_neg_wide,
    on=["patient", "gene"],
    suffixes=("_HLADRpos", "_HLADRneg")
)

print(rna_compare.head())

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(6,6))
plt.scatter(
    rna_compare["log2FC_HLADRneg"],
    rna_compare["log2FC_HLADRpos"],
    alpha=0.6
)
plt.axhline(0, color="grey", linestyle="--")
plt.axvline(0, color="grey", linestyle="--")
plt.xlabel("log2FC (HLADR− CD8)")
plt.ylabel("log2FC (HLADR+ CD8)")
plt.title("Therapy response: Activated vs Resting CD8 T cells")
plt.show()

In [None]:
# absolute change in activated minus resting
rna_compare["activation_specificity"] = (
    np.abs(rna_compare["log2FC_HLADRpos"])
    - np.abs(rna_compare["log2FC_HLADRneg"])
)

In [None]:
ranked = (
    rna_compare
    .sort_values("activation_specificity", ascending=False)
)

ranked.head(15)

In [None]:
biomarkers = ranked[
    (np.abs(ranked["log2FC_HLADRpos"]) > 1) &
    (np.abs(ranked["log2FC_HLADRneg"]) < 0.3)
]

biomarkers[["gene","log2FC_HLADRpos","log2FC_HLADRneg","activation_specificity"]].head(20)

In [None]:
plt.figure(figsize=(6,6))

# all genes
plt.scatter(
    rna_compare["log2FC_HLADRneg"],
    rna_compare["log2FC_HLADRpos"],
    alpha=0.3
)

# biomarkers
plt.scatter(
    biomarkers["log2FC_HLADRneg"],
    biomarkers["log2FC_HLADRpos"],
    color="red",
    s=60,
    label="Candidate biomarkers"
)

plt.axhline(0, color="grey", linestyle="--")
plt.axvline(0, color="grey", linestyle="--")
plt.xlabel("log2FC (HLADR− CD8)")
plt.ylabel("log2FC (HLADR+ CD8)")
plt.title("Activation-specific therapy biomarkers")
plt.legend()
plt.show()

In [None]:
# Take genes that are strongly activation-specific
top_genes = (
    rna_compare
    .groupby("gene")["activation_specificity"]
    .median()
    .sort_values(ascending=False)
    .head(15)
    .index
)

print(top_genes)

In [None]:
patient_gene_matrix = (
    rna_compare
    .query("gene in @top_genes")
    .pivot_table(
        index="patient",
        columns="gene",
        values="activation_specificity",
        aggfunc="median"
    )
)

patient_gene_matrix

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(10,6))
sns.heatmap(
    patient_gene_matrix,
    cmap="coolwarm",
    center=0,
    linewidths=0.5
)
plt.title("Patient-wise Activation-Specific Response (HLADR+ CD8 vs HLADR- CD8)")
plt.xlabel("Genes")
plt.ylabel("Patients")
plt.show()

In [None]:
patient_strength = patient_gene_matrix.abs().mean(axis=1)
patient_strength.sort_values(ascending=False)

In [None]:
plt.figure(figsize=(6,4))
patient_strength.sort_values().plot(kind="barh")
plt.title("Strength of Activation-Specific Immune Response per Patient")
plt.xlabel("Mean |activation specificity|")
plt.show()

In [None]:
gene_consistency = patient_gene_matrix.abs().mean(axis=0)
gene_consistency.sort_values(ascending=False)

In [None]:
# Add protein values into cell_meta
cell_meta["HLA-DR"] = adt_df.loc["HLA-DR"].values
cell_meta["CD8"]    = adt_df.loc["CD8"].values

In [None]:
cell_meta[["HLA-DR", "CD8"]].head()

In [None]:
# Define HLA-DR high per patient (top 20% inside each patient)
cell_meta["is_HLADR_high_patient"] = False

for p in cell_meta["patient"].unique():
    mask = cell_meta["patient"] == p
    thresh = cell_meta.loc[mask, "HLA-DR"].quantile(0.8)
    cell_meta.loc[mask, "is_HLADR_high_patient"] = cell_meta.loc[mask, "HLA-DR"] > thresh

In [None]:
cd8_thresh = cell_meta["CD8"].quantile(0.8)
cell_meta["is_CD8"] = cell_meta["CD8"] > cd8_thresh

In [None]:
cell_meta.groupby(["is_HLADR_high_patient", "is_CD8"]).size()

In [None]:
hladr_pos_cd8 = cell_meta.query("is_HLADR_high_patient and is_CD8").index
hladr_neg_cd8 = cell_meta.query("not is_HLADR_high_patient and is_CD8").index

In [None]:
import pandas as pd

rna_genes_subset = pd.read_csv(
    "rna_genes_subset.tsv",
    sep="\t",
    index_col=0
)

print(rna_genes_subset.shape)
rna_genes_subset.iloc[:5, :5]

In [None]:
common = set(rna_genes_subset.columns).intersection(set(cell_meta.index))
print("Common barcodes:", len(common))
print("RNA columns:", len(rna_genes_subset.columns))
print("Metadata rows:", len(cell_meta))

In [None]:

# Assign correct barcode names to RNA matrix columns
rna_genes_subset.columns = cell_meta.index.tolist()

# Sanity check
print(rna_genes_subset.columns[:5])
print(cell_meta.index[:5])

In [None]:
common = set(rna_genes_subset.columns).intersection(set(cell_meta.index))
print("Common barcodes:", len(common))
print("RNA columns:", len(rna_genes_subset.columns))
print("Metadata rows:", len(cell_meta))

In [None]:
# Activated CD8
rna_pos = rna_genes_subset.loc[:, hladr_pos_cd8]

# Resting CD8
rna_neg = rna_genes_subset.loc[:, hladr_neg_cd8]

print("RNA HLADR+ CD8:", rna_pos.shape)
print("RNA HLADR- CD8:", rna_neg.shape)

In [None]:
# Median expression per gene in each group
rna_pos_median = rna_pos.median(axis=1)
rna_neg_median = rna_neg.median(axis=1)

rna_compare = pd.DataFrame({
    "HLADRpos_CD8": rna_pos_median,
    "HLADRneg_CD8": rna_neg_median
})

rna_compare["log2FC_activation"] = np.log2(
    (rna_compare["HLADRpos_CD8"] + 1) /
    (rna_compare["HLADRneg_CD8"] + 1)
)

rna_compare.sort_values("log2FC_activation", ascending=False)

In [None]:
biomarker_genes = biomarkers["gene"].unique().tolist()
biomarker_genes

In [None]:
clean_biomarkers = [
    g for g in biomarker_genes
    if g not in ["IGHM"]
]

In [None]:
# Metadata for activated CD8 cells
meta_pos = cell_meta.loc[hladr_pos_cd8, ["patient", "week"]]

# Metadata for resting CD8 cells
meta_neg = cell_meta.loc[hladr_neg_cd8, ["patient", "week"]]

print(meta_pos.head())
print(meta_neg.head())

In [None]:
print(rna_pos.columns[:5])
print(meta_pos.index[:5])

print(rna_neg.columns[:5])
print(meta_neg.index[:5])

In [None]:
import numpy as np
import pandas as pd

# Step 1: Long format for HLADR+ CD8
rna_pos_long = (
    rna_pos.T
    .merge(meta_pos[["patient"]], left_index=True, right_index=True)
    .melt(id_vars="patient", var_name="gene", value_name="rna_count")
)

# Step 2: Long format for HLADR- CD8
rna_neg_long = (
    rna_neg.T
    .merge(meta_neg[["patient"]], left_index=True, right_index=True)
    .melt(id_vars="patient", var_name="gene", value_name="rna_count")
)

# Step 3: Median per patient per gene
pos_patient = (
    rna_pos_long
    .groupby(["patient", "gene"])["rna_count"]
    .median()
    .reset_index()
    .rename(columns={"rna_count": "rna_HLADRpos"})
)

neg_patient = (
    rna_neg_long
    .groupby(["patient", "gene"])["rna_count"]
    .median()
    .reset_index()
    .rename(columns={"rna_count": "rna_HLADRneg"})
)

# Step 4: Merge and compute activation score
rna_patient = (
    pos_patient
    .merge(neg_patient, on=["patient", "gene"], how="outer")
    .fillna(0)
)

rna_patient["log2FC_activation"] = np.log2(
    (rna_patient["rna_HLADRpos"] + 1) /
    (rna_patient["rna_HLADRneg"] + 1)
)

print(rna_patient.head())

In [None]:
rna_biomarkers = rna_patient[
    rna_patient["gene"].isin(biomarker_genes)
]

In [None]:
patient_gene_matrix = rna_biomarkers.pivot(
    index="patient",
    columns="gene",
    values="log2FC_activation"
)

patient_gene_matrix

In [None]:
patient_activation_score = patient_gene_matrix.mean(axis=1)
patient_activation_score = patient_activation_score.sort_values(ascending=False)

patient_activation_score

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(8,4))
patient_activation_score.plot(kind="bar")
plt.ylabel("Activation score")
plt.title("Patient immune activation ranking (HLADR⁺ CD8⁺ signature)")
plt.xticks(rotation=45)
plt.tight_layout()

# Save to Colab filesystem
plt.savefig("patient_activation_ranking.png", dpi=300)
plt.show()

print("Saved as: patient_activation_ranking.png")

In [None]:
import seaborn as sns

plt.figure(figsize=(10,6))
sns.heatmap(
    patient_gene_matrix,
    cmap="coolwarm",
    center=0
)
plt.title("Activated CD8 biomarker signature per patient")
plt.xlabel("Biomarker genes")
plt.ylabel("Patients")
plt.tight_layout()

# Save
plt.savefig("patient_biomarker_heatmap.png", dpi=300)
plt.show()

print("Saved as: patient_biomarker_heatmap.png")

In [None]:
patient_activation_score.to_csv("patient_activation_score.csv")
patient_gene_matrix.to_csv("patient_biomarker_matrix.csv")

print("Saved:")
print("- patient_activation_score.csv")
print("- patient_biomarker_matrix.csv")

In [None]:
patient_gene_matrix = rna_patient[
    rna_patient["gene"].isin(clean_biomarkers)
].pivot(
    index="patient",
    columns="gene",
    values="log2FC_activation"
)

patient_activation_score = patient_gene_matrix.mean(axis=1)
patient_activation_score.sort_values(ascending=False)

In [None]:
# #1 Filter CD8 cells only
cd8_cells = cell_meta[cell_meta["is_CD8"]].index

rna_cd8 = rna_genes_subset.loc[:, rna_genes_subset.columns.isin(cd8_cells)]
adt_cd8 = adt_df.loc[:, adt_df.columns.isin(cd8_cells)]
meta_cd8 = cell_meta.loc[cd8_cells]

In [None]:
!pip install scvi-tools

In [None]:
import anndata as ad
import scvi
import numpy as np

# RNA must be cells × genes
X = rna_cd8.T.values

adata = ad.AnnData(X=X)
adata.var_names = rna_cd8.index.astype(str)
adata.obs_names = rna_cd8.columns.astype(str)

# Add metadata
adata.obs["patient"] = meta_cd8["patient"].values
adata.obs["HLADR"] = meta_cd8["is_HLADR_high"].astype(str).values

In [None]:
# proteins must be cells × proteins
adt_X = adt_cd8.T.values

adata.obsm["protein_expression"] = adt_X
adata.uns["protein_names"] = adt_cd8.index.astype(str).tolist()

In [None]:
print("RNA matrix shape (genes x cells):", rna_cd8.shape)
print("ADT matrix shape (proteins x cells):", adt_cd8.shape)
print("Metadata shape (cells x variables):", meta_cd8.shape)

In [None]:
rna_cells = set(rna_cd8.columns)
adt_cells = set(adt_cd8.columns)
meta_cells = set(meta_cd8.index)

print("RNA ∩ ADT:", len(rna_cells & adt_cells))
print("RNA ∩ META:", len(rna_cells & meta_cells))
print("ADT ∩ META:", len(adt_cells & meta_cells))

print("RNA only:", len(rna_cells - adt_cells))
print("ADT only:", len(adt_cells - rna_cells))
print("META only:", len(meta_cells - rna_cells))

In [None]:
common_cells = list(rna_cells & adt_cells & meta_cells)

# Reorder everything identically
rna_cd8 = rna_cd8[common_cells]
adt_cd8 = adt_cd8[common_cells]
meta_cd8 = meta_cd8.loc[common_cells]

In [None]:
print("RNA example:")
print(rna_cd8.iloc[:5, :5])

print("\nADT example:")
print(adt_cd8.iloc[:5, :5])

print("\nMetadata example:")
print(meta_cd8.head())

In [None]:
print(meta_cd8["is_HLADR_high"].value_counts())

In [None]:
print(adt_cd8.index.tolist())

In [None]:
import sys

print("RNA memory (MB):", rna_cd8.values.nbytes / 1e6)
print("ADT memory (MB):", adt_cd8.values.nbytes / 1e6)

In [None]:
print(adata)

In [None]:
import scvi

scvi.model.TOTALVI.setup_anndata(
    adata,
    protein_expression_obsm_key="protein_expression",
    batch_key="patient"   # important: patient is your batch variable
)

In [None]:
model = scvi.model.TOTALVI(
    adata,
    n_latent=10,     # latent dimensions
    n_hidden=128
)

model.train(
    max_epochs=150,
    batch_size=256,
    accelerator="cpu"   # Colab free usually CPU
)

In [None]:
latent = model.get_latent_representation()
adata.obsm["X_totalVI"] = latent

latent.shape

In [None]:
adata.obsm["X_totalVI"] = latent

In [None]:
!pip install scanpy

In [None]:
import scanpy as sc

In [None]:
sc.pl.embedding(
    adata,
    basis="X_totalVI",
    color="HLADR",
    title="TotalVI latent space colored by HLA-DR protein",
    size=8
)

In [None]:
sc.pl.embedding(
    adata,
    basis="X_totalVI",
    color="patient",
    title="TotalVI latent space colored by patient",
    size=8
)

In [None]:
# Random subsample 5k cells for visualization
np.random.seed(0)
plot_cells = np.random.choice(adata.n_obs, size=5000, replace=False)

adata_plot = adata[plot_cells].copy()

In [None]:
# Extract HLA-DR protein values
hladr_idx = adata_plot.uns["protein_names"].index("HLA-DR")
adata_plot.obs["HLA-DR_protein"] = adata_plot.obsm["protein_expression"][:, hladr_idx]

In [None]:
sc.pl.embedding(
    adata_plot,
    basis="X_totalVI",
    color="HLA-DR_protein",
    size=20,
    alpha=0.8,
    title="HLA-DR protein (subsampled)"
)

In [None]:
cd8_idx = adata_plot.uns["protein_names"].index("CD8")
adata_plot.obs["CD8_protein"] = adata_plot.obsm["protein_expression"][:, cd8_idx]

sc.pl.embedding(
    adata_plot,
    basis="X_totalVI",
    color="CD8_protein",
    size=20,
    alpha=0.8,
    title="CD8 protein (subsampled)"
)

In [None]:
print(adata_plot.obs.columns)

In [None]:
hladr_thresh = adata_plot.obs["HLA-DR_protein"].quantile(0.8)
cd8_thresh   = adata_plot.obs["CD8_protein"].quantile(0.8)

In [None]:
adata_plot.obs["group"] = "Other"

# HLADR+ CD8+
adata_plot.obs.loc[
    (adata_plot.obs["HLA-DR_protein"] > hladr_thresh) &
    (adata_plot.obs["CD8_protein"] > cd8_thresh),
    "group"
] = "HLADR+ CD8+"

# HLADR- CD8+
adata_plot.obs.loc[
    (adata_plot.obs["HLA-DR_protein"] <= hladr_thresh) &
    (adata_plot.obs["CD8_protein"] > cd8_thresh),
    "group"
] = "HLADR- CD8+"

# HLADR+ CD8-
adata_plot.obs.loc[
    (adata_plot.obs["HLA-DR_protein"] > hladr_thresh) &
    (adata_plot.obs["CD8_protein"] <= cd8_thresh),
    "group"
] = "HLADR+ CD8-"

In [None]:
adata_plot.obs["group"].value_counts()

In [None]:
sc.pl.embedding(
    adata_plot,
    basis="X_totalVI",
    color="group",
    size=20,
    alpha=0.75,
    title="TotalVI latent space: immune activation states"
)

In [None]:
groups_to_show = ["HLADR+ CD8+", "HLADR- CD8+", "HLADR+ CD8-"]

adata_focus = adata_plot[adata_plot.obs["group"].isin(groups_to_show)].copy()

sc.pl.embedding(
    adata_focus,
    basis="X_totalVI",
    color="group",
    size=30,
    alpha=0.9,
    title="TotalVI latent space (activation-relevant groups only)"
)

In [None]:
rna_matrix = rna_genes_subset

In [None]:
print(adt_matrix.index.tolist())

In [None]:
print(adata.uns["protein_names"])

In [None]:
print(adt_matrix.shape)
adt_matrix.head()

In [None]:
# Remove the first row (it contains barcodes, not protein counts)
adt_matrix = adt_matrix.iloc[1:, :]

# Now assign protein names
adt_matrix.index = [
    'CD14', 'CD16', 'CD19', 'CD27', 'CD3', 'CD4',
    'CD45RA', 'CD62L', 'CD8', 'HLA-DR', 'IgM'
]

In [None]:
print(adt_matrix.shape)
print(adt_matrix.index)

In [None]:
# Make sure RNA and ADT are numeric
rna_hladr = pd.to_numeric(rna_matrix.loc["HLA-DRA"], errors="coerce")
adt_hladr = pd.to_numeric(adt_matrix.loc["HLA-DR"], errors="coerce")

# Sanity check
print(rna_hladr.dtype, adt_hladr.dtype)
print(rna_hladr.isna().sum(), adt_hladr.isna().sum())

In [None]:
print("Length ADT:", len(adt_hladr))
print("Length RNA:", len(rna_hladr))

In [None]:
# Drop the first ADT value to align with RNA
adt_hladr = adt_hladr.iloc[1:]

In [None]:
print("Length ADT:", len(adt_hladr))
print("Length RNA:", len(rna_hladr))

In [None]:
adt_vals = pd.to_numeric(adt_hladr, errors="coerce").values
rna_vals = pd.to_numeric(rna_hladr, errors="coerce").values

hladr_df = pd.DataFrame({
    "ADT_HLADR": adt_vals,
    "RNA_HLADRA": rna_vals
}).dropna()

print(hladr_df.shape)

In [None]:
plt.figure(figsize=(5,5))
plt.scatter(
    hladr_df["ADT_HLADR"],
    hladr_df["RNA_HLADRA"],
    alpha=0.4
)
plt.xlabel("HLA-DR protein (ADT)")
plt.ylabel("HLA-DRA RNA")
plt.title("RNA–Protein concordance for HLA-DR")
plt.show()

In [None]:
from scipy.stats import spearmanr

rho, p = spearmanr(hladr_df["ADT_HLADR"], hladr_df["RNA_HLADRA"])
print("Spearman correlation:", rho, "p-value:", p)

In [None]:
adt_matrix.index = [
    'CD14', 'CD16', 'CD19', 'CD27', 'CD3', 'CD4',
    'CD45RA', 'CD62L', 'CD8', 'HLA-DR', 'IgM'
]

In [None]:
# RNA vs ADT correlation for HLA-DR
rna_hladr = rna_matrix.loc["HLA-DRA"]
adt_hladr = adt_matrix.loc["HLA-DR"]

plt.figure(figsize=(5,5))
plt.scatter(adt_hladr, rna_hladr, alpha=0.4)
plt.xlabel("HLA-DR protein (ADT)")
plt.ylabel("HLA-DRA RNA")
plt.title("RNA–Protein concordance for HLA-DR")
plt.show()

In [None]:
sc.pl.embedding(
    adata_plot,
    basis="X_totalVI",
    color=adata_plot.obsm["protein_expression"][:, cd8_idx],
    size=20,
    alpha=0.8,
    title="CD8 protein (subsampled)"
)

In [None]:
sc.pl.embedding(
    adata_plot,
    basis="X_totalVI",
    color="group",
    size=20,
    alpha=0.8,
    title="Activation groups (subsampled)"
)