In [82]:
from pathlib import Path
from course_intro_ocr_t1.data import MidvPackage
import numpy as np
import cv2 as cv
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import lightning

In [83]:
DATASET_PATH = Path('midv500_compressed').resolve()
assert DATASET_PATH.exists(), DATASET_PATH.absolute()

packs = MidvPackage.read_midv500_dataset(DATASET_PATH)

In [90]:
class CropDataset(Dataset):
    def __init__(self, datapacks, is_test):
        self.datapacks = datapacks
        self.data_indexes = []
        
        for i in range(len(datapacks)):
            for j in range(len(datapacks[i])):
                if datapacks[i][j].is_test_split() == is_test:
                    self.data_indexes.append((i, j))

    def __len__(self):
        return len(self.data_indexes)

    def __getitem__(self, idx):       
        i, j = self.data_indexes[idx]
        
        image = torch.FloatTensor(np.array(self.datapacks[i][j].image.convert('RGB'))) / 255
        image = image.permute([2, 0, 1])
        image = transforms.Resize((256, 256))(image)

        trg = cv.fillConvexPoly(np.zeros(np.array(self.datapacks[i][j].image).shape[:2]), np.array(self.datapacks[i][j].gt_data['quad']), 1)
        trg = transforms.ToTensor()(trg)
        trg = transforms.Resize((256, 256), interpolation=Image.NEAREST)(trg)

        return image, trg

In [91]:
train_set = CropDataset(packs, is_test=False)
val_set = CropDataset(packs, is_test=True)

batch_size = 6
num_workers = 4

train_loader = DataLoader(train_set, batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_set, batch_size, shuffle=False, num_workers=num_workers)

In [92]:
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Sequential(
                        nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1),
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU())
        self.conv2 = nn.Sequential(
                        nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1),
                        nn.BatchNorm2d(out_channels))
        self.downsample = nn.Sequential(
                        nn.Conv2d(in_channels, out_channels, kernel_size=1),
                        nn.BatchNorm2d(out_channels))
        
    def forward(self, x):
        residual = self.downsample(x)
        out = self.conv1(x)
        out = self.conv2(out)
        out += residual
        out = nn.ReLU()(out)
        return out
    

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.enc_conv1 = ResNetBlock(in_channels, 32)
        self.enc_pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.enc_conv2 = ResNetBlock(32, 64)
        self.enc_pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.enc_conv3 = ResNetBlock(64, 128)
        self.enc_pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.enc_conv4 = ResNetBlock(128, 256)
        self.enc_pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.bridge = ResNetBlock(256, 512)
        
        self.dec_upsample1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(512, 256, kernel_size=1)
        )
        self.dec_conv1 = ResNetBlock(512, 256)

        self.dec_upsample2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 128, kernel_size=1)
        )
        self.dec_conv2 = ResNetBlock(256, 128)

        self.dec_upsample3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(128, 64, kernel_size=1)
        )
        self.dec_conv3 = ResNetBlock(128, 64)

        self.dec_upsample4 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(64, 32, kernel_size=1)
        )
        self.dec_conv4 = ResNetBlock(64, 32)
        
        self.out = nn.Conv2d(32, out_channels, kernel_size=1)
        
    def forward(self, x):
        enc1 = self.enc_conv1(x)
        enc2 = self.enc_conv2(self.enc_pool1(enc1))
        enc3 = self.enc_conv3(self.enc_pool2(enc2))
        enc4 = self.enc_conv4(self.enc_pool3(enc3))
        
        bridge = self.bridge(self.enc_pool4(enc4))
        
        out = self.dec_upsample1(bridge)
        out = torch.cat([out, enc4], dim=1)
        out = self.dec_conv1(out)
        
        out = self.dec_upsample2(out)
        out = torch.cat([out, enc3], dim=1)
        out = self.dec_conv2(out)

        out = self.dec_upsample3(out)
        out = torch.cat([out, enc2], dim=1)
        out = self.dec_conv3(out)
        
        out = self.dec_upsample4(out)
        out = torch.cat([out, enc1], dim=1)
        out = self.dec_conv4(out)
        
        out = self.out(out)
        return out

In [140]:
class LitIdentifier(lightning.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.criterion = nn.BCEWithLogitsLoss()
        
    def training_step(self, batch, batch_idx):
        inputs, targets = batch

        outputs = self.model(inputs)
        loss = self.criterion(outputs, targets)
        
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, targets = batch

        outputs = self.model(inputs)
        loss = self.criterion(outputs, targets)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(params=self.model.parameters(), lr=1e-3)
        return optimizer

In [146]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = UNet(in_channels=3, out_channels=1).to(device)

lit_model = LitIdentifier(model)
trainer = lightning.Trainer(max_epochs=5)
trainer.fit(lit_model, train_loader, val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type              | Params
------------------------------------------------
0 | model     | UNet              | 7.6 M 
1 | criterion | BCEWithLogitsLoss | 0     
------------------------------------------------
7.6 M     Trainable params
0         Non-trainable params
7.6 M     Total params
30.386    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


In [165]:
def get_corners(mask):
    mask = (torch.sigmoid(mask) > 0.2).cpu().numpy().astype(np.uint8)
    contours, hierarchy = cv.findContours(mask, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
    contour = max(contours, key=cv.contourArea)
    x, y = np.array(contour)[:, 0].T
    
    right_bottom = np.argmax(y + x)
    left_bottom = np.argmax(y - x)
    left_top = np.argmax(-y - x)
    right_top = np.argmax(-y + x)

    corners = np.array([[x[left_top], y[left_top]], 
                        [x[right_top], y[right_top]], 
                        [x[right_bottom], y[right_bottom]], 
                        [x[left_bottom], y[left_bottom]]], dtype=float)
    
    return corners

def to_relative_scale(corners, height, width):
    for i in range(len(corners)):
        corners[i][0] /= width
        corners[i][1] /= height
    return corners

In [None]:
results_dict = dict()
model.to(device)
model.eval()

for dp in packs:
    for i in range(len(dp)):
        if dp[i].is_test_split():
            try:
                image = torch.FloatTensor(np.array(dp[i].image.convert('RGB'))) / 255
                image = image.permute([2, 0, 1])
                image = transforms.Resize((256, 256))(image)
                mask = model(image.unsqueeze(0).to(device))
                corners = get_corners(mask.squeeze()).reshape(-1, 2)
                results_dict[dp[i].unique_key] = to_relative_scale(np.array(corners, dtype=np.float32), mask.size()[2], mask.size()[3])
            except Exception as exc:
                # Для пропущенных в словаре ключей в метриках автоаматически засчитается IoU=0
                print(exc)

In [167]:
from course_intro_ocr_t1.metrics import dump_results_dict, measure_crop_accuracy

dump_results_dict(results_dict, Path() / 'pred.json')

acc = measure_crop_accuracy(
    Path() / 'pred.json',
    Path() / 'gt.json'
)

print("Точность кропа: {:1.4f}".format(acc))

Точность кропа: 0.8125


In [164]:
torch.save(model.state_dict(), 'UNet.pt')