Skip to content

Commit

Permalink
code upload
Browse files Browse the repository at this point in the history
  • Loading branch information
quattrinifabio committed Feb 13, 2024
1 parent 157bc49 commit 227131c
Show file tree
Hide file tree
Showing 21 changed files with 1,830 additions and 0 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,32 @@
# FourBi_7

## Setup
To run this project, we used `python 3.11.7` and `pytorch 2.2.0`
```bash
conda create -n fourbi python=3.11.7
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
pip3 install opencv-python wandb pytorch-ignite
```

## Inference
To run the model on a folder with images, run with the following command
```
python binarize.py <path to checkpoint> --src <path to the test images folder>
--dst <path to the output folder>
```

## Training
The model is trained on patches, then evaluated and tested on complete documents. We provide the code to create the patches and train the model.
For example, to train on H-DIBCO12, first download the dataset from http://utopia.duth.gr/~ipratika/HDIBCO2012/benchmark/. Create a folder, then place the images in a sub-folder named "imgs" and the ground truth in a sub-folder named "gt_imgs". Then run the following command:
```
python create_patches.py --path_src <path to the dataset folder>
--path_dst <path to the folder where the patches will be saved>
--patch_size <size of the patches> --overlap_size <size of the overlap>
```
To launch the training, run the following command:
```
python train.py --datasets_paths <all datasets paths>
--eval_dataset_name <name of the validation dataset>
--test_dataset_name <name of the validation dataset>
```

40 changes: 40 additions & 0 deletions binarize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import argparse
import torch
from pathlib import Path
from trainer.fourbi_trainer import FourbiTrainingModule
from data.test_dataset import FolderDataset
from torchvision import transforms

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Binarize a folder of images')
parser.add_argument('model', type=str, metavar='PATH', help='path to the model file')
parser.add_argument('--src', type=str, required=True, help='path to the folder of input images')
parser.add_argument('--dst', type=str, required=True, help='path to the folder of output images')
parser.add_argument('--patch_size', type=int, default=512, help='patch size')
parser.add_argument('--batch_size', type=int, default=8, help='batch size when processing patches')
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dst = Path(args.dst)
dst.mkdir(parents=True, exist_ok=True)

fourbi = FourbiTrainingModule(config={'resume': args.model}, device=device, make_loaders=False)

dataset = FolderDataset(Path(args.src), patch_size=args.patch_size, overlap=True, transform=transforms.ToTensor())
fourbi.test_data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

fourbi.config['test_patch_size'] = args.patch_size
fourbi.config['test_stride'] = args.patch_size // 2
fourbi.config['eval_batch_size'] = args.batch_size

for i, sample in enumerate(fourbi.folder_test()):
key = list(sample.keys())[0]
img, pred, gt = sample[key]
src_img_path = Path(key)

dst_img_path = dst / (src_img_path.stem + '.png')
pred.save(str(dst_img_path))
print(f'({i + 1}/{len(dataset)}) Saving {dst_img_path}')

print('Done.')

18 changes: 18 additions & 0 deletions create_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import argparse
from data.process_image import PatchImage


def main():
parser = argparse.ArgumentParser(description='create patches')
parser.add_argument('--path_dst', type=str, help=f"Destination folder path")
parser.add_argument('--path_src', type=str, help="The path witch contains the images")
parser.add_argument('--patch_size', type=int, help="Patch size", default=384)
parser.add_argument('--overlap_size', type=int, help='Overlap size', default=192)
args = parser.parse_args()

patcher = PatchImage(patch_size=args.patch_size, overlap_size=args.overlap_size, destination_root=args.path_dst)
patcher.create_patches(root_original=args.path_src)


if __name__ == '__main__':
main()
47 changes: 47 additions & 0 deletions data/custom_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from torchvision import transforms
from torchvision.transforms import functional


class ToTensor(transforms.ToTensor):

def __call__(self, sample):
image, gt = sample['image'], sample['gt']
image = super().__call__(image)
gt = super().__call__(gt)
return {'image': image, 'gt': gt}


class ColorJitter(transforms.ColorJitter):

def __call__(self, sample):
image, gt = sample['image'], sample['gt']
image = super().__call__(image)
return {'image': image, 'gt': gt}


class RandomCrop(transforms.RandomCrop):

def __init__(self, size):
super(RandomCrop, self).__init__(size=size)
self.size = size

