# Base Trainer

In [None]:
import os
from utils import *
from agents import *
import time
import torch
import torch.nn as nn
import copy
import torch.nn.functional as F
from copy import deepcopy
import argparse
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import resnet18
from agents.adv import FGSM
import random
import math
from ov_utils import *

seed_everything(42)

os.makedirs('checkpoints_cifar10', exist_ok = True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'


transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4)

model = resnet18(num_classes=10)
model = model.to(device)

num_epochs = 200
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, weight_decay=5e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)

best_acc = 0.0 

def train(epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch} | Batch: {batch_idx}/{len(trainloader)} | Loss: {running_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})')

def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    acc = 100. * correct / total
    print(f'Test Epoch: {epoch} | Loss: {test_loss/len(testloader):.3f} | Acc: {acc:.3f}% ({correct}/{total})')
    
    if acc > best_acc:
        print(f'New best accuracy: {acc:.3f}% (previous best: {best_acc:.3f}%), saving the model...')
        best_acc = acc
        torch.save(model.state_dict(), 'checkpoints_cifar10/resnet18_cifar10_best.pth')
    
    return acc
    
for epoch in range(num_epochs):
    train(epoch)
    test(epoch)
    scheduler.step()

torch.save(model.state_dict(), 'checkpoints_cifar10/resnet18_cifar10_final.pth')

# Exclude Trainer

In [None]:
num_exclude = 1
all_classes = list(range(10))
excluded_classes = random.sample(all_classes, num_exclude)
print(f"Excluded Labels: {excluded_classes}")

train_indices = [
    idx for idx, label in enumerate(trainset.targets)
    if label not in excluded_classes
]

trainset = torch.utils.data.Subset(trainset, train_indices)

trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=128,
    shuffle=True,
    num_workers=4
)
testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=100,
    shuffle=False,
    num_workers=4
)

model = resnet18(num_classes=10)
model = model.to(device)

num_epochs = 200
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, weight_decay=5e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                           milestones=[60, 120, 160],
                                           gamma=0.2)

best_acc = 0.0  # 최고 테스트 정확도 저장

def train(epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch} | Batch: {batch_idx}/{len(trainloader)} '
                  f'| Loss: {running_loss/(batch_idx+1):.3f} '
                  f'| Acc: {100.*correct/total:.3f}% ({correct}/{total})')

def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    acc = 100. * correct / total
    print(f'Test Epoch: {epoch} | Loss: {test_loss/len(testloader):.3f} '
          f'| Acc: {acc:.3f}% ({correct}/{total})')
    
    # 최고 정확도 모델 갱신 시 저장
    if acc > best_acc:
        print(f'New best accuracy: {acc:.3f}% '
              f'(previous best: {best_acc:.3f}%), saving the model...')
        best_acc = acc
        torch.save(model.state_dict(), 'checkpoints_cifar10/resnet18_cifar10_retrain_best.pth')
    
    return acc

for epoch in range(num_epochs):
    train(epoch)
    test(epoch)
    scheduler.step()

torch.save(model.state_dict(), 'checkpoints_cifar10/resnet18_cifar10_retrain_final.pth')