In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import tcrdist
import matplotlib.pyplot as plt
from tcrdist.repertoire import TCRrep
from tcrdist.html_colors import get_html_colors
from tcrdist.public import _neighbors_fixed_radius
import pwseqdist as pw
import networkx as nx
import community as community_louvain
from itertools import combinations

In [None]:
import sys
sys.path.append('../')

In [None]:
from tcrnet.process import (
    standardize_tcr_data, 
    preprocess_tcr_data, 
    compute_clonotype_abundances
)
from tcrnet.visualize import (
    chain_pairing_configurations,
    sequence_length_distributions,
    clonotype_abundances,
    top_n_clonotypes,
    generate_network_plot
)
from tcrnet.networks import similarity, graph, cluster, metrics

In [None]:
# path to your TCR data
SAMPLE_ID = 'COVID_01'
clonotype_definition = ['cdr1', 'cdr2', 'cdr3']
tcr_filepath = "/Users/alaa/Documents/ucsf/data/tcrnet/10x/huati_gi_rr3692_01/vdj_t/filtered_contig_annotations.csv"
# tcr_filepath = "/Users/alaa/Documents/ucsf/data/sars2/tcr/SEVERE2_GSM4385994_C145_filtered_contig_annotations.csv.gz"

In [None]:
# load TCR data and standardize the format
tcr_df = standardize_tcr_data(tcr_filepath=tcr_filepath, 
                              technology_platform='10X',
#                               compression='gzip'
                             )

In [None]:
tcr_df.shape

In [None]:
tcr_df

In [None]:
# generate QC plot showing the different alpha-beta pairing configurations in the data
chain_pairing_configurations(tcr_df=tcr_df, clonotype_definition=clonotype_definition)

In [None]:
# preprocess TCR data (chain pairing, QC, and clonotype definition)
ptcr_df = preprocess_tcr_data(tcr_df=tcr_df, sample_id=SAMPLE_ID, clonotype_definition=clonotype_definition)

In [None]:
# ptcr_df

In [None]:
# compute clonotype abundances (absolute counts and relative frequencies)
qtcr_df = compute_clonotype_abundances(processed_tcr_df=ptcr_df, clonotype_definition=clonotype_definition)

In [None]:
# generate panel of bar plots showing sequence length distribution across complementarity determining regions
sequence_length_distributions(tcr_df=qtcr_df)

In [None]:
# generate histogram of clonotype abundances (most will likely have count = 1)
clonotype_abundances(tcr_df=qtcr_df)

In [None]:
# it is usually more helpful to look at clonotype abundances for clonotypes with counts > 1
clonotype_abundances(tcr_df=qtcr_df.loc[qtcr_df['num_records']>1])

In [None]:
# visualize the top clonotypes by relative abundance
top_n_clonotypes(tcr_df=qtcr_df, top_n=13)

## Networks

### Similarity Matrix

In [None]:
# network analysis parameters
edge_threshold = 150
clonotype_count_threshold = 2
analysis_mode = 'private'
top_k_clusters = 9

In [None]:
print(qtcr_df.shape)
qtcr_df.head()

In [None]:
# compute distance matrix
ntcr_df, distance_matrix = similarity.compute_tcrdist(qtcr_df)

In [None]:
type(distance_matrix)

In [None]:
network_df = graph.generate_graph_dataframe(ntcr_df=ntcr_df, 
                                            distance_matrix=distance_matrix,
                                            analysis_mode=analysis_mode,
                                            edge_threshold=edge_threshold,
                                            count_threshold=clonotype_count_threshold)
network_df

In [None]:
tcr_graph = graph.create_undirected_graph(net_df=network_df)

In [None]:
partition, cluster2color = cluster.cluster_lovain(net_df=network_df,
                                                  color_top_k_clusters=top_k_clusters)

In [None]:
# 
cluster2color

In [None]:
network_df = graph.update_df_with_cluster_information(net_df=network_df, partition=partition)

In [None]:
net_metrics = metrics.compute_network_metrics(net_df=network_df, 
                                              graph=tcr_graph, 
                                              top_k_clusters=top_k_clusters)

In [None]:
net_metrics

In [None]:
generate_network_plot(graph=tcr_graph, 
                      network_metrics=net_metrics, 
                      partition=partition, 
                      colors=cluster2color)

# DEV

In [None]:
# qtcr_df

