In [152]:
import torch
from torch.utils.data import DataLoader
from cellshape_cloud.vendor.chamfer_distance import ChamferLoss
from cellshape_cloud.pointcloud_dataset import (
    PointCloudDataset,
    SingleCellDataset,
)
from cellshape_cloud.cloud_autoencoder import CloudAutoEncoder
from torch.utils.data import Dataset
import pandas as pd
import os
import torch
from torch.utils.data import Dataset
from sklearn.preprocessing import LabelEncoder
import pandas as pd
import os
from pyntcloud import PyntCloud

path = "/home/mvries/Documents/CellShape/UploadData/cellshapeOutput/Models/cloud_autoencoder/dgcnn_foldinget_50_004_predictions_NEW_09.pt"

In [153]:
model = CloudAutoEncoder(
    num_features=50, k=20, encoder_type="dgcnn", decoder_type="foldingnet"
)
checkpoint = torch.load(path)
model_dict = model.state_dict()  # load parameters from pre-trained FoldingNet

model.load_state_dict(checkpoint["model_state_dict"])
checkpoint["loss"]

47067.47621536255

In [154]:
class SingleCellDatasetNew(Dataset):
    def __init__(
        self,
        annotations_file,
        points_dir,
        img_size=400,
        transform=None,
        cell_component="cell",
    ):
        self.annot_df = pd.read_csv(annotations_file)
        self.img_dir = points_dir
        self.img_size = img_size
        self.transform = transform
        self.cell_component = cell_component

        self.new_df = self.annot_df[
            (self.annot_df.xDim <= self.img_size)
            & (self.annot_df.yDim <= self.img_size)
            & (self.annot_df.zDim <= self.img_size)
        ].reset_index(drop=True)

    def __len__(self):
        return len(self.new_df)

    def __getitem__(self, idx):
        # read the image
        treatment = self.new_df.loc[idx, "Treatment"]
        plate_num = "Plate" + str(self.new_df.loc[idx, "PlateNumber"])
        if self.cell_component == "cell":
            component_path = "stacked_pointcloud"
        else:
            component_path = "stacked_pointcloud_nucleus"

        img_path = os.path.join(
            self.img_dir,
            plate_num,
            component_path,
            treatment,
            self.new_df.loc[idx, "serialNumber"],
        )
        image = PyntCloud.from_file(img_path + ".ply")
        image = torch.tensor(image.points.values)
        mean = torch.tensor([[13.4828, 26.5144, 24.4187]])
        std = torch.tensor([[9.2821, 20.4512, 18.9049]])

        image = (image - mean) / std
        # return the classical features as torch tensor
        feats = self.new_df.iloc[idx, 16:-4]
        feats = torch.tensor(feats)

        serial_number = self.new_df.loc[idx, "serialNumber"]

        return image, treatment, feats, serial_number

In [155]:
root_dir = "/home/mvries/Documents/CellShape/UploadData/cellshapeData/"
df = "/home/mvries/Documents/CellShape/UploadData/cellshapeData/all_cell_data.csv"
dataset = SingleCellDatasetNew(df, root_dir, cell_component="cell")

dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [156]:
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
all_feat_cell = []
all_outputs = []
all_inputs = []
labels_cell = []
serial_numbers_cell = []
for data in tqdm(dataloader):
    inputs = data[0]
    lab = data[1]
    ser_num = data[3]
    inputs = inputs.to(device)
    batch_size = inputs.shape[0]

    output, features = model(inputs)
    all_inputs.append(torch.squeeze(inputs).detach().cpu().numpy())
    all_outputs.append(torch.squeeze(output).detach().cpu().numpy())
    all_feat_cell.append(torch.squeeze(features).detach().cpu().numpy())
    labels_cell.append(lab[0])
    serial_numbers_cell.append(ser_num[0])

100%|█████████████████████████████████████| 70167/70167 [12:28<00:00, 93.74it/s]


In [157]:
all_feat_cell

