In [200]:
import os
import shutil
import numpy as np
import cv2

### Load functions

In [201]:
path = os.getcwd()
names = ['anabaena', 'aphanizomenon', 'detritus', 'dolichospermum', 'microcystis', 'oscillatoria', 'synechococcus', 'water bubble', 'woronichinia']

In [202]:
def move_to_combined(exclude_classes = []):
    splits = ['test', 'train', 'valid']

    if os.path.exists(path + f"/combined_ds"):
        shutil.rmtree(path + f"/combined_ds")
    os.mkdir(path + f"/combined_ds")
    os.mkdir(path + f"/combined_ds/images")
    os.mkdir(path + f"/combined_ds/labels")

    
    for split in splits:
        for img in os.listdir(f'{path}/b_ds/{split}/images'):
            if len(exclude_classes) > 0:
                for exclude_class in exclude_classes:
                    if not exclude_class in img:
                        shutil.copy(f'{path}/b_ds/{split}/images/{img}', f'{path}/combined_ds/images/{img}')
            else:
                shutil.copy(f'{path}/b_ds/{split}/images/{img}', f'{path}/combined_ds/images/{img}')
        for label in os.listdir(f'{path}/b_ds/{split}/labels'):
            if len(exclude_classes) > 0:
                for exclude_class in exclude_classes:            
                    if not exclude_class in label:
                        shutil.copy(f'{path}/b_ds/{split}/labels/{label}', f'{path}/combined_ds/labels/{label}')
            else:
                shutil.copy(f'{path}/b_ds/{split}/labels/{label}', f'{path}/combined_ds/labels/{label}')

In [203]:
def apply_CLAHE():
    # images are already resized to 512x512
    # apply CLAHE to images

    for img in os.listdir(f'{path}/combined_ds/images'):
        image = cv2.imread(f'{path}/combined_ds/images/{img}', cv2.IMREAD_GRAYSCALE) # need to convert to grayscale
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        image = clahe.apply(image)
        cv2.imwrite(f'{path}/combined_ds/images/{img}', image)

In [204]:
def augment_image(image_path):
    image = cv2.imread(image_path)
    image = cv2.flip(image, 1)
    # cv2.imwrite(, image)

    # augment label
    # with open(path + f"/org_ds/labels/{f'{real_name}_{names_with_freqs[numval]}_aug'}.txt", "w") as file:
    #     file.write(f"{asdf} {1-float(file.readline().split(' ')[1])} {file.readline().split(' ')[2]} {file.readline().split(' ')[3]} {file.readline().split(' ')[4]}")

In [205]:
def organize_to_names(include_backgrounds=100, augment=False):
    
    names_with_freqs = [0 for i in range(len(names))]
    x=0

    for name in names:
        if os.path.exists(path + f"/org_ds/{name}"):
            shutil.rmtree(path + f"/org_ds/{name}")
        os.makedirs(path + f"/org_ds/{name}")
    
    if os.path.exists(path + f"/org_ds/labels"):
            shutil.rmtree(path + f"/org_ds/labels")
    os.makedirs(path + f"/org_ds/labels")

    if include_backgrounds:
        if os.path.exists(path + f"/org_ds/backgrounds"):
            shutil.rmtree(path + f"/org_ds/backgrounds")
        os.makedirs(path + f"/org_ds/backgrounds")
    
    for label in os.listdir(path + "/combined_ds/labels"):
        # organize everything into folders based on
        
        with open(path + f"/combined_ds/labels/{label}") as file:
            # read first line
            asdf = file.readline().split(" ")[0]
            # print(asdf)
            try:
                numval = int(asdf)
                real_name = names[numval]
                
                # move image and label to folder
                names_with_freqs[numval] += 1

                shutil.copy(path + f"/combined_ds/images/{label[:-4]}.jpg", path + f"/org_ds/{real_name}/{f'{real_name}_{names_with_freqs[numval]}'}.jpg")
                shutil.copy(path + f"/combined_ds/labels/{label}", path + f"/org_ds/labels/{f'{real_name}_{names_with_freqs[numval]}'}.txt")

                curr_name = f'{real_name}_{names_with_freqs[numval]}'
                if augment:
                    augment_image(curr_name)
                    
                x+=1
            except:
                # label does not exist
                if include_backgrounds > 0:
                    shutil.copy(path + f"/combined_ds/images/{label[:-4]}.jpg", path + f"/org_ds/backgrounds/{label[:-4]}_background.jpg")
                    include_backgrounds-=1
    print(x)

In [206]:
globftrain = 0
globfvalid = 0
globftest = 0

