In [1]:
import sys
sys.path.append(sys.path[0] + "/..")

from typing import Any
import caffe
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from functools import partial
from torchvision.models import googlenet
from torchvision.models.googlenet import Inception, InceptionAux
from torchvision.io import read_image
from torchvision.io import ImageReadMode
from torch import Tensor
from torchsummary import summary
from models.googlenet import googlenet as Gnet

  from .collection import imread_collection_wrapper


In [2]:
class BasicConv2d_(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=True, **kwargs)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x: Tensor):
        return self.relu(self.conv(x))

Inception_ = partial(Inception, conv_block=BasicConv2d_)
InceptionAux_ = partial(InceptionAux, conv_block=BasicConv2d_)
model = googlenet(num_classes=23, blocks=[BasicConv2d_,  Inception_, InceptionAux_], aux_logits=False)

model.maxpool1 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True),
                                nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75))
model.maxpool2 = nn.Sequential(nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75),
                               nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True))
model.inception3a.branch3 = nn.Sequential(BasicConv2d_(192, 16, kernel_size=1, stride=(1,1)),
                                            BasicConv2d_(16, 32, kernel_size=5, stride=(1,1), padding=(2,2)))
model.inception3b.branch3 = nn.Sequential(BasicConv2d_(256, 32, kernel_size=1, stride=(1,1)),
                                            BasicConv2d_(32, 96, kernel_size=5, stride=(1,1), padding=(2,2)))
model.inception4a.branch3 = nn.Sequential(BasicConv2d_(480, 16, kernel_size=1, stride=(1,1)),
                                            BasicConv2d_(16, 48, kernel_size=5, stride=(1,1), padding=(2,2)))
model.inception4b.branch3 = nn.Sequential(BasicConv2d_(512, 24, kernel_size=1, stride=(1,1)),
                                            BasicConv2d_(24, 64, kernel_size=5, stride=(1,1), padding=(2,2)))
model.inception4c.branch3 = nn.Sequential(BasicConv2d_(512, 24, kernel_size=1, stride=(1,1)),
                                            BasicConv2d_(24, 64, kernel_size=5, stride=(1,1), padding=(2,2)))
model.inception4d.branch3 = nn.Sequential(BasicConv2d_(512, 32, kernel_size=1, stride=(1,1)),
                                            BasicConv2d_(32, 64, kernel_size=5, stride=(1,1), padding=(2,2)))
model.inception4e.branch3 = nn.Sequential(BasicConv2d_(528, 32, kernel_size=1, stride=(1,1)),
                                            BasicConv2d_(32, 128, kernel_size=5, stride=(1,1), padding=(2,2)))
model.inception5a.branch3 = nn.Sequential(BasicConv2d_(832, 32, kernel_size=1, stride=(1,1)),
                                            BasicConv2d_(32, 128, kernel_size=5, stride=(1,1), padding=(2,2)))
model.inception5b.branch3 = nn.Sequential(BasicConv2d_(832, 48, kernel_size=1, stride=(1,1)),
                                            BasicConv2d_(48, 128, kernel_size=5, stride=(1,1), padding=(2,2)))
model.dropout = nn.Dropout(p=0.4, inplace=False)

prototext = '../weights/minc-model/deploy-googlenet.prototxt'
caffemodel = '../weights/minc-model/minc-googlenet.caffemodel'
net=caffe.Classifier(prototext ,caffemodel)



In [3]:
caffe_state_dict = {}
for n, w in net.params.items():
    caffe_state_dict[n + '_weight'] = net.params[n][0].data
    caffe_state_dict[n + '_bias'] = net.params[n][1].data
caffe_keys = list(caffe_state_dict.keys())
torch_keys = list(model.state_dict().keys())


In [4]:
for i in range(len(caffe_keys)):
    w = torch.Tensor(caffe_state_dict[caffe_keys[i]])
    if model.state_dict()[torch_keys[i]].shape == w.shape:
        model.state_dict()[torch_keys[i]] = w
    else:
        print("Error at ", torch_keys[i])

In [16]:
m = Gnet().cuda()
m.load_state_dict(torch.load("../weights/minc-googlenet.pth"), strict=False)
img = read_image("test.png", mode=ImageReadMode.RGB).cuda()*1.0
img = torch.unsqueeze(img, 0)
img[:, 0, :, :] = img[:, 0, :, :] - 104
img[:, 1, :, :] = img[:, 1, :, :] - 117
img[:, 2, :, :] = img[:, 2, :, :] - 124
y = m(img)
y = torch.argmax(nn.Softmax(dim=1)(y))
y

tensor(14, device='cuda:0')