In [13]:

from config.args import get_arguments
from config.config import features,labels,duplex_labels,train_test_ratio,label_extension,threshold, train_val_ratio
from dataloader.load_data import SimplexDataset, DuplexDataset, InputSimplexDataset, InputDuplexDataset, load_dataloader
from utils.load_json import load_results
from utils.save_load_model import load
from utils.convert import convertBinary
from model.UNet import UNet
from model.ENet import ENet
from model.DeepLabV3 import CustomDeepLabV3
from train.train import Train
from train.test import Test

import logging
import os
import torch
import json
from torchvision.transforms.functional import to_pil_image
from torch.utils.data import random_split, DataLoader

In [14]:
def train(model,dataset):
    train_set, test_set = random_split(dataset,train_test_ratio,torch.Generator())
    train = Train(model,device,train_set,args)
    train.run()
    
    model = load(model,args)
    test = Test(model,device,test_set,int(args.batch))
    test.run()
    
    train.save_plot()

def test(model,dataset):
    model = load(model,args)
    test = Test(model,device,dataset,int(args.batch))
    test.run()

def inference(model):
    if not os.path.exists(args.output_folder):
        os.mkdir(args.output_folder)
    model.eval()
    model.to(device)
    if args.dataset == "simplex":
        data = InputSimplexDataset(args)
    elif args.dataset == "duplex":
        data = InputDuplexDataset(args)
    loader = DataLoader(data,batch_size = int(args.batch))
    
    json_name = ""
    intermediate = []
    for batch in loader:
        input = batch[2].to(device)
        outputs = model(input.float())
        outputs = convertBinary(outputs)
        for name,pn,output in zip(batch[0],batch[1],outputs):
            if not json_name:
                json_name = name
            if json_name != name:
                fp = open(os.path.join(args.output_folder,json_name,"results.json"),"w")
                json.dump({"intermediate_results":intermediate},fp)
                fp.close()
                json_name = name
                intermediate = []
            pn = pn.item()
            pgnum = (4-len(str(pn)))*"0" + str(pn)
            json_instance = {"pdf_filename":name+".pdf","page_num":pn,"intermediate_dir":"intermediate_results/"+pgnum}
            path = os.path.join(args.output_folder,name)
            if not os.path.exists(path):
                os.mkdir(path)
            path = os.path.join(path,"intermediate_results")
            if not os.path.exists(path):
                os.mkdir(path)
            if args.dataset == "simplex":
                path = os.path.join(path,pgnum)
                if not os.path.exists(path):
                    os.mkdir(path)
                for i, label in enumerate(labels):
                    to_pil_image(output[i]).save(os.path.join(path,f'''{label}{label_extension}'''))
                    json_instance[label] = json_instance["intermediate_dir"]+"/"+str(label)+str(label_extension)
                intermediate.append(json_instance)
            
            elif args.dataset == "duplex":
                pgnum2 = (4-len(str(pn)))*"0" + str(pn+1)
                json_instance2 = {"pdf_filename":name+".pdf","page_num":pn+1,"intermediate_dir":"intermediate_results/"+pgnum2}
                path2 = os.path.join(path,pgnum2)
                path = os.path.join(path,pgnum)
                if not os.path.exists(path):
                    os.mkdir(path)
                if not os.path.exists(path2):
                    os.mkdir(path2)
                cur = 0
                for i, label in enumerate(labels):
                    to_pil_image(output[cur + i]).save(os.path.join(path,f'''{label}{label_extension}'''))
                    json_instance[label] = json_instance["intermediate_dir"]+"/"+str(label)+str(label_extension)
                cur += len(labels)
                for i, label in enumerate(duplex_labels):
                    to_pil_image(output[cur + i]).save(os.path.join(path,f'''{label}{label_extension}'''))
                    json_instance[label] = json_instance["intermediate_dir"]+"/"+str(label)+str(label_extension)
                cur += len(duplex_labels)
                for i, label in enumerate(labels):
                    to_pil_image(output[cur + i]).save(os.path.join(path2,f'''{label}{label_extension}'''))
                    json_instance2[label] = json_instance2["intermediate_dir"]+"/"+str(label)+str(label_extension)
                intermediate.append(json_instance)
                intermediate.append(json_instance2)

