In [1]:
import torch

def make_weights_for_balanced_classes_split(dataset):
    N = float(len(dataset))                                           
    weight_per_class = [N/len(dataset.label[c]) for c in range(len(dataset.label))]                                                                                                     
    weight = [0] * int(N)                                           
    for idx in range(len(dataset)):   
        y = dataset.getlabel(idx)                        
        weight[idx] = weight_per_class[y]                                  

    return torch.DoubleTensor(weight)

In [41]:
import pandas as pd
import numpy as np
import cv2
import os
import torch
from torch.utils.data import Dataset
from sklearn import preprocessing
import random
from pathlib import Path
import tifffile as tfl


class CellData(Dataset):
    """
    Dataset class for the single cell dataset
    """

    def __init__(
        self,
        h5_path=None,
        csv_path=None,
        state=None,
        shuffle=False,
        drug_label="Blebbistatin",
    ):
        # Set all input args as attributes
        self.__dict__.update(locals())
        self.h5_path = h5_path
        self.csv_path = csv_path
        self.shuffle = shuffle
        self.drug_label = drug_label

        self.slide_data = pd.read_csv(self.csv_path)
        labels = {
            "Binimetinib": 0,
            "Blebbistatin": 0,
            "CK666": 0,
            "DMSO": 0,
            "H1152": 0,
            "MK1775": 0,
            "No Treatment": 0,
            "Nocodazole": 0,
            "PF228": 0,
            "Palbociclib": 0,
        }
        self.label_dict = {k: 1 if k == self.drug_label else 0 for k in labels}

        # ---->split dataset
        self.state = state
        self.pwn = np.unique(
            self.slide_data[
                (self.slide_data["Splits"] == self.state)  # &
                #             (self.slide_data['Class']!='Trials') &
                #             (self.slide_data['Class']!='Successful')
            ]["PlateWellNum"]
            .reset_index(drop=True)
            .values
        )
        self.label = self.slide_data[
            (self.slide_data["Splits"] == self.state)  # &
            #             (self.slide_data['Class']!='Trials') &
            #             (self.slide_data['Class']!='Successful')
        ]["Treatment"].reset_index(drop=True)
        class_counts = self.label.map(self.label_dict).value_counts()
        class_weights = 1.0 / torch.tensor(class_counts, dtype=torch.float32)
        self.weights = class_weights[self.label.map(self.label_dict).values]
        
    def get_weights(self):
        return self.weights


    def __len__(self):
        return len(self.pwn)

    def __getitem__(self, idx):
        slide_id = self.pwn[idx]
        full_path = Path(self.h5_path) / f"{slide_id}.csv"
        hdf = pd.read_csv(full_path)

        features = torch.from_numpy(hdf.iloc[:, :100].values).type(torch.DoubleTensor)
        label = torch.tensor(int(self.label_dict[np.unique(hdf["Treatment"])[0]]))
        treatment = np.unique(hdf["Treatment"])[0]
        serial_number = hdf["serialNumber"].values

        # ----> shuffle
        if self.shuffle == True:
            index = [x for x in range(features.shape[0])]
            random.shuffle(index)
            features = features[index]

        return features, label, slide_id, treatment, serial_number.tolist()

In [42]:
dset = CellData(h5_path = "/mnt/nvme0n1/Datasets/SingleCellFromNathan_17122021/TransformerFeats/csv_files/", 
                       csv_path= "/mnt/nvme0n1/Datasets/SingleCellFromNathan_17122021/"\
                "folds_3DMIL/all_data_removedwrong_ori_removed"\
                "Two_train_test_50_20_30_fold00.csv", 
                        state='train'
                  )

In [43]:
len(dset)

73

In [45]:
dset.get_weights()

tensor([3.4249e-05, 3.4249e-05, 3.4249e-05,  ..., 3.4249e-05, 3.4249e-05,
        3.4249e-05])

In [37]:
from torch.utils.data import DataLoader
import numpy as np

dload = DataLoader(dset, batch_size=1, 
                  sampler=torch.utils.data.WeightedRandomSampler(
                              weights=torch.from_numpy(weights), num_samples=73)
                  )

In [40]:
labels = []
for d in dload:
    if d[1][0].numpy() == 1:
        ones.append(1)
    print(d[1][0].numpy())  
    labels.append(d[1][0].numpy())

0
0
0
1
0
1
1
1
1
0
1
1
1
0
1
0
0
1
0
0
0
0
1
0
0
0
1
1
0
0
0
1
1
0
1
1
0
1
0
1
1
1
0
1
1
1
1
1
1
0
1
1
0
0
0
1
0
0
0
1
0
1
0
1
0
0
0
1
1
1
1
0
0


