<a href="https://colab.research.google.com/github/Yeongseok-Kim/PytorchClassReview/blob/master/gender_classificater.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import os
import torchvision
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
from PIL import Image



class MyDataset(Dataset):
    def __init__(self, female_dir, male_dir):
        self.female_list = [female_dir + file_name for file_name in os.listdir(female_dir)]
        self.male_list = [male_dir + file_name for file_name in os.listdir(male_dir)]
        self.transforms = transforms.Compose([transforms.CenterCrop((128, 128)),
                                              transforms.Resize((64, 64)),
                                              transforms.ToTensor(),
                                              transforms.Normalize((.5, .5, .5), (.5, .5, .5))])
    
    def __len__(self):
        return len(self.female_list + self.male_list)
    
    def __getitem__(self, index):
        female_len = len(self.female_list)
        if index < female_len:
            img = Image.open(self.female_list[index])
            label = torch.zeros(1)
        else:
            img = Image.open(self.male_list[index - female_len])
            label = torch.ones(1)
        img = self.transforms(img)
        return img, label

In [0]:
import torch.nn as nn



class Classificater(nn.Module):
    def __init__(self):
        super(Classificater, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(256, 512, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(512, 1024, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2))
        
        self.fc = nn.Sequential(
            nn.Linear(1024 * 2 * 2, 1),
            nn.Sigmoid())
        
    def forward(self, x):
        out = self.layer(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
!unzip '/content/drive/My Drive/gender_classification.zip'

In [0]:
import torch.optim as optim
from torch.utils.data import DataLoader

if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    batch_size = 64
    train_data = MyDataset('/content/train_set/female/', '/content/train_set/male/')
    train_set = DataLoader(train_data, batch_size, True)

    test_batch = 64
    test_data = MyDataset('/content/test_set/female/', '/content/test_set/male/')
    test_set = DataLoader(test_data, test_batch, True)

    learning_rate = 0.0002
    training_epochs = 50

    net = Classificater().to(device)
    criterion = nn.BCELoss().to(device)
    optimizer = optim.Adam(net.parameters(), learning_rate)

    print('Learning started. it takes sometime.')

    for epoch in range(training_epochs):
        for img, label in train_set:
            batch_size = img.size(0)

            img = img.to(device)
            label = label.to(device)

            optimizer.zero_grad()
            hypothesis = net(img)
            cost = criterion(hypothesis, label)

            cost.backward()
            optimizer.step()

            print("train cost : %f" % cost.item())

        print('***** epoch %d is over *****' % epoch)

        if epoch + 1 % 10 == 0:
            with torch.no_grad():
                for img, label in test_set:
                    test_batch = img.size(0)

                    img = img.to(device)
                    label = label.to(device)

                    prediction = net(img)
                    cost = criterion(prediction, label)

                    print("test cost : %f" % cost.item())
            
            torch.save(net.state_dict(), './gender_classificater_epoch_%d.pth' % epoch)
    
    print('Learning Finished.')