In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torch import optim
import torch.nn.init as init
from torch.autograd import Variable

import sys
import numpy as np

from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import os

from torchvision import utils
import matplotlib.pyplot as plt

import time
import copy
import wandb
import random
import math

import tarfile
import pickle
from torchvision import transforms
import torchvision

from PIL import Image
from torch.utils.data import random_split
from tqdm import tqdm

#### Seed setting

In [5]:
def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_random_seed(327)

#### Data loading/setting

In [None]:
class CIFAR100Dataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx].reshape(3, 32, 32).transpose(1, 2, 0)
        label = self.labels[idx]
        image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)

        return image, label
    
transformtrain = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

transformtest = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

train_data = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transformtrain)
test_data = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transformtest)

train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

#### Seed setting

In [14]:
class ShakeDropFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, training=True, p_drop=0.5, alpha_range=[-1, 1]):
        if training:
            gate = torch.cuda.FloatTensor([0]).bernoulli_(1 - p_drop)
            ctx.save_for_backward(gate)
            if gate.item() == 0:
                alpha = torch.cuda.FloatTensor(x.size(0)).uniform_(*alpha_range)
                alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x)
                return alpha * x
            else:
                return x
        else:
            return (1 - p_drop) * x

    @staticmethod
    def backward(ctx, grad_output):
        gate = ctx.saved_tensors[0]
        if gate.item() == 0:
            beta = torch.cuda.FloatTensor(grad_output.size(0)).uniform_(0, 1)
            beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output)
            beta = Variable(beta)
            return beta * grad_output, None, None, None
        else:
            return grad_output, None, None, None

class ShakeDrop(nn.Module):
    def __init__(self, p_drop=0.5, alpha_range=[-1, 1]):
        super(ShakeDrop, self).__init__()
        self.p_drop = p_drop
        self.alpha_range = alpha_range

    def forward(self, x):
        return ShakeDropFunction.apply(x, self.training, self.p_drop, self.alpha_range)

# ShakePyramidNet 구현
class ShakeBasicBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, p_shakedrop=1.0):
        super(ShakeBasicBlock, self).__init__()
        self.downsampled = stride == 2
        self.branch = self._make_branch(in_ch, out_ch, stride=stride)
        self.shortcut = None if not self.downsampled else nn.AvgPool2d(2)
        self.shake_drop = ShakeDrop(p_shakedrop)

    def forward(self, x):
        h = self.branch(x)
        h = self.shake_drop(h)
        h0 = x if not self.downsampled else self.shortcut(x)
        pad_zero = Variable(torch.zeros(h0.size(0), h.size(1) - h0.size(1), h0.size(2), h0.size(3)).float()).cuda()
        h0 = torch.cat([h0, pad_zero], dim=1)
        return h + h0

    def _make_branch(self, in_ch, out_ch, stride=1):
        return nn.Sequential(
            nn.BatchNorm2d(in_ch),
            nn.Conv2d(in_ch, out_ch, 3, padding=1, stride=stride, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_ch))

class ShakePyramidNet(nn.Module):
    def __init__(self, depth=110, alpha=270, label=100):
        super(ShakePyramidNet, self).__init__()
        in_ch = 16
        n_units = (depth - 2) // 6
        in_chs = [in_ch] + [in_ch + math.ceil((alpha / (3 * n_units)) * (i + 1)) for i in range(3 * n_units)]
        block = ShakeBasicBlock

        self.in_chs, self.u_idx = in_chs, 0
        self.ps_shakedrop = [1 - (1.0 - (0.5 / (3 * n_units)) * (i + 1)) for i in range(3 * n_units)]

        self.c_in = nn.Conv2d(3, in_chs[0], 3, padding=1)
        self.bn_in = nn.BatchNorm2d(in_chs[0])
        self.layer1 = self._make_layer(n_units, block, 1)
        self.layer2 = self._make_layer(n_units, block, 2)
        self.layer3 = self._make_layer(n_units, block, 2)
        self.bn_out = nn.BatchNorm2d(in_chs[-1])
        self.fc_out = nn.Linear(in_chs[-1], label)

        # 파라미터 초기화
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x):
        h = self.bn_in(self.c_in(x))
        h = self.layer1(h)
        h = self.layer2(h)
        h = self.layer3(h)
        h = F.relu(self.bn_out(h))
        h = F.avg_pool2d(h, 8)
        h = h.view(h.size(0), -1)
        h = self.fc_out(h)
        return h

    def _make_layer(self, n_units, block, stride=1):
        layers = []
        for i in range(int(n_units)):
            layers.append(block(self.in_chs[self.u_idx], self.in_chs[self.u_idx+1], stride, self.ps_shakedrop[self.u_idx]))
            self.u_idx, stride = self.u_idx + 1, 1
        return nn.Sequential(*layers)

