# MobileNetV2 CIFAR-10 Training (DDP)

In [None]:
!pip install accelerate -q
print("Installed")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import os
from collections import OrderedDict
from tqdm.auto import tqdm
from accelerate import Accelerator, notebook_launcher

# ===================== MobileNetV2 =====================
class Block(nn.Module):
    def __init__(self, in_planes, out_planes, expansion, stride):
        super(Block, self).__init__()
        self.stride = stride
        planes = expansion * in_planes
        self.conv1 = nn.Conv2d(in_planes, planes, 1, 1, 0, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, groups=planes, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, out_planes, 1, 1, 0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_planes)
        self.shortcut = nn.Sequential()
        if stride == 1 and in_planes != out_planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, out_planes, 1, 1, 0, bias=False),
                nn.BatchNorm2d(out_planes))

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        return out + self.shortcut(x) if self.stride == 1 else out

class MobileNetV2(nn.Module):
    cfg = [(1,16,1,1), (6,24,2,1), (6,32,3,2), (6,64,4,2), (6,96,3,1), (6,160,3,2), (6,320,1,1)]

    def __init__(self, num_classes=10):
        super(MobileNetV2, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.layers = self._make_layers(32)
        self.conv2 = nn.Conv2d(320, 1280, 1, 1, 0, bias=False)
        self.bn2 = nn.BatchNorm2d(1280)
        self.linear = nn.Linear(1280, num_classes)

    def _make_layers(self, in_planes):
        layers = []
        for exp, out, num, stride in self.cfg:
            for s in [stride] + [1]*(num-1):
                layers.append(Block(in_planes, out, exp, s))
                in_planes = out
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(self.layers(out))))
        return self.linear(F.avg_pool2d(out, 4).view(out.size(0), -1))

# ===================== Platform Detection & Config =====================
def detect_platform():
    if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
        return 'kaggle'
    try:
        import google.colab
        return 'colab'
    except ImportError:
        return 'local'

def find_checkpoint(search_dir, filename):
    """디렉토리에서 체크포인트 파일 자동 탐색"""
    if not os.path.exists(search_dir):
        return None
    for root, dirs, files in os.walk(search_dir):
        if filename in files:
            return os.path.join(root, filename)
    return None

PLATFORM = detect_platform()

# Colab: Google Drive 마운트
if PLATFORM == 'colab':
    from google.colab import drive
    drive.mount('/content/drive')
    CKPT_DIR = '/content/drive/MyDrive/cifar10_checkpoints'
    os.makedirs(CKPT_DIR, exist_ok=True)
    CONFIG = {
        'lr': 0.1,
        'batch_size': 256,
        'total_epochs': 200,
        'epochs_per_run': 200,
        'num_workers': 2,
        'import_ckpt': f'{CKPT_DIR}/mobilenetv2_ckpt.pth',
        'save_ckpt': f'{CKPT_DIR}/mobilenetv2_ckpt.pth',
    }
# Kaggle: /kaggle/input, /kaggle/working
elif PLATFORM == 'kaggle':
    # 자동으로 체크포인트 찾기 (이전 버전 output을 input으로 추가한 경우)
    import_ckpt = find_checkpoint('/kaggle/input', 'mobilenetv2_ckpt.pth')

    CONFIG = {
        'lr': 0.1,
        'batch_size': 256,
        'total_epochs': 200,
        'epochs_per_run': 200,
        'num_workers': 2,
        'import_ckpt': import_ckpt,  # None이면 처음부터 시작
        'save_ckpt': '/kaggle/working/checkpoint/mobilenetv2_ckpt.pth',
    }
# Local: 고정 경로
else:
    CKPT_DIR = './checkpoints'
    os.makedirs(CKPT_DIR, exist_ok=True)
    CONFIG = {
        'lr': 0.1,
        'batch_size': 128,
        'total_epochs': 200,
        'epochs_per_run': 200,
        'num_workers': 2,
        'import_ckpt': f'{CKPT_DIR}/mobilenetv2_ckpt.pth',
        'save_ckpt': f'{CKPT_DIR}/mobilenetv2_ckpt.pth',
    }

print(f"Platform: {PLATFORM}")
print(f"Import: {CONFIG['import_ckpt']}")
print(f"Save: {CONFIG['save_ckpt']}")

