In [1]:
import glob
import os.path as osp
import random
import numpy as np
import json
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from torchvision.models import resnet18

torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)

In [3]:
datapath = "./data/"
resize = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

train_dataset = ImageFolder(
    datapath + "train",
    transform=transforms.Compose([
        transforms.Resize((resize, resize)),
        transforms.RandomResizedCrop(resize, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]))
test_dataset = ImageFolder(
    datapath + "test",
    transform=transforms.Compose([
        transforms.Resize((resize, resize)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]))

In [4]:
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

In [5]:
net = resnet18(pretrained=True)

for p in net.parameters():
    p.requires_grid=False
    
fc_input_dim = net.fc.in_features
net.fc = nn.Linear(fc_input_dim, 2)
# print(net)

In [8]:
def eval_net(net, data_loader, device="cpu"):
    net.eval()
    ys = []
    ypreds = []
    for x, y in data_loader:
        x = x.to(device)
        y = y.to(device)
        with torch.no_grad():
            y_pred = net(x).argmax(1)
        ys.append(y)
        ypreds.append(y_pred)
    ys = torch.cat(ys)
    ypreds = torch.cat(ypreds)
    acc = (ys == ypreds).float().sum() / len(ys)
    return acc.item()

train_losses = []
train_acc = []
val_acc = []

def train_net(net, train_loader, test_loader, only_fc=True,
              optimizer_cls=optim.Adam, loss_fn=nn.CrossEntropyLoss(),
              num_epoch=10, device="cpu"):
    print("使用デバイス:", device)
    if only_fc:
        optimizer = optimizer_cls(net.fc.parameters())
    else:
        optimizer = optimizer_cls(net.parameters())
    for epoch in range(num_epoch):
        running_loss = 0.0
        net.train()
        n = 0
        n_acc = 0
        for i, (xx, yy) in tqdm(enumerate(train_loader), total = len(train_loader)):
            xx = xx.to(device)
            yy = yy.to(device)
            h = net(xx)
            loss = loss_fn(h, yy)
            optimizer.zero_grad()
            running_loss += loss.item() * xx.size(0)
            loss.backward()
            optimizer.step()
            n += len(xx)
            y_pred = h.argmax(1)
            n_acc += (yy == y_pred).float().sum().item()
        train_losses.append(running_loss / len(train_loader.dataset))
        train_acc.append(n_acc / n)
        val_acc.append(eval_net(net, test_loader, device))
        print("epoch: {}\ttrain_loss: {:.3f}\ttrain_acc: {:.3f}\tval_acc: {:.3f}".format(
             epoch+1, train_losses[-1], train_acc[-1], val_acc[-1]), flush=True)

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
train_net(net, train_loader, test_loader, device=device)

使用デバイス: cpu


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:18<00:00, 18.63s/it]


epoch: 0	train_loss: 0.677	train_acc: 0.532	val_acc: 0.667


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:20<00:00, 20.56s/it]


epoch: 1	train_loss: 0.639	train_acc: 0.660	val_acc: 0.833


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:23<00:00, 23.39s/it]


epoch: 2	train_loss: 0.639	train_acc: 0.702	val_acc: 0.833


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:26<00:00, 26.49s/it]


epoch: 3	train_loss: 0.605	train_acc: 0.766	val_acc: 0.833


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:26<00:00, 26.19s/it]


epoch: 4	train_loss: 0.518	train_acc: 0.787	val_acc: 0.833


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:24<00:00, 24.23s/it]


epoch: 5	train_loss: 0.536	train_acc: 0.702	val_acc: 0.833


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:26<00:00, 26.67s/it]


epoch: 6	train_loss: 0.503	train_acc: 0.702	val_acc: 0.833


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:24<00:00, 24.22s/it]


epoch: 7	train_loss: 0.436	train_acc: 0.809	val_acc: 0.833


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:23<00:00, 23.99s/it]


epoch: 8	train_loss: 0.456	train_acc: 0.787	val_acc: 0.833


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:27<00:00, 27.78s/it]


epoch: 9	train_loss: 0.414	train_acc: 0.830	val_acc: 0.833