In [None]:
ntcr_df = qtcr_df.rename(columns = {
    'num_records': 'count',
    'pct_records': 'frequency',
    'beta_cdr1': 'cdr1_b_aa',
    'beta_cdr2': 'cdr2_b_aa',
    'beta_cdr3': 'cdr3_b_aa',
    'alpha_cdr1': 'cdr1_a_aa',
    'alpha_cdr2': 'cdr2_a_aa',
    'alpha_cdr3': 'cdr3_a_aa',
    'clonotype_id': 'clone_id'
}).copy()
ntcr_df['sample_id'] = SAMPLE_ID
target_columns = ['count', 'frequency', 
                  'cdr1_b_aa','cdr2_b_aa', 'cdr3_b_aa', 
                  'cdr1_a_aa','cdr2_a_aa', 'cdr3_a_aa', 
                  'clone_id', 'sample_id']
ntcr_df = ntcr_df[target_columns].copy()

In [None]:
# assign distance metrics and weighting for each TCR sequence
beta_metrics = {
    "cdr3_b_aa": pw.metrics.nb_vector_tcrdist,
    "cdr2_b_aa": pw.metrics.nb_vector_tcrdist,
    "cdr1_b_aa": pw.metrics.nb_vector_tcrdist,
    "cdr3_a_aa": pw.metrics.nb_vector_tcrdist,
    "cdr2_a_aa": pw.metrics.nb_vector_tcrdist,
    "cdr1_a_aa": pw.metrics.nb_vector_tcrdist
}

beta_weights = {
    "cdr3_b_aa": 3,
    "cdr2_b_aa": 1,
    "cdr1_b_aa": 1,
    "cdr3_a_aa": 3,
    "cdr2_a_aa": 1,
    "cdr1_a_aa": 1
}
beta_kargs = {
    "cdr3_b_aa": {"use_numba": True},
    "cdr2_b_aa": {"use_numba": True},
    "cdr1_b_aa": {"use_numba": True},
    "cdr3_a_aa": {"use_numba": True},
    "cdr2_a_aa": {"use_numba": True},
    "cdr1_a_aa": {"use_numba": True}
}

dist_mat = tcrdist.rep_funcs._pws(df = ntcr_df,
                       metrics = beta_metrics,
                       weights = beta_weights,
                       kargs = beta_kargs,
                       cpu = 5,
                       uniquify = True,
                       store = True)

In [None]:
dist_mat

In [None]:
ntcr_df['frequency'].min(), ntcr_df['frequency'].max(), ntcr_df['frequency'].median(), ntcr_df['frequency'].mean()

In [None]:
# SPECIFY distance threshold for creating edges between TCR clonotypes
edge_threshold = 150
# specify frequency threshold for inclusion of clonotype islands
frequency_threshold = 2
# create set of nearest neighbors for each clonotype
x = _neighbors_fixed_radius(dist_mat['tcrdist'], edge_threshold)
network = list()
# populate our network with nodes as clonotypes, and edges as similarity-distance between them
for n1_idx, n1_neighbors in enumerate(x):
    for n2_idx in n1_neighbors:
        if n1_idx!=n2_idx:
            network.append((
                n1_idx,
                n2_idx,
                dist_mat['tcrdist'][n1_idx, n2_idx],
                ntcr_df['cdr1_b_aa'].iloc[n1_idx],
                ntcr_df['cdr1_b_aa'].iloc[n2_idx],
                ntcr_df['cdr2_b_aa'].iloc[n1_idx],
                ntcr_df['cdr2_b_aa'].iloc[n2_idx],
                ntcr_df['cdr3_b_aa'].iloc[n1_idx],
                ntcr_df['cdr3_b_aa'].iloc[n2_idx],
                ntcr_df['cdr1_a_aa'].iloc[n1_idx],
                ntcr_df['cdr1_a_aa'].iloc[n2_idx],
                ntcr_df['cdr2_a_aa'].iloc[n1_idx],
                ntcr_df['cdr2_a_aa'].iloc[n2_idx],
                ntcr_df['cdr3_a_aa'].iloc[n1_idx],
                ntcr_df['cdr3_a_aa'].iloc[n2_idx],
                len(n1_neighbors),
                False,
                ntcr_df['clone_id'].iloc[n1_idx],
                ntcr_df['clone_id'].iloc[n2_idx],
                ntcr_df['sample_id'].iloc[n1_idx],
                ntcr_df['sample_id'].iloc[n2_idx]
            ))
        elif ntcr_df['count'].iloc[n1_idx] >= frequency_threshold and len(n1_neighbors)==1:
            network.append((
                n1_idx,
                n2_idx,
                dist_mat['tcrdist'][n1_idx, n2_idx],
                ntcr_df['cdr1_b_aa'].iloc[n1_idx],
                ntcr_df['cdr1_b_aa'].iloc[n2_idx],
                ntcr_df['cdr2_b_aa'].iloc[n1_idx],
                ntcr_df['cdr2_b_aa'].iloc[n2_idx],
                ntcr_df['cdr3_b_aa'].iloc[n1_idx],
                ntcr_df['cdr3_b_aa'].iloc[n2_idx],
                ntcr_df['cdr1_a_aa'].iloc[n1_idx],
                ntcr_df['cdr1_a_aa'].iloc[n2_idx],
                ntcr_df['cdr2_a_aa'].iloc[n1_idx],
                ntcr_df['cdr2_a_aa'].iloc[n2_idx],
                ntcr_df['cdr3_a_aa'].iloc[n1_idx],
                ntcr_df['cdr3_a_aa'].iloc[n2_idx],
                len(n1_neighbors),
                True,
                ntcr_df['clone_id'].iloc[n1_idx],
                ntcr_df['clone_id'].iloc[n2_idx],
                ntcr_df['sample_id'].iloc[n1_idx],
                ntcr_df['sample_id'].iloc[n2_idx]
            ))

