In [13]:
import pandas as pd
import sys
import glob
from torchvision import transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torchvision import models
import torch.nn as nn
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

sys.path.append('finetuning/')
from utils import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
df=pd.DataFrame()
df['image']=glob.glob('./lemon_dataset/*/*.jpg')
df['label']=df.image.str.contains('good')

In [5]:
set_seed(1998,device)

X_train, X_test, y_train, y_test=train_test_split(
    df['image'],df['label'],test_size=0.2,random_state=1998,shuffle=True,
    stratify=df['label'])

X_train, X_val, y_train, y_val=train_test_split(
    X_train,y_train,test_size=0.3,random_state=1998,shuffle=True,
    stratify=y_train)

In [6]:
data_transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomResizedCrop(300),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

In [7]:
train_ds=dataset(X_train,data_transforms)
val_ds=dataset(X_train,data_transforms)
test_ds=dataset(X_test,data_transforms)

train_loader=DataLoader(train_ds,batch_size=8,
shuffle=True,num_workers=0,drop_last=True)

val_loader=DataLoader(val_ds,batch_size=8,
shuffle=True,num_workers=0,drop_last=True)

test_loader=DataLoader(test_ds,batch_size=8,
shuffle=True,num_workers=0,drop_last=True)

In [8]:
pretrained_resnet50 = models.resnet50(pretrained=True).to(device)

for param in pretrained_resnet50.parameters():
    param.requires_grad=False
    
in_features=pretrained_resnet50.fc.in_features
pretrained_resnet50.fc = nn.Linear(in_features, 1).to(device)



In [None]:
#from torchsummary import summary
#summary(pretrained_resnet50, input_size=(3, 300, 300))

In [9]:
from train import *

In [10]:
criterion= nn.BCEWithLogitsLoss()

In [54]:
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
import datetime

time_stamp = str(datetime.datetime.now().timestamp())
log_dir = 'runs/resnet' + time_stamp # runs 경로
#writer = SummaryWriter(log_dir)

#%load_ext tensorboard

# 텐보 로드
#%tensorboard --logdir logs/tensorboard


train(pretrained_resnet50,train_loader,
val_loader,criterion,1,0.00005,device,weight_name='output_trained_once')

100%|██████████| 1/1 [00:21<00:00, 21.58s/it]

Best Trial Renewed

        Train Accuracy is 0.822
        Validation Accuracy is 0.829
        Current best Accuracy is 0.829
        





In [55]:
resnet50 = models.resnet50(pretrained=False).to(device)
resnet50.fc = nn.Linear(resnet50.fc.in_features, 1).to(device)
model=torch.load('./runs/resnet/output_trained_once.pt')
resnet50.fc.weight=nn.Parameter(model['fc.weight']) #Resnet backbone에 출력층 학습된 것만 사용

In [56]:
train(resnet50,train_loader,val_loader,
criterion,50,0.00005,device,weight_name='resnet_fine_tuning')

  2%|▏         | 1/50 [00:35<28:46, 35.24s/it]

Best Trial Renewed

        Train Accuracy is 0.767
        Validation Accuracy is 0.914
        Current best Accuracy is 0.914
        


  4%|▍         | 2/50 [01:06<26:15, 32.83s/it]


        Train Accuracy is 0.783
        Validation Accuracy is 0.903
        Current best Accuracy is 0.914
        


  6%|▌         | 3/50 [01:37<25:12, 32.18s/it]

Best Trial Renewed

        Train Accuracy is 0.893
        Validation Accuracy is 0.939
        Current best Accuracy is 0.939
        


  8%|▊         | 4/50 [02:09<24:24, 31.83s/it]


        Train Accuracy is 0.9
        Validation Accuracy is 0.888
        Current best Accuracy is 0.939
        


 10%|█         | 5/50 [02:41<23:59, 31.98s/it]

Best Trial Renewed

        Train Accuracy is 0.925
        Validation Accuracy is 0.947
        Current best Accuracy is 0.947
        


 12%|█▏        | 6/50 [03:14<23:48, 32.46s/it]


        Train Accuracy is 0.935
        Validation Accuracy is 0.922
        Current best Accuracy is 0.947
        


 14%|█▍        | 7/50 [03:35<20:25, 28.49s/it]

Best Trial Renewed

        Train Accuracy is 0.941
        Validation Accuracy is 0.953
        Current best Accuracy is 0.953
        


 16%|█▌        | 8/50 [03:55<18:07, 25.90s/it]

