In [1]:
# pytorch
import torch
import torchvision
from torchvision import transforms, datasets, models
from torchsummary import summary
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torcheval.metrics.functional import multiclass_f1_score
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# other
import numpy as np
import matplotlib.pyplot as plt
import copy
import time
import glob
from tqdm import tqdm

In [2]:
class appleDataset(Dataset):
    def __init__(self, path, train, transform=None):
        self.path = path
        if train:
            self.AbNorm_path = path + '/train/AbNorm/'
            self.Mot_path = path + '/train/Mot/'
            self.Norm_path = path + '/train/Norm/'
        else:
            self.AbNorm_path = path + '/test/AbNorm/'
            self.Mot_path = path + '/test/Mot/'
            self.Norm_path = path + '/test/Norm/'
        
        self.AbNorm_img_list = glob.glob(self.AbNorm_path + '/*.jpg')
        self.Mot_img_list = glob.glob(self.Mot_path + '/*.jpg')
        self.Norm_img_list = glob.glob(self.Norm_path + '/*.jpg')

        self.transform = transform

        self.img_list = self.AbNorm_img_list + self.Mot_img_list + self.Norm_img_list
        self.class_list = [0] * len(self.AbNorm_img_list) + [1] * len(self.Mot_img_list) + [2] * len(self.Norm_img_list)
        
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        img_path = self.img_list[idx]
        label = self.class_list[idx]
        img = Image.open(img_path)

        if self.transform is not None:
            img = self.transform(img)

        return img, label


In [3]:
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = appleDataset(path='./input', train=True, transform=transform)
trainloader = DataLoader(dataset=train_dataset,
                        batch_size=512,
                        shuffle=True,
                        drop_last=False)
    
test_dataset = appleDataset(path='./input', train=False, transform=transform)
testloader = DataLoader(dataset=test_dataset,
                        batch_size=256,
                        shuffle=True,
                        drop_last=False
                        )

In [4]:
classes = ('Norm','Mot','AbNorm')

In [5]:
resnet_pt = models.resnet18(weights=True)



In [6]:
unfreeze = ['layer4.0.conv1.weight', 'layer4.0.bn1.weight', 'layer4.0.bn1.bias', 'layer4.0.bn1.running_mean', 'layer4.0.bn1.running_var', 'layer4.0.bn1.num_batches_tracked', 'layer4.0.conv2.weight', 'layer4.0.bn2.weight', 'layer4.0.bn2.bias', 'layer4.0.bn2.running_mean', 'layer4.0.bn2.running_var', 'layer4.0.bn2.num_batches_tracked', 'layer4.0.downsample.0.weight', 'layer4.0.downsample.1.weight', 'layer4.0.downsample.1.bias', 'layer4.0.downsample.1.running_mean', 'layer4.0.downsample.1.running_var', 'layer4.0.downsample.1.num_batches_tracked', 'layer4.1.conv1.weight', 'layer4.1.bn1.weight', 'layer4.1.bn1.bias', 'layer4.1.bn1.running_mean', 'layer4.1.bn1.running_var', 'layer4.1.bn1.num_batches_tracked', 'layer4.1.conv2.weight', 'layer4.1.bn2.weight', 'layer4.1.bn2.bias', 'layer4.1.bn2.running_mean', 'layer4.1.bn2.running_var', 'layer4.1.bn2.num_batches_tracked', 'fc.weight', 'fc.bias'] 

In [7]:
resnet_pt = models.resnet18(weights=True)
# freezing
for name, param in resnet_pt.named_parameters():
    if name in unfreeze:
        param.requires_grad = True
    else:
        param.requires_grad = False
    
# fc layer 수정
fc_in_features = resnet_pt.fc.in_features
resnet_pt.fc = nn.Linear(fc_in_features, len(classes))
resnet_pt = resnet_pt.to(device)

