<a href="https://colab.research.google.com/github/AndyYu25/AImpression/blob/main/AImpression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
import os
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import default_loader
import torchvision.transforms as tt
import torch
import torch.nn as nn
from tqdm.notebook import tqdm
import torch.nn.functional as F
from torchvision.utils import save_image
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from PIL import Image

DATA_DIR = "/content/gdrive/MyDrive/Colab Notebooks/Vincent_van_Gogh"
sample_dir = '/content/gdrive/MyDrive/Colab Notebooks/AImpression'








imageSize = 128
batchSize = 64
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
latentSize = 256
imageRows = 8
imageCols = 8
featureSize = 64

imageTransforms = tt.Compose([ tt.Resize(imageSize),
                               tt.CenterCrop(imageSize),
                               tt.ToTensor()])


def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')

device = get_default_device()

fixedLatent = torch.randn(imageRows * imageCols, latentSize, 1, 1, device=device)
print(fixedLatent.size())
trainDataset = ImageFolder(DATA_DIR, transform=imageTransforms)

trainDataLoader = DataLoader(trainDataset, batchSize, shuffle=True, num_workers=2, pin_memory=True)
modelVersion = 112 #set to 0 to train the model from scratch
if modelVersion == 0:
  generator = nn.Sequential(
              # input is latentSize x 1 x 1
              nn.ConvTranspose2d(latentSize, featureSize * 16, 4, 1, 0, bias=False),
              nn.BatchNorm2d(featureSize * 16),
              nn.ReLU(True),
              # state size. (featureSize*16) x 4 x 4
              nn.ConvTranspose2d(featureSize * 16, featureSize * 8, 4, 2, 1, bias=False),
              nn.BatchNorm2d(featureSize * 8),
              nn.ReLU(True),
              # state size. (featureSize*8) x 8 x 8
              nn.ConvTranspose2d( featureSize * 8, featureSize * 4, 4, 2, 1, bias=False),
              nn.BatchNorm2d(featureSize * 4),
              nn.ReLU(True),
              # state size. (featureSize*4) x 16 x 16
              nn.ConvTranspose2d( featureSize * 4, featureSize * 2, 4, 2, 1, bias=False),
              nn.BatchNorm2d(featureSize * 2),
              nn.ReLU(True),
              # state size. (featureSize*2) x 32 x 32
              nn.ConvTranspose2d( featureSize * 2, featureSize, 4, 2, 1, bias=False),
              nn.BatchNorm2d(featureSize),
              nn.ReLU(True),
              # state size. (featureSize) x 64 x 64
              nn.ConvTranspose2d(featureSize, 3, 4, 2, 1, bias=False),
              nn.Tanh()
              # state size. 3 x 128 x 128
          )


  discriminator = nn.Sequential(
              # input is 3 x 128 x 128
              nn.Conv2d(3, featureSize, 4, 2, 1, bias=False),
              nn.LeakyReLU(0.2, inplace=True),
              # state size. (featureSize) x 64 x 64
              nn.Conv2d(featureSize, featureSize * 2, 4, 2, 1, bias=False),
              nn.BatchNorm2d(featureSize * 2),
              nn.LeakyReLU(0.2, inplace=True),
              # state size. (featureSize * 2) x 32 x 32
              nn.Conv2d(featureSize * 2, featureSize * 4, 4, 2, 1, bias=False),
              nn.BatchNorm2d(featureSize * 4),
              nn.LeakyReLU(0.2, inplace=True),
              # state size. (featureSize*4) x 16 x 16
              nn.Conv2d(featureSize * 4, featureSize * 8, 4, 2, 1, bias=False),
              nn.BatchNorm2d(featureSize * 8),
              nn.LeakyReLU(0.2, inplace=True),
              # state size. (featureSize*8) x 8 x 8
              nn.Conv2d(featureSize * 8, featureSize * 16, 4, 2, 1, bias=False),
              nn.BatchNorm2d(featureSize * 16),
              nn.LeakyReLU(0.2, inplace=True),
              # state size. (featureSize*16) x 4 x 4
              nn.Conv2d(featureSize * 16, 1, 4, 1, 0, bias=False),
              nn.Flatten(),
              nn.Sigmoid()
          )