Best Trial Renewed

        Train Accuracy is 0.953
        Validation Accuracy is 0.981
        Current best Accuracy is 0.981
        


 18%|█▊        | 9/50 [04:15<16:32, 24.20s/it]


        Train Accuracy is 0.958
        Validation Accuracy is 0.969
        Current best Accuracy is 0.981
        


 20%|██        | 10/50 [04:36<15:21, 23.04s/it]


        Train Accuracy is 0.97
        Validation Accuracy is 0.978
        Current best Accuracy is 0.981
        


 22%|██▏       | 11/50 [04:58<14:47, 22.75s/it]


        Train Accuracy is 0.958
        Validation Accuracy is 0.976
        Current best Accuracy is 0.981
        


 24%|██▍       | 12/50 [05:36<17:17, 27.31s/it]


        Train Accuracy is 0.973
        Validation Accuracy is 0.972
        Current best Accuracy is 0.981
        


 26%|██▌       | 13/50 [06:11<18:16, 29.64s/it]


        Train Accuracy is 0.966
        Validation Accuracy is 0.98
        Current best Accuracy is 0.981
        


 28%|██▊       | 14/50 [06:46<18:49, 31.38s/it]


        Train Accuracy is 0.959
        Validation Accuracy is 0.969
        Current best Accuracy is 0.981
        


 30%|███       | 15/50 [07:21<18:52, 32.35s/it]


        Train Accuracy is 0.978
        Validation Accuracy is 0.973
        Current best Accuracy is 0.981
        


 32%|███▏      | 16/50 [07:54<18:30, 32.66s/it]


        Train Accuracy is 0.944
        Validation Accuracy is 0.914
        Current best Accuracy is 0.981
        


 34%|███▍      | 17/50 [08:31<18:38, 33.89s/it]


        Train Accuracy is 0.953
        Validation Accuracy is 0.949
        Current best Accuracy is 0.981
        


 36%|███▌      | 18/50 [09:03<17:47, 33.35s/it]


        Train Accuracy is 0.959
        Validation Accuracy is 0.966
        Current best Accuracy is 0.981
        


 38%|███▊      | 19/50 [09:36<17:09, 33.20s/it]


        Train Accuracy is 0.969
        Validation Accuracy is 0.952
        Current best Accuracy is 0.981
        


 40%|████      | 20/50 [10:08<16:24, 32.81s/it]


        Train Accuracy is 0.954
        Validation Accuracy is 0.969
        Current best Accuracy is 0.981
        


 42%|████▏     | 21/50 [10:41<16:00, 33.11s/it]


        Train Accuracy is 0.97
        Validation Accuracy is 0.975
        Current best Accuracy is 0.981
        


 44%|████▍     | 22/50 [11:15<15:33, 33.34s/it]


        Train Accuracy is 0.969
        Validation Accuracy is 0.978
        Current best Accuracy is 0.981
        


 46%|████▌     | 23/50 [11:52<15:30, 34.46s/it]


        Train Accuracy is 0.976
        Validation Accuracy is 0.974
        Current best Accuracy is 0.981
        


 48%|████▊     | 24/50 [12:28<15:03, 34.75s/it]


        Train Accuracy is 0.959
        Validation Accuracy is 0.972
        Current best Accuracy is 0.981
        


 50%|█████     | 25/50 [13:03<14:31, 34.86s/it]

Best Trial Renewed

        Train Accuracy is 0.974
        Validation Accuracy is 0.983
        Current best Accuracy is 0.983
        


 52%|█████▏    | 26/50 [13:35<13:33, 33.91s/it]


        Train Accuracy is 0.97
        Validation Accuracy is 0.972
        Current best Accuracy is 0.983
        


 54%|█████▍    | 27/50 [14:07<12:50, 33.49s/it]


        Train Accuracy is 0.956
        Validation Accuracy is 0.943
        Current best Accuracy is 0.983
        


 56%|█████▌    | 28/50 [14:40<12:10, 33.20s/it]


        Train Accuracy is 0.967
        Validation Accuracy is 0.974
        Current best Accuracy is 0.983
        


 58%|█████▊    | 29/50 [15:14<11:47, 33.70s/it]


        Train Accuracy is 0.967
        Validation Accuracy is 0.981
        Current best Accuracy is 0.983
        


 60%|██████    | 30/50 [15:46<11:02, 33.13s/it]

Best Trial Renewed

        Train Accuracy is 0.978
        Validation Accuracy is 0.986
        Current best Accuracy is 0.986
        


 62%|██████▏   | 31/50 [16:19<10:25, 32.92s/it]


        Train Accuracy is 0.982
        Validation Accuracy is 0.949
        Current best Accuracy is 0.986
        


 64%|██████▍   | 32/50 [16:51<09:50, 32.80s/it]


        Train Accuracy is 0.982
        Validation Accuracy is 0.972
        Current best Accuracy is 0.986
        


 66%|██████▌   | 33/50 [17:23<09:12, 32.48s/it]

