In [None]:
# dataloader : image_generator

class image_generator(Dataset):
    def __init__(self, paths, input_size, scaler, mode='train', logger=None, verbose=False):
        self.x_paths = paths
        self.y_paths = list(map(lambda x : x.replace('x', 'y'),self.x_paths))
        self.input_size = input_size
        self.scaler = scaler
        self.logger = logger
        self.verbose = verbose
        self.mode = mode
        
    def __len__(self):
        return len(self.x_paths)
    
    def __getitem__(self, id_: int):
        filename = os.path.basename(self.x_paths[id_])
        image = cv2.imread(self.x_paths[id_], cv2.IMREAD_COLOR)
        orig_size = image.shape
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        half = orig_size[1]//2
        img1 = image[:,:half,:]
        img2 = image[:,half:,:]

        difference = img1 - img2
        difference_size = difference.shape

        x = cv2.resize(difference, self.input_size)
        x = self.scaler(x)
        x = np.transpose(x, (2, 0, 1))

        if self.mode in ['train', 'valid']:
            y = cv2.imread(self.y_paths[id_], cv2.IMREAD_GRAYSCALE)
            y_seg = y[:,y.shape[1]//2:]
            y_seg = cv2.resize(y_seg, self.input_size, interpolation=cv2.INTER_NEAREST)
        
            return x, y_seg, filename
        
        elif self.mode in ['test']:
            return x, difference_size, filename

        else:
            assert False, f"Invalid mode : {self.mode}"

NameError: ignored

In [None]:
"""
Predict
"""
from datetime import datetime
from tqdm import tqdm
import numpy as np
import random, os, sys, torch, cv2, warnings
from glob import glob
from torch.utils.data import DataLoader

prj_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(prj_dir)

from baseline_modules.utils import load_yaml, save_yaml, get_logger
from baseline_modules.scalers import get_image_scaler
from baseline_modules.datasets import SegDataset, image_generator
from models.utils import get_model
warnings.filterwarnings('ignore')

if __name__ == '__main__':

    #! Load config
    config = load_yaml(os.path.join(prj_dir, 'config', 'predict.yaml'))
    train_config = load_yaml(os.path.join(prj_dir, 'results', 'train', config['train_serial'], 'train.yaml'))
    
    #! Set predict serial
    pred_serial = config['train_serial'] + '_' + datetime.now().strftime("%Y%m%d_%H%M%S")

    # Set random seed, deterministic
    torch.cuda.manual_seed(train_config['seed'])
    torch.manual_seed(train_config['seed'])
    np.random.seed(train_config['seed'])
    random.seed(train_config['seed'])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Set device(GPU/CPU)
    os.environ['CUDA_VISIBLE_DEVICES'] = str(config['gpu_num'])
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Create train result directory and set logger
    pred_result_dir = os.path.join(prj_dir, 'results', 'pred', pred_serial)
    pred_result_dir_mask = os.path.join(prj_dir, 'results', 'pred', pred_serial, 'mask')
    os.makedirs(pred_result_dir, exist_ok=True)
    os.makedirs(pred_result_dir_mask, exist_ok=True)

    # Set logger
    logging_level = 'debug' if config['verbose'] else 'info'
    logger = get_logger(name='train',
                        file_path=os.path.join(pred_result_dir, 'pred.log'),
                        level=logging_level)

    # Set data directory
    test_dirs = os.path.join(prj_dir, 'data', 'test')
    test_img_paths = glob(os.path.join(test_dirs, 'x', '*.png'))

    #! Load data & create dataset for train 
    test_dataset = image_generator(paths=test_img_paths,
                            input_size=[train_config['input_width'], train_config['input_height']],
                            scaler=get_image_scaler(train_config['scaler']),
                            mode='test',
                            logger=logger)

    # Create data loader
    test_dataloader = DataLoader(dataset=test_dataset,
                                batch_size=config['batch_size'],
                                num_workers=config['num_workers'],
                                shuffle=False,
                                drop_last=False)
    logger.info(f"Load test dataset: {len(test_dataset)}")

    # Load architecture
    model = get_model(model_str=train_config['architecture'])
    model = model(
                classes=train_config['n_classes'],
                encoder_name=train_config['encoder'],
                encoder_weights=train_config['encoder_weight'],
                activation=train_config['activation']).to(device)
    logger.info(f"Load model architecture: {train_config['architecture']}")

    #! Load weight
    check_point_path = os.path.join(prj_dir, 'results', 'train', config['train_serial'], 'model.pt')
    check_point = torch.load(check_point_path)
    model.load_state_dict(check_point['model'])
    logger.info(f"Load model weight, {check_point_path}")

    # Save config
    save_yaml(os.path.join(pred_result_dir, 'train_config.yml'), train_config)
    save_yaml(os.path.join(pred_result_dir, 'predict_config.yml'), config)
    
    # Predict
    logger.info(f"START PREDICTION")

    model.eval()

    with torch.no_grad():

        for batch_id, (x, difference_size, filename) in enumerate(tqdm(test_dataloader)):
            
            x = x.to(device, dtype=torch.float)
            y_pred = model(x)
            y_pred_argmax = y_pred.argmax(1).cpu().numpy().astype(np.uint8)
            difference_size = [(difference_size[0].tolist()[i], difference_size[1].tolist()[i]) for i in range(len(difference_size[0]))]

            # Save predict result
            for filename_, orig_size_, y_pred_ in zip(filename, difference_size, y_pred_argmax):
                resized_img = cv2.resize(y_pred_, [orig_size_[1], orig_size_[0]], interpolation=cv2.INTER_NEAREST)
                blank_img = np.zeros([orig_size_[0], orig_size_[1]], dtype=np.uint8)
                concated_img = cv2.hconcat([blank_img, resized_img])
                cv2.imwrite(os.path.join(pred_result_dir_mask, filename_), concated_img)
    
    logger.info(f"END PREDICTION")

NameError: ignored