In [None]:
network

In [None]:
# create a dataframe representation of our network graph
network_columns = ['node_1', 'node_2', 'distance', 
                   'cdr1_b_aa_1', 'cdr1_b_aa_2',
                   'cdr2_b_aa_1', 'cdr2_b_aa_2',
                   'cdr3_b_aa_1', 'cdr3_b_aa_2',
                   'cdr1_a_aa_1', 'cdr1_a_aa_2',
                   'cdr2_a_aa_1', 'cdr2_a_aa_2',
                   'cdr3_a_aa_1', 'cdr3_a_aa_2',
                   'k_neighbors', 'is_island',
                   'clone_id_1', 'clone_id_2',
                   'sample_id_1', 'sample_id_2']
network_df = pd.DataFrame(network, columns = network_columns)

In [None]:
network_df

In [None]:
# perform network analysis on each individual subject (patient ID)
subject_id = 'huati_06'
top_k_clusters = 9
# calculate the weight for each edge (connection between two TCR clonotypes)
network_df['weight'] = (edge_threshold - network_df['distance']) / edge_threshold
# create a field that tells us whether a connection is within a subject or between two different subjects
network_df['relation'] = 'private'
network_df.loc[network_df['sample_id_1']!=network_df['sample_id_2'], 'relation'] = 'public'
# 
subnetwork_df = network_df.loc[(network_df['sample_id_1']==subject_id)
                              &(network_df['relation']=='private')].copy()

In [None]:
# create a undirected graph
graph = nx.from_pandas_edgelist(pd.DataFrame({
    'source': subnetwork_df.loc[subnetwork_df['is_island']==False, 'node_1'],
    'target': subnetwork_df.loc[subnetwork_df['is_island']==False, 'node_2'],
    'weight': subnetwork_df.loc[subnetwork_df['is_island']==False, 'weight'],
}))
# perform unsupervised clustering on the graph
partition = community_louvain.best_partition(graph, random_state=42)
partitions_by_cluster_size = list(pd.Series(partition.values()).value_counts().index)
# order clusters based on their size
partition_reorder = {idx: rank for idx, rank in zip(partitions_by_cluster_size, 
                                                    range(len(partitions_by_cluster_size)))}
partition = {k: partition_reorder.get(v) for k, v in partition.items()}
clusters = [i for i in pd.Series(partition.values()).value_counts().index[:top_k_clusters]]
# assign a color for each of the top K clusters
colors = get_html_colors(top_k_clusters)
cluster2color = {clust: color for clust, color in zip(clusters, colors)}

In [None]:
cols_of_interest = ['node_1', 'node_2', 'distance',
                    'clone_id_1', 'clone_id_2','cluster_1', 'cluster_2']
# cluster-based quantifications
subnetwork_df['cluster_1'] = subnetwork_df['node_1'].apply(lambda x: partition.get(x, None))
subnetwork_df['cluster_2'] = subnetwork_df['node_2'].apply(lambda x: partition.get(x, None))
# sizes of each cluster in the left-hand side
cluster1_sizes = (subnetwork_df
                 .groupby('cluster_1')
                 .agg(cluster1_size=('node_1', 'nunique'))
                 .reset_index())
