# 0. Import & Set Parameters

In [1]:
# 표준 라이브러리
import argparse
import datetime
import functools
import math
import os
import random
import time

# 서드파티 라이브러리
import numpy as np
from scipy.linalg import sqrtm
from tensorboardX import SummaryWriter
from tqdm import tqdm

# PyTorch 관련 라이브러리
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from torch.nn import init, utils
import torch.nn.functional as F
from torch.nn import Parameter as P, init, utils
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
from torchvision.models import inception_v3
from torchvision.utils import save_image
import torchvision.transforms as transforms
import warnings

# 경고 메시지 비활성화
warnings.filterwarnings("ignore")

In [2]:
# Training information
TRAIN = True
VERSION = 'cifar100_superclass'

# Model hyper-parameters
IMSIZE = 32
Z_DIM = 128
G_CONV_DIM = 128
D_CONV_DIM = 128
N_CLASS = 20
GEN_DISTRIBUTION = 'normal' 
GEN_BOTTOM_WIDTH = 4  
SEED = 46 
FIX_SEED = False # If you want to fix the seed, set this to True.

# Training setting
TOTAL_STEP = 1000000
BATCH_SIZE = 64
NUM_WORKERS = 2
G_LR = 0.0002
D_LR = 0.0002
LR_DECAY = 0.95
BETA1 = 0.0
BETA2 = 0.999
D_STEP = 3
G_STEP = 1

# Misc
CUDA = 0

# Path
IMAGE_PATH = './data'
LOG_PATH = './logs'
MODEL_SAVE_PATH = './models'
SAMPLE_PATH = './samples'
FID_MEAN_COV = './datasetMoment/cifar100'

# Step size
LOG_STEP = 50
SAMPLE_STEP = 500
MODEL_SAVE_STEP = 5000
METRIC_CALCULATION_STEP = 2000
CALCULATE_FID = True
METRIC_IMAGES_NUM = 200

# Checkpoint
RESUME_TRAINING = False 
RESUME_CHECKPOINT = None

# ====================== Example ====================== #
# RESUME_TRAINING = True 
# RESUME_CHECKPOINT = 'models/cifar100_superclass/checkpoints/checkpoint_130000.pt'


def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

if FIX_SEED:
    seed_everything(SEED)

# 1. Utils

In [3]:
# Reporter
class Reporter:
    def __init__(self, reportPath):
        self.path = reportPath
        self.withTimeStamp = False
        self.index = 1
        self.timeStrFormat = '%Y-%m-%d %H:%M:%S'
        
        # KST 시간대 설정
        self.kst = datetime.timezone(datetime.timedelta(hours=9))
        
        # 파일명에 한국 시간 타임스탬프 추가
        now = datetime.datetime.now(self.kst)
        timeStr = now.strftime('%Y%m%d%H%M%S')
        self.path = self.path + ".%s" % timeStr
        
        # 파일이 없으면 생성
        if not os.path.exists(self.path):
            f = open(self.path, 'w')
            f.close()

    def writeInfo(self, strLine):
        with open(self.path, 'a+') as logf:
            # 현재 한국 시간 가져오기
            now = datetime.datetime.now(self.kst)
            timeStr = now.strftime(self.timeStrFormat)
            logf.writelines("[%d]-[%s]-[info] %s\n" % (self.index, timeStr, strLine))
            self.index += 1

    def writeModel(self, modelText):
        with open(self.path, 'a+') as logf:
            logf.writelines("[%d]-[model] %s\n" % (self.index, modelText))
            self.index += 1

    def writeTrainLog(self, step, logText):
        with open(self.path, 'a+') as logf:
            # 현재 한국 시간 가져오기
            now = datetime.datetime.now(self.kst)
            timeStr = now.strftime(self.timeStrFormat)
            logf.writelines("[%d]-[%s]-[logInfo]-[%d] %s\n" % (self.index, timeStr, step, logText))
            self.index += 1

def denorm(x):
    """Convert the range from [-1, 1] to [0, 1]"""
    out = (x + 1) / 2
    return out.clamp_(0, 1)

def str2bool(v):
    return v.lower() in ('true')
    
def makeFolder(path, version):
    if not os.path.exists(os.path.join(path, version)):
        os.makedirs(os.path.join(path, version))

# Decorator
def time_it(fn):
    def new_fn(*args):
        start = time.time()
        result = fn(*args)
        end = time.time()
        duration = end - start
        print('%.4f seconds are consumed in executing function:%s'\
              %(duration, fn.__name__))
        return result
    return new_fn

def load_inception_net():
    """Pre-trained inception network 로드"""
    inception_model = inception_v3(pretrained=True, transform_input=False)
    inception_model = WrapInception(inception_model.eval()).cuda()
    return inception_model

