In [10]:
dir_data = '../data'
path_ckpt = '../ckpt/1695772427_pl/last_ckpt.bin'
outside1_fname = '/home/eunwoo/experiment/PSSC/Oneformer(instance)/result/background6/background_fish.pickle'
outside2_fname = '/home/eunwoo/experiment/PSSC/Oneformer(instance)/result/background6/background_total.pickle'

In [11]:
import sys
sys.path.append('../')

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

import cv2
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import albumentations as A
from tqdm import tqdm

import torch

from segformers.utils import custom_cmap, rle_encode
from segformers.detectors import Backgroud_detector
from segformers.networks import SegFormer


In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

state_dict = torch.load(path_ckpt)
model = SegFormer
model.load_state_dict(state_dict['model_state_dict'])
model.to(device);

In [13]:
import torch.nn.functional as F

@torch.no_grad()
def slide_inference(images, model, num_classes=13, crop_size=(1024, 1024), stride=(768, 768)):
    h_stride, w_stride = stride
    h_crop, w_crop = crop_size
    batch_size, _, h_img, w_img = images.size()

    h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
    w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
    preds = images.new_zeros((batch_size, num_classes, h_img, w_img))
    count_mat = images.new_zeros((batch_size, 1, h_img, w_img))
    for h_idx in range(h_grids):
        for w_idx in range(w_grids):
            y1 = h_idx * h_stride
            x1 = w_idx * w_stride
            y2 = min(y1 + h_crop, h_img)
            x2 = min(x1 + w_crop, w_img)
            y1 = max(y2 - h_crop, 0)
            x1 = max(x2 - w_crop, 0)
            crop_img = images[:, :, y1:y2, x1:x2]

            crop_seg_logit = model(pixel_values=crop_img)[0]
            crop_seg_logit = F.interpolate(
                crop_seg_logit,
                size=crop_size,
                mode="bilinear",
                align_corners=False
            )
            preds += F.pad(crop_seg_logit,
                            (int(x1), int(preds.shape[3] - x2), int(y1),
                            int(preds.shape[2] - y2)))

            count_mat[:, :, y1:y2, x1:x2] += 1
    assert (count_mat == 0).sum() == 0
    seg_logits = preds / count_mat

    return preds, count_mat

In [14]:
with open(outside1_fname, 'rb') as f:
    outside1 = pickle.load(f)

with open(outside2_fname, 'rb') as f:
    outside2 = pickle.load(f)
    
for k, v in outside1.items():
    outside1[k] = v.astype(bool)
    
outside_dict = dict()
for k in outside1.keys():
    outside_dict[k] = (outside1[k] + outside2[k]).astype(np.uint8)

In [None]:
df = pd.read_csv(os.path.join(dir_data, 'test.csv'))

result = []
model.eval()
for idx in tqdm(range(len(df))):
    img_path = os.path.join(dir_data, df.loc[idx, 'img_path'])
    original_image = cv2.imread(img_path)
    original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)

    outside = outside_dict[idx][np.newaxis, np.newaxis]
    outside = torch.as_tensor(outside).float().to(device)
    outside = torch.nn.functional.interpolate(
                outside,
                size=(540, 960),
                mode='nearest'
        )
    outside = outside[0][0].cpu().numpy()
    
    # Stage 1
    image = cv2.resize(original_image, (960, 540))
    image = A.Normalize()(image=image)['image']
    images = torch.as_tensor(image, dtype=torch.float, device=device).permute(2, 0, 1).unsqueeze(0)
    preds, count_mat = slide_inference(images, model, num_classes=13, stride=(50, 50), crop_size=(512, 512))

    # Stage 2
    image = cv2.resize(original_image, (1200, 675))
    image = A.Normalize()(image=image)['image']
    images = torch.as_tensor(image, dtype=torch.float, device=device).permute(2, 0, 1).unsqueeze(0)
    cur_preds, cur_count_mat = slide_inference(images, model, num_classes=13, stride=(50, 50), crop_size=(512, 512))
    preds += F.interpolate(cur_preds, size=(540, 960), mode="bilinear", align_corners=False)
    count_mat += F.interpolate(cur_count_mat, size=(540, 960), mode="bilinear", align_corners=False)

    # Stage 3
    image = cv2.resize(original_image, (1440, 810))
    image = A.Normalize()(image=image)['image']
    images = torch.as_tensor(image, dtype=torch.float, device=device).permute(2, 0, 1).unsqueeze(0)
    cur_preds, cur_count_mat = slide_inference(images, model, num_classes=13, stride=(50, 50), crop_size=(512, 512))
    preds += F.interpolate(cur_preds, size=(540, 960), mode="bilinear", align_corners=False)
    count_mat += F.interpolate(cur_count_mat, size=(540, 960), mode="bilinear", align_corners=False)

    logits = preds / count_mat
    _, predictions = logits.max(1)

    predictions = predictions[0].cpu().numpy()
    predictions[np.where(outside == 1)] = 12
    predictions = predictions.astype(np.int32)
    # class 0 ~ 11에 해당하는 경우에 마스크 형성 / 12(배경)는 제외하고 진행
    for class_id in range(12):
        class_mask = (predictions == class_id).astype(np.int32)
        if np.sum(class_mask) > 0: # 마스크가 존재하는 경우 encode
            mask_rle = rle_encode(class_mask)
            result.append(mask_rle)
        else: # 마스크가 존재하지 않는 경우 -1
            result.append(-1)

  0%|          | 0/1898 [00:00<?, ?it/s]

100%|██████████| 1898/1898 [5:43:31<00:00, 10.86s/it]  


In [None]:
import pandas as pd

submit = pd.read_csv('../data/sample_submission.csv')
submit['mask_rle'] = result
submit

Unnamed: 0,id,mask_rle
0,TEST_0000_class_0,212629 3 212637 4 213588 22 214532 3 214545 32...
1,TEST_0000_class_1,-1
2,TEST_0000_class_2,597 281 1557 281 2517 281 3476 282 4436 282 53...
3,TEST_0000_class_3,207753 8 208709 25 208777 9 209663 85 210618 9...
4,TEST_0000_class_4,-1
...,...,...
22771,TEST_1897_class_7,152250 8 153208 14 153224 12 154166 35 155124 ...
22772,TEST_1897_class_8,95 539 678 125 1055 539 1639 124 2015 539 2599...
22773,TEST_1897_class_9,-1
22774,TEST_1897_class_10,-1


In [None]:
submit.to_csv('./segformer-b5_multiscale_inference_pl2.csv', index=False)