In [2]:
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import timm
from torchvision import datasets, transforms
from typing import Dict, Any, Tuple, Optional
from pathlib import Path
import os

In [3]:
class AsymmetricLoss(nn.Module):
    """Asymmetric Loss for multi-label classification"""
    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
        super(AsymmetricLoss, self).__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
        self.eps = eps

    def forward(self, x, y):
        """"
        Parameters
        ----------
        x: input logits
        y: targets (multi-label binarized vector)
        """
        # Calculating Probabilities
        x_sigmoid = torch.sigmoid(x)
        xs_pos = x_sigmoid
        xs_neg = 1 - x_sigmoid

        # Asymmetric Clipping
        if self.clip is not None and self.clip > 0:
            xs_neg = (xs_neg + self.clip).clamp(max=1)

        # Basic CE calculation
        los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
        los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))

        # Asymmetric Focusing
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(False)
            pt0 = xs_pos * y
            pt1 = xs_neg * (1 - y)
            pt = pt0 + pt1
            one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
            one_sided_w = torch.pow(1 - pt, one_sided_gamma)
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(True)
            los_pos *= one_sided_w
            los_neg *= one_sided_w

        return -(los_pos + los_neg).sum()


In [4]:
class FocalLoss(nn.Module):
    """Focal Loss for dealing with class imbalance"""
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        return focal_loss.mean()

