# Native TIMM Inference

In [None]:
# Path to your checkpoint file
checkpoint_path = "../model/train_ResNet18_2024-12-04_18-54-29/timm/timm_image/pytorch_model.bin"
image_path = "e:\\Current_Workdir\\palm-fruit-classification\\data\\intermediate\\valid\\empty_bunch\\IMG_20220803_112710_crop_0_jpg.rf.bfef2ca25d24fefe9a8c64c68c5bb66f.jpg"

# Configuration from your JSON
config = {
    "architecture": "resnet18",  # ResNet-18 architecture
    "num_classes": 6,            # Number of classes in your custom dataset
    "num_features": 512,         # Features size for ResNet-18
    "pretrained_cfg": {
        "tag": "a1_in1k",               # Pretraining tag (using ImageNet weights)
        "custom_load": False,
        "input_size": [3, 224, 224],    # Input size for training
        "test_input_size": [3, 288, 288],  # Test input size (larger than training)
        "fixed_input_size": False,
        "interpolation": "bicubic",      # Interpolation method
        "crop_pct": 0.95,                # Crop percentage for training
        "test_crop_pct": 1.0,            # Crop percentage for testing
        "crop_mode": "center",           # Center crop
        "mean": [0.485, 0.456, 0.406],   # Mean for ImageNet normalization
        "std": [0.229, 0.224, 0.225],    # Standard deviation for ImageNet normalization
        "num_classes": 1000,             # Default number of classes for ImageNet
        "pool_size": [7, 7],             # Pooling size after convolution
        "first_conv": "conv1",           # First convolutional layer
        "classifier": "fc",              # Final classifier layer
        "origin_url": "https://github.com/huggingface/pytorch-image-models",
        "paper_ids": "arXiv:2110.00476"
    }
}

In [11]:
import timm
import torch
from torchvision import transforms
from PIL import Image

# Instantiate the model
model = timm.create_model(
    config["architecture"],  # ResNet-18 model
    pretrained=False,        # Skip loading pretrained weights from timm
    num_classes=config["num_classes"],  # Adjust final layer for 6 classes
    global_pool="avg",       # Set global pooling (default is "avg" for ResNet)
)

# Load the checkpoint
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))  # Adjust map_location as needed
if "state_dict" in checkpoint:
    state_dict = checkpoint["state_dict"]  # For structured checkpoint files
else:
    state_dict = checkpoint

# Strip prefixes if necessary (e.g., 'module.' when using DataParallel)
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}

# Load weights into the model
model.load_state_dict(state_dict, strict=False)  # Use strict=True for strict matching

# Example preprocessing pipeline
input_size = config["pretrained_cfg"]["input_size"][1:]  # (224, 224) for training
mean = config["pretrained_cfg"]["mean"]
std = config["pretrained_cfg"]["std"]

transform = transforms.Compose([
    transforms.Resize(input_size, interpolation=transforms.InterpolationMode.BICUBIC),  # Resize the image to 224x224
    transforms.CenterCrop(int(input_size[0] * config["pretrained_cfg"]["crop_pct"])),   # Apply center crop
    transforms.ToTensor(),                                                              # Convert image to tensor
    transforms.Normalize(mean=mean, std=std),                                            # Normalize using ImageNet mean and std
])

# Example image (replace with your own image file path)
image = Image.open(image_path).convert("RGB")
image_tensor = transform(image).unsqueeze(0)  # Add batch dimension

# Inference
model.eval()  # Set model to evaluation mode
with torch.no_grad():
    outputs = model(image_tensor)  # Forward pass
    predictions = torch.softmax(outputs, dim=1)  # Convert logits to probabilities

print("Predictions:", predictions)


Predictions: tensor([[0.1811, 0.2997, 0.0941, 0.2725, 0.0779, 0.0748]])


  checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))  # Adjust map_location as needed
