# Use pretrained ResNet cnn to predict concentration plot over 6-minute sampling
### Firstly, import dependencies

In [1]:
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import os
from PIL import Image

cudnn.benchmark = True
plt.ion()   # interactive mode

<contextlib.ExitStack at 0x793f5b13a9d0>

### Then, we import the pre-fine-tuned model (a .pth file), and pass it to gpu memory.

In [2]:
model = torchvision.models.resnet18(pretrained=True)            # initialize the model
num_ftrs = model.fc.in_features                                 # get the number of features
model.fc = nn.Linear(num_ftrs, 2)                               # setting up the final layer of the neural network to take num_ftrs inputs and produce 2 outputs (experiment, may be more classes)
model.load_state_dict(torch.load("concen_identifier.pth"))      # load the pre-trained model
model.cuda()                                                    # pass the model to cuda
model.eval()                                                    # test the model-loading condition



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

### We define some pre-process transformation for the test images.

In [3]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224), transforms.InterpolationMode.BILINEAR),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224,224), transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}           # ResNet only support 224*224 images, resize to avoid losing information

### Define device variable (optional)

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Define directory and classes

In [5]:
_dir = "tests"
_classes = ['near', 'far']
pic_name_list = _1st_dir = os.listdir(_dir)

### Traverse through the directory, read the images to CPU buffer and then GPU buffer, and finally do prediction with model loaded in GPU memory

In [6]:
count = 0                                               # correct counts
result_cmp_set = {}                                     # actual:pred
for name in pic_name_list:                              # traverse the directory
    img = Image.open(_dir+'/'+name).convert('RGB')      # open the image
    img = data_transforms['val'](img)                   # do the transformation
    img = img.unsqueeze(0)                      
    img = img.to(device)                                # pass to GPU memory

    with torch.no_grad():           
        outputs = model(img)                            # predict
        _, preds = torch.max(outputs, 1)
    result_cmp_set[name] = _classes[preds]              # add prediction to the collection
    if _classes[preds][0] == name[1]:                   # if correctly predicted, count+1
        count += 1

### Check the result 

In [7]:
print(result_cmp_set)                                 
print('accuracy is: ', count / len(pic_name_list))   # accuracy
print("Correct count = ", count)
print("Incorrect count = ", len(pic_name_list)-count)

{'_far6.png': 'far', '_near3.png': 'near', '_far5.png': 'far', '_far2.png': 'far', '_far8.png': 'far', '_far3.png': 'far', '_near6.png': 'near', '_far9.png': 'far', '_near9.png': 'near', '_near7.png': 'near', '_far10.png': 'far', '_far7.png': 'far', '_near2.png': 'near', '_far4.png': 'far', '_near4.png': 'far', '_near5.png': 'near', '_near8.png': 'near', '_far1.png': 'far', '_near1.png': 'near', '_near10.png': 'near'}
accuracy is:  0.95
Correct count =  19
Incorrect count =  1
