In [22]:
import torch
import torch.nn as nn
from torchvision.transforms import transforms
import numpy as np
from torch.autograd import Variable
from torchvision.models import squeezenet1_1
import torch.functional as F
from io import open
import os
from PIL import Image
import pathlib
import glob
#import cv2

In [23]:
train_path = 'train'
pred_path = 'Check'

In [24]:
root=pathlib.Path(train_path)
classes=sorted([j.name.split('/')[-1] for j in root.iterdir()])

In [25]:
#CNN Network

class ConvNet(nn.Module):
    def __init__(self, num_pokemon = 151):
        super(ConvNet,self).__init__()
    
        #Input shape = (128, 3, 150, 150)
    
        #Output size after convolution = 150
    
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 12, kernel_size = 3, stride = 1, padding = 1)
        #Shape = (128, 12, 150, 150)
        self.bn1 = nn.BatchNorm2d(num_features = 12)
        #Shape = (128,12,150,150)
        self.relu1 = nn.ReLU()
    
        self.pool = nn.MaxPool2d(kernel_size = 2)
        #Reduce the image size by factor of 2
        #Shape = (128, 12, 75,75)
    
        self.conv2 = nn.Conv2d(in_channels = 12, out_channels = 20, kernel_size = 3, stride = 1, padding = 1)
        #Shape = (128, 20, 75, 75)
        self.relu2 = nn.ReLU()
        #Shape = (128,20,75,75)
    
        self.conv3 = nn.Conv2d(in_channels = 20, out_channels = 32, kernel_size = 3, stride = 1, padding = 1)
        #Shape = (128, 32, 75, 75)
        self.bn3 = nn.BatchNorm2d(num_features = 32)
        #Shape = (128,32,75,75)
        self.relu3 = nn.ReLU()
        #Shape = (128, 32, 75, 75)
    
        self.fc = nn.Linear(in_features = 32*75*75, out_features = num_pokemon)
    
        #Feed forward function
    
    def forward(self, input):
        output = self.conv1(input)
        output = self.bn1(output)
        output = self.relu1(output)
        
        output = self.pool(output)
        
        output = self.conv2(output)
        output = self.relu2(output)
        
        output = self.conv3(output)
        output = self.bn3(output)
        output = self.relu3(output)
        
        #Above outpput will be in matrix form, with shape (128, 32, 75, 75)
        
        output = output.view(-1, 32*75*75)
        
        output = self.fc(output)
        
        return output

In [26]:
checkpoint = torch.load('best_checkpoint.model')
model = ConvNet(num_pokemon = 151)
model.load_state_dict(checkpoint)
model.eval()

ConvNet(
  (conv1): Conv2d(3, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(12, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU()
  (conv3): Conv2d(20, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu3): ReLU()
  (fc): Linear(in_features=180000, out_features=151, bias=True)
)

In [27]:
transformer = transforms.Compose([
    transforms.Resize((150,150)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], 
                         [0.5, 0.5, 0.5])
])

In [28]:
#prediction function

def prediction(img_path, transformer):
    
    image = Image.open(img_path)
    
    image_tensor = transformer(image).float()
    
    image_tensor = image_tensor.unsqueeze_(0)
    
    if torch.cuda.is_available():
        image_tensor.cuda()
    
    input = Variable(image_tensor)
    
    output = model(input)
    
    index = output.data.numpy().argmax()
    
    pred = classes[index]
    
    return pred

In [29]:
images_path = glob.glob(pred_path + '/*.jpg')

In [30]:
pred_dict = {}

for i in images_path:
    pred_dict[i[i.rfind('/') + 1:]] = prediction(i,transformer)

In [31]:
pred_dict

{'Check\\1.jpg': 'Bulbasaur',
 'Check\\10.jpg': 'Caterpie',
 'Check\\100.jpg': 'Voltorb',
 'Check\\101.jpg': 'Electrode',
 'Check\\102.jpg': 'Exeggcute',
 'Check\\103.jpg': 'Exeggutor',
 'Check\\104.jpg': 'Cubone',
 'Check\\105.jpg': 'Marowak',
 'Check\\106.jpg': 'Hitmonlee',
 'Check\\107.jpg': 'Hitmonchan',
 'Check\\108.jpg': 'Lickitung',
 'Check\\109.jpg': 'Koffing',
 'Check\\11.jpg': 'Metapod',
 'Check\\110.jpg': 'Weezing',
 'Check\\111.jpg': 'Rhyhorn',
 'Check\\112.jpg': 'Rhydon',
 'Check\\113.jpg': 'Chansey',
 'Check\\114.jpg': 'Tangela',
 'Check\\115.jpg': 'Kangaskhan',
 'Check\\116.jpg': 'Horsea',
 'Check\\117.jpg': 'Seadra',
 'Check\\118.jpg': 'Goldeen',
 'Check\\119.jpg': 'Seaking',
 'Check\\12.jpg': 'Butterfree',
 'Check\\120.jpg': 'Staryu',
 'Check\\121.jpg': 'Starmie',
 'Check\\122.jpg': 'MrMime',
 'Check\\123.jpg': 'Scyther',
 'Check\\124.jpg': 'Jynx',
 'Check\\125.jpg': 'Mankey',
 'Check\\126.jpg': 'Magmar',
 'Check\\127.jpg': 'Pinsir',
 'Check\\128.jpg': 'Tauros',
 'Chec