In [None]:
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
from torch.utils.data import Dataset
import json
import matplotlib.pyplot as plt
import glob
from torchvision.utils import draw_segmentation_masks
import torch
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import segmentation_models_pytorch as smp
import os

In [None]:
#numbers of bad images
bad_ims = [6,12,25,27,29,30,32,44,49,53,56,66,67,68,76,77,78,79,88,92,93,94,98,101,105,106,109,112,113,114,122,124,136,143,144,146,150,
151,154,157,159,160,161,163,166,168,169,171,175,179,180,219,227,229,235,252,253,255,256,257,264,265,272,275,282,283,290,294,310,
312,318,335,341,346,355,357,358,359,360,362,369,375,381,388,390,394,396,402,407,408,411,413,414,416,420,427,429,430,431,436,437,
438,441,442,446,447,449,453,454,455,457,458,460,461,466,467,468,476,480,485,486,487,493,496,499,502,504,505,508,509,510,512,515,
516,518,526,528,531,532,535,537,539,541,544,546,553,555,556,559,565,567,568,569,574,578,579,586,589,591,595,598,601,603,604,607,
619,622,625,629,640,645,648,
1,4,7,8,9,11,13,14,16,18,19,20,21,23,26,28,31,33,34,35,36,41,47,48,51,54,57,59,61,62,63,64,69,70,71,72,73,74,80,84,85,89,90,91,
           96,97,104,115,119,121,125,126,129,132,134,135,139,153,156,158,162,164,173,176,177,181,182,183,185,186,187,189,192,
           194,196,198,199,204,209,213,215,222,226,231,233,234,236,237,239,240,243,244,250,251,254,261,262,266,270,276,284,287,
           291,292,293,297,301,309,311,313,314,316,323,326,328,329,330,331,333,338,340,342,344,345,349,350,352,353,363,365,366,
           367,368,370,371,372,374,380,384,385,399,401,404,406,410,415,417,421,424,428,434,452,463,469,475,477,482,492,494,497,
           501,519,529,540,543,549,575,580,585,587,590,594,597,600,611,612,613,615,620,624,628,630,635,638,641,643,644,646,650]

# Парсер данных

