In [3]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import cv2 as cv
from sklearn.model_selection import train_test_split

In [4]:
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

In [259]:
# merge print into one folder:

roots = [
    "data_digits/hand",
    "data_digits/print",
] 

for root in roots:
    root = Path(root)

    for d in range(1, 10):
        (root / str(d)).mkdir(parents=True, exist_ok=True)
    
    files = root.rglob('*.jpg')

    for file in files:
        parent = file.parent.name

        file.rename(root / parent / file.name)



In [260]:
# calculate mean and std

files = list(Path("data_digits/").rglob('*.jpg'))
print(len(files))

mean = 0
std = 0

for file in files:
    img = cv.imread(file, cv.IMREAD_GRAYSCALE)

    mean += img.mean()
    std += img.std()

mean /= len(files)
std /= len(files)

mean / 255, std / 255

3717


(0.08285867195867874, 0.2569625918691173)

In [261]:
# split on train and test:
roots = [
    "data_digits/hand",
    "data_digits/print",
]

def create_split_cls_folders(root: Path, splits, classes):
    for split in splits:
        for cls in classes:
            (root / split / cls).mkdir(parents=True, exist_ok=True)

splits = ['train', 'test']
classes = [str(x) for x in range(1, 10)]

for root in roots:
    root = Path(root)
    files = list(root.rglob('*.jpg'))
    print(f"Have {len(files)}")

    create_split_cls_folders(root, splits, classes)

    train_files, test_files = train_test_split(files, test_size=0.10, random_state=1)

    for file in train_files:
        cls = file.parent.name
        file.rename(root / 'train' / cls / file.name)

    for file in test_files:
        cls = file.parent.name
        file.rename(root / 'test' / cls / file.name)


Have 1790
Have 1927


In [262]:
folders = [x for x in Path('data_digits').rglob('*') if x.is_dir()]

for f in folders:
    if len(list(f.glob('*'))) == 0:
        f.rmdir()

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Grayscale(),
    transforms.Resize((50, 50)),
    transforms.Normalize((0.082), (0.25696)), # Normalize to mean=0.1307, std=0.3081 (MNIST stats)
    #transforms.RandomRotation(15),
    transforms.RandomAffine(degrees=10, translate=(0.15, 0.15), scale=(0.9, 1.1))
    #transforms.RandomPerspective(0.1, p=0.25),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Grayscale(),
    transforms.Resize((50, 50)),
    transforms.Normalize((0.082), (0.25696)), # Normalize to mean=0.1307, std=0.3081 (MNIST stats)
])

In [6]:
train_dataset_hand = torchvision.datasets.ImageFolder(root='./data_digits/hand/train', transform=transform)
test_dataset_hand = torchvision.datasets.ImageFolder(root='./data_digits/hand/test', transform=transform_test)
train_dataset_print = torchvision.datasets.ImageFolder(root='./data_digits/print/train', transform=transform)
test_dataset_print = torchvision.datasets.ImageFolder(root='./data_digits/print/test', transform=transform_test)

train_dataset = torch.utils.data.ConcatDataset((train_dataset_hand, train_dataset_print))
test_dataset = torch.utils.data.ConcatDataset((test_dataset_hand, test_dataset_print))

# train_dataset = train_dataset_hand
# test_dataset = test_dataset_hand

In [7]:
len(train_dataset), len(test_dataset)

(3345, 372)

In [8]:
# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [9]:
def calculate_params(model: nn.Module):
    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return params

In [10]:
import timm
timm.list_models('mobilenetv2*', pretrained=True)

  from .autonotebook import tqdm as notebook_tqdm


['mobilenetv2_050.lamb_in1k',
 'mobilenetv2_100.ra_in1k',
 'mobilenetv2_110d.ra_in1k',
 'mobilenetv2_120d.ra_in1k',
 'mobilenetv2_140.ra_in1k']

In [11]:
import timm
# model = timm.create_model('mobilenetv3_small_050.lamb_in1k', pretrained=True, num_classes=9, in_chans=1)
# print(calculate_params(model)) #577161

# model = timm.create_model('resnet10t.c3_in1k', pretrained=True, num_classes=9, in_chans=1)
# print(calculate_params(model))

model = timm.create_model('mobilenetv2_100.ra_in1k', pretrained=True, num_classes=9, in_chans=1)
print(calculate_params(model))


2234825


In [12]:
from timm.models.mobilenetv3 import default_cfgs
default_cfgs['mobilenetv3_small_050']