def __call__(self, sample):
image, gt = sample['image'], sample['gt']
i, j, h, w = self.get_params(image, output_size=(self.size, self.size))
image = functional.crop(image, i, j, h, w)
gt = functional.crop(gt, i, j, h, w)
return {'image': image, 'gt': gt}


class RandomRotation(transforms.RandomRotation):

def __call__(self, sample):
image, gt = sample['image'], sample['gt']
angle = self.get_params(self.degrees)

image = functional.rotate(image, angle, fill=[255, 255, 255])

gt = functional.invert(gt)
gt = functional.rotate(gt, angle)
gt = functional.invert(gt)
return {'image': image, 'gt': gt}
24 changes: 24 additions & 0 deletions data/dataloaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
from torch.utils.data import Dataset

from utils.htr_logging import get_logger

logger = get_logger(__file__)


def make_train_dataloader(train_dataset: Dataset, config: dict):
train_dataloader_config = config['train_kwargs']
train_data_loader = torch.utils.data.DataLoader(train_dataset, **train_dataloader_config)
return train_data_loader


def make_valid_dataloader(valid_dataset: Dataset, config: dict):
valid_dataloader_config = config['eval_kwargs']
valid_data_loader = torch.utils.data.DataLoader(valid_dataset, **valid_dataloader_config)
return valid_data_loader


def make_test_dataloader(test_dataset: Dataset, config: dict):
test_dataloader_config = config['test_kwargs']
test_data_loader = torch.utils.data.DataLoader(test_dataset, **test_dataloader_config)
return test_data_loader
100 changes: 100 additions & 0 deletions data/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from torchvision.transforms import transforms
import time
from data.training_dataset import TrainingDataset
from data.test_dataset import TestDataset
from data.utils import get_transform
from utils.htr_logging import get_logger
from torch.utils.data import ConcatDataset
from pathlib import Path

logger = get_logger(__file__)


def make_train_dataset(config: dict):
train_data_path = config['train_data_path']
patch_size = config['train_patch_size']
load_data = config['load_data']

logger.info(f"Train path: \"{train_data_path}\" with patch size {patch_size} and load_data={load_data}")

transform = get_transform(output_size=patch_size)

logger.info(f"Loading train datasets...")
time_start = time.time()
datasets = []
for i, path in enumerate(train_data_path):
logger.info(f"[{i+1}/{len(train_data_path)}] Loading train dataset from \"{path}\"")
data_path = Path(path) / 'train' if (Path(path) / 'train').exists() else Path(path)
datasets.append(
TrainingDataset(
data_path=data_path,
split_size=patch_size,
patch_size=config['train_patch_size_raw'],
transform=transform,
load_data=load_data))

logger.info(f"Loading train datasets took {time.time() - time_start:.2f} seconds")
train_dataset = ConcatDataset(datasets)
logger.info(f"Training set has {len(train_dataset)} instances")

return train_dataset


def make_val_dataset(config: dict):
val_data_path = config['eval_data_path']
stride = config['test_stride']
patch_size = config['eval_patch_size']
load_data = config['load_data']

transform = transforms.Compose([transforms.ToTensor()])

logger.info(f"Loading validation datasets...")
time_start = time.time()
datasets = []
for i, path in enumerate(val_data_path):
logger.info(f"[{i}/{len(val_data_path)}] Loading validation dataset from \"{path}\"")
datasets.append(
TestDataset(
data_path=Path(path),
patch_size=patch_size,
stride=stride,
transform=transform,
load_data=load_data
)
)

logger.info(f"Loading validation datasets took {time.time() - time_start:.2f} seconds")
validation_dataset = ConcatDataset(datasets)
logger.info(f"Validation set has {len(validation_dataset)} instances")

return validation_dataset


def make_test_dataset(config: dict):
test_data_path = config['test_data_path']
patch_size = config['test_patch_size']
stride = config['test_stride']
load_data = config['load_data']

transform = transforms.Compose([transforms.ToTensor()])

logger.info(f"Loading test datasets...")
time_start = time.time()
datasets = []

for path in test_data_path:
datasets.append(
TestDataset(
data_path=path,
patch_size=patch_size,
stride=stride,
transform=transform,
load_data=load_data))
logger.info(f'Loaded test dataset from {path} with {len(datasets[-1])} instances.')

