In [1]:
import os

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="6"

In [3]:
from pathlib import Path
from course_intro_ocr_t1.data import MidvPackage
from tqdm import tqdm
from matplotlib import pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from torch import nn
import cv2
import torch.nn.functional as F

from pylab import imshow

In [4]:
DATASET_PATH = Path().absolute().parent.parent / 'midv500_compressed'
assert DATASET_PATH.exists(), DATASET_PATH.absolute()

In [5]:
# Собираем список пакетов (MidvPackage) 
data_packs = MidvPackage.read_midv500_dataset(DATASET_PATH)
len(data_packs), type(data_packs[0])

(50, course_intro_ocr_t1.data.MidvPackage)

### Data

In [9]:
class MIDVDataset(Dataset):
    def __init__(self, dp, validation_flag=False):
        self.dp = dp
        self.data_idx = []
        
        for i in range(len(dp)):
            for j in range(len(dp[i])):                
                if validation_flag and dp[i][j].is_test_split():
                    self.data_idx.append((i, j))
                elif not validation_flag and not dp[i][j].is_test_split():
                    self.data_idx.append((i, j))

    def __len__(self):
        return len(self.data_idx)

    def __getitem__(self, idx):
        i, j = self.data_idx[idx]

        img = transforms.ToTensor()(np.array(self.dp[i][j].image))
        
        corner_coordinates = np.array(self.dp[i][j].gt_data['quad']).reshape(4, 2)
        mask = cv2.fillConvexPoly(np.zeros(img.size()[1:]), corner_coordinates, 1)
        img = transforms.Resize((256, 256))(img)
        mask = transforms.ToTensor()(mask)
        mask = mask.float()
        mask = transforms.Resize((256, 256))(mask)
                
        return img, mask

In [10]:
train_dataset = MIDVDataset(data_packs, False)
valid_dataset = MIDVDataset(data_packs, True)
print(len(train_dataset), len(valid_dataset))

10750 4250


In [11]:
batch_size = 16

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True,  drop_last=True)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

### Crop

In [12]:
class DoubleConv(nn.Module):

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [13]:
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        self.down4 = (Down(512, 1024))
        self.up1 = (Up(1024, 512))
        self.up2 = (Up(512, 256))
        self.up3 = (Up(256, 128))
        self.up4 = (Up(128, 64))
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [14]:
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        inputs = F.sigmoid(inputs)       
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

In [15]:
import torch
import lightning
import torchmetrics

class LitClassifier(lightning.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.criterion = DiceLoss()

    def training_step(self, batch, batch_idx):
        inputs, targets = batch

        outputs = self.model(inputs)
        loss = self.criterion(outputs, targets)

        return loss

    def validation_step(self, batch, batch_idx):
        inputs, targets = batch

        outputs = self.model(inputs)
        loss = self.criterion(outputs, targets)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(params=self.model.parameters(), lr=0.001)
        return optimizer

In [16]:
torch.cuda.empty_cache()

In [17]:
model = UNet()
lit_model = LitClassifier(model)
trainer = lightning.Trainer(max_epochs=10, accelerator="gpu")
trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=valid_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
2024-10-11 18:23:27.952121: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [6]

  | Name      | Type     | Params
---------------------------------------
0 | model     | UNet     | 31.0 M
1 | criterion | DiceLoss | 0     
---------------------------------------
31.0 M    Trainable params
0         Non-trainable params
31.0 M    Total params
124.151   Total estimated model params size (MB)


Sanity Checking: |                                                                                            …

/home/krotovan/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
/home/krotovan/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=10` reached.


### Estimation

In [137]:
from course_intro_ocr_t1.metrics import dump_results_dict, measure_crop_accuracy, iou_relative_quads
import cv2 as cv

In [138]:
from course_intro_ocr_t1.metrics import dump_results_dict, measure_crop_accuracy, iou_relative_quads
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device);

In [240]:
def crop(img, return_corners=True):

    img = transforms.ToTensor()(img)
    img_size = img.size()[1:]
    img = img.unsqueeze(0)
    img = transforms.Resize((256, 256))(img).to(device)

    with torch.no_grad():
        mask_ = model(img)
    
    mask = (mask_ > 0)
    mask = transforms.Resize((img_size[0], img_size[1]))(mask).cpu().squeeze().numpy().astype(np.uint8)
    contours, _ = cv.findContours(mask, cv.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contour = max(contours, key=cv.contourArea)

    if contours:
        contour = max(contours, key=cv.contourArea)
        x, y = contour[:, 0, 0], contour[:, 0, 1]

        corners_i = np.array([np.argmin(x + y), np.argmin(-x + y), np.argmax(x + y), np.argmax(-x + y)])
        corners = np.array([x[corners_i], y[corners_i]]).T.astype(float)
    
    y, x = img_size[0], img_size[1]

    normalization = np.array([x, y] * 4, dtype=np.float64).reshape(4, 2)
    

    for i in range(len(corners)):
        corners[i][0] /= normalization[i][0]
        corners[i][1] /= normalization[i][1] 
    
    if return_corners:
        return corners
    else:
        return corners, mask


In [241]:
results_dict = {}
for dp in tqdm(data_packs):
    for i in range(len(dp)):
        if dp[i].is_test_split():
            try:
                results_dict[dp[i].unique_key] = crop(np.array(dp[i].image))
            except Exception as exc:
                # Для пропущенных в словаре ключей в метриках автоаматически засчитается IoU=0
                print(exc)

  6%|██████████                                                                                                                                                              | 3/50 [00:01<00:27,  1.68it/s]

max() arg is an empty sequence


 44%|█████████████████████████████████████████████████████████████████████████▍                                                                                             | 22/50 [00:21<00:33,  1.21s/it]

max() arg is an empty sequence
max() arg is an empty sequence
max() arg is an empty sequence
max() arg is an empty sequence
max() arg is an empty sequence
max() arg is an empty sequence
max() arg is an empty sequence
max() arg is an empty sequence
max() arg is an empty sequence
max() arg is an empty sequence
max() arg is an empty sequence


 58%|████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                      | 29/50 [00:31<00:28,  1.36s/it]

max() arg is an empty sequence
max() arg is an empty sequence
max() arg is an empty sequence
max() arg is an empty sequence
max() arg is an empty sequence
max() arg is an empty sequence
max() arg is an empty sequence
max() arg is an empty sequence
max() arg is an empty sequence


 82%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                              | 41/50 [00:41<00:05,  1.61it/s]

max() arg is an empty sequence
max() arg is an empty sequence
max() arg is an empty sequence


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:49<00:00,  1.00it/s]


# Сохраним результаты и измерим точность
Результаты - словарь с ключем DataItem.unique_key() и значением предсказанным quadrangle в относительных единицах.

In [242]:
dump_results_dict(results_dict, Path() / 'pred.json')

In [243]:
acc = measure_crop_accuracy(
    Path() / 'pred.json',
    Path() / 'gt.json'
)

In [244]:
print("Точность кропа: {:1.4f}".format(acc))

Точность кропа: 0.8365
