In [35]:
%matplotlib inline

In [36]:
from torchvision.models import resnet50
import torch
import numpy as np
from tqdm import tqdm

In [32]:
# Assuming model is your PyTorch model

class BlurPoolConv2d(torch.nn.Module):

    # Purpose: This class creates a convolutional layer that first applies a blurring filter to the input before performing the convolution operation.
    # Condition: The function apply_blurpool iterates over all layers of the model and replaces convolution layers (ch.nn.Conv2d) with BlurPoolConv2d if they have a stride greater than 1 and at least 16 input channels.
    # Preventing Aliasing: Blurring the output of convolution layers (especially those with strides greater than 1) helps to reduce aliasing effects. Aliasing occurs when high-frequency signals are sampled too sparsely, leading to incorrect representations.
    # Smooth Transitions: Applying a blur before downsampling ensures that transitions between pixels are smooth, preserving important information in the feature maps.
    # Stabilizing Training: Blurring can help stabilize training by reducing high-frequency noise, making the model less sensitive to small changes in the input data.
    def __init__(self, conv):
        super().__init__()
        default_filter = torch.tensor([[[[1, 2, 1], [2, 4, 2], [1, 2, 1]]]]) / 16.0
        filt = default_filter.repeat(conv.in_channels, 1, 1, 1)
        self.conv = conv
        self.register_buffer("blur_filter", filt)

    def forward(self, x):
        blurred = F.conv2d(
            x,
            self.blur_filter,
            stride=1,
            padding=(1, 1),
            groups=self.conv.in_channels,
            bias=None,
        )
        return self.conv.forward(blurred)

def apply_blurpool(mod: torch.nn.Module):
    for name, child in mod.named_children():
        if isinstance(child, torch.nn.Conv2d) and (
            np.max(child.stride) > 1 and child.in_channels >= 16
        ):
            setattr(mod, name, BlurPoolConv2d(child))
        else:
            apply_blurpool(child)



def _load_checkpoint(model, checkpoint_path: str):
    checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
    checkpoint = checkpoint["model_state_dict"]

    new_state_dict = {}
    for key, value in checkpoint.items():
        if key.startswith("module."):
            new_state_dict[key.replace("module.", "")] = value

    model.load_state_dict(new_state_dict)


use_blurpool = True
resnet = resnet50(weights=None)
if use_blurpool:
    apply_blurpool(resnet)

initial_params = [param.clone() for param in resnet.parameters()]

_load_checkpoint(resnet, checkpoint_path)

for initial, loaded in zip(initial_params, resnet.parameters()):
    if not torch.equal(initial, loaded):
        print("Not same...")
        break

Not same...


In [37]:
def check_initial_weights():
    model1 = resnet50(weights=None)
    model2 = resnet50(weights=None)

    for w1, w2 in zip(model1.parameters(), model2.parameters()):
        if not torch.equal(w1, w2):
            return True

    return False

counter = 0
for _ in tqdm(range(100)):
    if check_initial_weights():
        counter += 1

assert counter==100

100%|██████████| 100/100 [00:40<00:00,  2.46it/s]


In [38]:
counter

100