In [8]:
summary(resnet_pt, (3,32,32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 16, 16]           9,408
       BatchNorm2d-2           [-1, 64, 16, 16]             128
              ReLU-3           [-1, 64, 16, 16]               0
         MaxPool2d-4             [-1, 64, 8, 8]               0
            Conv2d-5             [-1, 64, 8, 8]          36,864
       BatchNorm2d-6             [-1, 64, 8, 8]             128
              ReLU-7             [-1, 64, 8, 8]               0
            Conv2d-8             [-1, 64, 8, 8]          36,864
       BatchNorm2d-9             [-1, 64, 8, 8]             128
             ReLU-10             [-1, 64, 8, 8]               0
       BasicBlock-11             [-1, 64, 8, 8]               0
           Conv2d-12             [-1, 64, 8, 8]          36,864
      BatchNorm2d-13             [-1, 64, 8, 8]             128
             ReLU-14             [-1, 6

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet_pt.parameters(), lr=0.001,
                      momentum=0.9)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

In [10]:
# Training
def train(epoch, model, criterion, optimizer):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    predicted_list = 0
    labels_list = 0
    Norm_count = {"correct":0,"total":0}
    Mot_count = {"correct":0,"total":0}
    AbNorm_count = {"correct":0,"total":0}
    
    for batch_idx, (inputs, labels) in enumerate(tqdm(trainloader)):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()*inputs.size(0)
        _, predicted = outputs.max(1)
        if batch_idx == 0:
            predicted_list=predicted
        else:
            predicted_list=torch.cat((predicted_list,predicted), dim=0)
            
        if batch_idx == 0:
            labels_list=labels
        else:
            labels_list=torch.cat((labels_list, labels), dim=0)
            
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        correct_compare = predicted.eq(labels)
        for i in range(3):
            if i==0:
                Norm_count['correct'] += predicted[correct_compare].tolist().count(i)
                Norm_count['total'] += predicted.tolist().count(i)
            elif i==1:
                Mot_count['correct'] += predicted[correct_compare].tolist().count(i)
                Mot_count['total'] += predicted.tolist().count(i)
            else:
                AbNorm_count['correct'] += predicted[correct_compare].tolist().count(i)
                AbNorm_count['total'] += predicted.tolist().count(i)
            
    epoch_loss = train_loss/total
    epoch_acc = correct/total*100
    f1_score = multiclass_f1_score(predicted_list, labels_list, num_classes=3, average="macro")
    Norm = Norm_count['correct']/Norm_count['total']
    Mot = Mot_count['correct']/Mot_count['total']
    AbNorm = AbNorm_count['correct']/AbNorm_count['total']
    print(f"Test | Loss: {epoch_loss: .4f} Acc: {epoch_acc: .2f}  F1-Score: {f1_score:.2f} ({correct}/{total}) Norm:{Norm:.2f} Mot:{Mot:.2f} AbNorm:{AbNorm:.2f}")
    return epoch_loss, epoch_acc, f1_score

def test(epoch, model, criterion, optimizer):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    predicted_list = 0
    labels_list = 0
    Norm_count = {"correct":0,"total":0}
    Mot_count = {"correct":0,"total":0}
    AbNorm_count = {"correct":0,"total":0}
    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(tqdm(testloader)):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            test_loss += loss.item()*inputs.size(0)
            _, predicted = outputs.max(1)
            if batch_idx == 0:
                predicted_list=predicted
            else:
                predicted_list=torch.cat((predicted_list,predicted), dim=0)
            
            if batch_idx == 0:
                labels_list=labels
            else:
                labels_list=torch.cat((labels_list, labels), dim=0)
                
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            correct_compare = predicted.eq(labels)
            
            for i in range(3):
                if i==0:
                    Norm_count['correct'] += predicted[correct_compare].tolist().count(i)
                    Norm_count['total'] += predicted.tolist().count(i)
                elif i==1:
                    Mot_count['correct'] += predicted[correct_compare].tolist().count(i)
                    Mot_count['total'] += predicted.tolist().count(i)
                else:
                    AbNorm_count['correct'] += predicted[correct_compare].tolist().count(i)
                    AbNorm_count['total'] += predicted.tolist().count(i)

            
        epoch_loss = test_loss/total
        epoch_acc = correct/total*100
        f1_score = multiclass_f1_score(predicted_list, labels_list, num_classes=3, average="macro")
        Norm = Norm_count['correct']/Norm_count['total']
        Mot = Mot_count['correct']/Mot_count['total']
        AbNorm = AbNorm_count['correct']/AbNorm_count['total']
        print(f"Test | Loss: {epoch_loss: .4f} Acc: {epoch_acc: .2f}  F1-Score: {f1_score:.2f} ({correct}/{total}) Norm:{Norm:.2f} Mot:{Mot:.2f} AbNorm:{AbNorm:.2f}")
    return epoch_loss, epoch_acc, f1_score

In [11]:
tart_time = time.time()
best_acc = 0
epoch_length = 20
save_loss = {"train":[],
             "test":[]}
save_acc = {"train":[],
             "test":[]}
save_f1 = {"train":[],
            "test":[]}
start_time= time.time()

for epoch in range(epoch_length):
    print("Epoch %s" % epoch)
    train_loss, train_acc, train_f1 = train(epoch, resnet_pt, criterion, optimizer)
    save_loss['train'].append(train_loss)
    save_acc['train'].append(train_acc)
    save_f1['train'].append(train_f1)
    
    test_loss, test_acc, test_f1 = test(epoch, resnet_pt, criterion, optimizer)
    save_loss['test'].append(test_loss)
    save_acc['test'].append(test_acc)
    save_f1['test'].append(test_f1)

    scheduler.step()

    # Save model
    if test_acc > best_acc:
        best_acc = test_acc
        best_model_wts = copy.deepcopy(resnet_pt.state_dict())
    resnet_pt.load_state_dict(best_model_wts)

learning_time = time.time() - start_time
print(f'**Learning time: {learning_time // 60:.0f}m {learning_time % 60:.0f}s')

Epoch 0


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [12:56<00:00,  9.35s/it]


Test | Loss:  0.3830 Acc:  86.55  F1-Score: 0.78 (36720/42426) Norm:0.85 Mot:0.94 AbNorm:0.90


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:15<00:00,  3.58s/it]


