In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline
plt.rcParams['figure.figsize'] = (20.0, 16.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

%load_ext autoreload
%autoreload 2

from gan.train import train
from gan.artmodels import Discriminator, Generator
from gan.losses import w_gan_disloss, w_gan_genloss


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())

In [None]:
NOISE_DIM = 100
NUM_EPOCHS = 50
learning_rate = 0.001

In [None]:
batch_size = 1

root = './Abstract_gallery'

imsize = 512 if root == './Abstract_gallery' else 4096

art_train = ImageFolder(root=root, transform=transforms.Compose([
  transforms.ToTensor(),
    
  # Example use of RandomCrop:
  transforms.Resize(int(1.15 * imsize)),
  transforms.RandomCrop(imsize),
]))

art_loader_train = DataLoader(art_train, batch_size=batch_size, drop_last=True)

In [None]:
D = Discriminator().to(device)
G = Generator(noise_dim=NOISE_DIM).to(device)

In [None]:
D_optimizer = torch.optim.Adam(D.parameters(), lr=learning_rate, betas = (0.5, 0.999))
G_optimizer = torch.optim.Adam(G.parameters(), lr=learning_rate, betas = (0.5, 0.999))

In [None]:
train(D, G, D_optimizer, G_optimizer, w_gan_disloss, 
          w_gan_genloss, num_epochs=NUM_EPOCHS, show_every=250,
          batch_size=batch_size, train_loader=art_loader_train, device=device, train_every=1, l_gp=10)

In [None]:
print(8*3*512*512)