In [2]:

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Pad
import os
import pickle
import preprocessing as pre
import plotly.express as px
import plotly.graph_objects as go
from tqdm import tqdm
import seaborn as sns
from pathlib import Path


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
data_path = "/data/cb/ruochiz/scHiC/"
data_names = ["m3c_mouse_brain_small",  "m3c_Tian_et_al", "m3C_hg38_M1C_old",  "m3c_human_pfc_old", "m3c_Heffel_et_al"]
data_name = data_names[0]

dataset_path = os.path.join(data_path,data_name)
dataset_hic_path = os.path.join(dataset_path,"contact_pairs_filter")

In [22]:
# with open(os.path.join(dataset_path, "filelist.txt")) as f:
#     lines = f.readlines()
# for i, line in enumerate(lines):
#     filename = line.split("/")[-1]
#     lines[i] = dataset_path+"/contact_pairs_filter/"+filename
# with open("/data/cb/mihirb14/projects/Daifuku/data/mc3_mouse_brain_small/filelist.txt", 'w') as f2:
#     f2.writelines(lines)


In [3]:



# example_contact = np.load("/data/cb/mihirb14/projects/Daifuku/data/mc3_mouse_brain_small/contactmaps/raw/chr1_sparse_adj.npy",allow_pickle=True)
# print(example_contact.shape, example_contact[0].shape)
# contact_map = pre.spy_sparse2torch_sparse(example_contact[0]).to_dense()
# # px.imshow(contact_map,  color_continuous_scale='deep_r')

In [8]:
# import math

# pseudobulk_path = config["temp_dir"] + "bulk/"
# length = len(small_mouse_hic_dataloader)

# pseudobulk_map = np.zeros(shape=(196,196))
# for i, map in enumerate(small_mouse_hic_dataloader):
#     if i%20 == 0 and i!=0:
#         np.save(os.path.join(pseudobulk_path,str(math.ceil(i/20))), pseudobulk_map/20)
#         pseudobulk_map = np.zeros(shape=(196,196))
#     pseudobulk_map += map.numpy().squeeze()

In [28]:
def diagonal_normalize(map):
    normalized_map = np.zeros(map.shape)
    for k in range(map.shape[1]):
        diag = np.diag(map, k=k)
        diag_sum = np.sum(diag)
        if diag_sum != 0:
            normalized_diag = diag/diag_sum
            normalized_map += np.diagflat(normalized_diag,k=k)

    normalized_map = normalized_map + normalized_map.T - normalized_map * np.eye(map.shape[0])
    return normalized_map

def diagonal_unnormalize(original, new):
    unnormalized_map = np.zeros(new.shape)
    for k in range(original.shape[1]):
        original_diag = np.diag(original, k=k)
        new_diag = np.diag(new, k=k)
        original_diag_sum = np.sum(original_diag)
        unnormalized_diag = new_diag*original_diag_sum
        unnormalized_map += np.diagflat(unnormalized_diag,k=k)
    
    unnormalized_map = unnormalized_map + unnormalized_map.T - unnormalized_map * np.eye(new.shape[0])
    return unnormalized_map
    

def negone_to_one_normalize(map):
    min = np.min(map)
    max = np.max(map)
    return 2 * (map-min)/(max-min) - 1
    

def visualize_contact_map(map, zmax):
    fig = px.imshow(map.squeeze(),zmax=zmax,color_continuous_scale="darkmint", width=500)

    fig.update_layout(
        margin=dict(l=0, r=0, t=0, b=0),
    )
    return fig    

class BulkHiCDataset(Dataset):
    def __init__(self, config):
        self.contact_map_path = config["temp_dir"]
        self.dataset_path = config["data_dir"]
        self.file_list_path = config["file_list_path"]
        self.dataset_info = pd.DataFrame(pickle.load(open(self.dataset_path+"/label_info.pickle","rb")))
        self.chrom_list = config['chrom_list']
        self.num_cells =  len(self.dataset_info)
        self.num_chromosomes =  len(self.chrom_list)
        self.num_cells_pseudobulk = config["num_cells_pseudobulk"]
        self.is_sparse = True if config["is_sparse"] == "True" else False
        self.map_size = config["train_config"]["map_size"]
        self.normalize=config["train_config"]["normalization"]
        
        # Utilized when just selecting one chromosome
        self.selected_chromosome = config["selected_chrom"]


    def __len__(self):
        
        if self.selected_chromosome != "all":
            return len(list(Path(self.contact_map_path+"/bulk/").glob(f"*{self.selected_chromosome}*")))
        else:
           return len(os.listdir(self.contact_map_path+"/bulk/"))
    
    
    def __getitem__(self, idx):
        ''' 
            chrom_idx \in [0,num_chromosomes - 1]
            cell_idx  \in [0, num_cells - 1]
        '''
        if self.selected_chromosome == "all":
            chrom_idx = idx // self.num_cells
            cell_idx = idx - chrom_idx * self.num_cells
        else:
            chrom_idx = int(self.selected_chromosome.split("chr")[1]) - 1
            cell_idx = idx
            
        contact_map = np.load(self.contact_map_path+f"/bulk/chr{chrom_idx+1}_cell{str(cell_idx+1)}_pseudobulk.npy", allow_pickle=True) # load contact maps for chromosome at index chrom_idx

        if self.normalize == "diagonal":
            contact_map = diagonal_normalize(contact_map)

        transform = Pad((0,0,self.map_size - contact_map.shape[0], self.map_size - contact_map.shape[0]))
        return transform(torch.from_numpy(contact_map))


