In [1]:
import os
import random
import argparse
import numpy as np

import torch
from param import CONFIG

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
parser = argparse.ArgumentParser()
# General Settings
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--device', type=str, default='0')

parser.add_argument('--root', type=str, default='../dataset/image')

parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--input_size', nargs='+', type=int, default=[128, 128])

args = parser.parse_args(args=[])

In [3]:
device = 'cuda:' + args.device
args.device = torch.device(device)
torch.cuda.set_device(args.device)

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

In [14]:
from supernet_dense import SuperNet, SampledNetwork

CONFIG["GPU"] = [0]

sampled_net_path = 'darts_based_seg/output/2024-05-28/16_13_56/best_model.pt'

checkpoint = torch.load(sampled_net_path)
model = checkpoint['model']
model.load_state_dict(checkpoint['model_state_dict'])

SampledNetwork(
  (pretrained_net): DenseNet(
    (features): Sequential(
      (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu0): ReLU(inplace=True)
      (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (denseblock1): _DenseBlock(
        (denselayer1): _DenseLayer(
          (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (denselayer2): _DenseLayer(
          (norm1): BatchNorm2d(96, eps=1e-05

In [16]:
import cv2
from torch.utils.data import Dataset

class InferenceDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data = []
        self.data_dir = data_dir
        self.transform = transform

        for folder in self.data_dir:
            for file in os.listdir(folder):
                if file.endswith(".jpg"):
                    self.data.append(os.path.join(folder, file))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # check if data and label are the same name after ../(dir)/
        image = cv2.imread(self.data[idx])
        file_name = self.data[idx]

        if self.transform:
            image = self.transform(image)
        
        return image, file_name

In [20]:
from torchvision import transforms
from torch.utils.data import DataLoader

from dataloaders import load_data, train_val_test_split

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(args.input_size),
    ])

data = load_data(args.root)
_, _, test_data = train_val_test_split(data)
test_dataset = InferenceDataset(test_data, transform)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

### Hot-Stamping

In [31]:
def expand_range(range_str):
    start, end = range_str.split('~')
    prefix = start[:-4]
    start_num = int(start[-4:])
    end_num = int(end[-4:])
    return [f"{prefix}{i:04d}" for i in range(start_num, end_num + 1)]

hot_stamping = {
    "CE": (
        expand_range("F003-0003~0008") +
        expand_range("F003-0013~0025") +
        expand_range("F005-0001~0049")
    ),
    "DF": expand_range("F001-0012~0023"),
    "GN7 일반": (
        expand_range("F001-0029~0041") +
        expand_range("F002-0013~0031") +
        expand_range("F003-0010~0018")
    ),
    "GN7 파노라마": (
        expand_range("F001-0029~0041") +
        expand_range("F002-0013~0031") +
        expand_range("F003-0010~0019")
    )
}

In [32]:
# sampled_net.eval()
with torch.no_grad():
    for batch_idx, (data, file_path) in enumerate(test_loader):
        
        for i in range(data.size()[0]):
            path_list = file_path[i].split('/')
            category = path_list[3].split('~')[0]
            file_name = os.path.splitext(path_list[4])[0]

            hot_stamping_list = hot_stamping[category]

            if file_name in hot_stamping_list:
                print(category + "/" + file_name)


CE/F003-0024
CE/F005-0010
CE/F005-0034
CE/F003-0016
CE/F005-0045
CE/F005-0031
CE/F003-0014
CE/F005-0041
CE/F005-0037
CE/F005-0035
CE/F005-0008
CE/F005-0026
CE/F005-0048
CE/F003-0020
CE/F003-0006
CE/F005-0030
CE/F003-0004
CE/F005-0038
CE/F005-0043
CE/F005-0049
CE/F003-0023
CE/F005-0023
CE/F003-0022
CE/F005-0003
CE/F003-0025
CE/F005-0029
CE/F003-0013
CE/F003-0017
CE/F005-0007
CE/F005-0020
CE/F005-0009
CE/F005-0019
CE/F005-0024
CE/F003-0005
CE/F005-0022
CE/F005-0039
CE/F005-0028
CE/F005-0033
CE/F003-0007
CE/F005-0005
CE/F005-0040
CE/F003-0018
CE/F005-0044
CE/F005-0032
CE/F003-0008
CE/F005-0027
CE/F005-0018
CE/F005-0006
CE/F005-0013
CE/F005-0002
CE/F005-0017
CE/F005-0046
CE/F005-0036
CE/F005-0016
CE/F003-0021
CE/F005-0042
CE/F005-0015
CE/F003-0019
CE/F005-0011
CE/F005-0021
CE/F005-0004
CE/F005-0012
CE/F005-0025
CE/F003-0015
CE/F005-0047
DF/F001-0018
DF/F001-0023
DF/F001-0013
DF/F001-0019
DF/F001-0022
DF/F001-0014
DF/F001-0017
DF/F001-0021
DF/F001-0015
DF/F001-0012
DF/F001-0016
DF/F001-0020

### Inference

In [None]:
output_dir = 'prediction/'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)