else:
  generator = torch.load(sample_dir + '/models/' + 
                         str(modelVersion) + 'generator.pth')
  discriminator = torch.load(sample_dir + '/models/' + 
                             str(modelVersion)+ 'discriminator.pth')


def denorm(img_tensors):
    return img_tensors * stats[1][0] + stats[0][0]

def saveSamples(index, latent_tensors, imageRows, imageCols, show=True):
    fake_images = generator(latent_tensors)
    fake_fname = sample_dir +'/' + 'generated-images-{0:0=4d}.png'.format(index)
    print(fake_images.size())
    save_image(denorm(fake_images), fake_fname, nrow=imageRows)
    print('Saving', fake_fname)
    if show:
        fig, ax = plt.subplots(figsize=(imageRows, imageCols))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(fake_images.cpu().detach(), nrow=imageRows).permute(1, 2, 0))

def train_discriminator(real_images, opt_d):
    # Clear discriminator gradients
    opt_d.zero_grad()

    # Pass real images through discriminator
    real_preds = discriminator(real_images)
    
    real_targets = torch.ones(real_images.size(0), 1, device=device)
    real_loss = F.binary_cross_entropy(real_preds, real_targets)
    real_score = torch.mean(real_preds).item()
    
    # Generate fake images
    latent = torch.randn(batchSize, latentSize, 1, 1, device=device)
    fake_images = generator(latent)

    # Pass fake images through discriminator
    fake_targets = torch.zeros(fake_images.size(0), 1, device=device)
    fake_preds = discriminator(fake_images)
    fake_loss = F.binary_cross_entropy(fake_preds, fake_targets)
    fake_score = torch.mean(fake_preds).item()

    # Update discriminator weights
    loss = real_loss + fake_loss
    loss.backward()
    opt_d.step()
    return loss.item(), real_score, fake_score

def train_generator(opt_g):
    # Clear generator gradients
    opt_g.zero_grad()
    
    # Generate fake images
    latent = torch.randn(batchSize, latentSize, 1, 1, device=device)
    fake_images = generator(latent)
    
    # Try to fool the discriminator
    preds = discriminator(fake_images)
    targets = torch.ones(batchSize, 1, device=device)
    loss = F.binary_cross_entropy(preds, targets)
    
    # Update generator weights
    loss.backward()
    opt_g.step()
    
    return loss.item()


def fit(epochs, lr, start_idx=1):
    torch.cuda.empty_cache()
    
    # Losses & scores
    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []
    
    # Create optimizers
    opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    
    for epoch in range(modelVersion + 1, epochs + 1):
        for real_images, _ in tqdm(trainDataLoader):
            # Train discriminator
            loss_d, real_score, fake_score = train_discriminator(real_images, opt_d)
            # Train generator
            loss_g = train_generator(opt_g)
            
        # Record losses & scores
        losses_g.append(loss_g)
        losses_d.append(loss_d)
        real_scores.append(real_score)
        fake_scores.append(fake_score)
        
        # Log losses & scores (last batch)
        print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
            epoch+1, epochs, loss_g, loss_d, real_score, fake_score))
    
        # Save generated images
        saveSamples(epoch+start_idx, fixedLatent, imageRows, imageCols, show=False)

        # Save model weights
        generatorWeights = torch.save(generator, sample_dir + '/models/' + 
                                      str(epoch) + 'generator.pth')
        discriminatorWeights = torch.save(discriminator, sample_dir + 
                                          '/models/'+ str(epoch) + 
                                          'discriminator.pth')
    
    return losses_g, losses_d, real_scores, fake_scores



if __name__ == '__main__':
    lr = 0.0002
    epochs = 200
    history = fit(epochs, lr)
    losses_g, losses_d, real_scores, fake_scores = history

torch.Size([64, 256, 1, 1])


  0%|          | 0/14 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

# New Section