In [1]:
cd course_intro_ocr/task1

/home/silevichar/liza/course_intro_ocr/task1


In [2]:
%pip install -e .

Obtaining file:///home/silevichar/liza/course_intro_ocr/task1
  Installing build dependencies ... [?25ldone
[?25h  Checking if build backend supports build_editable ... [?25ldone
[?25h  Getting requirements to build editable ... [?25ldone
[?25h  Preparing editable metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: course_intro_ocr_t1
  Building editable for course_intro_ocr_t1 (pyproject.toml) ... [?25ldone
[?25h  Created wheel for course_intro_ocr_t1: filename=course_intro_ocr_t1-0.1.1-py3-none-any.whl size=1276 sha256=b91b87cc501f3f329455b9abfa763e2a2f0d3226b9c6b1999f17f1f5f1a0668e
  Stored in directory: /tmp/pip-ephem-wheel-cache-ymobz5ln/wheels/37/ad/4b/a41532ed0e2a03fe09e38e769dbe639574c2ad527611b043c4
Successfully built course_intro_ocr_t1
Installing collected packages: course_intro_ocr_t1
  Attempting uninstall: course_intro_ocr_t1
    Found existing installation: course_intro_ocr_t1 0.1.1
    Uninstalling course_intro_ocr_t1-0.1.1:
      Su

In [3]:
from pathlib import Path
from course_intro_ocr_t1.data import MidvPackage
from tqdm import tqdm
from matplotlib import pyplot as plt
from matplotlib import patches
import numpy as np

import cv2
from  torch.utils.data import Dataset, DataLoader

import torch
import torch.nn as nn
from tqdm.notebook import tqdm
from time import time

In [4]:
gpu_number = 4
device = torch.device(device=f'cuda:{gpu_number}')
device

device(type='cuda', index=4)

In [5]:
DATASET_PATH = Path().absolute().parent.parent / 'midv500_compressed'
assert DATASET_PATH.exists(), DATASET_PATH.absolute()
DATASET_PATH

PosixPath('/home/silevichar/liza/midv500_compressed')

In [6]:
# Собираем список пакетов (MidvPackage) 
data_packs = MidvPackage.read_midv500_dataset(DATASET_PATH)
len(data_packs), type(data_packs[0])

(50, course_intro_ocr_t1.data.MidvPackage)

In [7]:
class crop_dataset(Dataset):
    def __init__(self, data_packs, train=True):
        self.data_packs = data_packs
        self.indices = []
        self.device = device

        for i, data_pack in enumerate(data_packs):
            for j in range(len(data_pack)):
                if train:
                    if not data_pack[j].is_test_split():
                        self.indices.append((i, j))
                else:
                    if data_pack[j].is_test_split():
                        self.indices.append((i, j))
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        i, j = self.indices[idx]
        dp = self.data_packs[i][j]
        image = np.array(dp.image.convert('RGB')) / 255.
        mask = cv2.fillConvexPoly(np.zeros(image.shape[:2]), 
                                  np.array(dp.gt_data['quad']), (1,))[np.newaxis, ...]
        
        return torch.tensor(image.transpose(2, 0, 1), 
                            dtype=torch.float, 
                            device=device),\
                torch.tensor(mask, 
                             dtype=torch.float, 
                             device=device)

    def get_key(self, idx):
        i, j = self.indices[idx]
        dp = self.data_packs[i][j]
        return dp.unique_key
    
class EncConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.block(x)

class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2)
        self.dec_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, u, e):
        u = self.upsample(u)
        pad_w = e.shape[2] - u.shape[2]
        pad_h = e.shape[3] - u.shape[3]
        padding = [pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2]
        u = nn.functional.pad(u, padding)
        return self.dec_conv(torch.cat((e, u), dim=1))

class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.enc_conv0 = EncConvBlock(3, 64)
        self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc_conv1 = EncConvBlock(64, 128)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc_conv2 = EncConvBlock(128, 256)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc_conv3 = EncConvBlock(256, 512)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bottleneck_conv = EncConvBlock(512, 1024)
        self.up_0 = UpsampleBlock(1536, 512)
        self.up_1 = UpsampleBlock(768, 256)
        self.up_2 = UpsampleBlock(384, 128)
        self.up_3 = UpsampleBlock(192, 64)
        self.out = nn.Sequential(
            nn.Conv2d(64, 1, kernel_size=3, padding=1),
            nn.BatchNorm2d(1),
        )

    def forward(self, x):
        e0 = self.enc_conv0(x)
        e1 = self.enc_conv1(self.pool0(e0))
        e2 = self.enc_conv2(self.pool1(e1))
        e3 = self.enc_conv3(self.pool2(e2))
        b = self.bottleneck_conv(self.pool3(e3))
        u0 = self.up_0(b, e3)
        u1 = self.up_1(u0, e2)
        u2 = self.up_2(u1, e1)
        u3 = self.up_3(u2, e0)
        out = self.out(u3)
        return out

In [8]:
def train(model, opt, loss_fn, epochs, data_tr, data_val):
    X_val, Y_val = next(iter(data_val))
    train_loss, val_loss = [], []

    for epoch in range(epochs):
        tic = time()
        print('* Epoch %d/%d' % (epoch+1, epochs))

        avg_loss = 0
        model.train()
        for X_batch, Y_batch in tqdm(data_tr):

            opt.zero_grad()
            Y_pred = model(X_batch)
            Y_pred = torch.sigmoid(Y_pred)
            loss = loss_fn(Y_pred, Y_batch)
            loss.backward()
            opt.step()
            avg_loss += loss / len(data_tr)
        toc = time()
        print('loss: %f' % avg_loss)
        train_loss.append(avg_loss.item())
        
        model.eval()
        Y_hat = model(X_val.to(device)).detach().cpu()

        with torch.no_grad():
            val_loss_sum = 0
            for X_val_batch, Y_val_batch in tqdm(data_val):
                with torch.set_grad_enabled(False):
                    Y_pred_batch = model(X_val_batch)
                    loss = loss_fn(torch.sigmoid(Y_pred_batch), Y_val_batch)
                    prediction = torch.sigmoid(Y_pred_batch) > 0.5
                val_loss_sum += loss
            val_loss.append((val_loss_sum/len(data_val)).item())
        
        torch.save(model.state_dict(), f'epoch{epoch}.pth')

    return train_loss, val_loss

In [None]:
batch_size = 4
train_data = crop_dataset(data_packs)
val_data = crop_dataset(data_packs, train=False)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)


model = UNet()
model.to(device)
model.load_state_dict(torch.load('epoch5.pth'))

train(model, torch.optim.Adam(model.parameters()), nn.BCELoss(), 10, train_loader, val_loader)

torch.save(model.state_dict(), 'checkpoint.pth')

In [None]:
from course_intro_ocr_t1.metrics import dump_results_dict, measure_crop_accuracy

results_dict = dict()

def get_vertices(prediction):
    mask = torch.sigmoid(prediction.cpu().detach()) > 0.5
    nonzero = mask.nonzero()
    sums = nonzero.sum(axis=1)
    try:
        top_left = nonzero[sums.argmin()]
        bottom_right = nonzero[sums.argmax()]
        diffs = nonzero[:, 0] - nonzero[:, 1]
        top_right = nonzero[diffs.argmin()]
        bottom_left = nonzero[diffs.argmax()]
        return torch.stack([top_left, top_right, bottom_right, bottom_left]).numpy().astype(float)
    except IndexError:
        return torch.zeros((4, 2)).numpy().astype(float)

for idx in tqdm(range(len(val_data))):
    image, _ = val_data[idx]
    result = model(image.unsqueeze(0))[0][0]
    vertices = get_vertices(result)
    vertices[:, 0] /= result.shape[0]
    vertices[:, 1] /= result.shape[1]
    vertices[:, [0, 1]] = vertices[:, [1, 0]]
    key = val_data.get_key(idx)
    results_dict[key] = vertices

dump_results_dict(results_dict, Path() / 'pred.json')