In [48]:
import torch
from torch import nn
import torchvision.models as models

In [49]:
def get_classifier(in_features: int):
    return nn.Linear(in_features=in_features, out_features=594)

In [50]:
def handle_densenet_model(model: nn.Module, weights: str, save_to: str | None = None):
    model.classifier = get_classifier(in_features=model.classifier.in_features)
    
    if weights:
        model.load_state_dict(torch.load(weights))
    
    if save_to is not None:
        torch.save(model.state_dict(), save_to)
    
    return model

In [80]:
def handle_resnet_model(model: nn.Module, weights: str, save_to: str | None = None):
    model.fc = get_classifier(in_features=model.fc.in_features)
    
    if weights:
        model.load_state_dict(torch.load(weights))
    
    if save_to is not None:
        torch.save(model.state_dict(), save_to)
    
    return model

In [51]:
def get_dense121(weights: str | None = None, save_to: str | None = None):
    model = models.densenet121(weights='DEFAULT')
    return handle_densenet_model(model, weights, save_to)
    

In [52]:
def get_dense161(weights: str | None = None, save_to: str | None = None):
    model = models.densenet161(weights='DEFAULT')
    return handle_densenet_model(model, weights, save_to)

In [53]:
def get_dense169(weights: str | None = None, save_to: str | None = None):
    model = models.densenet169(weights='DEFAULT')
    return handle_densenet_model(model, weights, save_to)

In [54]:
def get_dense201(weights: str | None = None, save_to: str | None = None):
    model = models.densenet201(weights='DEFAULT')
    return handle_densenet_model(model, weights, save_to)

In [81]:
def get_resnet50(weights: str | None = None, save_to: str | None = None):
    model = models.resnet50(weights='DEFAULT')
    return handle_resnet_model(model, weights, save_to)

In [55]:
img = torch.zeros(1, 3, 224, 224)

In [56]:
dense121 = get_dense121(save_to="./weights/densenet121.pt")

In [68]:
dense161 = get_dense161(save_to="./weights/densenet161.pt")

In [60]:
dense169 = get_dense169(save_to="./weights/densenet169.pt")

In [62]:
dense201 = get_dense201(save_to="./weights/densenet201.pt")

In [63]:
res = dense121(img)

print(res.shape)

In [64]:
res = dense161(img)

print(res.shape)

In [65]:
res = dense169(img)

print(res.shape)

In [66]:
res = dense201(img)

print(res.shape)

In [72]:
from torchsummary import summary

In [75]:
summary(dense161, (3, 224, 224))

In [82]:
resnet50 = get_resnet50(save_to="./weights/resnet50.pt")

In [83]:
summary(resnet50, (3, 224, 224))