class WrapInception(nn.Module):
    def __init__(self, net):
        super(WrapInception,self).__init__()
        self.net = net
        self.mean = P(torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1),
                    requires_grad=False)
        self.std = P(torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1),
                   requires_grad=False)

    def forward(self, x):
        # Normalize x
        x = (x + 1.) / 2.0
        x = (x - self.mean) / self.std
        # Upsample if necessary
        if x.shape[2] != 299 or x.shape[3] != 299:
            x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True)
        # Get inception features
        x = self.net.Conv2d_1a_3x3(x)
        x = self.net.Conv2d_2a_3x3(x)
        x = self.net.Conv2d_2b_3x3(x)
        x = F.max_pool2d(x, kernel_size=3, stride=2)
        x = self.net.Conv2d_3b_1x1(x)
        x = self.net.Conv2d_4a_3x3(x)
        x = F.max_pool2d(x, kernel_size=3, stride=2)
        x = self.net.Mixed_5b(x)
        x = self.net.Mixed_5c(x)
        x = self.net.Mixed_5d(x)
        x = self.net.Mixed_6a(x)
        x = self.net.Mixed_6b(x)
        x = self.net.Mixed_6c(x)
        x = self.net.Mixed_6d(x)
        x = self.net.Mixed_6e(x)
        x = self.net.Mixed_7a(x)
        x = self.net.Mixed_7b(x)
        x = self.net.Mixed_7c(x)
        pool = torch.mean(x.view(x.size(0), x.size(1), -1), 2)
        logits = self.net.fc(F.dropout(pool, training=False).view(pool.size(0), -1))
        return pool, logits

# Sampler
class Distribution(torch.Tensor):
    # Init the params of the distribution
    def init_distribution(self, dist_type, **kwargs):
        self.dist_type = dist_type
        self.dist_kwargs = kwargs
        if self.dist_type == 'normal':
            self.mean, self.var = kwargs['mean'], kwargs['var']
        elif self.dist_type == 'categorical':
            self.num_categories = kwargs['num_categories']

    def sample_(self):
        if self.dist_type == 'normal':
            self.normal_(self.mean, self.var)
        elif self.dist_type == 'categorical':
            self.random_(0, self.num_categories)
    # return self.variable

    # Silly hack: overwrite the to() method to wrap the new object
    # in a distribution as well
    def to(self, *args, **kwargs):
        new_obj = Distribution(self)
        new_obj.init_distribution(self.dist_type, **self.dist_kwargs)
        new_obj.data = super().to(*args, **kwargs)
        return new_obj

# Sampling functions
def prepare_z_c(G_batch_size, dim_z, nclasses, device='cuda', z_var=1.0):
    z_ = Distribution(torch.randn(G_batch_size, dim_z, requires_grad=False))
    z_.init_distribution('normal', mean=0, var=z_var)
    z_ = z_.to(device, torch.float32)
    c_ = Distribution(torch.zeros(G_batch_size, requires_grad=False))
    c_.init_distribution('categorical',num_categories=nclasses)
    c_ = c_.to(device, torch.int64)
    return z_,c_

def prepareSampleZ(G_batch_size, dim_z, device='cuda', z_var=1.0):
    z_ = Distribution(torch.randn(G_batch_size, dim_z, requires_grad=False))
    z_.init_distribution('normal', mean=0, var=z_var)
    z_ = z_.to(device, torch.float32)
    return z_

def sampleG(G, z_, c_, parallel=False):
    with torch.no_grad():
        z_.sample_()
        c_.sample_()
        if parallel:
            G_z =  nn.parallel.data_parallel(G,[z_,c_])
        else:
            G_z = G(z_,c_)
        return G_z

def sampleFixedLabels(numClasses,batchSize,device):
    a = [1]*batchSize
    res = []
    for i in range(numClasses):
        res += [t*i for t in a]
    pseudo_labels = torch.tensor(res).long().to(device)
    return pseudo_labels

# 2. Datasets

In [4]:
def get_cifar100_superclass_loader(image_path, image_size, batch_size, num_workers=2):
    """
    Creates a data loader for CIFAR-100 that returns superclass labels instead of fine labels
    """
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = datasets.CIFAR100(root=image_path, train=True, download=True, transform=transform)

    # CIFAR-100 superclass labels (coarse_labels)
    superclass_map = {
        'aquatic mammals': [4, 30, 55, 72, 95],
        'fish': [1, 32, 67, 73, 91],
        'flowers': [54, 62, 70, 82, 92],
        'food containers': [9, 10, 16, 28, 61],
        'fruit and vegetables': [0, 51, 53, 57, 83],
        'household electrical devices': [22, 39, 40, 86, 87],
        'household furniture': [5, 20, 25, 84, 94],
        'insects': [6, 7, 14, 18, 24],
        'large carnivores': [3, 42, 43, 88, 97],
        'large man-made outdoor things': [12, 17, 37, 68, 76],
        'large natural outdoor scenes': [23, 33, 49, 60, 71],
        'large omnivores and herbivores': [15, 19, 21, 31, 38],
        'medium-sized mammals': [34, 35, 46, 98, 99],
        'non-insect invertebrates': [26, 45, 77, 79, 93],
        'people': [2, 11, 36, 66, 96],
        'reptiles': [27, 29, 44, 78, 80],
        'small mammals': [8, 13, 48, 58, 90],
        'trees': [41, 47, 52, 56, 59],
        'vehicles 1': [50, 63, 64, 65, 85],
        'vehicles 2': [69, 74, 75, 81, 89]
    }

    # Create reverse mapping from fine label to superclass index
    fine_to_super = {}
    for super_idx, (_, fine_labels) in enumerate(superclass_map.items()):
        for fine_label in fine_labels:
            fine_to_super[fine_label] = super_idx

    # Modify the dataset's targets to use superclass labels
    dataset.targets = [fine_to_super[target] for target in dataset.targets]

    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers
    )

    return loader

# 3. Define Model

## 3.1. Generator

