In [2]:
import numpy as np
from itertools import combinations
import scanpy as sc
import random
import copy

In [3]:
file_path = "//Users/apple/Desktop/KB/data/LarryData/Larry_41201_2000.h5ad"
adata = sc.read_h5ad(file_path)
count_matrix = adata.X
count_matrix = count_matrix.toarray()
cell_lineage = adata.obs['clone_id'].values.reshape(-1, 1)

In [4]:
def generate_lineage_info(lineage):
    """
    return a dictionary with key of the unqiue element of the lineage and the elements a list contains the index.
    """
    lineage_info = {}
    unique_elements = np.unique(lineage)
    
    for element in unique_elements:
        lineage_info[element] = np.where(lineage == element)[0].tolist()
    
    return lineage_info

def generate_avail_lineage_pairs(lineage_info, size_factor):
    """
    return the combinations of the j-choose-2, with the lenghth of size_factor*j*(j-1)/2
    """
    avail_lineage_pairs = {}
    
    for key, indices in lineage_info.items():
        j = len(indices)
        pair_count = int(size_factor * (j * (j - 1) / 2))
        pairs = list(combinations(indices, 2))
        
        # Shuffle and select required number of pairs
        np.random.shuffle(pairs)
        avail_lineage_pairs[key] = pairs[:pair_count]
    
    return avail_lineage_pairs

def generate_avail_lineage_pairs(lineage_info, size_factor):
    """
    Return the combinations of the j-choose-2, with the length of size_factor*j*(j-1)/2,
    ensuring that all elements are used at least once.
    """
    avail_lineage_pairs = {}
    
    for key, indices in lineage_info.items():
        j = len(indices)
        all_pairs = list(combinations(indices, 2))
        pair_count = int(size_factor * (j * (j - 1) / 2))
        
        # Ensure all elements are used at least once
        used_indices = set()
        essential_pairs = []
        
        for pair in all_pairs:
            if pair[0] not in used_indices or pair[1] not in used_indices:
                essential_pairs.append(pair)
                used_indices.update(pair)
                print(used_indices)
        
        # Shuffle the remaining pairs and select the required number to fill up to pair_count
        remaining_pairs = [pair for pair in all_pairs if pair not in essential_pairs]
        random.shuffle(remaining_pairs)
        additional_pairs = remaining_pairs[:max(0, pair_count - len(essential_pairs))]
        
        avail_lineage_pairs[key] = essential_pairs + additional_pairs
    
    return avail_lineage_pairs


def get_min_max_length(lineage_info):
    lengths = [len(indices) for indices in lineage_info.values()]
    min_length = min(lengths)
    max_length = max(lengths)
    mean_length = round(sum(lengths)/len(lengths),2)

    print(f"the range of number of cells in a lineage: {min_length, max_length}, average of number of cells in a lineage {mean_length}")




In [5]:
def generate_batch_all_index(avail_lineage_pairs, batch_size):
    avail_lineage_pairs_cp = copy.deepcopy(avail_lineage_pairs)
    batch_all_index = {}
    i = 0
    
    while len(avail_lineage_pairs_cp.keys()) != 0:
        batch_all_index[i] = []
        if len(avail_lineage_pairs_cp.keys()) >= batch_size:
            selected_keys = random.sample(list(avail_lineage_pairs_cp.keys()), batch_size)
        else:
            selected_keys = list(avail_lineage_pairs_cp.keys())
        
        for key in selected_keys:
            selected_tuple = random.choice(avail_lineage_pairs_cp[key])
            batch_all_index[i].append(selected_tuple)
            avail_lineage_pairs_cp[key].remove(selected_tuple)
            if not avail_lineage_pairs_cp[key]:
                del avail_lineage_pairs_cp[key]
        
        if len(selected_keys) < batch_size:
            complement_keys = list(set(avail_lineage_pairs.keys()) - set(avail_lineage_pairs_cp.keys()))
            remaining_keys = random.sample(complement_keys, batch_size - len(selected_keys))
            for key in remaining_keys:
                selected_tuple = random.choice(avail_lineage_pairs[key])
                batch_all_index[i].append(selected_tuple)
        
        i += 1
    
    return batch_all_index



In [6]:
def generate_batch_all(batch_all_index, count_matrix):
    batch_all = {}
    for key, pairs in batch_all_index.items():
        batch_all[key] = [(count_matrix[pair[0]], count_matrix[pair[1]]) for pair in pairs]
    return batch_all

def generate_lineage_array(batch_all_index, lineage):
    batch_size = len(next(iter(batch_all_index.values())))
    m = len(batch_all_index) * batch_size
    lineage_array = np.zeros((m, 1), dtype=int)  # Specify dtype as int
    index = 0
    for pairs in batch_all_index.values():
        for pair in pairs:
            lineage_array[index] = lineage[pair[0]]
            index += 1
    return lineage_array

