In [60]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt



class BasicConv2d(nn.Module):

	def __init__(self, in_channels, out_channels, **kwargs):
		super(BasicConv2d, self).__init__()
		self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
		self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

	def forward(self, x):
		x = self.conv(x)
		x = self.bn(x)
		return F.relu(x, inplace=True)

class Inception(nn.Module): 

	def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
		super(Inception, self).__init__()

		self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

		self.branch2 = nn.Sequential(
			BasicConv2d(in_channels, ch3x3red, kernel_size=1),
			BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
		)

		self.branch3 = nn.Sequential(
			BasicConv2d(in_channels, ch5x5red, kernel_size=1),
			BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1)
		)

		self.branch4 = nn.Sequential(
			nn.MaxPool2d(kernel_size=3, stride=1, padding=1,ceil_mode=True),
			BasicConv2d(in_channels, pool_proj, kernel_size=1)
		)

	def forward(self, x):
		branch1 = self.branch1(x)
		branch2 = self.branch2(x)
		branch3 = self.branch3(x)
		branch4 = self.branch4(x)

		outputs = [branch1, branch2, branch3, branch4]
		return torch.cat(outputs, 1)
    
convs = nn.Sequential(
    nn.Conv2d(3, 3, (1,1)),
    nn.BatchNorm2d(3),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(3, 64, (3, 3)),
    nn.BatchNorm2d(64),
    nn.ReLU(),  # relu1-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.BatchNorm2d(64),
    nn.ReLU(),  # relu1-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),

    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 128, (3, 3)),
    nn.BatchNorm2d(128),
    nn.ReLU(),  # relu2-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.BatchNorm2d(128),
    nn.ReLU(),  # relu2-2
    nn.Dropout(0.4),
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),

    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 256, (3, 3)),
    nn.BatchNorm2d(256),
    nn.ReLU(),  # relu3-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.BatchNorm2d(256),
    nn.ReLU(),  # relu3-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.BatchNorm2d(256),
    nn.ReLU(),  # relu3-3 
    nn.Dropout(0.4),
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),

#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(256, 512, (3, 3)),
#     nn.BatchNorm2d(512),
#     nn.ReLU(),  # relu4-1
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(512, 512, (3, 3)),
#     nn.BatchNorm2d(512),
#     nn.ReLU(),  # relu4-2
#     nn.Dropout(0.5),
#     nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
)

class NetModel(nn.Module):
	def __init__(self):
		super(NetModel, self).__init__()

		self.convs = convs
		# self.inception = models.inception_v3(pretrained=False)
		self.inception3a = Inception(256, 128, 128, 192, 32, 96, 64)
		self.inception4a = Inception(480, 192, 96, 208, 16, 48 ,64)
		self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
		# self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
		self.pool = nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True)
		self.fc1 = nn.Linear(512 * 2 * 2, 4096)
		self.fc2 = nn.Linear(4096, 512)
		self.fc3 = nn.Linear(512, 10)

	def forward(self, x):

		x = self.convs(x)
		# 256 x 4 x 4
		x = self.inception3a(x)
		# 480 x 4 x 4
		x = self.inception4a(x)
		# 512 x 4 x 4 
		x = self.inception4b(x)
		# 512 x 4 x 4
		# x = self.inception4c(x)
		# 512 x 4 x 4
		x = self.pool(x)

		x = x.view(-1, 512 * 2 * 2)
		x = self.fc1(x)
		x = self.fc2(x)
		x = self.fc3(x)
		return x
net1 = NetModel()
net1.eval()

lays1=(torch.load('./models/inception1-0.pth'))
for lay in lays1:
    print(lay)
net1.load_state_dict(torch.load('./models/inception1-0.pth'))
net1.cuda()

convs.0.weight
convs.0.bias
convs.1.weight
convs.1.bias
convs.1.running_mean
convs.1.running_var
convs.1.num_batches_tracked
convs.3.weight
convs.3.bias
convs.4.weight
convs.4.bias
convs.4.running_mean
convs.4.running_var
convs.4.num_batches_tracked
convs.7.weight
convs.7.bias
convs.8.weight
convs.8.bias
convs.8.running_mean
convs.8.running_var
convs.8.num_batches_tracked
convs.12.weight
convs.12.bias
convs.13.weight
convs.13.bias
convs.13.running_mean
convs.13.running_var
convs.13.num_batches_tracked
convs.16.weight
convs.16.bias
convs.17.weight
convs.17.bias
convs.17.running_mean
convs.17.running_var
convs.17.num_batches_tracked
convs.22.weight
convs.22.bias
convs.23.weight
convs.23.bias
convs.23.running_mean
convs.23.running_var
convs.23.num_batches_tracked
convs.26.weight
convs.26.bias
convs.27.weight
convs.27.bias
convs.27.running_mean
convs.27.running_var
convs.27.num_batches_tracked
convs.30.weight
convs.30.bias
convs.31.weight
convs.31.bias
convs.31.running_mean
convs.31.runnin

RuntimeError: Error(s) in loading state_dict for NetModel:
	size mismatch for convs.0.weight: copying a param with shape torch.Size([3, 3, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 3, 3, 3]).

In [59]:
transform = transforms.Compose(
	[transforms.ToTensor(),
	 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=20, shuffle=False, num_workers=2)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Files already downloaded and verified


In [56]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net1(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        accu = 100 * correct / total
    print('accuracy: %d %%' % (accu))

accuracy: 52 %