[array([ 1.3359473 , -0.06429   ,  0.22371387,  0.5712923 ,  0.33000588,
         0.30607045, -0.5742329 , -0.35096377, -0.12962812, -1.132867  ,
         1.1851256 , -0.07516593, -3.9351785 , -0.2788843 ,  0.64109325,
        -0.5488504 , -1.0730788 ,  0.31352624,  4.76736   , -0.93880916,
         1.2289758 , -0.30547446,  0.03359385, -0.50677574,  0.47813702,
         3.6397557 , -1.6570972 ,  0.4893663 , -0.9289645 , -0.3153311 ,
        -0.20028986,  0.5588279 , -0.15236028, -0.18528132,  0.32523888,
         0.26481003,  3.6079023 ,  0.04792903,  2.6113026 , -0.22071916,
        -1.6375008 , -1.5316108 ,  0.20495823,  0.18895158,  0.20828205,
         0.28209603,  0.76343197,  0.20474301,  3.3204126 , -0.24691856],
       dtype=float32),
 array([ 1.1425676 ,  0.2783515 ,  0.22087081, -0.11612812, -0.00518867,
         0.29018375,  0.15568402,  0.48670694,  0.5190521 ,  0.26828158,
         0.15065035, -0.7548012 , -3.5793135 , -0.23989885, -0.27207708,
         0.04442501, -0.714

In [158]:
path = "/home/mvries/Documents/CellShape/UploadData/cellshapeOutput/Models/cloud_autoencoder/dgcnn_foldinget_50_004_predictions_NEW_nucleus_continue_001.pt"

model = CloudAutoEncoder(
    num_features=50, k=20, encoder_type="dgcnn", decoder_type="foldingnet"
)
checkpoint = torch.load(path)
model_dict = model.state_dict()  # load parameters from pre-trained FoldingNet

model.load_state_dict(checkpoint["model_state_dict"])
checkpoint["loss"]

26.536249787585078

In [159]:
root_dir = "/home/mvries/Documents/Datasets/OPM/SingleCellFromNathan_17122021"
df = "/home/mvries/Documents/CellShape/UploadData/cellshapeData/all_cell_data.csv"
dataset = SingleCellDatasetNew(
    df, root_dir, transform=None, img_size=400, cell_component="nucleus"
)

dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
all_feat_nuc = []
all_outputs = []
all_inputs = []
labels_nuc = []
serial_numbers_nuc = []
for data in tqdm(dataloader):
    inputs = data[0]
    lab = data[1]
    ser_num = data[3]
    inputs = inputs.to(device)
    batch_size = inputs.shape[0]

    output, features = model(inputs)
    all_inputs.append(torch.squeeze(inputs).detach().cpu().numpy())
    all_outputs.append(torch.squeeze(output).detach().cpu().numpy())
    all_feat_nuc.append(torch.squeeze(features).detach().cpu().numpy())
    labels_nuc.append(lab[0])
    serial_numbers_nuc.append(ser_num[0])

100%|█████████████████████████████████████| 70167/70167 [11:44<00:00, 99.56it/s]


In [160]:
all_feat_nuc

[array([ 2.3637576 ,  0.28022018,  0.28023568,  0.61973786,  0.56297684,
         0.2781246 ,  0.31925946,  0.8148311 ,  0.74741375,  0.7462992 ,
        -0.33187196,  0.08679578, -3.9605432 ,  0.35818487,  0.3016377 ,
         0.31826806, -0.49990255,  0.25168878,  5.9917855 , -2.2914166 ,
         2.1238441 ,  0.4318273 , -0.6944669 ,  0.7228191 ,  0.29171187,
         3.3658185 , -2.3338315 , -0.21978149, -0.25569957,  0.08710444,
         0.24497548,  0.50384355, -0.9187745 ,  0.50921583,  0.36475134,
         0.6545639 ,  4.1638684 , -0.23605359,  1.7886207 , -0.3041696 ,
        -2.393606  , -0.49816036,  2.019983  , -0.9356307 , -0.09467918,
        -0.7483746 ,  0.43893555, -0.51032984,  3.343973  ,  0.32918862],
       dtype=float32),
 array([ 1.6500531 ,  0.2441522 ,  1.0898392 ,  1.3927075 ,  0.23122251,
         0.26053086,  0.6840242 ,  0.6685614 ,  0.28710136,  0.43073237,
        -0.09589708,  0.31438297, -3.8451557 , -0.06846784,  0.36615312,
         0.26210892, -0.563

In [161]:
# Need to join cell and nuc features and then do predictions on those

In [162]:
all_cell_df = pd.DataFrame(all_feat_cell)
all_cell_df["serialNumber"] = serial_numbers_cell

In [163]:
all_nuc_df = pd.DataFrame(all_feat_nuc)
all_nuc_df["serialNumber"] = serial_numbers_nuc

In [164]:
nucCols = []
for i in range(50):
    nucCols.append(f"nuc{i}")

nucCols.append("serialNumber")

In [165]:
all_nuc_df.columns = nucCols
all_nuc_df

Unnamed: 0,nuc0,nuc1,nuc2,nuc3,nuc4,nuc5,nuc6,nuc7,nuc8,nuc9,...,nuc41,nuc42,nuc43,nuc44,nuc45,nuc46,nuc47,nuc48,nuc49,serialNumber
0,2.363758,0.280220,0.280236,0.619738,0.562977,0.278125,0.319259,0.814831,0.747414,0.746299,...,-0.498160,2.019983,-0.935631,-0.094679,-0.748375,0.438936,-0.510330,3.343973,0.329189,0001_0001_accelerator_20210315_bakal01_erk_mai...
1,1.650053,0.244152,1.089839,1.392707,0.231223,0.260531,0.684024,0.668561,0.287101,0.430732,...,-0.782393,1.795784,-0.057592,-0.026909,-0.825139,1.147998,-0.309477,2.948443,0.255674,0001_0002_accelerator_20210315_bakal01_erk_mai...
2,2.433488,0.177738,0.050791,-0.240343,0.242407,0.081840,0.085129,0.096272,0.666577,0.459039,...,0.571963,1.053843,-0.469541,-0.055329,-0.105974,0.262963,-0.421813,2.255110,-0.136864,0001_0003_accelerator_20210315_bakal01_erk_mai...
3,1.633378,-0.142931,0.842870,0.991155,0.015454,0.023807,0.803087,0.157479,0.066049,0.536879,...,-0.296226,1.653841,-0.572729,0.055531,-0.670236,0.234754,-0.499244,2.744647,0.019480,0001_0004_accelerator_20210315_bakal01_erk_mai...
4,1.400861,-0.204258,0.212922,0.799014,0.616353,0.354219,0.579634,0.325082,-0.375879,-0.155254,...,-0.968143,1.080469,-2.237258,0.433220,-1.322754,-0.862850,-0.588138,3.848569,1.258221,0001_0005_accelerator_20210315_bakal01_erk_mai...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
70162,2.674400,0.028563,1.648065,0.967968,0.020426,0.407016,0.763635,0.560705,-0.847494,1.732055,...,-0.605893,1.123100,0.451426,0.365825,-0.298832,1.051046,-0.029756,4.062128,-0.566216,0148_0149_accelerator_20210318_bakal03_erk_mai...
70163,0.744649,-0.605178,1.848607,0.310854,-0.707326,-0.239575,0.771276,-0.279614,-1.759205,0.610468,...,-1.850919,0.002862,-0.401695,0.213863,-1.002397,0.795037,0.717941,4.103014,-0.060406,0148_0150_accelerator_20210318_bakal03_erk_mai...
70164,4.053229,0.728206,-0.805661,-0.205370,-0.278708,0.422023,-0.205914,0.023630,-0.376416,0.028567,...,-0.291311,-0.800661,-0.785478,0.126535,0.468476,-0.842588,-0.447983,1.196290,0.554765,0148_0151_accelerator_20210318_bakal03_erk_mai...
70165,-0.924818,-0.445705,0.926993,0.639158,0.062220,0.243191,0.576030,1.101720,-1.079200,0.284493,...,-0.488579,0.221143,-0.345841,0.233913,-0.728334,0.223090,-0.078949,2.601056,0.474541,0148_0152_accelerator_20210318_bakal03_erk_mai...


In [166]:
all_df = all_cell_df.join(
    all_nuc_df.set_index("serialNumber"), on="serialNumber"
)
all_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,nuc40,nuc41,nuc42,nuc43,nuc44,nuc45,nuc46,nuc47,nuc48,nuc49
0,1.335947,-0.064290,0.223714,0.571292,0.330006,0.306070,-0.574233,-0.350964,-0.129628,-1.132867,...,-2.393606,-0.498160,2.019983,-0.935631,-0.094679,-0.748375,0.438936,-0.510330,3.343973,0.329189
1,1.142568,0.278351,0.220871,-0.116128,-0.005189,0.290184,0.155684,0.486707,0.519052,0.268282,...,-3.046510,-0.782393,1.795784,-0.057592,-0.026909,-0.825139,1.147998,-0.309477,2.948443,0.255674
2,0.536568,0.535183,0.476276,-0.625985,-0.303703,-0.208633,0.814239,-0.092690,1.185713,0.323764,...,-0.521098,0.571963,1.053843,-0.469541,-0.055329,-0.105974,0.262963,-0.421813,2.255110,-0.136864
3,2.690795,-0.553936,0.131338,0.886677,0.220248,0.359097,0.475790,-0.600473,0.271425,-0.689685,...,-1.872417,-0.296226,1.653841,-0.572729,0.055531,-0.670236,0.234754,-0.499244,2.744647,0.019480
4,1.297160,0.063407,-0.139035,0.669093,-0.159644,0.063092,-0.345914,0.657981,-0.243066,0.321983,...,-3.626510,-0.968143,1.080469,-2.237258,0.433220,-1.322754,-0.862850,-0.588138,3.848569,1.258221
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
70162,1.531698,-0.224410,0.296207,1.026213,-0.136934,0.462338,0.164994,0.577112,0.006295,0.223636,...,-3.756241,-0.605893,1.123100,0.451426,0.365825,-0.298832,1.051046,-0.029756,4.062128,-0.566216
70163,1.799083,0.071390,0.831331,0.628335,-0.710662,0.633382,-0.802400,-0.328795,-0.627607,0.074306,...,-2.884280,-1.850919,0.002862,-0.401695,0.213863,-1.002397,0.795037,0.717941,4.103014,-0.060406
70164,-0.578693,0.127953,-0.730597,0.431282,-0.296105,0.256308,-0.973248,-1.452100,-1.401660,-0.818603,...,-1.742182,-0.291311,-0.800661,-0.785478,0.126535,0.468476,-0.842588,-0.447983,1.196290,0.554765
70165,-1.318052,-0.361149,0.414458,0.904969,0.341739,0.875744,0.069388,0.337239,0.104456,0.465414,...,-2.473555,-0.488579,0.221143,-0.345841,0.233913,-0.728334,0.223090,-0.078949,2.601056,0.474541


In [148]:
all_df["Treatment"] = labels_cell

In [150]:
all_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,nuc41,nuc42,nuc43,nuc44,nuc45,nuc46,nuc47,nuc48,nuc49,Treatment
0,2.376747,0.273255,1.773725,-0.202388,-0.363287,-0.558336,-0.363920,-0.610801,-1.700994,0.022208,...,-0.498160,2.019983,-0.935631,-0.094679,-0.748375,0.438936,-0.510330,3.343973,0.329189,Palbociclib
1,1.948957,0.290972,1.646176,-0.697013,-0.083165,-0.216205,-0.482494,1.184853,-0.599467,1.105442,...,-0.782393,1.795784,-0.057592,-0.026909,-0.825139,1.147998,-0.309477,2.948443,0.255674,Palbociclib
2,2.184104,0.860687,1.196558,-2.020930,-1.010935,-0.641888,0.834171,-0.195172,-0.633919,2.608234,...,0.571963,1.053843,-0.469541,-0.055329,-0.105974,0.262963,-0.421813,2.255110,-0.136864,Palbociclib
3,5.035021,-1.049980,1.713561,0.652896,-1.705227,0.133653,1.583452,0.278779,-1.961247,-0.468504,...,-0.296226,1.653841,-0.572729,0.055531,-0.670236,0.234754,-0.499244,2.744647,0.019480,Palbociclib
4,1.954101,0.739590,0.237623,-0.080215,-1.205235,-0.640888,-0.347276,0.858890,-1.054497,0.387034,...,-0.968143,1.080469,-2.237258,0.433220,-1.322754,-0.862850,-0.588138,3.848569,1.258221,Palbociclib
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
275,4.318625,1.286473,0.585905,-0.986787,-1.134113,-0.637686,1.540218,-0.986506,-0.697261,-0.065699,...,-0.491059,0.953977,-1.564872,-0.089829,-0.726465,-0.633800,-0.560001,3.131103,0.935712,Palbociclib
276,-0.578203,-0.520079,1.854106,2.148174,-0.834212,-0.473241,0.199134,1.398717,-2.806431,1.851045,...,-0.623942,1.810594,0.463510,-0.310281,-0.588479,1.580210,-0.429272,2.518702,-0.368702,Palbociclib
277,2.433918,0.537198,1.816412,0.270785,-0.490664,-0.043441,0.125383,-0.590728,-1.189746,0.713449,...,-0.820707,1.339987,-2.112126,0.353678,-1.549345,-0.911284,-0.655037,3.427647,1.323627,Palbociclib
278,-2.692635,-2.221163,1.320979,-1.171398,0.592593,-0.314039,-1.392639,2.506296,-0.981529,-0.760213,...,-0.853007,0.676412,-0.007248,0.121417,-0.747347,0.145947,-0.280917,0.695532,0.769288,Palbociclib


In [87]:
cellCols = [i for i in range(50)]
for i in range(50):
    cellCols.append(nucCols[i])

cellCols.append("serialNumber")
cellCols.append("Treatment")

In [89]:
all_df = all_df[cellCols]

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,nuc42,nuc43,nuc44,nuc45,nuc46,nuc47,nuc48,nuc49,serialNumber,Treatment
0,1.176457,0.036573,0.144208,0.974561,0.190976,-0.051801,-0.456669,-0.280339,0.313647,-1.101814,...,1.049302,-2.218610,0.090517,-1.944057,-0.844360,-0.533889,2.106393,1.966742,0001_0001_accelerator_20210315_bakal01_erk_mai...,Palbociclib
1,1.042657,0.279525,0.216832,0.511881,-0.057861,-0.024996,0.237815,0.601991,0.617132,0.073759,...,1.040648,-2.163831,0.183851,-1.928476,-0.782567,-0.487678,2.048687,1.972477,0001_0002_accelerator_20210315_bakal01_erk_mai...,Palbociclib
2,0.244490,0.691841,0.482207,-0.103967,-0.420336,-0.277397,0.836043,0.092463,1.350604,-0.120912,...,0.861794,-2.304255,0.161807,-1.877121,-1.007296,-0.632867,2.227048,1.996529,0001_0003_accelerator_20210315_bakal01_erk_mai...,Palbociclib
3,2.582871,-0.598855,0.085096,1.322022,0.045077,0.016280,0.213114,-0.410044,0.413402,-1.211520,...,0.974015,-2.377264,0.131908,-1.858805,-0.906936,-0.571006,2.235425,1.940970,0001_0004_accelerator_20210315_bakal01_erk_mai...,Palbociclib
4,1.209503,0.096136,-0.107168,0.919034,-0.340047,-0.163352,-0.212832,0.680159,0.039457,0.288753,...,0.981275,-2.222528,0.060697,-1.877372,-0.851041,-0.462084,2.023623,1.987442,0001_0005_accelerator_20210315_bakal01_erk_mai...,Palbociclib
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
70162,1.549778,-0.172052,0.335738,1.545157,-0.158937,0.176326,0.356563,0.686507,0.429516,0.078597,...,1.079316,-2.118956,0.031175,-1.856185,-1.017741,-0.583320,2.290260,2.004452,0148_0149_accelerator_20210318_bakal03_erk_mai...,No Treatment
70163,1.635872,-0.007548,0.787211,0.811775,-0.734099,0.242518,-0.593403,-0.195167,-0.291073,0.016119,...,0.947219,-2.429381,0.066181,-1.854842,-0.940333,-0.465450,2.158909,1.881860,0148_0150_accelerator_20210318_bakal03_erk_mai...,No Treatment
70164,-0.819511,0.094918,-0.401514,0.119053,-0.233743,0.002145,-0.899988,-1.436586,-1.284881,-0.952961,...,1.045491,-2.328006,0.154216,-1.868412,-0.919118,-0.462811,2.245556,1.888711,0148_0151_accelerator_20210318_bakal03_erk_mai...,No Treatment
70165,-0.844722,-0.556179,0.511987,1.080888,0.354778,0.645295,0.325304,0.512023,0.410016,0.396847,...,1.152191,-2.113703,0.075276,-1.816274,-0.984393,-0.636000,2.231858,1.968197,0148_0152_accelerator_20210318_bakal03_erk_mai...,No Treatment


In [97]:
df = pd.read_csv(df)

In [98]:
df["Proximal"]

0        0
1        0
2        0
3        0
4        0
        ..
70363    0
70364    1
70365    0
70366    1
70367    0
Name: Proximal, Length: 70368, dtype: int64

In [100]:
all_df_new = all_df.join(
    df[["Proximal", "serialNumber"]].set_index("serialNumber"),
    on="serialNumber",
)

In [101]:
all_df_new

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,nuc43,nuc44,nuc45,nuc46,nuc47,nuc48,nuc49,serialNumber,Treatment,Proximal
0,1.176457,0.036573,0.144208,0.974561,0.190976,-0.051801,-0.456669,-0.280339,0.313647,-1.101814,...,-2.218610,0.090517,-1.944057,-0.844360,-0.533889,2.106393,1.966742,0001_0001_accelerator_20210315_bakal01_erk_mai...,Palbociclib,0
1,1.042657,0.279525,0.216832,0.511881,-0.057861,-0.024996,0.237815,0.601991,0.617132,0.073759,...,-2.163831,0.183851,-1.928476,-0.782567,-0.487678,2.048687,1.972477,0001_0002_accelerator_20210315_bakal01_erk_mai...,Palbociclib,0
2,0.244490,0.691841,0.482207,-0.103967,-0.420336,-0.277397,0.836043,0.092463,1.350604,-0.120912,...,-2.304255,0.161807,-1.877121,-1.007296,-0.632867,2.227048,1.996529,0001_0003_accelerator_20210315_bakal01_erk_mai...,Palbociclib,0
3,2.582871,-0.598855,0.085096,1.322022,0.045077,0.016280,0.213114,-0.410044,0.413402,-1.211520,...,-2.377264,0.131908,-1.858805,-0.906936,-0.571006,2.235425,1.940970,0001_0004_accelerator_20210315_bakal01_erk_mai...,Palbociclib,0
4,1.209503,0.096136,-0.107168,0.919034,-0.340047,-0.163352,-0.212832,0.680159,0.039457,0.288753,...,-2.222528,0.060697,-1.877372,-0.851041,-0.462084,2.023623,1.987442,0001_0005_accelerator_20210315_bakal01_erk_mai...,Palbociclib,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
70162,1.549778,-0.172052,0.335738,1.545157,-0.158937,0.176326,0.356563,0.686507,0.429516,0.078597,...,-2.118956,0.031175,-1.856185,-1.017741,-0.583320,2.290260,2.004452,0148_0149_accelerator_20210318_bakal03_erk_mai...,No Treatment,0
70163,1.635872,-0.007548,0.787211,0.811775,-0.734099,0.242518,-0.593403,-0.195167,-0.291073,0.016119,...,-2.429381,0.066181,-1.854842,-0.940333,-0.465450,2.158909,1.881860,0148_0150_accelerator_20210318_bakal03_erk_mai...,No Treatment,1
70164,-0.819511,0.094918,-0.401514,0.119053,-0.233743,0.002145,-0.899988,-1.436586,-1.284881,-0.952961,...,-2.328006,0.154216,-1.868412,-0.919118,-0.462811,2.245556,1.888711,0148_0151_accelerator_20210318_bakal03_erk_mai...,No Treatment,0
70165,-0.844722,-0.556179,0.511987,1.080888,0.354778,0.645295,0.325304,0.512023,0.410016,0.396847,...,-2.113703,0.075276,-1.816274,-0.984393,-0.636000,2.231858,1.968197,0148_0152_accelerator_20210318_bakal03_erk_mai...,No Treatment,1


In [102]:
import pandas as pd
from sklearn import svm
from sklearn.model_selection import cross_val_score, KFold, train_test_split
from sklearn.metrics import plot_confusion_matrix, f1_score
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import seaborn as sns
import collections
import os

%matplotlib inline
scalar = StandardScaler()
feats_not_centred = scalar.fit_transform(all_df_new.iloc[:, :-3])
xy = pd.DataFrame(feats_not_centred)
xy["Treatment"] = all_df_new["Treatment"]
xy["Proximal"] = all_df_new["Proximal"].values

blebb_noc = xy[
    (xy["Treatment"] == "Blebbistatin") | (xy["Treatment"] == "Nocodazole")
]


X = blebb_noc.iloc[:, :100]
y = blebb_noc["Treatment"]
clf = svm.SVC(kernel="linear", C=1, random_state=0, class_weight="balanced")
kf = KFold(10, shuffle=True, random_state=0)
scores = cross_val_score(clf, X, y, cv=kf)
print("Using cell and nucleus features in the proximal and distal environment")
print(f"{scores.mean()} accuracy with a standard deviation of {scores.std()}")
print("======================================================================")



Using cell and nucleus features in the proximal and distal environment
0.7940595345247448 accuracy with a standard deviation of 0.013507155368720476
