1. # 🛠 Install Libraries

## For PC

In [54]:
#!pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
#!pip install --user numpy 
#!pip install --user pandas 
#!pip install  segmentation-models-pytorch
# !python -m pip install opencv-python
# !pip install tensorflow
# !pip install -q scikit-learn==1.0
#!pip install plotly
# !pip install --user albumentations
# import sys  
# !{sys.executable} -m pip install --user matplotlib
#!pip install ipywidgets --user
#!pip install -U albumentations[imgaug]

## For Kaggle !!

In [55]:
!pip install  segmentation-models-pytorch

# 📚 Import Libraries  


In [56]:
%load_ext autoreload
%autoreload 2

In [57]:
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"
import segmentation_models_pytorch as smp
import random
from glob import glob
import os, shutil
from tqdm import tqdm
tqdm.pandas()
import time
import copy
#import joblib
#from collections import defaultdict
from IPython import display as ipd
from PIL import Image
# visualization
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Sklearn
import sklearn
from sklearn.model_selection import train_test_split

# PyTorch 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

import timm

# Albumentations for augmentations
import albumentations as A

# For colored terminal text
from colorama import Fore, Back, Style
c_  = Fore.GREEN
sr_ = Style.RESET_ALL

import warnings
warnings.filterwarnings("ignore")

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# gc
import gc

## versions

In [58]:
print(f'Torch version{torch.__version__}')
print('The scikit-learn version is {}.'.format(sklearn.__version__))
import platform
print(f"Python version: {platform.python_version()}")


# ⚙️ Configuration 

In [59]:
class CFG:
    JUST_PREDICT  = True
    Kaggle        = True 
    DEBUG         = False
    wandb_on      = False
    seed          = 101
    MULTIMODEL    = True
    #exp_name      = 'Baselinev2'
    #comment       = 'unet-efficientnet_b1-224x224-aug2-split2'
    model_name_1    = 'u-efficientnet-b1'
    model_name_2    = 'u-efficientnet-b2'
    model_name_3    = 'u-timm-mobilenetv3_small_minimal_100'
    weights       = 'imagenet'
    backbone_1    = 'efficientnet-b1'
    backbone_2    = 'efficientnet-b2' 
    backbone_3    = 'timm-mobilenetv3_small_minimal_100'
    models        = []
    optimizers    = []
