In [28]:
"""This is to split the data into train and test sets, respectively"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import shutil

labels = pd.read_csv("labels.csv", header = None, delim_whitespace=True, names = ["image", "value", "type"])

print(labels)
print(labels.shape)

                    image  value  type
0     images/cell0001.png    1.0  mono
1     images/cell0002.png    1.0  mono
2     images/cell0003.png    1.0  mono
3     images/cell0004.png    0.0  mono
4     images/cell0005.png    1.0  mono
...                   ...    ...   ...
2619  images/cell2620.png    0.0  poly
2620  images/cell2621.png    0.0  poly
2621  images/cell2622.png    0.0  poly
2622  images/cell2623.png    0.0  poly
2623  images/cell2624.png    0.0  poly

[2624 rows x 3 columns]
(2624, 3)


In [30]:
"""Make binary classification and assign 0.3 to 0.0 and assign 0.6 to 1.0"""

labels_binary = labels.copy(deep=True)
print(labels_binary)

for row in range(len(labels_binary)):
    if labels_binary.iloc[row]['value'] == 0.3333333333333333:
        labels_binary.at[row, 'value'] = 0.0
    elif labels_binary.iloc[row]['value'] == 0.6666666666666666:
        labels_binary.at[row, 'value'] = 1.0


                    image  value  type
0     images/cell0001.png    1.0  mono
1     images/cell0002.png    1.0  mono
2     images/cell0003.png    1.0  mono
3     images/cell0004.png    0.0  mono
4     images/cell0005.png    1.0  mono
...                   ...    ...   ...
2619  images/cell2620.png    0.0  poly
2620  images/cell2621.png    0.0  poly
2621  images/cell2622.png    0.0  poly
2622  images/cell2623.png    0.0  poly
2623  images/cell2624.png    0.0  poly

[2624 rows x 3 columns]


In [31]:
""" stratified sampling """

train_set = labels_binary.groupby(['type', 'value'], group_keys=False).apply(lambda x: x.sample(frac=0.75))
# The frac = 0.75 is because we want to sample 75%

print(train_set)

                    image  value  type
448   images/cell0449.png    0.0  mono
1157  images/cell1158.png    0.0  mono
2170  images/cell2171.png    0.0  mono
2211  images/cell2212.png    0.0  mono
1205  images/cell1206.png    0.0  mono
...                   ...    ...   ...
833   images/cell0834.png    1.0  poly
2073  images/cell2074.png    1.0  poly
1841  images/cell1842.png    1.0  poly
643   images/cell0644.png    1.0  poly
790   images/cell0791.png    1.0  poly

[1969 rows x 3 columns]


In [32]:
""" create a pandas dataframe for the test_set """

test_set = pd.DataFrame(columns = ["image", "value", "type"])
#print(test_set)
count = 0
for row in range(len(labels_binary)):
    image = str(labels_binary.iloc[row]['image'])
    if image not in train_set['image'].unique():
        count +=1
        test_set.loc[len(test_set.index)] = [labels_binary.iloc[row]['image'],
                                            labels_binary.iloc[row]['value'],
                                            labels_binary.iloc[row]['type']]
print(test_set)

                   image  value  type
0    images/cell0002.png    1.0  mono
1    images/cell0004.png    0.0  mono
2    images/cell0006.png    1.0  mono
3    images/cell0009.png    0.0  mono
4    images/cell0010.png    1.0  mono
..                   ...    ...   ...
650  images/cell2611.png    0.0  poly
651  images/cell2612.png    0.0  poly
652  images/cell2617.png    0.0  poly
653  images/cell2619.png    0.0  poly
654  images/cell2624.png    0.0  poly

[655 rows x 3 columns]


In [33]:
"""ensure train and test folders are empty"""

test_def_dir = 'data/test/defective'
test_func_dir = 'data/test/functional'

train_def_dir = 'data/train/defective'
train_func_dir = 'data/train/functional'

if os.listdir(test_def_dir): # True if list is not empty
    for f in os.listdir(test_def_dir):
        os.remove(test_def_dir + '/' + f)

if os.listdir(test_func_dir):
    for f in os.listdir(test_func_dir):
        os.remove(test_func_dir + '/' + f)

if os.listdir(train_def_dir):
    for f in os.listdir(train_def_dir):
        os.remove(train_def_dir + '/' + f)

if os.listdir(train_func_dir):
    for f in os.listdir(train_func_dir):
        os.remove(train_func_dir + '/' + f)

In [34]:
""" copy images to the respective folders """

src = './data/all_data'

for line in range(len(train_set)):
    source = src + '/' + str(train_set.iloc[line]['image'])[7:]
    if train_set.iloc[line]['value'] == 1.0:
        shutil.copy(source, train_def_dir)
    elif train_set.iloc[line]['value'] == 0.0:
        shutil.copy(source, train_func_dir)

for line in range(len(test_set)):
    source = src + '/' + str(test_set.iloc[line]['image'])[7:]
    if test_set.iloc[line]['value'] == 1.0:
        shutil.copy(source, test_def_dir)
    elif test_set.iloc[line]['value'] == 0.0:
        shutil.copy(source, test_func_dir)

In [35]:
print(len(os.listdir(train_func_dir)))
print(len(os.listdir(train_def_dir)))

print(len(os.listdir(test_def_dir)))
print(len(os.listdir(test_func_dir)))

total = len(os.listdir(train_func_dir)) + len(os.listdir(train_def_dir))+ len(os.listdir(test_def_dir)) + len(os.listdir(test_func_dir))

print("total: ", total)

1353
616
205
450
total:  2624


In [None]:
""" count the numbers of each group after stratified sampling """

count_1_mono = 0
count_1_poly = 0

count_0_mono = 0
count_0_poly = 0

for row in range(group1.shape[0]):
    if group1.iloc[row]["value"] == 1.0:
        if group1.iloc[row]["type"] == "mono":
            count_1_mono += 1
        else:
            count_1_poly += 1
    elif group1.iloc[row]["value"] == 0.0:
        if group1.iloc[row]["type"] == "mono":
            count_0_mono += 1
        else:
            count_0_poly += 1

print("count_0_poly: ", count_0_poly)
print("count_0_mono: ", count_0_mono)
print("count_1_poly: ", count_1_poly)
print("count_1_mono: ", count_1_mono, '\n')
print("total: ", count_0_poly+count_0_mono+count_1_poly+count_1_mono)

In [9]:
""" count the number of each group before stratified sampling to validate that the stratification went well"""

count_1_mono_pre = 0
count_1_poly_pre = 0

count_0_mono_pre = 0
count_0_poly_pre = 0

for row in range(labels_binary.shape[0]):
    if labels_binary.iloc[row]["value"] == 1.0:
        if labels_binary.iloc[row]["type"] == "mono":
            count_1_mono_pre += 1
        else:
            count_1_poly_pre += 1
    elif labels_binary.iloc[row]["value"] == 0.0:
        if labels_binary.iloc[row]["type"] == "mono":
            count_0_mono_pre += 1
        else:
            count_0_poly_pre += 1

print("count_0_poly_pre: ", count_0_poly_pre)
print("count_0_mono_pre: ", count_0_mono_pre)
print("count_1_poly_pre: ", count_1_poly_pre)
print("count_1_mono_pre: ", count_1_mono_pre, '\n')
print("total: ", count_0_poly_pre+count_0_mono_pre+count_1_poly_pre+count_1_mono_pre)

count_0_poly_pre:  1098
count_0_mono_pre:  705
count_1_poly_pre:  452
count_1_mono_pre:  369 

total:  2624


In [14]:
""" Stratified sampling for all labels, not converted to 0.0 and 1.0 """

group2 = labels.groupby(['type', 'value'], group_keys=False).apply(lambda x: x.sample(frac=0.75))

In [15]:
""" getting the counts of each class in the group2 stratified sample """

count_1_mono = 0
count_1_poly = 0

count_0_mono = 0
count_0_poly = 0

count_03_mono = 0
count_03_poly = 0

count_06_mono = 0
count_06_poly = 0

for row in range(group2.shape[0]):
    if group2.iloc[row]["value"] == 1.0:
        if group2.iloc[row]["type"] == "mono":
            count_1_mono += 1
        else:
            count_1_poly += 1
    elif group2.iloc[row]["value"] == 0.0:
        if group2.iloc[row]["type"] == "mono":
            count_0_mono += 1
        else:
            count_0_poly += 1
    elif group2.iloc[row]["value"] == 0.3333333333333333:
        if group2.iloc[row]["type"] == "mono":
            count_03_mono += 1
        else:
            count_03_poly += 1
    else:
        if group2.iloc[row]["type"] == "mono":
            count_06_mono += 1
        else:
            count_06_poly += 1

print("count_1_mono :", count_1_mono)
print("count_1_poly: ", count_1_poly)

print("count_0_mono: ", count_0_mono)
print("count_0_poly: ", count_0_poly)

print("count_03_mono: ", count_03_mono)
print("count_03_poly: ", count_03_poly)

print("count_06_mono: ", count_06_mono)
print("count_06_poly: ", count_06_poly)

count_1_mono : 235
count_1_poly:  302
count_0_mono:  441
count_0_poly:  690
count_03_mono:  88
count_03_poly:  134
count_06_mono:  42
count_06_poly:  38
