In [125]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torchvision.models import resnet18, ResNet18_Weights, ResNet
from pprint import pprint
import gzip

from torchvision import transforms
from PIL import Image

In [11]:
class ShuffledLinear(nn.modules.Linear):
    def __init__(self, *args):
        super(ShuffledLinear, self).__init__(*args)
    def forward(self, input: Tensor) -> Tensor:
        indices = torch.argsort(torch.rand_like(self.weight.T), dim=-1)
        shuffled_W_T = torch.gather(self.weight.T, dim=-1, index=indices)
        shuffled_W = shuffled_W_T.T
        return F.linear(input, shuffled_W, self.bias)

In [90]:
class ShuffledResnet18(nn.Module):
    def __init__(self, weights):
        super(ShuffledResnet18, self).__init__()
        self.resnet_model = resnet18(weights=weights)
        self.layers = []
        for i, child in enumerate(self.resnet_model.children()):
            if i == len(list(self.resnet_model.children())) - 1:
                modified_layer = ShuffledLinear(512, 1000, True)
                modified_layer.weight = child.weight
                modified_layer.bias = child.bias
                self.layers.append(modified_layer)
            else:
                self.layers.append(child)
            pass
        
    def forward(self, x: Tensor) -> Tensor:
        for i, layer in enumerate(self.layers):
            if i == len(self.layers) - 1:
                x = torch.flatten(x, 1)
            x = layer(x)
        return x

In [137]:
def print_topk(pred, k, categories):
    probabilities = torch.nn.functional.softmax(pred[0], dim=0)
    topk_prob, topk_catid = torch.topk(probabilities, 5)
    for i in range(topk_prob.size(0)):
        print(categories[topk_catid[i]], topk_prob[i].item())

In [105]:
with open("../data/restnet18_set/imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

In [97]:
img = Image.open("../data/restnet18_set/image.jpg")
transform_pipe = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])
img_tensor = transform_pipe(img)
img_tensor = torch.unsqueeze(img_tensor, 0)



In [139]:
res = resnet18(weights=ResNet18_Weights.DEFAULT)
res.eval()
pred = res.forward(img_tensor)
print_topk(pred=pred, k=5, categories=categories)

comic book 0.9730576276779175
book jacket 0.023433461785316467
packet 0.0005408110446296632
jigsaw puzzle 0.00031631538877263665
wig 0.00021806337463203818


In [138]:
model = ShuffledResnet18(weights=ResNet18_Weights.DEFAULT)
model.eval()
pred = model.forward(img_tensor)
print_topk(pred=pred, k=5, categories=categories)

coffeepot 0.12048495560884476
traffic light 0.10936012864112854
yawl 0.05186710134148598
scuba diver 0.04510520398616791
siamang 0.027079632505774498
