In [1]:
# !pip install ipywidgets
# !jupyter nbextension enable --py widgetsnbextension

In [2]:
# ==================== Cell 1: 라이브러리 import 및 경고 억제 ====================

import warnings
warnings.filterwarnings('ignore', message='.*UnsupportedFieldAttributeWarning.*')

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import numpy as np
import wandb
from tqdm.notebook import tqdm  # Jupyter용 tqdm
import random
import matplotlib.pyplot as plt

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")



PyTorch version: 2.8.0+cu126
CUDA available: True
CUDA version: 12.6
GPU: NVIDIA GeForce RTX 4050 Laptop GPU


In [3]:
# ==================== Cell 2: 시드 설정 ====================

def set_seed(seed=42):
    """재현성을 위한 시드 설정"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Random seed set to {seed}")

set_seed(42)

Random seed set to 42


In [4]:
# ==================== Cell 3: 데이터 증강 클래스 ====================

class MixupTransform:
    """Mixup 데이터 증강"""
    def __init__(self, alpha=1.0):
        self.alpha = alpha
    
    def __call__(self, batch_data, batch_labels):
        if self.alpha > 0:
            lam = np.random.beta(self.alpha, self.alpha)
        else:
            lam = 1
        
        batch_size = batch_data.size(0)
        index = torch.randperm(batch_size).to(batch_data.device)
        
        mixed_data = lam * batch_data + (1 - lam) * batch_data[index]
        labels_a, labels_b = batch_labels, batch_labels[index]
        
        return mixed_data, labels_a, labels_b, lam


class CutMixTransform:
    """CutMix 데이터 증강"""
    def __init__(self, alpha=1.0):
        self.alpha = alpha
    
    def __call__(self, batch_data, batch_labels):
        if self.alpha > 0:
            lam = np.random.beta(self.alpha, self.alpha)
        else:
            lam = 1
        
        batch_size = batch_data.size(0)
        index = torch.randperm(batch_size).to(batch_data.device)
        
        # CutMix 영역 계산
        _, _, H, W = batch_data.shape
        cut_rat = np.sqrt(1. - lam)
        cut_w = int(W * cut_rat)
        cut_h = int(H * cut_rat)
        
        cx = np.random.randint(W)
        cy = np.random.randint(H)
        
        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)
        
        mixed_data = batch_data.clone()
        mixed_data[:, :, bby1:bby2, bbx1:bbx2] = batch_data[index, :, bby1:bby2, bbx1:bbx2]
        
        # 실제 lambda 재계산
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
        
        labels_a, labels_b = batch_labels, batch_labels[index]
        
        return mixed_data, labels_a, labels_b, lam

print("Data augmentation classes defined successfully!")

Data augmentation classes defined successfully!


In [5]:
# ==================== Cell 3: 데이터 증강 클래스 ====================

class MixupTransform:
    """Mixup 데이터 증강"""
    def __init__(self, alpha=1.0):
        self.alpha = alpha
    
    def __call__(self, batch_data, batch_labels):
        if self.alpha > 0:
            lam = np.random.beta(self.alpha, self.alpha)
        else:
            lam = 1
        
        batch_size = batch_data.size(0)
        index = torch.randperm(batch_size).to(batch_data.device)
        
        mixed_data = lam * batch_data + (1 - lam) * batch_data[index]
        labels_a, labels_b = batch_labels, batch_labels[index]
        
        return mixed_data, labels_a, labels_b, lam


class CutMixTransform:
    """CutMix 데이터 증강"""
    def __init__(self, alpha=1.0):
        self.alpha = alpha
    
    def __call__(self, batch_data, batch_labels):
        if self.alpha > 0:
            lam = np.random.beta(self.alpha, self.alpha)
        else:
            lam = 1
        
        batch_size = batch_data.size(0)
        index = torch.randperm(batch_size).to(batch_data.device)
        
        # CutMix 영역 계산
        _, _, H, W = batch_data.shape
        cut_rat = np.sqrt(1. - lam)
        cut_w = int(W * cut_rat)
        cut_h = int(H * cut_rat)
        
        cx = np.random.randint(W)
        cy = np.random.randint(H)
        
        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)
        
        mixed_data = batch_data.clone()
        mixed_data[:, :, bby1:bby2, bbx1:bbx2] = batch_data[index, :, bby1:bby2, bbx1:bbx2]
        
        # 실제 lambda 재계산
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
        
        labels_a, labels_b = batch_labels, batch_labels[index]
        
        return mixed_data, labels_a, labels_b, lam

print("Data augmentation classes defined successfully!")

Data augmentation classes defined successfully!


In [6]:
# ==================== Cell 3: 데이터 증강 클래스 ====================

class MixupTransform:
    """Mixup 데이터 증강"""
    def __init__(self, alpha=1.0):
        self.alpha = alpha
    
    def __call__(self, batch_data, batch_labels):
        if self.alpha > 0:
            lam = np.random.beta(self.alpha, self.alpha)
        else:
            lam = 1
        
        batch_size = batch_data.size(0)
        index = torch.randperm(batch_size).to(batch_data.device)
        
        mixed_data = lam * batch_data + (1 - lam) * batch_data[index]
        labels_a, labels_b = batch_labels, batch_labels[index]
        
        return mixed_data, labels_a, labels_b, lam


class CutMixTransform:
    """CutMix 데이터 증강"""
    def __init__(self, alpha=1.0):
        self.alpha = alpha
    
    def __call__(self, batch_data, batch_labels):
        if self.alpha > 0:
            lam = np.random.beta(self.alpha, self.alpha)
        else:
            lam = 1
        
        batch_size = batch_data.size(0)
        index = torch.randperm(batch_size).to(batch_data.device)
        
        # CutMix 영역 계산
        _, _, H, W = batch_data.shape
        cut_rat = np.sqrt(1. - lam)
        cut_w = int(W * cut_rat)
        cut_h = int(H * cut_rat)
        
        cx = np.random.randint(W)
        cy = np.random.randint(H)
        
        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)
        
        mixed_data = batch_data.clone()
        mixed_data[:, :, bby1:bby2, bbx1:bbx2] = batch_data[index, :, bby1:bby2, bbx1:bbx2]
        
        # 실제 lambda 재계산
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
        
        labels_a, labels_b = batch_labels, batch_labels[index]
        
        return mixed_data, labels_a, labels_b, lam

print("Data augmentation classes defined successfully!")

Data augmentation classes defined successfully!


In [7]:
# ==================== Cell 4: SE Block ====================

class SEBlock(nn.Module):
    """Squeeze-and-Excitation Block"""
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.squeeze(x).view(b, c)
        y = self.excitation(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

print("SE Block defined successfully!")

SE Block defined successfully!


In [8]:
# ==================== Cell 5: SENet Block ====================

class SENetBlock(nn.Module):
    """SENet의 기본 블록"""
    def __init__(self, in_channels, out_channels, stride=1, reduction=16):
        super(SENetBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, 
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.se = SEBlock(out_channels, reduction)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, 
                         stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.se(out)
        out += self.shortcut(x)
        out = F.relu(out)
        return out

print("SENet Block defined successfully!")

SENet Block defined successfully!


In [9]:
# ==================== Cell 6: SENet 모델 ====================

class SENet(nn.Module):
    """FashionMNIST용 SENet"""
    def __init__(self, num_classes=10, reduction=16):
        super(SENet, self).__init__()
        
        self.in_channels = 32
        
        # 초기 Conv 레이어
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        
        # SE 블록들
        self.layer1 = self._make_layer(32, 2, stride=1, reduction=reduction)
        self.layer2 = self._make_layer(64, 2, stride=2, reduction=reduction)
        self.layer3 = self._make_layer(128, 2, stride=2, reduction=reduction)
        self.layer4 = self._make_layer(256, 2, stride=2, reduction=reduction)
        
        # Local Response Normalization
        self.lrn = nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2)
        
        # 분류기
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(256, num_classes)
    
    def _make_layer(self, out_channels, num_blocks, stride, reduction):
        layers = []
        layers.append(SENetBlock(self.in_channels, out_channels, stride, reduction))
        self.in_channels = out_channels
        for _ in range(1, num_blocks):
            layers.append(SENetBlock(out_channels, out_channels, 1, reduction))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.lrn(out)
        
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.dropout(out)
        out = self.fc(out)
        
        return out

print("SENet model defined successfully!")

SENet model defined successfully!


In [10]:
# ==================== Cell 7: Early Stopping ====================

class EarlyStopping:
    """Early Stopping 구현"""
    def __init__(self, patience=10, min_delta=0.0, verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_model = None
    
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_model = model.state_dict().copy()
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.best_model = model.state_dict().copy()
            self.counter = 0

print("Early Stopping class defined successfully!")

Early Stopping class defined successfully!


In [11]:
# ==================== Cell 8: 학습 함수 ====================

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Mixup/CutMix용 손실 함수"""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def train_epoch(model, train_loader, criterion, optimizer, device, mixup, cutmix, use_mixing=True):
    """한 에포크 학습"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc='Training', leave=True)
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        
        # Mixing 적용 (50% 확률)
        if use_mixing and random.random() > 0.5:
            if random.random() > 0.5:
                # Mixup
                data, targets_a, targets_b, lam = mixup(data, target)
            else:
                # CutMix
                data, targets_a, targets_b, lam = cutmix(data, target)
            
            optimizer.zero_grad()
            outputs = model(data)
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
        else:
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, target)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
        pbar.set_postfix({
            'loss': f'{running_loss / (batch_idx + 1):.4f}',
            'acc': f'{100. * correct / total:.2f}%'
        })
    
    return running_loss / len(train_loader), 100. * correct / total


def validate(model, val_loader, criterion, device):
    """검증"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation', leave=True)
        for data, target in pbar:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            loss = criterion(outputs, target)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
            pbar.set_postfix({
                'loss': f'{running_loss / len(val_loader):.4f}',
                'acc': f'{100. * correct / total:.2f}%'
            })
    
    return running_loss / len(val_loader), 100. * correct / total

