In [1]:
import torch
import torch.optim as optim

In [None]:
import torch_mimicry as mmc

In [3]:
import models.ssd_sngan_32 as ssd_sngan

In [4]:
# 支持选择图片库 ('CIFAR10', 'CIFAR100', 'STL10', 'FashionMNIST' )，选择sample size，选择image大小，和grayscale
import os
from torch.utils.data import random_split
from torchvision import transforms, datasets

def dataset_split_shape(name, n = 60000, size = 32, grayScale = False, convert_tensor=True, transform_data = True, root='./datasets/', download=True):
  
  dataset_dir = os.path.join(root, f"{name}_{n}_{size}")
  if not os.path.exists(dataset_dir):
    os.makedirs(dataset_dir)

  if transform_data:
    transform_list = [transforms.ToTensor(),
                      transforms.Resize(size)]
    if grayScale:
      transform_list.append(transforms.Grayscale()) 
    if convert_tensor:
      transform_list.append(transforms.Normalize((0.5, ), (0.5, )))
  else:
    transform_list = []

  transformer = transforms.Compose(transform_list)

  if name == 'CIFAR10': 
    dataset = datasets.CIFAR10(
      root=dataset_dir,
      download=download,
      transform=transforms.Compose(transform_list)
  )
  elif name == 'CIFAR100': 
    dataset = datasets.CIFAR100(
      root=dataset_dir,
      download=download,
      transform=transforms.Compose(transform_list)
  )
  elif name == 'STL10': #与原package相同，默认选择unlabeled的数据
    dataset = datasets.STL10(
      root=dataset_dir,
      download=download,
      split='unlabeled',
      transform=transforms.Compose(transform_list)
  )
  elif name == 'FashionMNIST': 
    dataset = datasets.FashionMNIST(
      root=dataset_dir,
      download=download,
      transform=transforms.Compose(transform_list)
  )
  else:
    print("invalid name")
    return 
  if n is None or n > len(dataset):
    return dataset
  generator1 = torch.Generator().manual_seed(42)
  a, b = random_split(dataset, [n, len(dataset)-n], generator = generator1)
  return a

In [None]:
#选取CIFAR里面的500张图，size改为32*32
CIFAR10_500_32 = dataset_split_shape('CIFAR10', n = 60000, size = 32)

In [None]:
print('check dataset')
print(f"number of images: {len(CIFAR10_500_32)}")
print(f"shape of images: {CIFAR10_500_32[0][0].shape}")

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
CIFAR10_500_32_dataloader = torch.utils.data.DataLoader(
    CIFAR10_500_32, batch_size=64, shuffle=True, num_workers=4)

In [8]:
# Define models and optimizers
netG = ssd_sngan.SSD_SNGANGenerator32().to(device)
netD = ssd_sngan.SSD_SNGANDiscriminator32().to(device)
optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9))
optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))

# Start training
trainer = mmc.training.Trainer(
    netD=netD,
    netG=netG,
    optD=optD,
    optG=optG,
    n_dis=5,
    num_steps=100000, # number of iterations
    lr_decay='linear',
    dataloader= CIFAR10_500_32_dataloader,
    log_dir='./log/CIFAR10_500_32', #自定义地址
    device=device)

In [None]:
trainer.train() # training所用时长与batch size相关

In [10]:

def create_stats_file(log_dir, num_real_samples, seed, dataset, metric):
  stats_dir = os.path.join(log_dir, 'metrics', metric, 'statistics')
  if not os.path.exists(stats_dir):
    os.makedirs(stats_dir)

  stats_file = os.path.join(
            stats_dir,
            "fid_stats_{}_{}k_run_{}.npz".format(dataset, num_real_samples // 1000,
                                                 seed))
  return stats_file

#如果metric是kid: 
def create_feat_file(log_dir, num_samples, seed, dataset, metric):
  stats_dir = os.path.join(log_dir, 'metrics', metric, 'statistics')
  if not os.path.exists(stats_dir):
    os.makedirs(stats_dir)

  stats_file = os.path.join(
            stats_dir,
            "fid_stats_{}_{}k_run_{}.npz".format(dataset, num_samples // 1000,
                                                 seed))
  return stats_file



In [None]:
mmc.metrics.evaluate(
    metric='fid',
    log_dir='./log/CIFAR10_500_32',
    netG=netG,
    num_real_samples = 50000,
    num_fake_samples = 50000, 
    dataset = CIFAR10_500_32,
    evaluate_step=100000,
    start_seed=0,
    num_runs=1,
    device=device,
    stats_file = create_stats_file('./log/CIFAR10_500_32', 20, 0, CIFAR10_500_32, 'fid'))

In [15]:
Log=mmc.training.Logger(log_dir='./log/CIFAR10_500_32', num_steps=100000, dataset_size=60000, device=device)

In [16]:
Log.vis_images(netG=netG,global_step=100000)