In [1]:
import torch
import numpy as np
from skimage import io
import torch.optim as optim
from torch.utils.data import Dataset
import torchvision
from torchvision import transforms

from progan.modeling import Discriminator, Generator
from progan.loss import GANLoss
from progan.trainer import TrainingArguments, GANTrainer

In [2]:
class PokemonDataset(Dataset):
    def __init__(self, dir_path, n_pokemons, keep_alpha = True, transform = None):
        self.dir_path = dir_path
        self.n_pokemons = n_pokemons
        self.transform = transform
        
        self.images = []
        
        for i in range(n_pokemons):
            path = os.path.join(self.dir_path, f'{i+1}.png')
            image = io.imread(path)
            
            if image[0,0,0]==255:
                for i in range(3):
                    image[:,:,i]-=255-image[:,:,-1]
                
            if not keep_alpha:
                image = image[:,:,:-1]
                
            self.images.append(image)
            
    def __len__(self):
        return self.n_pokemons
    
    def __getitem__(self, idx):
        
        if self.transform:
            return self.transform(self.images[idx])
        else:
            return self.images[idx]

In [3]:
class ToTensor:
    def __call__(self, x):
        x = x.transpose((2, 0, 1))     
        return torch.from_numpy(x).float()


In [4]:
toTensor = ToTensor()

compose = transforms.Compose([
            toTensor,
            #torchvision.transforms.Resize((H, W)),
            #torchvision.transforms.RandomAffine(degrees = 0, translate = (0.05, 0.05), scale = (0.9, 1.1)),
            transforms.RandomHorizontalFlip(),
            torchvision.transforms.Normalize(128, 128),
        ])

In [5]:
DATA_DIR = './ultra-sun-ultra-moon'
LOG_DIR = './logs_w'
GAN_DEPTH = 5
MAX_FILTERS = 256
NUM_WORKERS=0
IMAGE_DIM = 3
LATENT_DIM = 128
FILTERS_MULTPL = 32
N_EPOCHS=8192
TRANSITION_STEPS = [8000, 16000, 32000, 64000, 128000]
RANK_STEPS = [4000, 8000, 16000, 32000, 64000, 128000]
LEARNING_RATE = 1e-3
DISCRIMINATOR_ITERATIONS = 1
LAMBDA_GP = 10
H, W = 64, 64


In [6]:
dataset = PokemonDataset(DATA_DIR, 807, keep_alpha=False, transform = compose)

In [7]:
discriminator = Discriminator(LATENT_DIM, FILTERS_MULTPL, MAX_FILTERS, GAN_DEPTH, IMAGE_DIM)
generator = Generator(LATENT_DIM, FILTERS_MULTPL, MAX_FILTERS, GAN_DEPTH, IMAGE_DIM)

In [8]:
d_optimizer = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0, 0.99))
g_optimizer = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0, 0.99))

In [9]:
gan_loss = GANLoss(g_optimizer, d_optimizer, LAMBDA_GP)

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

generator.to(device)
discriminator.to(device)

print(f"Device : {device}")

Device : cuda:0


In [11]:
training_args = TrainingArguments(batch_size = [32, 32, 16, 16, 16, 16], resume_from="1617481270_iter_68000", save_steps = 1000)

In [12]:
trainer = GANTrainer(discriminator, generator, gan_loss, dataset, training_args)

In [13]:
trainer.train()

Training : Rank 4 = 64x64:  14%|█▎        | 8680/64000 [27:35<5:27:00,  2.82it/s]

KeyboardInterrupt: 