In [1]:
import os
import cv2 
import json
import torch
import shutil
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

In [2]:
train_data_dir = '/mnt/researchteam/document_understanding/dataset/train'
train_csv_file = '/mnt/researchteam/document_understanding/hiertext/gt/train.jsonl'
val_data_dir = '/mnt/researchteam/document_understanding/dataset/validation'
val_csv_file = '/mnt/researchteam/document_understanding/hiertext/gt/validation.jsonl'

In [3]:
writer = SummaryWriter()

In [4]:
class HierText(Dataset):
    def __init__(self, csv_file, data_dir, transform=None):
        self.data = json.load(open(csv_file, 'r'))["annotations"][:1000]
        self.data_dir = data_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)
    
    def draw_mask(self, vertices, w, h):
        mask = np.zeros((h, w, 3), dtype=np.float32)
        mask = cv2.fillPoly(mask, [vertices], [1.] * 3)[:, :, 0]
        return mask

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        data_annotations = self.data[idx]
        img_name = os.path.join(self.data_dir, data_annotations['image_id'])
        image = cv2.imread(f"{img_name}.jpg")
        w = data_annotations['image_width']
        h = data_annotations['image_height']

        gt_word_masks = []
        gt_word_weights = []

        for paragraph in data_annotations['paragraphs']:
            for line in paragraph['lines']:
                for word in line['words']:
                    gt_word_weights.append(1.0 if word['legible'] else 0.0)
                    vertices = np.array(word['vertices'])
                    gt_word_mask = self.draw_mask(vertices, w, h)
                    gt_word_masks.append(gt_word_mask)

        n_mask = len(gt_word_masks)

        gt_masks = (np.stack(gt_word_masks, -1) if n_mask else np.zeros(((h + 1) // 2, (w + 1) // 2, 0), np.float32))
        gt_weights = (np.array(gt_word_weights) if n_mask else np.zeros((0,), np.float32))
        
        palette = [[1]]*n_mask
        colored = np.reshape(np.matmul(np.reshape(gt_masks, (-1, n_mask)), palette), (h, w, 1))
        dont_care_mask = (np.reshape(np.matmul(np.reshape(gt_masks, (-1, n_mask)), np.reshape(1.- gt_weights, (-1, 1))), (h, w, 1)) > 0).astype(np.float32)

        binary_image = np.clip(dont_care_mask * 1. + (1. - dont_care_mask) * colored, 0., 1.)
        
        sample = {"image": image.astype(np.uint8), "binary_image": binary_image.astype(np.uint8)}

        if self.transform:
            sample["image"] = self.transform(sample["image"])
            sample["binary_image"] = self.transform(sample["binary_image"])

        return sample

In [5]:
transforms = transforms.Compose([transforms.ToPILImage(), transforms.Resize((400,400)), transforms.ToTensor()])

In [6]:
hiertext_train_dataset = HierText(csv_file=train_csv_file, data_dir=train_data_dir, transform=transforms)
hiertext_val_dataset = HierText(csv_file=val_csv_file, data_dir=val_data_dir, transform=transforms)

In [7]:
# fig = plt.figure()

# for i in range(len(hiertext_train_dataset)):
#     sample = hiertext_train_dataset[i]

#     print(i, sample['image'].shape, sample['binary_image'].shape)
    
#     fig, (ax1, ax2) = plt.subplots(1, 2)
#     ax1.imshow(sample['image'].permute(1,2,0))
#     ax2.imshow(sample['binary_image'].permute(1,2,0))
    
#     if i ==3:
#         break

In [8]:
bs = 1
train_dataloader = DataLoader(hiertext_train_dataset, batch_size=bs, shuffle=False)
val_dataloader = DataLoader(hiertext_val_dataset, batch_size=bs, shuffle=False)

In [9]:
len(val_dataloader.dataset)

1000

In [10]:
# for batch_idx, data in enumerate(val_dataloader):
#     print(batch_idx, data['image'].shape, data['binary_image'].shape)
    
#     fig, (ax1, ax2) = plt.subplots(1, 2)
#     ax1.imshow(data['image'][0].permute(1,2,0))
#     ax2.imshow(data['binary_image'][0].permute(1,2,0))
#     if batch_idx == 3:
#         break

### Model 

In [11]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU())
#         self.layer2 = nn.Sequential(
#             nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=1),
#             nn.BatchNorm2d(64),
#             nn.ReLU()
#             )
        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU())
#         self.layer4 = nn.Sequential(
#             nn.Conv2d(128, 128, kernel_size=5, stride=1, padding=1),
#             nn.BatchNorm2d(128),
#             nn.ReLU())

        self.layer5 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU())

        self.layer6 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU())