In [15]:
class default:
    def __init__(self):
        self.input_folder = "./data/cache/DP_a2200_xml_ff2c81d8ad6655f915cbaa558ee7bf9e878730a8"
        self.output_folder = "./output"
        self.label_folder = "./data/output/DP_a2200_xml_ff2c81d8ad6655f915cbaa558ee7bf9e878730a8"
        self.save_folder = "./checkpoints"
        self.model = "unet"
        self.mode = "train"
        self.epoch = 5
        self.lr = 1e-5
        self.batch = 10
        self.dataset = "simplex"
        
args = default()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [16]:
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logging.info(f'''Using device: {device}''')

n_input = len(features) if args.dataset == 'simplex' else 2*len(features)
n_output = len(labels) if args.dataset == 'simplex' else 2*len(labels)+len(duplex_labels)
if args.mode != 'inference':
    pdf,algorithm,intermediate = load_results(args.label_folder)
    if args.dataset =='simplex':
        dataset = SimplexDataset(args.input_folder,args.label_folder,intermediate)
    else:
        dataset = DuplexDataset(args.input_folder,args.label_folder,intermediate)
        
if args.model == 'unet':
    model = UNet(n_input,n_output)
elif args.model == 'enet':
    model = ENet(n_input,n_output)
elif args.model == 'deeplabv3':
    model = CustomDeepLabV3(n_output)
    


INFO: Using device: cuda
INFO: Preparing SimplexDataset...
INFO: Finished preparing SimplexDataset
INFO: Total 53 samples
INFO: Initializing UNet Model...
INFO: Done initialize UNet Model


In [17]:
train = Train(model,device,dataset,args)
# train.run()
batch = next(iter(train.train_dataloader)) 
inputs, labels = batch[0].to(train.device), batch[1].to(train.device)
preds = train.model(inputs.float())

INFO: Initializing training script...
INFO: Preparing Dataloader...
INFO: Done preparing Dataloader
INFO: Done initialize training script


In [12]:
preds[0]

tensor([0.5536, 0.4413, 0.4860, 0.4658, 0.5008, 0.4639, 0.5059, 0.4479, 0.5090,
        0.4538, 0.4963, 0.4350, 0.4957, 0.4545, 0.4925, 0.4307, 0.4943, 0.4519,
        0.4955, 0.4308, 0.4859, 0.4369, 0.4914, 0.4493, 0.4809, 0.4434, 0.4941,
        0.4451, 0.4810, 0.4446, 0.4932, 0.4433, 0.4907, 0.4472, 0.4998, 0.4454,
        0.4878, 0.4502, 0.4927, 0.4495, 0.4906, 0.4456, 0.4943, 0.4381, 0.4842,
        0.4482, 0.4854, 0.4421, 0.4881, 0.4480, 0.5026, 0.4435, 0.4866, 0.4451,
        0.4894, 0.4495, 0.4909, 0.4505, 0.4938, 0.4440, 0.4813, 0.4459, 0.4888,
        0.4439, 0.4884, 0.4490, 0.5009, 0.4425, 0.4844, 0.4474, 0.4918, 0.4496,
        0.4909, 0.4482, 0.4931, 0.4438, 0.4847, 0.4483, 0.4882, 0.4436, 0.4900,
        0.4486, 0.5000, 0.4422, 0.4843, 0.4486, 0.4915, 0.4489, 0.4901, 0.4498,
        0.4929, 0.4429, 0.4835, 0.4499, 0.4883, 0.4447, 0.4886, 0.4472, 0.4967,
        0.4421, 0.4792, 0.4512, 0.4942, 0.4529, 0.4823, 0.4563, 0.4841, 0.4442,
        0.4816, 0.4401, 0.4839, 0.4474, 