In [39]:
class_counts = np.bincount(labels)

# Calculate the inverse of each class frequency
class_weights = 1. / class_counts

# Now, you can create a weight for each instance in the dataset
weights = class_weights[labels]

In [35]:
weights

array([0.01538462, 0.01538462, 0.01538462, 0.01538462, 0.01538462,
       0.01538462, 0.01538462, 0.01538462, 0.01538462, 0.01538462,
       0.01538462, 0.01538462, 0.125     , 0.01538462, 0.01538462,
       0.01538462, 0.01538462, 0.01538462, 0.01538462, 0.01538462,
       0.125     , 0.01538462, 0.01538462, 0.01538462, 0.125     ,
       0.125     , 0.01538462, 0.01538462, 0.01538462, 0.01538462,
       0.01538462, 0.01538462, 0.01538462, 0.01538462, 0.125     ,
       0.01538462, 0.01538462, 0.01538462, 0.01538462, 0.01538462,
       0.01538462, 0.125     , 0.01538462, 0.01538462, 0.01538462,
       0.01538462, 0.01538462, 0.01538462, 0.01538462, 0.01538462,
       0.01538462, 0.01538462, 0.01538462, 0.01538462, 0.01538462,
       0.125     , 0.01538462, 0.01538462, 0.01538462, 0.01538462,
       0.01538462, 0.01538462, 0.01538462, 0.01538462, 0.01538462,
       0.01538462, 0.01538462, 0.125     , 0.01538462, 0.01538462,
       0.01538462, 0.01538462, 0.01538462])

In [47]:
df = pd.read_csv("/mnt/nvme0n1/Datasets/SingleCellFromNathan_17122021/"\
                "folds_3DMIL/all_data_removedwrong_ori_removed"\
                "Two_train_test_50_20_30_fold00.csv")
df

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,runNumber,fieldNumber,Row,Column,Treatment,serialNumber,Proximal,nucleusCoverslipDistance,...,erkIntensityCell,zDim,xDim,yDim,PlateNumber,roll,pitch,yaw,PlateWellNum,Splits
0,0,0,1,1,B,2,Palbociclib,0001_0001_accelerator_20210315_bakal01_erk_mai...,0,94.392097,...,15.150823,27.0,29.0,27.0,1,-47.420894,28.757543,64.409348,1B2,test
1,1,1,1,1,B,2,Palbociclib,0001_0002_accelerator_20210315_bakal01_erk_mai...,0,95.745798,...,22.502644,38.0,38.0,24.0,1,3.957440,-26.918521,95.859157,1B2,test
2,2,2,1,1,B,2,Palbociclib,0001_0003_accelerator_20210315_bakal01_erk_mai...,0,78.871111,...,1.508700,40.0,37.0,44.0,1,60.852063,-34.099907,36.650979,1B2,test
3,3,3,1,1,B,2,Palbociclib,0001_0004_accelerator_20210315_bakal01_erk_mai...,0,94.070423,...,23.218808,34.0,24.0,63.0,1,-86.973231,-18.720156,82.882932,1B2,test
4,4,4,1,1,B,2,Palbociclib,0001_0005_accelerator_20210315_bakal01_erk_mai...,0,101.902062,...,3.281942,23.0,40.0,24.0,1,-7.784532,9.477303,172.396074,1B2,test
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
65495,65495,70106,1,148,F,11,No Treatment,0148_0148_accelerator_20210318_bakal03_erk_mai...,0,9.259455,...,12.293090,18.0,44.0,85.0,3,73.433925,0.740002,-0.038130,3F11,train
65496,65496,70107,1,148,F,11,No Treatment,0148_0149_accelerator_20210318_bakal03_erk_mai...,0,8.753131,...,17.760000,18.0,16.0,17.0,3,44.417517,23.726035,66.810003,3F11,train
65497,65497,70108,1,148,F,11,No Treatment,0148_0150_accelerator_20210318_bakal03_erk_mai...,1,6.748609,...,23.811151,16.0,24.0,22.0,3,-51.445433,-5.610570,114.151912,3F11,train
65498,65498,70109,1,148,F,11,No Treatment,0148_0151_accelerator_20210318_bakal03_erk_mai...,0,7.933269,...,13.921818,18.0,89.0,62.0,3,-26.995276,-2.383103,-1.066633,3F11,train


In [None]:
df.groupby