################################################### 
    num_of_models = 1
    model_number  = 8
    train_bs      = 12
    valid_bs      = 12
    number_imgs   = 100 if DEBUG else 8203     #8203
    num_test      = 10 if DEBUG else 1000      # 1000
    print_every   = 8  if DEBUG else 100      #500
    img_size      = [256, 256] #[540, 960]
    start_width   = 512
    start_height  = 512
    final_width   = 512
    final_height  = 512
    epochs        = 4  if DEBUG else 28        #35
    ###############################################
    crop_koef     = 1
    lr            = 2e-3
    num_workers   = 4 if Kaggle else 0
    scheduler     = 'CosineAnnealingLR'
    min_lr        = 1e-6
    T_max         = int(30000/train_bs*epochs)+50
    T_0           = 25
    warmup_epochs = 0
    wd            = 0 #1e-6
    n_accumulate  = max(1, 32//train_bs)
    n_fold        = 5
    num_classes   = 4
    classes       = [0,6,7,10]
    activation    = None #'softmax'
    device        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    images_path   = "../input/russian-railways-2/images/images/" if Kaggle else "./train/images/" #"../Цифровой прорыв 2022_Лето\train\images"
    masks_path    = "../input/russian-railways-2/mask/mask/" if Kaggle else  "./train/mask/"
    test_path     = "../input/russian-railways-2/test/test/" if Kaggle else "./test/"
    save_path     = '../working/result/' if Kaggle else "./result/"
    best_model_w_1= '../input/russian-railways-2/best_epoch_ofu-efficientnet-b1_v2.bin' if Kaggle else './last_epoch_ofu-efficientnet-b1_v2.bin'
    best_model_w_2= '../input/russian-railways-2/best_epoch_ofu-efficientnet-b2_v2.bin' if Kaggle else './last_epoch_ofu-efficientnet-b2_v2.bin'
    best_model_w_3= '../input/russian-railways-2/best_epoch_ofu-timm-mobilenetv3_small_minimal_100_v2.bin' if Kaggle else './best_epoch_ofu-timm-mobilenetv3_small_minimal_100_v2.bin'
    best_model_w_4= None
    best_model_w_5= None
    best_model_w_6= None
    
#'../input/russian-railways-2/best_epoch.bin'

In [60]:
torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ❗ Reproducibility

In [61]:
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('> SEEDING DONE')
    
set_seed(CFG.seed)

# 📈 Visualization

In [62]:
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

In [63]:
gc.collect() # gc.collect() возвращает количество объектов, которые были собраны и удалены.

# 📦 Model


In [64]:

import segmentation_models_pytorch as smp

# def build_models(number):
#     if number == 1:
#         model_Unet = smp.Unet(
#             encoder_name=CFG.backbone,      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#             encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
#             in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#             classes=CFG.num_classes,        # model output channels (number of classes in your dataset)
#             activation=CFG.activation)
#         model_Unet.to(CFG.device)
#         CFG.models = [model_Unet]
#     else:                  
#         model_Unet = smp.Unet(
#                 encoder_name=CFG.backbone,      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#                 encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
#                 in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#                 classes=CFG.num_classes,        # model output channels (number of classes in your dataset)
#                 activation=CFG.activation )
#         model_Unet.to(CFG.device)
#         CFG.models = [model_Unet] 

#         model_UnetPP = smp.UnetPlusPlus(
#             encoder_name='timm-efficientnet-b7',
#             encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
#             in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#             classes=CFG.num_classes,        # model output channels (number of classes in your dataset)
#             activation=CFG.activation)
#         model_UnetPP.to(CFG.device)  

#         model_inceptionresnetv2 = smp.Unet(
#             encoder_name='timm-res2net50_26w_4s',      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#             encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
#             in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#             classes=CFG.num_classes,        # model output channels (number of classes in your dataset)
#             activation=CFG.activation )
#         model_inceptionresnetv2.to(CFG.device)   

#         model_Deep_lab = smp.UnetPlusPlus(
#             encoder_name ='tu-gluon_xception65',      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#             encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
#             in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#             classes=CFG.num_classes,        # model output channels (number of classes in your dataset)
#             activation=CFG.activation)
#         model_Deep_lab.to(CFG.device) 


#         model_dl2 = smp.UnetPlusPlus( 
#             encoder_name ='timm-efficientnet-b2',      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#             encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
#             in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#             classes=CFG.num_classes,        # model output channels (number of classes in your dataset)
#             activation=CFG.activation)
#         model_dl2.to(CFG.device)   

#         model_pan = smp.Unet( 
#             encoder_name ='timm-efficientnet-b2',      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#             encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
#             in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#             classes=CFG.num_classes,        # model output channels (number of classes in your dataset)
#             activation=CFG.activation)
#         model_pan.to(CFG.device)    


#         model_pan2 = smp.UnetPlusPlus( 
#             encoder_name ='timm-regnetx_064',      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#             encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
#             in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#             classes=CFG.num_classes,        # model output channels (number of classes in your dataset)
#             activation=CFG.activation)
#         model_pan2.to(CFG.device) 

#         model_mobile = smp.Unet(
#             encoder_name='mobilenet_v2',      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#             encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
#             in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#             classes=CFG.num_classes,        # model output channels (number of classes in your dataset)
#             activation=CFG.activation)
#         model_mobile.to(CFG.device)

#         CFG.models = [model_Unet,model_UnetPP,model_inceptionresnetv2, model_Deep_lab, model_dl2, model_pan, model_pan2, model_mobile]
#     return CFG.models

    
##################################################################################################################################################################    
    
##################################################################################################################################################################    
def build_model(indx):
    if indx == 1: 
        # 7.7 million
        model = smp.Unet(
            encoder_name='efficientnet-b1',      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
            encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
            in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
            classes=CFG.num_classes,        # model output channels (number of classes in your dataset)
            activation=CFG.activation)
        CFG.backbone = 'efficientnet-b1'
        
    elif indx == 2: 
        model = smp.Unet(
            encoder_name='efficientnet-b2',
            encoder_weights="imagenet",     
            in_channels=3,                  
            classes=CFG.num_classes,       
            activation=CFG.activation)
        CFG.backbone = 'efficientnet-b2'
        
    elif indx == 3: 
        model = smp.Unet(
            encoder_name='timm-mobilenetv3_small_minimal_100',      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
            encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
            in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
            classes=CFG.num_classes,        # model output channels (number of classes in your dataset)
            activation=CFG.activation)
        CFG.backbone = 'timm-mobilenetv3_small_minimal_100'
        
    
    model.to(CFG.device)
    CFG.models = [model]
    return  model


def load_models(pash):
    for model in CFG.models:
        model = build_model(model)


def load_model(path,indx):
    model = build_model(indx)
    model.load_state_dict(torch.load(path , map_location=torch.device('cpu')))
    model.eval()
    model.to(CFG.device)
    CFG.models.append(model)
    return model

## Loading models !

In [65]:
preprocessing_fn =[]
#preprocessing_fn[0] = None

if CFG.JUST_PREDICT:
    preprocessing_fn.append(None)
    preprocessing_fn.append(smp.encoders.get_preprocessing_fn(CFG.backbone_1, CFG.weights))
    preprocessing_fn.append(smp.encoders.get_preprocessing_fn(CFG.backbone_2, CFG.weights))
    preprocessing_fn.append(smp.encoders.get_preprocessing_fn(CFG.backbone_3, CFG.weights))
#################################################################################################################################
    #preprocessing_fn[1] = smp.encoders.get_preprocessing_fn(CFG.backbone_1, CFG.weights)
    #preprocessing_fn[2] = smp.encoders.get_preprocessing_fn(CFG.backbone_2, CFG.weights)
    #preprocessing_fn[3] = smp.encoders.get_preprocessing_fn(CFG.backbone_3, CFG.weights)
##########################################################################################

model_name = [ 'Nothing','u-efficientnet-b1','u-efficientnet-b2','u-timm-mobilenetv3_small_minimal_100']
best_model_w = [0,CFG.best_model_w_1, CFG.best_model_w_2, CFG.best_model_w_3]

#PATH = f"best_epoch.bin"
#torch.save(model.state_dict(), PATH)

In [66]:


def load_img(path):
    image = cv2.imread(path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image

def load_msk(path):
    msk = cv2.imread(path, cv2.IMREAD_GRAYSCALE) # or msk=cv2.imread(path, 0)
    masks = [(msk == v) for v in CFG.classes]
    msk = np.stack(masks, axis=-1).astype('float')
    return msk
        
    
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    _transform = [
        A.Lambda(image=preprocessing_fn),
        A.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return A.Compose(_transform)

def get_preprocessing_test(preprocessing_fn):
    _transform = [
        A.Lambda(image=preprocessing_fn),
        A.Lambda(image=to_tensor),
    ]
    return A.Compose(_transform)

# 🍚 Dataset class


In [67]:
class BuildDataset(torch.utils.data.Dataset):
#      """ Read images, apply augmentation and preprocessing transformations.
#     Args:
#         images_dir (str): path to images folder
#         masks_dir (str): path to segmentation masks folder
#         class_values (list): values of classes to extract from segmentation mask
#         augmentation (albumentations.Compose): data transfromation pipeline 
#             (e.g. flip, scale, etc.)
#         preprocessing (albumentations.Compose): data preprocessing 
#             (e.g. noralization, shape manipulation, etc.) 

    def __init__(self, images_paths, masks_paths = None, label=True, transforms=None,  preprocessing= None ,preprocessing_img = None ):
        self.label      = label
        self.img_paths  = images_paths
        self.msk_paths  = masks_paths
        self.transforms = transforms
        self.preprocessing = preprocessing
        self.preprocessing_img = preprocessing_img
    def __len__(self):
        return len(self.img_paths)
    
    
    def __getitem__(self, index):
        img_path  = self.img_paths[index]
        img = load_img(img_path)
        
        if self.label: # WHEN WE TRAIN 
            msk_path = self.msk_paths[index]
            msk = load_msk(msk_path)
            
            if self.transforms:
                data = self.transforms(image=img, mask=msk)
                img, msk  = data['image'], data['mask']
            if self.preprocessing:
                data = self.preprocessing(image=img, mask=msk)
                img, msk  = data['image'], data['mask']
            return img, msk
        else: # WHEN WE PREDICT
            if self.transforms:
                data = self.transforms(image=img)
                img  = data['image']
            if self.preprocessing:
                data =  self.preprocessing_img(image=img)
                img = data['image']
            return img    

# 🍰 DataLoader

In [68]:
# def prepare_loaders():
    
#     img_names= [ os.path.join(CFG.images_path,img_name) for img_name in os.listdir(CFG.images_path)]
#     masks_names = [ os.path.join(CFG.masks_path,mask_name) for mask_name in os.listdir(CFG.masks_path)]
#     img_names = img_names[0:CFG.number_imgs]
#     masks_names=masks_names[0:CFG.number_imgs]
#     image_train, image_valid, mask_train, mask_valid = train_test_split(img_names, masks_names, test_size=0.2, random_state=CFG.seed)

    
    
#     train_dataset = BuildDataset(image_train, mask_train, transforms=data_transforms['train'],preprocessing=get_preprocessing(preprocessing_fn))
#     valid_dataset = BuildDataset(image_valid, mask_valid, transforms=data_transforms['valid'],preprocessing=get_preprocessing(preprocessing_fn))

#     train_loader = DataLoader(train_dataset, batch_size=CFG.train_bs, 
#                               num_workers=CFG.num_workers, shuffle=True, pin_memory=True, drop_last=False)
#     valid_loader = DataLoader(valid_dataset, batch_size=CFG.valid_bs, 
#                               num_workers=CFG.num_workers, shuffle=False, pin_memory=True)
    
#     return train_loader, valid_loader

In [69]:
# train_loader, valid_loader = prepare_loaders()

In [70]:
# imgs, msks = next(iter(train_loader))
# imgs.size(), msks.size()

In [71]:
torch.cuda.empty_cache()

In [72]:
data_transforms = {
    
    "valid": A.Compose([
        A.Resize(height=CFG.start_height, width=CFG.start_width, interpolation=cv2.INTER_NEAREST),
        ], p=1.0)

}

# Make a new directory for results !

In [73]:
import shutil, os

try:
    if CFG.Kaggle:
        os.mkdir('../working/result')
        print('KAGGLE DIR CREATED')
    else:
        shutil.rmtree('./result')
        os.mkdir('./result')
        print('PC DIR CREATED')
except Exception:
    print("DIR NOT CREATED")
    pass
    

# 🔭 Prediction

In [74]:
img_names_test= [ os.path.join(CFG.test_path,img_name) for img_name in os.listdir(CFG.test_path)]
img_names_test = img_names_test[0:CFG.num_test]
print(len(img_names_test))
# sizes = []
# for image_file in img_names_test:
#     img = Image.open(image_file).convert("RGB")
#     orig_size=img.size
#     sizes.append(orig_size)
sizes = [(3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), 
         (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160),
         (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512),
         (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512),
         (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512),
         (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), 
         (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), 
         (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), 
         (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), 
         (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), 
         (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), 
         (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160),
         (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), 
         (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), 
         (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), 
         (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), 
         (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), 
         (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), 
         (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), 
         (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), 
         (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), 
         (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), 
         (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), 
         (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), 
         (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), 
         (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), 
         (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), 
         (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), 
         (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), 
         (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512),
         (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), 
         (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), 
         (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), 
         (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), 
         (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512),
         (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512),
         (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160),
         (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), 
         (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), 
         (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), 
         (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), 
         (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), 
         (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), 
         (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), 
         (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), 
         (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160),
         (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160),
         (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160),
         (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512),
         (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160),
         (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512),
         (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), 
         (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), 
         (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), 
         (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), 
         (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), 
         (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), 
         (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), 
         (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), 
         (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), 
         (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), 
         (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), 
         (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), 
         (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512),
         (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), 
         (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), 
         (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), 
         (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512),
         (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), 
         (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), 
         (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), 
         (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512),
         (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), 
         (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), 
         (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), 
         (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), 
         (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), 
         (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), 
         (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), 
         (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), 
         (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), 
         (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), 
         (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), 
         (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), 
         (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), 
         (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), 
         (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), 
         (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), 
         (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), 
         (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), 
         (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), 
         (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), 
         (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), 
         (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), 
         (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), 
         (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), 
         (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), 
         (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), 
         (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), 
         (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), 
         (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512),
         (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), 
         (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), 
         (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), 
         (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), 
         (2688, 1512), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), 
         (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (2688, 1512), (3840, 2160), 
         (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512),
         (3840, 2160), (2688, 1512), (2688, 1512), (3840, 2160), (2688, 1512), (2688, 1512), (2688, 1512), (3840, 2160), (3840, 2160), 
         (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), 
         (3840, 2160),(3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160), (3840, 2160), (3840, 2160), (3840, 2160), (2688, 1512), (3840, 2160)]

save_path = CFG.save_path
preds = []


            

In [75]:
# load_model(best_model_w_1,1) 
# preprocessing_fn= smp.encoders.get_preprocessing_fn('efficientnet-b1', CFG.weights)
# test_dataset = BuildDataset(img_names_test,None, label=False, transforms=data_transforms['valid'],preprocessing=get_preprocessing(preprocessing_fn))
# test_loader  = DataLoader(test_dataset, batch_size=1,  num_workers=CFG.num_workers, shuffle=False, pin_memory=True)
# data_loader_iter = iter(test_loader)

## Color Check

In [76]:
###########################################################################################################################
###########################################################################################################################
################################################### MULTI MODEL SISTEM ####################################################
###########################################################################################################################
###########################################################################################################################
if CFG.MULTIMODEL:
    print("MULTI_MODEL_MODE")
    #for indx, model in enumerate(CFG.models, start=1):
    for indx in range(2,3):
        test_dataset = BuildDataset(img_names_test,None, label=False, 
                                    transforms=data_transforms['valid'],
                                    preprocessing=get_preprocessing(preprocessing_fn[indx]), 
                                    preprocessing_img= get_preprocessing_test(preprocessing_fn[indx]))
        test_loader  = DataLoader(test_dataset, batch_size=1,  num_workers=CFG.num_workers, shuffle=False, pin_memory=True)
        data_loader_iter = iter(test_loader)

        #imgs = imgs.to(CFG.device, dtype=torch.float)
        #model_pred = load_model(f"./best_epoch_of{model.name}.bin")
  
        print(model_name[indx])
        model_pred = load_model(f"../input/russian-railways-2/best_epoch_of{model_name[indx]}_v2.bin",indx)
        print(f'WE USE TRAINED MODEL № {indx}: {model_pred.name}!!!')
        model_pred = model_pred.to(CFG.device)
        for index, imgs in enumerate(test_loader):
            with torch.no_grad():
                imgs = imgs.to(CFG.device, dtype=torch.float)
                #print(imgs.shape)
                pred = model_pred(imgs)
                #print(pred)
                pred = nn.Sigmoid()(pred)
                #print(pred.shape)
                #print(pred)
                imgs = next(data_loader_iter)
                #pred.to("cpu",dtype=torch.float)
                #pred = pred.cpu().detach().numpy()#("cpu")
                #pred_arg_max = np.argmax(pred, axis = 1)

                #print(pred_arg_max.shape)
                #print(pred_arg_max)

                #print(np.unique(pred_arg_max))
                # yellow - train - 10 - CLASS 2 | more light main rail - 7 - CLASS 3 | SIDE RAIL - 6 - MORE DARK CLASS 1 \ BACKGROUND - 0 PINK CLASS 0
                #pred_arg_max[pred_arg_max == 0 ] = 0
                #pred_arg_max[pred_arg_max == 1 ] = 6
                #pred_arg_max[pred_arg_max == 2 ] = 10
                #pred_arg_max[pred_arg_max == 3 ] = 7
                #res = np.array(pred_arg_max).astype(np.uint8)

                #res = np.reshape(res, (CFG.img_size[0], CFG.img_size[1]))
                preds.append(pred)
                if index % CFG.print_every == 0:
                    pred.to("cpu",dtype=torch.float)
                    pred = pred.cpu().detach().numpy()#("cpu")
                    pred_arg_max = np.argmax(pred, axis = 1)
                    print(pred_arg_max)
                    pred_arg_max[pred_arg_max == 0 ] = 0
                    pred_arg_max[pred_arg_max == 1 ] = 6
                    pred_arg_max[pred_arg_max == 2 ] = 7
                    pred_arg_max[pred_arg_max == 3 ] = 10
                    res = np.array(pred_arg_max).astype(np.uint8)
                    res = np.reshape(res, (CFG.start_height,CFG.start_width))#(CFG.img_size[0], CFG.img_size[1]))
                    img = Image.fromarray(res)
                    img = img.resize(sizes[index],Image.NEAREST)
                    imgplot=plt.imshow(img)
                    plt.show()
    results =[]
    print("ANSAMBLE OF MODELS")
    for i in range(CFG.num_test):  
        results =[]
        for j in range(CFG.num_of_models):
            results.append(preds[j*CFG.num_test+i])
        arr = torch.stack(results, dim=0)
        mean = torch.mean(arr,dim=0).cpu().detach().numpy()
        pred_arg_max = np.argmax(mean, axis = 1)
        # yellow - train - 10 - CLASS 2 | more light main rail - 7 - CLASS 3 | SIDE RAIL - 6 - MORE DARK CLASS 1 \ BACKGROUND - 0 PINK CLASS 0
        pred_arg_max[pred_arg_max == 0 ] = 0
        pred_arg_max[pred_arg_max == 1 ] = 6
        pred_arg_max[pred_arg_max == 2 ] = 7
        pred_arg_max[pred_arg_max == 3 ] = 10
        res = np.array(pred_arg_max).astype(np.uint8)
        res = np.reshape(res, (CFG.start_height,CFG.start_width))#(CFG.img_size[0], CFG.img_size[1])), (CFG.img_size[0], CFG.img_size[1]))
        img = Image.fromarray(res)
        img = img.resize(sizes[i],Image.NEAREST)
        if i % CFG.print_every == 0: #CFG.print_every
            imgplot=plt.imshow(img)
            plt.show()

        img.save(save_path + img_names_test[i].split("/")[-1])
gc.collect()
print("PREDICTONS DONE !")

In [77]:
# str = 'img_0.004817998680835878.png'
# path_2 = CFG.images_path + str
# path = CFG.masks_path + str
# msk = cv2.imread(path, cv2.IMREAD_GRAYSCALE) 
# msk = Image.fromarray(msk)
# imgplot=plt.imshow(msk)
# plt.show()
# #####################################################################
# image = cv2.imread(path_2)
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# data = data_transforms['valid'](image=image)
# image  = data['image']
# data =  get_preprocessing_test(preprocessing_fn[3])(image=image)
# image = data['image'] 
# image = np.expand_dims(image, axis = 0)
# image = torch.tensor(image)
# image = image.to(CFG.device)
# pred = model_pred(image)
# pred.to("cpu",dtype=torch.float)
# pred = pred.cpu().detach().numpy()#("cpu")

# pred_arg_max = np.argmax(pred, axis = 1)
# # yellow - CLASS 2
# # yellow - train - 10 - CLASS 2 | more light main rail - 7 - CLASS 3 | SIDE RAIL - 6 - MORE DARK CLASS 1 \ BACKGROUND - 0 PINK CLASS 0
# pred_arg_max[pred_arg_max == 0 ] = 0
# pred_arg_max[pred_arg_max == 1 ] = 6
# pred_arg_max[pred_arg_max == 2 ] = 7
# pred_arg_max[pred_arg_max == 3 ] = 10
# res = np.array(pred_arg_max).astype(np.uint8)
# res = np.reshape(res, (CFG.start_width,CFG.start_height))#(CFG.img_size[0], CFG.img_size[1]))
# img = Image.fromarray(res)
# img = img.resize(sizes[index],Image.NEAREST)
# if index % CFG.print_every == 0:
#     imgplot=plt.imshow(img)
#     plt.show()

In [78]:
 print(len(preds))

In [79]:

# model_pred = load_model(f"../input/rails-hackaton/best_epoch.bin",1)
# # try:
# #     model_pred = load_model(f"best_epoch.bin",1)
# #     print(f'WE USE TRAINED MODEL № : !!!')
# # except Exception:
# #     model_pred = build_model(1)
# #     print(f'WE USE NEW MODEL 1  !!!')
# img_1 = load_img("../input/russian-railways-2/test/test/img_0.007661808580294749.png")
# img_2 = load_img('../input/russian-railways-2/test/test/img_0.027471593748947032.png')      
# data = data_transforms['train'](image=img_1, image2 = img_2)
# img_1  = data['image']
# img_2  = data['image2']
# img_1 = np.transpose(img_1, (2, 0, 1))   
# img_2 = np.transpose(img_2, (2, 0, 1)) 
# print(img_1)
# print('')
# print('')
# print(img_2)                  
# print(img_1.shape)
# #imgplot=plt.imshow(img_1)
# #plt.show()
# #imgplot=plt.imshow(img_2)
# #plt.show()
# pred_1 = model_pred(img_1)
# pred_2 = model_pred(img_2)
# print(pred_1)
# print(pred_2)
# # CORREKT - MIN MAX

In [80]:
# path = "../input/russian-railways-2/mask/mask/img_0.008320160156388479.png"
# msk = Image.open(path).convert("L")
# print(msk.size)
# imgplot=plt.imshow(msk)
# plt.show()

# ✂️ Remove Files

In [81]:
import shutil

shutil.make_archive('name', 'zip', '../working/result/')

#####################################
#######         DELETE !!!  #########
#####################################

# test_path = "../input/russian-railways-2/test/test/"
# img_names_test= [ os.path.join(test_path,img_name) for img_name in os.listdir(test_path)]
# for image_file in img_names_test:
#     try:
#         os.remove('./result'+ img_names_test[index].split("/")[-1])
#     except Exception:
#         pass
# print("DONE!")



#for image_file in img_names_test:
 #   os.remove("/kaggle/working/" + img_names_test[index].split("/")[-1])

In [82]:
if CFG.wandb_on:
    !rm -r ./wandb