In [5]:
class ModelManager:
    """TIMM 모델 및 학습 관리"""
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.checkpoint_dir = Path("model_checkpoints")
        
        if not self.checkpoint_dir.is_dir():
            self.checkpoint_dir.mkdir(parents=True, exist_ok=True)

        self.classes = ('plane', 'car', 'bird', 'cat', 'deer',
                       'dog', 'frog', 'horse', 'ship', 'truck')
        
    @staticmethod
    def get_sweep_config() -> Dict[str, Any]:
        """WandB sweep 설정 반환"""
        return {
            'method': 'random',
            'metric': {
                'name': 'val_loss',
                'goal': 'minimize'
            },
            'parameters': {
                'model_name': {
                    # CIFAR-10에 적합한 작은/중간 크기 모델들 선택
                    'values': [
                        'resnet18',
                        'mobilenetv3_small_100',
                        'efficientnet_b0',
                        'vit_tiny_patch16_224',
                        'convnext_tiny',
                        'mobilevitv2_050'
                    ]
                },
                'optimizer': {
                    'values': ['adam', 'sgd', 'adamw']
                },
                'loss_function': {
                    'values': ['cross_entropy', 'focal']
                },
                'dropout': {
                    'values': [0.3, 0.4, 0.5]
                },
                'learning_rate': {
                    'distribution': 'log_uniform',
                    'min': -9.21,  # log(1e-4)
                    'max': -4.61   # log(1e-2)
                },
                'batch_size': {
                    'distribution': 'q_log_uniform_values',
                    'q': 8,
                    'min': 32,
                    'max': 256,
                },
                'epochs': {
                    'value': 5
                }
            }
        }
    
    def get_model(self, model_name: str, num_classes: int) -> nn.Module:
        """TIMM 모델 로드 및 CIFAR-10 크기에 맞게 조정"""
        if 'vit_tiny_patch16_224' in model_name:
            # ViT 모델의 경우에만 patch_size 적용
            model = timm.create_model(
                model_name,
                pretrained=True,
                num_classes=num_classes,
                img_size=32,
                patch_size=4  # 32x32 이미지에 맞게 patch size 조정
            )
        else:
            # 다른 모델들의 경우
            model = timm.create_model(
                model_name,
                pretrained=True,
                num_classes=num_classes,
                in_chans=3
            )
                
        return model.to(self.device)
    
    def get_loss_function(self, loss_name: str) -> nn.Module:
        """Loss function 생성"""
        if loss_name == 'focal':
            return FocalLoss()
        # elif loss_name == 'asymmetric':
        #     return AsymmetricLoss()
        return nn.CrossEntropyLoss()

    def get_optimizer(self, model: nn.Module, optimizer_name: str, learning_rate: float) -> torch.optim.Optimizer:
        """옵티마이저 생성"""
        if optimizer_name == "sgd":
            return optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
        elif optimizer_name == "adamw":
            return optim.AdamW(model.parameters(), lr=learning_rate)
        return optim.Adam(model.parameters(), lr=learning_rate)
    
    def get_data_loaders(self, batch_size: int) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
        """CIFAR-10 데이터 로더 생성"""
        # CIFAR-10에 최적화된 데이터 증강
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.4914, 0.4822, 0.4465],
                std=[0.2470, 0.2435, 0.2616]
            )
        ])
        
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.4914, 0.4822, 0.4465],
                std=[0.2470, 0.2435, 0.2616]
            )
        ])
        
        train_dataset = datasets.CIFAR10(
            root='./data',
            train=True,
            download=True,
            transform=train_transform
        )
        
        val_dataset = datasets.CIFAR10(
            root='./data',
            train=False,
            download=True,
            transform=test_transform
        )
        
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True
        )
        
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )
        
        return train_loader, val_loader
    
    def save_best_model(self, model: nn.Module, config: wandb.Config, val_loss: float):
        """최고 성능 모델 저장"""
        checkpoint_name = f"{config.model_name}_{config.optimizer}_{config.loss_function}_loss{val_loss:.4f}.pth"
        checkpoint_path = self.checkpoint_dir / checkpoint_name

        # 같은 조합의 이전 체크포인트 삭제
        for old_checkpoint in self.checkpoint_dir.glob(f"{config.model_name}_{config.optimizer}_{config.loss_function}_*.pth"):
            os.remove(old_checkpoint)

        # 항상 저장
        torch.save({
            'model_state_dict': model.state_dict(),
            'config': dict(config),
            'val_loss': val_loss
        }, checkpoint_path)

    
    def train_epoch(self, model: nn.Module, loader: torch.utils.data.DataLoader,
                   criterion: nn.Module, optimizer: torch.optim.Optimizer) -> Tuple[float, float]:
        """한 에폭 학습 수행"""
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for data, target in loader:
            data, target = data.to(self.device), target.to(self.device)
            optimizer.zero_grad()
            
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item()
            
            # 정확도 계산
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
            loss.backward()
            optimizer.step()
            
            wandb.log({
                "batch_loss": loss.item(),
                "batch_accuracy": 100. * correct / total
            })
            
        epoch_loss = total_loss / len(loader)
        epoch_acc = 100. * correct / total
        return epoch_loss, epoch_acc
    
    def validate(self, model: nn.Module, loader: torch.utils.data.DataLoader,
                criterion: nn.Module) -> Tuple[float, float]:
        """검증 수행"""
        model.eval()
        total_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in loader:
                data, target = data.to(self.device), target.to(self.device)
                output = model(data)
                loss = criterion(output, target)
                total_loss += loss.item()
                
                # 정확도 계산
                _, predicted = output.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()
        
        val_loss = total_loss / len(loader)
        val_acc = 100. * correct / total
        return val_loss, val_acc
    
    def train(self, config: wandb.Config = None):
        """전체 학습 프로세스 실행"""
        with wandb.init(config=config):
            config = wandb.config
            
            # 모델, 데이터, 손실함수, 옵티마이저 초기화
            model = self.get_model(config.model_name, num_classes=10)
            train_loader, val_loader = self.get_data_loaders(config.batch_size)
            criterion = self.get_loss_function(config.loss_function)
            optimizer = self.get_optimizer(model, config.optimizer, config.learning_rate)
            
            best_val_loss = float('inf')
            
            # wandb에 클래스 이름 기록
            wandb.config.update({"classes": self.classes})
            
            # 학습 수행
            for epoch in range(config.epochs):
                train_loss, train_acc = self.train_epoch(model, train_loader, criterion, optimizer)
                val_loss, val_acc = self.validate(model, val_loader, criterion)
                
                wandb.log({
                    "train_loss": train_loss,
                    "train_accuracy": train_acc,
                    "val_loss": val_loss,
                    "val_accuracy": val_acc,
                    "epoch": epoch
                })
                
                # 최고 성능 모델 저장
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    self.save_best_model(model, config, val_loss)
                    # 모델 저장했다는 메시지 출력
                    print(f"Model saved with loss: {val_loss:.4f}")
                else:
                    print(f"Model not saved, best loss: {best_val_loss:.4f}")
                    
                    
                print(f'Epoch: {epoch+1}/{config.epochs}')
                print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
                print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
                print('-' * 50)

