In [1]:
import os
os.chdir('..')

import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, cohen_kappa_score

import numpy as np

In [2]:
def count_all_parameters(model): 
    return sum(p.numel() for p in model.parameters())

def count_learnabel_parameters(model): 
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [10]:
def build_model(backbone="resnet18", num_classes=3, pretrained=True):

    if backbone == "resnet18":
        model = models.resnet18(weights="IMAGENET1K_V1")
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif backbone == "efficientnet":
        model = models.efficientnet_b0(weights="IMAGENET1K_V1")
        layer_fc: nn.Linear = model.classifier[1] # type: ignore[assignment]
        model.classifier[1] = nn.Linear(layer_fc.in_features, num_classes)
    else:
        raise ValueError(f"Unsupported backbone: {backbone}")
    return model

In [66]:
# model = build_model("efficientnet", pretrained=True)
model = build_model("resnet18", pretrained=True)

In [None]:
from torchsummary import summary
summary(model, (3, 256, 256))

In [None]:
for name, x in model.named_parameters():
    req = 'TRUE' if x.requires_grad else 'FALS'
    print('{:>40}  |  {}  |  {}'.format(name, req, tuple(x.shape)))

In [63]:
def freeze_non_linear_layers(model):
    # Freeze everything
    for p in model.parameters():
        p.requires_grad = False

    # Unfreeze only Linear layers
    for m in model.modules():
        if isinstance(m, nn.Linear):
            for p in m.parameters():
                p.requires_grad = True
    return model

In [67]:
model = freeze_non_linear_layers(model)

In [68]:
# print param amounts
print('all params:       {:_d}'.format(count_all_parameters(model)))
print('learnable params: {:_d}'.format(count_learnabel_parameters(model)))

all params:       11_178_051
learnable params: 1_539
