In [15]:
import numpy as np
import cv2
import os
import random
import time
import torch
from torch.utils.data.dataset import Dataset
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



In [2]:
class ImageDataset(Dataset):
    def __init__(self, imgs, labels):
        super().__init__()
        self.imgList = []
        self.labelList = []
        for i in range(len(imgs)):
            img = cv2.resize(imgs[i], (112,112))/255.0
            img = img.astype(np.float32)
            imgTensor = torch.from_numpy(img.transpose((2, 0, 1)))
            self.imgList.append(imgTensor)
            self.labelList.append(labels[i])

    def __getitem__(self, index):
        return self.imgList[index], self.labelList[index]
    def __len__(self):
        return len(self.labelList)

In [3]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channel, k_size, pad=1, s=1, dilation=1):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channel, kernel_size=k_size, padding=pad, stride = s, dilation=dilation)
        self.batchNorm = nn.BatchNorm2d(out_channel)
        self.actfunction = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.batchNorm(x)
        x = self.actfunction(x)
        return x

class Fullyconnect(nn.Module):
    def __init__(self, in_channels, out_channel):
        super().__init__()
        self.fc = nn.Linear(in_channels, out_channel)
        self.actfunction = nn.Sigmoid()

    def forward(self, x):
        x = self.fc(x)
        x = self.actfunction(x)
        return x


In [4]:
class Architecture(nn.Module):
    def __init__(self, numclass = 10):
        super().__init__()
        self.convs = nn.Sequential(  CNNBlock(in_channels = 3, out_channel = 32, k_size = 3),
                                    nn.MaxPool2d(kernel_size = 2,stride= 2),
                                    CNNBlock(in_channels =32, out_channel =64, k_size =3),
                                    nn.MaxPool2d(kernel_size = 2,stride= 2),
                                    CNNBlock(in_channels =64, out_channel =128, k_size =3),
                                    nn.AdaptiveAvgPool2d((3,3)),)
        self.fc = Fullyconnect(3*3*128, numclass)
    def forward(self, x):
        x = self.convs(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return F.softmax(x, dim=1)


In [5]:
class ce_loss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss = nn.CrossEntropyLoss()

    def forward(self, output, target):
        return self.loss(output, target.long())


In [10]:
imgPath = './train/train/'
label_dict = {'Bald': 0, 'Black_Hair': 1, 'Blond_Hair': 2,  
            'Brown_Hair': 3,'Gray_Hair': 4, 'Receding_Hairline': 5}

imgPathList = os.listdir(imgPath)

allImgList = []
allLabelList = []
randomList = []

trainImgList = []
trainLabelList = []

valImgList = []
valLabelList = []

for f in imgPathList:
    for imgName in os.listdir( imgPath + f ):
        img = cv2.imread(imgPath + f + '/' + imgName)
        allImgList.append(img)
        allLabelList.append(label_dict[f])

while(1):
    rint = random.randint(0,len(allImgList))
    if rint not in randomList:
        randomList.append(rint)
    if len(randomList)> len(allImgList)*0.2:
        break

for i in range(len(allImgList)):
    if i in randomList:
        valImgList.append(allImgList[i])
        valLabelList.append(allLabelList[i])
    else:
        trainImgList.append(allImgList[i])
        trainLabelList.append(allLabelList[i])

trainDataset = ImageDataset(trainImgList, trainLabelList)
valDataset = ImageDataset(valImgList, valLabelList)

dataLoaderTrain = torch.utils.data.DataLoader(  trainDataset,
                                                batch_size = 64,
                                                shuffle = True,
                                                num_workers = 0,
                                                drop_last = True)

dataLoaderVal = torch.utils.data.DataLoader(valDataset,
                                            batch_size = 64,
                                            shuffle = True,
                                            num_workers = 0,
                                            drop_last = True)


In [11]:
net = Architecture(numclass = 6).to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)
ceLoss = ce_loss().to(device)

In [17]:
# begin train
epochs = 100
torch.backends.cudnn.benchmark = True
for i in range(epochs):
    net.train()
    beginTime = time.time()
    for trainImage, trainLabel in dataLoaderTrain:
        optimizer.zero_grad()
        trainImage = trainImage.to(device)
        trainLabel = trainLabel.to(device, dtype=torch.long)
        output = net(trainImage)
        loss = ceLoss(output, trainLabel)
        loss.backward()
        optimizer.step()

    net.eval()
    for valImage, valLabel in dataLoaderVal:
        with torch.no_grad():
            valImage = valImage.to(device)
            valLabel = valLabel.to(device, dtype=torch.long)
            output = net(valImage)
            loss = ceLoss(output, trainLabel)
    endTime = time.time()

    print(i , "epoch time : ", (endTime - beginTime))
    torch.save(net, "./model/hair.pth")

0 epoch time :  1.8037354946136475
1 epoch time :  1.1136300563812256
2 epoch time :  0.8131828308105469
3 epoch time :  0.839240312576294
4 epoch time :  0.8883223533630371
5 epoch time :  0.7839264869689941
6 epoch time :  0.8766829967498779
7 epoch time :  0.821087121963501
8 epoch time :  0.7821059226989746
9 epoch time :  0.8038580417633057
10 epoch time :  0.8471462726593018
11 epoch time :  0.9104475975036621
12 epoch time :  0.777777910232544
13 epoch time :  0.8453929424285889
14 epoch time :  1.0638651847839355
15 epoch time :  0.8559796810150146
16 epoch time :  0.9814233779907227
17 epoch time :  0.8490517139434814
18 epoch time :  0.8708357810974121
19 epoch time :  0.7762899398803711
20 epoch time :  0.8401312828063965
21 epoch time :  0.9246985912322998
22 epoch time :  0.9828791618347168
23 epoch time :  0.909233570098877
24 epoch time :  0.8909540176391602
25 epoch time :  1.3194572925567627
26 epoch time :  0.8496527671813965
27 epoch time :  0.9592077732086182
28 epo