In [None]:
class ConditionalBatchNorm2d(nn.BatchNorm2d):
    """Conditional Batch Normalization"""

    def __init__(self, num_features, eps=1e-05, momentum=0.1,
                 affine=False, track_running_stats=True):
        super(ConditionalBatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats
        )

    def forward(self, input, weight, bias, **kwargs):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            self.num_batches_tracked += 1
            if self.momentum is None:  # use cumulative moving average
                exponential_average_factor = 1.0 / self.num_batches_tracked.item()
            else:  # use exponential moving average
                exponential_average_factor = self.momentum

        output = F.batch_norm(input, self.running_mean, self.running_var,
                              self.weight, self.bias,
                              self.training or not self.track_running_stats,
                              exponential_average_factor, self.eps)
        if weight.dim() == 1:
            weight = weight.unsqueeze(0)
        if bias.dim() == 1:
            bias = bias.unsqueeze(0)
        size = output.size()
        weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
        bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)
        return weight * output + bias

class CategoricalConditionalBatchNorm2d(ConditionalBatchNorm2d):
    def __init__(self, num_classes, num_features, eps=1e-5, momentum=0.1,
                 affine=False, track_running_stats=True):
        super(CategoricalConditionalBatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats
        )
        self.weights = nn.Embedding(num_classes, num_features)
        self.biases = nn.Embedding(num_classes, num_features)

        self._initialize()

    def _initialize(self):
        init.ones_(self.weights.weight.data)
        init.zeros_(self.biases.weight.data)

    def forward(self, input, c, **kwargs):
        weight = self.weights(c)
        bias = self.biases(c)

        return super(CategoricalConditionalBatchNorm2d, self).forward(input, weight, bias)

def _upsample(x):
    h, w = x.size()[2:]
    return F.interpolate(x, size=(h * 2, w * 2))

class GenResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, h_ch=None, ksize=3, pad=1,
                 activation=F.relu, upsample=False, num_classes=0):
        super(GenResBlock, self).__init__()

        self.activation     = activation
        self.upsample       = upsample
        self.learnable_sc   = in_ch != out_ch or upsample
        if h_ch is None:
            h_ch = out_ch
        self.num_classes = num_classes
        # Register layrs
        self.c1 = nn.Conv2d(in_ch, h_ch, ksize, 1, pad)
        self.c2 = nn.Conv2d(h_ch, out_ch, ksize, 1, pad)
        if self.num_classes > 0:
            self.b1 = CategoricalConditionalBatchNorm2d(num_classes, in_ch)
            self.b2 = CategoricalConditionalBatchNorm2d(num_classes, h_ch)
        else:
            self.b1 = nn.BatchNorm2d(in_ch)
            self.b2 = nn.BatchNorm2d(h_ch)
        if self.learnable_sc:
            self.c_sc = nn.Conv2d(in_ch, out_ch, 1)
        self._initialize()

    def _initialize(self):
        init.xavier_uniform_(self.c1.weight.data, gain=math.sqrt(2))
        init.xavier_uniform_(self.c2.weight.data, gain=math.sqrt(2))
        if self.learnable_sc:
            init.xavier_uniform_(self.c_sc.weight.data, gain=1)

    def forward(self, x, y=None, z=None, **kwargs):
        return self.shortcut(x) + self.residual(x, y, z)

    def shortcut(self, x, **kwargs):
        if self.learnable_sc:
            if self.upsample:
                h = _upsample(x)
            h = self.c_sc(h)
            return h
        else:
            return x

    def residual(self, x, y=None, z=None, **kwargs):
        if y is not None:
            h = self.b1(x, y, **kwargs)
        else:
            h = self.b1(x)
        h = self.activation(h)
        if self.upsample:
            h = _upsample(h)
        h = self.c1(h)
        if y is not None:
            h = self.b2(h, y, **kwargs)
        else:
            h = self.b2(h)
        return self.c2(self.activation(h))

class ResNetGenerator(nn.Module):
    """Generator generates 32x32."""
    def __init__(self, num_features=64, dim_z=128, bottom_width=4, num_classes=0,
                 activation=F.relu):
        super(ResNetGenerator, self).__init__()
        self.num_features = num_features
        self.dim_z = dim_z
        self.bottom_width = bottom_width
        self.activation = activation
        self.num_classes = num_classes

        self.l1 = nn.Linear(dim_z, 8 * num_features * bottom_width ** 2)

        self.block2 = GenResBlock(num_features * 8, num_features * 4,
                            activation=activation, upsample=True,
                            num_classes=num_classes)
        self.block3 = GenResBlock(num_features * 4, num_features * 2,
                            activation=activation, upsample=True,
                            num_classes=num_classes)
        self.block4 = GenResBlock(num_features * 2, num_features,
                            activation=activation, upsample=True,
                            num_classes=num_classes)
        self.b5     = nn.BatchNorm2d(num_features)
        self.conv5  = nn.Conv2d(num_features, 3, 1, 1)
        self._initialize()

    def _initialize(self):
        init.xavier_uniform_(self.l1.weight.data)
        init.xavier_uniform_(self.conv5.weight.data)

    def forward(self, z, y=None, **kwargs):
        h = self.l1(z).view(z.size(0), -1, self.bottom_width, self.bottom_width)
        for i in range(2, 5):
            h = getattr(self, 'block{}'.format(i))(h, y, **kwargs)
        h = self.activation(self.b5(h))
        return torch.tanh(self.conv5(h))

## 3.2. Discriminator