In [None]:
class EyeDataset(Dataset):
    def __init__(self, path, bad_ims):
        img_files = sorted(glob.glob(f"{path}/*.png"))
        
        true = []
        for i in range(len(img_files)):
            if i not in bad_ims:
                true.append(img_files[i])
        self.img_files = true
        
        #self.transform = transform
    def fill_polig(self, polig, image_size):
        mask = np.zeros(image_size)
        if len(polig) == 1:
            cv2.fillPoly(mask, np.int32(polig), 1)
        else:
            for p in polig:
                cv2.fillPoly(mask, np.int32([p]), 1)
        return mask
    
    def fill_mask(self, features, image_size):
        mask = np.zeros(image_size)
        for feature in features:
            polig_points = []
            if feature['geometry']['type'] == 'MultiPolygon':
                for polig in feature['geometry']['coordinates']:
                    mask+=self.fill_polig(polig, image_size)
            else:
                points = feature['geometry']['coordinates']
                mask+=self.fill_polig(points, image_size)
        return mask

    def cv2unsharp(self, image):
        gaussian_3 = cv2.GaussianBlur(image, (0, 0), 2.0)
        unsharp_image = cv2.addWeighted(image, 7.0, gaussian_3, -6.0, 0)
        return unsharp_image

    def __getitem__(self, idx):
        image_path = self.img_files[idx]
        json_path = image_path.replace("png", "geojson")
        with open(json_path) as f:
            json_contents = json.load(f)
            
        image = cv2.imread(self.img_files[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.cv2unsharp(image)
        image = np.array(image / 255, dtype=np.float32)
        image_size = image.shape[:2]
        
        if type(json_contents) == dict and json_contents['type'] == 'FeatureCollection':
            features = json_contents['features']
        elif type(json_contents) == list:
            features = json_contents
        else:
            features = [json_contents]
            
        mask = self.fill_mask(features, image_size)
        mask_rev = np.ones(image_size)
        return {'image': image, 'mask': np.float32(np.stack([mask_rev-mask, mask], axis=-1))}
    
    def __len__(self):
        return len(self.img_files)

In [None]:
path = 'path/to/data'
dataset = EyeDataset(path, bad_ims)

In [None]:
#Augmentation
transforms = A.Compose([ToTensorV2(transpose_mask=True)])

aug_transform = A.Compose([
    A.CenterCrop(900, 900, p=1),
    A.RandomCrop(320, 320, p=1),
    A.VerticalFlip(p=0.5),              
    A.RandomRotate90(p=0.5)
])

# Создание даталоадера для обучения

In [None]:
class Dataseter(Dataset):
    def __init__(self, dataset, indices, transform, aug_transform, mult):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform
        self.aug_transform = aug_transform
        self.mult = mult

    def __getitem__(self, idx):
        sample = self.dataset[self.indices[int(idx/self.mult)]]
        
        augmented = self.aug_transform(image=sample['image'], mask=sample['mask'])
        if self.transform is not None:
            sample = self.transform(**augmented)
            
        return sample
    
    def __len__(self):
        return len(self.indices)*self.mult

In [None]:
train_indices, test_indices = train_test_split(range(len(dataset)), test_size=0.25)

#datasets
train_dataset = Dataseter(dataset, train_indices, transform=transforms, aug_transform=aug_transform, mult=2)
valid_dataset = ValidDataseter(dataset, test_indices, transform=transforms)

In [None]:
#dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, 1,
                                   shuffle=True, drop_last=True)

valid_loader = torch.utils.data.DataLoader(valid_dataset, 1,
                                   shuffle=True, drop_last=True)

# Проверка того как работает

In [None]:
batch = next(iter(train_loader))
image = batch['image']
mask = batch['mask']

image_with_mask = draw_segmentation_masks((image[0] * 255).type(torch.uint8), (mask[0][1]).type(torch.bool))
plt.imshow(image_with_mask.permute(2,1,0))
plt.show()

# Модель и обучени

In [None]:
#device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

In [None]:
#model
model = smp.Unet('resnet50', activation='softmax', classes=2)
model = model.to(device)

model.load_state_dict(torch.load('path/to/model.pth'))

In [None]:
#SoftDice loss
class SoftDice():
    def __init__(self, eps=1e-8):
        self.eps = eps
    def __call__(self, pred, target):
        num = torch.sum(2*pred*target)
        den = torch.sum(pred+target)
        return 1 - num / (den + self.eps)

In [None]:
loss = SoftDice()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [None]:
num_epochs = 30
for epoch in range(num_epochs):
    losses = []
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        image = batch['image'].to(device)
        mask = batch['mask'].to(device)
        pred = model(image)

        l = loss(pred[0][1], mask[0][1])
        l.backward()
        
        optimizer.step()
        losses.append(l.item())
        
    torch.save(model.state_dict(), 'model.pth')
    print(f'epoch: {epoch}, loss: {np.sum(np.array(losses))/(len(train_loader))}')

In [None]:
model.eval()

In [None]:
class TestDataseter(Dataset):
    def __init__(self, path, transforms):
        self.path = path
        self.images = os.listdir(path)
        self.transform = transforms

    def cv2unsharp(self, image):
        gaussian_3 = cv2.GaussianBlur(image, (0, 0), 2.0)
        unsharp_image = cv2.addWeighted(image, 7.0, gaussian_3, -6.0, 0)
        return unsharp_image
    
    def __getitem__(self, idx):
        image = cv2.imread(self.path+f'/{self.images[idx]}')
        image = self.cv2unsharp(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        image = np.array(image / 255, dtype=np.float32)
        sample = {'image': image, 'mask': np.zeros((10,10))}
        
        return {'image': self.transform(**sample)['image'], 'idx': self.images[idx]}
    
    def __len__(self):
        return len(self.images)
    
def sliding_window(x, window_size=320, stride=50):
    h = x.shape[2]
    w = x.shape[3]
    final = torch.zeros((2,h,w)).to(device)
    for i in range(0, x.shape[2]-window_size, stride):
        for j in range(0, x.shape[3]-window_size, stride):
            crop = x[:,:,i:i+window_size, j:j+window_size]
            mask = model(crop)
            zero = torch.zeros((window_size,window_size)).to(device)
            one = torch.ones((window_size,window_size)).to(device)

            pred_mask2 = torch.where(mask[0][0]<mask[0][1], one, zero)
            pred_mask1 = one-pred_mask2
            pred_mask = torch.stack((pred_mask1, pred_mask2)).to(device)
            
            final[:, i:i+window_size, j:j+window_size]+=pred_mask
    return final

In [None]:
test_dataset = TestDataseter('path/to/test_dataset', transforms)
test_loader = torch.utils.data.DataLoader(test_dataset, 1, shuffle=False, drop_last=True)

In [None]:
for batch in tqdm(test_loader):
    image = batch['image']
    idx = batch['idx'][0]
    a = sliding_window(image.to(device))

    s = a.shape[1:3]
    zero = torch.zeros(s).to(device)
    one = torch.ones(s).to(device)
    mean = torch.sum(a[1])/(a[1].shape[0]*a[1].shape[1]-torch.sum(a[1]==0))
    pred_mask2 = torch.where(a[1]>(mean), one, zero)
    pred_mask1 = one - pred_mask2
    pred_mask = torch.stack((pred_mask1, pred_mask2))

    im = (((torch.stack((pred_mask[1], pred_mask[1], pred_mask[1]), 0)*255).permute(1,2,0)).cpu()).numpy()
    cv2.imwrite(f'path/to/save/images/{idx}', im)