In [1]:
import os 
import cv2
import time
import timm
import copy 
import random
import argparse
import datetime
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from glob import glob
from tqdm import tqdm
from sklearn import cluster
from sklearn.metrics import f1_score, accuracy_score
from sklearn.model_selection import StratifiedKFold

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
warnings.filterwarnings(action='ignore')


seed = 10
device = '0,1,2,3'

os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)


os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]=device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    
print('Device: %s' % device)
if (device.type == 'cuda') or (torch.cuda.device_count() > 1):
    print('GPU activate --> Count of using GPUs: %s' % torch.cuda.device_count())


seed = 10
suffix = (datetime.datetime.now() + datetime.timedelta(hours=9)).strftime("%y%m%d_%H%M")


Device: cuda
GPU activate --> Count of using GPUs: 4


In [2]:
def get_train_data(data_dir, mode='train'):
    img_path_list = []
    meta_path_list = []
    label_list = []
    
    for case_name in os.listdir(data_dir):
        current_path = os.path.join(data_dir, case_name)
        if os.path.isdir(current_path):
            # get image path
            img_path_list.extend(glob(os.path.join(current_path, 'image', '*.jpg')))
            img_path_list.extend(glob(os.path.join(current_path, 'image', '*.png')))
            
            # get meta path
            meta_path_list.extend(glob(os.path.join(current_path, 'meta', '*.csv')))
            
            # get label
            label_df = pd.read_csv(current_path+'/label.csv')
            label_list.extend(label_df['leaf_weight'])
                
    return img_path_list, meta_path_list, label_list

def get_test_data(data_dir):
        
    # get image path
    img_path_list = glob(os.path.join(data_dir, 'image', '*.jpg'))
    img_path_list.extend(glob(os.path.join(data_dir, 'image', '*.png')))
    img_path_list.sort(key=lambda x:int(x.split('/')[-1].split('.')[0]))
    
    # get meta path
    meta_path_list = glob(os.path.join(data_dir, 'meta', '*.csv'))
    
    return img_path_list, meta_path_list

In [3]:
all_img_path, all_meta_path, all_label = get_train_data('/data/KIST_PLANT/train')
test_img_path, test_meta_path = get_test_data('/data/KIST_PLANT/test')

In [None]:
row = 1
col = 3
n_clusters = 3

for idx in range(10,20):
    plt.figure(figsize=(25,15))
    img = cv2.imread(test_img_path[idx])[:,:,::-1]
    img = cv2.resize(img, (512, 512))
    
    # -----------------------------------------------------------
    plt.subplot(row, col, 1)    
    plt.imshow(img)
    img_n = img.reshape(img.shape[0]*img.shape[1], img.shape[2])
    
    # -----------------------------------------------------------
    agglo = cluster.FeatureAgglomeration(n_clusters=n_clusters).fit(img_n)
    X_reduced = agglo.transform(img_n)

    pic2show = X_reduced.reshape(img.shape[0], img.shape[1], n_clusters)
    pic2show = pic2show.astype('uint8')

    ret, mask = cv2.threshold(pic2show[:,:,0], 120, 255, cv2.THRESH_BINARY)
    mask_inv = cv2.bitwise_not(mask)

    plt.subplot(row, col, 2)  
    plt.imshow(mask_inv, cmap='gray')
    
    # -----------------------------------------------------------
    img_mask = cv2.bitwise_and(img, img, mask = mask)
    img_mask_inv = cv2.bitwise_and(img, img, mask = mask_inv)
    plt.subplot(row, col, 3)  
    plt.imshow(img_mask_inv)
    plt.show()

In [5]:
modes = ['TRAIN', 'TEST']
datas = [all_img_path, test_img_path]

for mode, dset in zip(modes, datas):
    
    n_clusters = 3    
    foreground_sum_list = []

    for idx in tqdm(range(len(dset)), total=len(dset)):

        img = cv2.imread(dset[idx])[:,:,::-1]
        img = cv2.resize(img, (512, 512))

        img_n = img.reshape(img.shape[0]*img.shape[1], img.shape[2])
        agglo = cluster.FeatureAgglomeration(n_clusters=n_clusters).fit(img_n)
        X_reduced = agglo.transform(img_n)

        pic2show = X_reduced.reshape(img.shape[0], img.shape[1], n_clusters)
        pic2show = pic2show.astype('uint8')

        ret, mask = cv2.threshold(pic2show[:,:,0], 120, 255, cv2.THRESH_BINARY)
        mask_inv = cv2.bitwise_not(mask)
        
        cv2.imwrite("/home/SY_LEE/KIST_PLANT/DATA_BACKGROUND_{}/".format(mode) + dset[idx].split('/')[-1].split('.')[0] + "_background.png", mask_inv)
        foreground_sum_list.append(mask_inv.sum() / ( img.shape[0] * img.shape[1]))

    foreground_sum_df = pd.DataFrame(foreground_sum_list, columns=['image_sum'])
    foreground_sum_df.to_csv('{}_foreground_sum_df.csv'.format(mode), index=False)


100%|██████████| 460/460 [08:59<00:00,  1.17s/it]


In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(30,12))

row = 1
col = 3

t = np.arange(0, 512)
t_func = 0.43 * t

# -----------------------------------------------------------
plt.subplot(row, col, 1)  
plt.scatter(all_label, foreground_sum_list)
plt.plot(t, t, '--')
plt.xlabel('leaf_weight')
plt.ylabel('segment_sum')
# -----------------------------------------------------------

plt.subplot(row, col, 2)  
plt.scatter(all_label, foreground_sum_list)
plt.plot(t, t_func, '--')
plt.xlabel('leaf_weight')
plt.ylabel('segment_sum')
plt.ylim(0, 200)

# -----------------------------------------------------------
plt.subplot(row, col, 3)  
t = np.arange(0, 512)
plt.scatter(all_label, np.log1p(foreground_sum_list))
plt.plot(t, np.log1p(t_func), '--')
plt.xlabel('leaf_weight')
plt.ylabel('segment_sum')
plt.show()