#### Import necessary packages


In [1]:
# Standard library imports
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
from typing import Tuple

# Third-party imports
import torch
from torch import nn
from torch.utils.data import DataLoader
from einops import einsum, rearrange
from torchvision.datasets import ImageNet
from torchvision.models import AlexNet_Weights

#### Load the weights and biases of AlexNet


In [2]:
weights_and_biases = AlexNet_Weights.DEFAULT.get_state_dict()
print(weights_and_biases.keys())

odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 'features.3.bias', 'features.6.weight', 'features.6.bias', 'features.8.weight', 'features.8.bias', 'features.10.weight', 'features.10.bias', 'classifier.1.weight', 'classifier.1.bias', 'classifier.4.weight', 'classifier.4.bias', 'classifier.6.weight', 'classifier.6.bias'])


#### Load the data


In [3]:
def default_collate(batch):
    """
    A collation function that simply returns the batch as is.
    We convert torch tensors to numpy arrays since np.pad doesn't work on tensors
    """
    imgs, labels = zip(*batch)                   # list of (np.ndarray, int)
    imgs = np.stack(imgs, axis=0)               # [B, 3, 224, 224] (np.float32)
    labels = np.asarray(labels, dtype=np.int64)  # [B]
    return imgs, labels


# implement using ImageNet
imagenet_val = ImageNet(
    root="data/ImageNet1k",
    split="val",
    transform=AlexNet_Weights.IMAGENET1K_V1.transforms()
)

# the dataloader automatically segregates the labels
val_dataloader = DataLoader(
    imagenet_val,
    batch_size=1024,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
    # persistent_workers=True,
    collate_fn=default_collate
)

In [4]:
images = None
labels = None
for images, labels in val_dataloader:
    images = images[0]
    labels = labels[0]
    break



#### Define the custom Conv2d

Here we just inherit from nn.Module so that it will work cleanly with pytorch even though the implementation of the computation is done by einops, einsum, and numpy.


In [5]:
class PatchMixin:
    def __init__(self, kernel_size: int, stride: int) -> None:
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride

    def _patch_with_stride(self, x_pad: np.ndarray) -> np.ndarray:
        """
        Extracts k x k patches (kernel dims) from the input array with the given stride.

        Args:
            x (np.ndarray): Input array of shape (b, c, h, w).

        Returns:
            np.ndarray: Array of shape (b, c, h/2, w/2, k, k) containing the extracted patches.
        """
        windows = sliding_window_view(
            # type: ignore
            x_pad,
            window_shape=(self.kernel_size, self.kernel_size),
            axis=(-2, -1)  # type: ignore
        )

        # Stride by taking every second window in both height and width dimensions
        return windows[:, :, ::self.stride, ::self.stride, :, :]


class WeightsAndBiasMixin:
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def init_weights_and_bias(self, weight_loc: str, bias_loc: str) -> Tuple[np.ndarray, np.ndarray]:
        weight = weights_and_biases[weight_loc].detach().numpy()
        bias = weights_and_biases[bias_loc].detach().numpy()

        return weight, bias