In [6]:
class OptimizedBlock(nn.Module):
    def __init__(self, in_ch, out_ch, ksize=3, pad=1, activation=F.relu):
        super(OptimizedBlock, self).__init__()
        self.activation = activation

        self.c1 = utils.spectral_norm(nn.Conv2d(in_ch, out_ch, ksize, 1, pad))
        self.c2 = utils.spectral_norm(nn.Conv2d(out_ch, out_ch, ksize, 1, pad))
        self.c_sc = utils.spectral_norm(nn.Conv2d(in_ch, out_ch, 1, 1, 0))

        self._initialize()

    def _initialize(self):
        init.xavier_uniform_(self.c1.weight.data, math.sqrt(2))
        init.xavier_uniform_(self.c2.weight.data, math.sqrt(2))
        init.xavier_uniform_(self.c_sc.weight.data)

    def forward(self, x):
        return self.shortcut(x) + self.residual(x)

    def shortcut(self, x):
        return self.c_sc(F.avg_pool2d(x, 2))

    def residual(self, x):
        h = self.activation(self.c1(x))
        return F.avg_pool2d(self.c2(h), 2)

In [7]:
class DisResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, h_ch=None, ksize=3, pad=1,
                 activation=F.relu, downsample=False):
        super(DisResBlock, self).__init__()

        self.activation = activation
        self.downsample = downsample

        self.learnable_sc = (in_ch != out_ch) or downsample
        if h_ch is None:
            h_ch = in_ch
        else:
            h_ch = out_ch

        self.c1 = utils.spectral_norm(nn.Conv2d(in_ch, h_ch, ksize, 1, pad))
        self.c2 = utils.spectral_norm(nn.Conv2d(h_ch, out_ch, ksize, 1, pad))
        if self.learnable_sc:
            self.c_sc = utils.spectral_norm(nn.Conv2d(in_ch, out_ch, 1, 1, 0))

        self._initialize()

    def _initialize(self):
        init.xavier_uniform_(self.c1.weight.data, math.sqrt(2))
        init.xavier_uniform_(self.c2.weight.data, math.sqrt(2))
        if self.learnable_sc:
            init.xavier_uniform_(self.c_sc.weight.data)

    def forward(self, x):
        return self.shortcut(x) + self.residual(x)

    def shortcut(self, x):
        if self.learnable_sc:
            x = self.c_sc(x)
        if self.downsample:
            return F.avg_pool2d(x, 2)
        return x

    def residual(self, x):
        h = self.c1(self.activation(x))
        h = self.c2(self.activation(h))
        if self.downsample:
            h = F.avg_pool2d(h, 2)
        return h

In [8]:
class SNResNetProjectionDiscriminator(nn.Module):
    def __init__(self, num_features, num_classes=0, activation=F.relu):
        super(SNResNetProjectionDiscriminator, self).__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.activation = activation

        # First block starts with OptimizedBlock
        self.block1 = OptimizedBlock(3, num_features)
        # Reduce number of downsampling layers
        self.block2 = DisResBlock(num_features, num_features * 2,
                            activation=activation, downsample=True)
        self.block3 = DisResBlock(num_features * 2, num_features * 4,
                            activation=activation, downsample=True)
        self.block4 = DisResBlock(num_features * 4, num_features * 8,
                            activation=activation, downsample=True)
        # Remove block5 and block6 which were causing the size to become too small

        # Adjust final linear layer to match new feature size
        self.l7 = utils.spectral_norm(nn.Linear(num_features * 8, 1))
        if num_classes > 0:
            self.l_y = utils.spectral_norm(
                nn.Embedding(num_classes, num_features * 8))

        self._initialize()

    def _initialize(self):
        init.xavier_uniform_(self.l7.weight.data)
        optional_l_y = getattr(self, 'l_y', None)
        if optional_l_y is not None:
            init.xavier_uniform_(optional_l_y.weight.data)

    def forward(self, x, y=None):
        h = x
        # Reduce number of blocks in forward pass
        for i in range(1, 5):  # Changed from range(1, 7)
            h = getattr(self, f'block{i}')(h)
        h = self.activation(h)
        # Global pooling
        h = torch.sum(h, dim=(2, 3))
        output = self.l7(h)
        if y is not None:
            output += torch.sum(self.l_y(y) * h, dim=1, keepdim=True)
        return output


# 4. Trainer