In [7]:
# Example usage
lineage = np.array([1, 2, 1, 2, 3, 1, 3, 2,1, 2, 1, 2,1, 2, 1, 2,3, 1, 3,3, 1, 3])
size_factor = 0.5
batch_size = 2

lineage_info = generate_lineage_info(lineage)
print("Lineage Info:", lineage_info)


Lineage Info: {1: [0, 2, 5, 8, 10, 12, 14, 17, 20], 2: [1, 3, 7, 9, 11, 13, 15], 3: [4, 6, 16, 18, 19, 21]}


In [8]:
avail_lineage_pairs = generate_avail_lineage_pairs(lineage_info, size_factor)
print("Available Lineage Pairs:", avail_lineage_pairs)

{0, 2}
{0, 2, 5}
{0, 8, 2, 5}
{0, 2, 5, 8, 10}
{0, 2, 5, 8, 10, 12}
{0, 2, 5, 8, 10, 12, 14}
{0, 2, 5, 8, 10, 12, 14, 17}
{0, 2, 5, 8, 10, 12, 14, 17, 20}
{1, 3}
{1, 3, 7}
{1, 3, 9, 7}
{1, 3, 7, 9, 11}
{1, 3, 7, 9, 11, 13}
{1, 3, 7, 9, 11, 13, 15}
{4, 6}
{16, 4, 6}
{16, 18, 4, 6}
{4, 6, 16, 18, 19}
{4, 6, 16, 18, 19, 21}
Available Lineage Pairs: {1: [(0, 2), (0, 5), (0, 8), (0, 10), (0, 12), (0, 14), (0, 17), (0, 20), (2, 5), (8, 12), (2, 12), (2, 8), (10, 17), (2, 14), (14, 17), (12, 14), (12, 20), (14, 20)], 2: [(1, 3), (1, 7), (1, 9), (1, 11), (1, 13), (1, 15), (11, 13), (9, 13), (9, 11), (3, 11)], 3: [(4, 6), (4, 16), (4, 18), (4, 19), (4, 21), (6, 21), (6, 19)]}


In [9]:
lineage[8]

1

In [10]:
batch_all_index = generate_batch_all_index(avail_lineage_pairs, batch_size)
print("Batch All Index:", batch_all_index)

Batch All Index: {0: [(1, 15), (4, 18)], 1: [(9, 11), (0, 17)], 2: [(10, 17), (1, 11)], 3: [(11, 13), (4, 6)], 4: [(0, 12), (1, 7)], 5: [(6, 19), (0, 5)], 6: [(6, 21), (9, 13)], 7: [(2, 12), (4, 19)], 8: [(0, 8), (4, 16)], 9: [(2, 14), (4, 21)], 10: [(0, 20), (1, 13)], 11: [(12, 14), (1, 3)], 12: [(3, 11), (14, 17)], 13: [(1, 9), (2, 5)], 14: [(0, 14), (1, 15)], 15: [(0, 2), (9, 11)], 16: [(14, 20), (1, 13)], 17: [(12, 20), (4, 16)], 18: [(0, 10), (1, 3)], 19: [(8, 12), (1, 15)], 20: [(2, 8), (10, 17)]}


In [11]:
len(batch_all_index.keys())

21

In [12]:
x = generate_lineage_array(batch_all_index, lineage)

array([[2],
       [3],
       [2],
       [1],
       [1],
       [2],
       [2],
       [3],
       [1],
       [2],
       [3],
       [1],
       [3],
       [2],
       [1],
       [3],
       [1],
       [3],
       [1],
       [3],
       [1],
       [2],
       [1],
       [2],
       [2],
       [1],
       [2],
       [1],
       [1],
       [2],
       [1],
       [2],
       [1],
       [2],
       [1],
       [3],
       [1],
       [2],
       [1],
       [2],
       [1],
       [1]])

In [35]:
lineage_info = generate_lineage_info(cell_lineage)

In [36]:
len(lineage_info.keys())

2817

In [17]:
get_min_max_length(lineage_info)

the range of number of cells in a lineage: (5, 177), average of number of cells in a lineage 14.63


In [18]:
avail_lineage_pairs = generate_avail_lineage_pairs(lineage_info, size_factor)

In [19]:
get_min_max_length(avail_lineage_pairs)

the range of number of cells in a lineage: (5, 7788), average of number of cells in a lineage 105.07


In [24]:
batch_all_index = generate_batch_all_index(avail_lineage_pairs, 30)

In [25]:
len(batch_all_index.keys())

14840

In [29]:
adata.obs.shape

(41201, 9)

In [60]:
(33*32/4)*15

3960.0