# 1: Install Required Packages

Since seaborn is used for the confusion matrix and isn't installed by default in Colab, we'll install it. Other required libraries like torch, matplotlib, numpy, and pillow (for PIL) are typically pre-installed, but we'll ensure seaborn is available.

In [None]:
%pip install seaborn

# 2: Import Libraries

Here, we'll import all necessary libraries, removing tkinter since it won't be used and adding ipywidgets for the new interface.

In [11]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import widgets, Layout
from IPython.display import display
from PIL import Image
from datetime import datetime
import time
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import confusion_matrix
import seaborn as sns


AttributeError: partially initialized module 'torch._inductor' has no attribute 'custom_graph_pass' (most likely due to a circular import)

# 3: Define the ImprovedCNN Class

The ImprovedCNN class remains unchanged as it defines the model architecture and doesn't rely on Tkinter.

In [5]:
class ImprovedCNN(nn.Module):
    """
    An improved CNN architecture for digit recognition with:
    - Batch normalization
    - Residual connections
    - Increased channel depth
    - Global average pooling
    """
    def __init__(self):
        super(ImprovedCNN, self).__init__()

        # Initial convolution block
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)

        # Residual block 1
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(64)

        # Residual block 2
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.conv5 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(128)

        # Global average pooling
        self.gap = nn.AdaptiveAvgPool2d(1)

        # Fully connected layers
        self.fc1 = nn.Linear(128, 64)
        self.bn6 = nn.BatchNorm1d(64)
        self.fc2 = nn.Linear(64, 10)

        # Dropout for regularization
        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        # Initial convolution block
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)

        # Residual block 1
        identity = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        if identity.size(1) != x.size(1):
            identity = F.pad(identity, (0, 0, 0, 0, 0, x.size(1) - identity.size(1)))
        x = F.relu(x + identity)
        x = F.max_pool2d(x, 2)

        # Residual block 2
        identity = x
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.bn5(self.conv5(x))
        if identity.size(1) != x.size(1):
            identity = F.pad(identity, (0, 0, 0, 0, 0, x.size(1) - identity.size(1)))
        x = F.relu(x + identity)

        # Global average pooling
        x = self.gap(x)
        x = x.view(-1, 128)

        # Fully connected layers
        x = F.relu(self.bn6(self.fc1(x)))
        x = self.dropout(x)
        x = self.fc2(x)

        return x

# 4: Define the DigitRecognizer Class

We'll modify the DigitRecognizer class by removing the create_drawing_interface method (which uses Tkinter) and adding a new create_colab_drawing_interface method for the Colab-compatible interface. The rest of the class remains largely the same.

