In [4]:
import torch
from torch import nn
import torchvision.transforms as T
import random
import matplotlib.pyplot as plt
import os
from PIL import Image
import torch.utils.data as dutils
from tqdm import tqdm
import torchvision.utils as vutils
import time

## Defining data generator

In [5]:
class Dataset(dutils.Dataset):
    def __init__(self, transform=None, Test=False):  
        self.transform = transform
        if Test:
            self.folder_path='images/test/'
        else:
            self.folder_path='images/train/'
        self.images = [x for x in os.listdir(self.folder_path) if 'cut' not in x if 'hole' not in x]

    def __len__(self):
        return len(self.images)
	
    def __getitem__(self, index):
        image = self.images[index]
        with Image.open(self.folder_path+image) as img:
                if self.transform:
                    img = self.transform(img)
        return img

In [6]:
transform = T.Compose([
	T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
	T.ToTensor()
])


In [7]:
train_ds = Dataset(transform=transform)
test_ds = Dataset(transform=transform, Test=True)
train = dutils.DataLoader(train_ds, batch_size=16, shuffle=True)
test = dutils.DataLoader(test_ds, batch_size=16, shuffle=True)

In [8]:
class Generator(nn.Module):
	def __init__(self):
		super().__init__()
		self.project = nn.Linear(256, 256**2)
		self.stack = nn.Sequential(
			nn.Conv2d(4, 32, 5, 1, 2),
			nn.ReLU(inplace=True),
			nn.MaxPool2d(2, 2),


			nn.Conv2d(32, 32*2, 5, 1, 2),
			nn.ReLU(inplace=True),
			nn.MaxPool2d(2, 2),


			nn.Conv2d(32*2, 32*4, 3, 1, 1),
			nn.ReLU(inplace=True),
			nn.MaxPool2d(2, 2),
	

			nn.Conv2d(32*4, 32*8, 3, 1, 1),
			nn.ReLU(inplace=True),
			nn.MaxPool2d(2, 2),
		
			nn.Conv2d(32*8, 32*16, 3, 1, 1),
			nn.ReLU(inplace=True),
			nn.MaxPool2d(2, 2),

            nn.Conv2d(32*16, 32*32, 3, 1, 1),
			nn.ReLU(inplace=True),
			nn.MaxPool2d(2, 2),
			
            nn.ConvTranspose2d(32*32, 32*16, 4, 2, 1),
			nn.ReLU(inplace=True),

			nn.ConvTranspose2d(32*16, 32*8, 4, 2, 1),
			nn.ReLU(inplace=True),
			

			nn.ConvTranspose2d(32*8, 32*4, 4, 2, 1),
			nn.ReLU(inplace=True),
	

			nn.ConvTranspose2d(32*4, 32*2, 4, 2, 1),
			nn.ReLU(inplace=True),
			nn.ConvTranspose2d(32*2, 32, 4, 2, 1),
			nn.ReLU(inplace=True),
			

			nn.ConvTranspose2d(32, 3, 4, 2, 1),
			nn.Sigmoid()
			
		)
	
	def forward(self, x, z):
		b_size = x.shape[0]
		cond = self.project(z)
		cond = torch.reshape(cond, (b_size, 1, 256, 256))
		x = torch.concat((cond, x), dim=1)
		return self.stack(x)

In [9]:
model = Generator()

x = torch.randn(16, 3, 256, 256)
z = torch.randn(16, 256)

images = model(x, z)

print(images.shape)

torch.Size([16, 3, 256, 256])


In [10]:
class Discriminator(nn.Module):
	def __init__(self):
		super(Discriminator, self).__init__()
		self.main = nn.Sequential(
			nn.Conv2d(3, 32, 4, 2, 1),
			nn.LeakyReLU(0.2, inplace=True),
			
			nn.Conv2d(32, 32*2, 4, 2, 1),
			nn.LeakyReLU(0.2, inplace=True),

			nn.Conv2d(32*2, 32*4, 4, 2, 1),
			nn.LeakyReLU(0.2, inplace=True),

			nn.Conv2d(32*4, 32*8, 4, 2, 1),
			nn.LeakyReLU(0.2, inplace=True),

			nn.Conv2d(32*8, 32*16, 4, 2, 1),
			nn.LeakyReLU(0.2, inplace=True),
			
            nn.Conv2d(32*16, 32*32, 4, 2, 1),
			nn.LeakyReLU(0.2, inplace=True),

			nn.Conv2d(32*32, 1, 4, 1, 0),
			nn.Sigmoid()
		)

	def forward(self, input):
		return self.main(input)

