In [17]:
from maps.screens import ImageScreenMultiAntibody
from maps.multiantibody.config import TrainConfig, ModelConfig, DataLoaderConfig
from  maps.multiantibody.data_loaders import create_multiantibody_dataloader

import json
import torch
import sys
from pathlib import Path
import pandas as pd

sys.path.append("/home/kkumbier/maps/")
import maps.multiantibody.models as models
import maps.multiantibody.training as training
import maps.multiantibody.evaluate as evaluate

root = Path("/home/kkumbier/als/scripts/maps/")
param_dir = root / "template_analyses/params/"
result_dir = root / "template_analyses/multimarker_multiclass/4-class"

# --- Initialize parameters ---
with open(param_dir / "multiclass-split-train.json", "r") as f:
    train_params = json.load(f)

with open(param_dir / "multiclass-split-test.json", "r") as f:
    test_params = json.load(f)

with open(param_dir / "multiclass-sporadic.json", "r") as f:
    sporadic_params = json.load(f)

response_map = {"WT": 0, "FUS": 1, "SOD1": 2, "C9orf72": 3}
antibodies = ["HSP70/SOD1", "FUS/EEA1", "COX IV/Galectin3/atubulin"]
print(json.dumps(train_params, indent=4))


{
    "name": "maps_dmso",
    "screen": "20250216_AWALS37_Full_screen_n96",
    "root": "/awlab/projects/2024_ALS/Experiments",
    "data_file": "Objects_Population - Nuclei Selected.txt",
    "eval_dir": "Evaluation1",
    "preprocess": {
        "drop_na_features": {
            "na_prop": 0.1
        },
        "drop_sample_by_feature": {
            "drop_key": [
                {
                    "CellLines": [
                        "C9014",
                        "NS048"
                    ]
                }
            ]
        },
        "select_sample_by_feature": {
            "select_key": [
                {
                    "Drugs": [
                        "DMSO"
                    ],
                    "Mutations": [
                        "FUS",
                        "WT",
                        "SOD1",
                        "C9orf72"
                    ],
                    "CellLines": [
                        "NS007",
                        

In [8]:
# Load and process screens for train / test
train_screen = ImageScreenMultiAntibody(train_params)
train_screen.load(antibody=antibodies)

test_screen = ImageScreenMultiAntibody(test_params)
test_screen.load(antibody=antibodies)

sporadic_screen = ImageScreenMultiAntibody(sporadic_params)
sporadic_screen.load(antibody=antibodies)

print("Processing training set...")
train_screen.preprocess()

print("Processing test set...")
test_screen.preprocess()

print("Processing sporadic set...")
sporadic_screen.preprocess()

for ab in  antibodies:
    print(f"Marker set: {ab}")
    print(f"Train data: {train_screen.data[ab].shape}")
    print(f"Test data: {test_screen.data[ab].shape}")
    print(f"Sporadic data: {sporadic_screen.data[ab].shape}")
    print(train_screen.metadata["FUS/EEA1"]["Mutations"].unique())

Processing training set...
Preprocessing complete
Processing test set...
Preprocessing complete
Processing sporadic set...
Preprocessing complete
Marker set: HSP70/SOD1
Train data: (46472, 311)
Test data: (53047, 311)
Sporadic data: (13782, 311)
shape: (4,)
Series: 'Mutations' [str]
[
	"C9orf72"
	"SOD1"
	"WT"
	"FUS"
]
Marker set: FUS/EEA1
Train data: (47053, 315)
Test data: (55199, 315)
Sporadic data: (14465, 315)
shape: (4,)
Series: 'Mutations' [str]
[
	"WT"
	"SOD1"
	"FUS"
	"C9orf72"
]
Marker set: COX IV/Galectin3/atubulin
Train data: (44234, 312)
Test data: (51160, 312)
Sporadic data: (14337, 312)
shape: (4,)
Series: 'Mutations' [str]
[
	"SOD1"
	"WT"
	"C9orf72"
	"FUS"
]


In [18]:
dataloader_config = DataLoaderConfig(
    batch_size = 9,
    n_cells = 250,
    response_map = {"Mutations": response_map}
    
)

train_dataloader = create_multiantibody_dataloader(
    train_screen,
    **vars(dataloader_config)
)

scalers = train_dataloader._get_scalers()
test_dataloader = create_multiantibody_dataloader(
    test_screen,
    scalers=scalers,
    **vars(dataloader_config)
)

In [10]:
eval_model = evaluate.eval
train = training.train
MultiAntibodyClassifier = models.MultiAntibodyClassifier

# Initiaize model and train
train_config = TrainConfig(
    n_epochs=100,
    patience=10,
    lr=5e-3,
    use_contrastive_loss=False
)

model_config = ModelConfig(
    d_model=16,
    n_layers=1,
    n_classes=len(response_map)
)

model_config.antibody_feature_dims = train_dataloader._get_feature_dims()

model_train = MultiAntibodyClassifier(**vars(model_config))
model_test = MultiAntibodyClassifier(**vars(model_config))

In [11]:
import os

train_wt_path = result_dir / "model_train_weights.pth"
if os.path.exists(train_wt_path):
    model_train.load_state_dict(torch.load(train_wt_path))
    print("Loaded existing model weights for model_train.")
else:
    train(model_train, train_dataloader, train_config)
    torch.save(model_train.state_dict(), train_wt_path)

test_wt_path = result_dir / "model_test_weights.pth"
if os.path.exists(test_wt_path):
    model_test.load_state_dict(torch.load(test_wt_path))
    print("Loaded existing model weights for model_test.")
else:
    train(model_test, test_dataloader, train_config)
    torch.save(model_test.state_dict(), test_wt_path)

Loaded existing model weights for model_train.
Loaded existing model weights for model_test.


In [19]:
dataloader_config.n_cells = 250
dataloader_config.mode = "eval"
n_reps = 10

train_dataloader = create_multiantibody_dataloader(
    train_screen,
    scalers=scalers,
    **vars(dataloader_config)
)

test_dataloader = create_multiantibody_dataloader(
    test_screen,
    scalers=scalers,
    **vars(dataloader_config)
)

dataloader_config.response_map = {"Mutations": {"sporadic": len(response_map)}}
sporadic_dataloader = create_multiantibody_dataloader(
    sporadic_screen,
    scalers=scalers,
    **vars(dataloader_config)
)

df = []
for rep in range(n_reps):
    df_test, _ = eval_model(model_train, test_dataloader)
    df_train, _ = eval_model(model_test, train_dataloader)
    df_sporadic_test, _ = eval_model(model_test, sporadic_dataloader)
    df_sporadic_train, _ = eval_model(model_train, sporadic_dataloader)
    
    df_train["Rep"] = rep
    df_test["Rep"] = rep
    df_sporadic_test["Rep"] = rep
    df_sporadic_train["Rep"] = rep
    
    df_rep = pd.concat([df_train, df_test, df_sporadic_test, df_sporadic_train])
    df_rep = df_rep.rename(columns={
        f"class_{i}_agg":f"prob_{i}" for i in range(len(response_map))
    })
    
    df.append(df_rep)

In [None]:
import pickle

df_full = pd.concat(df)
df = df_full.copy()
df = df.rename(columns={"True": "Label"})
df = df.groupby(["Label", "CellLines"]).mean(numeric_only=True).reset_index()

with open(result_dir / "train_test_predictions.pkl", "wb") as f:
    pickle.dump({"df": df, "df_full": df_full}, f)