print("Training functions defined successfully!")

Training functions defined successfully!


In [12]:
# ==================== Cell 9: 하이퍼파라미터 설정 ====================

config = {
    'batch_size': 128,
    'epochs': 55,
    'learning_rate': 0.001,
    'weight_decay': 0.0001,
    'mixup_alpha': 1.0,
    'cutmix_alpha': 1.0,
    'se_reduction': 16,
    'early_stopping_patience': 15,
}

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

Configuration:
  batch_size: 128
  epochs: 55
  learning_rate: 0.001
  weight_decay: 0.0001
  mixup_alpha: 1.0
  cutmix_alpha: 1.0
  se_reduction: 16
  early_stopping_patience: 15


In [13]:
# ==================== Cell 10: Wandb 초기화 ====================

# Wandb 초기화
wandb.init(
    project='fashionmnist-senet',
    config=config,
    name='SENet-Mixup-CutMix'
)

print("Wandb initialized successfully!")

wandb: Currently logged in as: -ddj127 (-ddj127-korea-university-of-technology-and-education) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


Wandb initialized successfully!


In [14]:
# ==================== Cell 11: 디바이스 및 데이터 준비 ====================

# 디바이스 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# 데이터 변환
train_transform = transforms.Compose([
    transforms.Pad(4, fill=0),  # 28x28 -> 36x36
    transforms.RandomCrop(28),  # 36x36 -> 28x28
    transforms.ToTensor(),
    transforms.Normalize((0.2860,), (0.3530,))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.2860,), (0.3530,))
])

