## 전이학습

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torchvision.models import resnet18, ResNet18_Weights

from torchinfo import summary
from torchmetrics.functional.classification import multiclass_accuracy

In [4]:
### 데이터 로드

imgdir = '../../datas/Project Jellyfish/Train_Test_Valid/test'

#resnet 전처리
preprocessing = transforms.Compose([transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR),
                                    transforms.CenterCrop(size=224),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225))])

#이미지 데이터셋 생성
imgDS = ImageFolder(root=imgdir, transform=preprocessing)
print(imgDS.classes, imgDS.targets, imgDS.imgs, end='/')

['Moon_jellyfish', 'barrel_jellyfish', 'blue_jellyfish', 'compass_jellyfish', 'lions_mane_jellyfish', 'mauve_stinger_jellyfish'] [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5] [('../../datas/Project Jellyfish/Train_Test_Valid/test\\Moon_jellyfish\\07.JPG', 0), ('../../datas/Project Jellyfish/Train_Test_Valid/test\\Moon_jellyfish\\11.jpg', 0), ('../../datas/Project Jellyfish/Train_Test_Valid/test\\Moon_jellyfish\\19.jpg', 0), ('../../datas/Project Jellyfish/Train_Test_Valid/test\\Moon_jellyfish\\23.jpg', 0), ('../../datas/Project Jellyfish/Train_Test_Valid/test\\Moon_jellyfish\\33.jpg', 0), ('../../datas/Project Jellyfish/Train_Test_Valid/test\\Moon_jellyfish\\48.jpg', 0), ('../../datas/Project Jellyfish/Train_Test_Valid/test\\barrel_jellyfish\\08.jpg', 1), ('../../datas/Project Jellyfish/Train_Test_Valid/test\\barrel_jellyfish\\16.jpg', 1), ('../../datas/Project Jellyfish/Train_Test_Valid/test\\barrel_jellyfish\\

In [5]:
# 데이터 로더 생성
imgDL = DataLoader(dataset=imgDS, batch_size=3, shuffle=True, drop_last=True)
for (img, label) in imgDL:
    print(img.shape, label)

torch.Size([3, 3, 224, 224]) tensor([4, 1, 4])
torch.Size([3, 3, 224, 224]) tensor([1, 0, 0])
torch.Size([3, 3, 224, 224]) tensor([0, 1, 5])
torch.Size([3, 3, 224, 224]) tensor([5, 1, 4])
torch.Size([3, 3, 224, 224]) tensor([0, 3, 4])
torch.Size([3, 3, 224, 224]) tensor([4, 3, 5])
torch.Size([3, 3, 224, 224]) tensor([5, 0, 2])
torch.Size([3, 3, 224, 224]) tensor([4, 3, 3])
torch.Size([3, 3, 224, 224]) tensor([2, 2, 5])
torch.Size([3, 3, 224, 224]) tensor([2, 4, 5])
torch.Size([3, 3, 224, 224]) tensor([3, 2, 1])
torch.Size([3, 3, 224, 224]) tensor([0, 3, 5])
torch.Size([3, 3, 224, 224]) tensor([2, 4, 2])


In [10]:
## 모델 설계 / 설정
# 사전학습된 모델 인스턴스 생성
resmodel = resnet18(weights=ResNet18_Weights.DEFAULT) # 파라미터 세팅 안 했음

#전결합층 변경
# in_features: featuremap에서 받은 feature 수, out_features = 출력 / 분류할 클래스 수

resmodel.fc = nn.Linear(in_features=512, out_features=6)


In [11]:
#summary: 구조 보기
summary(resmodel, (3,3,224,224))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [3, 6]                    --
├─Conv2d: 1-1                            [3, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [3, 64, 112, 112]         128
├─ReLU: 1-3                              [3, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [3, 64, 56, 56]           --
├─Sequential: 1-5                        [3, 64, 56, 56]           --
│    └─BasicBlock: 2-1                   [3, 64, 56, 56]           --
│    │    └─Conv2d: 3-1                  [3, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-2             [3, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [3, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [3, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [3, 64, 56, 56]           128
│    │    └─ReLU: 3-6                    [3, 64, 56, 56]           --
│

In [13]:
##resnet18 Feature Module parameter requires_grad = True-> False (변경)

for name, param in resmodel.named_parameters():
    print(name,param.requires_grad, end='   =======>   ')
    param.requires_grad = False
    print(param.requires_grad)
    
###resnet18 Full cConnected Module parameter requires_grad = False -> True (변경)
# 
for name, param in resmodel.fc.named_parameters():
    print(name, param.requires_grad, end='   =======>   ')
    param.requires_grad = True
    print(param.requires_grad)



In [14]:
# 학습 준비
optimizer = optim.Adam(resmodel.fc.parameters())
loss_fn = nn.CrossEntropyLoss()
EPOCHS = 3