Best Trial Renewed

        Train Accuracy is 0.974
        Validation Accuracy is 0.987
        Current best Accuracy is 0.987
        


 68%|██████▊   | 34/50 [17:55<08:39, 32.47s/it]


        Train Accuracy is 0.97
        Validation Accuracy is 0.958
        Current best Accuracy is 0.987
        


 70%|███████   | 35/50 [18:28<08:06, 32.45s/it]


        Train Accuracy is 0.973
        Validation Accuracy is 0.975
        Current best Accuracy is 0.987
        


 72%|███████▏  | 36/50 [19:00<07:34, 32.47s/it]


        Train Accuracy is 0.978
        Validation Accuracy is 0.973
        Current best Accuracy is 0.987
        


 74%|███████▍  | 37/50 [19:33<07:02, 32.46s/it]


        Train Accuracy is 0.97
        Validation Accuracy is 0.977
        Current best Accuracy is 0.987
        


 76%|███████▌  | 38/50 [20:05<06:29, 32.48s/it]


        Train Accuracy is 0.971
        Validation Accuracy is 0.975
        Current best Accuracy is 0.987
        


 78%|███████▊  | 39/50 [20:37<05:54, 32.18s/it]


        Train Accuracy is 0.966
        Validation Accuracy is 0.978
        Current best Accuracy is 0.987
        


 80%|████████  | 40/50 [21:09<05:22, 32.25s/it]


        Train Accuracy is 0.98
        Validation Accuracy is 0.916
        Current best Accuracy is 0.987
        


 82%|████████▏ | 41/50 [21:42<04:50, 32.32s/it]


        Train Accuracy is 0.972
        Validation Accuracy is 0.981
        Current best Accuracy is 0.987
        


 84%|████████▍ | 42/50 [22:14<04:18, 32.36s/it]


        Train Accuracy is 0.978
        Validation Accuracy is 0.975
        Current best Accuracy is 0.987
        


 86%|████████▌ | 43/50 [22:46<03:44, 32.14s/it]


        Train Accuracy is 0.978
        Validation Accuracy is 0.978
        Current best Accuracy is 0.987
        


 88%|████████▊ | 44/50 [23:18<03:13, 32.25s/it]


        Train Accuracy is 0.976
        Validation Accuracy is 0.976
        Current best Accuracy is 0.987
        


 90%|█████████ | 45/50 [23:51<02:41, 32.32s/it]


        Train Accuracy is 0.983
        Validation Accuracy is 0.986
        Current best Accuracy is 0.987
        


 92%|█████████▏| 46/50 [24:22<02:08, 32.12s/it]


        Train Accuracy is 0.973
        Validation Accuracy is 0.963
        Current best Accuracy is 0.987
        


 94%|█████████▍| 47/50 [24:55<01:36, 32.20s/it]


        Train Accuracy is 0.975
        Validation Accuracy is 0.973
        Current best Accuracy is 0.987
        


 96%|█████████▌| 48/50 [25:23<01:02, 31.13s/it]


        Train Accuracy is 0.965
        Validation Accuracy is 0.973
        Current best Accuracy is 0.987
        


 98%|█████████▊| 49/50 [25:52<00:30, 30.37s/it]


        Train Accuracy is 0.958
        Validation Accuracy is 0.98
        Current best Accuracy is 0.987
        


100%|██████████| 50/50 [26:21<00:00, 31.62s/it]


        Train Accuracy is 0.978
        Validation Accuracy is 0.981
        Current best Accuracy is 0.987
        





# 출력층을 별도로 학습한 뒤 전체 레이어를 fine-Tuning하는 방안 

In [57]:
resnet50 = models.resnet50(pretrained=False).to(device)
weight=torch.load('./runs/resnet/output_trained_once.pt')

def inference(model,test_loader,weight,device):

    model.fc = nn.Linear(model.fc.in_features, 1).to(device)
    model.load_state_dict(weight)

    test_correct=0
    count=0

    model.eval()
    with torch.no_grad():
        for xx,yy in test_loader:
            xx=xx.to(device)
            yy=yy.to(device)
            output=model(xx)
            predicted=torch.round(torch.sigmoid(output))
            count+=yy.size(0)
            test_correct+=(predicted==yy.unsqueeze(1)).sum().item()

    print(round(test_correct/count,3))

inference(resnet50,test_loader,weight,device)

0.865


# 전체 레이어를 학습하는, resnet의 backbone만 사용하는 fine-Tuning 방안 

In [58]:
weight=torch.load('./runs/resnet/resnet_fine_tuning.pt')
inference(resnet50,test_loader,weight,device)

0.981
