In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.optim import lr_scheduler
from torchvision import datasets, transforms, utils, models


In [0]:
def initialize_model(model_name, num_classes):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "alexnet":
        models_list = []
        model_ft = models.alexnet()
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        models_list.append(model_ft)

    elif model_name == "resnet":

        models_list = []
        for model in (torchvision.models.resnet18, torchvision.models.resnet34, torchvision.models.resnet50, torchvision.models.resnet101, models.resnet152):
          model_ft = model()
          num_ftrs = model_ft.fc.in_features
          model_ft.fc = nn.Linear(num_ftrs, num_classes)
          models_list.append(model_ft)


    elif model_name == "vgg":
        models_list = []
        for model in (models.vgg11, models.vgg11_bn, models.vgg13_bn, models.vgg13, models.vgg16_bn, models.vgg16, models.vgg19_bn):
          model_ft = model()
          num_ftrs = model_ft.classifier[-1].in_features
          model_ft.classifier[-1] = nn.Linear(num_ftrs,num_classes)
          models_list.append(model_ft)
        
    elif model_name == "squeezenet":

        models_list = []
        for model in (torchvision.models.squeezenet1_0, torchvision.models.squeezenet1_1):
          model_ft = model()
          model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
          model_ft.num_classes = num_classes
          models_list.append(model_ft)

    elif model_name == "densenet":

        models_list = []
        for model in (torchvision.models.densenet121, torchvision.models.densenet169, torchvision.models.densenet161, models.densenet201):
          model_ft = model()
          num_ftrs = model_ft.classifier.in_features
          model_ft.classifier = nn.Linear(num_ftrs, num_classes)
          models_list.append(model_ft)

    elif model_name == "inception":
        models_list = []
        model_ft = models.inception_v3()
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs,num_classes)
        models_list.append(model_ft)

    elif model_name == "resnext":
      
      models_list = []
      for model in (models.resnext50_32x4d, models.resnext101_32x8d):
          model_ft = model()
          num_ftrs = model_ft.fc.in_features
          model_ft.fc = nn.Linear(num_ftrs,num_classes)
          models_list.append(model_ft)

    elif model_name == "wide_resnet":
      
      models_list = []
      for model in (torchvision.models.wide_resnet50_2, torchvision.models.wide_resnet101_2):
          model_ft = model()
          num_ftrs = model_ft.fc.in_features
          model_ft.fc = nn.Linear(num_ftrs,num_classes)
          models_list.append(model_ft)

    return models_list

In [0]:
model_list = initialize_model("wide_resnet", 9)


In [11]:
model_list[0].state_dict().keys()

odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.0.conv3.weight', 'layer1.0.bn3.weight', 'layer1.0.bn3.bias', 'layer1.0.bn3.running_mean', 'layer1.0.bn3.running_var', 'layer1.0.bn3.num_batches_tracked', 'layer1.0.downsample.0.weight', 'layer1.0.downsample.1.weight', 'layer1.0.downsample.1.bias', 'layer1.0.downsample.1.running_mean', 'layer1.0.downsample.1.running_var', 'layer1.0.downsample.1.num_batches_tracked', 'layer1.1.conv1.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.bn1.num_batches_tracked', 'layer1.1.conv2.we

In [14]:
for i, n in enumerate([3,2,6]):
  print(i)

0
1
2


In [18]:
for i, model in enumerate(model_list):
  print(i, model)

0 ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), str

In [12]:
print(model_list[0].__class__.__name__+str(3))

ResNet3


In [0]:
models.append(torchvision.models.alexnet(pretrained=False, progress=True))
models
torchvision.models.vgg11(pretrained=False, progress=True)

torchvision.models.vgg11_bn(pretrained=False, progress=True)

torchvision.models.vgg13(pretrained=False, progress=True)

torchvision.models.vgg13_bn(pretrained=False, progress=True)

torchvision.models.vgg16(pretrained=False, progress=True)

torchvision.models.vgg16_bn(pretrained=False, progress=True)

torchvision.models.vgg19(pretrained=False, progress=True)

torchvision.models.vgg19_bn(pretrained=False, progress=True)

torchvision.models.resnet18(pretrained=False, progress=True)

torchvision.models.resnet34(pretrained=False, progress=True)

torchvision.models.resnet50(pretrained=False, progress=True)

torchvision.models.resnet101(pretrained=False, progress=True)

torchvision.models.resnet152(pretrained=False, progress=True)

torchvision.models.squeezenet1_0(pretrained=False, progress=True)

torchvision.models.squeezenet1_1(pretrained=False, progress=True)

torchvision.models.densenet121(pretrained=False, progress=True)

torchvision.models.densenet169(pretrained=False, progress=True)

torchvision.models.densenet161(pretrained=False, progress=True)

torchvision.models.densenet201(pretrained=False, progress=True)

torchvision.models.inception_v3(pretrained=False, progress=True)

torchvision.models.googlenet(pretrained=False, progress=True)

torchvision.models.shufflenet_v2_x0_5(pretrained=False, progress=True)

torchvision.models.shufflenet_v2_x1_0(pretrained=False, progress=True)

torchvision.models.shufflenet_v2_x1_5(pretrained=False, progress=True)

torchvision.models.mobilenet_v2(pretrained=False, progress=True)

torchvision.models.resnext50_32x4d(pretrained=False, progress=True)

torchvision.models.resnext101_32x8d(pretrained=False, progress=True)

torchvision.models.wide_resnet50_2(pretrained=False, progress=True)

torchvision.models.wide_resnet101_2(pretrained=False, progress=True)



In [33]:
        models_list = []
        num_classes=9
        for model in (models.vgg11_bn, models.vgg13_bn,models.vgg16_bn,models.vgg19_bn):
          model_ft = model(pretrained=False)
          num_ftrs = model_ft.classifier[-1].in_features
          model_ft.classifier[-1] = nn.Linear(num_ftrs,num_classes)
          models_list.append(model_ft)

KeyboardInterrupt: ignored

In [24]:
print(models_list[1])

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256