#### Import necessary packages


In [41]:
# Standard library imports
import cupy as cp
import numpy as np
# from numpy.lib.stride_tricks import sliding_window_view
from cupy.lib.stride_tricks import sliding_window_view
from typing import Tuple

# Progress bar
from tqdm.auto import tqdm

# Third-party imports
import torch
from torch import nn
from torch.utils.data import DataLoader
from einops import einsum, rearrange
from torchvision.datasets import ImageNet
from torchvision.models import AlexNet_Weights

#### Load the weights and biases of AlexNet


In [42]:
weights_and_biases = AlexNet_Weights.DEFAULT.get_state_dict()
print(weights_and_biases.keys())

odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 'features.3.bias', 'features.6.weight', 'features.6.bias', 'features.8.weight', 'features.8.bias', 'features.10.weight', 'features.10.bias', 'classifier.1.weight', 'classifier.1.bias', 'classifier.4.weight', 'classifier.4.bias', 'classifier.6.weight', 'classifier.6.bias'])


#### Load the data


In [None]:
def default_collate(batch):
    """
    A collation function that simply returns the batch as is.
    We convert torch tensors to numpy arrays since np.pad doesn't work on tensors
    """
    imgs, labels = zip(*batch)  # imgs: tuple[torch.Tensor], labels: tuple[int]
    imgs = [cp.asarray(img.numpy(), dtype=cp.float32) for img in imgs]
    imgs = cp.stack(imgs, axis=0)                      # [B,3,224,224]
    labels = cp.asarray(labels, dtype=cp.int64)        # [B]
    return imgs, labels


# implement using ImageNet
imagenet_val = ImageNet(
    root="data/ImageNet1k",
    split="val",
    transform=AlexNet_Weights.IMAGENET1K_V1.transforms()
)

# the dataloader automatically segregates the labels
val_dataloader = DataLoader(
    imagenet_val,
    batch_size=512,
    shuffle=False,
    num_workers=0,
    collate_fn=default_collate
)

#### Define the custom Conv2d

Here we just inherit from nn.Module so that it will work cleanly with pytorch even though the implementation of the computation is done by einops, einsum, and numpy.


In [None]:
class PatchMixin:
    def __init__(self, kernel_size: int, stride: int) -> None:
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride

    def _patch_with_stride(self, x_pad: cp.ndarray) -> cp.ndarray:
        """
        Extract k x k patches (kernel dims) from the input array with the given stride.

        Args:
            x_pad (cp.ndarray): Input array of shape (b, c, h, w).

        Returns:
            cp.ndarray: Array of shape (b, c, h/stride, w/stride, k, k) containing the extracted patches.
        """
        windows = sliding_window_view(  # type: ignore
            x_pad,
            window_shape=(self.kernel_size, self.kernel_size),
            axis=(-2, -1)  # type: ignore
        )
        return windows[:, :, ::self.stride, ::self.stride, :, :]


class WeightsAndBiasMixin:
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def init_weights_and_bias(self, weight_loc: str, bias_loc: str) -> Tuple[cp.ndarray, cp.ndarray]:
        weight_np = weights_and_biases[weight_loc].detach().cpu().numpy()
        bias_np = weights_and_biases[bias_loc].detach().cpu().numpy()
        weight = cp.asarray(weight_np)
        bias = cp.asarray(bias_np)
        return weight, bias


