In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torchvision.models import resnet18, ResNet18_Weights, ResNet

from torchvision import transforms
from PIL import Image

In [11]:
class ShuffledLinear(nn.modules.Linear):
    def __init__(self, *args):
        super(ShuffledLinear, self).__init__(*args)
    def forward(self, input: Tensor) -> Tensor:
        indices = torch.argsort(torch.rand_like(self.weight.T), dim=-1)
        shuffled_W_T = torch.gather(self.weight.T, dim=-1, index=indices)
        shuffled_W = shuffled_W_T.T
        return F.linear(input, shuffled_W, self.bias)

In [90]:
class ShuffledResnet18(nn.Module):
    def __init__(self, weights):
        super(ShuffledResnet18, self).__init__()
        self.resnet_model = resnet18(weights=weights)
        self.layers = []
        for i, child in enumerate(self.resnet_model.children()):
            if i == len(list(self.resnet_model.children())) - 1:
                modified_layer = ShuffledLinear(512, 1000, True)
                modified_layer.weight = child.weight
                modified_layer.bias = child.bias
                self.layers.append(modified_layer)
            else:
                self.layers.append(child)
            pass
        
    def forward(self, x: Tensor) -> Tensor:
        for i, layer in enumerate(self.layers):
            if i == len(self.layers) - 1:
                x = torch.flatten(x, 1)
            x = layer(x)
        return x

In [105]:
with open("../data/restnet18_set/imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

In [97]:
img = Image.open("../data/restnet18_set/image.jpg")
transform_pipe = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])
img_tensor = transform_pipe(img)
img_tensor = torch.unsqueeze(img_tensor, 0)



In [104]:
res = resnet18(weights=ResNet18_Weights.DEFAULT)
res.eval()
pred = res.forward(img_tensor)
_, index = torch.max(pred, 1)
percentage = torch.nn.functional.softmax(pred, dim=1)[0] * 100
percentage[index[0]].item()

97.3057632446289

In [100]:
model = ShuffledResnet18(weights=ResNet18_Weights.DEFAULT)
model.eval()
pred = model.forward(img_tensor)
pred

tensor([[-2.0875e+00,  2.9714e+00, -1.3753e+00,  1.1601e+00,  2.1175e+00,
         -1.8867e-01,  7.8912e-01, -1.8231e-01,  2.0535e+00,  2.4593e+00,
         -2.4077e+00,  9.3080e-01,  2.5925e-01,  4.7606e-01,  9.8624e-01,
         -2.9124e+00, -9.7233e-01, -1.7213e+00, -1.6867e+00, -3.5803e+00,
         -3.4935e-01,  1.0233e+00, -2.1248e+00,  9.7970e-01,  1.0139e-01,
          9.7792e-01,  3.6971e+00,  1.8176e-01,  1.4576e-01,  4.1115e+00,
         -1.1358e+00, -2.0476e+00, -6.3897e-03,  2.1294e+00, -1.5466e-01,
         -3.3418e+00, -3.5221e-01,  1.9195e+00,  2.2568e+00,  8.9942e-01,
         -2.2759e+00,  1.0832e+00, -1.9765e+00,  8.8084e-01,  6.6893e-01,
         -5.2979e-01,  4.1197e+00, -3.5377e-01, -3.7315e+00,  1.1392e+00,
          2.6379e-01,  1.2907e+00, -3.5604e+00, -3.0378e+00, -7.3833e-01,
         -9.6334e-01,  6.3746e-03, -6.4084e-01,  7.6603e-01, -6.1063e-01,
          1.2614e-01, -1.5727e+00,  2.7495e-01,  4.5044e+00,  1.4814e+00,
          1.0057e+00,  3.9749e+00, -1.