## **Import packages**

In [None]:
import torch
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms, datasets, models
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import os
import time
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image

import cv2
#from google.colab.patches import cv2_imshow

import torchsummary
from torch.nn import init

from datetime import datetime

import json

## **Set hardware**

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

## **Load CNN models (ResNet50)**

In [None]:
#model = models.resnet50(pretrained=True)

In [None]:
def load_model(model_path,model_name,device):
    
    if model_name == 'resnet50':

        model = models.resnet50()
        
    else:
        #default:resnet50
        model = models.resnet50()
        model_name = 'resnet50'

    model_files = [os.path.join(root, file)
             for root, dirs, files in os.walk(model_path)
             for file in files if file.endswith('.pth') and file.startswith(model_name)]
    
    pretrained_model = torch.load(model_files[0])
    
    model.load_state_dict(pretrained_model)

    model.to(device)

    return model

In [None]:
model_name = 'resnet50'

model_path = './pretrained_models'

model = load_model(model_path,model_name,device)

print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

## **Load dataset**

In [132]:
def testset_loader(path,transform,batch_size):
    testset = ImageFolder(path,transform = transform)
    test_loader = torch.utils.data.DataLoader(dataset=testset,
                      batch_size=batch_size,
                      shuffle=False)
    return testset, test_loader

In [None]:
filepath = './ImageNet_20_class/'

val_path = os.path.join(filepath,'val-100k-rgb')


transform = transforms.Compose([
    #transforms.Grayscale(num_output_channels=3),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

batch_size = 1

val_dataset, val_loader = testset_loader(val_path,transform,batch_size)


In [None]:
val_loader

<torch.utils.data.dataloader.DataLoader at 0x7fd2ed0c1be0>

## **Generate mapped dictionary**

In [None]:
def list_direct_subfolders(path):
    """List all immediate sub-folders under the given path."""
    # os.walk() yields a generator that contains folder information for the directory tree being traversed.
    # next() fetches the first tuple, i.e., the top-level directory information.
    _, folders, _ = next(os.walk(path))

    folders.sort()
    
    return folders


def map_dict_generator(path):
    
    real_labels = list_direct_subfolders(path)

    real_labels = [int(item) for item in real_labels]

    L = len(real_labels)

    map_dict = dict(zip(range(L), real_labels))

    return map_dict

map_dict = map_dict_generator(val_path)
real_labels = list(map_dict.values())

print(map_dict)
print(real_labels)

{0: 12, 1: 2, 2: 22, 3: 72, 4: 82}
[12, 2, 22, 72, 82]


## **Single-image classification (Just for test)**

In [None]:
# Pre-processing
preprocess = transforms.Compose([
    # transforms.Grayscale(num_output_channels=3),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the image for inference
input_image = Image.open("test_rgb.png")
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)  # Create a mini-batch as model input

# Model inference
# Ensure the model runs on the hardware
model.eval()
input_batch = input_batch.to(device)
model.to(device)

with torch.no_grad():
    output = model(input_batch)

# Process the output
# The `output` is a Tensor. Here we apply softmax to obtain the probability distribution
# and pick the most probable class.

# Filter for the selected classes
output_chosen = output[:, real_labels]
print(output_chosen)

prob_chosen = torch.nn.functional.softmax(output_chosen[0], dim=0)
print(prob_chosen)

# Get the index of the class with the highest probability
_, top_category = prob_chosen.max(0)
print("Predicted category:", map_dict[top_category.item()])

tensor([[15.1828, -2.5741,  0.1659, -2.2125,  6.1723]], device='cuda:0')
tensor([9.9988e-01, 1.9420e-08, 3.0075e-07, 2.7880e-08, 1.2212e-04],
       device='cuda:0')
Predicted category: 12


## **Dataset classification**

In [140]:
def testset_inference_with_chosen(device,net,test_loader,map_dict,dict_flag,chosen_flag):
    
    net.eval()

    correct = 0
    total = 0
    predict_list = []

    real_labels = list(map_dict.values())

    with torch.no_grad():

        for (inputs, labels) in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = net(inputs)

            # chosen_flag
            if chosen_flag == 1:
                outputs_chosen = outputs[:,real_labels]
            else:
                outputs_chosen = outputs
                
                
            prob = torch.nn.functional.softmax(outputs_chosen[0],dim=0)
            
            _, predicted = prob.max(0)
            
            # chosen_flag
            if chosen_flag == 1:
                predicted = map_dict[predicted.item()]
            else:
                predicted = predicted.item()
            
            label_index = labels.item()
            # print(label_index)
            if dict_flag == 1:
                real_label = map_dict[label_index]
                # print(label_index)
            else:
                real_label = label_index
    
            # print("Predicted category: %d, Ground truth: %d" %(predicted, real_label))
            
            total += labels.size(0)
            correct += (predicted == real_label)
    
            predict_list.append([predicted,real_label])

        
        test_accuracy=correct/total

        print("Total:%d, Correct:%d" %(total,correct))

        print("Accuracy: %.3f%%" %(test_accuracy*100))

    return predict_list,test_accuracy

In [None]:
dict_flag = 1
chosen_flag = 1

val_list,val_acc = testset_inference_with_chosen(device,model,val_loader,map_dict,dict_flag,chosen_flag)

Total:250, Correct:240
Accuracy: 96.000%
Total:250, Correct:242
Accuracy: 96.800%
Total:250, Correct:246
Accuracy: 98.400%