In [11]:
device = (
	'cuda' if torch.cuda.is_available()
	else 'mps' if torch.backends.mps.is_available()
	else 'cpu'
)

In [12]:
BATCH_SIZE = 8
Z_DIM = 256
EPOCHS = 300
BETA_1 = 0.5
BETA_2 = 0.999
LEARNING_RATE = 1e-4
FEATURES_D = 32
FEATURES_G = 32
IMG_SIZE = 256
PATCH_SIZE = 32


In [13]:
print(torch.version.cuda)

12.1


In [14]:
print(f'Started training using device: {device}')

generator = Generator().to(device)
discriminator = Discriminator().to(device)

d_opt = torch.optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(BETA_1, BETA_2))
g_opt = torch.optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(BETA_1, BETA_2))

loss_fn = nn.BCELoss()

fixed_noise = torch.randn(16, Z_DIM, device=device)
fixed_images = next(iter(train))[:16].to(device)

fixed_x_offset = max(0, (IMG_SIZE - 32) // 2)
fixed_y_offset =  min(IMG_SIZE, fixed_x_offset + 32)
start = time.time()
for epoch in range(EPOCHS):
	for image_batch in tqdm(train):
		image_batch = image_batch.to(device)
		b_size = image_batch.shape[0]
		discriminator.zero_grad()

		y_hat_real = discriminator(image_batch).view(-1)
		y_real = torch.ones_like(y_hat_real, device=device)
		real_loss = loss_fn(y_hat_real, y_real)
		real_loss.backward()

	

		image_batch[:, :, fixed_x_offset:fixed_x_offset+PATCH_SIZE, fixed_y_offset:fixed_y_offset+PATCH_SIZE] = 0

		# Predict using generator
		noise = torch.randn(b_size, Z_DIM, device=device)
		predicted_patch = generator(image_batch, noise)

		# Replace black patch with generator output
		image_batch[:, :, fixed_x_offset:fixed_x_offset+PATCH_SIZE, fixed_y_offset:fixed_y_offset+PATCH_SIZE] = predicted_patch[:, :, fixed_x_offset:fixed_x_offset+PATCH_SIZE, fixed_y_offset:fixed_y_offset+PATCH_SIZE]

		# Predict fake images using discriminator
		y_hat_fake = discriminator(image_batch.detach()).view(-1)

		# Train discriminator
		y_fake = torch.zeros_like(y_hat_fake)
		fake_loss = loss_fn(y_hat_fake, y_fake)
		fake_loss.backward()
		d_opt.step()

		# Train generator
		generator.zero_grad()
		y_hat_fake = discriminator(image_batch).view(-1)
		g_loss = loss_fn(y_hat_fake, torch.ones_like(y_hat_fake))
		g_loss.backward()
		g_opt.step()

	fixed_images[:, :, fixed_x_offset:fixed_x_offset+PATCH_SIZE, fixed_y_offset:fixed_y_offset+PATCH_SIZE] = 0
	with torch.no_grad():
		predicted_patches = generator(fixed_images, fixed_noise)
	fixed_images[:, :, fixed_x_offset:fixed_x_offset+PATCH_SIZE, fixed_y_offset:fixed_y_offset+PATCH_SIZE] = predicted_patches
	img = T.ToPILImage()(vutils.make_grid(fixed_images.to('cpu'), normalize=True, padding=2, nrow=4))
	img.save(f'progress/epoch_{epoch}.jpg')
train_time = time.time() - start
print(f'Total training time: {train_time // 60} minutes')

generator = generator.to('cpu')

torch.save(generator, 'models/patch_generator.pkl')

Started training using device: cuda


  8%|▊         | 70/845 [00:15<02:55,  4.42it/s]


KeyboardInterrupt: 