In [15]:
def accuracy_topk(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def super_class_accuracy(output, target):
    _, predicted = torch.max(output, 1)

    pred_superclass = torch.tensor([get_superclass(p.item()) for p in predicted], dtype=torch.long)
    target_superclass = torch.tensor([get_superclass(t.item()) for t in target], dtype=torch.long)

    correct = (pred_superclass == target_superclass).sum().item()
    total = target.size(0)

    accuracy = 100.0 * correct / total
    return accuracy

superclass_mapping = {
    0: [4, 30, 55, 72, 95],   # aquatic mammals
    1: [1, 32, 67, 73, 91],   # fish
    2: [54, 62, 70, 82, 92],  # flowers
    3: [9, 10, 16, 28, 61],   # food containers
    4: [0, 51, 53, 57, 83],   # fruit and vegetables
    5: [22, 39, 40, 86, 87],  # household electrical devices
    6: [5, 20, 25, 84, 94],   # household furniture
    7: [6, 7, 14, 18, 24],    # insects
    8: [3, 42, 43, 88, 97],   # large carnivores
    9: [12, 17, 37, 68, 76],  # large man-made outdoor things
    10: [23, 33, 49, 60, 71], # large natural outdoor scenes
    11: [15, 19, 21, 31, 38], # large omnivores and herbivores
    12: [34, 63, 64, 66, 75], # medium-sized mammals
    13: [26, 45, 77, 79, 99], # non-insect invertebrates
    14: [2, 11, 35, 46, 98],  # people
    15: [27, 29, 44, 78, 93], # reptiles
    16: [36, 50, 65, 74, 80], # small mammals
    17: [47, 52, 56, 59, 96], # trees
    18: [8, 13, 48, 58, 90],  # vehicles 1
    19: [41, 69, 81, 85, 89], # vehicles 2
}

def get_superclass(label):
    for super_class, classes in superclass_mapping.items():
        if label in classes:
            return super_class
    return None

In [16]:
def cutmix_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    rand_index = torch.randperm(x.size()[0]).to(x.device)
    target_a = y
    target_b = y[rand_index]
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = x[rand_index, :, bbx1:bbx2, bby1:bby2]
    return x, target_a, target_b, lam

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

def mixup_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

In [17]:
def train(model, train_loader, optimizer, criterion, device, use_cutmix=False, use_mixup=False):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        if use_cutmix:
            inputs, target_a, target_b, lam = cutmix_data(inputs, labels)
            outputs = model(inputs)
            loss = lam * criterion(outputs, target_a) + (1 - lam) * criterion(outputs, target_b)
        elif use_mixup:
            inputs, target_a, target_b, lam = mixup_data(inputs, labels)
            outputs = model(inputs)
            loss = lam * criterion(outputs, target_a) + (1 - lam) * criterion(outputs, target_b)
        else:
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total

    return epoch_loss, epoch_acc

In [18]:
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    total = 0
    top1_correct = 0
    top5_correct = 0
    super_class_correct = 0

    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            total += labels.size(0)

            top1_acc, top5_acc = accuracy_topk(outputs, labels, topk=(1, 5))
            top1_correct += (top1_acc.item() * inputs.size(0)) / 100
            top5_correct += (top5_acc.item() * inputs.size(0)) / 100
            super_class_correct += super_class_accuracy(outputs, labels) * inputs.size(0)

    epoch_loss = running_loss / len(loader)
    top1_accuracy = 100.0 * top1_correct / total
    top5_accuracy = 100.0 * top5_correct / total
    super_class_accuracy_final = super_class_correct / total

    return epoch_loss, top1_accuracy, top5_accuracy, super_class_accuracy_final

In [None]:
def train_and_evaluate(model, criterion, optimizer, scheduler, num_epochs, train_loader, test_loader, device):
    best_combined_accuracy = 270
    best_model_path = None 

    for epoch in range(num_epochs):
        
        train_loss, train_acc = train(model, train_loader, optimizer, criterion, device, use_cutmix=True)
        val_loss, val_top1_acc, val_top5_acc, val_super_class_acc = evaluate(model, test_loader, criterion, device)
        combined_accuracy = val_top1_acc + val_top5_acc + val_super_class_acc

        print(f'Epoch {epoch + 1}/{num_epochs}')
        print(f'Test Loss: {val_loss:.4f}, Top-1 Accuracy: {val_top1_acc:.2f}%, Top-5 Accuracy: {val_top5_acc:.2f}%, Super-Class Accuracy: {val_super_class_acc:.2f}%')

        scheduler.step(val_top1_acc)

        if combined_accuracy > best_combined_accuracy:
            best_combined_accuracy = combined_accuracy

            if best_model_path and os.path.exists(best_model_path):
                os.remove(best_model_path)

            best_model_path = f"pyramidnet_best_model_epoch_{epoch + 1}.pth"
            torch.save(model.state_dict(), best_model_path)
            print(f"New best model found at Epoch {epoch + 1} with combined accuracy: {combined_accuracy:.2f}. Model saved.")


config = {
    'epoch': 200,
    'lr': 0.1,
    'weight_decay': 5e-4,
    'momentum': 0.9,
    "nesterov": True,
    'patience': 10,
    'factor': 0.1,   
    'min_lr': 1e-6   
}


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ShakePyramidNet().to(device)
criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(
    model.parameters(),
    lr=config['lr'],
    weight_decay=config['weight_decay'],
    momentum=config['momentum'],
    nesterov=config["nesterov"]
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',
    factor=config['factor'],
    patience=config['patience'],
    min_lr=config['min_lr']
)


train_and_evaluate(model, criterion, optimizer, scheduler, config['epoch'], train_loader, test_loader, device)