In [1]:
import os
os.chdir('..')

import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, cohen_kappa_score

import numpy as np

In [2]:
def count_all_parameters(model): 
    return sum(p.numel() for p in model.parameters())

def count_learnabel_parameters(model): 
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [3]:
def build_model(backbone="resnet18", num_classes=3, pretrained=True):

    if backbone == "resnet18":
        model = models.resnet18(weights="IMAGENET1K_V1")
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif backbone == "efficientnet":
        model = models.efficientnet_b0(weights="IMAGENET1K_V1")
        layer_fc: nn.Linear = model.classifier[1] # type: ignore[assignment]
        model.classifier[1] = nn.Linear(layer_fc.in_features, num_classes)
    else:
        raise ValueError(f"Unsupported backbone: {backbone}")
    return model

In [7]:
model = build_model("efficientnet", pretrained=True)
# model = build_model("resnet18", pretrained=True)

In [12]:
from torchsummary import summary
summary(model, (3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 128, 128]             864
       BatchNorm2d-2         [-1, 32, 128, 128]              64
              SiLU-3         [-1, 32, 128, 128]               0
            Conv2d-4         [-1, 32, 128, 128]             288
       BatchNorm2d-5         [-1, 32, 128, 128]              64
              SiLU-6         [-1, 32, 128, 128]               0
 AdaptiveAvgPool2d-7             [-1, 32, 1, 1]               0
            Conv2d-8              [-1, 8, 1, 1]             264
              SiLU-9              [-1, 8, 1, 1]               0
           Conv2d-10             [-1, 32, 1, 1]             288
          Sigmoid-11             [-1, 32, 1, 1]               0
SqueezeExcitation-12         [-1, 32, 128, 128]               0
           Conv2d-13         [-1, 16, 128, 128]             512
      BatchNorm2d-14         [-1, 16, 1

In [11]:
for name, x in model.named_parameters():
    req = 'TRUE' if x.requires_grad else 'FALS'
    print('{:>40}  |  {}  |  {}'.format(name, req, tuple(x.shape)))

                     features.0.0.weight  |  TRUE  |  (32, 3, 3, 3)
                     features.0.1.weight  |  TRUE  |  (32,)
                       features.0.1.bias  |  TRUE  |  (32,)
           features.1.0.block.0.0.weight  |  TRUE  |  (32, 1, 3, 3)
           features.1.0.block.0.1.weight  |  TRUE  |  (32,)
             features.1.0.block.0.1.bias  |  TRUE  |  (32,)
         features.1.0.block.1.fc1.weight  |  TRUE  |  (8, 32, 1, 1)
           features.1.0.block.1.fc1.bias  |  TRUE  |  (8,)
         features.1.0.block.1.fc2.weight  |  TRUE  |  (32, 8, 1, 1)
           features.1.0.block.1.fc2.bias  |  TRUE  |  (32,)
           features.1.0.block.2.0.weight  |  TRUE  |  (16, 32, 1, 1)
           features.1.0.block.2.1.weight  |  TRUE  |  (16,)
             features.1.0.block.2.1.bias  |  TRUE  |  (16,)
           features.2.0.block.0.0.weight  |  TRUE  |  (96, 16, 1, 1)
           features.2.0.block.0.1.weight  |  TRUE  |  (96,)
             features.2.0.block.0.1.bias  |  TRUE  

In [63]:
def freeze_non_linear_layers(model):
    # Freeze everything
    for p in model.parameters():
        p.requires_grad = False

    # Unfreeze only Linear layers
    for m in model.modules():
        if isinstance(m, nn.Linear):
            for p in m.parameters():
                p.requires_grad = True
    return model

In [67]:
model = freeze_non_linear_layers(model)

In [13]:
# print param amounts
print('all params:       {:_d}'.format(count_all_parameters(model)))
print('learnable params: {:_d}'.format(count_learnabel_parameters(model)))

all params:       4_011_391
learnable params: 4_011_391


In [14]:
os.path.dirname("checkpoints/thing/file.pt")

'checkpoints/thing'