In [3]:
import functools
from typing import Callable, Tuple, List

import numpy as np
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.preprocessing import OneHotEncoder

import auto_diff as ad
import torch
from torchvision import datasets, transforms

max_len = 28  # MNIST images are 28x28

def transformer(X: ad.Node, nodes: List[ad.Node],
                model_dim: int, seq_length: int, eps, batch_size, num_classes) -> ad.Node:
    """
    Build a single transformer encoder layer for sequence classification.

    Parameters
    ----------
    X : ad.Node
        Input node of shape (batch_size, seq_length, input_dim). (In our case, input_dim=28.)
    nodes : List[ad.Node]
        List of parameter nodes. We assume:
            nodes[0]: W_Q (shape: [input_dim, model_dim])
            nodes[1]: W_K (shape: [input_dim, model_dim])
            nodes[2]: W_V (shape: [input_dim, model_dim])
            nodes[3]: W_O (shape: [model_dim, model_dim])
            nodes[4]: W_1 (shape: [model_dim, model_dim])
            nodes[5]: W_2 (shape: [model_dim, num_classes])
            nodes[6]: b_1 (shape: [model_dim,])
            nodes[7]: b_2 (shape: [num_classes,])
    model_dim : int
        Hidden dimension.
    seq_length : int
        Sequence length (28 for MNIST).
    eps : float
        (Not used here, but passed for consistency.)
    batch_size : int
        Batch size.
    num_classes : int
        Number of output classes.
        
    Returns
    -------
    logits : ad.Node
        A node of shape (batch_size, num_classes) containing the classifier logits.
    """
    # Unpack parameters
    w_q = nodes[0]
    w_k = nodes[1]
    w_v = nodes[2]
    w_o = nodes[3]
    w_1 = nodes[4]
    w_2 = nodes[5]
    b_1 = nodes[6]
    b_2 = nodes[7]

    # --- Self-Attention ---
    # Compute Q, K, V (each: [batch, seq_length, model_dim])
    Q = ad.matmul(X, w_q)
    K = ad.matmul(X, w_k)
    V = ad.matmul(X, w_v)

    # For computing dot–products we need to transpose K (swap dims 1 and 2)
    K_t = ad.transpose(K, 1, 2)  # now K_t is [batch, model_dim, seq_length]

    # Attention scores: [batch, seq_length, seq_length]
    scores = ad.matmul(Q, K_t)

    # Scale the scores by sqrt(model_dim)
    scale = model_dim ** 0.5
    scaled_scores = ad.div_by_const(scores, scale)

    # Apply softmax on the last dimension to obtain attention weights
    attn_weights = ad.softmax(scaled_scores, dim=-1)

    # Compute attention output: weighted sum of V
    context = ad.matmul(attn_weights, V)  # [batch, seq_length, model_dim]

    # (Optional) Output projection: project context using W_O.
    attn_output = ad.matmul(context, w_o)  # [batch, seq_length, model_dim]

    # --- Pooling ---
    # Average over the sequence dimension to get a (batch, model_dim) tensor.
    pooled = ad.sum_op(attn_output, dim=1, keepdim=False)
    avg_pooled = ad.div_by_const(pooled, seq_length)

    # --- Feed-Forward Network ---
    # First linear layer with bias and ReLU activation.
    hidden_linear = ad.matmul(avg_pooled, w_1)
    hidden_linear = ad.add(hidden_linear, b_1)
    hidden = ad.relu(hidden_linear)

    # Second linear layer (output logits)
    logits_linear = ad.matmul(hidden, w_2)
    logits = ad.add(logits_linear, b_2)

    return logits

def softmax_loss(Z: ad.Node, y_one_hot: ad.Node, batch_size: int) -> ad.Node:
    """
    Compute average softmax loss given the logits and one-hot encoded labels.
    """
    # Compute softmax probabilities
    probs = ad.softmax(Z, dim=1)
    # Compute natural log of probabilities
    log_probs = ad.log(probs)
    # Multiply element–wise with the one-hot targets
    prod = ad.mul(y_one_hot, log_probs)
    # Sum over classes for each example
    loss_per_example = ad.sum_op(prod, dim=1, keepdim=False)
    # Sum over the batch and take negative average
    total_loss = ad.sum_op(loss_per_example, dim=0, keepdim=False)
    avg_loss = ad.div_by_const(ad.mul_by_const(total_loss, -1), batch_size)
    return avg_loss

