In [82]:
import warnings
warnings.filterwarnings('ignore')
import os
import sys
from pathlib import Path
import yaml

from tqdm import tqdm
import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams.update({
    "pgf.texsystem": "xelatex",
    'font.family': 'Arial',
    'text.usetex': False,
    'pgf.rcfonts': False,
    'figure.dpi': 300,
})

%load_ext autoreload
%autoreload 2
package_path = r'C:\Users\Mingchuan\Huanglab\PRISM\PRISM_Code\gene_calling'
if package_path not in sys.path: sys.path.append(package_path)

# workdir 
BASE_DIR = Path(r'G:\spatial_data\processed')
RUN_ID = '20250618_PJR_WSY_Huh7-18_AF_quench'
src_dir = BASE_DIR / f'{RUN_ID}'
read_dir = src_dir / 'readout'
figure_dir = read_dir / 'figures'
read_dir.mkdir(exist_ok=True)
figure_dir.mkdir(exist_ok=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [83]:
# parameters
with open(read_dir/'params.yaml', 'r') as f: params = yaml.load(f, Loader=yaml.UnsafeLoader)

# basic
PRISM_PANEL = params["PRISM_PANEL"]        
GLAYER = params["GLAYER"]                  
COLOR_GRADE = params["COLOR_GRADE"]        
Q_CHNS = params["Q_CHNS"]                  
Q_NUM = params["Q_NUM"]                    
# thresholds
thre_min = params["thre_min"]              
thre_max = params["thre_max"]              
# visualization
XRANGE = params["XRANGE"]                  
YRANGE = params["YRANGE"]                  
s = params["s"]                            
alpha = params["alpha"]                    
percentile_thre = params["percentile_thre"]
bins = tuple(params["bins"])               
# GMM
CD_1_proj = np.array(params["CD_1_proj"])  
CD_2_proj = np.array(params["CD_2_proj"])  
centroid_init_dict = {int(k): v for k, v in params["centroid_init_dict"].items()}
colormap = params["colormap"]

# Load data
intensity_raw = pd.read_csv(read_dir / 'tmp' / 'intensity_raw.csv')
intensity = pd.read_csv(read_dir / 'intensity_labeled.csv')
intensity_G = pd.read_csv(read_dir / 'intensity_G.csv')

In [116]:
## manual check
import cv2
from lib.manual_thre import draw_mask_manual

# Adjust mask
mask_check_dir = read_dir / 'mask_check'
mask_check_dir.mkdir(exist_ok=True)
draw_mask_manual(image_path=str(read_dir / 'figures' / '2-layer2.png'),
                 mask_path=str(read_dir / 'mask_check' / 'mask_30.png'))

In [117]:
if not 'label' in intensity.columns: intensity['label'] = 0
intensity_check = intensity[['Y','X','CD_1_blur', 'CD_2_blur', 'G/A', 'G_layer', 'label']]

# Adjust G
# intensity_check['check_G_layer'] = intensity_check['G_layer']
# adjusted_G = intensity_check[(intensity_check['G/A']>0.5)&(intensity_check['G/A']<1)]
# intensity_check.loc[adjusted_G.index, 'check_G_layer'] = 1

In [118]:
from lib.projection import generate_colormap

thermal_LUT = pd.read_csv(os.path.join(package_path, 'thermal_LUT.csv'), index_col=0)
thermal_LUT.loc[0] = [0, 0, 0]
cmap_thermal = generate_colormap(thermal_LUT.values)
cmap_thermal
import matplotlib
from lib.manual_thre import relabel_mask
from lib.projection import plot_params_generator, customize_axis, projection_gene

plot_params = plot_params_generator(x=intensity_check['X'], y=intensity_check['Y'], downsample_factor=100, edge=0.05)
plot_params['cmap'] = cmap_thermal
plot_params['percentile_max'] = 10
plot_params['percentile_min'] = 2

# 设置画布 
fig_height_inch = plot_params['fig_height_inch']
main_plot_width_inch = plot_params['main_plot_width_inch']
cbar_width_inch = plot_params['cbar_width_inch']
gap_width_inch = 1
full_height_inch = fig_height_inch + 2 * gap_width_inch
full_width_inch = main_plot_width_inch + fig_height_inch-gap_width_inch + cbar_width_inch + 3 * gap_width_inch
grid_resolution = 100
spots_color = 'lightblue'
select_spot_color = 'red'
# 保存图片
density_dir = read_dir / 'density_check'
density_dir.mkdir(exist_ok=True)
label_list = [_.split('_')[-1].split('.')[0] for _ in os.listdir(mask_check_dir) if _.endswith('.png')]
label_list = sorted(label_list, key=lambda x: int(x.split('-')[0]))
sample_frac = 0.01

for label in tqdm(label_list):
    intensity_check = intensity[['Y','X','CD_1_blur', 'CD_2_blur', 'G/A', 'G_layer', 'label']]
    mask = cv2.imread(os.path.join(mask_check_dir, f"mask_{label}.png"), cv2.IMREAD_GRAYSCALE).astype(bool)
    glayer = (int(label.split('-')[0])-1) // Q_NUM
    intensity_check = relabel_mask(intensity=intensity_check, plot_column=['CD_1_blur', 'CD_2_blur'], mask=mask, xlim=XRANGE,ylim=YRANGE, ch_label=label, mode='replace', G_layer=glayer)
    
    intensity_select = intensity_check[intensity_check['label']==label]
    intensity_check_sample = intensity_check.sample(frac=sample_frac)
    intensity_select_sample = intensity_select.sample(frac=sample_frac)
    glayer = (int(label.split('-')[0])-1) // Q_NUM

    fig = plt.figure(figsize=(full_width_inch, full_height_inch))
    fig.set_facecolor('black')
    one_cubic_num = int((fig_height_inch-gap_width_inch)*grid_resolution/2)
    r_width_num = int(main_plot_width_inch * grid_resolution)
    cbar_num = int(cbar_width_inch * grid_resolution)
    gap_width_num = int(gap_width_inch * grid_resolution)

    gs = fig.add_gridspec(2*one_cubic_num+3*gap_width_num, 2*one_cubic_num + 4*gap_width_num + r_width_num + cbar_num)
    ax_lu = fig.add_subplot(gs[gap_width_num: one_cubic_num+gap_width_num, gap_width_num: one_cubic_num+gap_width_num], projection='3d')
    ax_lu.set_axis_off()
    ax_lb = fig.add_subplot(gs[one_cubic_num+2*gap_width_num: 2*one_cubic_num+2*gap_width_num, gap_width_num: one_cubic_num+gap_width_num])
    ax_ru = fig.add_subplot(gs[gap_width_num: one_cubic_num+gap_width_num, one_cubic_num+2*gap_width_num: 2*one_cubic_num+2*gap_width_num])
    ax_rb = fig.add_subplot(gs[one_cubic_num+2*gap_width_num: 2*one_cubic_num+2*gap_width_num, one_cubic_num+2*gap_width_num: 2*one_cubic_num+2*gap_width_num])
    ax_r = fig.add_subplot(gs[gap_width_num: 2*one_cubic_num+2*gap_width_num, 2*one_cubic_num+4*gap_width_num: 2*one_cubic_num+4*gap_width_num+r_width_num])
    cax = fig.add_subplot(gs[gap_width_num: 2*one_cubic_num+2*gap_width_num, 2*one_cubic_num+4*gap_width_num+r_width_num: 2*one_cubic_num+4*gap_width_num+r_width_num+cbar_num])

    # Customizing each subplot
    customize_axis(ax_lu)
    customize_axis(ax_lb)
    customize_axis(ax_ru)
    customize_axis(ax_rb)
    customize_axis(ax_r)
    customize_axis(cax)

    # left upper
    x = intensity_check_sample['CD_1_blur']
    y = intensity_check_sample['CD_2_blur']
    z = intensity_check_sample['G/A']
    ax_lu.scatter(x, y, z, c=spots_color, marker='.', s=s, alpha=alpha, linewidths=None)
    x_select = intensity_select_sample['CD_1_blur']
    y_select = intensity_select_sample['CD_2_blur']
    z_select = intensity_select_sample['G/A']
    ax_lu.scatter(x_select, y_select, z_select, c=select_spot_color, marker='.', s=s, alpha=min(1,5*alpha), linewidths=None)
    ax_lu.set_xlim(XRANGE)
    ax_lu.set_ylim(YRANGE)
    ax_lu.set_zlim([-0.2,1])
    ax_lu.view_init(elev=20, azim=-70) 

    # left bottom    
    ax_lb.scatter(intensity_check_sample['CD_1_blur'], intensity_check_sample['G/A'], s=s, alpha=alpha, color=spots_color, linewidths=None)
    ax_lb.scatter(intensity_select_sample['CD_1_blur'], intensity_select_sample['G/A'], s=s, alpha=min(1,2*alpha), color=select_spot_color, linewidths=None)
    ax_lb.set_ylim([-0.2,1.4])

    # right upper
    ax_ru.scatter(intensity_check_sample['CD_2_blur'], intensity_check_sample['G/A'], s=s, alpha=alpha, color=spots_color, linewidths=None)
    ax_ru.scatter(intensity_select_sample['CD_2_blur'], intensity_select_sample['G/A'], s=s, alpha=min(1,2*alpha), color=select_spot_color, linewidths=None)
    ax_ru.set_ylim([-0.2,1.4])

    # right bottom
    intensity_layer = intensity_check_sample[intensity_check_sample['G_layer']==glayer]
    ax_rb.scatter(intensity_layer['CD_1_blur'], intensity_layer['CD_2_blur'], s=s, alpha=alpha, color=spots_color, linewidths=None)
    ax_rb.scatter(intensity_select_sample['CD_1_blur'], intensity_select_sample['CD_2_blur'], s=s, alpha=min(1,2*alpha), color=select_spot_color, linewidths=None)
    ax_rb.set_xlim(XRANGE)
    ax_rb.set_ylim(YRANGE)
    ax_rb.set_title(f'G={glayer}')

    # right: projection
    plot_params['ax'] = ax_r
    hist_im = projection_gene(x=intensity_select['X'], y=intensity_select['Y'], gene_name=label, plot_params_update=plot_params)
    # right: colorbar
    cbar = plt.colorbar(hist_im, cax=cax)
    cbar.set_label('Counts', color='white')
    cbar.ax.yaxis.set_tick_params(color='white')
    cbar.ax.yaxis.set_tick_params(labelcolor='white')
    cbar.formatter = matplotlib.ticker.FuncFormatter(lambda x, _: f'{round(x,1)}')
    cbar.update_ticks()

    # plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
    plt.savefig(density_dir / f'{label}.png', dpi=300, bbox_inches='tight')
    plt.close()

100%|██████████| 30/30 [02:42<00:00,  5.42s/it]