# 데이터셋 로드
print("\nLoading datasets...")
train_dataset = datasets.FashionMNIST(
    root='./data',
    train=True,
    download=True,
    transform=train_transform
)

test_dataset = datasets.FashionMNIST(
    root='./data',
    train=False,
    download=True,
    transform=test_transform
)

# 데이터로더
train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=2,  # Jupyter에서는 적은 worker 사용
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

Using device: cuda

Loading datasets...
Train dataset size: 60000
Test dataset size: 10000


In [15]:
# ==================== Cell 12: 모델 초기화 ====================

# 모델 초기화
model = SENet(num_classes=10, reduction=config['se_reduction']).to(device)
print(f'Total parameters: {sum(p.numel() for p in model.parameters()):,}')

# Wandb로 모델 추적
wandb.watch(model, log='all', log_freq=100)

# 손실 함수 및 최적화
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(
    model.parameters(),
    lr=config['learning_rate'],
    weight_decay=config['weight_decay']
)

# Mixing 변환
mixup = MixupTransform(alpha=config['mixup_alpha'])
cutmix = CutMixTransform(alpha=config['cutmix_alpha'])

# Early Stopping
early_stopping = EarlyStopping(
    patience=config['early_stopping_patience'],
    verbose=True
)

print("Model and training components initialized!")

Total parameters: 2,818,794
Model and training components initialized!


In [16]:
# ==================== Cell 13: 학습 실행 ====================

