In [1]:
from pathlib import Path
import json

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# data_dir = Path(r"C:\datasets\pubtabnet")
data_dir = Path(r"C:\Users\stans\Documents\Projects\Datasets\pubtabnet.tar\pubtabnet\ptnLite")
train_dir = data_dir / "Images"

In [4]:
publaynet_images = list(train_dir.glob("*.png"))
with open(data_dir/'targets.json','r') as f:
    publaynet_targets = json.load(f)

In [5]:
COLORS = {
    1:.5,
    2:.7,
}

def pad_input(img, shape):
    arr = np.zeros(shape)
    w,h = img.size
    arr[:h, :w] = np.array(img)/255
    return arr  # torch.as_tensor(arr).type(torch.cuda.FloatTensor)

def boxes_2_mask(labels, boxes, shape):
    mask = np.zeros(shape)
    for label, box in zip(labels, boxes):
        x1,y1,x2,y2 = box
        mask[y1+1:y2-1,x1+1:x2-1] = label
    return mask
    
def get_data(img_path):
    image = Image.open(img_path)
    labels, boxes = publaynet_targets[img_path.name]
    return pad_input(image, (256,256)), boxes_2_mask(labels, boxes, (256,256))
    

def test_resizer(i):
    im, mask = get_data(publaynet_images[i])
    plt.imshow(im)
    plt.show()
    plt.imshow(mask)

In [6]:
# test_resizer(12)

In [67]:
# Hyper-parameters
sequence_length = 256
input_size = 256
hidden_size = 128
num_layers = 2
num_classes = 256
batch_size = 32
num_epochs = 2
learning_rate = 0.01

In [68]:
# MNIST dataset
# train_dataset = torchvision.datasets.MNIST(root='../../data/',
#                                            train=True, 
#                                            transform=transforms.ToTensor(),
#                                            download=True)

# test_dataset = torchvision.datasets.MNIST(root='../../data/',
#                                           train=False, 
#                                           transform=transforms.ToTensor())

In [74]:
class PubTabNetDataset(Dataset):
    def __init__(self, files):
        super(Dataset, self).__init__()
        self.files = files

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        image, mask = get_data(self.files[index])
        image = torch.as_tensor(image).type(torch.cuda.FloatTensor)
        # mask = torch.as_tensor(mask.ravel()).type(torch.cuda.LongTensor)
        return image, torch.as_tensor(mask.max(axis=1)>=1).type(torch.cuda.FloatTensor)

In [75]:
# a = np.arange(100).reshape((10,10))
# a.max(axis=1)

In [76]:
train_dataset = PubTabNetDataset(publaynet_images[:300000])

In [77]:
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

# test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
#                                           batch_size=batch_size, 
#                                           shuffle=False)

In [78]:
# Recurrent neural network (many-to-one)
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        # Set initial hidden and cell states 
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        
        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)
        
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out

model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)


# Loss and optimizer
criterion = nn.MSELoss(reduction='none')  # nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.mean().backward()  # .backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.mean()))

# Test the model
# model.eval()
# with torch.no_grad():
#     correct = 0
#     total = 0
#     for images, labels in test_loader:
#         images = images.reshape(-1, sequence_length, input_size).to(device)
#         labels = labels.to(device)
#         outputs = model(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()

#     print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) 

# Save the model checkpoint
# torch.save(model.state_dict(), 'model.ckpt')

Epoch [1/2], Step [100/9375], Loss: 0.1282
Epoch [1/2], Step [200/9375], Loss: 0.1238
Epoch [1/2], Step [300/9375], Loss: 0.1098
Epoch [1/2], Step [400/9375], Loss: 0.1468
Epoch [1/2], Step [500/9375], Loss: 0.1013
Epoch [1/2], Step [600/9375], Loss: 0.1208
Epoch [1/2], Step [700/9375], Loss: 0.1196
Epoch [1/2], Step [800/9375], Loss: 0.1232
Epoch [1/2], Step [900/9375], Loss: 0.1343


KeyboardInterrupt: 