Test | Loss:  0.1659 Acc:  95.53  F1-Score: 0.94 (5069/5306) Norm:0.96 Mot:0.98 AbNorm:0.93
Epoch 1


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [12:43<00:00,  9.20s/it]


Test | Loss:  0.1228 Acc:  96.44  F1-Score: 0.95 (40917/42426) Norm:0.97 Mot:0.98 AbNorm:0.94


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:14<00:00,  3.57s/it]


Test | Loss:  0.1042 Acc:  96.80  F1-Score: 0.96 (5136/5306) Norm:0.97 Mot:0.98 AbNorm:0.94
Epoch 2


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [12:48<00:00,  9.26s/it]


Test | Loss:  0.0851 Acc:  97.42  F1-Score: 0.97 (41330/42426) Norm:0.98 Mot:0.99 AbNorm:0.95


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:14<00:00,  3.53s/it]


Test | Loss:  0.0809 Acc:  97.32  F1-Score: 0.97 (5164/5306) Norm:0.98 Mot:0.98 AbNorm:0.95
Epoch 3


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [12:42<00:00,  9.18s/it]


Test | Loss:  0.0664 Acc:  97.97  F1-Score: 0.97 (41564/42426) Norm:0.98 Mot:0.99 AbNorm:0.96


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:14<00:00,  3.56s/it]


Test | Loss:  0.0681 Acc:  97.78  F1-Score: 0.97 (5188/5306) Norm:0.98 Mot:0.99 AbNorm:0.96
Epoch 4


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [12:51<00:00,  9.30s/it]


Test | Loss:  0.0554 Acc:  98.29  F1-Score: 0.98 (41700/42426) Norm:0.99 Mot:0.99 AbNorm:0.97


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:14<00:00,  3.54s/it]


Test | Loss:  0.0596 Acc:  98.02  F1-Score: 0.98 (5201/5306) Norm:0.99 Mot:0.99 AbNorm:0.96
Epoch 5


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [12:38<00:00,  9.14s/it]


Test | Loss:  0.0476 Acc:  98.51  F1-Score: 0.98 (41794/42426) Norm:0.99 Mot:0.99 AbNorm:0.97


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:13<00:00,  3.51s/it]


Test | Loss:  0.0532 Acc:  98.13  F1-Score: 0.98 (5207/5306) Norm:0.99 Mot:0.99 AbNorm:0.96
Epoch 6


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [12:39<00:00,  9.15s/it]


Test | Loss:  0.0413 Acc:  98.75  F1-Score: 0.98 (41897/42426) Norm:0.99 Mot:0.99 AbNorm:0.97


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:13<00:00,  3.50s/it]


Test | Loss:  0.0483 Acc:  98.40  F1-Score: 0.98 (5221/5306) Norm:0.99 Mot:0.99 AbNorm:0.97
Epoch 7


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [12:36<00:00,  9.12s/it]


Test | Loss:  0.0372 Acc:  98.91  F1-Score: 0.99 (41963/42426) Norm:0.99 Mot:0.99 AbNorm:0.98


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:13<00:00,  3.51s/it]


Test | Loss:  0.0444 Acc:  98.55  F1-Score: 0.98 (5229/5306) Norm:0.99 Mot:0.99 AbNorm:0.97
Epoch 8


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [12:38<00:00,  9.14s/it]


Test | Loss:  0.0333 Acc:  99.01  F1-Score: 0.99 (42008/42426) Norm:0.99 Mot:0.99 AbNorm:0.98


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:13<00:00,  3.50s/it]


Test | Loss:  0.0413 Acc:  98.62  F1-Score: 0.98 (5233/5306) Norm:0.99 Mot:0.99 AbNorm:0.97
Epoch 9


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [12:40<00:00,  9.16s/it]


Test | Loss:  0.0302 Acc:  99.12  F1-Score: 0.99 (42052/42426) Norm:0.99 Mot:1.00 AbNorm:0.98


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:14<00:00,  3.56s/it]


