In [1]:
# 2018310064 문학준
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import random_split

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Hyperparameters
batch_size = 256
learning_rate = 0.001
num_epoch = 110

In [3]:
transform_train = transforms.Compose([transforms.Resize((32,32)),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,),(0.5,))])

In [4]:
transform_test = transforms.Compose([transforms.Resize((32,32)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,),(0.5,))])

In [5]:
download_path = './CIFAR'
dataset = dset.CIFAR10(download_path, transform=transform_train, train = True, download = True)
test_dataset = dset.CIFAR10(download_path, transform=transform_test, train = False, download = True)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
train_dataset, val_dataset = random_split(dataset, [40000, 10000])

In [7]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [8]:
# transfer learning : resnet50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet50(weights="IMAGENET1K_V2").to(device)

In [9]:
loss_funcion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=0.000001)

In [10]:
# train
start_time = time.time()

len_val_loader = len(val_loader)
len_val_dataset = len(val_dataset)

for i in range(num_epoch):
    
    val_loss = 0.0
    val_num_corrects = 0.0
    
    
    for image, label in train_loader:
        x = image.to(device)
        y_ = label.to(device)

        optimizer.zero_grad()
        output = model.forward(x)
        loss = loss_funcion(output, y_)
        loss.backward()
        optimizer.step()
        
    scheduler.step()
    
    with torch.no_grad():
        for val_inputs, val_labels in val_loader:
            val_inputs = val_inputs.to(device)
            val_labels = val_labels.to(device)
            val_outputs = model(val_inputs)
            loss = loss_funcion(val_outputs, val_labels)

            _, val_preds = torch.max(val_outputs, 1)
            
            val_loss += loss.item()
            val_num_corrects += torch.sum(val_preds == val_labels.data)
            
        
    train_loss = loss.cpu().detach().numpy()
    
    val_epoch_loss = val_loss / len_val_loader
    val_epoch_acc = 100 * val_num_corrects.float() / len_val_dataset
    
    if (i % 2 == 0):
        print(f"epoch : {i}, train_loss : {train_loss:.4f}, val_loss : {val_epoch_loss:.4f}, val_acc : {val_epoch_acc:.2f} %, lr : {optimizer.param_groups[0]['lr']:.6f}, elapsed time : {time.time() - start_time:.2f} sec")

epoch : 0, train_loss : 0.8608, val_loss : 0.6532, val_acc : 78.97 %, lr : 0.001000, elapsed time : 27.04 sec
epoch : 2, train_loss : 0.7138, val_loss : 0.4762, val_acc : 84.20 %, lr : 0.000998, elapsed time : 78.02 sec
epoch : 4, train_loss : 0.4998, val_loss : 0.4478, val_acc : 86.20 %, lr : 0.000994, elapsed time : 128.96 sec
epoch : 6, train_loss : 0.2029, val_loss : 0.4517, val_acc : 85.79 %, lr : 0.000988, elapsed time : 179.60 sec
epoch : 8, train_loss : 0.3949, val_loss : 0.4872, val_acc : 85.36 %, lr : 0.000980, elapsed time : 230.29 sec
epoch : 10, train_loss : 0.2665, val_loss : 0.4590, val_acc : 85.99 %, lr : 0.000970, elapsed time : 281.08 sec
epoch : 12, train_loss : 0.8856, val_loss : 0.4913, val_acc : 85.72 %, lr : 0.000959, elapsed time : 331.78 sec
epoch : 14, train_loss : 0.9420, val_loss : 0.5013, val_acc : 85.66 %, lr : 0.000946, elapsed time : 382.32 sec
epoch : 16, train_loss : 0.3197, val_loss : 0.4814, val_acc : 86.09 %, lr : 0.000930, elapsed time : 432.77 sec

In [11]:
# test
correct = 0
total = 0

with torch.no_grad():
    for image, label in test_loader:
        x = image.to(device)
        y_ = label.to(device)

        output = model.forward(x)
        _,output_index = torch.max(output,1)
        total += label.size(0)
        correct += (output_index == y_).sum().float()

    print("Accuracy : {}", format(100*correct/total))

Accuracy : {} 87.7699966430664


In [12]:
path = './CIFAR/model_resnet50.pt'
torch.save(model.state_dict(), path)