In [None]:
def training_function():
    accelerator = Accelerator(mixed_precision='fp16')

    # Data - rank 0만 다운로드, 나머지는 대기
    if accelerator.is_main_process:
        torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
        torchvision.datasets.CIFAR10(root='./data', train=False, download=True)
    accelerator.wait_for_everyone()

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465), (0.2023,0.1994,0.2010))])
    transform_test = transforms.Compose([
        transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465), (0.2023,0.1994,0.2010))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_train)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=CONFIG['num_workers'], pin_memory=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=CONFIG['num_workers'], pin_memory=True)

    # Model
    net = MobileNetV2()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=CONFIG['lr'], momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['total_epochs'])

    best_acc, start_epoch = 0, 0

    # Load checkpoint (None 체크 추가)
    if CONFIG['import_ckpt'] and os.path.exists(CONFIG['import_ckpt']):
        ckpt = torch.load(CONFIG['import_ckpt'], map_location='cpu')
        state_dict = ckpt['net']
        if list(state_dict.keys())[0].startswith('module.'):
            state_dict = OrderedDict([(k.replace('module.',''), v) for k, v in state_dict.items()])
        net.load_state_dict(state_dict)
        best_acc, start_epoch = ckpt.get('best_acc', 0), ckpt['epoch'] + 1
        if 'optimizer' in ckpt:
            optimizer.load_state_dict(ckpt['optimizer'])
            scheduler.load_state_dict(ckpt['scheduler'])
        accelerator.print(f'Loaded checkpoint (epoch {ckpt["epoch"]}, best_acc {best_acc:.2f}%)')
    else:
        accelerator.print('No checkpoint found, starting from scratch')

    # DDP Prepare
    net, optimizer, trainloader, testloader, scheduler = accelerator.prepare(
        net, optimizer, trainloader, testloader, scheduler)

    accelerator.print(f'Device: {accelerator.device}, Num GPUs: {accelerator.num_processes}')

    # Training - epochs_per_run 단위로 끊어서 학습
    step = CONFIG['epochs_per_run']
    end_epoch = min(((start_epoch // step) + 1) * step, CONFIG['total_epochs'])
    accelerator.print(f'Training epochs {start_epoch} ~ {end_epoch-1}')

    for epoch in range(start_epoch, end_epoch):
        # Train
        net.train()
        train_loss, correct, total = 0, 0, 0
        for inputs, targets in tqdm(trainloader, disable=not accelerator.is_main_process):
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            accelerator.backward(loss)
            optimizer.step()
            train_loss += loss.item()
            correct += outputs.argmax(1).eq(targets).sum().item()
            total += targets.size(0)
        train_acc = 100. * correct / total

        # Test
        net.eval()
        test_correct, test_total = 0, 0
        with torch.no_grad():
            for inputs, targets in testloader:
                outputs = net(inputs)
                test_correct += outputs.argmax(1).eq(targets).sum().item()
                test_total += targets.size(0)
        test_acc = 100. * test_correct / test_total

        scheduler.step()

        # Save checkpoint
        if accelerator.is_main_process:
            os.makedirs(os.path.dirname(CONFIG['save_ckpt']), exist_ok=True)

            # Best 모델 별도 저장
            if test_acc > best_acc:
                best_ckpt_path = CONFIG['save_ckpt'].replace('.pth', '_best.pth')
                torch.save({'net': accelerator.unwrap_model(net).state_dict(),
                            'acc': test_acc, 'epoch': epoch}, best_ckpt_path)

            # Last checkpoint (resume용)
            torch.save({'net': accelerator.unwrap_model(net).state_dict(),
                        'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(),
                        'acc': test_acc, 'best_acc': max(best_acc, test_acc), 'epoch': epoch}, CONFIG['save_ckpt'])

        if test_acc > best_acc:
            best_acc = test_acc

        accelerator.print(f'Epoch {epoch}: Train {train_acc:.2f}%, Test {test_acc:.2f}%, Best {best_acc:.2f}%')

    accelerator.print(f'Done! Best: {best_acc:.2f}%')
print("Ready")

In [None]:
# DDP 실행 (nvidia-smi로 GPU 감지, CUDA 초기화 없음)
import subprocess
try:
    result = subprocess.run(['nvidia-smi', '-L'], capture_output=True, text=True)
    num_gpus = len([l for l in result.stdout.strip().split('\n') if l])
except:
    num_gpus = 1
print(f'Detected GPUs: {num_gpus}')
notebook_launcher(training_function, num_processes=num_gpus)