In [1]:
# !pip install gdown > /dev/null
# !pip install pytorch_lightning > /dev/null
# !pip install omegaconf > /dev/null
# !pip install einops > /dev/null
# !pip install torchmetrics > /dev/null
# !pip install wandb > /dev/null
# !pip install lightning-bolts > /dev/null
# !pip install albumentations > /dev/null
# !pip install timm > /dev/null

In [2]:
import os
import glob

import cv2
import gdown
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image

import ttach as tta
import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader, Dataset

In [3]:
def show_image(image, figsize=(5, 5), cmap=None, title='', xlabel=None, ylabel=None, axis=False):
    plt.figure(figsize=figsize)
    plt.imshow(image, cmap=cmap)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.axis(axis)
    plt.show();
    
def show_image_mask(image, mask, name='', figsize=(10, 15), axis=False):
    plt.figure(figsize=figsize)
    
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title(f"Image {name}")
    plt.axis(axis)

    plt.subplot(1, 2, 2)
    plt.imshow(mask)
    plt.title(f"Mask {name}")
    plt.axis(axis)
    
    plt.show();
    
def get_img_names(folder, img_format='png'):
    img_paths = glob.glob(os.path.join(folder, f'*.{img_format}'))
    img_names = [os.path.basename(x) for x in img_paths]
    return img_names

def preprocess_image(image):
    input_img = image.copy()
    img = cv2.resize(input_img, (GLOBAL_CONFIG['IMG_W'], GLOBAL_CONFIG['IMG_H']))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32) / 255.0
    img = torch.from_numpy(img).permute(2, 0, 1)
    return img

def process_multimask2np(image):
    img = image.cpu().clone()
    img = img.permute(1, 2, 0).numpy().astype(bool)
    h, w, c = img.shape
    mask = np.zeros((h, w), dtype=np.uint8)
    
    for c_index in range(c):
        mask[img[:, :, c_index]] = LABELS[c_index]
    
    return mask

class TestDataset(Dataset):
    def __init__(self, images_names, images_folder, augmentations=None):
        self.images_folder = images_folder
        self.images_names = images_names
        self.augmentations = augmentations

    def __len__(self):
        return len(self.images_names)

    def __getitem__(self, index):
        img_name = self.images_names[index]
        img_path = os.path.join(self.images_folder, img_name)
        
        image = cv2.imread(img_path)
        h, w = image.shape[:2]
        
        if self.augmentations is not None:
            image = self.augmentations(image=image)['image']
        
        image = preprocess_image(image)
        return {'image': image, 'image_name': img_name, 'orig_h': h, 'orig_w': w}

In [4]:
TEST_IMAGES_FOLDER = '/home/raid/datasets/niias_rzd/data/test/images/'
# TEST_IMAGES_FOLDER = # PATH TO TEST IMAGES
TEST_IMG_NAMES = get_img_names(TEST_IMAGES_FOLDER)

LABELS = [0, 6, 7, 10]

In [5]:
GLOBAL_CONFIG = {
    'device': 'cuda:0',
    
    'IMG_H': 512,
    'IMG_W': 512*2,
    
    'predict_folder': f'../output/submission_images',
}

In [10]:
## ЗАГРУЗКА МОДЕЛЕЙ (ОПЦИОНАЛЬНО)

models_storage = [
    {
        'name': 'unetplus_resnet34_0.pt',
        'id': '1S-UeCavyHUtX7GBT1Uijdb2SXYYpRSEh'
    },
    {
        'name': 'unetplus_resnet34_1.pt',
        'id': '1CWk7h0gUDUu4abCiWY7nW16sz4tb_nC1'
    },
    {
        'name': 'unetplus_resnet34_2.pt',
        'id': '17S7t8W2gDTt1IQhck8Jqxh-Muzy3lDAb'
    },
    
    {
        'name': 'unet_resnet34_0.pt',
        'id': '11FnssFXa2f6zZEV9COcMJ61rSsjMIITQ'
    },
    {
        'name': 'unet_resnet34_1.pt',
        'id': '1AGOyCzjFxivAxJAkxFY3NJVCYNm19RAe'
    },
    {
        'name': 'unet_resnet34_2.pt',
        'id': '1Aa6z6zG1SjPwhZWZ-od27yvO4a-uWOml'
    },
    
    {
        'name': 'unet_resnext50_32x4d_0.pt',
        'id': '1tAmGGeXlOU7LR5umo3-L8P8ws7qQYSpE'
    },
    {
        'name': 'unet_resnext50_32x4d_1.pt',
        'id': '1_xY0ZgXaR4dELnsAiRDbXZJAkeZffdNO'
    },
    {
        'name': 'unet_resnext50_32x4d_2.pt',
        'id': '1fhW2dx8mCrBF1TL-PCqY5pqTleacnEB8'
    },
]

