In [None]:
from google.colab import drive
drive.mount('/content/drive/')

In [None]:
%cd /content/drive/MyDrive/image_classification_course

**Import Required Libraries**

In [None]:
import os
import torch

import numpy as np
import cv2

from models.resnet import resnet_50

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


**Define Parameters**

In [None]:

in_channel = 3
num_classes = 7
image_shape = (244, 244)
model_checkpoint_path = "weights/resnet/resnet_model_checkpoint.pth"

**Loading model**

In [None]:
model = resnet(num_classes= num_classes)
model = torch.compile(model)

model.to(device)
model.eval()

checkpoint = torch.load(model_checkpoint_path)
model.load_state_dict(checkpoint)

**Create folder to save images**

In [None]:
save_image_path = "save_images"

if not os.path.exists(save_image_path):
    os.makedirs(save_image_path)

**Getting all images path**

In [None]:
test_dataset_path = "data/version1/test"
classes = os.listdir(test_dataset_path)
classes.sort()

images_path_list = []
for clc in classes:
    images_name = os.listdir(test_dataset_path + "/" + clc)
    for img_name in images_name:
        images_path_list.append(test_dataset_path + "/" + clc + "/" + img_name)


In [None]:
print(images_path_list[0:10])

In [None]:

for idx, img_path in enumerate(images_path_list):
    img_name = os.path.basename(img_path)
    image = cv2.imread(img_path)
    img = image.copy()
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, image_shape)
    """
      input image shape:
        (height, width, no of channel)
         (32, 32, 3)
    axis (0, 1, 2)

    model input shape:
        (batch_size, no of channel, height, width)
        (1, 3, 32, 32)
        new axis will be (new axis, 2, 0, 1)
    """
    img = np.transpose(img, (2, 0, 1))
    img = np.expand_dims(img, axis=0)
    img = img/255
    img = torch.Tensor(img).to(device)

    with torch.no_grad():
        pred_score = model(img)

    _, y_pred = pred_score.max(1)
    class_index = y_pred.item()
    pred_class = classes[class_index]
    actual_class = img_path.split("/")[-2]
    cv2.putText(image, actual_class, (10,10), cv2.FONT_HERSHEY_PLAIN, 1, (255,0,0), 1, cv2.LINE_AA)
    cv2.putText(image, pred_class, (10,20), cv2.FONT_HERSHEY_PLAIN, 1, (255,0,0), 1, cv2.LINE_AA)

    cv2.imwrite(save_image_path + "/" + img_name, image)

print("inferencing complete")