In [None]:
import numpy as np
import re

In [None]:
from pathlib import Path

import cv2
import numpy as np
from collections import defaultdict
from PIL import Image
import torch
from torch.utils.data import Dataset


class PairedDatImageDataset(Dataset):
    def __init__(self, data_dir="data", transform=None, dat_dtype=np.float32):
        """
        Args:
            data_dir (str or Path): directory containing .dat and .jpg files
            transform (callable, optional): image transform
            dat_dtype: numpy dtype for .dat files
        """
        self.data_dir = Path(data_dir)
        self.transform = transform
        self.dat_dtype = dat_dtype

        self.samples = self._index_files()

    def _index_files(self):
        grouped = defaultdict(dict)

        for file in (self.data_dir / "v2_train" / "v2_train").iterdir():
            if not file.is_file():
                continue

            if file.suffix not in {".dat", ".jpg"}:
                continue

            grouped[file.stem][file.suffix] = file

            # match = re.match(r"([a-z]+)([0-9]+)", 'foofo21', re.I)
            # if match:
            #     items = match.groups()
            # print(items)

        samples = []
        for key, files in grouped.items():
            if ".dat" in files and ".jpg" in files:
                samples.append({
                    "key": key,
                    "dat": files[".dat"],
                    "img": files[".jpg"],
                })
            else:
                print(f"File {file.stem} did not find a match")

        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        # load .dat
        dat = np.fromfile(sample["dat"], dtype=self.dat_dtype)
        dat = torch.from_numpy(dat)

        # load image
        img = cv2.imread(sample["img"]) # Image.open(sample["img"]).convert("RGB")
        if self.transform:
            img = self.transform(img)

        return dat, img

In [None]:
ds = PairedDatImageDataset(data_dir="../data/")

In [None]:
ds.samples

In [None]:
import matplotlib.pyplot as plt
plt.imshow(ds[11][1])

#### Imports

In [None]:
import matplotlib.pyplot as plt
import cv2
import imutils

### 1. Read in Image, Grayscale and Blur

In [None]:
ds[1][1].shape

In [None]:
ds[51][1].shape

TODO: Different input sizes -> Requires fix

In [None]:
img = ds[12][1]
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
plt.imshow(cv2.cvtColor(gray, cv2.COLOR_BGR2RGB))

In [None]:
gray.shape

### 2. Apply filter and find edges for localization

In [None]:
bfilter = cv2.bilateralFilter(gray, 11, 17, 17) #Noise reduction
edged = cv2.Canny(bfilter, 30, 200) #Edge detection
plt.imshow(cv2.cvtColor(edged, cv2.COLOR_BGR2RGB))

### 3. Find Contours and Apply Mask

In [None]:
keypoints = cv2.findContours(edged.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
contours = imutils.grab_contours(keypoints)
contours = sorted(contours, key=cv2.contourArea, reverse=True)[:10]

In [None]:
location = None
for contour in contours:
    approx = cv2.approxPolyDP(contour, 10, True)
    if len(approx) == 4:
        location = approx
        break


In [None]:
location

In [None]:
mask = np.zeros(gray.shape, np.uint8)
new_image = cv2.drawContours(mask, [location], 0,255, -1)
new_image = cv2.bitwise_and(img, img, mask=mask)

In [None]:
plt.imshow(cv2.cvtColor(new_image, cv2.COLOR_BGR2RGB))

----

### Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
class CNN_Sudoku_Recog(nn.Module):
    def __init__(self):
        super().__init__()

        # ---- Feature extractor ----
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),  # 240 x 320
            nn.ReLU(),

            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 120 x 160
            nn.ReLU(),

            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 60 x 80
            nn.ReLU(),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 30 x 40
            nn.ReLU(),

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 15 x 20
            nn.ReLU(),
        )

        # ---- Head ----
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 15 * 20, 9 * 9)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x.view(-1, 9, 9)


In [None]:
CNN_Sudoku_Recog()

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import MSELoss
from tqdm import tqdm

In [None]:
def train_model(
    model,
    dataset,
    *,
    batch_size=8,
    epochs=10,
    lr=1e-3,
    num_workers=4,
    device=None,
):
    """
    Train a PyTorch model on PairedDatImageDataset.
    """

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )

    optimizer = Adam(model.parameters(), lr=lr)
    criterion = MSELoss()

    model.train()

    for epoch in range(epochs):
        epoch_loss = 0.0

        progress = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}")

        print("-----------")
        for x in loader:
            print(x)
            break

        for dat, img in progress:
            # Move to device
            img = img.to(device)
            dat = dat.to(device)

            # Reshape target â†’ (B, 9, 9)
            target = dat.view(-1, 9, 9)

            # Forward
            output = model(img)

            # Loss
            loss = criterion(output, target)

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            progress.set_postfix(loss=loss.item())

        avg_loss = epoch_loss / len(loader)
        print(f"Epoch {epoch+1} | Avg Loss: {avg_loss:.6f}")

    return model


In [None]:
# dataset = PairedDatImageDataset(
#     "data",
#     transform=transform,
# )

model = CNN_Sudoku_Recog()

trained_model = train_model(
    model,
    ds,
    epochs=20,
    batch_size=16,
    lr=1e-4,
)
