In [1]:
import os
import math
import torch
import numpy as np
from tqdm import tqdm
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import functools
import random
import time
from datetime import timedelta

import utils
import train_fns
from collections import Counter
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [2]:
def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_random_seed(112)

In [None]:
norm_mean = [0.5,0.5,0.5]
norm_std = [0.5,0.5,0.5]
image_size = 32,32

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std)
])


train_dataset = torchvision.datasets.CIFAR100(
    root="./data",  # 데이터 저장 경로
    train=True,     # 학습용 데이터셋
    download=True,  # 데이터셋 다운로드
    transform=train_transform
)


train_loader = DataLoader(train_dataset, batch_size=256,shuffle=True,
                         num_workers=8, pin_memory=True, drop_last = True)

In [None]:
superclass_mapping = {
      4: 0, 30: 0, 55: 0, 72: 0, 95: 0,           # aquatic mammals
      1: 1, 32: 1, 67: 1, 73: 1, 91: 1,           # fish
      54: 2, 62: 2, 70: 2, 82: 2, 92: 2,          # flowers
      9: 3, 10: 3, 16: 3, 28: 3, 61: 3,           # food containers
      0: 4, 51: 4, 53: 4, 57: 4, 83: 4,           # fruit and vegetables
      22: 5, 39: 5, 40: 5, 86: 5, 87: 5,          # household electrical devices
      5: 6, 20: 6, 25: 6, 84: 6, 94: 6,           # household furniture
      6: 7, 7: 7, 14: 7, 18: 7, 24: 7,            # insects
      3: 8, 42: 8, 43: 8, 88: 8, 97: 8,           # large carnivores 
      12: 9, 17: 9, 37: 9, 68: 9, 76: 9,          # large man-made outdoor things
      23: 10, 33: 10, 49: 10, 60: 10, 71: 10,     # large natural outdoor scenes
      15: 11, 19: 11, 21: 11, 31: 11, 38: 11,     # large omnivores and herbivores
      34: 12, 63: 12, 64: 12, 66: 12, 75: 12,     # medium-sized mammals
      26: 13, 45: 13, 77: 13, 79: 13, 99: 13,     # non-insect invertebrates
      2: 14, 11: 14, 35: 14, 46: 14, 98: 14,      # people
      27: 15, 29: 15, 44: 15, 78: 15, 93: 15,     # reptiles
      36: 16, 50: 16, 65: 16, 74: 16, 80: 16,     # small mammals
      47: 17, 52: 17, 56: 17, 59: 17, 96: 17,     # trees 
      8: 18, 13: 18, 48: 18, 58: 18, 90: 18,      # vehicles 1
      41: 19, 69: 19, 81: 19, 85: 19, 89: 19      # vehicles 2
      }

train_loader.dataset.targets = [superclass_mapping[label] for label in train_loader.dataset.targets]
print(set(train_loader.dataset.targets))
label_counts = Counter(train_loader.dataset.targets)

In [None]:
print("Superclass label counts:")
for label, count in sorted(label_counts.items()):
    print(f"Label {label}: {count} samples")

In [7]:
config = {'load_in_mem': False,  'dict_decay': 0.9, 
          'commitment': 10.0, 'discrete_layer': '0123', 'dict_size': 6,
          'model': 'BigGAN', 'G_param': 'SN', 'D_param': 'SN', 'G_ch': 64,
          'D_ch': 64, 'G_depth': 1, 'D_depth': 1, 'D_wide': True, 
          'G_shared': False, 'shared_dim': 0, 'dim_z': 128, 'z_var': 1.0, 
          'hier': False, 'cross_replica': False, 'mybn': False, 
          'G_nl': 'relu', 'D_nl': 'relu', 'G_attn': '0', 'D_attn': '0',
          'norm_style': 'bn', 'G_init': 'N02', 'D_init': 'N02', 
          'skip_init': False, 'G_lr': 0.0002, 'D_lr': 0.0001, 'G_B1': 0.0,
          'D_B1': 0.0, 'G_B2': 0.999, 'D_B2': 0.999, 'batch_size': 64, 
          'G_batch_size': 0, 'num_G_accumulations': 1, 'num_D_steps': 4, 
          'num_D_accumulations': 1, 'split_D': False, 'num_epochs': 700, 
          'parallel': False, 'D_mixed_precision': False, 'G_mixed_precision': False,
          'accumulate_stats': False, 'num_standing_accumulations': 16,
          'G_eval_mode': False, 'save_every': 1000, 'num_save_copies': 2, 
          'num_best_copies': 5, 'base_root': '', 'data_root': 'data', 'weights_root': 'weights',
          'samples_root': 'samples', 'name_suffix': 'quant', 'experiment_name': '',
          'ema': True, 'ema_decay': 0.9999, 'use_ema': True, 'ema_start': 1000, 
          'adam_eps': 1e-08, 'BN_eps': 1e-05, 'SN_eps': 1e-08, 'num_G_SVs': 1, 
          'num_D_SVs': 1, 'num_G_SV_itrs': 1, 'num_D_SV_itrs': 1, 'G_ortho': 0.0,
          'D_ortho': 0.0, 'toggle_grads': True, 'load_weights': '', 'resume': False}

