In [3]:
from pathlib import Path
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import numpy as np
from tqdm import tqdm 
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')

In [65]:
def load_and_plot_pca(path, outdir, hue, output=False, label_type="class"):
    feat_type = path.parent.name
    target = path.parent.parent.name
    title = f"{target}-{feat_type}"
    data = np.load(path)
    
    x = data[:, :-1]
    x = StandardScaler().fit_transform(x)
    y = data[:, -1]
    pca = PCA(n_components=2)
    comps = pca.fit_transform(x)
    
    df = pd.DataFrame(comps,columns=["comp_0", "comp_1"])
    df[hue] = y
    f, ax = plt.subplots(figsize=(7,7))    
    
    if label_type=="class":
#         alpha = np.zeros(df.shape[0])
#         print(alpha.shape)
#         alpha[y ==1] = 0.2
#         alpha[y ==0] = 0.8
        for group in df.groupby('decoy'):
            group_label, group_df = group
            if group_label == 0:
                sns.scatterplot(x="comp_0", y="comp_1", data=group_df, ax=ax, s=5, alpha=1, color="red")
            elif group_label == 1:
                sns.scatterplot(x="comp_0", y="comp_1", data=group_df, ax=ax, s=1, alpha=0.25, color="blue")
            
    else:
        sns.scatterplot(x="comp_0", y="comp_1", data=df, hue=hue, ax=ax, s=5)
        
    ax.set_title(title)
    
    if output:
        output_path = Path(f"{outdir}/{title}.png")
        if not output_path.exists():
            output_path.parent.mkdir(exist_ok=True)
        plt.savefig(output_path, dpi=300)
        plt.close(f)

## Label By Class

In [70]:
from tqdm import tqdm 

def main_class(output):
    p = Path("/usr/WS1/jones289/hd-cuda-master/datasets/dude/deepchem_feats")
    path_list = list(p.glob("**/data.npy"))

    for path in tqdm(path_list):
        load_and_plot_pca(path, outdir="dude_figs", hue='decoy', output=output, label_type="class")

def main_reg(output):
    p = Path("/usr/WS1/jones289/hd-cuda-master/datasets/dude/deepchem_feats_labeled_by_gbsa/")
    path_list = list(p.glob("**/data.npy"))
    for path in tqdm(path_list):
        load_and_plot_pca(path, output_dir="dude_gbsa_figs", hue="best_gbsa_score", output=output, label_type="reg")


In [71]:
main_class(output=True)

100%|██████████| 242/242 [13:30<00:00,  3.35s/it]