In [9]:
class Trainer(object):
    def __init__(self, data_loader):
        # Set device
        self.device = torch.device(f'cuda:{CUDA}' if torch.cuda.is_available() else 'cpu')

        # Basic configurations
        self.n_classes = 20 
        self.data_loader = data_loader

        # Model hyperparameters
        self.imsize = IMSIZE
        self.z_dim = Z_DIM
        self.g_conv_dim = G_CONV_DIM
        self.d_conv_dim = D_CONV_DIM

        # Training settings
        self.total_step = TOTAL_STEP
        self.batch_size = BATCH_SIZE
        self.g_lr = G_LR
        self.d_lr = D_LR
        self.beta1 = BETA1
        self.beta2 = BETA2
        self.DStep = D_STEP
        self.GStep = G_STEP

        # Paths
        self.log_path = os.path.join(LOG_PATH, VERSION)
        self.sample_path = os.path.join(SAMPLE_PATH, VERSION)
        self.model_save_path = os.path.join(MODEL_SAVE_PATH, VERSION)
        self.checkpoint_dir = os.path.join(MODEL_SAVE_PATH, 'checkpoints')
        os.makedirs(self.checkpoint_dir, exist_ok=True)

        # Step settings
        self.start_step = 0
        self.log_step = LOG_STEP
        self.sample_step = SAMPLE_STEP
        self.model_save_step = MODEL_SAVE_STEP
        self.metric_calculation_step = METRIC_CALCULATION_STEP

        # Initialize reporter
        self.report_file = os.path.join(LOG_PATH, VERSION, VERSION+"_report.log")
        self.reporter = Reporter(self.report_file)

        # Build model and initialize or load checkpoint
        self.build_model()
        
        # Initialize tensorboard writer
        self.writer = SummaryWriter(log_dir=self.log_path)

        # Initialize Inception network
        self.inception_net = load_inception_net()
        
        # Get real data statistics for FID
        self.real_pool, self.real_logits, self.real_labels = self.get_data_statistics()
        self.real_mu = np.mean(self.real_pool, axis=0)
        self.real_sigma = np.cov(self.real_pool, rowvar=False)

        # Load checkpoint if resuming
        if RESUME_TRAINING:
            self.load_checkpoint(RESUME_CHECKPOINT)

        # Write model
        self.reporter.writeModel(self.G.__str__())
        self.reporter.writeModel(self.D.__str__())

    def save_checkpoint(self, step, is_best=False):
        """Save complete training state"""
        checkpoint = {
            'step': step,
            'G_state_dict': self.G.state_dict(),
            'D_state_dict': self.D.state_dict(),
            'g_optimizer_state_dict': self.g_optimizer.state_dict(),
            'd_optimizer_state_dict': self.d_optimizer.state_dict(),
            'g_scheduler_state_dict': self.g_scheduler.state_dict() if self.g_scheduler else None,
            'd_scheduler_state_dict': self.d_scheduler.state_dict() if self.d_scheduler else None,
            'random_state': {
                'python': random.getstate(),
                'numpy': np.random.get_state(),
                'torch': torch.get_rng_state(),
                'cuda': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
            },
            'metrics': {
                'best_fid': getattr(self, 'best_fid', float('inf')),
                'best_is': getattr(self, 'best_is', 0)
            }
        }
        
        checkpoint_path = os.path.join(self.checkpoint_dir, f'checkpoint_{step}.pt')
        torch.save(checkpoint, checkpoint_path)
        
        if is_best:
            best_path = os.path.join(self.checkpoint_dir, 'best_model.pt')
            torch.save(checkpoint, best_path)
        
        # Save latest checkpoint reference
        latest_path = os.path.join(self.checkpoint_dir, 'latest.txt')
        with open(latest_path, 'w') as f:
            f.write(str(step))
            
        self.reporter.writeInfo(f"Saved checkpoint at step {step}")

    def load_checkpoint(self, checkpoint_path=None):
        """Load training state from checkpoint"""
        if checkpoint_path is None:
            # Try to load latest checkpoint
            latest_path = os.path.join(self.checkpoint_dir, 'latest.txt')
            if not os.path.exists(latest_path):
                self.reporter.writeInfo("No checkpoint found, starting from scratch")
                return False
                    
            with open(latest_path, 'r') as f:
                step = int(f.read().strip())
            checkpoint_path = os.path.join(self.checkpoint_dir, f'checkpoint_{step}.pt')
    
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"No checkpoint found at {checkpoint_path}")
    
        self.reporter.writeInfo(f"Loading checkpoint from {checkpoint_path}")
        
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
        # Restore model and optimizer states
        self.G.load_state_dict(checkpoint['G_state_dict'])
        self.D.load_state_dict(checkpoint['D_state_dict'])
        self.g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
        self.d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])
        
        # Restore scheduler states if they exist
        if checkpoint['g_scheduler_state_dict'] and hasattr(self, 'g_scheduler'):
            self.g_scheduler.load_state_dict(checkpoint['g_scheduler_state_dict'])
        if checkpoint['d_scheduler_state_dict'] and hasattr(self, 'd_scheduler'):
            self.d_scheduler.load_state_dict(checkpoint['d_scheduler_state_dict'])
        
        # Safely restore random states
        try:
            # Restore Python's random state
            random.setstate(checkpoint['random_state']['python'])
            
            # Restore NumPy's random state
            np.random.set_state(checkpoint['random_state']['numpy'])
            
            # Restore PyTorch's random state
            torch_state = checkpoint['random_state']['torch']
            if isinstance(torch_state, torch.Tensor):
                if torch_state.dtype != torch.uint8:
                    torch_state = torch_state.byte()
                torch.set_rng_state(torch_state)
            
            # Restore CUDA random state if available
            if torch.cuda.is_available() and checkpoint['random_state']['cuda'] is not None:
                cuda_state = checkpoint['random_state']['cuda']
                if isinstance(cuda_state, list):
                    cuda_state = [state.byte() if state.dtype != torch.uint8 else state 
                                for state in cuda_state]
                torch.cuda.set_rng_state_all(cuda_state)
        except Exception as e:
            self.reporter.writeInfo(f"Warning: Could not restore random states: {str(e)}")
            self.reporter.writeInfo("Continuing with current random states")
        
        # Restore metrics
        self.best_fid = checkpoint['metrics']['best_fid']
        self.best_is = checkpoint['metrics']['best_is']
        
        # Set starting step
        self.start_step = checkpoint['step']
        
        self.reporter.writeInfo(f"Resumed training from step {self.start_step}")
        return True


    def train(self):
        # Data iterator
        data_iter = iter(self.data_loader)
        
        # Fixed input for debugging
        sampleBatch = 10
        fixed_z = torch.randn(self.n_classes*sampleBatch, self.z_dim)
        fixed_z = fixed_z.to(self.device)
        fixed_c = sampleFixedLabels(self.n_classes,sampleBatch,self.device)

        runingZ,runingLabel = prepare_z_c(self.batch_size, self.z_dim, self.n_classes, device=self.device)

        # Start time
        start_time = time.time()
        self.reporter.writeInfo(f"Start/Resume training from step {self.start_step}")
        dstepCounter = 0
        gstepCounter = 0

        # Time limit in seconds (3 days = 72 hours = 259200 seconds)
        time_limit = 259200
    
        from tqdm import tqdm
        
        # Create GradScaler for mixed precision training
        scaler = GradScaler()

        # Main training progress bar
        pbar = tqdm(range(self.start_step, self.total_step),
                    desc='Training',
                    total=self.total_step,
                    initial=self.start_step)

        # Initialize metrics for progress bar
        metrics = {'d_loss_real': 0, 'd_loss_fake': 0, 'g_loss': 0}
        
        try:
            for step in pbar:
                current_time = time.time()
                elapsed_time = current_time - start_time
                if elapsed_time >= time_limit:
                    self.reporter.writeInfo(f"Training stopped: Time limit (3 days) exceeded")
                    self.save_checkpoint(step, is_best=False)
                    break
                # ================== Train D ================== #
                self.D.train()
                self.G.train()

                if dstepCounter < self.DStep:
                    try:
                        realImages, realLabel = next(data_iter)
                    except StopIteration:
                        data_iter = iter(self.data_loader)
                        realImages, realLabel = next(data_iter)

                    # Move data to device
                    realImages = realImages.to(self.device)
                    realLabel = realLabel.to(self.device).long()

                    # Train with real images
                    with autocast():
                        d_out_real = self.D(realImages, realLabel)
                        d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()

                        # Generate fake images
                        runingZ.sample_()
                        runingLabel.sample_()
                        with torch.no_grad():
                            fake_images = self.G(runingZ, runingLabel)
                        
                        d_out_fake = self.D(fake_images.detach(), runingLabel)
                        d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()
                        
                        d_loss = d_loss_real + d_loss_fake

                    # Backward pass
                    self.reset_grad()
                    scaler.scale(d_loss).backward()
                    scaler.step(self.d_optimizer)
                    scaler.update()

                    dstepCounter += 1
                    
                    # Update metrics
                    metrics['d_loss_real'] = d_loss_real.item()
                    metrics['d_loss_fake'] = d_loss_fake.item()

                # ================== Train G ================== #
                else:
                    with autocast():
                        runingZ.sample_()
                        runingLabel.sample_()
                        fake_images = self.G(runingZ, runingLabel)
                        g_out_fake = self.D(fake_images, runingLabel)
                        g_loss = -g_out_fake.mean()

                    # Backward pass
                    self.reset_grad()
                    scaler.scale(g_loss).backward()
                    scaler.step(self.g_optimizer)
                    scaler.update()

                    gstepCounter += 1
                    
                    # Update metric
                    metrics['g_loss'] = g_loss.item()

                # Reset counters if necessary
                if gstepCounter == self.GStep:
                    dstepCounter = 0
                    gstepCounter = 0

                # Step schedulers if they exist
                if hasattr(self, 'g_scheduler') and self.g_scheduler is not None:
                    self.g_scheduler.step()
                if hasattr(self, 'd_scheduler') and self.d_scheduler is not None:
                    self.d_scheduler.step()
                    
                # Update progress bar with time information
                remaining_time = time_limit - elapsed_time if elapsed_time < time_limit else 0
                hours_remaining = remaining_time // 3600
                minutes_remaining = (remaining_time % 3600) // 60
                
                # Update progress bar
                pbar.set_postfix({
                    'd_real': f"{metrics['d_loss_real']:.4f}",
                    'd_fake': f"{metrics['d_loss_fake']:.4f}",
                    'g_loss': f"{metrics['g_loss']:.4f}"
                })

                # Log to tensorboard
                if (step + 1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))
                    
                    self.writer.add_scalar('loss/d_real', metrics['d_loss_real'], step + 1)
                    self.writer.add_scalar('loss/d_fake', metrics['d_loss_fake'], step + 1)
                    self.writer.add_scalar('loss/g_loss', metrics['g_loss'], step + 1)
                    self.writer.add_scalar('time/hours_remaining', hours_remaining, step + 1)

                # Save generated samples
                if (step + 1) % self.sample_step == 0:
                    with torch.no_grad():
                        fake_images = self.G(fixed_z, fixed_c)
                        save_image(denorm(fake_images.data),
                                os.path.join(self.sample_path, f'{step + 1}_fake.png'),
                                nrow=self.n_classes)

                # Save checkpoint
                if (step + 1) % self.model_save_step == 0:
                    self.save_checkpoint(step + 1)

                # Calculate metrics
                if (step + 1) % self.metric_calculation_step == 0:
                    with tqdm(total=1, desc="Calculating FID and IS") as metric_pbar:
                        fid, inception_score = self.get_inception_metrics()
                        
                        # Check if this is the best model
                        is_best = False
                        if not hasattr(self, 'best_fid') or fid < self.best_fid:
                            self.best_fid = fid
                            is_best = True
                            
                        if not hasattr(self, 'best_is') or inception_score > self.best_is:
                            self.best_is = inception_score
                            is_best = True

                        if is_best:
                            self.save_checkpoint(step + 1, is_best=True)
                        
                        # Log metrics
                        self.writer.add_scalar('metrics/FID', fid, step + 1)
                        self.writer.add_scalar('metrics/IS', inception_score, step + 1)
                        self.reporter.writeTrainLog(step + 1, 
                            f"FID: {fid:.4f}, IS: {inception_score:.4f}")
                        
                        metric_pbar.update(1)

        except Exception as e:
            # Save checkpoint on error
            self.reporter.writeInfo(f"Training interrupted at step {step}: {str(e)}")
            self.save_checkpoint(step)
            raise e

        # Calculate and log total training time
        total_time = time.time() - start_time
        print(f"Total time is {total_time}s")
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        self.reporter.writeInfo(f"Total training time: {total_time_str}")
        self.writer.add_text('training/total_time', total_time_str)
        
        # Save final checkpoint
        self.save_checkpoint(self.total_step)
        pbar.close()
        
    def get_data_statistics(self):
        """실제 데이터의 inception statistics를 계산"""
        pool, logits, labels = [], [], []
        device = next(self.inception_net.parameters()).device

        for i, (x, y) in enumerate(self.data_loader):
            x = x.to(device)
            with torch.no_grad():
                pool_val, logits_val = self.inception_net(x)
                pool += [np.asarray(pool_val.cpu())]
                logits += [np.asarray(F.softmax(logits_val, 1).cpu())]
                labels += [np.asarray(y.cpu())]

        pool, logits, labels = [np.concatenate(item, 0) for item in [pool, logits, labels]]
        return pool, logits, labels

    def accumulate_inception_activations(self, num_inception_images=50000):
        """생성된 이미지의 inception statistics를 계산"""
        pool, logits, labels = [], [], []
        while (torch.cat(logits, 0).shape[0] if len(logits) else 0) < num_inception_images:
            with torch.no_grad():
                z_, c_ = prepare_z_c(self.batch_size, self.z_dim, self.n_classes, device=self.device)
                z_.sample_()
                c_.sample_()
                images = self.G(z_, c_)
                pool_val, logits_val = self.inception_net(images)
                pool += [pool_val]
                logits += [F.softmax(logits_val, 1)]
                labels += [c_]

        return (torch.cat(pool, 0),
                torch.cat(logits, 0),
                torch.cat(labels, 0))

    def calculate_inception_score(self, pred, num_splits=10):
        """Inception Score 계산"""
        scores = []
        for index in range(num_splits):
            pred_chunk = pred[index * (pred.shape[0] // num_splits):
                            (index + 1) * (pred.shape[0] // num_splits), :]
            kl_inception = pred_chunk * (np.log(pred_chunk) -
                                       np.log(np.expand_dims(np.mean(pred_chunk, 0), 0)))
            kl_inception = np.mean(np.sum(kl_inception, 1))
            scores.append(np.exp(kl_inception))
        return np.mean(scores), np.std(scores)

    def calculate_fid(self, mu1, sigma1, mu2, sigma2):
        """Fréchet Inception Distance 계산"""
        diff = mu1 - mu2
        covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False)
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
        return fid

    def get_inception_metrics(self):
        """FID와 Inception score를 계산"""
        # 생성된 이미지의 statistics 수집
        g_pool, g_logits, g_labels = self.accumulate_inception_activations(num_inception_images=50000)

        # Numpy로 변환
        g_pool_np = g_pool.cpu().numpy()
        g_logits_np = g_logits.cpu().numpy()

        # 생성된 이미지의 통계치 계산
        mu = np.mean(g_pool_np, axis=0)
        sigma = np.cov(g_pool_np, rowvar=False)

        # FID 계산
        fid = self.calculate_fid(self.real_mu, self.real_sigma, mu, sigma)

        # Inception Score 계산
        is_mean, is_std = self.calculate_inception_score(g_logits_np)

        return fid, is_mean

    def sample_for_inception(self):
        """Sample function that returns images and labels"""
        z_, c_ = prepare_z_c(self.batch_size, self.z_dim, self.n_classes, device=self.device)
        z_.sample_()
        c_.sample_()
        with torch.no_grad():
            imgs = self.G(z_, c_)
        return imgs, c_

    def build_model(self):
        """Initialize generator and discriminator with optimizers and schedulers"""
        # Initialize generator and discriminator
        self.G = ResNetGenerator(self.g_conv_dim, self.z_dim, 4, num_classes=self.n_classes)
        self.D = SNResNetProjectionDiscriminator(self.d_conv_dim, self.n_classes)
    
        # Move models to device
        self.G = self.G.to(self.device)
        self.D = self.D.to(self.device)
    
        # Initialize optimizers
        self.g_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.G.parameters()),
            self.g_lr, 
            [self.beta1, self.beta2]
        )
        self.d_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.D.parameters()),
            self.d_lr, 
            [self.beta1, self.beta2]
        )
    
        # Initialize learning rate schedulers
        self.g_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.g_optimizer,
            lr_lambda=lambda step: max(1 - step / self.total_step, 0)
        )
        self.d_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.d_optimizer,
            lr_lambda=lambda step: max(1 - step / self.total_step, 0)
        )
    
        # Enable gradient checkpointing for memory efficiency if available
        if hasattr(self.G, 'gradient_checkpointing_enable'):
            self.G.gradient_checkpointing_enable()
        if hasattr(self.D, 'gradient_checkpointing_enable'):
            self.D.gradient_checkpointing_enable()
    
        # Log model architectures
        self.reporter.writeInfo("Generator Architecture:")
        self.reporter.writeInfo(str(self.G))
        self.reporter.writeInfo("Discriminator Architecture:")
        self.reporter.writeInfo(str(self.D))
    
        # Log number of parameters
        g_params = sum(p.numel() for p in self.G.parameters() if p.requires_grad)
        d_params = sum(p.numel() for p in self.D.parameters() if p.requires_grad)
        self.reporter.writeInfo(f"Generator parameters: {g_params:,}")
        self.reporter.writeInfo(f"Discriminator parameters: {d_params:,}")
    
        # Initialize weights if not loading from checkpoint
        if not hasattr(self, 'start_step') or self.start_step == 0:
            self.reporter.writeInfo("Initializing model weights from scratch")

    def load_pretrained_model(self):
        self.G.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_G.pth'.format(self.chechpoint_step))))
        self.D.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_D.pth'.format(self.chechpoint_step))))
        print('loaded trained models (step: {})..!'.format(self.chechpoint_step))

    def reset_grad(self):
        self.d_optimizer.zero_grad()
        self.g_optimizer.zero_grad()