DefaultCfg(tags=deque(['lamb_in1k']), cfgs={'lamb_in1k': PretrainedCfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth', file=None, state_dict=None, hf_hub_id='timm/', hf_hub_filename=None, source=None, architecture=None, tag=None, custom_load=False, input_size=(3, 224, 224), test_input_size=None, min_input_size=None, fixed_input_size=False, interpolation='bicubic', crop_pct=0.875, test_crop_pct=None, crop_mode='center', mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), num_classes=1000, label_offset=None, label_names=None, label_descriptions=None, pool_size=(7, 7), test_pool_size=None, first_conv='conv_stem', classifier='classifier', license=None, description=None, origin_url=None, paper_name=None, paper_ids=None, notes=None)}, is_pretrained=True)

In [13]:

criterion = nn.CrossEntropyLoss()  # Standard loss for classification
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam optimizer
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5)


In [14]:
def test_model(model, test_loader):

    model.eval()  # Set to evaluation mode
    correct = 0
    total = 0
    device = 'cuda'

    with torch.no_grad():  # No need to track gradients in evaluation
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"Test Accuracy: {100 * correct / total:.2f}%")
    return 100 * correct / total


In [15]:
num_epochs = 50  # You can increase this for better accuracy
device = 'cuda'
model.to(device)

best_score = 0

for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()  # Zero gradients from previous step
        outputs = model(images)  # Forward pass
        loss = criterion(outputs, labels)  # Compute loss
        loss.backward()  # Backpropagation
        optimizer.step()  # Update weights
        
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

    test_score = test_model(model, test_loader)

    if test_score > best_score:
        best_score = test_score
        torch.save(model.state_dict(), "best_model.pt")

    scheduler.step(running_loss, epoch=epoch)

print("Training complete!")


Epoch [1/50], Loss: 1.9571
Test Accuracy: 73.12%




Epoch [2/50], Loss: 0.3311
Test Accuracy: 95.43%
Epoch [3/50], Loss: 0.1992
Test Accuracy: 98.39%
Epoch [4/50], Loss: 0.1226
Test Accuracy: 99.46%
Epoch [5/50], Loss: 0.1266
Test Accuracy: 98.66%
Epoch [6/50], Loss: 0.0680
Test Accuracy: 99.19%
Epoch [7/50], Loss: 0.0926
Test Accuracy: 99.46%
Epoch [8/50], Loss: 0.1118
Test Accuracy: 99.19%
Epoch [9/50], Loss: 0.0603
Test Accuracy: 98.12%
Epoch [10/50], Loss: 0.0714
Test Accuracy: 99.46%
Epoch [11/50], Loss: 0.1150
Test Accuracy: 99.46%
Epoch [12/50], Loss: 0.0907
Test Accuracy: 99.46%
Epoch [13/50], Loss: 0.1189
Test Accuracy: 97.85%
Epoch [14/50], Loss: 0.2032
Test Accuracy: 97.31%
Epoch [15/50], Loss: 0.0971
Test Accuracy: 99.19%
Epoch [16/50], Loss: 0.0837
Test Accuracy: 99.73%
Epoch [17/50], Loss: 0.0489
Test Accuracy: 99.46%
Epoch [18/50], Loss: 0.0465
Test Accuracy: 99.19%
Epoch [19/50], Loss: 0.0399
Test Accuracy: 99.73%
Epoch [20/50], Loss: 0.0371
Test Accuracy: 99.73%
Epoch [21/50], Loss: 0.0407
Test Accuracy: 100.00%
Epoch [

KeyboardInterrupt: 

In [16]:
model.load_state_dict(torch.load('best_model.pt', weights_only=True))
model.eval()

EfficientNet(
  (conv_stem): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNormAct2d(
    32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
    (drop): Identity()
    (act): ReLU6(inplace=True)
  )
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bn1): BatchNormAct2d(
          32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): ReLU6(inplace=True)
        )
        (aa): Identity()
        (se): Identity()
        (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNormAct2d(
          16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
    )
    (1): Sequent

In [17]:
test_model(model, test_loader)

Test Accuracy: 100.00%


100.0

In [None]:
root = Path("data_digits/hand/test")

for digit in range(1, 10):
    (root / str(digit)).mkdir(parents=True, exist_ok=True)

In [None]:
def viz(img):
    plt.figure()
    plt.imshow(img, cmap='gray')
    plt.show()



model.eval()
root = Path("data_digits/hand")
digits_files = list(root.glob('*.jpg'))


for file in digits_files:
    digit = cv.imread(file, cv.IMREAD_GRAYSCALE)

    inp = transform_test(digit).to(device)

    with torch.no_grad():
        output = model(inp.unsqueeze(0))
        _, pred = torch.max(output, 1)

    #viz(digit)
    #print(test_dataset.classes[pred])
    file.rename(file.parent / "test" / train_dataset.classes[pred] / file.name)

In [None]:
root = Path("data_digits/hand")

for digit in range(1, 10):
    (root / str(digit)).mkdir()