### CNN 해보기

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import transforms
from torchvision.datasets import CIFAR10

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = CIFAR10(root="./data", train=True, transform=transform_train, download=True)
test_dataset = CIFAR10(root="./data", train=False, transform=transform_test, download=False)

test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [None]:
indices = np.arange(len(train_dataset.targets))
train_idx, valid_idx = train_test_split(indices, test_size=0.2, shuffle=True, stratify=train_dataset.targets)

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = DataLoader(train_dataset, batch_size=128, sampler=train_sampler)
valid_loader = DataLoader(train_dataset, batch_size=128, sampler=valid_sampler)

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(256)
        self.bn4 = nn.BatchNorm2d(512)
        self.fc1 = nn.Linear(512, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), 2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(F.relu(self.bn3(self.conv3(x))), 2)
        x = self.bn4(self.conv4(x))
        x = F.adaptive_avg_pool2d(x, 1)
        x = x.view(x.size(0), -1)
        return self.fc1(x)

In [None]:
def train(model, dataloader, optimizer, device):
    train_loss = 0
    model.train()
    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss
    train_loss /= len(dataloader)
    return train_loss

In [None]:
def eval(model, dataloader, device):
    test_loss = 0
    correct = 0
    num_sample = 0
    model.eval()
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction="sum").item()
            pred = output.max(dim=1)[1]
            correct += (pred == target).sum().item()
            num_sample += len(target)
    test_loss /= num_sample
    test_acc = 100 * correct / num_sample
    return test_loss, test_acc

In [None]:
model = CNN().to(device)

optimizer = optim.Adam(model.parameters(), lr=0.1, weight_decay=0.0005)

In [None]:
count = 0
best_acc = 0.0
patience = 3
epsilon = 0.001

train_losses = []
valid_losses = []
valid_accs = []

for i in range(100):
    train_loss = train(model, train_loader, optimizer, device)
    train_losses.append(train_loss)

    valid_loss, valid_acc = eval(model, valid_loader, device)
    valid_losses.append(valid_loss)
    valid_accs.append(valid_acc)

    print(f"[{i+1}] Train Loss : {train_loss:.4f}")
    print(f"        Valid > Valid Loss : {valid_loss:.4f},  Valid Acc : {valid_acc:.2f}%")

    if valid_acc > best_acc + epsilon:
        print(f"  !!New best performance!! {valid_acc:.2f}% <- {best_acc:.2f}%")
        best_acc = valid_acc
        count = 0
        torch.save(model.state_dict(), "./model.pt")
    else:
        count += 1
        print(f"  > Stop count : {count} / {patience}")
        if count >= patience:
            break

In [None]:
state_dict = torch.load("./model.pt")
model.load_state_dict(state_dict)