class ScHiCDataset(Dataset):
    def __init__(self, config):
        self.contact_map_path = config["temp_dir"]
        self.dataset_path = config["data_dir"]
        self.file_list_path = config["file_list_path"]
        self.dataset_info = pd.DataFrame(pickle.load(open(self.dataset_path+"/label_info.pickle","rb")))
        self.chrom_list = config['chrom_list']
        self.num_cells =  len(self.dataset_info)
        self.num_chromosomes =  len(self.chrom_list)
        self.is_sparse = True if config["is_sparse"] == "True" else False
        self.map_size = config["train_config"]["map_size"]
        # self.chrom1_size = self.get_map_info()
        # self.pseudobulk_maps, self.chrom1_size = self.get_map_info()

    def __len__(self):
        return self.num_cells*self.num_chromosomes

    def __getitem__(self, idx):
        ''' 
            chrom_idx \in [0,num_chromosomes - 1]
            cell_idx  \in [0, num_cells - 1]
        '''
        chrom_idx = idx // self.num_cells
        cell_idx = idx - chrom_idx * self.num_cells
        if self.is_sparse:
            contact_path = f"{self.contact_map_path}/sparse/chr{chrom_idx+1}_cell{str(cell_idx+1)}.npy"
        else:
            contact_path = f"{self.contact_map_path}/dense/chr{chrom_idx+1}_cell{str(cell_idx+1)}.npy"

        contact_map = np.load(contact_path, allow_pickle=True) # load contact maps for chromosome at index chrom_idx
        
        if self.is_sparse:
            # return contact_map_sparse
            contact_map_sparse = pre.spy_sparse2torch_sparse(contact_map)
            return
        else:
            transform = Pad((0,0,self.map_size - contact_map.shape[0], self.map_size - contact_map.shape[0]))
            return transform(torch.from_numpy(contact_map))
                    
    def get_map_info(self):
        # pseudobulk_maps = []
        # chrom1_mapsize = 0
        # for i, chrom in enumerate(self.chrom_list):
        #     map = diagonal_normalize(np.load(f"{self.contact_map_path}/dense/{chrom}_pseudobulk.npy"))
        #     map = torch.from_numpy(map).squeeze"()
        #     if i==0:
        #         chrom1_mapsize = map.shape[0]
        #     # pseudobulk_maps.append(map)
        return np.load(f"{self.contact_map_path}/dense/chr1_pseudobulk.npy").shape[0]
    
    


In [19]:
config = pre.get_config("/data/cb/mihirb14/projects/Daifuku/configs/config_m3c_mouse_brain_small_1M.json")
train_config = pre.get_config("/data/cb/mihirb14/projects/Daifuku/configs/train_configs/config_pseudobulk_smallmouse.json")
config["train_config"] = train_config
small_mouse_hic_dataset = ScHiCDataset(config)
small_mouse_hic_dataloader = DataLoader(small_mouse_hic_dataset, batch_size=1)
bulk_small_mouse_hic_dataset = BulkHiCDataset(config=config)
bulk_small_mouse_hic_dataloader = DataLoader(bulk_small_mouse_hic_dataset, batch_size=1)

len(bulk_small_mouse_hic_dataset)

5398

In [25]:
small_mouse_hic_dataset = ScHiCDataset(config)
from torch.utils.data import random_split
train_size = int(0.8 * len(small_mouse_hic_dataset))
test_size = len(small_mouse_hic_dataset) - train_size
train_dataset, test_dataset = random_split(small_mouse_hic_dataset, [train_size,test_size])
train_dataloader = DataLoader(train_dataset, batch_size=64)
next(iter(train_dataloader)).shape

torch.Size([64, 136, 136])

In [21]:
chrom1_cell1_map = next(iter(small_mouse_hic_dataloader))
print(chrom1_cell1_map.shape)
visualize_contact_map(chrom1_cell1_map.squeeze(),zmax=1)

torch.Size([1, 136, 136])


In [22]:
pseudobulk_1_map = next(iter(bulk_small_mouse_hic_dataloader))
print(pseudobulk_1_map.shape)
visualize_contact_map(pseudobulk_1_map.squeeze(), zmax=1)

torch.Size([1, 136, 136])


In [30]:
normalized = diagonal_normalize(pseudobulk_1_map.squeeze())
unnormalized = diagonal_unnormalize(pseudobulk_1_map.squeeze(), normalized)
visualize_contact_map(unnormalized, zmax=1)

In [31]:
from models import LightningDiffusion as ld

daifuku = ld.LightningDiffusion(train_config=train_config)

# checkpoint = torch.load("../out/trained/daifuku_bulksmallmouse1M_500epochs.ckpt")
# checkpoint = torch.load("../Daifuku/lmbg9mze/checkpoints/epoch=499-step=8500.ckpt")
# checkpoint = torch.load("../Daifuku/g9hx58dz/checkpoints/epoch=1999-step=90000.ckpt")
checkpoint = torch.load("../out/trained/daifuku_bulksmallmouse1M_2000epochs.ckpt")
daifuku.load_state_dict(checkpoint['state_dict'])
daifuku.eval()

sampled_map = daifuku.sample(batch_size=1)
sampled_map.shape

sampling loop time step: 100%|██████████| 250/250 [01:49<00:00,  2.29it/s]


torch.Size([1, 1, 136, 136])

In [33]:
visualize_contact_map(sampled_map.squeeze(), zmax=1)