def sgd_epoch(
    f_run_model: Callable,
    X: torch.Tensor,
    y: torch.Tensor,
    model_weights: List[torch.Tensor],
    batch_size: int,
    lr: float,
) -> Tuple[List[torch.Tensor], float]:
    """
    Run one epoch of SGD.
    """
    num_examples = X.shape[0]
    num_batches = (num_examples + batch_size - 1) // batch_size
    total_loss = 0.0

    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = min(start_idx + batch_size, num_examples)
        if end_idx - start_idx == 0:
            continue
        X_batch = X[start_idx:end_idx]
        y_batch = y[start_idx:end_idx]

        # Run the forward and backward pass
        outputs = f_run_model(model_weights, X_batch, y_batch)
        # Expected outputs: [logits, loss, grad_w_q, grad_w_k, grad_w_v, grad_w_o, grad_w_1, grad_w_2, grad_b_1, grad_b_2]
        loss_val = outputs[1]
        grads = outputs[2:]

        # Update each model parameter
        for j in range(len(model_weights)):
            model_weights[j] = model_weights[j] - lr * grads[j]

        # Accumulate loss (weighted by the number of examples in the mini-batch)
        total_loss += loss_val.item() * (end_idx - start_idx)

    average_loss = total_loss / num_examples
    print('Avg_loss:', average_loss)
    return model_weights, average_loss

