In [1]:
from tqdm import tqdm

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score, precision_score, recall_score
import cytoflow as flow
import torch
from collections import Counter

In [2]:
basepath = "./data/"
tube1 = flow.Tube(file = basepath + '494.csv')
ex = flow.ImportOp(tubes = [tube1]).apply()

cd19_thresh = flow.ThresholdOp(name = "CD19pos",channel = 'CD19', threshold = 1.5)
ex2 = cd19_thresh.apply(ex)

human_conditions = flow.BulkConditionOp(conditions_csv_path =  basepath + '494_labels.csv',
                                      combine_order = ["syto", "singlets", "intact","cd19", "blast"],
                                      combined_conditions_name="human_gt",
                                      combined_condition_default ="other")

ex3 = human_conditions.apply(ex2)

markers_to_use = ex3.channels[1:-2]
type(markers_to_use)


list

In [3]:
print(markers_to_use)
print(ex3.data)

['CD10', 'CD19', 'CD20', 'CD34', 'CD38', 'CD45', 'FSC-A', 'FSC-W', 'SSC-A']
           TIME     FSC-A     FSC-W     SSC-A      CD20      CD10      CD45  \
0      0.256109  0.724819  0.000000  0.993575  0.511647  0.737258  3.116999   
1      0.256109  2.111998  1.095594  0.789835  2.883300  4.009430  3.395909   
2      0.256109  2.071740  1.094057  4.499983  1.406279  1.910993  2.866903   
3      0.256150  1.923883  1.095217  0.648713  1.031393  4.020207  2.840137   
4      0.256190  1.580198  1.005982  0.781159  1.264268  1.307244  3.775684   
...         ...       ...       ...       ...       ...       ...       ...   
83406  4.499838  4.229612  1.110860  4.499983  1.581865  1.840961  2.920752   
83407  4.499919  1.721509  1.086372  0.589498  2.136454  3.867376  3.045547   
83408  4.499959  1.888425  1.035353  1.412427  3.039960  4.131422  3.501659   
83409  4.500000  1.479483  1.069826  0.567807  1.333666  4.141739  3.046796   
83410  4.500000  1.235837  1.022533  0.782086  2.860944

In [4]:
df = ex3.data 
# Initialize LabelEncoder
label_encoder = LabelEncoder()

# Encode the categorical column
df['human_gt'] = label_encoder.fit_transform(df['human_gt'])

# Separate features and labels
X = df.drop('human_gt', axis=1).values
y = df['human_gt'].values

label_mapping = dict(zip(label_encoder.classes_, range(len(label_encoder.classes_))))

print(label_mapping)

{'blast': 0, 'cd19': 1, 'intact': 2, 'other': 3, 'singlets': 4, 'syto': 5}


In [5]:
events_tensor = torch.Tensor(ex3.data[markers_to_use].to_numpy()).unsqueeze(dim=0)
print(events_tensor)
print(events_tensor.shape)

print(len(markers_to_use))

tensor([[[0.7373, 2.8390, 0.5116,  ..., 0.7248, 0.0000, 0.9936],
         [4.0094, 2.4233, 2.8833,  ..., 2.1120, 1.0956, 0.7898],
         [1.9110, 1.8300, 1.4063,  ..., 2.0717, 1.0941, 4.5000],
         ...,
         [4.1314, 3.5735, 3.0400,  ..., 1.8884, 1.0354, 1.4124],
         [4.1417, 2.5612, 1.3337,  ..., 1.4795, 1.0698, 0.5678],
         [3.8888, 3.3905, 2.8609,  ..., 1.2358, 1.0225, 0.7821]]])
torch.Size([1, 83411, 9])
9


In [6]:
from transformers import AutoModel
flowformer = AutoModel.from_pretrained("matth/flowformer", trust_remote_code=True)

In [7]:
output = flowformer(events_tensor, markers= markers_to_use)
print(output)

{'logits': tensor([[ -2.1165,   1.3967, -10.9895,  ...,   5.8799,   6.8554,   7.1269]],
       grad_fn=<SelectBackward0>), 'prediction': tensor([[0, 1, 0,  ..., 1, 1, 1]])}


In [8]:
output["prediction"] = output["prediction"].squeeze().numpy()

In [9]:
print(output)

{'logits': tensor([[ -2.1165,   1.3967, -10.9895,  ...,   5.8799,   6.8554,   7.1269]],
       grad_fn=<SelectBackward0>), 'prediction': array([0, 1, 0, ..., 1, 1, 1], dtype=int64)}


In [10]:
ex3.add_condition("prediction",dtype="bool", data =output["prediction"])

In [11]:
from sklearn.metrics import classification_report
print(classification_report(ex3["prediction"], df["blast"]))

              precision    recall  f1-score   support

       False       0.99      0.98      0.99     49296
        True       0.98      0.98      0.98     34115

    accuracy                           0.98     83411
   macro avg       0.98      0.98      0.98     83411
weighted avg       0.98      0.98      0.98     83411



In [12]:
from sklearn.metrics import classification_report
print(classification_report(ex3["prediction"], df["cd19"]))

              precision    recall  f1-score   support

       False       0.99      0.90      0.94     49296
        True       0.87      0.98      0.92     34115

    accuracy                           0.93     83411
   macro avg       0.93      0.94      0.93     83411
weighted avg       0.94      0.93      0.93     83411



In [13]:
from sklearn.metrics import classification_report
print(classification_report(ex3["prediction"], df["intact"]))

              precision    recall  f1-score   support

       False       0.93      0.11      0.20     49296
        True       0.43      0.99      0.60     34115

    accuracy                           0.47     83411
   macro avg       0.68      0.55      0.40     83411
weighted avg       0.72      0.47      0.36     83411



In [14]:
from sklearn.metrics import classification_report
print(classification_report(ex3["prediction"], df["singlets"]))

              precision    recall  f1-score   support

       False       0.93      0.09      0.17     49296
        True       0.43      0.99      0.60     34115

    accuracy                           0.46     83411
   macro avg       0.68      0.54      0.39     83411
weighted avg       0.73      0.46      0.35     83411



In [15]:
from sklearn.metrics import classification_report
print(classification_report(ex3["prediction"], df["syto"]))

              precision    recall  f1-score   support

       False       0.99      0.06      0.11     49296
        True       0.42      1.00      0.60     34115

    accuracy                           0.44     83411
   macro avg       0.71      0.53      0.35     83411
weighted avg       0.76      0.44      0.31     83411

