**요약**
- 학습된 SegFormerDANN로부터 테스트 데이터에 대한 예측을 합니다.

<br>

**Inputs:**
- `dir_data`: 데이터가 있는 디렉토리
- `dir_save`: 각 테스트 이미지에 대한 logit 파일을 저장할 폴더
- `path_ckpt`: Inference에 사용할 SegFormerDANN 모델의 체크포인트 경로

<br>

**Outputs**:
- f`{dir_save}/0000.pt`: 각 테스트 이미지에 대한 logit이 저장된 `pt` 파일

In [1]:
dir_data = '../data'
dir_save = '../outputs/SegFormer_DANN'
path_ckpt = '../ckpt/segformer_dann/best_ckpt_0048.bin'

In [2]:
import sys
sys.path.append('../')

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

import cv2
import pickle
import numpy as np
import pandas as pd
import albumentations as A
from tqdm import tqdm

import torch

from segformers.utils import rle_encode
from segformers.networks import SegFormer
from dann.DomainAdaptation import SegFormerDANN2


  from .autonotebook import tqdm as notebook_tqdm
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b5-finetuned-cityscapes-1024-1024 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([19, 768, 1, 1]) in the checkpoint and torch.Size([13, 768, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([19]) in the checkpoint and torch.Size([13]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of Mask2FormerForUniversalSegmentation were not initialized from the model checkpoint at facebook/mask2former-swin-large-cityscapes-semantic and are newly initialized because the shapes did not match:
- class_predictor.bias: found shape torch.Size([20]) in the checkpoint and torch.Size([14]) in the model instantiated
- class_predictor.weigh

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

state_dict = torch.load(path_ckpt)
model = SegFormerDANN2(segformer=SegFormer)
model.load_state_dict(state_dict['model_state_dict'])
model.to(device);

if not os.path.exists(dir_save):
        os.makedirs(dir_save)

In [4]:
import torch.nn.functional as F


@torch.no_grad()
def slide_inference(images, model, num_classes=13, crop_size=(1024, 1024), stride=(768, 768)):
    h_stride, w_stride = stride
    h_crop, w_crop = crop_size
    batch_size, _, h_img, w_img = images.size()

    h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
    w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
    preds = images.new_zeros((batch_size, num_classes, h_img, w_img))
    count_mat = images.new_zeros((batch_size, 1, h_img, w_img))
    for h_idx in range(h_grids):
        for w_idx in range(w_grids):
            y1 = h_idx * h_stride
            x1 = w_idx * w_stride
            y2 = min(y1 + h_crop, h_img)
            x2 = min(x1 + w_crop, w_img)
            y1 = max(y2 - h_crop, 0)
            x1 = max(x2 - w_crop, 0)
            crop_img = images[:, :, y1:y2, x1:x2]

            try:
                crop_seg_logit = model(crop_img)[0]  # Try calling the model directly with crop_img
            except TypeError:  # Catch the TypeError if model cannot be called with crop_img directly
                crop_seg_logit = model(pixel_values=crop_img)[0]  # Call the model with pixel_values argument

            crop_seg_logit = F.interpolate(
                crop_seg_logit,
                size=crop_size,
                mode="bilinear",
                align_corners=False
            )
            preds += F.pad(crop_seg_logit,
                            (int(x1), int(preds.shape[3] - x2), int(y1),
                            int(preds.shape[2] - y2)))

            count_mat[:, :, y1:y2, x1:x2] += 1
    assert (count_mat == 0).sum() == 0
    seg_logits = preds / count_mat

    return preds, count_mat

In [None]:
df = pd.read_csv(os.path.join(dir_data, 'test.csv'))

result = []
model.eval()
for idx in tqdm(range(len(df))):
    img_path = os.path.join(dir_data, df.loc[idx, 'img_path'])
    original_image = cv2.imread(img_path)
    original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
   
    # Stage 1
    image = cv2.resize(original_image, (960, 540))
    image = A.Normalize()(image=image)['image']
    images = torch.as_tensor(image, dtype=torch.float, device=device).permute(2, 0, 1).unsqueeze(0)
    preds, count_mat = slide_inference(images, model, num_classes=13, stride=(50, 50), crop_size=(512, 512))

    # Stage 2
    image = cv2.resize(original_image, (1200, 675))
    image = A.Normalize()(image=image)['image']
    images = torch.as_tensor(image, dtype=torch.float, device=device).permute(2, 0, 1).unsqueeze(0)
    cur_preds, cur_count_mat = slide_inference(images, model, num_classes=13, stride=(50, 50), crop_size=(512, 512))
    preds += F.interpolate(cur_preds, size=(540, 960), mode="bilinear", align_corners=False)
    count_mat += F.interpolate(cur_count_mat, size=(540, 960), mode="bilinear", align_corners=False)

    # Stage 3
    image = cv2.resize(original_image, (1440, 810))
    image = A.Normalize()(image=image)['image']
    images = torch.as_tensor(image, dtype=torch.float, device=device).permute(2, 0, 1).unsqueeze(0)
    cur_preds, cur_count_mat = slide_inference(images, model, num_classes=13, stride=(50, 50), crop_size=(512, 512))
    preds += F.interpolate(cur_preds, size=(540, 960), mode="bilinear", align_corners=False)
    count_mat += F.interpolate(cur_count_mat, size=(540, 960), mode="bilinear", align_corners=False)

    logits = preds / count_mat
    _, predictions = logits.max(1)

    # logits (background 처리 전) 텐서를 .pt 파일로 저장
    tensor_save_path = os.path.join(dir_save, f"prediction_{idx}.pt")
    torch.save(logits, tensor_save_path)

    predictions = predictions[0].cpu().numpy()
    predictions = predictions.astype(np.int32)
   
    # class 0 ~ 11에 해당하는 경우에 마스크 형성 / 12(배경)는 제외하고 진행
    for class_id in range(12):
        class_mask = (predictions == class_id).astype(np.int32)
        if np.sum(class_mask) > 0: # 마스크가 존재하는 경우 encode
            mask_rle = rle_encode(class_mask)
            result.append(mask_rle)
        else: # 마스크가 존재하지 않는 경우 -1
            result.append(-1)

In [None]:
import pandas as pd

submit = pd.read_csv('../data/sample_submission.csv')
submit['mask_rle'] = result
submit

In [21]:
submit.to_csv('./segformer-dannv2.csv', index=False)