# 학습 히스토리 저장
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

best_acc = 0.0

for epoch in range(config['epochs']):
    print(f'\n{"="*70}')
    print(f'Epoch {epoch+1}/{config["epochs"]}')
    print(f'{"="*70}')
    
    # 학습
    train_loss, train_acc = train_epoch(
        model, train_loader, criterion, optimizer, device, mixup, cutmix
    )
    
    # 검증
    val_loss, val_acc = validate(model, test_loader, criterion, device)
    
    # 히스토리 저장
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Wandb 로깅
    wandb.log({
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'val_loss': val_loss,
        'val_acc': val_acc,
    })
    
    print(f'\nTrain Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    
    # 최고 정확도 모델 저장
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_acc': best_acc,
        }, 'best_model.pth')
        print(f'✓ Best model saved with accuracy: {best_acc:.2f}%')
    
    # Early Stopping 체크
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print('\n' + '='*70)
        print('Early stopping triggered!')
        print('='*70)
        model.load_state_dict(early_stopping.best_model)
        break

print("\nTraining completed!")


Epoch 1/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 1.0295, Train Acc: 63.59%
Val Loss: 0.5846, Val Acc: 78.43%
✓ Best model saved with accuracy: 78.43%

Epoch 2/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.8128, Train Acc: 70.37%
Val Loss: 0.6133, Val Acc: 77.76%
EarlyStopping counter: 1 out of 15

Epoch 3/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.7854, Train Acc: 69.11%
Val Loss: 0.3416, Val Acc: 89.50%
✓ Best model saved with accuracy: 89.50%

Epoch 4/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.7038, Train Acc: 73.33%
Val Loss: 0.3240, Val Acc: 88.51%

Epoch 5/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.6834, Train Acc: 74.21%
Val Loss: 0.2558, Val Acc: 91.89%
✓ Best model saved with accuracy: 91.89%

Epoch 6/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.6571, Train Acc: 75.03%
Val Loss: 0.2722, Val Acc: 90.94%
EarlyStopping counter: 1 out of 15

Epoch 7/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.6342, Train Acc: 73.71%
Val Loss: 0.2420, Val Acc: 92.31%
✓ Best model saved with accuracy: 92.31%

Epoch 8/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.6095, Train Acc: 74.23%
Val Loss: 0.2301, Val Acc: 92.45%
✓ Best model saved with accuracy: 92.45%

Epoch 9/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.6171, Train Acc: 76.98%
Val Loss: 0.2620, Val Acc: 90.74%
EarlyStopping counter: 1 out of 15

Epoch 10/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.5985, Train Acc: 74.40%
Val Loss: 0.2198, Val Acc: 93.01%
✓ Best model saved with accuracy: 93.01%

Epoch 11/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.5501, Train Acc: 77.55%
Val Loss: 0.2198, Val Acc: 92.90%
EarlyStopping counter: 1 out of 15

Epoch 12/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.6035, Train Acc: 74.50%
Val Loss: 0.2236, Val Acc: 92.73%
EarlyStopping counter: 2 out of 15

Epoch 13/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.5424, Train Acc: 78.44%
Val Loss: 0.2223, Val Acc: 92.38%
EarlyStopping counter: 3 out of 15

Epoch 14/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.5057, Train Acc: 78.93%
Val Loss: 0.1911, Val Acc: 93.78%
✓ Best model saved with accuracy: 93.78%

Epoch 15/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.5472, Train Acc: 77.99%
Val Loss: 0.1994, Val Acc: 93.23%
EarlyStopping counter: 1 out of 15

Epoch 16/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.5042, Train Acc: 79.26%
Val Loss: 0.2163, Val Acc: 93.13%
EarlyStopping counter: 2 out of 15

Epoch 17/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.5394, Train Acc: 79.25%
Val Loss: 0.2102, Val Acc: 92.78%
EarlyStopping counter: 3 out of 15

Epoch 18/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.5514, Train Acc: 76.46%
Val Loss: 0.1997, Val Acc: 93.30%
EarlyStopping counter: 4 out of 15

Epoch 19/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.5072, Train Acc: 79.35%
Val Loss: 0.1930, Val Acc: 93.48%
EarlyStopping counter: 5 out of 15

Epoch 20/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.5148, Train Acc: 80.04%
Val Loss: 0.2066, Val Acc: 93.28%
EarlyStopping counter: 6 out of 15