logger.info(f"Loading test datasets took {time.time() - time_start:.2f} seconds")

test_dataset = ConcatDataset(datasets)

logger.info(f"Test set has {len(test_dataset)} instances")
return test_dataset
85 changes: 85 additions & 0 deletions data/process_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import logging
import cv2
import numpy as np
from pathlib import Path


class PatchImage:

def __init__(self, patch_size: int, overlap_size: int, destination_root: str):
logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.INFO)
destination_root = Path(destination_root)
self.train_folder = destination_root / f'imgs_{patch_size}/'
self.train_gt_folder = destination_root / f'gt_imgs_{patch_size}/'
self.train_folder.mkdir(parents=True, exist_ok=True)
self.train_gt_folder.mkdir(parents=True, exist_ok=True)

self.patch_size = patch_size
self.overlap_size = overlap_size
self.number_image = 1
self.image_name = ""

logging.info(f"Using Patch size: {self.patch_size} - Overlapping: {self.overlap_size}")

def create_patches(self, root_original: str):
logging.info("Start process ...")
root_original = Path(root_original)
gt = root_original / 'gt_imgs'
imgs = root_original / 'imgs'

path_imgs = list(path_img for path_img in imgs.glob('*') if path_img.suffix in {".png", ".jpg", ".bmp", ".tif"})
for i, img in enumerate(path_imgs):
or_img = cv2.imread(str(img))
gt_img = gt / img.name
gt_img = gt_img if gt_img.exists() else gt / (img.stem + '.png')
gt_img = cv2.imread(str(gt_img))
try:
self._split_train_images(or_img, gt_img)
except Exception as e:
print(f'Error: {e} - {img}')

def _split_train_images(self, or_img: np.ndarray, gt_img: np.ndarray):
runtime_size = self.overlap_size
patch_size = self.patch_size
for i in range(0, or_img.shape[0], runtime_size):
for j in range(0, or_img.shape[1], runtime_size):

if i + patch_size <= or_img.shape[0] and j + patch_size <= or_img.shape[1]:
dg_patch = or_img[i:i + patch_size, j:j + patch_size, :]
gt_patch = gt_img[i:i + patch_size, j:j + patch_size, :]

elif i + patch_size > or_img.shape[0] and j + patch_size <= or_img.shape[1]:
dg_patch = np.ones((patch_size, patch_size, 3)) * 255
gt_patch = np.ones((patch_size, patch_size, 3)) * 255

dg_patch[0:or_img.shape[0] - i, :, :] = or_img[i:or_img.shape[0], j:j + patch_size, :]
gt_patch[0:or_img.shape[0] - i, :, :] = gt_img[i:or_img.shape[0], j:j + patch_size, :]

elif i + patch_size <= or_img.shape[0] and j + patch_size > or_img.shape[1]:
dg_patch = np.ones((patch_size, patch_size, 3)) * 255
gt_patch = np.ones((patch_size, patch_size, 3)) * 255

dg_patch[:, 0:or_img.shape[1] - j, :] = or_img[i:i + patch_size, j:or_img.shape[1], :]
gt_patch[:, 0:or_img.shape[1] - j, :] = gt_img[i:i + patch_size, j:or_img.shape[1], :]

else:
dg_patch = np.ones((patch_size, patch_size, 3)) * 255
gt_patch = np.ones((patch_size, patch_size, 3)) * 255

dg_patch[0:or_img.shape[0] - i, 0:or_img.shape[1] - j, :] = or_img[i:or_img.shape[0],
j:or_img.shape[1],
:]
gt_patch[0:or_img.shape[0] - i, 0:or_img.shape[1] - j, :] = gt_img[i:or_img.shape[0],
j:or_img.shape[1],
:]
gt_patch[0:or_img.shape[0] - i, 0:or_img.shape[1] - j, :] = gt_img[i:or_img.shape[0],
j:or_img.shape[1],
:]

cv2.imwrite(str(self.train_folder / (str(self.number_image) + '.png')), dg_patch)
cv2.imwrite(str(self.train_gt_folder / (str(self.number_image) + '.png')), gt_patch)
self.number_image += 1
print(self.number_image, end='\r')

def _create_name(self, folder: str, i: int, j: int):
return folder + self.image_name.split('.')[0] + '_' + str(i) + '_' + str(j) + '.png'
Loading

0 comments on commit 227131c

Please sign in to comment.