In [13]:
import torchvision.models as models
import inspect

def get_classification_models():
    """Get all image classification models from torchvision.models"""
    
    # Get all callable objects from torchvision.models
    all_models = {name: obj for name, obj in inspect.getmembers(models, callable)}
    
    # Filter out non-classification models (detection, segmentation, etc.)
    # Classification models take 'weights' parameter in newer PyTorch versions
    classification_models = {
        name: obj for name, obj in all_models.items() 
        if 'weights' in inspect.signature(obj).parameters
    }
    
    return classification_models

def get_model_transforms(model_name: str):
    """Get the inference transforms for a given model name"""
    # Get the weights enum class (e.g., ResNet50_Weights)
    weights_enum = getattr(models, f"{model_name}_Weights")
    
    # Get the default weights
    weights = weights_enum.DEFAULT
    
    # Get the transforms
    transforms = weights.transforms()
    
    return transforms


if __name__ == "__main__":
    models_dict = get_classification_models()
    print("\nAvailable classification models:")
    for name in sorted(models_dict.keys()):
        print(f"- {name}")


Available classification models:
- alexnet
- convnext_base
- convnext_large
- convnext_small
- convnext_tiny
- densenet121
- densenet161
- densenet169
- densenet201
- efficientnet_b0
- efficientnet_b1
- efficientnet_b2
- efficientnet_b3
- efficientnet_b4
- efficientnet_b5
- efficientnet_b6
- efficientnet_b7
- efficientnet_v2_l
- efficientnet_v2_m
- efficientnet_v2_s
- googlenet
- inception_v3
- maxvit_t
- mnasnet0_5
- mnasnet0_75
- mnasnet1_0
- mnasnet1_3
- mobilenet_v2
- mobilenet_v3_large
- mobilenet_v3_small
- regnet_x_16gf
- regnet_x_1_6gf
- regnet_x_32gf
- regnet_x_3_2gf
- regnet_x_400mf
- regnet_x_800mf
- regnet_x_8gf
- regnet_y_128gf
- regnet_y_16gf
- regnet_y_1_6gf
- regnet_y_32gf
- regnet_y_3_2gf
- regnet_y_400mf
- regnet_y_800mf
- regnet_y_8gf
- resnet101
- resnet152
- resnet18
- resnet34
- resnet50
- resnext101_32x8d
- resnext101_64x4d
- resnext50_32x4d
- shufflenet_v2_x0_5
- shufflenet_v2_x1_0
- shufflenet_v2_x1_5
- shufflenet_v2_x2_0
- squeezenet1_0
- squeezenet1_1
- swin

In [None]:
getattr(models, f"{model_name}_Weights")

In [15]:
if __name__ == "__main__":
    # Example usage with a few models
    model_names = ["ResNet50", "VGG16", "DenseNet121"]
    
    for name in model_names:
        print(f"\nTransforms for {name}:")
        transforms = get_model_transforms(name)
        for t in transforms.transforms:
            print(f"- {t.__class__.__name__}: {t}")


Transforms for ResNet50:


AttributeError: 'ImageClassification' object has no attribute 'transforms'