-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
157bc49
commit 227131c
Showing
21 changed files
with
1,830 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,32 @@ | ||
# FourBi_7 | ||
|
||
## Setup | ||
To run this project, we used `python 3.11.7` and `pytorch 2.2.0` | ||
```bash | ||
conda create -n fourbi python=3.11.7 | ||
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia | ||
pip3 install opencv-python wandb pytorch-ignite | ||
``` | ||
|
||
## Inference | ||
To run the model on a folder with images, run with the following command | ||
``` | ||
python binarize.py <path to checkpoint> --src <path to the test images folder> | ||
--dst <path to the output folder> | ||
``` | ||
|
||
## Training | ||
The model is trained on patches, then evaluated and tested on complete documents. We provide the code to create the patches and train the model. | ||
For example, to train on H-DIBCO12, first download the dataset from http://utopia.duth.gr/~ipratika/HDIBCO2012/benchmark/. Create a folder, then place the images in a sub-folder named "imgs" and the ground truth in a sub-folder named "gt_imgs". Then run the following command: | ||
``` | ||
python create_patches.py --path_src <path to the dataset folder> | ||
--path_dst <path to the folder where the patches will be saved> | ||
--patch_size <size of the patches> --overlap_size <size of the overlap> | ||
``` | ||
To launch the training, run the following command: | ||
``` | ||
python train.py --datasets_paths <all datasets paths> | ||
--eval_dataset_name <name of the validation dataset> | ||
--test_dataset_name <name of the validation dataset> | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import argparse | ||
import torch | ||
from pathlib import Path | ||
from trainer.fourbi_trainer import FourbiTrainingModule | ||
from data.test_dataset import FolderDataset | ||
from torchvision import transforms | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='Binarize a folder of images') | ||
parser.add_argument('model', type=str, metavar='PATH', help='path to the model file') | ||
parser.add_argument('--src', type=str, required=True, help='path to the folder of input images') | ||
parser.add_argument('--dst', type=str, required=True, help='path to the folder of output images') | ||
parser.add_argument('--patch_size', type=int, default=512, help='patch size') | ||
parser.add_argument('--batch_size', type=int, default=8, help='batch size when processing patches') | ||
args = parser.parse_args() | ||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
dst = Path(args.dst) | ||
dst.mkdir(parents=True, exist_ok=True) | ||
|
||
fourbi = FourbiTrainingModule(config={'resume': args.model}, device=device, make_loaders=False) | ||
|
||
dataset = FolderDataset(Path(args.src), patch_size=args.patch_size, overlap=True, transform=transforms.ToTensor()) | ||
fourbi.test_data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) | ||
|
||
fourbi.config['test_patch_size'] = args.patch_size | ||
fourbi.config['test_stride'] = args.patch_size // 2 | ||
fourbi.config['eval_batch_size'] = args.batch_size | ||
|
||
for i, sample in enumerate(fourbi.folder_test()): | ||
key = list(sample.keys())[0] | ||
img, pred, gt = sample[key] | ||
src_img_path = Path(key) | ||
|
||
dst_img_path = dst / (src_img_path.stem + '.png') | ||
pred.save(str(dst_img_path)) | ||
print(f'({i + 1}/{len(dataset)}) Saving {dst_img_path}') | ||
|
||
print('Done.') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import argparse | ||
from data.process_image import PatchImage | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='create patches') | ||
parser.add_argument('--path_dst', type=str, help=f"Destination folder path") | ||
parser.add_argument('--path_src', type=str, help="The path witch contains the images") | ||
parser.add_argument('--patch_size', type=int, help="Patch size", default=384) | ||
parser.add_argument('--overlap_size', type=int, help='Overlap size', default=192) | ||
args = parser.parse_args() | ||
|
||
patcher = PatchImage(patch_size=args.patch_size, overlap_size=args.overlap_size, destination_root=args.path_dst) | ||
patcher.create_patches(root_original=args.path_src) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from torchvision import transforms | ||
from torchvision.transforms import functional | ||
|
||
|
||
class ToTensor(transforms.ToTensor): | ||
|
||
def __call__(self, sample): | ||
image, gt = sample['image'], sample['gt'] | ||
image = super().__call__(image) | ||
gt = super().__call__(gt) | ||
return {'image': image, 'gt': gt} | ||
|
||
|
||
class ColorJitter(transforms.ColorJitter): | ||
|
||
def __call__(self, sample): | ||
image, gt = sample['image'], sample['gt'] | ||
image = super().__call__(image) | ||
return {'image': image, 'gt': gt} | ||
|
||
|
||
class RandomCrop(transforms.RandomCrop): | ||
|
||
def __init__(self, size): | ||
super(RandomCrop, self).__init__(size=size) | ||
self.size = size | ||
|
||
def __call__(self, sample): | ||
image, gt = sample['image'], sample['gt'] | ||
i, j, h, w = self.get_params(image, output_size=(self.size, self.size)) | ||
image = functional.crop(image, i, j, h, w) | ||
gt = functional.crop(gt, i, j, h, w) | ||
return {'image': image, 'gt': gt} | ||
|
||
|
||
class RandomRotation(transforms.RandomRotation): | ||
|
||
def __call__(self, sample): | ||
image, gt = sample['image'], sample['gt'] | ||
angle = self.get_params(self.degrees) | ||
|
||
image = functional.rotate(image, angle, fill=[255, 255, 255]) | ||
|
||
gt = functional.invert(gt) | ||
gt = functional.rotate(gt, angle) | ||
gt = functional.invert(gt) | ||
return {'image': image, 'gt': gt} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
from utils.htr_logging import get_logger | ||
|
||
logger = get_logger(__file__) | ||
|
||
|
||
def make_train_dataloader(train_dataset: Dataset, config: dict): | ||
train_dataloader_config = config['train_kwargs'] | ||
train_data_loader = torch.utils.data.DataLoader(train_dataset, **train_dataloader_config) | ||
return train_data_loader | ||
|
||
|
||
def make_valid_dataloader(valid_dataset: Dataset, config: dict): | ||
valid_dataloader_config = config['eval_kwargs'] | ||
valid_data_loader = torch.utils.data.DataLoader(valid_dataset, **valid_dataloader_config) | ||
return valid_data_loader | ||
|
||
|
||
def make_test_dataloader(test_dataset: Dataset, config: dict): | ||
test_dataloader_config = config['test_kwargs'] | ||
test_data_loader = torch.utils.data.DataLoader(test_dataset, **test_dataloader_config) | ||
return test_data_loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from torchvision.transforms import transforms | ||
import time | ||
from data.training_dataset import TrainingDataset | ||
from data.test_dataset import TestDataset | ||
from data.utils import get_transform | ||
from utils.htr_logging import get_logger | ||
from torch.utils.data import ConcatDataset | ||
from pathlib import Path | ||
|
||
logger = get_logger(__file__) | ||
|
||
|
||
def make_train_dataset(config: dict): | ||
train_data_path = config['train_data_path'] | ||
patch_size = config['train_patch_size'] | ||
load_data = config['load_data'] | ||
|
||
logger.info(f"Train path: \"{train_data_path}\" with patch size {patch_size} and load_data={load_data}") | ||
|
||
transform = get_transform(output_size=patch_size) | ||
|
||
logger.info(f"Loading train datasets...") | ||
time_start = time.time() | ||
datasets = [] | ||
for i, path in enumerate(train_data_path): | ||
logger.info(f"[{i+1}/{len(train_data_path)}] Loading train dataset from \"{path}\"") | ||
data_path = Path(path) / 'train' if (Path(path) / 'train').exists() else Path(path) | ||
datasets.append( | ||
TrainingDataset( | ||
data_path=data_path, | ||
split_size=patch_size, | ||
patch_size=config['train_patch_size_raw'], | ||
transform=transform, | ||
load_data=load_data)) | ||
|
||
logger.info(f"Loading train datasets took {time.time() - time_start:.2f} seconds") | ||
train_dataset = ConcatDataset(datasets) | ||
logger.info(f"Training set has {len(train_dataset)} instances") | ||
|
||
return train_dataset | ||
|
||
|
||
def make_val_dataset(config: dict): | ||
val_data_path = config['eval_data_path'] | ||
stride = config['test_stride'] | ||
patch_size = config['eval_patch_size'] | ||
load_data = config['load_data'] | ||
|
||
transform = transforms.Compose([transforms.ToTensor()]) | ||
|
||
logger.info(f"Loading validation datasets...") | ||
time_start = time.time() | ||
datasets = [] | ||
for i, path in enumerate(val_data_path): | ||
logger.info(f"[{i}/{len(val_data_path)}] Loading validation dataset from \"{path}\"") | ||
datasets.append( | ||
TestDataset( | ||
data_path=Path(path), | ||
patch_size=patch_size, | ||
stride=stride, | ||
transform=transform, | ||
load_data=load_data | ||
) | ||
) | ||
|
||
logger.info(f"Loading validation datasets took {time.time() - time_start:.2f} seconds") | ||
validation_dataset = ConcatDataset(datasets) | ||
logger.info(f"Validation set has {len(validation_dataset)} instances") | ||
|
||
return validation_dataset | ||
|
||
|
||
def make_test_dataset(config: dict): | ||
test_data_path = config['test_data_path'] | ||
patch_size = config['test_patch_size'] | ||
stride = config['test_stride'] | ||
load_data = config['load_data'] | ||
|
||
transform = transforms.Compose([transforms.ToTensor()]) | ||
|
||
logger.info(f"Loading test datasets...") | ||
time_start = time.time() | ||
datasets = [] | ||
|
||
for path in test_data_path: | ||
datasets.append( | ||
TestDataset( | ||
data_path=path, | ||
patch_size=patch_size, | ||
stride=stride, | ||
transform=transform, | ||
load_data=load_data)) | ||
logger.info(f'Loaded test dataset from {path} with {len(datasets[-1])} instances.') | ||
|
||
logger.info(f"Loading test datasets took {time.time() - time_start:.2f} seconds") | ||
|
||
test_dataset = ConcatDataset(datasets) | ||
|
||
logger.info(f"Test set has {len(test_dataset)} instances") | ||
return test_dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import logging | ||
import cv2 | ||
import numpy as np | ||
from pathlib import Path | ||
|
||
|
||
class PatchImage: | ||
|
||
def __init__(self, patch_size: int, overlap_size: int, destination_root: str): | ||
logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.INFO) | ||
destination_root = Path(destination_root) | ||
self.train_folder = destination_root / f'imgs_{patch_size}/' | ||
self.train_gt_folder = destination_root / f'gt_imgs_{patch_size}/' | ||
self.train_folder.mkdir(parents=True, exist_ok=True) | ||
self.train_gt_folder.mkdir(parents=True, exist_ok=True) | ||
|
||
self.patch_size = patch_size | ||
self.overlap_size = overlap_size | ||
self.number_image = 1 | ||
self.image_name = "" | ||
|
||
logging.info(f"Using Patch size: {self.patch_size} - Overlapping: {self.overlap_size}") | ||
|
||
def create_patches(self, root_original: str): | ||
logging.info("Start process ...") | ||
root_original = Path(root_original) | ||
gt = root_original / 'gt_imgs' | ||
imgs = root_original / 'imgs' | ||
|
||
path_imgs = list(path_img for path_img in imgs.glob('*') if path_img.suffix in {".png", ".jpg", ".bmp", ".tif"}) | ||
for i, img in enumerate(path_imgs): | ||
or_img = cv2.imread(str(img)) | ||
gt_img = gt / img.name | ||
gt_img = gt_img if gt_img.exists() else gt / (img.stem + '.png') | ||
gt_img = cv2.imread(str(gt_img)) | ||
try: | ||
self._split_train_images(or_img, gt_img) | ||
except Exception as e: | ||
print(f'Error: {e} - {img}') | ||
|
||
def _split_train_images(self, or_img: np.ndarray, gt_img: np.ndarray): | ||
runtime_size = self.overlap_size | ||
patch_size = self.patch_size | ||
for i in range(0, or_img.shape[0], runtime_size): | ||
for j in range(0, or_img.shape[1], runtime_size): | ||
|
||
if i + patch_size <= or_img.shape[0] and j + patch_size <= or_img.shape[1]: | ||
dg_patch = or_img[i:i + patch_size, j:j + patch_size, :] | ||
gt_patch = gt_img[i:i + patch_size, j:j + patch_size, :] | ||
|
||
elif i + patch_size > or_img.shape[0] and j + patch_size <= or_img.shape[1]: | ||
dg_patch = np.ones((patch_size, patch_size, 3)) * 255 | ||
gt_patch = np.ones((patch_size, patch_size, 3)) * 255 | ||
|
||
dg_patch[0:or_img.shape[0] - i, :, :] = or_img[i:or_img.shape[0], j:j + patch_size, :] | ||
gt_patch[0:or_img.shape[0] - i, :, :] = gt_img[i:or_img.shape[0], j:j + patch_size, :] | ||
|
||
elif i + patch_size <= or_img.shape[0] and j + patch_size > or_img.shape[1]: | ||
dg_patch = np.ones((patch_size, patch_size, 3)) * 255 | ||
gt_patch = np.ones((patch_size, patch_size, 3)) * 255 | ||
|
||
dg_patch[:, 0:or_img.shape[1] - j, :] = or_img[i:i + patch_size, j:or_img.shape[1], :] | ||
gt_patch[:, 0:or_img.shape[1] - j, :] = gt_img[i:i + patch_size, j:or_img.shape[1], :] | ||
|
||
else: | ||
dg_patch = np.ones((patch_size, patch_size, 3)) * 255 | ||
gt_patch = np.ones((patch_size, patch_size, 3)) * 255 | ||
|
||
dg_patch[0:or_img.shape[0] - i, 0:or_img.shape[1] - j, :] = or_img[i:or_img.shape[0], | ||
j:or_img.shape[1], | ||
:] | ||
gt_patch[0:or_img.shape[0] - i, 0:or_img.shape[1] - j, :] = gt_img[i:or_img.shape[0], | ||
j:or_img.shape[1], | ||
:] | ||
gt_patch[0:or_img.shape[0] - i, 0:or_img.shape[1] - j, :] = gt_img[i:or_img.shape[0], | ||
j:or_img.shape[1], | ||
:] | ||
|
||
cv2.imwrite(str(self.train_folder / (str(self.number_image) + '.png')), dg_patch) | ||
cv2.imwrite(str(self.train_gt_folder / (str(self.number_image) + '.png')), gt_patch) | ||
self.number_image += 1 | ||
print(self.number_image, end='\r') | ||
|
||
def _create_name(self, folder: str, i: int, j: int): | ||
return folder + self.image_name.split('.')[0] + '_' + str(i) + '_' + str(j) + '.png' |
Oops, something went wrong.