#         self.layer7 = nn.Sequential(
#             nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(512),
#             nn.ReLU())

        self.layer8 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU())
        
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, ceil_mode=True)
        
        self.fc = nn.Linear(in_features=256*25*25, out_features=512)
        
    def forward(self, x):
        out = self.layer1(x)
#         out = self.layer2(out)
        out = self.maxpool(out)
        out = self.layer3(out)
#         out = self.layer4(out)
        out = self.maxpool(out)
        out = self.layer5(out)
        out = self.maxpool(out)
        out = self.layer6(out)
#         out = self.layer7(out)
        out = self.maxpool(out)
        out = self.layer8(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [12]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(in_features=512, out_features=256*25*25)
        
        self.layer7 = nn.Sequential(
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.ReLU()
            )
        self.layer6 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU())
#         self.layer5 = nn.Sequential(
#             nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1),
#             nn.BatchNorm2d(128),
#             nn.ReLU())
        self.layer4 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU())
        self.layer3 = nn.Sequential(
            nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU())
#         self.layer2 = nn.Sequential(
#             nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=0),
#             nn.BatchNorm2d(512),
#             nn.ReLU())
#         self.layer1 = nn.Sequential(
#             nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=0),
#             nn.BatchNorm2d(512),
#             nn.ReLU())
        
    def forward(self, x):
        out = self.fc(x)
        out = out.view(out.size(0), 256, 25, 25)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer6(out)
        out = self.layer7(out)
        return out

In [13]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [14]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(EncoderDecoder, self).__init__() 
        self.encoder = encoder 
        self.decoder = decoder 
    
    def forward(self, x):
        out = self.encoder(x)
        out = self.decoder(out)
        return out 

In [15]:
encoder_model = Encoder().to(device)
decoder_model = Decoder().to(device)
ip = torch.rand(1, 3, 400, 400).to(device)

In [16]:
model = EncoderDecoder(encoder_model, decoder_model).to(device)

### Training Loop

In [23]:
1e-4

0.0001

In [17]:
learning_rate = 0.0001 
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.L1Loss()

In [18]:
len(train_dataloader.dataset)

1000

In [17]:
model.load_state_dict(torch.load("saved_models/model4.pth", map_location=torch.device(device)))

<All keys matched successfully>

In [19]:
epoch = 100

In [None]:
model.train()
for e in range(epoch): 
    total_loss = 0 
    for batch_idx, data in enumerate(train_dataloader):
        optimizer.zero_grad()
        image, binary_image = data["image"].to(device), data["binary_image"].to(device)
        pred_binary_image = model(image)
        loss = loss_fn(binary_image, pred_binary_image)
        total_loss += loss 
        loss.backward()
        optimizer.step()
        if batch_idx % 50 == 0:
            print(f"Epoch: {e}, batch_idx: {batch_idx}, num_data: {len(train_dataloader.dataset)}, Loss: {loss}")
    epoch_loss = (total_loss.item()*bs)/len(train_dataloader.dataset)
    print(f"Epoch: {e}, Epoch Loss: {epoch_loss}")
    writer.add_scalar('Loss/train', epoch_loss, e)
    if e % 10 == 0:
        if os.path.exists('/mnt/researchteam/.local/share/Trash/'):
            shutil.rmtree('/mnt/researchteam/.local/share/Trash/')            
        if os.path.exists(f"saved_models/model{e-10}.pth"):
            os.remove(f"saved_models/model{e-10}.pth")
        torch.save(model.state_dict(), f"saved_models/model{e}.pth")

Epoch: 0, batch_idx: 0, num_data: 1000, Loss: 0.5748628377914429


In [None]:
param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

In [None]:
1000000000

In [None]:
param_size

In [None]:
torch.cuda.empty_cache()