In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import multiprocessing
from PIL import Image


In [None]:
# define the CNN architecture
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        # 1. conv layer 1: (1, 28, 28) → (32, 28, 28)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)  # Batch Normalization
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)  # pooling layer

        # 2. conv layer 2: (32, 14, 14) →  (64, 14, 14)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        # 3. conv layer 3: (64, 7, 7) →  (128, 7, 7)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)

        # 4. full conn layer: (128*7*7) → 128
        self.fc1 = nn.Linear(128 * 7 * 7, 128)
        self.dropout = nn.Dropout(0.5)  # drop out to prevent overfitting
        self.fc2 = nn.Linear(128, 62) 

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))  # Conv1 + ReLU + Pooling
        x = self.pool(F.relu(self.bn2(self.conv2(x))))  # Conv2 + ReLU + Pooling
        x = F.relu(self.bn3(self.conv3(x)))  # Conv3 + ReLU

        x = x.view(-1, 128 * 7 * 7) 
        x = F.relu(self.fc1(x))
        x = self.dropout(x)  
        x = self.fc2(x)  
        return x



In [None]:

import matplotlib.pyplot as plt

def plot_training_curves(train_losses, val_losses, train_accs, val_accs):

    epochs = range(1, len(train_losses) + 1) 

    plt.figure(figsize=(12, 5))

    #  Loss curve
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, marker='o', linestyle='-', label="Train Loss", color='blue')
    plt.plot(epochs, val_losses, marker='s', linestyle='--', label="Validation Loss", color='red', alpha=0.8)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training & Validation Loss")
    plt.legend()
    plt.grid(True)

    # Accuracy curve
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accs, marker='o', linestyle='-', label="Train Accuracy", color='blue')
    plt.plot(epochs, val_accs, marker='s', linestyle='--', label="Validation Accuracy", color='red', alpha=0.8)
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Training & Validation Accuracy")
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()



def display_emnist_sample_images(images, labels, preds=None):

    emnist_classes = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
    fig, axes = plt.subplots(2, 6, figsize=(10, 5))

    images.cpu()
    if not isinstance(images, np.ndarray):
        images = images.numpy()

    for i, ax in enumerate(axes.flat):
        img = images[i]
        img = img.squeeze()
        # img = np.transpose(img, (1, 2, 0))  
        img = np.rot90(img, k=3)  # 270° rotation
        img = np.fliplr(img)  # flip left-right

        ax.imshow(img.squeeze(), cmap="gray")  # grey scale
        title = f"Label: {emnist_classes[labels[i]]}" if preds is None else f"Pred: {emnist_classes[preds[i]]}\nLabel: {emnist_classes[labels[i]]}"
        ax.set_title(title) 
        ax.axis("off")

    plt.show()



def predict(model, image_or_path):
    if isinstance(image_or_path, str):
        image = preprocess_image(image_or_path)  
    elif isinstance(image_or_path, torch.Tensor):
        image = image_or_path  
    else:
        raise ValueError("file path must be (str) or PyTorch Tensor")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    image = image.to(device)
    emnist_classes = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
    
    with torch.no_grad():
        output = model(image)
        _, predicted = output.max(1)
        return emnist_classes[predicted.item()]



def preprocess_image(image_path, mode="test"):

    image = Image.open(image_path).convert("L")  # greyscale
    image = np.array(image)  # to numpy array

    # color binarization
    if np.mean(image) > 127:
        image = 255 - image

    coords = cv2.findNonZero(255 - image)  # find non-zero pixel coordinates
    x, y, w, h = cv2.boundingRect(coords)
    image = image[y:y+h, x:x+w]  # crop the image

    # padding
    h, w = image.shape
    max_dim = max(h, w)
    pad_h = (max_dim - h) // 2
    pad_w = (max_dim - w) // 2
    image = np.pad(image, ((pad_h, pad_h), (pad_w, pad_w)), mode='constant', constant_values=255)

    # EMNIST dataset ajustments
    image = np.rot90(image, k=3)  # 270° rotation
    image = np.fliplr(image)  # flip left-right

    image = Image.fromarray(image).resize((28, 28), resample=Image.LANCZOS)  # resize
    image = np.array(image, dtype=np.uint8).copy()  # numpy array copy

    # augmentation
    if mode == "train":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.RandomRotation(15),  
            transforms.RandomAffine(0, translate=(0.1, 0.1)), 
            transforms.Normalize((0.5,), (0.5,))
        ])
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    image_tensor = transform(image).unsqueeze(0)  # to tensor (1, 1, 28, 28)

    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    axes[0].imshow(Image.open(image_path).convert("L"), cmap="gray")
    axes[0].set_title("Before Processing")
    axes[0].axis("off")

    axes[1].imshow(image.squeeze(), cmap="gray")
    axes[1].set_title("After Processing")
    axes[1].axis("off")

    plt.show()

    return image_tensor




