In [None]:
import numpy as np
import torch
from torch.utils.data import random_split
import matplotlib.pyplot as plt
from model.unet import Unet
from model.diffusion import Diffusion_Models
from model.dataset import SpritesDataset
from train import train

# training hyperparameters
BATCH_SIZE = 128
EPOCHS = 100
LR = 1e-3
# network hyperparameters
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
N_FEAT = 64 # 64 hidden dimension feature
N_CFEAT = 5 # label vector is of size 5
HEIGHT = 16 # 16x16 image
SAVE_DIR = './weights/'
# diffusion hyperparameters
TIMESTEPS = 500

# DataSet path
dataset_data_path = './sprites_1788_16x16.npy'
dataset_label_path = './sprite_labels_nc_1788_16x16.npy'
LABELS = ['hero', 'non-hero', 'food', 'spell', 'side-facing']

## Entrenamiento

In [None]:
model = Unet(in_channels=3, n_feat=N_FEAT, n_cfeat=N_CFEAT, height=HEIGHT).to(DEVICE)
#model.load_state_dict(torch.load(save_dir+"model_100.pth", map_location=device))
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
df = Diffusion_Models(TIMESTEPS)

dataset = SpritesDataset(dataset_data_path, dataset_label_path)
generator1 = torch.Generator().manual_seed(42)
train_dataset, validation_dataset = random_split(dataset, [79400,10000], generator=generator1)

train(model, df, optimizer, train_dataset, validation_dataset, BATCH_SIZE, EPOCHS, DEVICE)

## Generación

In [None]:
model = Unet(in_channels=3, n_feat=N_FEAT, n_cfeat=N_CFEAT, height=HEIGHT).to(DEVICE)
model.load_state_dict(torch.load(SAVE_DIR+"model_100.pth", map_location=DEVICE))
model.eval()

def draw_samples(samples, ctx, filename, cols, labels):
    samples = samples.to("cpu")
    rows = int(samples.shape[0] / cols)
    if samples.shape[0]%cols!=0:
        rows += 1
    plt.figure(figsize=(10,rows*2))
    for i, curr_imgs in enumerate(samples):
        curr_img = df.unorm(curr_imgs)
        plt.subplot(rows, cols, i + 1)
        plt.axis('off')
        plt.title(f"{LABELS[np.argmax(labels[i])]}, time: {ctx[i]}")
        plt.imshow(curr_img.permute(1,2,0))
    plt.savefig(f'./{filename}.png')
    plt.close()