In [None]:
import numpy as np
import random
import os
import pandas as pd
import pickle
import torch 
from glob import glob
import cv2 

In [None]:
CFG = {
    'IMG_SIZE':224,
    'EPOCHS':80,
    'LEARNING_RATE':1e-3,
    'BATCH_SIZE':32,
    'log': True,
    'SEED':41
}

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

In [7]:
# 훈련 데이터 생성
def get_train_data_file(data_dir):
    img_path_list = []
    label_list = []
    img_list = []
    for case_name in os.listdir(data_dir):
        current_path = os.path.join(data_dir, case_name)
        if os.path.isdir(current_path):
            # get image path
            img_path_list.extend(glob(os.path.join(current_path, 'image', '*.jpg')))
            img_path_list.extend(glob(os.path.join(current_path, 'image', '*.png')))
            
            # get label
            label_df = pd.read_csv(current_path+'/label.csv')
            if case_name == 'CASE45': # 노이즈 데이터 제거
                label_df = label_df.drop(label_df[label_df.img_name == "CASE45_17.png"].index) 
            label_list.extend(label_df['leaf_weight'])
    for i in img_path_list:
        image = cv2.imread(i)
        image = cv2.resize(image, (256,256), interpolation=cv2.INTER_AREA)
        img_list.append(image)
    return img_list, label_list

# 현재 무게에 대한 훈련 데이터 생성
def get_train_current_data_file(data_dir):
    not_single_series = {
    'CASE73':[0,10,14]
    }
    img_path_list = []
    label_list = []
    img_list = []

    for case_name in os.listdir(data_dir):
        current_path = os.path.join(data_dir, case_name)
        if os.path.isdir(current_path):
            # get image path
            if case_name in not_single_series.keys():
                idx = not_single_series[case_name]
            else:
                idx = [0]
            try:
                case_img_path_list = glob(os.path.join(current_path, 'image', '*.jpg'))
                case_img_path_list = np.array(case_img_path_list)
                case_img_path_list = list(np.delete(case_img_path_list, idx))
            except:
                pass
            img_path_list.extend(case_img_path_list)
            try:
                case_img_path_list = glob(os.path.join(current_path, 'image', '*.png'))
                case_img_path_list = np.array(case_img_path_list)
                case_img_path_list = list(np.delete(case_img_path_list, idx))
            except:
                pass
            img_path_list.extend(case_img_path_list)
            

            # get label
            label_df = pd.read_csv(current_path+'/label.csv')
            if case_name == 'CASE45': # 노이즈 데이터 제거
                label_df = label_df.drop(label_df[label_df.img_name == "CASE45_17.png"].index) 
            label_df['leaf_weight'] = label_df['leaf_weight'].shift(1)
            if case_name in not_single_series.keys():
                idx = not_single_series[case_name]
            else:
                idx = [0]
            label_df.drop(idx, axis=0, inplace=True)
            label_list.extend(label_df['leaf_weight'])
            
    for i in range(len(img_path_list)):
        
        image = cv2.imread(img_path_list[i])
        image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA) 
        img_list.append(image)
                
    return img_list, label_list

# 테스트 데이터 생성
def get_test_data_file(data_dir):
    img_list = []
    # get image path
    img_path_list = glob(os.path.join(data_dir, 'image', '*.jpg'))
    img_path_list.extend(glob(os.path.join(data_dir, 'image', '*.png')))
    img_path_list.sort(key=lambda x:int(x.split('\\')[-1].split('.')[0]))
    for i in img_path_list:
        image = cv2.imread(i)
        image = cv2.resize(image, (256,256), interpolation=cv2.INTER_AREA)
        img_list.append(image)
    return img_list

In [None]:
# CASE45_17 이미지와 메타 데이터 제거
rm_paths = ['./train/CASE45/image/CASE45_17.png',
          './train/CASE45/meta/CASE45_17.csv']
for rm_path in rm_paths:
    if os.path.exists(rm_path):
        os.remove(rm_path)
        print(rm_path, ' removed..')
    else:
        pass

In [None]:
# 데이터셋 생성
all_img_list, all_label = get_train_data_file('./train') # 하루 뒤 무게
all_img_curr_list, all_curr_label = get_train_current_data_file('./train') # 현재 무게
test_img_list = get_test_data_file('./test')


# 데이터셋 저장
with open('all_img_256.pkl', 'wb') as f:
    pickle.dump(all_img_list, f)
with open('all_label_rm_CASE45-17.pkl', 'wb') as f:
    pickle.dump(all_label, f)
with open('test_img_256.pkl', 'wb') as f:
    pickle.dump(test_img_list, f)
with open('all_img_curr_256.pkl', 'wb') as f:
    pickle.dump(all_img_curr_list, f)
with open('all_label_curr_rm_CASE45-17.pkl', 'wb') as f:
    pickle.dump(all_curr_label, f)