필요한 도구를 불러오자.

In [6]:
import torch
import torch.nn as nn
from torchvision.models import resnet18

resnet18 모델을 선언하고 층을 확인해보자.

In [7]:
model = resnet18(pretrained=True)

In [8]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

이때 fc layer를 바꾸고 싶다면 아래처럼 바꿔줄 수 있다.

In [9]:
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, 10, bias=True)

In [10]:
model.fc

Linear(in_features=512, out_features=10, bias=True)

이때 미리 학습된 fc layer 이전 부분은 학습시키고 싶지 않다면 나머지 부분을 freeze시키면 된다. <br>
먼저 각 층에 접근하는 법을 알아보자. 

아래처럼 **model.children()**을 사용하면 각 층에 접근할 수 있다.

In [14]:
for index, child in enumerate(model.children()):
  print(f'{index}. {child}')
  print(type(child))
  print()

0. Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
<class 'torch.nn.modules.conv.Conv2d'>

1. BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>

2. ReLU(inplace=True)
<class 'torch.nn.modules.activation.ReLU'>

3. MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
<class 'torch.nn.modules.pooling.MaxPool2d'>

4. Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): Bat

이때 각 층을 freeze 시키려면 **파라미터의 requires_grad를 False로 지정**하면 된다. 따라서 아래처럼 접근가능하다.

In [16]:
for child in model.children():
    if isinstance(child, nn.Linear): break
    
    for param in child.parameters():
        param.requires_grad = False

아래처럼 각 층의 requires_grad를 확인하면 올바르게 설정된 것을 확인할 수 있다.

In [17]:
for index, child in enumerate(model.children()):
  print(f'{index}. {child}')
  
  for param in child.parameters():
      print(param.requires_grad)
  print()

0. Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
False

1. BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
False
False

2. ReLU(inplace=True)

3. MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

4. Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), st