### Load Test Dataset

In [42]:
import os 
import pandas as pd
import torch
import numpy as np
import torchvision
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image
from torch.utils.data import DataLoader



class customDataset(Dataset):
    def __init__(self, csv_file, data_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.data_dir = data_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, index):
        img_path = os.path.join(self.data_dir, self.annotations.iloc[index,1])
        image = Image.open(img_path)
        y_label = torch.tensor(int(self.annotations.iloc[index,2]))
        
        if self.transform:
            image = self.transform(image)
            
        return (image, y_label)
    

transform = transforms.Compose([
                                       transforms.Resize((224, 224)),
                                       transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

# test_set = customDataset(csv_file='./data/test_CapsuleEndos.csv', data_dir='./data/test_CapsuleEndos', \
#                       transform=transform)    # 10197, 2 patients
train_set = customDataset(csv_file='./data/train_CapsuleEndos.csv', data_dir='./data/train_CapsuleEndos', \
                      transform=transform) 
stomach = 0
small = 0
large = 0

for data, label in train_set:
    if label==0:
        stomach += 1
    elif label==1:
        small += 1
    elif label==2:
        large += 1

print(stomach, small, large)



3560 27282 6349


### Load Model

In [39]:
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=3)
model

Loaded pretrained weights for efficientnet-b0


EfficientNet(
  (_conv_stem): Conv2dStaticSamePadding(
    3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
    (static_padding): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)
  )
  (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
  (_blocks): ModuleList(
    (0): MBConvBlock(
      (_depthwise_conv): Conv2dStaticSamePadding(
        32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
        (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
      )
      (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_se_reduce): Conv2dStaticSamePadding(
        32, 8, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv2dStaticSamePadding(
        8, 32, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Conv2dStaticSamePadding(
        32, 16, kernel_size=

### Load Weights

In [40]:
model.load_state_dict(torch.load('model_effi-b0_weights.pth'))

RuntimeError: Error(s) in loading state_dict for EfficientNet:
	Missing key(s) in state_dict: "_conv_stem.weight", "_bn0.weight", "_bn0.bias", "_bn0.running_mean", "_bn0.running_var", "_blocks.0._depthwise_conv.weight", "_blocks.0._bn1.weight", "_blocks.0._bn1.bias", "_blocks.0._bn1.running_mean", "_blocks.0._bn1.running_var", "_blocks.0._se_reduce.weight", "_blocks.0._se_reduce.bias", "_blocks.0._se_expand.weight", "_blocks.0._se_expand.bias", "_blocks.0._project_conv.weight", "_blocks.0._bn2.weight", "_blocks.0._bn2.bias", "_blocks.0._bn2.running_mean", "_blocks.0._bn2.running_var", "_blocks.1._expand_conv.weight", "_blocks.1._bn0.weight", "_blocks.1._bn0.bias", "_blocks.1._bn0.running_mean", "_blocks.1._bn0.running_var", "_blocks.1._depthwise_conv.weight", "_blocks.1._bn1.weight", "_blocks.1._bn1.bias", "_blocks.1._bn1.running_mean", "_blocks.1._bn1.running_var", "_blocks.1._se_reduce.weight", "_blocks.1._se_reduce.bias", "_blocks.1._se_expand.weight", "_blocks.1._se_expand.bias", "_blocks.1._project_conv.weight", "_blocks.1._bn2.weight", "_blocks.1._bn2.bias", "_blocks.1._bn2.running_mean", "_blocks.1._bn2.running_var", "_blocks.2._expand_conv.weight", "_blocks.2._bn0.weight", "_blocks.2._bn0.bias", "_blocks.2._bn0.running_mean", "_blocks.2._bn0.running_var", "_blocks.2._depthwise_conv.weight", "_blocks.2._bn1.weight", "_blocks.2._bn1.bias", "_blocks.2._bn1.running_mean", "_blocks.2._bn1.running_var", "_blocks.2._se_reduce.weight", "_blocks.2._se_reduce.bias", "_blocks.2._se_expand.weight", "_blocks.2._se_expand.bias", "_blocks.2._project_conv.weight", "_blocks.2._bn2.weight", "_blocks.2._bn2.bias", "_blocks.2._bn2.running_mean", "_blocks.2._bn2.running_var", "_blocks.3._expand_conv.weight", "_blocks.3._bn0.weight", "_blocks.3._bn0.bias", "_blocks.3._bn0.running_mean", "_blocks.3._bn0.running_var", "_blocks.3._depthwise_conv.weight", "_blocks.3._bn1.weight", "_blocks.3._bn1.bias", "_blocks.3._bn1.running_mean", "_blocks.3._bn1.running_var", "_blocks.3._se_reduce.weight", "_blocks.3._se_reduce.bias", "_blocks.3._se_expand.weight", "_blocks.3._se_expand.bias", "_blocks.3._project_conv.weight", "_blocks.3._bn2.weight", "_blocks.3._bn2.bias", "_blocks.3._bn2.running_mean", "_blocks.3._bn2.running_var", "_blocks.4._expand_conv.weight", "_blocks.4._bn0.weight", "_blocks.4._bn0.bias", "_blocks.4._bn0.running_mean", "_blocks.4._bn0.running_var", "_blocks.4._depthwise_conv.weight", "_blocks.4._bn1.weight", "_blocks.4._bn1.bias", "_blocks.4._bn1.running_mean", "_blocks.4._bn1.running_var", "_blocks.4._se_reduce.weight", "_blocks.4._se_reduce.bias", "_blocks.4._se_expand.weight", "_blocks.4._se_expand.bias", "_blocks.4._project_conv.weight", "_blocks.4._bn2.weight", "_blocks.4._bn2.bias", "_blocks.4._bn2.running_mean", "_blocks.4._bn2.running_var", "_blocks.5._expand_conv.weight", "_blocks.5._bn0.weight", "_blocks.5._bn0.bias", "_blocks.5._bn0.running_mean", "_blocks.5._bn0.running_var", "_blocks.5._depthwise_conv.weight", "_blocks.5._bn1.weight", "_blocks.5._bn1.bias", "_blocks.5._bn1.running_mean", "_blocks.5._bn1.running_var", "_blocks.5._se_reduce.weight", "_blocks.5._se_reduce.bias", "_blocks.5._se_expand.weight", "_blocks.5._se_expand.bias", "_blocks.5._project_conv.weight", "_blocks.5._bn2.weight", "_blocks.5._bn2.bias", "_blocks.5._bn2.running_mean", "_blocks.5._bn2.running_var", "_blocks.6._expand_conv.weight", "_blocks.6._bn0.weight", "_blocks.6._bn0.bias", "_blocks.6._bn0.running_mean", "_blocks.6._bn0.running_var", "_blocks.6._depthwise_conv.weight", "_blocks.6._bn1.weight", "_blocks.6._bn1.bias", "_blocks.6._bn1.running_mean", "_blocks.6._bn1.running_var", "_blocks.6._se_reduce.weight", "_blocks.6._se_reduce.bias", "_blocks.6._se_expand.weight", "_blocks.6._se_expand.bias", "_blocks.6._project_conv.weight", "_blocks.6._bn2.weight", "_blocks.6._bn2.bias", "_blocks.6._bn2.running_mean", "_blocks.6._bn2.running_var", "_blocks.7._expand_conv.weight", "_blocks.7._bn0.weight", "_blocks.7._bn0.bias", "_blocks.7._bn0.running_mean", "_blocks.7._bn0.running_var", "_blocks.7._depthwise_conv.weight", "_blocks.7._bn1.weight", "_blocks.7._bn1.bias", "_blocks.7._bn1.running_mean", "_blocks.7._bn1.running_var", "_blocks.7._se_reduce.weight", "_blocks.7._se_reduce.bias", "_blocks.7._se_expand.weight", "_blocks.7._se_expand.bias", "_blocks.7._project_conv.weight", "_blocks.7._bn2.weight", "_blocks.7._bn2.bias", "_blocks.7._bn2.running_mean", "_blocks.7._bn2.running_var", "_blocks.8._expand_conv.weight", "_blocks.8._bn0.weight", "_blocks.8._bn0.bias", "_blocks.8._bn0.running_mean", "_blocks.8._bn0.running_var", "_blocks.8._depthwise_conv.weight", "_blocks.8._bn1.weight", "_blocks.8._bn1.bias", "_blocks.8._bn1.running_mean", "_blocks.8._bn1.running_var", "_blocks.8._se_reduce.weight", "_blocks.8._se_reduce.bias", "_blocks.8._se_expand.weight", "_blocks.8._se_expand.bias", "_blocks.8._project_conv.weight", "_blocks.8._bn2.weight", "_blocks.8._bn2.bias", "_blocks.8._bn2.running_mean", "_blocks.8._bn2.running_var", "_blocks.9._expand_conv.weight", "_blocks.9._bn0.weight", "_blocks.9._bn0.bias", "_blocks.9._bn0.running_mean", "_blocks.9._bn0.running_var", "_blocks.9._depthwise_conv.weight", "_blocks.9._bn1.weight", "_blocks.9._bn1.bias", "_blocks.9._bn1.running_mean", "_blocks.9._bn1.running_var", "_blocks.9._se_reduce.weight", "_blocks.9._se_reduce.bias", "_blocks.9._se_expand.weight", "_blocks.9._se_expand.bias", "_blocks.9._project_conv.weight", "_blocks.9._bn2.weight", "_blocks.9._bn2.bias", "_blocks.9._bn2.running_mean", "_blocks.9._bn2.running_var", "_blocks.10._expand_conv.weight", "_blocks.10._bn0.weight", "_blocks.10._bn0.bias", "_blocks.10._bn0.running_mean", "_blocks.10._bn0.running_var", "_blocks.10._depthwise_conv.weight", "_blocks.10._bn1.weight", "_blocks.10._bn1.bias", "_blocks.10._bn1.running_mean", "_blocks.10._bn1.running_var", "_blocks.10._se_reduce.weight", "_blocks.10._se_reduce.bias", "_blocks.10._se_expand.weight", "_blocks.10._se_expand.bias", "_blocks.10._project_conv.weight", "_blocks.10._bn2.weight", "_blocks.10._bn2.bias", "_blocks.10._bn2.running_mean", "_blocks.10._bn2.running_var", "_blocks.11._expand_conv.weight", "_blocks.11._bn0.weight", "_blocks.11._bn0.bias", "_blocks.11._bn0.running_mean", "_blocks.11._bn0.running_var", "_blocks.11._depthwise_conv.weight", "_blocks.11._bn1.weight", "_blocks.11._bn1.bias", "_blocks.11._bn1.running_mean", "_blocks.11._bn1.running_var", "_blocks.11._se_reduce.weight", "_blocks.11._se_reduce.bias", "_blocks.11._se_expand.weight", "_blocks.11._se_expand.bias", "_blocks.11._project_conv.weight", "_blocks.11._bn2.weight", "_blocks.11._bn2.bias", "_blocks.11._bn2.running_mean", "_blocks.11._bn2.running_var", "_blocks.12._expand_conv.weight", "_blocks.12._bn0.weight", "_blocks.12._bn0.bias", "_blocks.12._bn0.running_mean", "_blocks.12._bn0.running_var", "_blocks.12._depthwise_conv.weight", "_blocks.12._bn1.weight", "_blocks.12._bn1.bias", "_blocks.12._bn1.running_mean", "_blocks.12._bn1.running_var", "_blocks.12._se_reduce.weight", "_blocks.12._se_reduce.bias", "_blocks.12._se_expand.weight", "_blocks.12._se_expand.bias", "_blocks.12._project_conv.weight", "_blocks.12._bn2.weight", "_blocks.12._bn2.bias", "_blocks.12._bn2.running_mean", "_blocks.12._bn2.running_var", "_blocks.13._expand_conv.weight", "_blocks.13._bn0.weight", "_blocks.13._bn0.bias", "_blocks.13._bn0.running_mean", "_blocks.13._bn0.running_var", "_blocks.13._depthwise_conv.weight", "_blocks.13._bn1.weight", "_blocks.13._bn1.bias", "_blocks.13._bn1.running_mean", "_blocks.13._bn1.running_var", "_blocks.13._se_reduce.weight", "_blocks.13._se_reduce.bias", "_blocks.13._se_expand.weight", "_blocks.13._se_expand.bias", "_blocks.13._project_conv.weight", "_blocks.13._bn2.weight", "_blocks.13._bn2.bias", "_blocks.13._bn2.running_mean", "_blocks.13._bn2.running_var", "_blocks.14._expand_conv.weight", "_blocks.14._bn0.weight", "_blocks.14._bn0.bias", "_blocks.14._bn0.running_mean", "_blocks.14._bn0.running_var", "_blocks.14._depthwise_conv.weight", "_blocks.14._bn1.weight", "_blocks.14._bn1.bias", "_blocks.14._bn1.running_mean", "_blocks.14._bn1.running_var", "_blocks.14._se_reduce.weight", "_blocks.14._se_reduce.bias", "_blocks.14._se_expand.weight", "_blocks.14._se_expand.bias", "_blocks.14._project_conv.weight", "_blocks.14._bn2.weight", "_blocks.14._bn2.bias", "_blocks.14._bn2.running_mean", "_blocks.14._bn2.running_var", "_blocks.15._expand_conv.weight", "_blocks.15._bn0.weight", "_blocks.15._bn0.bias", "_blocks.15._bn0.running_mean", "_blocks.15._bn0.running_var", "_blocks.15._depthwise_conv.weight", "_blocks.15._bn1.weight", "_blocks.15._bn1.bias", "_blocks.15._bn1.running_mean", "_blocks.15._bn1.running_var", "_blocks.15._se_reduce.weight", "_blocks.15._se_reduce.bias", "_blocks.15._se_expand.weight", "_blocks.15._se_expand.bias", "_blocks.15._project_conv.weight", "_blocks.15._bn2.weight", "_blocks.15._bn2.bias", "_blocks.15._bn2.running_mean", "_blocks.15._bn2.running_var", "_conv_head.weight", "_bn1.weight", "_bn1.bias", "_bn1.running_mean", "_bn1.running_var", "_fc.weight", "_fc.bias". 
	Unexpected key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "bn1.num_batches_tracked", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.bn1.num_batches_tracked", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.bn2.num_batches_tracked", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.bn1.num_batches_tracked", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.bn2.num_batches_tracked", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.bn1.num_batches_tracked", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.bn2.num_batches_tracked", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.bn1.num_batches_tracked", "layer2.0.conv2.weight", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.bn2.num_batches_tracked", "layer2.0.downsample.0.weight", "layer2.0.downsample.1.weight", "layer2.0.downsample.1.bias", "layer2.0.downsample.1.running_mean", "layer2.0.downsample.1.running_var", "layer2.0.downsample.1.num_batches_tracked", "layer2.1.conv1.weight", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.bn1.num_batches_tracked", "layer2.1.conv2.weight", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer2.1.bn2.num_batches_tracked", "layer2.2.conv1.weight", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.bn1.num_batches_tracked", "layer2.2.conv2.weight", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer2.2.bn2.num_batches_tracked", "layer2.3.conv1.weight", "layer2.3.bn1.weight", "layer2.3.bn1.bias", "layer2.3.bn1.running_mean", "layer2.3.bn1.running_var", "layer2.3.bn1.num_batches_tracked", "layer2.3.conv2.weight", "layer2.3.bn2.weight", "layer2.3.bn2.bias", "layer2.3.bn2.running_mean", "layer2.3.bn2.running_var", "layer2.3.bn2.num_batches_tracked", "layer3.0.conv1.weight", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.bn1.num_batches_tracked", "layer3.0.conv2.weight", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.bn2.num_batches_tracked", "layer3.0.downsample.0.weight", "layer3.0.downsample.1.weight", "layer3.0.downsample.1.bias", "layer3.0.downsample.1.running_mean", "layer3.0.downsample.1.running_var", "layer3.0.downsample.1.num_batches_tracked", "layer3.1.conv1.weight", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.bn1.num_batches_tracked", "layer3.1.conv2.weight", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer3.1.bn2.num_batches_tracked", "layer3.2.conv1.weight", "layer3.2.bn1.weight", "layer3.2.bn1.bias", "layer3.2.bn1.running_mean", "layer3.2.bn1.running_var", "layer3.2.bn1.num_batches_tracked", "layer3.2.conv2.weight", "layer3.2.bn2.weight", "layer3.2.bn2.bias", "layer3.2.bn2.running_mean", "layer3.2.bn2.running_var", "layer3.2.bn2.num_batches_tracked", "layer3.3.conv1.weight", "layer3.3.bn1.weight", "layer3.3.bn1.bias", "layer3.3.bn1.running_mean", "layer3.3.bn1.running_var", "layer3.3.bn1.num_batches_tracked", "layer3.3.conv2.weight", "layer3.3.bn2.weight", "layer3.3.bn2.bias", "layer3.3.bn2.running_mean", "layer3.3.bn2.running_var", "layer3.3.bn2.num_batches_tracked", "layer3.4.conv1.weight", "layer3.4.bn1.weight", "layer3.4.bn1.bias", "layer3.4.bn1.running_mean", "layer3.4.bn1.running_var", "layer3.4.bn1.num_batches_tracked", "layer3.4.conv2.weight", "layer3.4.bn2.weight", "layer3.4.bn2.bias", "layer3.4.bn2.running_mean", "layer3.4.bn2.running_var", "layer3.4.bn2.num_batches_tracked", "layer3.5.conv1.weight", "layer3.5.bn1.weight", "layer3.5.bn1.bias", "layer3.5.bn1.running_mean", "layer3.5.bn1.running_var", "layer3.5.bn1.num_batches_tracked", "layer3.5.conv2.weight", "layer3.5.bn2.weight", "layer3.5.bn2.bias", "layer3.5.bn2.running_mean", "layer3.5.bn2.running_var", "layer3.5.bn2.num_batches_tracked", "layer4.0.conv1.weight", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.bn1.num_batches_tracked", "layer4.0.conv2.weight", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.bn2.num_batches_tracked", "layer4.0.downsample.0.weight", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.0.downsample.1.num_batches_tracked", "layer4.1.conv1.weight", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.bn1.num_batches_tracked", "layer4.1.conv2.weight", "layer4.1.bn2.weight", "layer4.1.bn2.bias", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.1.bn2.num_batches_tracked", "layer4.2.conv1.weight", "layer4.2.bn1.weight", "layer4.2.bn1.bias", "layer4.2.bn1.running_mean", "layer4.2.bn1.running_var", "layer4.2.bn1.num_batches_tracked", "layer4.2.conv2.weight", "layer4.2.bn2.weight", "layer4.2.bn2.bias", "layer4.2.bn2.running_mean", "layer4.2.bn2.running_var", "layer4.2.bn2.num_batches_tracked", "fc.weight", "fc.bias". 

### Check Function

In [37]:
# model.eval()
import numpy as np


def check_accuracy(loader, model):
    num_correct = 0
    num_samples = 0
    
    Conf_matrix = np.zeros((3, 3))
    
    model.eval()
    
    with torch.no_grad():
        for images, labels in loader:
            images = images.to('cuda')
            labels = labels.to('cuda')
            results = model(images)
            _, predictions = results.max(1)
#             print(predictions)
            if labels==0:
                if predictions==labels:
                    Conf_matrix[0][0] += 1
                elif predictions==1:
                    Conf_matrix[0][1] += 1
                elif predictions==2:
                    Conf_matrix[0][2] += 1
            elif labels==1:
                if predictions==labels:
                    Conf_matrix[1][1] += 1
                elif predictions==0:
                    Conf_matrix[1][0] += 1
                elif predictions==2:
                    Conf_matrix[1][2] += 1
            elif labels==2:
                if predictions==labels:
                    Conf_matrix[2][2] += 1
                elif predictions==0:
                    Conf_matrix[2][0] += 1
                elif predictions==1:
                    Conf_matrix[2][1] += 1
                    
            num_correct += (predictions==labels).sum()
            num_samples += predictions.size(0)
        
        print(f'Got{num_correct}/{num_samples} with accuracy \
        {float(num_correct)/float(num_samples)*100}%')
        np.set_printoptions(suppress=True)
        print(Conf_matrix)
        
test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=True)        
check_accuracy(test_loader, model)

Got9549/10197 with accuracy         93.64518976169461%
[[ 510.    5.   83.]
 [  14. 9030.  155.]
 [   0.  391.    9.]]


In [26]:
import numpy as np

tesfsd =np.zeros((3,3))
tesfsd.shape


(3, 3)