In [48]:
import torch
from torch import nn
import torchvision.models as models

In [49]:
def get_classifier(in_features: int):
    return nn.Linear(in_features=in_features, out_features=594)

In [50]:
def handle_densenet_model(model: nn.Module, weights: str, save_to: str | None = None):
    model.classifier = get_classifier(in_features=model.classifier.in_features)
    
    if weights:
        model.load_state_dict(torch.load(weights))
    
    if save_to is not None:
        torch.save(model.state_dict(), save_to)
    
    return model

In [80]:
def handle_resnet_model(model: nn.Module, weights: str, save_to: str | None = None):
    model.fc = get_classifier(in_features=model.fc.in_features)
    
    if weights:
        model.load_state_dict(torch.load(weights))
    
    if save_to is not None:
        torch.save(model.state_dict(), save_to)
    
    return model

In [51]:
def get_dense121(weights: str | None = None, save_to: str | None = None):
    model = models.densenet121(weights='DEFAULT')
    return handle_densenet_model(model, weights, save_to)
    

In [52]:
def get_dense161(weights: str | None = None, save_to: str | None = None):
    model = models.densenet161(weights='DEFAULT')
    return handle_densenet_model(model, weights, save_to)

In [53]:
def get_dense169(weights: str | None = None, save_to: str | None = None):
    model = models.densenet169(weights='DEFAULT')
    return handle_densenet_model(model, weights, save_to)

In [54]:
def get_dense201(weights: str | None = None, save_to: str | None = None):
    model = models.densenet201(weights='DEFAULT')
    return handle_densenet_model(model, weights, save_to)

In [81]:
def get_resnet50(weights: str | None = None, save_to: str | None = None):
    model = models.resnet50(weights='DEFAULT')
    return handle_resnet_model(model, weights, save_to)

In [55]:
img = torch.zeros(1, 3, 224, 224)

In [56]:
dense121 = get_dense121(save_to="./weights/densenet121.pt")

In [68]:
dense161 = get_dense161(save_to="./weights/densenet161.pt")

In [60]:
dense169 = get_dense169(save_to="./weights/densenet169.pt")

In [62]:
dense201 = get_dense201(save_to="./weights/densenet201.pt")

In [63]:
res = dense121(img)

print(res.shape)

torch.Size([1, 594])


In [64]:
res = dense161(img)

print(res.shape)

torch.Size([1, 594])


In [65]:
res = dense169(img)

print(res.shape)

torch.Size([1, 594])


In [66]:
res = dense201(img)

print(res.shape)

torch.Size([1, 594])


In [72]:
from torchsummary import summary

In [75]:
summary(dense161, (3, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 2208, 7, 7]          --
|    └─Conv2d: 2-1                       [-1, 96, 112, 112]        14,112
|    └─BatchNorm2d: 2-2                  [-1, 96, 112, 112]        192
|    └─ReLU: 2-3                         [-1, 96, 112, 112]        --
|    └─MaxPool2d: 2-4                    [-1, 96, 56, 56]          --
|    └─_DenseBlock: 2-5                  [-1, 384, 56, 56]         --
|    |    └─_DenseLayer: 3-1             [-1, 48, 56, 56]          101,952
|    |    └─_DenseLayer: 3-2             [-1, 48, 56, 56]          111,264
|    |    └─_DenseLayer: 3-3             [-1, 48, 56, 56]          120,576
|    |    └─_DenseLayer: 3-4             [-1, 48, 56, 56]          129,888
|    |    └─_DenseLayer: 3-5             [-1, 48, 56, 56]          139,200
|    |    └─_DenseLayer: 3-6             [-1, 48, 56, 56]          148,512
|    └─_Transition: 2-6                  [-1, 192,

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 2208, 7, 7]          --
|    └─Conv2d: 2-1                       [-1, 96, 112, 112]        14,112
|    └─BatchNorm2d: 2-2                  [-1, 96, 112, 112]        192
|    └─ReLU: 2-3                         [-1, 96, 112, 112]        --
|    └─MaxPool2d: 2-4                    [-1, 96, 56, 56]          --
|    └─_DenseBlock: 2-5                  [-1, 384, 56, 56]         --
|    |    └─_DenseLayer: 3-1             [-1, 48, 56, 56]          101,952
|    |    └─_DenseLayer: 3-2             [-1, 48, 56, 56]          111,264
|    |    └─_DenseLayer: 3-3             [-1, 48, 56, 56]          120,576
|    |    └─_DenseLayer: 3-4             [-1, 48, 56, 56]          129,888
|    |    └─_DenseLayer: 3-5             [-1, 48, 56, 56]          139,200
|    |    └─_DenseLayer: 3-6             [-1, 48, 56, 56]          148,512
|    └─_Transition: 2-6                  [-1, 192,

In [82]:
resnet50 = get_resnet50(save_to="./weights/resnet50.pt")

In [83]:
summary(resnet50, (3, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
├─Conv2d: 1-1                            [-1, 64, 112, 112]        9,408
├─BatchNorm2d: 1-2                       [-1, 64, 112, 112]        128
├─ReLU: 1-3                              [-1, 64, 112, 112]        --
├─MaxPool2d: 1-4                         [-1, 64, 56, 56]          --
├─Sequential: 1-5                        [-1, 256, 56, 56]         --
|    └─Bottleneck: 2-1                   [-1, 256, 56, 56]         --
|    |    └─Conv2d: 3-1                  [-1, 64, 56, 56]          4,096
|    |    └─BatchNorm2d: 3-2             [-1, 64, 56, 56]          128
|    |    └─ReLU: 3-3                    [-1, 64, 56, 56]          --
|    |    └─Conv2d: 3-4                  [-1, 64, 56, 56]          36,864
|    |    └─BatchNorm2d: 3-5             [-1, 64, 56, 56]          128
|    |    └─ReLU: 3-6                    [-1, 64, 56, 56]          --
|    |    └─Conv2d: 3-7                  [-1, 256, 56, 56]         16,38

Layer (type:depth-idx)                   Output Shape              Param #
├─Conv2d: 1-1                            [-1, 64, 112, 112]        9,408
├─BatchNorm2d: 1-2                       [-1, 64, 112, 112]        128
├─ReLU: 1-3                              [-1, 64, 112, 112]        --
├─MaxPool2d: 1-4                         [-1, 64, 56, 56]          --
├─Sequential: 1-5                        [-1, 256, 56, 56]         --
|    └─Bottleneck: 2-1                   [-1, 256, 56, 56]         --
|    |    └─Conv2d: 3-1                  [-1, 64, 56, 56]          4,096
|    |    └─BatchNorm2d: 3-2             [-1, 64, 56, 56]          128
|    |    └─ReLU: 3-3                    [-1, 64, 56, 56]          --
|    |    └─Conv2d: 3-4                  [-1, 64, 56, 56]          36,864
|    |    └─BatchNorm2d: 3-5             [-1, 64, 56, 56]          128
|    |    └─ReLU: 3-6                    [-1, 64, 56, 56]          --
|    |    └─Conv2d: 3-7                  [-1, 256, 56, 56]         16,38