url_template = 'https://drive.google.com/uc?id={}'

for item in tqdm(models_storage):
    out_name = os.path.join('../models', item['name'])
    url = url_template.format(item['id'])
    gdown.download(url, out_name, quiet=True)

In [11]:
models = [
    {
        'model': smp.Unet(encoder_name="resnet34", classes=len(LABELS)),
        'ckpt_path': '../models/unet_resnet34_0.pt',
        'name': 'unet_resnet34',
        'use_tta': False,
        'weight': 0.1,
    },
    {
        'model': smp.Unet(encoder_name="resnet34", classes=len(LABELS)),
        'ckpt_path': '../models/unet_resnet34_1.pt',
        'name': 'unet_resnet34',
        'use_tta': False,
        'weight': 0.1,
    },
    {
        'model': smp.Unet(encoder_name="resnet34", classes=len(LABELS)),
        'ckpt_path': '../models/unet_resnet34_2.pt',
        'name': 'unet_resnet34',
        'use_tta': False,
        'weight': 0.1,
    },
    {
        'model': smp.Unet(encoder_name="resnext50_32x4d", classes=len(LABELS)),
        'ckpt_path': '../models/unet_resnext50_32x4d_0.pt',
        'name': 'fpn_resnext50_32x4d',
        'use_tta': False,
        'weight': 0.1,
    },
    {
        'model': smp.Unet(encoder_name="resnext50_32x4d", classes=len(LABELS)),
        'ckpt_path': '../models/unet_resnext50_32x4d_1.pt',
        'name': 'fpn_resnext50_32x4d',
        'use_tta': False,
        'weight': 0.1,
    },
    {
        'model': smp.Unet(encoder_name="resnext50_32x4d", classes=len(LABELS)),
        'ckpt_path': '../models/unet_resnext50_32x4d_2.pt',
        'name': 'fpn_resnext50_32x4d',
        'use_tta': False,
        'weight': 0.1,
    },
    {
        'model': smp.UnetPlusPlus(encoder_name='resnet34', classes=len(LABELS)),
        'ckpt_path': '../models/unetplus_resnet34_0.pt',
        'name': 'unetplus_resnet34',
        'use_tta': False,
        'weight': 0.1,
    },
    {
        'model': smp.UnetPlusPlus(encoder_name='resnet34', classes=len(LABELS)),
        'ckpt_path': '../models/unetplus_resnet34_1.pt',
        'name': 'unetplus_resnet34',
        'use_tta': False,
        'weight': 0.1,
    },
    {
        'model': smp.UnetPlusPlus(encoder_name='resnet34', classes=len(LABELS)),
        'ckpt_path': '../models/unetplus_resnet34_2.pt',
        'name': 'unetplus_resnet34',
        'use_tta': False,
        'weight': 0.1,
    },
]


transforms = tta.Compose(
    [
        tta.HorizontalFlip(),
#         tta.Rotate90(angles=[0, 90, 180]),
        tta.Scale(scales=[1, 2, 4]),
        tta.Multiply(factors=[0.9, 1.1]),
    ]
)

for item in models:
    ckpt = torch.load(item['ckpt_path'])
    item['model'].load_state_dict(ckpt['state_dict'])
    item['model'] = item['model'].eval().to(GLOBAL_CONFIG['device'])
    
    if item['use_tta']:
        item['model'] = tta.SegmentationTTAWrapper(item['model'], transforms, merge_mode='mean')

In [12]:
test_dataset = TestDataset(TEST_IMG_NAMES, TEST_IMAGES_FOLDER)
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=1, shuffle=False, drop_last=False)

In [13]:
os.makedirs(GLOBAL_CONFIG['predict_folder'], exist_ok=True)

for n, batch in enumerate(tqdm(test_loader)):
    
    predict = torch.zeros(len(LABELS), GLOBAL_CONFIG['IMG_H'], GLOBAL_CONFIG['IMG_W'], dtype=torch.float32, device='cpu')
    
    for item in models:
        with torch.no_grad():
            pr = item['model'](batch['image'].to(GLOBAL_CONFIG['device']))
            pr = F.softmax(pr.cpu().detach(), dim=1)[0]
            predict += pr * item['weight']
    
    predict = process_multimask2np(predict.round())
    predict = Image.fromarray(predict).resize((batch['orig_w'][0], batch['orig_h'][0]), Image.NEAREST)
    img_path = os.path.join(GLOBAL_CONFIG["predict_folder"], batch["image_name"][0])
    predict.save(img_path)

100%|██████████| 1000/1000 [12:25<00:00,  1.34it/s]
