In [1]:
# # This Python 3 environment comes with many helpful analytics libraries installed
# # It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# # For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# # Input data files are available in the read-only "../input/" directory
# # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
import pytorch_lightning as pl

class MyDataModule(pl.LightningDataModule):
    def __init__(self, root_dir, batch_size, image_size):
        super().__init__()
        self.root_dir = root_dir
        self.batch_size = batch_size
        self.image_size = image_size

    def setup(self, stage=None):
        # transforms for images
        transform=pl.transforms.Compose([
            pl.transforms.Resize(self.image_size),
            pl.transforms.CenterCrop(self.image_size),
            pl.transforms.ToTensor(),
            pl.transforms.Normalize((0,0,0), (1,1,1))
        ])

        # load the dataset
        self.dataset = ImageFolder(root=self.root_dir, transform=transform)

    def train_dataloader(self):
        return pl.DataLoader(self.dataset, batch_size=self.batch_size,
                              shuffle=True, num_workers=2)

    def val_dataloader(self):
        return pl.DataLoader(self.dataset, batch_size=self.batch_size,
                              shuffle=False, num_workers=2)

    def test_dataloader(self):
        return pl.DataLoader(self.dataset, batch_size=self.batch_size,
                              shuffle=False, num_workers=2)



ModuleNotFoundError: No module named 'pytorch_lightning.metrics'

In [None]:
dm = MyDataModule(root_dir='/kaggle/input/intel-image-classification/seg_train/seg_train/',
                  batch_size=batch_size, image_size=image_size)

In [None]:
class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

class Discriminator(nn.Module):
    def __init__(self, ndf, nc):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    def forward(self, input):
        return self.main(input)

In [None]:
class GAN(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.generator = Generator(self.hparams.nz, self.hparams.ngf, self.hparams.nc)
        self.discriminator = Discriminator(self.hparams.ndf, self.hparams.nc)
        self.example_input_array = torch.zeros(2, self.hparams.nz, 1, 1)
        self.criterion = nn.BCELoss()

    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return self.criterion(y_hat, y)

    def training_step(self, batch, batch_idx, optimizer_idx):
        real, _ = batch
        real = real.to(self.device)
        batch_size = real.size(0)
        valid = torch.ones(batch_size, 1, device=self.device)
        fake = torch.zeros(batch_size, 1, device=self.device)

        if optimizer_idx == 0:
            z = torch.randn(batch_size, self.hparams.nz, 1, 1, device=self.device)
            generated_imgs = self(z)
            g_loss = self.adversarial_loss(self.discriminator(generated_imgs), valid)
            self.log('g_loss', g_loss)
            return g_loss

        if optimizer_idx == 1:
            z = torch.randn(batch_size, self.hparams.nz, 1, 1, device=self.device)
            generated_imgs = self(z)
            real_loss = self.adversarial_loss(self.discriminator(real), valid)
            fake_loss = self.adversarial_loss(self.discriminator(generated_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2
            self.log('d_loss', d_loss)
            return d_loss

    def configure_optimizers(self):
        opt_g = optim.Adam(self.generator.parameters(), lr=self.hparams.lr, betas=(0.5, 0.999))
        opt_d = optim.Adam(self.discriminator.parameters(), lr=self.hparams.lr, betas=(0.5, 0.999))
        return [opt_g, opt_d], []

    def on_epoch_end(self):
        z = torch.randn(1, self.hparams.nz, 1, 1, device=self.device)
        sample_image = self(z)
        self.logger.experiment.add_image('generated_image', sample_image[0], self.current_epoch)

In [None]:
import matplotlib.pyplot as plt
from PIL import Image

# Set the hyperparameters
hparams = argparse.Namespace()
hparams.root_dir = 'cifar10'
hparams.batch_size = 64
hparams.image_size = 64
hparams.nz = 100
hparams.ngf = 64
hparams.ndf = 64
hparams.nc = 3
hparams.lr = 0.0002
hparams.beta1 = 0.5
hparams.beta2 = 0.999
hparams.n_epochs = 10

# Instantiate the data module
dm = MyDataModule(hparams.root_dir, hparams.batch_size, hparams.image_size)

# Instantiate the model
model = GAN(hparams)

# Instantiate a checkpoint callback to save the best model
checkpoint_callback = ModelCheckpoint(monitor='d_loss', mode='min')

# Train the model
trainer = pl.Trainer(gpus=1, max_epochs=hparams.n_epochs, checkpoint_callback=checkpoint_callback)
trainer.fit(model, dm)

# Generate some sample images
noise = torch.randn(64, hparams.nz, 1, 1)
fake_images = model.generator(noise).detach().cpu()
grid = vutils.make_grid(fake_images, nrow=8, padding=2, normalize=True)

# Save the generated images to a file
vutils.save_image(grid, 'generated_images.png')

# Load the generated image
img = Image.open('generated_images.png')

# Display the image using matplotlib
plt.imshow(img)
plt.axis('off')
plt.show()