In [6]:
def main():
    """메인 실행 함수"""
    wandb.login(key="c59a3db1bd9de26ab7b58ea9df28add1357ec84b")
    
    # 모델 매니저 초기화
    manager = ModelManager()
    
    # Sweep 설정 및 실행
    sweep_config = manager.get_sweep_config()
    sweep_id = wandb.sweep(sweep_config, project="timm-cifar10-sweeps-5")
    
    # Sweep Agent 실행
    wandb.agent(sweep_id, function=manager.train, count=5)

In [7]:
if __name__ == "__main__":
    main()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mxyztomas[0m ([33mxyztomas-xyz[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/jun/.netrc


Create sweep with ID: 8428yi1m
Sweep URL: https://wandb.ai/juno95/timm-cifar10-sweeps-5/sweeps/8428yi1m


[34m[1mwandb[0m: Agent Starting Run: h73hggbd with config:
[34m[1mwandb[0m: 	batch_size: 144
[34m[1mwandb[0m: 	dropout: 0.4
[34m[1mwandb[0m: 	epochs: 5
[34m[1mwandb[0m: 	learning_rate: 0.0036139350384962743
[34m[1mwandb[0m: 	loss_function: focal
[34m[1mwandb[0m: 	model_name: mobilevitv2_050
[34m[1mwandb[0m: 	optimizer: adamw
[34m[1mwandb[0m: Currently logged in as: [33mjuno95[0m. Use [1m`wandb login --relogin`[0m to force relogin


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [04:43<00:00, 601kB/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Model saved with loss: 0.3996
Epoch: 1/5
Train Loss: 0.5830, Train Acc: 65.95%
Val Loss: 0.3996, Val Acc: 73.17%
--------------------------------------------------
Model saved with loss: 0.3306
Epoch: 2/5
Train Loss: 0.3486, Train Acc: 77.33%
Val Loss: 0.3306, Val Acc: 78.60%
--------------------------------------------------
Model saved with loss: 0.3000
Epoch: 3/5
Train Loss: 0.2963, Train Acc: 80.16%
Val Loss: 0.3000, Val Acc: 80.57%
--------------------------------------------------
Model saved with loss: 0.2871
Epoch: 4/5
Train Loss: 0.2710, Train Acc: 81.37%
Val Loss: 0.2871, Val Acc: 81.30%
--------------------------------------------------
Model saved with loss: 0.2646
Epoch: 5/5
Train Loss: 0.2488, Train Acc: 82.69%
Val Loss: 0.2646, Val Acc: 82.38%
--------------------------------------------------


0,1
batch_accuracy,▁▂▂▃▄▅▅▆▆▆▇▇▇▇▇▇▇▇██████████████████████
batch_loss,▇▆█▆▅▅█▄▅▇▆▅▅▅▅▄▃▂▄▄▄▅▂▅▄▃▃▃▄▃▄▄▃▄▂▄▃▄▂▁
epoch,▁▃▅▆█
train_accuracy,▁▆▇▇█
train_loss,█▃▂▁▁
val_accuracy,▁▅▇▇█
val_loss,█▄▃▂▁

0,1
batch_accuracy,82.686
batch_loss,0.22207
epoch,4.0
train_accuracy,82.686
train_loss,0.24878
val_accuracy,82.38
val_loss,0.26457


[34m[1mwandb[0m: Agent Starting Run: 82i29xp0 with config:
[34m[1mwandb[0m: 	batch_size: 136
[34m[1mwandb[0m: 	dropout: 0.5
[34m[1mwandb[0m: 	epochs: 5
[34m[1mwandb[0m: 	learning_rate: 0.001895754175963262
[34m[1mwandb[0m: 	loss_function: focal
[34m[1mwandb[0m: 	model_name: resnet18
[34m[1mwandb[0m: 	optimizer: adam


Files already downloaded and verified
Files already downloaded and verified
Model saved with loss: 0.4712
Epoch: 1/5
Train Loss: 0.7308, Train Acc: 59.92%
Val Loss: 0.4712, Val Acc: 71.53%
--------------------------------------------------
Model saved with loss: 0.3709
Epoch: 2/5
Train Loss: 0.4208, Train Acc: 73.89%
Val Loss: 0.3709, Val Acc: 76.25%
--------------------------------------------------
Model saved with loss: 0.3218
Epoch: 3/5
Train Loss: 0.3488, Train Acc: 77.09%
Val Loss: 0.3218, Val Acc: 79.02%
--------------------------------------------------
Model not saved, best loss: 0.3218
Epoch: 4/5
Train Loss: 0.3084, Train Acc: 79.50%
Val Loss: 0.3440, Val Acc: 78.43%
--------------------------------------------------
Model saved with loss: 0.3073
Epoch: 5/5
Train Loss: 0.2818, Train Acc: 80.73%
Val Loss: 0.3073, Val Acc: 79.90%
--------------------------------------------------


0,1
batch_accuracy,▁▃▃▄▄▅▅▅▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████
batch_loss,▇█▅▅▅▄▄▃▂▃▂▃▃▃▃▃▃▂▂▂▂▂▂▂▃▂▂▃▃▂▂▁▂▂▂▃▃▁▁▂
epoch,▁▃▅▆█
train_accuracy,▁▆▇██
train_loss,█▃▂▁▁
val_accuracy,▁▅▇▇█
val_loss,█▄▂▃▁

0,1
batch_accuracy,80.732
batch_loss,0.29803
epoch,4.0
train_accuracy,80.732
train_loss,0.28183
val_accuracy,79.9
val_loss,0.30726


[34m[1mwandb[0m: Agent Starting Run: 5600gpn4 with config:
[34m[1mwandb[0m: 	batch_size: 248
[34m[1mwandb[0m: 	dropout: 0.5
[34m[1mwandb[0m: 	epochs: 5
[34m[1mwandb[0m: 	learning_rate: 0.0036500270066907262
[34m[1mwandb[0m: 	loss_function: cross_entropy
[34m[1mwandb[0m: 	model_name: vit_tiny_patch16_224
[34m[1mwandb[0m: 	optimizer: adamw
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011114527700313678, max=1.0…

Files already downloaded and verified
Files already downloaded and verified
Model saved with loss: 1.7473
Epoch: 1/5
Train Loss: 2.0911, Train Acc: 22.25%
Val Loss: 1.7473, Val Acc: 32.07%
--------------------------------------------------
Model saved with loss: 1.5896
Epoch: 2/5
Train Loss: 1.6957, Train Acc: 36.81%
Val Loss: 1.5896, Val Acc: 41.82%
--------------------------------------------------
Model saved with loss: 1.5207
Epoch: 3/5
Train Loss: 1.5853, Train Acc: 41.54%
Val Loss: 1.5207, Val Acc: 44.14%
--------------------------------------------------
Model saved with loss: 1.4995
Epoch: 4/5
Train Loss: 1.5365, Train Acc: 43.83%
Val Loss: 1.4995, Val Acc: 44.39%
--------------------------------------------------
Model saved with loss: 1.4267
Epoch: 5/5
Train Loss: 1.5207, Train Acc: 44.23%
Val Loss: 1.4267, Val Acc: 47.07%
--------------------------------------------------


0,1
batch_accuracy,▁▁▁▂▂▃▃▅▆▆▆▆▇▇▇▇▇▇▇▇████████████████████
batch_loss,█▇▆▄▄▅▄▄▄▃▃▃▃▃▃▂▁▂▁▂▂▂▂▂▂▂▂▂▁▁▂▂▂▁▁▂▁▂▂▁
epoch,▁▃▅▆█
train_accuracy,▁▆▇██
train_loss,█▃▂▁▁
val_accuracy,▁▆▇▇█
val_loss,█▅▃▃▁

0,1
batch_accuracy,44.234
batch_loss,1.62486
epoch,4.0
train_accuracy,44.234
train_loss,1.52065
val_accuracy,47.07
val_loss,1.42666


[34m[1mwandb[0m: Agent Starting Run: nyqeq6hd with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	epochs: 5
[34m[1mwandb[0m: 	learning_rate: 0.00017879501327875178
[34m[1mwandb[0m: 	loss_function: cross_entropy
[34m[1mwandb[0m: 	model_name: mobilenetv3_small_100
[34m[1mwandb[0m: 	optimizer: adam


Files already downloaded and verified
Files already downloaded and verified
Model saved with loss: 1.6866
Epoch: 1/5
Train Loss: 4.0242, Train Acc: 28.18%
Val Loss: 1.6866, Val Acc: 39.56%
--------------------------------------------------
Model saved with loss: 1.4826
Epoch: 2/5
Train Loss: 1.6050, Train Acc: 42.07%
Val Loss: 1.4826, Val Acc: 46.73%
--------------------------------------------------
Model saved with loss: 1.3095
Epoch: 3/5
Train Loss: 1.4250, Train Acc: 48.58%
Val Loss: 1.3095, Val Acc: 53.02%
--------------------------------------------------
Model saved with loss: 1.1994
Epoch: 4/5
Train Loss: 1.3258, Train Acc: 52.21%
Val Loss: 1.1994, Val Acc: 56.60%
--------------------------------------------------
Model saved with loss: 1.1215
Epoch: 5/5
Train Loss: 1.2217, Train Acc: 56.28%
Val Loss: 1.1215, Val Acc: 61.02%
--------------------------------------------------


0,1
batch_accuracy,▁▂▂▂▂▃▃▃▃▃▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████
batch_loss,█▂▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▁▁▂▂▂▁▂▁▂▂▁▁▂▁▁▂▂▁▁▂▂
epoch,▁▃▅▆█
train_accuracy,▁▄▆▇█
train_loss,█▂▂▁▁
val_accuracy,▁▃▅▇█
val_loss,█▅▃▂▁

0,1
batch_accuracy,56.278
batch_loss,1.1372
epoch,4.0
train_accuracy,56.278
train_loss,1.22168
val_accuracy,61.02
val_loss,1.12152


[34m[1mwandb[0m: Agent Starting Run: bvdstu0i with config:
[34m[1mwandb[0m: 	batch_size: 184
[34m[1mwandb[0m: 	dropout: 0.4
[34m[1mwandb[0m: 	epochs: 5
[34m[1mwandb[0m: 	learning_rate: 0.00024948928609067273
[34m[1mwandb[0m: 	loss_function: cross_entropy
[34m[1mwandb[0m: 	model_name: resnet18
[34m[1mwandb[0m: 	optimizer: adamw


Files already downloaded and verified
Files already downloaded and verified
Model saved with loss: 1.2148
Epoch: 1/5
Train Loss: 1.7240, Train Acc: 40.36%
Val Loss: 1.2148, Val Acc: 57.65%
--------------------------------------------------
Model saved with loss: 0.9461
Epoch: 2/5
Train Loss: 1.0850, Train Acc: 62.01%
Val Loss: 0.9461, Val Acc: 66.85%
--------------------------------------------------
Model saved with loss: 0.8021
Epoch: 3/5
Train Loss: 0.8947, Train Acc: 68.73%
Val Loss: 0.8021, Val Acc: 72.07%
--------------------------------------------------
Model saved with loss: 0.7198
Epoch: 4/5
Train Loss: 0.7883, Train Acc: 72.41%
Val Loss: 0.7198, Val Acc: 75.16%
--------------------------------------------------
Model saved with loss: 0.6652
Epoch: 5/5
Train Loss: 0.7198, Train Acc: 74.73%
Val Loss: 0.6652, Val Acc: 76.52%
--------------------------------------------------


0,1
batch_accuracy,▁▁▂▂▂▄▄▄▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇███████████████
batch_loss,██▇▇▆▅▅▄▄▄▄▄▃▃▃▃▃▃▂▃▃▂▃▂▃▂▂▂▂▂▂▁▁▂▂▂▁▂▁▁
epoch,▁▃▅▆█
train_accuracy,▁▅▇██
train_loss,█▄▂▁▁
val_accuracy,▁▄▆▇█
val_loss,█▅▃▂▁

0,1
batch_accuracy,74.726
batch_loss,0.82659
epoch,4.0
train_accuracy,74.726
train_loss,0.71983
val_accuracy,76.52
val_loss,0.66521
