In [2]:
import os
os.chdir("..")

import numpy as np
import pandas as pd
from PIL import Image

from glob import glob

from lib import *

%matplotlib inline

In [4]:
path = "data/train"
images = os.listdir(path)
ind = np.random.choice(images).split(".")[0]

img = np.array(Image.open(f"{path}/{ind}.jpg"))
mask = np.array(Image.open(f"{path}_mask/{ind}.png"))

In [5]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
from torch.utils.data import Dataset, DataLoader

In [6]:
X_train = []
y_train = []
X_val = []
y_val = []
X_test = []

train_prefix = "data/train"
for image_name in os.listdir(train_prefix):
    ind = image_name.split('.')[0]
    image = np.array(Image.open(f"{train_prefix}/{ind}.jpg"))
    mask = np.array(Image.open(f"{train_prefix}_mask/{ind}.png"))
    X_train.append(image)
    y_train.append(mask)
    
val_prefix = "data/valid"
for image_name in os.listdir(val_prefix):
    ind = image_name.split('.')[0]
    image = np.array(Image.open(f"{val_prefix}/{ind}.jpg"))
    mask = np.array(Image.open(f"{val_prefix}_mask/{ind}.png"))
    X_val.append(image)
    y_val.append(mask)
    

test_prefix = "data/test"
for image_name in sorted(os.listdir(test_prefix)):
    ind = image_name.split('.')[0]
    image = np.array(Image.open(f"{test_prefix}/{ind}.jpg"))
    X_test.append(image)

In [7]:
class MILDataset(Dataset):
    def __init__(self, X, y, transform=None):
        super().__init__()
        self.images = X
        self.masks = y
        self.transform = transform
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        if self.transform is not None:
            tr = self.transform(image=self.images[idx], mask=self.masks[idx])
            # albumentations has a different mask format
            image, mask = tr['image'], tr['mask']
            mask = mask.double()
            mask /= mask.max()
            return image, mask.unsqueeze(0)
        return self.images[idx], self.masks[idx]


class MILDatasetTest(Dataset):
    def __init__(self, X, transform=None):
        super().__init__()
        self.images = X
        self.transform = transform
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        if self.transform is not None:
            tr = self.transform(image=self.images[idx])
            image = tr['image']
            return image
        return self.images[idx]

In [8]:
h, w = X_val[0].shape[:2]
# I use Albumentations for it's speed and ability to work with masks
train_aug = A.Compose([
    A.HorizontalFlip(),
    A.Perspective(),
    A.RandomResizedCrop(h, w, scale=(0.9, 1)),
    A.Normalize(always_apply=True),
    ToTensorV2(),
])
val_aug = A.Compose([
    A.Normalize(always_apply=True), 
    ToTensorV2(),
])

In [9]:
train_dataset = MILDataset(X_train, y_train, transform=train_aug)
train_dataloader = DataLoader(train_dataset, batch_size=48, shuffle=True, pin_memory=True)
val_dataset = MILDataset(X_val, y_val, transform=val_aug)
val_dataloader = DataLoader(val_dataset, batch_size=48, shuffle=False, pin_memory=True)
test_dataset = MILDatasetTest(X_test, transform=val_aug)
test_dataloader = DataLoader(test_dataset, batch_size=48, shuffle=False, pin_memory=True)

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PSPNet().to(device)
# We are optimizing for Dice, so it is natural to use soft-dice loss.
# It is a differentiable approximation to Dice coefficient
criterion = soft_dice_loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
train(model, val_dataloader, val_dataloader, val_dataset, criterion, optimizer, device, n_epochs=2)

Примерно так будет выглядеть процесс обучения:

<img src="https://i.imgur.com/yTHErot.png" width="500">

Кажется, что можно еще пообучаться, но уже надо сдавать!

In [13]:
!mkdir results

In [14]:
torch.save(model.state_dict(), 'results/model_state_dict')

In [15]:
preds = predict_test(model, test_dataloader, device)
preds = torch.cat(preds, 0)

In [16]:
result_data = {
    'id': [],
    'rle_mask': [],
}
for idx, img_name in enumerate(sorted(os.listdir(test_prefix))):
    img_id = img_name.split('.')[0]
    result_data['id'].append(img_id)
    result_data['rle_mask'].append(encode_rle(preds[idx]))

In [None]:
pred_test = pd.DataFrame(result_data)
pred_test.to_csv('results/pred_test.csv')

В задаче было написано сделать для валидационного (хотя кажется что нужен тренировочный), так что для валидационного я сделал аналогично

In [19]:
# multiply by 255 to get correct mask images
html = get_html(sorted(glob('data/test/*')), preds.squeeze().int()*255, path_to_save='results/test')

Я обучал модель на kaggle, поэтому соберу HTML не из предсказаний, а по pred_test.csv

In [28]:
pred_test = pd.read_csv('results/pred_test.csv', index_col=0)

In [31]:
preds = []
for pred in pred_test.rle_mask:
    preds.append(decode_rle(pred))

In [36]:
html = get_html(sorted(glob('data/test/*')), np.stack(preds)*255, path_to_save='results/test')