In [1]:
import os
import numpy as np

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import DataLoader

from torchvision import datasets, models, transforms

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic=True
    torch.backends.cudnn.benchmark=False
set_seed(42)

In [4]:
from source import cifar_dataloader
path = './../data/'
loader_dict = cifar_dataloader.get_cifar100_dataloader(path = path, batch_size=128)
train_loader = loader_dict['train']
val_loader = loader_dict['val']

Files already downloaded and verified
Files already downloaded and verified


# 일반 resenet50을 학습
- loss : Cross-Entropy

In [5]:
model = models.resnet18(pretrained=False)

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
epochs = 500

model = model.to(device)

In [7]:
# early stopping
from source import EarlyStopping
from source import functions

early_stopping = EarlyStopping.EarlyStopping(patience=20, verbose=False)
model_ft, hist = functions.train_model(
    model = model, 
    dataloaders = loader_dict, 
    criterion = criterion, 
    optimizer = optimizer, 
    device=device, 
    num_epochs=epochs, 
    early_stop = early_stopping
    )
    

Epoch 0/499
----------


train Loss: 4.7830 Acc: 0.0601
val Loss: 4.0622 Acc: 0.0955

Epoch 1/499
----------
train Loss: 3.9684 Acc: 0.1043
val Loss: 3.8294 Acc: 0.1263

Epoch 2/499
----------
train Loss: 3.8180 Acc: 0.1242
val Loss: 3.6893 Acc: 0.1419

Epoch 3/499
----------
train Loss: 3.7269 Acc: 0.1354
val Loss: 3.6308 Acc: 0.1497

Epoch 4/499
----------
train Loss: 3.6545 Acc: 0.1489
val Loss: 3.5383 Acc: 0.1662

Epoch 5/499
----------
train Loss: 3.5927 Acc: 0.1593
val Loss: 3.4460 Acc: 0.1866

Epoch 6/499
----------
train Loss: 3.5335 Acc: 0.1716
val Loss: 3.3843 Acc: 0.2031

Epoch 7/499
----------
train Loss: 3.4802 Acc: 0.1791
val Loss: 3.3208 Acc: 0.2105

Epoch 8/499
----------
train Loss: 3.4183 Acc: 0.1907
val Loss: 3.2849 Acc: 0.2138

Epoch 9/499
----------
train Loss: 3.3679 Acc: 0.1992
val Loss: 3.2079 Acc: 0.2281

Epoch 10/499
----------
train Loss: 3.3331 Acc: 0.2052
val Loss: 3.1790 Acc: 0.2301

Epoch 11/499
----------
train Loss: 3.2742 Acc: 0.2149
val Loss: 3.1108 Acc: 0.2446

Epoch 12/499


KeyboardInterrupt: 