In [None]:
def train_CNN(model, model_path, train_loader, val_loader, num_epochs=10, patience=3):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    best_val_loss = float('inf') 
    counter = 0  
    current_epoch = 0
    train_losses, train_accs, val_losses, val_accs = [], [], [], []

    if os.path.exists(model_path):
        print(f"✅ found {model_path}, loading model ...")
        pretrained_model = torch.load(model_path, map_location=device)
        model.load_state_dict(pretrained_model['model_state_dict'])
        optimizer.load_state_dict(pretrained_model['optimizer_state_dict'])

        train_losses = pretrained_model.get('train_losses', [])
        train_accs = pretrained_model.get('train_accs', [])
        val_losses = pretrained_model.get('val_losses', [])
        val_accs = pretrained_model.get('val_accs', [])
        current_epoch = pretrained_model.get('epoch', 0)
        best_val_loss = pretrained_model.get('best_val_loss', float('inf'))
        counter = pretrained_model.get('counter', 0)

        if current_epoch >= num_epochs:
            print(f"✅ finished training {model_path} with {current_epoch} epoches. No action needed.")
            return train_losses, train_accs, val_losses, val_accs
        else:
            print(f"🔄 continue traiing {num_epochs - current_epoch} epoches...")

    else:
        print(f"⚠️ {model_path} not found, {num_epochs} epoches left...")

    for epoch in range(current_epoch, num_epochs):
        model.train()
        running_loss = 0.0
        correct, total = 0, 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct / total
        train_losses.append(epoch_loss)
        train_accs.append(epoch_acc)

        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

        val_loss /= len(val_loader)
        val_acc = val_correct / val_total
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            counter = 0 
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'train_accs': train_accs,
                'val_losses': val_losses,
                'val_accs': val_accs,
                'epoch': epoch + 1,
                'best_val_loss': best_val_loss,
                'counter': counter
            }, model_path)
            print(f"model updated to {model_path}...")
        else:
            counter += 1
            print(f"early stopping counter: {counter}/{patience}")

        if counter >= patience:
            print("early stopped!")
            break

    print("training completed 🚀")
    return train_losses, train_accs, val_losses, val_accs


In [None]:
# train data augmentation
train_transform = transforms.Compose([
    transforms.RandomRotation(5),  
    transforms.RandomAffine(degrees=0, translate=(0.02, 0.02)),  
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
    transforms.Resize((28, 28)),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
    transforms.Resize((28, 28)),
])


full_train_dataset = torchvision.datasets.EMNIST(root="./data", split="byclass", train=True, download=True, transform=train_transform)
test_dataset = torchvision.datasets.EMNIST(root="./data", split="byclass", train=False, download=True, transform=test_transform)

train_size = int(0.8 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size  
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

batch_size = 64
num_workers = min(4, multiprocessing.cpu_count()) 
print(f"num_workers: {num_workers}")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)  
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

print(f"training size: {len(train_dataset)}")  
print(f"validation size: {len(val_dataset)}")  
print(f"testing size: {len(test_dataset)}")
print(f"classification size: {len(full_train_dataset.class_to_idx)}")  
print(f"classifications: {full_train_dataset.class_to_idx}")  



In [None]:
# visualize the data
dataiter = iter(train_loader)
images, labels = next(dataiter)
display_emnist_sample_images(images, labels)

这样设计，有利于随时增加一圈。比如4圈增加到5圈

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
model_path = "../models/emnist_baseline.pth"
num_epochs = 5

train_losses, train_accs, val_losses, val_accs = train_CNN(model, model_path, train_loader, val_loader, num_epochs)

In [None]:
# model evaluation
model.eval()
correct = 0
total = 0
misclassified = []

with torch.no_grad(): 
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        # wrong clsasifications
        for i in range(len(labels)):
            if predicted[i] != labels[i]:
                misclassified.append((images[i], predicted[i], labels[i]))

test_acc = correct / total
print(f"Test set accuracy: {test_acc:.4f}")



In [None]:
plot_training_curves(train_losses, val_losses, train_accs, val_accs)

In [None]:
# visualize misclassified images
dataiter = iter(test_loader)
images, labels = next(dataiter)
images, labels = images.to(device), labels.to(device)

model.eval()
with torch.no_grad():
    outputs = model(images)
    _, predicted = outputs.max(1)

fig, axes = plt.subplots(2, 6, figsize=(10, 5))
emnist_classes = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

for i, ax in enumerate(axes.flat):
    img = images[i].cpu().numpy().squeeze()
    img = np.rot90(img, k=3) 
    img = np.fliplr(img) 

    ax.imshow(img, cmap="gray")
    ax.set_title(f"Pred: {emnist_classes[predicted[i]]}\nLabel: {emnist_classes[labels[i]]}")
    ax.axis("off")

plt.show()

In [None]:
# visualize misclassified images
fig, axes = plt.subplots(2, 6, figsize=(10, 5))
emnist_classes = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

for i, ax in enumerate(axes.flat):
    img, pred, true = misclassified[i]
    img = img.cpu().numpy().squeeze()
    img = np.rot90(img, k=3)  
    img = np.fliplr(img) 

    ax.imshow(img, cmap="gray")
    ax.set_title(f"Pred: {emnist_classes[pred]}\nTrue: {emnist_classes[true]}")
    ax.axis("off")

plt.show()

In [None]:
# test on hand-written images from the internet
image_path = "../images/S.png"
img = Image.open(image_path)

plt.imshow(img, cmap="gray") 
plt.axis("off")
plt.show()

prediction = predict(model, image_path)
print(f"prediction result: {prediction}")

In [None]:
# process the image and predict again
image_tensor = preprocess_image(image_path, mode="test")
print("tensor shape: ", image_tensor.shape)

processed_image = image_tensor
prediction = predict(model, processed_image)
print(f"prediction result: {prediction}")