<a href="https://colab.research.google.com/github/ajw1587/Pytorch_Study/blob/main/28_VGG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torchvision import transforms
from torchvision import models

if torch.cuda.is_available():
  device = torch.device('cuda')
  print(device)
else:
  device = torch.device('cpu')
  print(device)

cuda


## Extract Pretrained Weight Value
학습된 가중치 값 뽑아내기

In [3]:
model = models.vgg11_bn().to(device)
# summary(model, input_size = (3, 224, 224))

# print(weight_dict)
weight_list = [(c_key, c_weight) for (c_key, c_weight) in model.state_dict().items()]

w_idx = [0, 7, 14, 21, 56, 58, 60]
vgg11_w = []
for i, weight in enumerate(weight_list):
  name, value = weight_list[i]
  # print(i, "     ", name, "             ", value.shape) # weight 값 확인
  # 0, 7, 14, 21, 56, 58, 60 weight 추출 후 Custom Layer에 적용

for i, weight in enumerate(weight_list):
  name, value = weight_list[i]
  for j in w_idx:
    if i == j:
      vgg11_w.append(value)

for i in range(len(vgg11_w)):
  print(vgg11_w[i].shape)

torch.Size([64, 3, 3, 3])
torch.Size([128, 64, 3, 3])
torch.Size([256, 128, 3, 3])
torch.Size([256, 256, 3, 3])
torch.Size([4096, 25088])
torch.Size([4096, 4096])
torch.Size([1000, 4096])


## VGG Model
VGG 모델 구축해주기

In [14]:
class CUSTOM_VGG(nn.Module):
  def __init__(self, cfg, vgg11_w, num_classes, init_weights=True):
    super(CUSTOM_VGG, self).__init__()

    # Convolution Layers
    self.vgg_model = self.make_layers(cfg)
    
    # Classifier
    self.classifier = nn.Sequential(
        nn.Linear(512 * 7 * 7, 4096),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(4096, 4096),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(4096,num_classes)
    )

    # Initialization Weights
    if init_weights == True:
      self.init_weights(vgg11_w)

  def make_layers(self, layer_list):
    layers = []
    in_channels = 3
    for v in layer_list:
      if v == 'M':
        layers += [nn.MaxPool2d(kernel_size=2, stride=2, padding=2)]
      else:
        conv2d = nn.Conv2d(in_channels, v, kernel_size=3, stride=1, padding=1)
        layers += [conv2d, nn.ReLU()]
        in_channels = v
    return nn.Sequential(*layers)

  def init_weights(self, vgg11_w):
    count = 0
    for layer in self.vgg_model:
      if isinstance(layer, nn.Conv2d) and count < 4:
        print(count)
        layer.weight.data = vgg11_w[count]      # Weight는 VGG11에서 학습한 Weight를 사용한다.
        nn.init.zeros_(layer.bias)         # Bias는 0으로 초기화
        count += 1
      elif isinstance(layer, nn.Conv2d) and count >= 4:
        torch.nn.init.normal_(layer.weight, mean=0, std=0.1) # 평균 0, 분산 0.01
        nn.init.zeros_(layer.bias)
      elif isinstance(layer, nn.Linear):
        print(layer)
        layer.weight.data = vgg11_w[count]
        nn.init.zeros_(layer.bias.data)
        count += 1


cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
model = CUSTOM_VGG(cfg, vgg11_w, num_classes=10)
print(model)

0
1
2
3
CUSTOM_VGG(
  (vgg_model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=2, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU()
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU()
    (9): MaxPool2d(kernel_size=2, stride=2, padding=2, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU()
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU()
    (16): MaxPool2d(kernel_size=2, stride=2, padding=2, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), st