In [207]:
def get_train_val_test_splits(include_backgrounds=True):
    path = os.getcwd()

    ftrain, fval, ftest = np.array([]), np.array([]), np.array([])
    # stratify splitting of data
    print("getting splits")
    for name in names:
        allFileNames = os.listdir(path + f"/org_ds/{name}")
        np.random.seed(42)
        np.random.shuffle(allFileNames)

        train, val, test = np.split(np.array(allFileNames),[int(len(allFileNames)*0.8), int(len(allFileNames)*0.9)])

        ftrain = np.concatenate((ftrain, train))
        fval = np.concatenate((fval, val))
        ftest = np.concatenate((ftest, test))

        print(name, len(train), len(val), len(test))
    print("final lengths after stratified split: ", len(ftrain), len(fval), len(ftest))
    
    

    if include_backgrounds:
        allFileNames = os.listdir(path + f"/org_ds/backgrounds")
        np.random.seed(42)
        np.random.shuffle(allFileNames)

        train, val, test = np.split(np.array(allFileNames),[int(len(allFileNames)*0.8), int(len(allFileNames)*0.9)])

        ftrain = np.concatenate((ftrain, train))
        fval = np.concatenate((fval, val))
        ftest = np.concatenate((ftest, test))
    global globftrain
    global globfval
    global globftest
    globftrain = ftrain
    globfval = fval
    globftest = ftest
    
    return ftrain, fval, ftest

In [222]:
def check_freqs(ftrain, fval, ftest):
    splits = [ftrain, fval, ftest]
    print("printing frequency information")
    print(len(ftrain), len(fval), len(ftest))
    print(len(ftrain)/(len(ftrain)+len(fval)+len(ftest)))
    splitarr = ['train', 'val', 'test']
    # dictionary storing percentage frequencies across splits
    freqs = {}
    for name in names:
        freqs[name] = [0,0,0]
    
    freqs["background"] = [0,0,0]

    for i in range(len(splits)):
        for name in names:
            for file in splits[i]:
                if name in file:
                    freqs[name][i] += 1
    
    for name in names:
        for i in range(len(freqs[name])):
            print(name, splitarr[i], freqs[name][i]/sum(freqs[name]), freqs[name][i])

    print("background frequencies")
    # check background frequencies
    for i in range(len(splits)):
        for file in splits[i]:
            if "background" in file:
                freqs["background"][i] += 1

    for i in range(len(freqs["background"])):
        print("background", splitarr[i], freqs["background"][i]/sum(freqs["background"]), freqs["background"][i])

In [209]:
def reorganize_to_final(ftrain, fval, ftest):
    splits = ['train', 'valid', 'test']
    
    # clear existing final ds
    if os.path.exists(path + "/final_ds"):
        shutil.rmtree(path + "/final_ds")
    os.mkdir(path + "/final_ds")
    
    for split in splits:
        os.mkdir(path + f"/final_ds/{split}")
        os.mkdir(path + f"/final_ds/{split}/images")
        os.mkdir(path + f"/final_ds/{split}/labels")
        
        if split == 'train':
            thing = ftrain
        elif split == 'valid':
            thing = fval
        elif split == 'test':
            thing = ftest
        
        for file in thing:
            c = file.split("_")[0] # class name
            
            if os.path.exists(path + f"/org_ds/labels/{file[:-4]}.txt"):
                shutil.copy(path + f"/org_ds/labels/{file[:-4]}.txt", path + f"/final_ds/{split}/labels/{file[:-4]}.txt")
                shutil.copy(path + f"/org_ds/{c}/{file}", path + f"/final_ds/{split}/images/{file}")
            else:
                shutil.copy(path + f"/org_ds/backgrounds/{file}", path + f"/final_ds/{split}/images/{file}")
            


    # add data.yaml file
    with open(path + "/final_ds/data.yaml", "w") as file:
        file.write("train: final_ds/train/images"+"\n")
        file.write("test: final_ds/test/images"+"\n")
        file.write("val: final_ds/valid/images"+"\n")
        file.write("nc: 9"+"\n")
        file.write("names: ['aphanizomenon', 'detritus', 'dolichospermum', 'microcystis', 'oscillatoria', 'synechococcus', 'water bubble', 'woronichinia']")

In [210]:
def rebalance_dataset(**kwargs):
    useCLAHE = kwargs.get('useCLAHE', False)

    move_to_combined()

    if useCLAHE:
        apply_CLAHE()
    
    organize_to_names()
    
    ftrain, fval, ftest = get_train_val_test_splits()
    # check_freqs(ftrain, fval, ftest)
    print(fval)
    print(ftest)
    
    reorganize_to_final(ftrain, fval, ftest)
    # send to zip
    shutil.make_archive("final_ds", 'zip', path + "/final_ds")

