In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Dataset

from torchvision import transforms as T
import torchvision.transforms.functional as F
from torchvision.utils import make_grid

import albumentations as A
from albumentations.pytorch import ToTensorV2

import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp

import os
import random
import cv2
from datetime import datetime

In [None]:
# Device setup
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
# Data wrappers
intersect = list(set(os.listdir('./TlkWaterMeters/images')) & set(os.listdir('./TlkWaterMeters/masks')))
train, test = train_test_split(intersect, random_state=57)

IMAGES_DIR = './TlkWaterMeters/images'
MASKS_DIR = './TlkWaterMeters/masks'
HEIGHT = 256
WIDTH = 256
EPOCHS = 20


class WaterMeter(Dataset):
    def __init__(self, objects, image_dir, target_dir, transform=None):
                
        self.image_paths = sorted([os.path.join(image_dir, file) for file in objects])
        self.target_paths = sorted([os.path.join(target_dir, file) for file in objects])
        self.transform = transform
        

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):

        image_path = self.image_paths[idx]
        target_path = self.target_paths[idx]
        
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        mask = cv2.imread(target_path)
        
        # Правим косяки разметки
        mask[mask <= 127] = 0
        mask[mask > 127] = 255
        
        # Первый слой маски - класс "цифры"
        # Второй слой маски - класс "не цифры"
        mask[..., 1] = 255 - mask[..., 1] 
        
        mask = mask / 255.
        
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            return transformed['image'], transformed['mask'][:2, ...]
        else:
            return image, mask[:2, ...]
        
train_transform = A.Compose([
    A.ColorJitter(), 
    A.Rotate(),
    A.geometric.Resize(height=HEIGHT, width=WIDTH),
    A.Normalize(always_apply=True),
    ToTensorV2(transpose_mask=True),
])

test_transform = A.Compose([
    A.geometric.Resize(height=HEIGHT, width=WIDTH),
    A.Normalize(always_apply=True),
    ToTensorV2(transpose_mask=True),
])

train_dataset = WaterMeter(
    objects=train,
    image_dir=IMAGES_DIR, 
    target_dir=MASKS_DIR, 
    transform=train_transform
)

test_dataset = WaterMeter(
    objects=test,
    image_dir=IMAGES_DIR, 
    target_dir=MASKS_DIR, 
    transform=test_transform
)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
image, mask = train_dataset[0]

fig,ax = plt.subplots(nrows=1, ncols=3, figsize=(16,9))

to_show = [image, mask[:1,...], mask[1:,...]]

for i, img in enumerate(to_show):
    ax[i].imshow(img.numpy().transpose())

In [None]:
# Tensorboard setup
time = datetime.now()
writer = SummaryWriter(log_dir=f'log_{time.strftime("%Y%M%D_%H%M%S")}')

In [None]:
# Model setup
model = smp.Unet(
    encoder_name="efficientnet-b7",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=2,                      # model output channels (number of classes in your dataset)
    activation='softmax'
)
model = model.float()
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss = nn.CrossEntropyLoss()

In [None]:
# Training 
for i in range(EPOCHS):
    print('EPOCH: ', i)
    model.train()
    for batch, (X, y) in enumerate(train_dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X.float())
        output = loss(pred, y.long().argmax(dim=1))

        # Backpropagation
        optimizer.zero_grad()
        output.backward()
        optimizer.step()

        if batch % 10 == 0:
            output, current = output.item(), batch * len(X)
            writer.add_scalar('train loss ', output, i * len(X) + batch)
#             writer.add_figure()
            print(f"Train loss: {output:>7f}  [{current:>5d}/{len(train_dataloader.dataset):>5d}]")

    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in test_dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X.float()).detach()
            test_loss += loss(pred, y.long().argmax(dim=1)).item()
            correct += (pred.argmax(1) == y.argmax(1)).type(torch.float).sum().item() / pred.shape[0]
    test_loss /= len(test_dataloader)
    correct /= len(test_dataloader.dataset)
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [None]:
def inference_file(model, image_transform, image_path):
    model.eval()
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = image_transform(image=image)['image']
    image = image.to(device)
    pred = model(image.unsqueeze(0).float())
    pred[pred >= 0.5] = 1
    pred[pred < 0.5] = 0
    pred = pred.squeeze().cpu().data
    return pred

In [None]:
image_path = './photo_2022-07-12 17.53.56.jpeg'
image = test_transform(image=cv2.imread(image_path))['image']
mask = inference_file(model, test_transform, image_path)
fig,ax = plt.subplots(nrows=1, ncols=3, figsize=(16,9))

to_show = [image, mask[:1,...], mask[1:,...]]

for i, img in enumerate(to_show):
    ax[i].imshow(img.numpy().transpose())