In [14]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as Models
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image

Device = torch.device("cuda")

class DenseNetSequenceMatch(nn.Module):
    def __init__(self, NumClasses=100):
        super(DenseNetSequenceMatch, self).__init__()
        Backbone = Models.densenet121(weights=Models.DenseNet121_Weights.DEFAULT)
        Backbone.features.conv0 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        nn.init.kaiming_normal_(Backbone.features.conv0.weight, mode="fan_out", nonlinearity="relu")
        self.Features = Backbone.features
        self.Classifier = nn.Linear(Backbone.classifier.in_features, NumClasses)

    def forward(self, X):
        FeaturesOut = self.Features(X)
        ReluOut = F.relu(FeaturesOut, inplace=True)
        PooledOut = F.adaptive_avg_pool2d(ReluOut, (1, 1))
        Flattened = torch.flatten(PooledOut, 1)
        Logits = self.Classifier(Flattened)
        return Logits

def GetInferenceTransform():
    Mean = (0.5, 0.5, 0.5)
    Std = (0.5, 0.5, 0.5)
    Transform = A.Compose([
        A.Resize(32, 32),
        A.Normalize(mean=Mean, std=Std, max_pixel_value=255.0),
        ToTensorV2()
    ])
    return Transform

IdxToClass = {
    0: "beaver",
    1: "dolphin",
    2: "otter",
    3: "seal",
    4: "whale",
    5: "aquarium fish",
    6: "flatfish",
    7: "ray",
    8: "shark",
    9: "trout",
    10: "orchid",
    11: "poppy",
    12: "rose",
    13: "sunflower",
    14: "tulip",
    15: "bottle",
    16: "bowl",
    17: "can",
    18: "cup",
    19: "plate",
    20: "apple",
    21: "mushroom",
    22: "orange",
    23: "pear",
    24: "sweet pepper",
    25: "clock",
    26: "keyboard",
    27: "lamp",
    28: "telephone",
    29: "television",
    30: "bed",
    31: "chair",
    32: "couch",
    33: "table",
    34: "wardrobe",
    35: "bee",
    36: "beetle",
    37: "butterfly",
    38: "caterpillar",
    39: "cockroach",
    40: "bear",
    41: "leopard",
    42: "lion",
    43: "tiger",
    44: "wolf",
    45: "bridge",
    46: "castle",
    47: "house",
    48: "road",
    49: "skyscraper",
    50: "cloud",
    51: "forest",
    52: "mountain",
    53: "plain",
    54: "sea",
    55: "camel",
    56: "cattle",
    57: "chimpanzee",
    58: "elephant",
    59: "kangaroo",
    60: "fox",
    61: "porcupine",
    62: "possum",
    63: "raccoon",
    64: "skunk",
    65: "crab",
    66: "lobster",
    67: "snail",
    68: "spider",
    69: "worm",
    70: "baby",
    71: "boy",
    72: "girl",
    73: "man",
    74: "woman",
    75: "crocodile",
    76: "dinosaur",
    77: "lizard",
    78: "snake",
    79: "turtle",
    80: "hamster",
    81: "mouse",
    82: "rabbit",
    83: "shrew",
    84: "squirrel",
    85: "maple tree",
    86: "oak tree",
    87: "palm tree",
    88: "pine tree",
    89: "willow tree",
    90: "bicycle",
    91: "bus",
    92: "motorcycle",
    93: "pickup truck",
    94: "train",
    95: "lawn mower",
    96: "rocket",
    97: "streetcar",
    98: "tank",
    99: "tractor"
}

def ClassifyImage(ImagePath, WeightPath, IdxToClass):
    Transform = GetInferenceTransform()
    ImageObj = Image.open(ImagePath).convert("RGB")
    ImageNp = np.array(ImageObj)
    Augmented = Transform(image=ImageNp)
    ImgTensor = Augmented["image"].unsqueeze(0).float().to(Device)
    ModelObj = DenseNetSequenceMatch(NumClasses=100)
    ModelObj.load_state_dict(torch.load(WeightPath, map_location=Device))
    ModelObj.to(Device)
    ModelObj.eval()
    with torch.no_grad():
        Logits = ModelObj(ImgTensor)
        Probs = F.softmax(Logits, dim=1)
        PredIdx = torch.argmax(Probs, dim=1).item()
    PredictedClass = IdxToClass.get(PredIdx, f"Class #{PredIdx}")
    print(f"Predicted Class Index: {PredIdx} -> {PredictedClass}")
    return PredictedClass

WeightPath = r"C:/Users/clash/Downloads/mkc/Sequence Match/SequenceMatch_CIFAR100_260_labeled.pth"
ImagePath = r"C:/Users/clash/Downloads/mkc/3.png"
ClassifyImage(ImagePath, WeightPath, IdxToClass)

Predicted Class Index: 53 -> plain


'plain'