class CustomConv2d(WeightsAndBiasMixin, PatchMixin, nn.Module):
    """
    2D Convolution layer using CuPy, Einops, and einsum (einsum from einops handles cupy).
    Limited shape flexibility for demonstration/inference.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: int = 0,
        weight_loc: str = '',
        bias_loc: str = '',
    ) -> None:
        super().__init__(kernel_size, stride)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        self.weight, self.bias = self.init_weights_and_bias(
            weight_loc, bias_loc)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        # Using pretrained weights only; no init required for now.
        pass

    def _apply_padding(self, x: cp.ndarray) -> cp.ndarray:
        if self.padding == 0:
            return x
        return cp.pad(
            x,
            pad_width=((0, 0), (0, 0), (self.padding, self.padding),
                       (self.padding, self.padding)),
            mode='constant',
            constant_values=0,
        )

    def forward(self, x: cp.ndarray) -> cp.ndarray:
        x_pad = self._apply_padding(x)
        patched_windows = self._patch_with_stride(x_pad)
        pre_activation = einsum(
            patched_windows, self.weight, 'b c w h kw kh, o c kw kh -> b o w h')
        return pre_activation + self.bias[None, :, None, None]  # type: ignore

In [45]:
class CustomReLU(nn.Module):
    """Custom ReLU activation (CuPy)."""

    def forward(self, x: cp.ndarray) -> cp.ndarray:
        return cp.maximum(x, 0.0)

In [46]:
class CustomMaxPool2d(PatchMixin, nn.Module):
    """Custom Max Pooling layer (CuPy)."""

    def __init__(self, kernel_size: int, stride: int) -> None:
        super().__init__(kernel_size, stride)

    def forward(self, x: cp.ndarray) -> cp.ndarray:
        patched_windows = self._patch_with_stride(x)
        return cp.max(patched_windows, axis=(-2, -1))

In [47]:
class CustomAdaptiveAvgPool2d(nn.Module):
    """Custom Adaptive Average Pooling layer (CuPy, inference only)."""

    def __init__(self, output_size: tuple[int, int]) -> None:
        super().__init__()
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            self.output_size = output_size

    def forward(self, x: cp.ndarray) -> cp.ndarray:
        b, c, h, w = x.shape
        out_h, out_w = self.output_size
        out = cp.zeros((b, c, out_h, out_w), dtype=x.dtype)
        for i in range(out_h):
            h_start = int(cp.floor(i * h / out_h))
            h_end = int(cp.ceil((i + 1) * h / out_h))
            for j in range(out_w):
                w_start = int(cp.floor(j * w / out_w))
                w_end = int(cp.ceil((j + 1) * w / out_w))
                region = x[:, :, h_start:h_end, w_start:w_end]
                out[:, :, i, j] = region.mean(axis=(-2, -1))
        return out

In [48]:
class EinopsLinear(WeightsAndBiasMixin, nn.Module):
    def __init__(self, in_features: int, out_features: int, weight_loc: str, bias_loc: str):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight, self.bias = self.init_weights_and_bias(
            weight_loc, bias_loc)

    def forward(self, x: cp.ndarray) -> cp.ndarray:
        y = einsum(x, self.weight, "b i, o i -> b o")
        if self.bias is not None:
            y = y + self.bias
        return y

In [49]:
class AlexNet(nn.Module):
    def __init__(self, num_classes: int = 1000) -> None:
        super().__init__()
        self.features = nn.Sequential(
            CustomConv2d(3, 64, kernel_size=11, stride=4, padding=2,
                         weight_loc='features.0.weight', bias_loc='features.0.bias'),
            CustomReLU(),
            CustomMaxPool2d(kernel_size=3, stride=2),
            CustomConv2d(64, 192, kernel_size=5, padding=2,
                         weight_loc='features.3.weight', bias_loc='features.3.bias'),
            CustomReLU(),
            CustomMaxPool2d(kernel_size=3, stride=2),
            CustomConv2d(192, 384, kernel_size=3, padding=1,
                         weight_loc='features.6.weight', bias_loc='features.6.bias'),
            CustomReLU(),
            CustomConv2d(384, 256, kernel_size=3, padding=1,
                         weight_loc='features.8.weight', bias_loc='features.8.bias'),
            CustomReLU(),
            CustomConv2d(256, 256, kernel_size=3, padding=1,
                         weight_loc='features.10.weight', bias_loc='features.10.bias'),
            CustomReLU(),
            CustomMaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = CustomAdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            EinopsLinear(256 * 6 * 6, 4096, weight_loc='classifier.1.weight',
                         bias_loc='classifier.1.bias'),
            CustomReLU(),
            EinopsLinear(4096, 4096, weight_loc='classifier.4.weight',
                         bias_loc='classifier.4.bias'),
            CustomReLU(),
            EinopsLinear(4096, num_classes, weight_loc='classifier.6.weight',
                         bias_loc='classifier.6.bias'),
        )

    def forward(self, x: cp.ndarray) -> cp.ndarray:
        x = self.features(x)
        x = self.avgpool(x)
        b = x.shape[0]
        x = x.reshape(b, -1)  # flatten
        x = self.classifier(x)
        return x

In [50]:
model = AlexNet()
model.eval()

total = 0
correct = 0

# Estimate total batches for progress bar length
try:
    total_batches = len(val_dataloader)
except TypeError:
    total_batches = None

for images, labels in tqdm(val_dataloader, total=total_batches, desc="Evaluating", leave=True):
    # ensure CuPy arrays
    if isinstance(images, torch.Tensor):
        images = cp.asarray(images.numpy())
    if isinstance(labels, torch.Tensor):
        labels = cp.asarray(labels.numpy())

    outputs = model.forward(images)  # (b, num_classes) cp.ndarray
    predicted = cp.argmax(outputs, axis=1)
    total += int(labels.shape[0])
    correct += int((predicted == labels).sum())
    running_acc = correct / total if total else 0.0
    tqdm.write(f"Running Acc: {running_acc:.4f}")

accuracy = correct / total if total else 0.0
print(f"Validation Accuracy: {accuracy:.4f}")

Evaluating:   1%|          | 1/98 [00:01<02:40,  1.65s/it]

Running Acc: 0.7109


Evaluating:   2%|▏         | 2/98 [00:02<02:18,  1.45s/it]

Running Acc: 0.7500


Evaluating:   3%|▎         | 3/98 [00:04<02:14,  1.42s/it]

Running Acc: 0.7428


Evaluating:   4%|▍         | 4/98 [00:05<02:09,  1.38s/it]

Running Acc: 0.6816


Evaluating:   5%|▌         | 5/98 [00:06<02:06,  1.36s/it]

Running Acc: 0.6559


Evaluating:   6%|▌         | 6/98 [00:08<02:02,  1.33s/it]

Running Acc: 0.6283


Evaluating:   7%|▋         | 7/98 [00:09<02:00,  1.32s/it]

Running Acc: 0.6119


Evaluating:   8%|▊         | 8/98 [00:10<01:59,  1.33s/it]

Running Acc: 0.6155


Evaluating:   9%|▉         | 9/98 [00:12<01:57,  1.32s/it]

Running Acc: 0.6369


Evaluating:  10%|█         | 10/98 [00:13<01:57,  1.33s/it]

Running Acc: 0.6520


Evaluating:  11%|█         | 11/98 [00:14<01:58,  1.36s/it]

Running Acc: 0.6529


Evaluating:  12%|█▏        | 12/98 [00:16<02:03,  1.44s/it]

Running Acc: 0.6481


Evaluating:  13%|█▎        | 13/98 [00:18<02:01,  1.44s/it]

Running Acc: 0.6487


Evaluating:  14%|█▍        | 14/98 [00:19<02:00,  1.43s/it]

Running Acc: 0.6575


Evaluating:  15%|█▌        | 15/98 [00:20<01:59,  1.44s/it]

Running Acc: 0.6629


Evaluating:  16%|█▋        | 16/98 [00:22<01:57,  1.44s/it]

Running Acc: 0.6591


Evaluating:  17%|█▋        | 17/98 [00:23<01:55,  1.42s/it]

Running Acc: 0.6495


Evaluating:  18%|█▊        | 18/98 [00:25<01:54,  1.43s/it]

Running Acc: 0.6452


Evaluating:  19%|█▉        | 19/98 [00:26<01:55,  1.46s/it]

Running Acc: 0.6405


Evaluating:  20%|██        | 20/98 [00:28<01:51,  1.43s/it]

Running Acc: 0.6388


Evaluating:  21%|██▏       | 21/98 [00:29<01:51,  1.44s/it]

Running Acc: 0.6356


Evaluating:  22%|██▏       | 22/98 [00:30<01:48,  1.43s/it]

Running Acc: 0.6355


Evaluating:  23%|██▎       | 23/98 [00:32<01:46,  1.42s/it]

Running Acc: 0.6332


Evaluating:  24%|██▍       | 24/98 [00:33<01:44,  1.41s/it]

Running Acc: 0.6315


Evaluating:  26%|██▌       | 25/98 [00:35<01:42,  1.40s/it]

Running Acc: 0.6312


Evaluating:  27%|██▋       | 26/98 [00:36<01:40,  1.40s/it]

Running Acc: 0.6323


Evaluating:  28%|██▊       | 27/98 [00:38<01:42,  1.44s/it]

Running Acc: 0.6314


Evaluating:  29%|██▊       | 28/98 [00:39<01:44,  1.49s/it]

Running Acc: 0.6309


Evaluating:  30%|██▉       | 29/98 [00:41<01:42,  1.49s/it]

Running Acc: 0.6342


Evaluating:  31%|███       | 30/98 [00:42<01:40,  1.48s/it]

Running Acc: 0.6348


Evaluating:  32%|███▏      | 31/98 [00:44<01:39,  1.49s/it]

Running Acc: 0.6323


Evaluating:  33%|███▎      | 32/98 [00:45<01:37,  1.48s/it]

Running Acc: 0.6370


Evaluating:  34%|███▎      | 33/98 [00:47<01:36,  1.49s/it]

Running Acc: 0.6387


Evaluating:  35%|███▍      | 34/98 [00:48<01:33,  1.47s/it]

Running Acc: 0.6396


Evaluating:  36%|███▌      | 35/98 [00:49<01:31,  1.45s/it]

Running Acc: 0.6385


Evaluating:  37%|███▋      | 36/98 [00:51<01:28,  1.43s/it]

Running Acc: 0.6390


Evaluating:  38%|███▊      | 37/98 [00:52<01:28,  1.45s/it]

Running Acc: 0.6388


Evaluating:  39%|███▉      | 38/98 [00:54<01:31,  1.52s/it]

Running Acc: 0.6383


Evaluating:  40%|███▉      | 39/98 [00:55<01:30,  1.53s/it]

Running Acc: 0.6382


Evaluating:  41%|████      | 40/98 [00:57<01:25,  1.48s/it]

Running Acc: 0.6372


Evaluating:  42%|████▏     | 41/98 [00:58<01:22,  1.45s/it]

Running Acc: 0.6326


Evaluating:  43%|████▎     | 42/98 [01:00<01:21,  1.45s/it]

Running Acc: 0.6305


Evaluating:  44%|████▍     | 43/98 [01:01<01:18,  1.43s/it]

Running Acc: 0.6269


Evaluating:  45%|████▍     | 44/98 [01:02<01:16,  1.43s/it]

Running Acc: 0.6245


Evaluating:  46%|████▌     | 45/98 [01:04<01:16,  1.45s/it]

Running Acc: 0.6212


Evaluating:  47%|████▋     | 46/98 [01:06<01:17,  1.49s/it]

Running Acc: 0.6174


Evaluating:  48%|████▊     | 47/98 [01:07<01:14,  1.46s/it]

Running Acc: 0.6157


Evaluating:  49%|████▉     | 48/98 [01:08<01:12,  1.45s/it]

Running Acc: 0.6128


Evaluating:  50%|█████     | 49/98 [01:10<01:09,  1.42s/it]

Running Acc: 0.6097


Evaluating:  51%|█████     | 50/98 [01:11<01:08,  1.43s/it]

Running Acc: 0.6058


Evaluating:  52%|█████▏    | 51/98 [01:13<01:06,  1.41s/it]

Running Acc: 0.6032


Evaluating:  53%|█████▎    | 52/98 [01:14<01:05,  1.43s/it]

Running Acc: 0.6006


Evaluating:  54%|█████▍    | 53/98 [01:15<01:04,  1.43s/it]

Running Acc: 0.5994


Evaluating:  55%|█████▌    | 54/98 [01:17<01:01,  1.39s/it]

Running Acc: 0.5987


Evaluating:  56%|█████▌    | 55/98 [01:18<01:01,  1.44s/it]

Running Acc: 0.5980


Evaluating:  57%|█████▋    | 56/98 [01:20<01:01,  1.46s/it]

Running Acc: 0.5984


Evaluating:  58%|█████▊    | 57/98 [01:21<00:58,  1.43s/it]

Running Acc: 0.5986


Evaluating:  59%|█████▉    | 58/98 [01:23<00:59,  1.49s/it]

Running Acc: 0.5952


Evaluating:  60%|██████    | 59/98 [01:24<00:58,  1.49s/it]

Running Acc: 0.5932


Evaluating:  61%|██████    | 60/98 [01:26<00:55,  1.45s/it]

Running Acc: 0.5942


Evaluating:  62%|██████▏   | 61/98 [01:27<00:52,  1.42s/it]

Running Acc: 0.5908


Evaluating:  63%|██████▎   | 62/98 [01:28<00:50,  1.40s/it]

Running Acc: 0.5900


Evaluating:  64%|██████▍   | 63/98 [01:30<00:49,  1.40s/it]

Running Acc: 0.5883


Evaluating:  65%|██████▌   | 64/98 [01:31<00:48,  1.42s/it]

Running Acc: 0.5869


Evaluating:  66%|██████▋   | 65/98 [01:33<00:47,  1.44s/it]

Running Acc: 0.5849


Evaluating:  67%|██████▋   | 66/98 [01:34<00:46,  1.44s/it]

Running Acc: 0.5840


Evaluating:  68%|██████▊   | 67/98 [01:35<00:43,  1.40s/it]

Running Acc: 0.5831


Evaluating:  69%|██████▉   | 68/98 [01:37<00:42,  1.43s/it]

Running Acc: 0.5819


Evaluating:  70%|███████   | 69/98 [01:38<00:41,  1.42s/it]

Running Acc: 0.5812


Evaluating:  71%|███████▏  | 70/98 [01:40<00:40,  1.44s/it]

Running Acc: 0.5804


Evaluating:  72%|███████▏  | 71/98 [01:41<00:38,  1.44s/it]

Running Acc: 0.5798


Evaluating:  73%|███████▎  | 72/98 [01:43<00:37,  1.45s/it]

Running Acc: 0.5786


Evaluating:  74%|███████▍  | 73/98 [01:44<00:35,  1.41s/it]

Running Acc: 0.5773


Evaluating:  76%|███████▌  | 74/98 [01:45<00:33,  1.40s/it]

Running Acc: 0.5762


Evaluating:  77%|███████▋  | 75/98 [01:47<00:32,  1.40s/it]

Running Acc: 0.5751


Evaluating:  78%|███████▊  | 76/98 [01:48<00:30,  1.41s/it]

Running Acc: 0.5737


Evaluating:  79%|███████▊  | 77/98 [01:50<00:30,  1.43s/it]

Running Acc: 0.5726


Evaluating:  80%|███████▉  | 78/98 [01:51<00:28,  1.43s/it]

Running Acc: 0.5712


Evaluating:  81%|████████  | 79/98 [01:53<00:27,  1.45s/it]

Running Acc: 0.5714


Evaluating:  82%|████████▏ | 80/98 [01:54<00:26,  1.46s/it]

Running Acc: 0.5697


Evaluating:  83%|████████▎ | 81/98 [01:56<00:24,  1.44s/it]

Running Acc: 0.5687


Evaluating:  84%|████████▎ | 82/98 [01:57<00:22,  1.42s/it]

Running Acc: 0.5665


Evaluating:  85%|████████▍ | 83/98 [01:58<00:21,  1.41s/it]

Running Acc: 0.5649


Evaluating:  86%|████████▌ | 84/98 [02:00<00:19,  1.40s/it]

Running Acc: 0.5652


Evaluating:  87%|████████▋ | 85/98 [02:01<00:18,  1.40s/it]

Running Acc: 0.5643


Evaluating:  88%|████████▊ | 86/98 [02:03<00:16,  1.41s/it]

Running Acc: 0.5642


Evaluating:  89%|████████▉ | 87/98 [02:04<00:15,  1.41s/it]

Running Acc: 0.5636


Evaluating:  90%|████████▉ | 88/98 [02:05<00:14,  1.41s/it]

Running Acc: 0.5627


Evaluating:  91%|█████████ | 89/98 [02:07<00:12,  1.44s/it]

Running Acc: 0.5609


Evaluating:  92%|█████████▏| 90/98 [02:08<00:11,  1.45s/it]

Running Acc: 0.5618


Evaluating:  93%|█████████▎| 91/98 [02:10<00:10,  1.46s/it]

Running Acc: 0.5612


Evaluating:  94%|█████████▍| 92/98 [02:11<00:09,  1.50s/it]

Running Acc: 0.5622


Evaluating:  95%|█████████▍| 93/98 [02:13<00:07,  1.52s/it]

Running Acc: 0.5629


Evaluating:  96%|█████████▌| 94/98 [02:14<00:05,  1.49s/it]

Running Acc: 0.5638


Evaluating:  97%|█████████▋| 95/98 [02:16<00:04,  1.45s/it]

Running Acc: 0.5627


Evaluating:  98%|█████████▊| 96/98 [02:17<00:02,  1.46s/it]

Running Acc: 0.5628


Evaluating:  99%|█████████▉| 97/98 [02:19<00:01,  1.46s/it]

Running Acc: 0.5653


Evaluating: 100%|██████████| 98/98 [02:20<00:00,  1.43s/it]

Running Acc: 0.5656
Validation Accuracy: 0.5656



