In [1]:
import torch
import torch.nn as nn

from torch import optim
from torchvision import models
from torch.utils.data import Dataset,DataLoader
from torchvision import datasets,transforms

In [2]:
dataset = datasets.ImageFolder(root ='train/',
                               transform=transforms.Compose([transforms.Resize(256),
                                                             transforms.ColorJitter(),
                                                             transforms.RandomCrop(224),
                                                             transforms.RandomHorizontalFlip(),
                                                             transforms.Resize(128),
                                                             transforms.ToTensor(),
                                                             transforms.Normalize([0.485, 0.456, 0.406],
                                                            [0.229, 0.224, 0.225])
                               ]))

train_set, val_set = torch.utils.data.random_split(dataset,[int(len(dataset)*0.8),int(len(dataset)*0.2)])

train_dataloader = DataLoader(train_set,batch_size=512,
                        shuffle=True,num_workers=8,
                        drop_last=True,pin_memory=True)

val_dataloader = DataLoader(val_set)

In [3]:
vgg16 = models.vgg16(pretrained=True)
for p in vgg16.parameters():
    p.requires_grad = False

In [4]:
num_featrues = num_features = vgg16.classifier[6].in_features
features = list(vgg16.classifier.children())[:-1]
features.extend([nn.Linear(num_features, 2)])
vgg16.classifier = nn.Sequential(*features)

In [5]:
print(vgg16)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = vgg16.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.001)

epochs = 10

In [7]:
plot_list = {'train':[],'val':[],'accuracy':[]}

for epoch in range(epochs):
    
    model.train()
    train_losses = []
    for i,data in enumerate(train_dataloader):
        inputs,targets = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        
        outputs = model(inputs.detach())
        train_loss = criterion(outputs,targets)
        train_loss.backward()
        optimizer.step()
        
        train_losses.append(train_loss.detach()) ## train_loss를 detach하지않으면 gpu에 계속 남아서 메모리를 잡아먹음...
                                                 ## 진짜 애많이먹었습니다....
                                                 ## data와 inputs도 혹시모르니...삭제
        del data
        del inputs
        
    
    val_losses = []
    correct = 0
    model.eval()
    for data in val_dataloader:
        inputs,targets = data[0].to(device), data[1].to(device)

        outputs = model(inputs)
        val_loss = criterion(outputs,targets)
        val_losses.append(val_loss.detach())

        prob,label = torch.exp(outputs).topk(1,dim=1)
        if targets==label.view(1):
            correct+=1
        del data
        del inputs
    val_accuracy = correct/len(val_set)
    
    print(f"{epoch+1} epoch train loss = {sum(train_losses)/len(train_losses)}")
    print(f"{epoch+1} epoch val loss = {sum(val_losses)/len(val_losses)}")
    print(f"{epoch+1} epoch accuracy = {val_accuracy}")
    print('--------------------------------------------------')
    plot_list['train'].append(sum(train_losses)/len(train_losses))
    plot_list['val'].append(sum(val_losses)/len(val_losses))
    plot_list['accuracy'].append(val_accuracy)
    
#     if epoch < 11:continue
#     if sum(plot_list['val'][-11:-6])/5 < sum(val_losses)/len(val_losses):
#         print(f'over_fitting is occured at {epoch} epoch')
#         break
        

1 epoch train loss = 0.1267513930797577
1 epoch val loss = 0.06851354241371155
1 epoch accuracy = 0.9744
--------------------------------------------------
2 epoch train loss = 0.07510467618703842
2 epoch val loss = 0.0647452101111412
2 epoch accuracy = 0.9748
--------------------------------------------------
3 epoch train loss = 0.0728272944688797
3 epoch val loss = 0.058975204825401306
3 epoch accuracy = 0.976
--------------------------------------------------
4 epoch train loss = 0.06776028871536255
4 epoch val loss = 0.06174173951148987
4 epoch accuracy = 0.9764
--------------------------------------------------
5 epoch train loss = 0.06574101746082306
5 epoch val loss = 0.06877242773771286
5 epoch accuracy = 0.9714
--------------------------------------------------
6 epoch train loss = 0.06959791481494904
6 epoch val loss = 0.05596082657575607
6 epoch accuracy = 0.9778
--------------------------------------------------
7 epoch train loss = 0.06382209062576294
7 epoch val loss = 0