# sizes of each cluster in the right-hand side
cluster2_sizes = (subnetwork_df
                 .groupby('cluster_2')
                 .agg(cluster2_size=('node_2', 'nunique'))
                 .reset_index())
# adding cluster size information to our network dataframe
subnetwork_df2 = pd.merge(subnetwork_df, cluster1_sizes, on='cluster_1')
subnetwork_df2 = pd.merge(subnetwork_df2, cluster2_sizes, on='cluster_2')
subnetwork_df2[cols_of_interest].head()

In [None]:
# subnetwork_df2

In [None]:
# add island nodes to the graph for visualization
for i, node_island in subnetwork_df.loc[subnetwork_df['is_island']==True].iterrows():
#     print(node_island['node_1'])
    graph.add_node(node_island['node_1'])

In [None]:
# nx_kwargs = {"edgecolors": "tab:gray", "node_size": 50}
node_positions = nx.spring_layout(graph, seed=42, k=.15)

In [None]:
subnetwork_df2.head()

In [None]:
network_df = subnetwork_df2.copy()

In [None]:
top_k_clusters

In [None]:
for cluster_combination in combinations(clusters_of_interest, 2):
    print(cluster_combination)

In [None]:
n = len(graph.nodes)
network_density = network_df.loc[network_df['node_1']!=network_df['node_2'], 
                                    'weight'].sum() / (n*(n-1))
clusters_of_interest = list(range(0, top_k_clusters))
intercon_df = (network_df
               .loc[(network_df['cluster_1'].isin(clusters_of_interest))
                   &(network_df['cluster_2'].isin(clusters_of_interest))
                   &(network_df['cluster_1']!=network_df['cluster_2'])]
               .copy())
intracluster_connectivity = []
intercluster_connectivity = []
for cluster_combination in combinations(clusters_of_interest, 2):
#     print(cluster_combination)
    unicluster_df = (network_df
                     .loc[(network_df['cluster_1']==cluster_combination[0])
                         &(network_df['cluster_2']==cluster_combination[0])
                         &(network_df['node_1']!=network_df['node_2'])]
                     .copy())

    intra_numerator = unicluster_df['weight'].sum()
    assert np.nan_to_num(unicluster_df['cluster1_size'].mean())==np.nan_to_num(unicluster_df['cluster2_size'].mean()), "Error: unexpected non-equality"
    intra_denominator = np.maximum(1, unicluster_df['cluster1_size'].mean()*(unicluster_df['cluster2_size'].mean()-1))
    intracluster_connectivity.append(intra_numerator / intra_denominator)
    bicluster_df = (intercon_df
                     .loc[(intercon_df['cluster_1']==cluster_combination[0])
                         &(intercon_df['cluster_2']==cluster_combination[1])]
                     .copy())
    inter_numerator = bicluster_df['weight'].sum()**2
    inter_denominator = (bicluster_df['cluster1_size'].mean()*bicluster_df['cluster2_size'].mean()*np.abs(bicluster_df['cluster2_size'].mean() - bicluster_df['cluster1_size'].mean())**2)
    intercluster_connectivity.append(np.sqrt(inter_numerator / np.maximum(1., inter_denominator)))
intracluster_connectivity = np.nan_to_num(intracluster_connectivity)
intercluster_connectivity = np.nan_to_num(intercluster_connectivity)

In [None]:
intracluster_connectivity.shape

In [None]:
intracluster_connectivity

In [None]:
print(f"Network density for sample {subject_id}: {network_density:.6f}")
print(f"Network intra-cluster density for sample {subject_id}: {intracluster_connectivity.mean():.6f}")
print(f"Network inter-cluster density for sample {subject_id}: {intercluster_connectivity.mean():.6f}")

In [None]:
type(graph)

In [None]:
# partition