Test | Loss:  0.0389 Acc:  98.74  F1-Score: 0.98 (5239/5306) Norm:0.99 Mot:0.99 AbNorm:0.98
Epoch 10


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [12:46<00:00,  9.23s/it]


Test | Loss:  0.0269 Acc:  99.22  F1-Score: 0.99 (42093/42426) Norm:0.99 Mot:1.00 AbNorm:0.98


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:13<00:00,  3.49s/it]


Test | Loss:  0.0368 Acc:  98.76  F1-Score: 0.98 (5240/5306) Norm:0.99 Mot:0.99 AbNorm:0.98
Epoch 11


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [12:36<00:00,  9.11s/it]


Test | Loss:  0.0248 Acc:  99.28  F1-Score: 0.99 (42121/42426) Norm:0.99 Mot:1.00 AbNorm:0.98


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:13<00:00,  3.50s/it]


Test | Loss:  0.0353 Acc:  98.79  F1-Score: 0.99 (5242/5306) Norm:0.99 Mot:0.99 AbNorm:0.98
Epoch 12


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [12:37<00:00,  9.12s/it]


Test | Loss:  0.0224 Acc:  99.36  F1-Score: 0.99 (42155/42426) Norm:1.00 Mot:1.00 AbNorm:0.99


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:13<00:00,  3.50s/it]


Test | Loss:  0.0334 Acc:  98.85  F1-Score: 0.99 (5245/5306) Norm:0.99 Mot:0.99 AbNorm:0.98
Epoch 13


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [12:39<00:00,  9.15s/it]


Test | Loss:  0.0209 Acc:  99.41  F1-Score: 0.99 (42177/42426) Norm:1.00 Mot:1.00 AbNorm:0.99


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:16<00:00,  3.62s/it]


Test | Loss:  0.0322 Acc:  98.98  F1-Score: 0.99 (5252/5306) Norm:0.99 Mot:0.99 AbNorm:0.98
Epoch 14


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [12:47<00:00,  9.25s/it]


Test | Loss:  0.0195 Acc:  99.46  F1-Score: 0.99 (42196/42426) Norm:1.00 Mot:1.00 AbNorm:0.99


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:14<00:00,  3.53s/it]


Test | Loss:  0.0309 Acc:  98.98  F1-Score: 0.99 (5252/5306) Norm:0.99 Mot:0.99 AbNorm:0.98
Epoch 15


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [12:52<00:00,  9.30s/it]


Test | Loss:  0.0193 Acc:  99.50  F1-Score: 0.99 (42214/42426) Norm:1.00 Mot:1.00 AbNorm:0.99


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:17<00:00,  3.68s/it]


Test | Loss:  0.0308 Acc:  99.02  F1-Score: 0.99 (5254/5306) Norm:0.99 Mot:0.99 AbNorm:0.98
Epoch 16


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [13:27<00:00,  9.73s/it]


Test | Loss:  0.0182 Acc:  99.51  F1-Score: 0.99 (42218/42426) Norm:1.00 Mot:1.00 AbNorm:0.99


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:16<00:00,  3.63s/it]


Test | Loss:  0.0299 Acc:  98.98  F1-Score: 0.99 (5252/5306) Norm:0.99 Mot:0.99 AbNorm:0.98
Epoch 17


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [13:09<00:00,  9.51s/it]


Test | Loss:  0.0180 Acc:  99.54  F1-Score: 0.99 (42231/42426) Norm:1.00 Mot:1.00 AbNorm:0.99


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:18<00:00,  3.73s/it]


Test | Loss:  0.0301 Acc:  98.91  F1-Score: 0.99 (5248/5306) Norm:0.99 Mot:0.99 AbNorm:0.98
Epoch 18


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [13:26<00:00,  9.71s/it]


Test | Loss:  0.0180 Acc:  99.53  F1-Score: 0.99 (42227/42426) Norm:1.00 Mot:1.00 AbNorm:0.99


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:17<00:00,  3.71s/it]


Test | Loss:  0.0300 Acc:  99.02  F1-Score: 0.99 (5254/5306) Norm:0.99 Mot:0.99 AbNorm:0.98
Epoch 19


100%|██████████████████████████████████████████████████████████████████████████████████| 83/83 [13:25<00:00,  9.71s/it]


Test | Loss:  0.0180 Acc:  99.53  F1-Score: 0.99 (42225/42426) Norm:1.00 Mot:1.00 AbNorm:0.99


100%|██████████████████████████████████████████████████████████████████████████████████| 21/21 [01:18<00:00,  3.74s/it]

Test | Loss:  0.0297 Acc:  99.02  F1-Score: 0.99 (5254/5306) Norm:0.99 Mot:0.99 AbNorm:0.98
**Learning time: 282m 7s