In [211]:
rebalance_dataset(useCLAHE=True)
# print(len(os.listdir(path + "/final_ds/train/images")))
# print(len(os.listdir(path + "/final_ds/valid/images")))
# print(len(os.listdir(path + "/final_ds/test/images")))

1799
getting splits
anabaena 94 12 12
aphanizomenon 262 33 33
detritus 29 4 4
dolichospermum 258 32 33
microcystis 352 44 44
oscillatoria 144 18 18
synechococcus 10 1 2
water bubble 78 10 10
woronichinia 209 26 27
final lengths after stratified split:  1436 180 183
['anabaena_87.jpg' 'anabaena_13.jpg' 'anabaena_11.jpg' 'anabaena_68.jpg'
 'anabaena_67.jpg' 'anabaena_15.jpg' 'anabaena_77.jpg' 'anabaena_73.jpg'
 'anabaena_88.jpg' 'anabaena_93.jpg' 'anabaena_47.jpg' 'anabaena_57.jpg'
 'aphanizomenon_109.jpg' 'aphanizomenon_40.jpg' 'aphanizomenon_141.jpg'
 'aphanizomenon_182.jpg' 'aphanizomenon_256.jpg' 'aphanizomenon_252.jpg'
 'aphanizomenon_137.jpg' 'aphanizomenon_119.jpg' 'aphanizomenon_209.jpg'
 'aphanizomenon_145.jpg' 'aphanizomenon_272.jpg' 'aphanizomenon_168.jpg'
 'aphanizomenon_24.jpg' 'aphanizomenon_130.jpg' 'aphanizomenon_210.jpg'
 'aphanizomenon_99.jpg' 'aphanizomenon_258.jpg' 'aphanizomenon_302.jpg'
 'aphanizomenon_58.jpg' 'aphanizomenon_237.jpg' 'aphanizomenon_62.jpg'
 'aphaniz

In [212]:
def delete_all_folders():
    shutil.rmtree(path + "/combined_ds")
    shutil.rmtree(path + "/org_ds")
    shutil.rmtree(path + "/final_ds")

In [213]:
# delete_all_folders()

In [214]:
# print(len(os.listdir(path + "/final_ds/train/labels")))
# print(len(os.listdir(path + "/final_ds/valid/labels")))
# print(len(os.listdir(path + "/final_ds/test/labels")))

# check if all labels have corresponding images
for label in os.listdir(path + "/final_ds/train/labels"):
    with open(path + f"/final_ds/train/labels/{label}") as file:
        if len(file.readlines()) == 0:
            print(label)
    if not os.path.exists(path + f"/final_ds/train/images/{label[:-4]}.jpg"):
        print(label)

for label in os.listdir(path + "/final_ds/valid/labels"):
    with open(path + f"/final_ds/valid/labels/{label}") as file:
        if len(file.readlines()) == 0:
            print(label)
    if not os.path.exists(path + f"/final_ds/valid/images/{label[:-4]}.jpg"):
        print(label)

for label in os.listdir(path + "/final_ds/test/labels"):
    with open(path + f"/final_ds/test/labels/{label}") as file:
        if len(file.readlines()) == 0:
            print(label)
    if not os.path.exists(path + f"/final_ds/test/images/{label[:-4]}.jpg"):
        print(label)

In [223]:
check_freqs(globftrain, globfval, globftest)

printing frequency information
1516 190 193
0.7983149025803055
anabaena train 0.8 96
anabaena val 0.1 12
anabaena test 0.1 12
aphanizomenon train 0.7987804878048781 262
aphanizomenon val 0.10060975609756098 33
aphanizomenon test 0.10060975609756098 33
detritus train 0.7837837837837838 29
detritus val 0.10810810810810811 4
detritus test 0.10810810810810811 4
dolichospermum train 0.7987616099071208 258
dolichospermum val 0.09907120743034056 32
dolichospermum test 0.1021671826625387 33
microcystis train 0.7981859410430839 352
microcystis val 0.10204081632653061 45
microcystis test 0.09977324263038549 44
oscillatoria train 0.8 144
oscillatoria val 0.1 18
oscillatoria test 0.1 18
synechococcus train 0.7692307692307693 10
synechococcus val 0.07692307692307693 1
synechococcus test 0.15384615384615385 2
water bubble train 0.7959183673469388 78
water bubble val 0.10204081632653061 10
water bubble test 0.10204081632653061 10
woronichinia train 0.7977099236641222 209
woronichinia val 0.0992366412

In [216]:
print(len(os.listdir(path + "/final_ds/train/images")))
print(len(os.listdir(path + "/final_ds/valid/images")))
print(len(os.listdir(path + "/final_ds/test/images")))

1516
190
193
