In [1]:
import os
import random
import argparse
import numpy as np

import torch
import torch.backends.cudnn as cudnn

In [2]:
parser = argparse.ArgumentParser()
# General Settings
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--device', type=str, default='0')

parser.add_argument('--root', type=str, default='./dataset/inference')

parser.add_argument('--model', type=str, default='DeepLabv3', choices=['DeepLabv3', 'ESPNet', 'STDC'])

parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--input_size', nargs='+', type=int, default=[128, 128])

args = parser.parse_args(args=[])

In [6]:
device = 'cuda:' + args.device
args.device = torch.device(device)
torch.cuda.set_device(args.device)

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
cudnn.deterministic = True
cudnn.benchmark = False

TypeError: can only concatenate str (not "torch.device") to str

In [None]:
from deeplabv3plus.seg_models.models import load_model

model_path = 'deeplabv3/output/2024-06-08/23_59_15/best_model.pt'

model = load_model(args.model)
model.load_state_dict(torch.load(model_path))
model.to(device)

DeepLabV3Plus(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequentia

In [7]:
from torchvision import transforms
from torch.utils.data import DataLoader

from deeplabv3.seg_utils.dataset_utils import load_data, InferenceDataset

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(args.input_size),
    ])

inf_data = load_data(args.root)

inf_dataset = InferenceDataset(inf_data, transform)
inf_loader = DataLoader(inf_dataset, batch_size=args.batch_size, shuffle=False)

inf_data

['./dataset/inference/GN7 파노라마~07_08_45',
 './dataset/inference/GN7 파노라마~06_24_31',
 './dataset/inference/GN7 파노라마~08_02_05',
 './dataset/inference/GN7 파노라마~08_20_09',
 './dataset/inference/GN7 파노라마~07_57_13',
 './dataset/inference/GN7 파노라마~07_39_39',
 './dataset/inference/GN7 파노라마~07_48_07',
 './dataset/inference/GN7 파노라마~07_43_34',
 './dataset/inference/GN7 파노라마~06_54_04',
 './dataset/inference/GN7 파노라마~08_20_12',
 './dataset/inference/GN7 파노라마~07_26_34',
 './dataset/inference/GN7 파노라마~07_13_58',
 './dataset/inference/GN7 파노라마~06_21_05',
 './dataset/inference/GN7 파노라마~08_12_20',
 './dataset/inference/GN7 파노라마~07_43_55',
 './dataset/inference/GN7 파노라마~06_51_15',
 './dataset/inference/GN7 파노라마~08_05_55',
 './dataset/inference/GN7 파노라마~07_33_53',
 './dataset/inference/GN7 파노라마~07_40_15',
 './dataset/inference/GN7 파노라마~07_14_33',
 './dataset/inference/GN7 파노라마~08_08_12',
 './dataset/inference/GN7 파노라마~08_03_45',
 './dataset/inference/GN7 파노라마~08_20_44',
 './dataset/inference/GN7 파노라마~07_

In [8]:
import torchvision.transforms.functional as TF

model.eval()
with torch.no_grad():
    for batch_idx, (data, file_name) in enumerate(inf_loader):
        data = data.to(device)

        output = model(data)
        output = (output > 0.5).float()
        
        for i in range(data.size()[0]):
            file_idx = file_name[i].split('/')
            folder_name = file_idx[3].split('~')[0]
            file_idx = '/'.join(file_idx[-2:])

            # ori_img = data[i].cpu().numpy().transpose(1, 2, 0) * 255.0
            # ori_img = TF.to_pil_image(ori_img.astype(np.uint8))
            
            out_img = output[i].cpu().numpy().squeeze() * 255.0
            out_img = TF.to_pil_image(out_img.astype(np.uint8))

            # combined_img = ori_img.copy()
            # combined_img[out_img == 255] = [255, 255, 255]  # White color for mask
            # combined_img[out_img != 255] = [0, 0, 0]  # White color for mask
            # combined_img = TF.to_pil_image(combined_img.astype(np.uint8))

            save_path = os.path.join('results', folder_name, file_idx)
            print(save_path)
            os.makedirs(os.path.dirname(save_path), exist_ok=True)

            out_img.save(save_path)
                
print("Finished")

Infernece
results/GN7 파노라마/GN7 파노라마~07_08_45/F001-0018.jpg
results/GN7 파노라마/GN7 파노라마~07_08_45/F001-0023.jpg
results/GN7 파노라마/GN7 파노라마~07_08_45/F002-0040.jpg
results/GN7 파노라마/GN7 파노라마~07_08_45/F002-0047.jpg
results/GN7 파노라마/GN7 파노라마~07_08_45/F003-0016.jpg
results/GN7 파노라마/GN7 파노라마~07_08_45/F002-0019.jpg
results/GN7 파노라마/GN7 파노라마~07_08_45/F004-0002.jpg
results/GN7 파노라마/GN7 파노라마~07_08_45/F001-0024.jpg
Infernece
results/GN7 파노라마/GN7 파노라마~07_08_45/F003-0014.jpg
results/GN7 파노라마/GN7 파노라마~07_08_45/F001-0029.jpg
results/GN7 파노라마/GN7 파노라마~07_08_45/F002-0009.jpg
results/GN7 파노라마/GN7 파노라마~07_08_45/F002-0006.jpg
results/GN7 파노라마/GN7 파노라마~07_08_45/F002-0049.jpg
results/GN7 파노라마/GN7 파노라마~07_08_45/F002-0048.jpg
results/GN7 파노라마/GN7 파노라마~07_08_45/F002-0046.jpg
results/GN7 파노라마/GN7 파노라마~07_08_45/F002-0024.jpg
Infernece
results/GN7 파노라마/GN7 파노라마~07_08_45/F001-0034.jpg
results/GN7 파노라마/GN7 파노라마~07_08_45/F004-0001.jpg
results/GN7 파노라마/GN7 파노라마~07_08_45/F002-0013.jpg
results/GN7 파노라마/GN7 파노라마~07_08_45/F003