In [None]:
import torch
import torch.nn as nn


class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.middle = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(192, 128, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

        self.output = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.encoder(x)
        x2 = self.middle(x1)
        x3 = self.decoder(torch.cat([x2, x1], dim=1))
        out = self.output(x3)
        return out


unet_model = UNet(3, 2)
print(unet_model)


In [None]:
import os
import json
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset

class SyntheticTableDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        with open(os.path.join(root_dir, 'annotations.json')) as f:
            self.annotations = json.load(f)

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        annotation = self.annotations[idx]
        img_path = os.path.join(self.root_dir, annotation['filename'])
        img = Image.open(img_path).convert('RGB')

        mask = np.zeros((3, img.height, img.width), dtype=np.uint8)
        # Initialize the background channel as 1
        mask[0, :, :] = 1

        for item in annotation['annotations']:
            label = 0 if item['label'] == 'table' else 1
            coords = item['coordinates']
            x, y, w, h = coords['x'], coords['y'], coords['width'], coords['height']
            mask[label+1, y:y+h, x:x+w] = 1
            # Set the background channel to 0 where there is an object
            mask[0, y:y+h, x:x+w] = 0

        if self.transform:
            img, mask = self.transform(img, mask)

        return img, mask.clone().detach().requires_grad_(False)

In [None]:
from torchvision import transforms

class ResizeAndNormalize:
    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, image, mask):
        image = image.resize(self.output_size)
        mask = Image.fromarray(mask.transpose(1, 2, 0)).resize(self.output_size)
        mask = np.array(mask).transpose(2, 0, 1)

        image = transforms.functional.to_tensor(image)
        image = transforms.functional.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        return image, torch.tensor(mask)

input_size = (256, 256)
transform = ResizeAndNormalize(input_size)

In [None]:
from torch.utils.data import DataLoader, random_split

dataset = SyntheticTableDataset('synthetic_tables', transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=8, pin_memory=True)

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
%matplotlib inline

def visualize_predictions(image, predictions, classes):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 15))

    ax1.imshow(image)
    ax1.set_title('Original Image')

    # cmap = ListedColormap(classes)
    ax2.imshow(predictions)
    ax2.set_title('Segmented Image')

    plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.optim import Adam

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.has_mps else 'cpu')
print(f"Using device: {device}")
model = UNet(3, 2).to(device)

optimizer = Adam(model.parameters(), lr=0.0001)
train_losses = []
val_losses = []
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    train_loss = []
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = F.cross_entropy(outputs, masks.argmax(dim=1))
        train_loss.append(loss.item())
        loss.backward()
        optimizer.step()

    train_losses.append(np.mean(train_loss))

    train_losses.append(loss)
    # Validation loop
    model.eval()
    val_loss = []
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = F.cross_entropy(outputs, masks.argmax(dim=1))
            val_loss.append(loss.item())

    val_losses.append(np.mean(val_loss))
    print(f"Epoch: {epoch + 1}/{num_epochs}, Loss: {train_losses[-1]}, Validation Loss: {val_losses[-1]}")

torch.save(model.state_dict(), "table_segmentation_model.pth")

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.plot([l.item() for l in train_losses], label='Train Loss')
plt.plot([l.item() for l in val_losses], label='Validation Loss')
plt.legend()
plt.show()

In [None]:
model = UNet(3, 2)
model.load_state_dict(torch.load('table_segmentation_model.pth'))
model = model.to(device)
model.eval()

In [None]:
example_imgs, example_masks = next(iter(val_loader))
idx = 0
example_img = example_imgs[idx]
example_mask = example_masks[idx]
example_img.shape, example_mask.shape

In [None]:
input = example_img.unsqueeze(0)
input = input.to(device)

output = model(input)
output = output.squeeze().cpu().detach().numpy()
result = np.argmax(output, axis=0)
input.shape, result.shape

In [None]:
# Define the color map for visualization
classes = [(0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255)]
image = input.cpu().detach()[0].permute((1,2,0)).numpy()

visualize_predictions(image, result, classes)

In [None]:
import numpy as np
from scipy.ndimage import label

def extract_coordinates_from_mask(mask, threshold=0.5):
    # Apply the threshold to the mask
    binary_mask = (mask > threshold).astype(np.uint8)
    labeled_mask, num_components = label(binary_mask)

    coordinates = []
    for i in range(1, num_components + 1):
        component_mask = (labeled_mask == i)
        rows, cols = np.where(component_mask)
        
        y_min, y_max = np.min(rows), np.max(rows)
        x_min, x_max = np.min(cols), np.max(cols)
        
        coordinates.append({
            'label': 'table' if i == 1 else 'cell',
            'coordinates': {
                'x': int(x_min),
                'y': int(y_min),
                'width': int(x_max - x_min),
                'height': int(y_max - y_min)
            }
        })

    return coordinates

In [None]:
coordinates = extract_coordinates_from_mask(result)

for item in coordinates:
    print(f"{item['label']} at {item['coordinates']}")

In [None]:
import torch.quantization

quantized_model = torch.quantization.quantize_dynamic(
    unet_model, {torch.nn.Conv2d}, dtype=torch.qint8
)
quantized_model

In [None]:
import time

# timeit quantized_model(input_tensor)
start = time.time()
with torch.no_grad():
    output = quantized_model(input)
end = time.time()
print(f'Time taken: {end - start} seconds')

output_np = output.squeeze().numpy()
mask = np.argmax(output_np, axis=0)

coordinates = extract_coordinates_from_mask(mask)
for item in coordinates:
    print(f"{item['label']} at {item['coordinates']}")

In [None]:
visualize_predictions(input, mask, classes)