In [None]:
import umap.umap_ as umap
from sklearn.preprocessing import StandardScaler
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd

import numpy as np
import sklearn.metrics
import os
import sys

from tqdm import tqdm

sys.path.append("../profiling/")
import profiling

In [None]:
OUTPUT_FILE = "well_level_profiles_vits_LINCS_1e-5_final.csv"
REG_PARAM = 1e-5

In [None]:
# Load metadata
meta = pd.read_csv("/scr/data/LINCS/DP-project/outputs/max_concentration_set/sc-metadata.csv")
meta_add = pd.read_csv("/scr/data/LINCS/DP-project/outputs/SQ00015147_maxconc/sc-metadata.csv")

In [None]:
files = [str(f) for f in Path('/scr/zitong/cp_dino/experiments/DINO01/LINCS/').glob('features*.pth')]
open_files = [torch.load(f) for f in files]
features = torch.concat([f[0] for f in open_files])

print("Total images:",features.shape[0])
print("Number of features per image:", features.shape[1])

# cell_names = np.concatenate(([f[1] for f in open_files]))
# torch.save((features, cell_names), '/scr/zitong/cp_dino/experiments/DINOBASE/Combined/features.pth')

In [None]:
files_add = [str(f) for f in Path('/scr/zitong/cp_dino/experiments/DINO01/LINCS/additional/').glob('features*.pth')]
open_files_add = [torch.load(f) for f in files_add]
features_add = torch.concat([f[0] for f in open_files_add])

print("Total images:",features_add.shape[0])
print("Number of features per image:", features_add.shape[1])

# cell_names = np.concatenate(([f[1] for f in open_files]))
# torch.save((features, cell_names), '/scr/zitong/cp_dino/experiments/DINOBASE/Combined/features.pth')

In [None]:
scaled_features = StandardScaler().fit_transform(features.numpy())
scaled_features_add = StandardScaler().fit_transform(features_add.numpy())


In [None]:
cell_names = np.concatenate(([f[1] for f in open_files]))
order, ordered_features = (np.array(t) for t in zip(*sorted(zip(cell_names, scaled_features))))

cell_names_add = np.concatenate(([f[1] for f in open_files_add]))
order_add, ordered_features_add = (np.array(t) for t in zip(*sorted(zip(cell_names_add, scaled_features_add))))

# 2. Site-level profiles / Median Aggregation

In [None]:
group_dict = meta.groupby('Key').groups
print("Grouping finished.")

In [None]:
site_level_data = []
site_level_features = []

for site_name in tqdm(list(group_dict.keys())):
    metadata = site_name.split('/')
    if metadata[0] == 'SQ00015147': continue
    indices = group_dict[site_name]
    mean_profile = np.median(ordered_features[indices], axis=0)
    
    site_level_data.append(
        {
            "Plate": metadata[0],
            "Well": metadata[1],
            "Treatment": meta["Treatment"][indices].unique()[0]
        }

    )
    site_level_features.append(mean_profile)


In [None]:
site_dict = meta_add.groupby(['Key']).groups
for site_name in tqdm(list(site_dict.keys())):
    metadata = site_name.split('/')
    indices = site_dict[site_name]
    mean_profile = np.median(ordered_features_add[indices], axis=0)
    
    site_level_data.append(
        {
            "Plate": metadata[0],
            "Well": metadata[1],
            "Treatment": meta_add["Treatment"][indices].unique()[0]
        }

    )
    site_level_features.append(mean_profile)

In [None]:
num_features = 384
columns1 = ["Plate", "Well", "Treatment"] # dataset
columns2 = [i for i in range(num_features)]

sites1 = pd.DataFrame(columns=columns1, data=site_level_data)
sites2 = pd.DataFrame(columns=columns2, data=site_level_features)
sites = pd.concat([sites1, sites2], axis=1)

In [None]:
sites["Treatment_Clean"] = sites["Treatment"].apply(lambda x: "-".join([str(i) for i in x.split("-")[:2]]))

# 3. Well-level profiles / Mean Aggregation

In [None]:
# Collapse well data
wells = sites.groupby(["Plate", "Well", "Treatment", "Treatment_Clean"]).mean().reset_index()
wells[:10]

In [None]:
wells.to_csv("Wells_Prewhitened_ViT_small_LINCS.csv")

# 4. Whitening

In [None]:
sum(wells["Treatment"].isin(["DMSO@NA"]))

In [None]:
whN = profiling.WhiteningNormalizer(wells.loc[wells["Treatment"].isin(["DMSO@NA"]), columns2], REG_PARAM)
whD = whN.normalize(wells[columns2])

In [None]:
# Save whitened profiles
wells[columns2] = whD
wells.to_csv(OUTPUT_FILE, index=False)