# 5. Train

In [None]:
def main():
    # For fast training
    cudnn.benchmark = True

    # Create directories
    makeFolder(MODEL_SAVE_PATH, VERSION)
    makeFolder(SAMPLE_PATH, VERSION)
    makeFolder(LOG_PATH, VERSION)

    # Data loader
    data_loader = get_cifar100_superclass_loader(IMAGE_PATH, IMSIZE, 
                                                 BATCH_SIZE, NUM_WORKERS)

    trainer = Trainer(data_loader)
    trainer.train()

if __name__ == '__main__':
    main()

In [None]:
# 훈련이 끝나면 총 학습 시간이  출력됩니다. 
# The total training time will be displayed once the training is complete.

# 6. Evaluate

In [None]:
import inceptionID

In [None]:
# Load model and get samples
def sample_from_model(G, batch_size, z_dim, device):
    samples_per_class = batch_size // 20  # 각 클래스당 정확한 샘플 수 계산
    z = torch.randn(batch_size, z_dim, device=device)
    labels = torch.tensor(np.repeat(np.arange(20), samples_per_class)).to(device)
    
    with torch.no_grad():
        samples = G(z, labels)
    
    return samples, labels

In [None]:
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Data loading
    norm_mean = [0.5, 0.5, 0.5]
    norm_std = [0.5, 0.5, 0.5]
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((299, 299)),
        transforms.Normalize(norm_mean, norm_std)
    ])
    
    # Load CIFAR100
    train_dataset = torchvision.datasets.CIFAR100(
        root="./data",
        train=True,
        download=True,
        transform=transform
    )
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=64,
        shuffle=True
    )
    
    # Load inception network
    net = inceptionID.load_inception_net()
    
    # Get real data statistics
    print("Calculating real data statistics...")
    pool, logits, labels = inceptionID.get_net_output(
        device=device,
        train_loader=train_loader,
        net=net
    )
    mu_data, sigma_data = np.mean(pool, axis=0), np.cov(pool, rowvar=False)
    
    # Load generator from best_model.pt
    print("Loading best model...")
    checkpoint = torch.load('models/cifar100_superclass/checkpoints/best_model.pt', map_location=device)
    
    G = ResNetGenerator(
        num_features=128,  # g_conv_dim from your config
        dim_z=128,        # z_dim from your config
        bottom_width=4,
        num_classes=20    # for CIFAR100 superclasses
    ).to(device)
    
    G.load_state_dict(checkpoint['G_state_dict'])
    G.eval()
    
    # Create sampling function
    def sample():
        return sample_from_model(G, batch_size=400, z_dim=128, device=device)
    
    # Get generator statistics
    print("Calculating generator statistics...")
    g_pool, g_logits, g_labels = inceptionID.accumulate_inception_activations(
        sample, net, 50000
    )
    
    # Ensure we use exactly 50000 samples
    g_pool = g_pool[:50000]
    g_logits = g_logits[:50000]
    g_labels = g_labels[:50000]
    
    mu, sigma = np.mean(g_pool.cpu().numpy(), axis=0), np.cov(g_pool.cpu().numpy(), rowvar=False)
    
    # Calculate FID
    fid = inceptionID.calculate_fid(mu_data, sigma_data, mu, sigma)
    print(f"FID Score: {fid:.4f}")
    
    # Calculate Inception Score
    is_mean, is_std = inceptionID.calculate_inception_score(g_logits.cpu().numpy(), 10)
    print(f"Inception Score: {is_mean:.4f} ± {is_std:.4f}")
    
    # Calculate Intra-FID
    print("Calculating Intra-FID...")
    intra_fids_mean, intra_fids = inceptionID.calculate_intra_fid(
        pool, logits, labels,
        g_pool, g_logits, g_labels,
        chage_superclass=False
    )
    
    print(f"Mean Intra-FID: {intra_fids_mean:.4f}")
    print("\nIntra-FID scores per superclass:")
    for i, score in enumerate(intra_fids):
        print(f"Superclass {i}: {score:.4f}")

if __name__ == "__main__":
    main()