In [24]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.models as models
import torchvision.datasets as datasets
import numpy as np
from PIL import Image

In [25]:
class Generator(nn.Module) :
    def __init__(self, n, m):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(in_features=100, out_features=256), nn.ReLU())
        self.l2 = nn.Sequential(nn.Linear(in_features=256, out_features=512), nn.ReLU())
        self.l3 = nn.Sequential(nn.Linear(in_features=512, out_features=1024), nn.ReLU())
        self.output = nn.Sequential(nn.Linear(in_features=1024, out_features=n*m), nn.Tanh())
    def forward(self, x) :
        x = self.l1(x)
        x = self.l2(x)
        x = self.l3(x)
        x = self.output(x)
        return x

In [34]:
class GAN() :
    def __init__(self, n, m) :
        self.generator = Generator(n, m)
        self.discriminator = models.resnet50(pretrained=True)
        self.discriminator.fc = nn.Linear(in_features=512, out_features=2)
        self.discriminator = self.discriminator.to('cpu')
        self.discr_criterion = torch.nn.BCELoss()
        self.discr_optimizer = torch.optim.SGD(self.discriminator.fc.parameters(), lr=0.001, momentum=0.9)
        self.gen_criterion = torch.nn.BCELoss()
        self.gen_optimizer = torch.optim.SGD(self.discriminator.parameters(), lr=0.001, momentum=0.9)
    
    def learning_step(self, x) :
        self.evolution.append(self.generator(self.test_noises))
        z = torch.randn(x.shape[0], 1, 100)
        fakes = self.generator(z)
        fake_classes = self.discriminator(fakes)
        real_classes = self.discriminator(x)
        fake_loss = self.discr_criterion(fake_classes, torch.ones[x.shape[0]])
        real_loss = self.discr_criterion(real_classes, torch.zeros[x.shape[0]])
        full_loss = fake_loss + real_loss
        full_loss.backward()
        self.discr_optimizer.step()
        self.generator.zero_grad()
        gen_loss = self.gen__criterion(fake_classes, torch.zeros[x.shape[0]])
        gen_loss.backward()
        self.gen_optimizer.step()
        return gen_loss, full_loss
    
    def fit(self, file_path, epochs_number) :
        trans = transforms.Compose([transforms.ToTensor(),
                                        transforms.RandomResizedCrop(32),
                                        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

        trans_data = datasets.ImageFolder(file_path, trans)
        dataloader = DataLoader(trans_data, batch_size=128, shuffle=True, num_workers=4)
        history = {'generator' : [], 'discriminator' : []}
        for epoch in range(epochs_number) :
            losses = {'generator' : [], 'discriminator' : []}
            for i, data in enumerate(dataloader, 0) :
                x, _ = data
                gen_loss, full_loss = self.learning_step(x)
                losses['generator'].append(float(gen_loss))
                losses['discriminator'].append(float(full_loss))
                history['generator'].append(float(gen_loss))
                history['discriminator'].append(float(full_loss))
            avg_gen_loss = np.mean(losses['generator'])
            avg_discr_loss = np.mean(losses['discriminator'])
            print("Epoch {} / {}: Generator Loss = {:.3f}, Discriminator Loss = {:.3f}".format(epoch+1, epochs_number, avg_gen_loss, avg_discr_loss))
        return history
    
    def predict(self, n) :
        z = torch.randn(n, 1, 100)
        fakes = self.generator(z)
        for i, tensor in enumerate(fakes) :
            image = tensor.clone().detach().numpy()
            img = Image.fromarray(image, 'RGB')
            img.save(f'image_{i+1}.png')


In [35]:
model = GAN(100, 100)

In [None]:
model.fit('dataset', 200)

In [32]:
model.predict(10)