def train_model():
    """
    Train a single-layer transformer (ViT-style) on MNIST.
    """
    # --- Hyperparameters ---
    input_dim = 28      # Each row of an MNIST image has 28 pixels.
    seq_length = max_len  # There are 28 rows per image.
    num_classes = 10
    model_dim = 128
    eps = 1e-5

    num_epochs = 20
    batch_size = 50
    lr = 0.02

    # --- Build the Computational Graph ---
    # Create input and ground-truth nodes.
    X_node = ad.Variable("X")
    y_groundtruth = ad.Variable("y")

    # Create parameter nodes.
    W_Q = ad.Variable("W_Q")
    W_K = ad.Variable("W_K")
    W_V = ad.Variable("W_V")
    W_O = ad.Variable("W_O")
    W_1 = ad.Variable("W_1")
    W_2 = ad.Variable("W_2")
    b_1 = ad.Variable("b_1")
    b_2 = ad.Variable("b_2")
    param_nodes = [W_Q, W_K, W_V, W_O, W_1, W_2, b_1, b_2]

    # Build the transformer network (which outputs logits)
    y_predict = transformer(X_node, param_nodes, model_dim, seq_length, eps, batch_size, num_classes)
    # Define the softmax loss node.
    loss = softmax_loss(y_predict, y_groundtruth, batch_size)

    # --- Initialize Model Weights ---
    np.random.seed(0)
    stdv = 1.0 / np.sqrt(num_classes)
    W_Q_val = np.random.uniform(-stdv, stdv, (input_dim, model_dim))
    W_K_val = np.random.uniform(-stdv, stdv, (input_dim, model_dim))
    W_V_val = np.random.uniform(-stdv, stdv, (input_dim, model_dim))
    W_O_val = np.random.uniform(-stdv, stdv, (model_dim, model_dim))
    W_1_val = np.random.uniform(-stdv, stdv, (model_dim, model_dim))
    W_2_val = np.random.uniform(-stdv, stdv, (model_dim, num_classes))
    b_1_val = np.random.uniform(-stdv, stdv, (model_dim,))
    b_2_val = np.random.uniform(-stdv, stdv, (num_classes,))

    model_weights = [
        torch.tensor(W_Q_val, dtype=torch.float32),
        torch.tensor(W_K_val, dtype=torch.float32),
        torch.tensor(W_V_val, dtype=torch.float32),
        torch.tensor(W_O_val, dtype=torch.float32),
        torch.tensor(W_1_val, dtype=torch.float32),
        torch.tensor(W_2_val, dtype=torch.float32),
        torch.tensor(b_1_val, dtype=torch.float32),
        torch.tensor(b_2_val, dtype=torch.float32)
    ]

    # --- Dummy Forward Pass to Set Node Shapes ---
    dummy_feed_dict = {
        X_node: torch.zeros(1, seq_length, input_dim),  # (1, 28, 28)
        y_groundtruth: torch.zeros(1, num_classes),       # (1, 10)
        W_Q: model_weights[0],
        W_K: model_weights[1],
        W_V: model_weights[2],
        W_O: model_weights[3],
        W_1: model_weights[4],
        W_2: model_weights[5],
        b_1: model_weights[6],
        b_2: model_weights[7],
    }
    # Use a temporary evaluator to run the dummy forward pass.
    temp_eval = ad.Evaluator([loss])
    _ = temp_eval.run(dummy_feed_dict)
    # Now all nodes have their "shape" attributes set.

    # --- Compute Gradients ---
    grads = ad.gradients(loss, param_nodes)

    # --- Create Evaluators for Training and Testing ---
    evaluator = ad.Evaluator([y_predict, loss] + grads)
    test_evaluator = ad.Evaluator([y_predict])

    # --- Load and Preprocess Data ---
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)

    # Reshape images to (num_examples, 28, 28) and normalize.
    X_train = train_dataset.data.numpy().reshape(-1, 28, 28) / 255.0
    y_train = train_dataset.targets.numpy()
    X_test = test_dataset.data.numpy().reshape(-1, 28, 28) / 255.0
    y_test = test_dataset.targets.numpy()

    # One–hot encode the training labels.
    encoder = OneHotEncoder(sparse_output=False)
    y_train_oh = encoder.fit_transform(y_train.reshape(-1, 1))

    # --- Define Functions for Running the Graph ---
    def f_run_model(model_weights, X_batch, y_batch):
        feed_dict = {
            X_node: X_batch,
            y_groundtruth: y_batch,
            W_Q: model_weights[0],
            W_K: model_weights[1],
            W_V: model_weights[2],
            W_O: model_weights[3],
            W_1: model_weights[4],
            W_2: model_weights[5],
            b_1: model_weights[6],
            b_2: model_weights[7],
        }
        return evaluator.run(feed_dict)

    def f_eval_model(X_val, model_weights: List[torch.Tensor]):
        num_examples = X_val.shape[0]
        num_batches = (num_examples + batch_size - 1) // batch_size
        all_logits = []
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = min(start_idx + batch_size, num_examples)
            if end_idx - start_idx == 0:
                continue
            X_batch = X_val[start_idx:end_idx]
            feed_dict = {
                X_node: X_batch,
                W_Q: model_weights[0],
                W_K: model_weights[1],
                W_V: model_weights[2],
                W_O: model_weights[3],
                W_1: model_weights[4],
                W_2: model_weights[5],
                b_1: model_weights[6],
                b_2: model_weights[7],
            }
            logits = test_evaluator.run(feed_dict)[0]
            all_logits.append(logits)
        concatenated_logits = np.concatenate(
            [log.detach().numpy() if isinstance(log, torch.Tensor) else log for log in all_logits],
            axis=0
        )
        predictions = np.argmax(concatenated_logits, axis=1)
        return predictions

    # --- Convert Data to Torch Tensors ---
    X_train = torch.tensor(X_train, dtype=torch.float32)
    X_test = torch.tensor(X_test, dtype=torch.float32)
    y_train_oh = torch.tensor(y_train_oh, dtype=torch.float32)
    y_test_tensor = torch.tensor(y_test, dtype=torch.int64)

    # --- Training Loop ---
    for epoch in range(num_epochs):
        X_train, y_train_oh = shuffle(X_train, y_train_oh)
        model_weights, loss_val = sgd_epoch(f_run_model, X_train, y_train_oh, model_weights, batch_size, lr)
        pred_labels = f_eval_model(X_test, model_weights)
        accuracy = np.mean(pred_labels == y_test)
        print(f"Epoch {epoch}: test accuracy = {accuracy}, loss = {loss_val}")

    # Final evaluation on test data.
    final_predictions = f_eval_model(X_test, model_weights)
    final_accuracy = np.mean(final_predictions == y_test)
    return final_accuracy

if __name__ == "__main__":
    print("RUNNING")
    final_acc = train_model()
    print(f"Final test accuracy: {final_acc}")


RUNNING


NameError: name 'ad' is not defined