In [20]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import torchvision.models as models


In [21]:
transform = transforms.Compose([
  transforms.Resize((96,96)),
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.485,0.456,0.406],
                        std=[0.229,0.224,0.225])
])

In [22]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [23]:
rmodel = models.resnet18()
rmodel.fc = torch.nn.Linear(rmodel.fc.in_features, 10)  # STL10 = 10 classes

# Step 2: Load only the state_dict
checkpoint = torch.load('model.pt', map_location='cpu')
rmodel.load_state_dict(checkpoint['model_state_dict'])  # adjust key if different
rmodel.to(device)
rmodel.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [24]:
cat = Image.open('cat.jpeg')
cat_tensor = transform(cat).unsqueeze(0).to(device)

In [25]:
class_names = [
    'airplane',  # 0
    'bird',      # 1
    'car',       # 2
    'cat',       # 3
    'deer',      # 4
    'dog',       # 5
    'horse',     # 6
    'monkey',    # 7
    'ship',      # 8
    'truck'      # 9
]

In [26]:
with torch.no_grad():
    ypred = rmodel(cat_tensor)
    pred_class = torch.argmax(ypred, dim=1).item()
print("Predicted class:", class_names[pred_class])

Predicted class: cat
