In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch import Tensor
from torchvision.models import resnet50, ResNet50_Weights

In [44]:
model = resnet50(weights=ResNet50_Weights.DEFAULT)

In [34]:
class ShuffledConv2d(nn.modules.Conv2d):
    def __init__(self, *args):
        super(ShuffledConv2d, self).__init__(*args)
    def forward(self, input: Tensor) -> Tensor:
        indices = torch.argsort(torch.rand_like(self.weight.T), dim=-1)
        shuffled_W_T = torch.gather(self.weight.T, dim=-1, index=indices)
        shuffled_W = shuffled_W_T.T
        return self._conv_forward(input, shuffled_W, self.bias)

In [49]:
module = list(model.modules())[1]



In [None]:
class ModifiedResNet(nn.Module):
    def __init__(self):
        super(ModifiedResNet, self).__init__()
        self.resnet = resnet50()

    def forward(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)

        x = self.resnet.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.resnet.fc(x)
        return x