Epoch 21/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.5162, Train Acc: 77.97%
Val Loss: 0.2184, Val Acc: 93.37%
EarlyStopping counter: 7 out of 15

Epoch 22/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.5356, Train Acc: 75.11%
Val Loss: 0.1795, Val Acc: 93.78%

Epoch 23/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4858, Train Acc: 78.89%
Val Loss: 0.1938, Val Acc: 93.74%
EarlyStopping counter: 1 out of 15

Epoch 24/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4983, Train Acc: 77.21%
Val Loss: 0.1875, Val Acc: 93.63%
EarlyStopping counter: 2 out of 15

Epoch 25/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4916, Train Acc: 78.17%
Val Loss: 0.1918, Val Acc: 94.17%
✓ Best model saved with accuracy: 94.17%
EarlyStopping counter: 3 out of 15

Epoch 26/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4835, Train Acc: 79.21%
Val Loss: 0.1932, Val Acc: 93.79%
EarlyStopping counter: 4 out of 15

Epoch 27/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.5157, Train Acc: 78.11%
Val Loss: 0.2025, Val Acc: 94.02%
EarlyStopping counter: 5 out of 15

Epoch 28/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4780, Train Acc: 79.67%
Val Loss: 0.2250, Val Acc: 94.28%
✓ Best model saved with accuracy: 94.28%
EarlyStopping counter: 6 out of 15

Epoch 29/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4623, Train Acc: 81.43%
Val Loss: 0.1839, Val Acc: 93.70%
EarlyStopping counter: 7 out of 15

Epoch 30/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4706, Train Acc: 77.88%
Val Loss: 0.1813, Val Acc: 94.34%
✓ Best model saved with accuracy: 94.34%
EarlyStopping counter: 8 out of 15

Epoch 31/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4602, Train Acc: 79.31%
Val Loss: 0.1942, Val Acc: 94.04%
EarlyStopping counter: 9 out of 15

Epoch 32/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4820, Train Acc: 80.59%
Val Loss: 0.2127, Val Acc: 94.02%
EarlyStopping counter: 10 out of 15

Epoch 33/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4419, Train Acc: 81.05%
Val Loss: 0.2085, Val Acc: 93.75%
EarlyStopping counter: 11 out of 15

Epoch 34/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4824, Train Acc: 79.30%
Val Loss: 0.1951, Val Acc: 93.77%
EarlyStopping counter: 12 out of 15

Epoch 35/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4394, Train Acc: 81.41%
Val Loss: 0.1996, Val Acc: 94.53%
✓ Best model saved with accuracy: 94.53%
EarlyStopping counter: 13 out of 15

Epoch 36/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4386, Train Acc: 81.44%
Val Loss: 0.1989, Val Acc: 94.31%
EarlyStopping counter: 14 out of 15

Epoch 37/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4588, Train Acc: 79.14%
Val Loss: 0.1730, Val Acc: 94.43%

Epoch 38/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4458, Train Acc: 80.86%
Val Loss: 0.1650, Val Acc: 94.35%

Epoch 39/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4525, Train Acc: 79.60%
Val Loss: 0.1717, Val Acc: 94.46%
EarlyStopping counter: 1 out of 15

Epoch 40/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4606, Train Acc: 80.97%
Val Loss: 0.2000, Val Acc: 94.20%
EarlyStopping counter: 2 out of 15

Epoch 41/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4571, Train Acc: 79.60%
Val Loss: 0.1864, Val Acc: 94.64%
✓ Best model saved with accuracy: 94.64%
EarlyStopping counter: 3 out of 15

Epoch 42/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4346, Train Acc: 81.13%
Val Loss: 0.1771, Val Acc: 94.39%
EarlyStopping counter: 4 out of 15

Epoch 43/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4321, Train Acc: 81.40%
Val Loss: 0.1807, Val Acc: 94.22%
EarlyStopping counter: 5 out of 15

Epoch 44/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4367, Train Acc: 80.93%
Val Loss: 0.1832, Val Acc: 94.24%
EarlyStopping counter: 6 out of 15

Epoch 45/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4400, Train Acc: 80.41%
Val Loss: 0.1856, Val Acc: 94.58%
EarlyStopping counter: 7 out of 15

Epoch 46/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4446, Train Acc: 80.77%
Val Loss: 0.1831, Val Acc: 94.34%
EarlyStopping counter: 8 out of 15