class CustomConv2d(WeightsAndBiasMixin, PatchMixin, nn.Module):
    """
    2D Convolution layer using NumPy, Einops, and einsum.

    This isn't as flexible yet because it only supports very specific shapes.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: int = 0,
        weight_loc: str = '',
        bias_loc: str = '',
        # bias: bool = True
    ) -> None:
        super().__init__(kernel_size, stride)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        self.weight, self.bias = self.init_weights_and_bias(
            weight_loc, bias_loc)

        self.reset_parameters()

    def reset_parameters(self) -> None:
        """
        Since we are only going to be using this for inference, then we can just pass.
        We'll just fill this up when needed.
        """
        pass

    def _apply_padding(self, x: np.ndarray) -> np.ndarray:
        return np.pad(
            x,
            pad_width=((0, 0), (0, 0), (self.padding, self.padding),
                       (self.padding, self.padding)),
            mode='constant',
            constant_values=0
        )

    def forward(self, x: np.ndarray) -> np.ndarray:
        x_pad = self._apply_padding(x)
        patched_windows = self._patch_with_stride(x_pad)
        pre_activation = einsum(patched_windows, self.weight,
                                'b c w h kw kh, o c kw kh -> b o w h')

        return pre_activation + self.bias[None, :, None, None]  # type: ignore

In [6]:
class CustomReLU(nn.Module):
    """
    Custom ReLU activation function.
    """

    def forward(self, x: np.ndarray) -> np.ndarray:
        return np.maximum(x, 0.0)

In [7]:
class CustomMaxPool2d(PatchMixin, nn.Module):
    """
    Custom Max Pooling layer.
    """

    def __init__(self, kernel_size: int, stride: int) -> None:
        super().__init__(kernel_size, stride)

    def forward(self, x: np.ndarray) -> np.ndarray:
        patched_windows = self._patch_with_stride(x)
        return np.max(patched_windows, axis=(-2, -1))

In [8]:
class CustomAdaptiveAvgPool2d(nn.Module):
    """
    Custom Adaptive Average Pooling layer (NumPy-based, inference only).

    Args:
        output_size (tuple[int, int]): desired (out_h, out_w)
    """

    def __init__(self, output_size: tuple[int, int]) -> None:
        super().__init__()
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            self.output_size = output_size

    def forward(self, x: np.ndarray) -> np.ndarray:
        """
        Args:
            x (np.ndarray): input of shape (b, c, h, w)
        Returns:
            np.ndarray: pooled output of shape (b, c, out_h, out_w)
        """
        b, c, h, w = x.shape
        out_h, out_w = self.output_size

        # compute region boundaries for adaptive pooling
        out = np.zeros((b, c, out_h, out_w), dtype=x.dtype)
        for i in range(out_h):
            h_start = int(np.floor(i * h / out_h))
            h_end = int(np.ceil((i+1) * h / out_h))
            for j in range(out_w):
                w_start = int(np.floor(j * w / out_w))
                w_end = int(np.ceil((j+1) * w / out_w))
                # (b, c, h_slice, w_slice)
                region = x[:, :, h_start:h_end, w_start:w_end]
                out[:, :, i, j] = region.mean(axis=(-2, -1))
        return out

In [9]:
class EinopsLinear(WeightsAndBiasMixin, nn.Module):
    def __init__(self, in_features: int, out_features: int, weight_loc: str, bias_loc: str):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight, self.bias = self.init_weights_and_bias(
            weight_loc, bias_loc)

    def forward(self, x: np.ndarray) -> np.ndarray:
        """
        Args:
            x (np.ndarray): Input array of shape (batch, in_features).
            self.weight (np.ndarray): Weight matrix of shape (out_features, in_features).

        Returns:
            torch.Tensor: Output tensor of shape (batch, out_features), equivalent to x @ W.T.
        """
        y = einsum(x, self.weight, "b i, o i -> b o")
        if self.bias is not None:
            y = y + self.bias
        return y

In [10]:
class AlexNet(nn.Module):
    def __init__(self, num_classes: int = 1000) -> None:
        super().__init__()
        self.features = nn.Sequential(
            CustomConv2d(3, 64, kernel_size=11, stride=4, padding=2,
                         weight_loc='features.0.weight', bias_loc='features.0.bias'),
            CustomReLU(),
            CustomMaxPool2d(kernel_size=3, stride=2),
            CustomConv2d(64, 192, kernel_size=5, padding=2,
                         weight_loc='features.3.weight', bias_loc='features.3.bias'),
            CustomReLU(),
            CustomMaxPool2d(kernel_size=3, stride=2),
            CustomConv2d(192, 384, kernel_size=3, padding=1,
                         weight_loc='features.6.weight', bias_loc='features.6.bias'),
            CustomReLU(),
            CustomConv2d(384, 256, kernel_size=3, padding=1,
                         weight_loc='features.8.weight', bias_loc='features.8.bias'),
            CustomReLU(),
            CustomConv2d(256, 256, kernel_size=3, padding=1,
                         weight_loc='features.10.weight', bias_loc='features.10.bias'),
            CustomReLU(),
            CustomMaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = CustomAdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            EinopsLinear(256 * 6 * 6, 4096,
                         weight_loc='classifier.1.weight', bias_loc='classifier.1.bias'),
            # nn.ReLU(inplace=True),
            CustomReLU(),
            # nn.Dropout(),
            EinopsLinear(4096, 4096,
                         weight_loc='classifier.4.weight', bias_loc='classifier.4.bias'),
            CustomReLU(),
            # nn.Dropout(),
            EinopsLinear(4096, num_classes,
                         weight_loc='classifier.6.weight', bias_loc='classifier.6.bias'),
        )

    def forward(self, x: np.ndarray) -> np.ndarray:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(torch.from_numpy(x), 1).numpy()  # type: ignore
        x = self.classifier(x)
        print('finished one')
        return x

In [None]:
import numpy as np

model = AlexNet()
model.eval()

total = 0
correct = 0

for images, labels in val_dataloader:  # assume your dataloader gives numpy arrays
    # make sure they're numpy
    if isinstance(images, torch.Tensor):
        images = images.numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.numpy()

    # your custom forward returns np.ndarray (b, num_classes)
    outputs = model.forward(images)

    # predicted class indices
    predicted = np.argmax(outputs, axis=1)

    total += labels.shape[0]
    correct += (predicted == labels).sum()

accuracy = correct / total
print(f"Test Accuracy: {accuracy:.4f}")