In [None]:
import umap.umap_ as umap
from sklearn.preprocessing import StandardScaler
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd

import numpy as np
import sklearn.metrics
import os
import sys

from tqdm import tqdm

sys.path.append("../profiling/")
import profiling

In [None]:
feature_path = '/scr/data/LINCS-DINO/features/24_07_04_vitsmall_features.npz'
feature_add_path = ''

output_folder = '/scr/lanfang/cell-painting-devbench/LINCS/data'
output_file = "well_level_profiles_vits_LINCS_1e-5_final.csv"
REG_PARAM = 1e-5

In [None]:
# Load metadata
meta = pd.read_csv("/scr/data/LINCS-DINO/max_concentration_set/sc-metadata.csv")

In [None]:
# Define the feature path
feature_path = '/scr/data/LINCS-DINO/features/24_07_04_vitsmall_features.npz'

# Load the .npz file
data = np.load(feature_path)

# Print the keys in the .npz file
print("Keys in the .npz file:", data.files)

# Get the array associated with the 'features' key
features = data['features']

# Print the shape of the array
print("Array shape:", features.shape)

# Get the total number of images and the number of features per image
total_images = features.shape[0]
features_per_image = features.shape[1]

print("Total images:", total_images)
print("Number of features per image:", features_per_image)


In [None]:
scaled_features = StandardScaler().fit_transform(features)

In [None]:
#cell_names = np.concatenate(([f[1] for f in open_files]))
#order, ordered_features = (np.array(t) for t in zip(*sorted(zip(cell_names, scaled_features))))

In [None]:
meta

# 2. Site-level profiles / Median Aggregation

In [None]:
group_dict = meta.groupby('Key').groups
print("Grouping finished.")

In [None]:
#all_data = pd.concat([meta, pd.DataFrame(features)], axis=1)

In [None]:
site_level_data = []
site_level_features = []

for site_name in tqdm(list(group_dict.keys())):
    metadata = site_name.split('/')
    indices = group_dict[site_name]
    mean_profile = np.median(features[indices], axis=0)
    
    site_level_data.append(
        {
            "Plate": metadata[0],
            "Well": metadata[1],
            "Treatment": meta["Treatment"][indices].unique()[0]
        }

    )
    site_level_features.append(mean_profile)


In [None]:
num_features = 384
columns1 = ["Plate", "Well", "Treatment"] # dataset
columns2 = [i for i in range(num_features)]

sites1 = pd.DataFrame(columns=columns1, data=site_level_data)
sites2 = pd.DataFrame(columns=columns2, data=site_level_features)
sites = pd.concat([sites1, sites2], axis=1)

In [None]:
sites["Treatment_Clean"] = sites["Treatment"].apply(lambda x: "-".join([str(i) for i in x.split("-")[:2]]))

# 3. Well-level profiles / Mean Aggregation

In [None]:
# Collapse well data
wells = sites.groupby(["Plate", "Well", "Treatment", "Treatment_Clean"]).mean().reset_index()
wells[:10]

In [None]:
wells.to_csv(f"{output_folder}/Wells_Prewhitened_ViT_small_LINCS.csv")

# 4. Whitening

In [None]:
sum(wells["Treatment"].isin(["DMSO@NA"]))

In [None]:
whN = profiling.WhiteningNormalizer(wells.loc[wells["Treatment"].isin(["DMSO@NA"]), columns2], REG_PARAM)
whD = whN.normalize(wells[columns2])

In [None]:
# Save whitened profiles
wells[columns2] = whD
wells.to_csv(f'{output_folder}/{output_file}', index=False)