config['resolution'] = 32
config['n_classes'] = 20
config['G_activation'] = nn.ReLU(inplace=False)
config['D_activation'] = nn.ReLU(inplace=False)

In [6]:
%load_ext autoreload
%autoreload 2

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [9]:
model = __import__(config['model'])

In [None]:
G = model.Generator(**config).to(device)
D = model.Discriminator(**config).to(device)
  

G_ema = model.Generator(**{**config, 'skip_init':True, 
                        'no_optim': True}).to(device)

ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])

In [None]:
GD = model.G_D(G, D)
print(G)
print(D)

In [12]:
state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,'config': config}

In [13]:
G_batch_size = max(config['G_batch_size'], config['batch_size'])
z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'],
                             device=device, fp16=False)

fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z,
                                       config['n_classes'], device=device,
                                       fp16=False)  
fixed_z.sample_()
fixed_y.sample_()

In [14]:
experiment_name = 'BigGAN'

In [15]:
train = train_fns.GAN_training_function(G, D, GD, z_, y_, 
                                            ema, state_dict, config)

In [None]:
start_time = time.time()

for epoch in range(state_dict['epoch'], config['num_epochs']):


    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{config['num_epochs']}")

    for i, (x, y) in enumerate(progress_bar):
        # Increment the iteration counter
        state_dict['itr'] += 1

        # Generator와 Discriminator를 training 모드로 설정
        G.train()
        D.train()

        G_ema.train()

        x, y = x.to(device), y.to(device)

        # Train 함수 실행
        metrics = train(x, y)

       
        # Progress bar에 현재 손실 업데이트
        progress_bar.set_postfix({
            "D_loss_real": metrics['D_loss_real'],
            "D_loss_fake": metrics['D_loss_fake'],
            "G_loss": metrics['G_loss']
        })

        # 모델 저장
        if not (state_dict['itr'] % config['save_every']):
            if config['G_eval_mode']:
                print('Switchin G to eval mode...')
                G.eval()
            train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, 
                                      state_dict, config, experiment_name)

        
    # Epoch 증가
    state_dict['epoch'] += 1
    
end_time = time.time()
elapsed_time = timedelta(seconds=end_time - start_time)

# Print the training time
print(f"Total training time: {elapsed_time}")


# Score Metrics

In [18]:
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader
import numpy as np
import inceptionID
import torch

In [None]:
norm_mean = [0.5,0.5,0.5]
norm_std = [0.5,0.5,0.5]
image_size = 32,32


train_transform = []
train_transform = transforms.Compose(train_transform + [
                     transforms.ToTensor(),
                        transforms.Resize((299, 299)),
                     transforms.Normalize(norm_mean, norm_std)])


train_dataset = torchvision.datasets.CIFAR100(
    root="./data",  # 데이터 저장 경로
    train=True,     # 학습용 데이터셋
    download=True,  # 데이터셋 다운로드
    transform=train_transform
)

train_loader = DataLoader(train_dataset, batch_size=64,
                              shuffle=True)
net = inceptionID.load_inception_net()

In [None]:
pool, logits, labels = inceptionID.get_net_output(device="cuda:0", train_loader=train_loader, net=net)
mu_data, sigma_data = np.mean(pool, axis=0), np.cov(pool, rowvar=False)

In [None]:
utils.load_weights(G if not (config['use_ema']) else None, None, state_dict, 
                     config['weights_root'], experiment_name, config['load_weights'],
                     G if config['ema'] and config['use_ema'] else None,
                     strict=False, load_optim=False)

In [22]:
G_batch_size = 200 
z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'],
                             device=device, fp16= False, 
                             z_var=config['z_var'])
y_ = [ i for i in range(0,20)] * 10
y_ = torch.tensor(y_, dtype=torch.long).cuda()
sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)

In [23]:
g_pool, g_logits, g_labels = inceptionID.accumulate_inception_activations(sample, net, 50000)

g_pool = g_pool[:50000]
g_logits= g_logits[:50000]
g_labels = g_labels[:50000]

## FID

In [24]:
mu, sigma = np.mean(g_pool.cpu().numpy(), axis=0), np.cov(g_pool.cpu().numpy(), rowvar=False)

In [None]:
print("FID : ", inceptionID.calculate_fid(mu_data, sigma_data, mu, sigma))

## Inception Score

In [None]:
m, cov = inceptionID.calculate_inception_score(g_logits.cpu().numpy(), 10)
print("mean : ", m)
print("cov : ", cov)

## Intra-FID

In [27]:
intra_fids_mean, intra_fids = inceptionID.calculate_intra_fid(pool, logits, labels, g_pool, g_logits, g_labels, chage_superclass=False)

In [None]:
print("intra-FID : ", intra_fids_mean)

In [None]:
for i in range(len(intra_fids)):
    print(f"superclass intra-score {i}: {intra_fids[i]}")