# 문제

1. 얼리스탑핑 추가(3번이상 valid loss가 높으면 정지)

2. 스케쥴러 추가

3. paperswithcode 에서 cifar-10 데이터셋에서 좋은 성능을 보이는 모델을 timm에서 찾아 적용해보세요.

In [1]:
#!pip install albumentations timm

In [2]:
import timm

In [3]:
avail_pretrained_models = timm.list_models(pretrained=True)
len(avail_pretrained_models), avail_pretrained_models[:5]

(1298,
 ['bat_resnext26ts.ch_in1k',
  'beit_base_patch16_224.in22k_ft_in22k',
  'beit_base_patch16_224.in22k_ft_in22k_in1k',
  'beit_base_patch16_384.in22k_ft_in22k_in1k',
  'beit_large_patch16_224.in22k_ft_in22k'])

In [4]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

# Define data augmentation transforms
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)

# Reduce the total amount of data by half
trainset, _ = torch.utils.data.random_split(trainset, [len(trainset) // 4 * 3, len(trainset) // 4])

train_size = int(0.8 * len(trainset))
valid_size = len(trainset) - train_size
train_dataset, valid_dataset = torch.utils.data.random_split(trainset, [train_size, valid_size])

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
validloader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [5]:
from cifar10_models.vgg import vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn

# Untrained model
model = vgg11_bn()

# Pretrained model
model = vgg11_bn(pretrained=True)



# Load pre-trained ResNet-34 model
#model = models.resnet34(pretrained=False)

model.fc = nn.Linear(512, 10)

# loss function, optimizer 설정
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# gpu 설정
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

ModuleNotFoundError: No module named 'cifar10_models'

In [None]:
from tqdm import tqdm

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np

# Define scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=2, verbose=True)

# Define early stopping variables
best_valid_loss = float('inf')
early_stopping_patience = 5
early_stopping_counter = 0

for epoch in range(20):  # Train for 20 epochs
    print(epoch + 1,'에포크 학습 시작')
    train_loss = 0.0
    valid_loss = 0.0
    model.train()
    print('training...')
    for img, label in tqdm(iter(trainloader)):
        inputs, labels = img.to(device), label.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    model.eval()
    print('validating...')
    with torch.no_grad():
        for img, label in tqdm(iter(validloader)):
            inputs, labels = img.to(device), label.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            valid_loss += loss.item()

    # Adjust learning rate using scheduler
    scheduler.step(valid_loss / len(validloader))

    # Early stopping
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'best_model.pt')
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1

    print(f'Epoch [{epoch + 1}/5], Train Loss: {train_loss / len(trainloader):.3f}, Valid Loss: {valid_loss / len(validloader):.3f}')
    print('\n')

    # Check if early stopping criteria are met
    if early_stopping_counter >= early_stopping_patience:
        
        print('Early stopping triggered. Training halted.')
        break


print('학습 완료')
print('\n')

# 모델 예측
correct = 0
total = 0
print('testing...')

with torch.no_grad():
    for img, label in tqdm(iter(testloader)):
        images, labels = img.to(device), label.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('test 데이터에 대한 정확도: %d %%' % (
    100 * correct / total))

In [None]:
print('test 데이터에 대한 정확도: %d %%' % (
    100 * correct / total))