In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torchvision.models import resnet18, ResNet18_Weights, ResNet

from torchvision import transforms
from PIL import Image

In [11]:
class ShuffledLinear(nn.modules.Linear):
    def __init__(self, *args):
        super(ShuffledLinear, self).__init__(*args)
    def forward(self, input: Tensor) -> Tensor:
        indices = torch.argsort(torch.rand_like(self.weight.T), dim=-1)
        shuffled_W_T = torch.gather(self.weight.T, dim=-1, index=indices)
        shuffled_W = shuffled_W_T.T
        return F.linear(input, shuffled_W, self.bias)

In [66]:
class ShuffledResnet18(nn.Module):
    def __init__(self, weights):
        super(ShuffledResnet18, self).__init__()
        self.resnet_model = resnet18(weights=weights)
        self.layers = []
        for child in self.resnet_model.children():
            self.layers.append(child)
            pass
        #print(self.layers)
        
        
    def forward(self, x: Tensor) -> Tensor:
        for i, layer in enumerate(self.layers):
            print(i, layer)
            x = layer(x)
        return x

In [63]:
img = Image.open("../data/restnet18_set/image.jpg")
transform_pipe = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])
img_tensor = transform_pipe(img)
img_tensor = torch.unsqueeze(img_tensor, 0)



In [65]:
res = resnet18(weights=ResNet18_Weights.DEFAULT)
res.eval()
pred = res.forward(img_tensor)

In [68]:
model = ShuffledResnet18(weights=ResNet18_Weights.DEFAULT)
model.eval()
pred = model.forward(img_tensor)

0 Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
1 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
2 ReLU(inplace=True)
3 MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
4 Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x1 and 512x1000)