In [1]:
import os
import random
import argparse
import numpy as np

import torch
import torch.backends.cudnn as cudnn

In [2]:
parser = argparse.ArgumentParser()
# General Settings
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--device', type=str, default='0')

parser.add_argument('--root', type=str, default='../data/inference')

parser.add_argument('--model', type=str, default='DeepLabv3', choices=['DeepLabv3', 'ESPNet', 'STDC'])

parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--input_size', nargs='+', type=int, default=[128, 128])

args = parser.parse_args(args=[])

In [3]:
device = 'cuda:' + args.device
args.device = torch.device(device)
torch.cuda.set_device(args.device)

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
cudnn.deterministic = True
cudnn.benchmark = False

In [4]:
from seg_models.models import load_model

model_path = 'output/2024-05-27/23_20_33/best_model.pt'

model = load_model(args.model)
model.load_state_dict(torch.load(model_path))
model.to(device)

DeepLabV3Plus(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequentia

In [5]:
from torchvision import transforms
from torch.utils.data import DataLoader

from seg_utils.dataset_utils import load_data, InferenceDataset

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(args.input_size),
    ])

inf_data = load_data(args.root)

inf_dataset = InferenceDataset(inf_data, transform)
inf_loader = DataLoader(inf_dataset, batch_size=args.batch_size, shuffle=False)

In [6]:
import torchvision.transforms.functional as TF

model.eval()
with torch.no_grad():
    for batch_idx, (data, file_name) in enumerate(inf_loader):
        data = data.to(device)

        output = model(data)
        output = (output > 0.5).float()
        
        for i in range(data.size()[0]):
                file_idx = file_name[i].split('/')
                folder_name = file_idx[3].split('~')[0]
                file_idx = '/'.join(file_idx[-2:])

                # ori_img = data[i].cpu().numpy().transpose(1, 2, 0) * 255.0
                # ori_img = TF.to_pil_image(ori_img.astype(np.uint8))
                
                out_img = output[i].cpu().numpy().squeeze() * 255.0
                out_img = TF.to_pil_image(out_img.astype(np.uint8))

                # combined_img = ori_img.copy()
                # combined_img[out_img == 255] = [255, 255, 255]  # White color for mask
                # combined_img[out_img != 255] = [0, 0, 0]  # White color for mask
                # combined_img = TF.to_pil_image(combined_img.astype(np.uint8))

                save_path = os.path.join('results', folder_name, file_idx)
                os.makedirs(os.path.dirname(save_path), exist_ok=True)

                out_img.save(save_path)
                
print("Finished")

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Finished