In [None]:
def generate_network_plot(graph: nx.classes.graph.Graph,
                          partition: dict,
                          colors: dict):
    # set the desired figure resolution
    plt.rcParams['figure.dpi'] = 550
    nx_kwargs = {"edgecolors": "tab:gray"}
    # Set the desired figure size (adjust width and height as needed)
    fig = plt.figure(figsize=(12, 10))

    # Create two separate lists of nodes for the 1st and 2nd states
    nodes = [node for node in graph.nodes]

    # Draw the network graphs with circles for the 1st state and triangles for the 2nd state
    nx.draw(graph,
            nodelist=nodes,  # Only nodes in the 1st state
            pos=node_positions,
            node_color=[colors.get(partition.get(i), 'grey') for i in nodes],
            node_shape='o',  # Circle shape for 1st state
            node_size=50,
            with_labels=False,
            **nx_kwargs)


    # Annotate each cluster with its cluster number
    for cluster, color in list(cluster2color.items())[:top_k_clusters]:
        cluster_nodes = [node for node, part in partition.items() if part == cluster]
        x, y = zip(*[node_positions[node] for node in cluster_nodes])
        x_center, y_center = sum(x) / len(cluster_nodes), sum(y) / len(cluster_nodes)
        plt.text(x_center, y_center, f'{cluster}', fontsize=11, 
                 color='black', ha='center', va='center', fontweight='bold')

    # # Add a custom legend
    # shape_legends = [
    #     Line2D([0], [0], marker='o', color='w', markerfacecolor='w', 
    #            markersize=10, label='Dual', markeredgewidth=0.5, 
    #            markeredgecolor='k'),
    #     Line2D([0], [0], marker='^', color='w', markerfacecolor='w', 
    #            markersize=10, label='Newly', markeredgewidth=0.5, 
    #            markeredgecolor='k'),
    # ]

    # size_legends = [
    #     Line2D([0], [0], marker='o', color='w', markerfacecolor='grey',
    #            label='< 1e-05', markersize=8, markeredgecolor='w'),
    #     Line2D([0], [0], marker='o', color='w', markerfacecolor='grey',
    #            label='< 5e-05', markersize=12, markeredgecolor='w'),
    #     Line2D([0], [0], marker='o', color='w', markerfacecolor='grey',
    #            label='< 1e-04', markersize=14, markeredgecolor='w'),
    #     Line2D([0], [0], marker='o', color='w', markerfacecolor='grey',
    #            label='< 5e-04', markersize=16, markeredgecolor='w'),
    #     Line2D([0], [0], marker='o', color='w', markerfacecolor='grey',
    #            label='< 1e-03', markersize=18, markeredgecolor='w'),
    #     Line2D([0], [0], marker='o', color='w', markerfacecolor='grey',
    #            label='< 5e-03', markersize=20, markeredgecolor='w'),
    #     Line2D([0], [0], marker='o', color='w', markerfacecolor='grey',
    #            label='< 1e-02', markersize=22, markeredgecolor='w'),
    #     Line2D([0], [0], marker='o', color='w', markerfacecolor='grey',
    #            label='< 5e-02', markersize=24, markeredgecolor='w'),
    #     Line2D([0], [0], marker='o', color='w', markerfacecolor='grey',
    #            label='< 1e-01', markersize=26, markeredgecolor='w'),
    #     Line2D([0], [0], marker='o', color='w', markerfacecolor='grey',
    #            label='> 5e-01', markersize=28, markeredgecolor='w'),
    # ]

    # first_legend = plt.legend(handles=shape_legends, loc='upper right', title='Clonotype Expansion State')
    ax = plt.gca()
    # ax.add_artist(first_legend)
    # plt.legend(handles=size_legends, loc='upper left', title='Clonotype Frequency Post Treatment')

    latex_symbol1 = r'$D_{total}$'
    latex_symbol2 = r'$\overline{S}(C_{x})_{x \in [0:12]}$'
    latex_symbol3 = r'$\overline{S}(C_{x}, C_{y})_{x,y \in [0:12], x \neq y}$'
    latex_symbol4 = r'$\beta$'
    plt.title(f"Clonotypes in {SAMPLE_ID}\nColored by overlap with top {latex_symbol4} CDR3s from scTCRseq of multimer sort (BOOSTED response)\n{latex_symbol1} = {network_density:.6f}\n{latex_symbol2} = {np.mean(intracluster_connectivity):.4G}\n{latex_symbol3} = {np.mean(intercluster_connectivity):.4G}")
    # plt.savefig(f"/Users/alaa/Documents/ucsf/data/rutishauser/bmgf_vax/in omniscope/clonotype_networks/bmgf_p{subject_id}_{clonotype_dynamics_directionality}_clonotypes_network_overlap_scTCRb_CDR3_multimer_sort_boosted_response_edge-threshold64_v1.png", 
    #             dpi=550,
    #             bbox_inches='tight')
    plt.show()