In [6]:
class DigitRecognizer:
    def __init__(self, model_path=None):
        """
        Initialize the digit recognizer with device setup, data preparation, and model loading.
        Allows loading a pre-trained model if model_path is provided.
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        self.train_transform = transforms.Compose([
            transforms.RandomRotation(10),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        self.test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        self.model = self.create_model()

        if model_path and os.path.exists(model_path):
            self.load_model(model_path)

    def create_model(self):
        """Create and return a CNN model with modern architecture enhancements."""
        model = ImprovedCNN().to(self.device)
        return model

    def prepare_data(self, batch_size=128):
        """Load and prepare MNIST dataset with data augmentation."""
        train_data = datasets.MNIST(
            root="data",
            train=True,
            download=True,
            transform=self.train_transform
        )

        test_data = datasets.MNIST(
            root="data",
            train=False,
            download=True,
            transform=self.test_transform
        )

        train_loader = DataLoader(
            train_data,
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,  # Reduced workers for Colab compatibility
            pin_memory=True
        )

        test_loader = DataLoader(
            test_data,
            batch_size=batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )

        self.loaders = {'train': train_loader, 'test': test_loader}
        self.test_data = test_data
        return self.loaders

    def train_model(self, epochs=10, lr=0.001, weight_decay=1e-5):
        """Train the model with advanced techniques."""
        if not hasattr(self, 'loaders'):
            self.prepare_data()

        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
        criterion = nn.CrossEntropyLoss()

        best_accuracy = 0
        patience = 3
        patience_counter = 0
        train_losses = []
        test_accuracies = []

        start_time = time.time()

        for epoch in range(1, epochs + 1):
            epoch_start = time.time()

            self.model.train()
            total_loss = 0
            correct = 0
            total = 0

            for batch_idx, (data, target) in enumerate(self.loaders['train']):
                data, target = data.to(self.device), target.to(self.device)
                optimizer.zero_grad()
                output = self.model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                _, predicted = torch.max(output.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()

                if batch_idx % 50 == 0:
                    print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(self.loaders['train'].dataset)} "
                          f"({100. * batch_idx / len(self.loaders['train']):.0f}%)]\tLoss: {loss.item():.6f}")

            avg_loss = total_loss / len(self.loaders['train'])
            train_accuracy = 100. * correct / total
            train_losses.append(avg_loss)

            test_accuracy = self.evaluate_model()
            test_accuracies.append(test_accuracy)

            scheduler.step(avg_loss)

            if test_accuracy > best_accuracy:
                best_accuracy = test_accuracy
                patience_counter = 0
                self.save_model(best=True)
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"Early stopping triggered after {epoch} epochs!")
                    break

            epoch_time = time.time() - epoch_start
            print(f"Epoch {epoch} completed in {epoch_time:.2f}s. Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%")

        total_time = time.time() - start_time
        print(f"Training completed in {total_time:.2f} seconds.")
        print(f"Best test accuracy: {best_accuracy:.2f}%")

        self.save_model()
        self.visualize_training(train_losses, test_accuracies)
        self.generate_confusion_matrix()

        return train_losses, test_accuracies

    def evaluate_model(self):
        """Evaluate the model on the test dataset."""
        self.model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for data, target in self.loaders['test']:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                _, predicted = torch.max(output.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()

        accuracy = 100. * correct / total
        print(f"Test set: Accuracy: {correct}/{total} ({accuracy:.2f}%)")
        return accuracy

    def save_model(self, best=False):
        """Save the model with metadata."""
        directory = "saved_models"
        os.makedirs(directory, exist_ok=True)

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"digit_model_best.pth" if best else f"digit_model_{timestamp}.pth"
        path = os.path.join(directory, filename)

        torch.save({
            'model_state_dict': self.model.state_dict(),
            'model_architecture': 'ImprovedCNN',
            'timestamp': timestamp
        }, path)

        print(f"Model saved to {path}")
        return path

    def load_model(self, model_path):
        """Load a saved model."""
        if not os.path.exists(model_path):
            print(f"Model file not found: {model_path}")
            return False

        checkpoint = torch.load(model_path, map_location=self.device)
        if checkpoint.get('model_architecture') == 'ImprovedCNN':
            self.model = ImprovedCNN().to(self.device)
        else:
            print("Warning: Unknown model architecture, using default ImprovedCNN")
            self.model = ImprovedCNN().to(self.device)

        self.model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Model loaded from {model_path}")
        return True

    def visualize_training(self, train_losses, test_accuracies):
        """Visualize training progress."""
        plt.figure(figsize=(12, 5))
        # Create the 'plots' directory if it doesn't exist
        os.makedirs("plots", exist_ok=True)
        plt.subplot(1, 2, 1)
        plt.plot(train_losses, label='Training Loss')
        plt.title('Training Loss Over Epochs')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(test_accuracies, label='Test Accuracy')
        plt.title('Test Accuracy Over Epochs')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy (%)')
        plt.legend()

        plt.tight_layout()
        plt.savefig("plots/training_progress.png")
        plt.show()

    def generate_confusion_matrix(self):
        """Generate and visualize confusion matrix."""
        self.model.eval()
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for data, target in self.loaders['test']:
                data = data.to(self.device)
                output = self.model(data)
                pred = output.argmax(dim=1, keepdim=True).cpu()
                all_preds.extend(pred.numpy().flatten())
                all_targets.extend(target.numpy())

        cm = confusion_matrix(all_targets, all_preds)

        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
        plt.title('Confusion Matrix')
        plt.savefig("plots/confusion_matrix.png")
        plt.show()

    def predict_digit(self, image_tensor):
        """Predict the digit from a preprocessed image tensor."""
        self.model.eval()

        with torch.no_grad():
            image_tensor = image_tensor.unsqueeze(0).to(self.device)
            output = self.model(image_tensor)
            probabilities = F.softmax(output, dim=1)
            probability, prediction = torch.max(probabilities, 1)

        return prediction.item(), probability.item()

    def visualize_prediction(self, index=0):
        """Visualize a prediction from the test set."""
        self.model.eval()
        image, label = self.test_data[index]

        prediction, probability = self.predict_digit(image)

        plt.figure(figsize=(6, 6))
        plt.imshow(image.squeeze(0), cmap='gray')
        plt.title(f"Prediction: {prediction} (Confidence: {probability:.2f})\nActual: {label}")
        plt.axis('off')
        plt.show()

    def create_colab_drawing_interface(self):
        """Create an interactive drawing interface for Colab."""
        ColabDrawingInterface(self)

# 5: Define the ColabDrawingInterface Class

This new class replaces the Tkinter-based DrawingInterface with a Colab-compatible version using matplotlib for drawing and ipywidgets for buttons. Users can draw digits on a 280x280 canvas, and the image is resized to 28x28 for prediction.

In [7]:
class ColabDrawingInterface:
    def __init__(self, digit_recognizer):
        self.digit_recognizer = digit_recognizer
        self.img = np.zeros((280, 280), dtype=np.float32)  # Black background

        # Set up matplotlib figure
        self.fig, self.ax = plt.subplots(figsize=(5, 5))
        self.im = self.ax.imshow(self.img, cmap='gray', vmin=0, vmax=1)
        self.ax.set_title('Draw a digit')
        self.ax.axis('off')  # Hide axes for cleaner look

        self.drawing = False

        # Connect mouse events
        self.cid_press = self.fig.canvas.mpl_connect('button_press_event', self.on_press)
        self.cid_release = self.fig.canvas.mpl_connect('button_release_event', self.on_release)
        self.cid_motion = self.fig.canvas.mpl_connect('motion_notify_event', self.on_motion)

        # Create interactive buttons
        self.predict_button = widgets.Button(description='Predict')
        self.clear_button = widgets.Button(description='Clear')
        self.predict_button.on_click(self.predict)
        self.clear_button.on_click(self.clear)

        # Display buttons
        display(widgets.HBox([self.predict_button, self.clear_button]))
        plt.show()

    def on_press(self, event):
        self.drawing = True

    def on_release(self, event):
        self.drawing = False

    def on_motion(self, event):
        if self.drawing and event.xdata is not None and event.ydata is not None:
            x = int(event.xdata)
            y = int(event.ydata)
            # Simulate a brush by setting a 5x5 area to 1 (white)
            for i in range(-2, 3):
                for j in range(-2, 3):
                    xi = x + i
                    yj = y + j
                    if 0 <= xi < 280 and 0 <= yj < 280:
                        self.img[yj, xi] = 1
            self.im.set_data(self.img)
            self.fig.canvas.draw()

    def clear(self, b):
        self.img.fill(0)
        self.im.set_data(self.img)
        self.ax.set_title('Draw a digit')
        self.fig.canvas.draw()

    def predict(self, b):
        # Convert numpy array to PIL image
        pil_img = Image.fromarray((self.img * 255).astype(np.uint8))
        # Resize to 28x28
        small_pil = pil_img.resize((28, 28), Image.BILINEAR)
        # Convert back to numpy and normalize to [0,1]
        small_img = np.array(small_pil).astype(np.float32) / 255.0

        # Apply MNIST normalization
        transform = transforms.Normalize((0.1307,), (0.3081,))
        image_tensor = torch.from_numpy(small_img).float().unsqueeze(0).unsqueeze(0)
        image_tensor = transform(image_tensor)

        # Make prediction
        prediction, confidence = self.digit_recognizer.predict_digit(image_tensor)

        # Update title with prediction
        self.ax.set_title(f'Prediction: {prediction} ({confidence*100:.2f}%)')
        self.fig.canvas.draw()

# 6: Define and Run the main Function

The main function is updated to use the new Colab interface. We'll run it here to execute the entire program.

In [8]:
import os

def main():
    model_path = "saved_models/digit_model_best.pth" if os.path.exists("saved_models/digit_model_best.pth") else None
    recognizer = DigitRecognizer(model_path)

    if model_path is None:
        print("No saved model found. Training a new model...")
        recognizer.prepare_data()
        recognizer.train_model(epochs=5)  # Adjust epochs as needed
    else:
        print(f"Using saved model: {model_path}")
        recognizer.prepare_data()

    print("\nVisualizing Predictions:")
    for i in range(5):
        recognizer.visualize_prediction(i)

    print("\nLaunching Drawing Interface...")
    recognizer.create_colab_drawing_interface()

main()

Using device: cuda


NameError: name 'transforms' is not defined