In [1]:
import pandas as pd
import torch

from maps.screens import ImageScreenMultiAntibody
from maps.multiantibody.data_loaders import create_multiantibody_dataloader
from maps.multiantibody.models import MultiAntibodyClassifier
from maps.multiantibody.training import train
from maps.multiantibody.config import TrainConfig, ModelConfig, DataLoaderConfig

import json
import wandb
from sklearn.metrics import roc_auc_score

# --- Initialize parameters ---
pdir = "/home/kkumbier/als/scripts/maps/template_analyses/params/"
with open(pdir + "maps_multiantibody.json", "r") as f:
    params = json.load(f)
    

dataloader_config = DataLoaderConfig()
train_config = TrainConfig()
model_config = ModelConfig()
model_config.n_classes = 2

print(json.dumps(params, indent=4))


{
    "name": "maps_dmso",
    "screen": "20250216_AWALS37_Full_screen_n96",
    "root": "/awlab/projects/2024_ALS/Experiments",
    "data_file": "Objects_Population - Nuclei Selected.txt",
    "eval_dir": "Evaluation1",
    "preprocess": {
        "drop_na_features": {
            "na_prop": 0.1
        },
        "drop_sample_by_feature": {
            "drop_key": [
                {
                    "CellLines": [
                        "C9014",
                        "NS048"
                    ]
                }
            ]
        },
        "select_sample_by_feature": {
            "select_key": [
                {
                    "Drugs": [
                        "DMSO"
                    ],
                    "Mutations": [
                        "FUS",
                        "WT"
                    ]
                }
            ]
        },
        "drop_cells_by_feature_qt": {
            "feature_filters": {
                "Nucleus_Region_Area_[\u00b5m\

In [None]:
# Create and load screen
screen = ImageScreenMultiAntibody(params)
screen.load(antibody=["FUS/EEA1"])
screen.preprocess()

In [3]:
dataloader = create_multiantibody_dataloader(
    screen,
    **vars(dataloader_config)
)

In [4]:
# Initiaize model and train
model_config.antibody_feature_dims = dataloader._get_feature_dims()
model = MultiAntibodyClassifier(**vars(model_config))
train(model, dataloader, train_config)

Epoch 1/10, Loss: 1.00300931930542
Epoch 2/10, Loss: 1.1400103569030762
Epoch 3/10, Loss: 0.9668740034103394
Epoch 4/10, Loss: 1.0532829761505127
Epoch 5/10, Loss: 1.078003168106079
Epoch 6/10, Loss: 0.9099025726318359
Epoch 7/10, Loss: 1.0770238637924194
Epoch 8/10, Loss: 1.065738558769226
Epoch 9/10, Loss: 0.8835322856903076
Epoch 10/10, Loss: 0.9150745272636414


In [5]:
line_head = model.line_head
feature_weight = line_head.fc.weight.abs().softmax(dim=-1)
feature_weight = feature_weight.sum(dim=0)
print(feature_weight.shape)
print(f"Group1: {feature_weight[:32].sum()}")
print(f"Group2: {feature_weight[32:].sum()}")

torch.Size([64])
Group1: 1.0110417604446411
Group2: 0.9889581799507141


In [6]:
model.eval()
all_probs = []
all_labels = []
all_lines = []
dataloader.mode = "eval"  # Set dataloader to evaluation mode
device = next(model.parameters()).device  # Get device from model parameters
print(device)

with torch.no_grad():
    for batch in dataloader:
        if batch is None:
            continue
        
        x_dict = {ab: batch[ab][0].to(device) for ab in batch}
        y_line = batch[list(batch.keys())[0]][1].to(device)
        cl = [batch[list(batch.keys())[0]][-1]]
        _, line_logits = model(x_dict)
        probs = torch.softmax(line_logits, dim=1)
        all_probs.append(probs.cpu())
        all_labels.append(y_line.cpu())
        all_lines.extend(cl)

all_probs = torch.cat(all_probs, dim=0)
all_labels = torch.cat(all_labels, dim=0)

preds = pd.DataFrame(all_probs.numpy())
preds.columns = [f"Class_{i}" for i in range(model_config.n_classes)]   
preds["CellLines"] = all_lines
preds["True"] = all_labels.numpy()

cuda:0


In [7]:
print(preds.sort_values("Class_0", ascending=False))

     Class_0   Class_1 CellLines  True
11  0.069124  0.930876     FTD75     0
16  0.054935  0.945065   ND29563     1
13  0.046083  0.953917    GIH158     0
3   0.041696  0.958304     ALS18     0
7   0.038726  0.961274  CRL-2201     0
27  0.034986  0.965014     NS045     1
4   0.033842  0.966158     ALS20     0
8   0.032774  0.967226  CRL-2510     0
30  0.028189  0.971811    SS2032     1
0   0.026562  0.973438        11     0
25  0.025578  0.974422     NS036     1
1   0.021630  0.978370      ALS1     0
26  0.017453  0.982547     NS044     1
10  0.016219  0.983781     FTD73     0
12  0.015109  0.984891   GIH-130     1
5   0.014874  0.985126     ALS24     0
28  0.013457  0.986543     NS046     1
15  0.012260  0.987740     GIH91     0
23  0.011152  0.988848     NS030     1
22  0.011111  0.988889     NS020     1
14  0.011092  0.988908    GIH166     0
20  0.011073  0.988927     NS003     1
6   0.009591  0.990409     ALS45     0
2   0.009372  0.990628     ALS15     0
18  0.008753  0.991247   