In [1]:
import torch
import glob
import json
import os

import pandas as pd
import torch
import numpy as np
from sklearn import preprocessing, decomposition, metrics, impute
import torch, torchvision
from tqdm import tqdm
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch.nn as nn


def train_cls(exp_folder, n_labels, train_X, train_Y, test_X, test_Y, mode, model_save_path):
    model = nn.Sequential(
        nn.Linear(672 if mode == "cp" else 2048, 256),
        nn.BatchNorm1d(256),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(256, n_labels),
    )
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters())
    train_dataset = torch.utils.data.TensorDataset(
        torch.from_numpy(train_X).float(), torch.from_numpy(train_Y).long()
    )
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)

    test_dataset = torch.utils.data.TensorDataset(
        torch.from_numpy(test_X).float(), torch.from_numpy(test_Y).long()
    )
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False)

    best_accuracy = 0
    accuracies = []
    for i in range(500):
        model.train()
        running_total_loss = 0
        for _, sample in enumerate(train_loader):
            optimizer.zero_grad()
            op = model(sample[0])
            loss = criterion(op, sample[1])
            loss.backward()
            optimizer.step()
            running_total_loss += loss.item()

        if i % 50 == 0:
            model.eval()
            accuracy = 0
            for _, sample in enumerate(test_loader):
                op = model(sample[0])
                loss = criterion(op, sample[1])
                accuracy += (torch.argmax(op, dim=1) == sample[1]).sum().item()
            if accuracy > best_accuracy:
                best_accuracy = accuracy
                torch.save(model.state_dict(), os.path.join(exp_folder, model_save_path))
            print(
                "Epoch: {}, Loss: {}, Accuracy: {}".format(
                    i,
                    running_total_loss / len(test_loader.dataset),
                    accuracy / len(test_loader.dataset),
                )
            )
            accuracies.append(accuracy / len(test_loader.dataset))

    print("Best Accuracy:", np.max(accuracies))
    model.load_state_dict(torch.load(os.path.join(exp_folder, model_save_path)))
    with torch.no_grad():
        model.eval()
        outputs = []
        targets = []
        for _, sample in enumerate(test_loader):
            op = model(sample[0])
            outputs.append(torch.argmax(op, dim=1).cpu().numpy())
            targets.append(sample[1].cpu().numpy())
    return outputs, targets


In [3]:
bf_folders = sorted(glob.glob("/proj/haste_berzelius/exps/specs_non_grit_based/*bf*"))[:5]
exp_cat = "bf_11cls_basic_aug_dmsonorm_750e_sgd"

In [4]:
cp_df_cell = pd.read_csv("stats/non_grit_based/CP_features_cells.csv")
site_conversion = pd.DataFrame(
    {"bf_sites": ["s1", "s2", "s3", "s4", "s5"], "f_sites": ["s2", "s4", "s5", "s6", "s8"]}
)

feature_groups = [
    "AreaShape",
    "Correlation",
    "Granularity",
    "Intensity",
    "Neighbors",
    "RadialDistribution",
]
cp_feature_columns = [c for c in cp_df_cell.columns if c.startswith(tuple(feature_groups))]

for bf_folder in bf_folders:
    exp_folder = os.path.join(bf_folder, "bf_11cls_basic_aug_dmsonorm_750e_sgd", "ResNet_resnet50")
    print(exp_folder)
    if os.path.exists(exp_folder):
        model_train_df = pd.read_csv(os.path.join(exp_folder, "feature_data_train.csv"))
        model_train_df["site"] = model_train_df["site"].map(
            site_conversion.set_index("bf_sites")["f_sites"]
        )
        cp_bf_df = pd.merge(
            model_train_df, cp_df_cell, on=["plate", "well", "compound", "moa", "site"]
        )
        cp_train_df = cp_bf_df[cp_df_cell.columns]

        model_test_df = pd.read_csv(os.path.join(exp_folder, "feature_data_test.csv"))
        model_test_df["site"] = model_test_df["site"].map(
            site_conversion.set_index("bf_sites")["f_sites"]
        )
        cp_bf_df = pd.merge(
            model_test_df, cp_df_cell, on=["plate", "well", "compound", "moa", "site"]
        )
        cp_test_df = cp_bf_df[cp_df_cell.columns]

        test_unique_moas = np.unique(cp_test_df["moa"])
        n_labels = len(test_unique_moas)

        le = preprocessing.LabelEncoder()
        le.fit(cp_train_df["moa"])

        cp_train_df["moa_label"] = cp_train_df["moa"].apply(lambda x: le.transform([x])[0].item())
        cp_test_df["moa_label"] = cp_test_df["moa"].apply(lambda x: le.transform([x])[0].item())

        train_X = cp_train_df[cp_feature_columns].values
        imputer = impute.SimpleImputer(missing_values=np.nan, strategy="mean").fit(train_X)
        train_X = imputer.transform(train_X)
        normalize = preprocessing.StandardScaler().fit(train_X)
        train_X = normalize.transform(train_X)
        train_Y = cp_train_df["moa_label"].values

        test_X = cp_test_df[cp_feature_columns].values
        test_X = imputer.transform(test_X)
        test_X = normalize.transform(test_X)
        test_Y = cp_test_df["moa_label"].values

        outputs, targets = train_cls(
            exp_folder, n_labels, train_X, train_Y, test_X, test_Y, "cp", "cp_moa.pth"
        )

        pred_moa = le.inverse_transform(np.concatenate(outputs))
        moa = le.inverse_transform(np.concatenate(targets))

        report = metrics.classification_report(moa, pred_moa, output_dict=True)
        report_df = pd.DataFrame(report).transpose()
        report_df.to_csv(os.path.join(exp_folder, "cp_moa_report.csv"))

        df = pd.DataFrame(
            {
                "plate": cp_test_df["plate"].values,
                "well": cp_test_df["well"].values,
                "site": cp_test_df["site"].values,
                "compound": cp_test_df["compound"].values,
                "moa": moa,
                "pred_moa": pred_moa,
            }
        )
        df.to_csv(os.path.join(exp_folder, "cp_moa_analysis.csv"), index=False)


/proj/haste_berzelius/exps/specs_non_grit_based/bf_exps_1_split1/bf_11cls_basic_aug_dmsonorm_750e_sgd/ResNet_resnet50
Epoch: 0, Loss: 0.024468150028892532, Accuracy: 0.5632424877707897
Epoch: 50, Loss: 0.0012802896665827033, Accuracy: 0.7484276729559748
Epoch: 100, Loss: 0.0007091910232725384, Accuracy: 0.7540181691125087
Epoch: 150, Loss: 0.0003640659649385846, Accuracy: 0.756114605171209
Epoch: 200, Loss: 0.00023983435244524003, Accuracy: 0.7596086652690426
Epoch: 250, Loss: 0.0003011822628910353, Accuracy: 0.7686932215234102
Epoch: 300, Loss: 0.00018357050297280694, Accuracy: 0.7575122292103424
Epoch: 350, Loss: 0.00020603578313820685, Accuracy: 0.7672955974842768
Epoch: 400, Loss: 0.00016401730579337733, Accuracy: 0.76659678546471
Epoch: 450, Loss: 0.00026435094835414993, Accuracy: 0.7721872816212438
Best Accuracy: 0.7721872816212438
/proj/haste_berzelius/exps/specs_non_grit_based/bf_exps_1_split2/bf_11cls_basic_aug_dmsonorm_750e_sgd/ResNet_resnet50
Epoch: 0, Loss: 0.02372290397602