In [1]:
import torch
import torchvision.transforms as transforms
from d2l import torch as d2l
from collections import OrderedDict
from PIL import Image
import model.model as model

In [10]:
devices = d2l.try_all_gpus()
file_dir = r"/home/dell/Research/Disk/1_dataset/Classification_3/upload/Powder_SingleCrystal_Liquid_ResNet18/"
checkpoint_path = file_dir + r"/checkpoint/1retrain_resnet18_batchsize_50_epoch_150_lp_1_ld_0.99.pth"

In [7]:
net = model.get_net(pretrained = True).net
checkpoint = torch.load(checkpoint_path, map_location='cpu')
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    if k.startswith('module.'):
        new_state_dict[k[7:]] = v
    else:
        new_state_dict[k] = v
net.load_state_dict(new_state_dict)

net = net.to(devices[0])
net.eval()  

  checkpoint = torch.load(checkpoint_path, map_location='cpu')


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [8]:
classes = ['Liquid', 'Powder', 'SingleCrystal']

In [9]:
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
    ])
    image = Image.open(image_path).convert('RGB')
    return transform(image).unsqueeze(0)  
def predict_image(image_path):
    image_tensor = preprocess_image(image_path)
    image_tensor = image_tensor.to(devices[0])
    
    with torch.no_grad():
        output = net(image_tensor)
    
    _, predicted = torch.max(output, 1)
    predicted_class = predicted.item()
    predicted_label = classes[predicted_class]
    
    return predicted_class, predicted_label

In [14]:
image_path = file_dir + r"/sample/Crystal4.JPG"  
predicted_class, predicted_label = predict_image(image_path)

print(f'Predicted Class ID: {predicted_class}')
print(f'Predicted Class Name: {predicted_label}')

Predicted Class ID: 2
Predicted Class Name: SingleCrystal
