In [1]:
import torch
from torch.utils.data import Dataset
import os



#### Define the Dataset for Multi-digit-MNIST Images ####
<br>
We have to generate data in a train and test set still. Or can we split data into train and test?

In [28]:
class MultiDigitMNISTDataset(Dataset):
    def __init__(self, source_dir, img_transform=None, label_transform=None, merge_point_transform=None):
        self.source_dir = source_dir
        self.img_transform = img_transform
        self.label_transform = label_transform
        self.merge_point_transform = merge_point_transform
        
        self.data_record_entries = []
        cwd = os.getcwd()
        self.source_path = os.path.join(cwd, self.source_dir)
        with os.scandir(self.source_path) as it:
            for entry in it:
                if entry.is_file() and entry.name.endswith(".pt"):
                    self.data_record_entries.append(entry)
        
        self.data_records = []
    
    def __len__(self):
        return len(self.data_record_entries)
    
    def __getitem__(self, idx):
        data_record = torch.load(os.path.join(self.source_path, self.data_record_entries[idx].name))
        multi_img = data_record["multi_img"]
        multi_img_label = data_record["multi_img_label"]
        merge_points = data_record["merge_points"]
        if self.img_transform:
            multi_img = self.img_transform(multi_img)
        if self.label_transform:
            multi_img_label = self.label_transform(multi_img_label) 
        if self.merge_point_transform:
            merge_points = self.merge_point_transform(merge_points)
        
        return multi_img, multi_img_label, merge_points




In [None]:
from torch.utils.data import dataloader

mmnist_ds = MultiDigitMNISTDataset(source_dir="Mnist4")
train_dl = Dataloader(mmnist_ds, batch_size=1)

### Build the Neural Net to learn finding the merge-points ###
<br>
At first we presume, that we know the number of merge-points and ask only for their x-coordinate
<br>
These coordinates will be used to split the image of the multi-digit number into single-digit numbers, <br>
leaving us with an standard MNIST-like problem for which there are good solutions available.


In [None]:
import torch.functional as F

class MultiDigitMNISTNet(torch.nn.Module):

    def __init__(self, nof_digits):
        super(self).__init__()
        # input MNIST images: 1 x 28 x 28
        self.conv1 = torch.nn.Conv2d(1, 6, 5)
        # out: 6 x 24 x 24
        # max-pooling: 6 x 12 x 12
        self.conv2 = torch.nn.Conv2d(6, 16, 3)
        # out: (6 x ?: no!) 16 x 10 x 10 
        # max-pooling: 16 x 5 x 5
        self.fc1 = torch.nn.linear(16 * 5 * 5, 120)
        self.fc2 = torch.nn.linear(120, nof_digits)
    
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2,2))
        x = torch.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x