In [1]:
# https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz

from typing import Optional
import torch
import torchvision
import torchvision.models as models

def get_imagenet12_ds(transform: Optional[callable] = None, batch_size: int = 32, shuffle: bool = True, num_workers: int = 16):
    imagenet_path = '/home/ran/datasets/imagenet12'

    def _default_transform(img):
        if not isinstance(img, torch.Tensor):
            img = torchvision.transforms.ToTensor()(img)

        return img

    if transform is None:
        transform = _default_transform

    imagenet_data = torchvision.datasets.ImageNet(imagenet_path, split='val', transform=transform)
    data_loader = torch.utils.data.DataLoader(imagenet_data,
                                            batch_size=batch_size,
                                            shuffle=shuffle,
                                            num_workers=num_workers)
    
    return data_loader

def eval_torch_model(model, data, use_cuda:bool = True, transform: Optional[callable] = None, n_subset: Optional[int] = None):
    cuda_available = torch.cuda.is_available()
    device = torch.device("cuda:0" if cuda_available and use_cuda else "cpu")

    model.to(device)
    model.eval()

    correct_top1 = 0
    correct_top5 = 0
    total = 0
    with torch.no_grad():
        for i, (images, labels) in enumerate(data):
            if n_subset and i>=n_subset:
                break

            if transform:
                images = transform(images)
            
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass through the model
            outputs = model(images)
            
            # Get the predicted classes
            _, predicted_top1 = torch.max(outputs.data, 1)
            _, predicted_top5 = torch.topk(outputs, 5, dim=1)
            
            # Update the accuracy metrics
            total += labels.size(0)
            correct_top1 += (predicted_top1 == labels).sum().item()
            correct_top5 += (labels.unsqueeze(1) == predicted_top5).any(1).sum().item()

    # Calculate and print the overall accuracy
    accuracy_top1 = 100 * correct_top1 / total
    accuracy_top5 = 100 * correct_top5 / total

    return accuracy_top1, accuracy_top5

weights = models.ResNet18_Weights.IMAGENET1K_V1
preprocess = weights.transforms()

ds = get_imagenet12_ds(transform=preprocess)

In [2]:
model = models.resnet18(weights=weights)

accuray_top1, accuracy_top5 = eval_torch_model(model, ds, n_subset=1)

print(f'Top-1 Accuracy on the ImageNet validation set: {accuray_top1:.2f}%')
print(f'Top-5 Accuracy on the ImageNet validation set: {accuracy_top5:.2f}%')

Top-1 Accuracy on the ImageNet validation set: 75.00%
Top-5 Accuracy on the ImageNet validation set: 96.88%


In [3]:
from typing import Optional
import torch
import torchvision
import torchvision.models as models

for name in models.list_models():
    print(name)

alexnet
convnext_base
convnext_large
convnext_small
convnext_tiny
deeplabv3_mobilenet_v3_large
deeplabv3_resnet101
deeplabv3_resnet50
densenet121
densenet161
densenet169
densenet201
efficientnet_b0
efficientnet_b1
efficientnet_b2
efficientnet_b3
efficientnet_b4
efficientnet_b5
efficientnet_b6
efficientnet_b7
efficientnet_v2_l
efficientnet_v2_m
efficientnet_v2_s
fasterrcnn_mobilenet_v3_large_320_fpn
fasterrcnn_mobilenet_v3_large_fpn
fasterrcnn_resnet50_fpn
fasterrcnn_resnet50_fpn_v2
fcn_resnet101
fcn_resnet50
fcos_resnet50_fpn
googlenet
inception_v3
keypointrcnn_resnet50_fpn
lraspp_mobilenet_v3_large
maskrcnn_resnet50_fpn
maskrcnn_resnet50_fpn_v2
maxvit_t
mc3_18
mnasnet0_5
mnasnet0_75
mnasnet1_0
mnasnet1_3
mobilenet_v2
mobilenet_v3_large
mobilenet_v3_small
mvit_v1_b
mvit_v2_s
quantized_googlenet
quantized_inception_v3
quantized_mobilenet_v2
quantized_mobilenet_v3_large
quantized_resnet18
quantized_resnet50
quantized_resnext101_32x8d
quantized_resnext101_64x4d
quantized_shufflenet_v2_x0_