Epoch 47/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4674, Train Acc: 79.24%
Val Loss: 0.1814, Val Acc: 94.57%
EarlyStopping counter: 9 out of 15

Epoch 48/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4418, Train Acc: 79.16%
Val Loss: 0.1877, Val Acc: 94.41%
EarlyStopping counter: 10 out of 15

Epoch 49/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.3858, Train Acc: 82.27%
Val Loss: 0.1818, Val Acc: 94.50%
EarlyStopping counter: 11 out of 15

Epoch 50/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4133, Train Acc: 80.18%
Val Loss: 0.1828, Val Acc: 93.95%
EarlyStopping counter: 12 out of 15

Epoch 51/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4309, Train Acc: 80.79%
Val Loss: 0.1949, Val Acc: 94.91%
✓ Best model saved with accuracy: 94.91%
EarlyStopping counter: 13 out of 15

Epoch 52/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4402, Train Acc: 80.36%
Val Loss: 0.1903, Val Acc: 94.72%
EarlyStopping counter: 14 out of 15

Epoch 53/55


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Train Loss: 0.4143, Train Acc: 81.11%
Val Loss: 0.1725, Val Acc: 94.59%
EarlyStopping counter: 15 out of 15

Early stopping triggered!

Training completed!


In [19]:
# ==================== Cell 14: 최종 테스트 ====================

print('\n' + '='*70)
print('Final Test Evaluation')
print('='*70)

test_loss, test_acc = validate(model, test_loader, criterion, device)
print(f'\nFinal Test Loss: {test_loss:.4f}')
print(f'Final Test Accuracy: {test_acc:.2f}%')

wandb.log({
    'final_test_loss': test_loss,
    'final_test_acc': test_acc
})


Final Test Evaluation


Validation:   0%|          | 0/79 [00:00<?, ?it/s]


Final Test Loss: 0.1725
Final Test Accuracy: 94.59%


In [20]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [21]:
# ==================== Cell: 모델 구조 및 정보 확인 ====================

from torchinfo import summary

print("=" * 80)
print("MODEL ARCHITECTURE SUMMARY")
print("=" * 80)

# 모델 요약
model_summary = summary(
    model,
    input_size=(1, 1, 28, 28),
    col_names=["input_size", "output_size", "num_params", "kernel_size"],
    row_settings=["var_names"],
    verbose=1,
    device=device
)

print("\n" + "=" * 80)
print("MODEL STATISTICS")
print("=" * 80)

# 파라미터 통계
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
non_trainable_params = total_params - trainable_params

print(f"Total Parameters:        {total_params:,}")
print(f"Trainable Parameters:    {trainable_params:,}")
print(f"Non-trainable Parameters: {non_trainable_params:,}")
print(f"Model Size (MB):         {total_params * 4 / (1024**2):.2f}")

print("=" * 80)

# 레이어별 파라미터 수
print("\n" + "=" * 80)
print("PARAMETERS BY LAYER TYPE")
print("=" * 80)

layer_params = {}
for name, module in model.named_modules():
    if len(list(module.children())) == 0:  # 최하위 레이어만
        num_params = sum(p.numel() for p in module.parameters())
        if num_params > 0:
            layer_type = type(module).__name__
            layer_params[layer_type] = layer_params.get(layer_type, 0) + num_params

for layer_type, num_params in sorted(layer_params.items(), key=lambda x: x[1], reverse=True):
    print(f"{layer_type:.<40} {num_params:>15,} ({num_params/total_params*100:>5.2f}%)")

print("=" * 80)

MODEL ARCHITECTURE SUMMARY
Layer (type (var_name))                            Input Shape               Output Shape              Param #                   Kernel Shape
SENet (SENet)                                      [1, 1, 28, 28]            [1, 10]                   --                        --
├─Conv2d (conv1)                                   [1, 1, 28, 28]            [1, 32, 28, 28]           288                       [3, 3]
├─BatchNorm2d (bn1)                                [1, 32, 28, 28]           [1, 32, 28, 28]           64                        --
├─LocalResponseNorm (lrn)                          [1, 32, 28, 28]           [1, 32, 28, 28]           --                        --
├─Sequential (layer1)                              [1, 32, 28, 28]           [1, 32, 28, 28]           --                        --
│    └─SENetBlock (0)                              [1, 32, 28, 28]           [1, 32, 28, 28]           --                        --
│    │    └─Conv2d (conv1)         