In [14]:
import torch
from torch import nn
import torchvision.transforms as T
import random
import matplotlib.pyplot as plt
import os
from PIL import Image
import torch.utils.data as dutils
from tqdm import tqdm
import torchvision.utils as vutils
import time

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += self.shortcut(residual)
        out = self.relu(out)

        return out

class DownBlock(nn.Module):
    def __init__(self, FEATURES_G):
        super(DownBlock, self).__init__()

        self.main=nn.Sequential(
            nn.Conv2d(FEATURES_G, FEATURES_G*2, 3, 2, 1),
            nn.BatchNorm2d(FEATURES_G*2),
            nn.ReLU(inplace=True)

        )

    def forward(self, x):
        return self.main(x)
    

class UpBlock(nn.Module):
    def __init__(self, FEATURES_G):
        super(UpBlock, self).__init__()

        self.main=nn.Sequential(
            nn.ConvTranspose2d(FEATURES_G, int(FEATURES_G/2), 3, 2, 1),
            nn.BatchNorm2d(int(FEATURES_G/2)),
            nn.ReLU(inplace=True)

        )

    def forward(self, x):
        return self.main(x)

class ResNetGenerator(nn.Module):
    def __init__(self, FEATURES_G=32, num_blocks=5):
        super(ResNetGenerator, self).__init__()
        self.project = nn.Linear(256, 256**2)
        self.FEATURES_G=FEATURES_G
        self.relu = nn.ReLU(inplace=True)
        self.initial=nn.Sequential(
            nn.Conv2d(4, self.FEATURES_G, 7, 1, 3),
            nn.BatchNorm2d(self.FEATURES_G),
            nn.ReLU(inplace=True))
        
        
        # Reshape to start the convolution stack
     #   self.initial = nn.Sequential(
    #        nn.Conv2d(16384, hidden_dim * 8, kernel_size=3, stride=1, padding=1, bias=False),
     #       nn.BatchNorm2d(hidden_dim * 8),
     #       nn.ReLU(inplace=True)
     #   )
        self.down=nn.Sequential(
            *[DownBlock(self.FEATURES_G*(2**x)) for x in range(5)]
        )
        self.up=nn.Sequential(
            *[UpBlock(self.FEATURES_G*(2**(5-x))) for x in range(5)]
        )
        # Residual blocks
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(self.FEATURES_G*32, self.FEATURES_G*32) for _ in range(num_blocks)]
        )

  

        # Output layer
        self.out_conv = nn.ConvTranspose2d(self.FEATURES_G, 3, kernel_size=4, stride=2, padding=1, bias=False)
        self.tanh = nn.Tanh()

    def forward(self, noise, x):
        b_size = x.shape[0]
        cond = self.project(noise)
        cond = torch.reshape(cond, (b_size, 1, 256, 256))
        x = torch.concat((cond, x), dim=1)

       
        x = self.initial(x)
        x=self.down(x)
        x = self.res_blocks(x)
        x = self.up(x)

        x = self.out_conv(x)
        x = 0.5*(self.tanh(x)+1)

        return x
class Dataset(dutils.Dataset):
    def __init__(self, transform=None, Test=False):  
        self.transform = transform
        if Test:
            self.folder_path='images/test/'
        else:
            self.folder_path='images/train/'
        self.images = [x for x in os.listdir(self.folder_path) if 'cut' not in x if 'hole' not in x]

    def __len__(self):
        return len(self.images)
	
    def __getitem__(self, index):
        image = self.images[index]
        with Image.open(self.folder_path+image) as img:
                if self.transform:
                    img = self.transform(img)
        return img
transform = T.Compose([
	T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
	T.ToTensor()
])
train_ds = Dataset(transform=transform)
test_ds = Dataset(transform=transform, Test=True)
train = dutils.DataLoader(train_ds, batch_size=16, shuffle=True)
test = dutils.DataLoader(test_ds, batch_size=16, shuffle=True)
class Discriminator(nn.Module):
	def __init__(self):
		super(Discriminator, self).__init__()
		self.main = nn.Sequential(
			nn.Conv2d(3, FEATURES_D, 4, 2, 1),
			nn.LeakyReLU(0.2, inplace=True),
			
			nn.Conv2d(FEATURES_D, FEATURES_D*2, 4, 2, 1),
			nn.LeakyReLU(0.2, inplace=True),

			nn.Conv2d(FEATURES_D*2, FEATURES_D*4, 4, 2, 1),
			nn.LeakyReLU(0.2, inplace=True),

			nn.Conv2d(FEATURES_D*4, FEATURES_D*8, 4, 2, 1),
			nn.LeakyReLU(0.2, inplace=True),

			nn.Conv2d(FEATURES_D*8, FEATURES_D*16, 4, 2, 1),
			nn.LeakyReLU(0.2, inplace=True),
			
            nn.Conv2d(FEATURES_D*16,FEATURES_D*32, 4, 2, 1),
			nn.LeakyReLU(0.2, inplace=True),

			nn.Conv2d(FEATURES_D*32, 1, 4, 1, 0),
			nn.Sigmoid()
		)

	def forward(self, input):
		return self.main(input)





In [13]:
device = (
	'cuda' if torch.cuda.is_available()
	else 'mps' if torch.backends.mps.is_available()
	else 'cpu'
)
BATCH_SIZE = 8
Z_DIM = 256
EPOCHS = 300
BETA_1 = 0.5
BETA_2 = 0.999
LEARNING_RATE = 1e-4
FEATURES_D=32
IMG_SIZE = 256
PATCH_SIZE = 32

In [None]:

print(f'Started training using device: {device}')

generator = ResNetGenerator().to(device)
discriminator = Discriminator().to(device)

d_opt = torch.optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(BETA_1, BETA_2))
g_opt = torch.optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(BETA_1, BETA_2))

loss_fn = nn.BCELoss()

fixed_noise = torch.randn(16, Z_DIM, device=device)
fixed_images = next(iter(train))[:16].to(device)

fixed_x_offset = max(0, (IMG_SIZE - PATCH_SIZE) // 2)
fixed_y_offset =  fixed_x_offset
start = time.time()
for epoch in range(EPOCHS):
	for image_batch in tqdm(train):
		image_batch = image_batch.to(device)
		b_size = image_batch.shape[0]
		discriminator.zero_grad()

		y_hat_real = discriminator(image_batch).view(-1)
		y_real = torch.ones_like(y_hat_real, device=device)
		real_loss = loss_fn(y_hat_real, y_real)
		real_loss.backward()

	

		image_batch[:, :, fixed_x_offset:fixed_x_offset+PATCH_SIZE, fixed_y_offset:fixed_y_offset+PATCH_SIZE] = 0

		# Predict using generator
		noise = torch.randn(b_size, Z_DIM, device=device)
		predicted_patch = generator(noise,image_batch)

		# Replace black patch with generator output
		image_batch[:, :, fixed_x_offset:fixed_x_offset+PATCH_SIZE, fixed_y_offset:fixed_y_offset+PATCH_SIZE] = predicted_patch[:, :, fixed_x_offset:fixed_x_offset+PATCH_SIZE, fixed_y_offset:fixed_y_offset+PATCH_SIZE]

		# Predict fake images using discriminator
		y_hat_fake = discriminator(image_batch.detach()).view(-1)

		# Train discriminator
		y_fake = torch.zeros_like(y_hat_fake)
		fake_loss = loss_fn(y_hat_fake, y_fake)
		fake_loss.backward()
		d_opt.step()

		# Train generator
		generator.zero_grad()
		y_hat_fake = discriminator(image_batch).view(-1)
		g_loss = loss_fn(y_hat_fake, torch.ones_like(y_hat_fake))
		g_loss.backward()
		g_opt.step()

	fixed_images[:, :, fixed_x_offset:fixed_x_offset+PATCH_SIZE, fixed_y_offset:fixed_y_offset+PATCH_SIZE] = 0
	with torch.no_grad():
		predicted_patches = generator(fixed_noise, fixed_images)
	fixed_images[:, :, fixed_x_offset:fixed_x_offset+PATCH_SIZE, fixed_y_offset:fixed_y_offset+PATCH_SIZE] = predicted_patches[:, :, fixed_x_offset:fixed_x_offset+PATCH_SIZE, fixed_y_offset:fixed_y_offset+PATCH_SIZE]
	img = T.ToPILImage()(vutils.make_grid(fixed_images.to('cpu'), normalize=True, padding=2, nrow=4))
	img.save(f'progress3/epoch_{epoch}.jpg')
train_time = time.time() - start
print(f'Total training time: {train_time // 60} minutes')

generator = generator.to('cpu')


torch.save(generator, 'models/resnet_generator.pkl')

In [15]:
generator=torch.load('models/resnet_generator.pkl')

In [17]:
fix_x_offset = max(0, (IMG_SIZE - PATCH_SIZE) // 2)
fix_y_offset =  fix_x_offset



for i in range(5):
	images = next(iter(test))
	images = torch.reshape(images, (16, 3, 256, 256))

	img_og = images.clone()


	images[:, :, fix_x_offset:fix_x_offset+PATCH_SIZE, fix_y_offset:fix_y_offset+PATCH_SIZE] = 0


	noise = torch.randn(images.shape[0], 256)
	with torch.no_grad():
		predicted_patches = generator(noise, images)
	
	images[:, :, fix_x_offset:fix_x_offset+PATCH_SIZE, fix_y_offset:fix_y_offset+PATCH_SIZE] = predicted_patches[:, :, fix_x_offset:fix_x_offset+PATCH_SIZE, fix_y_offset:fix_y_offset+PATCH_SIZE]

	images.to('cpu')

	img = torch.cat([img_og, images], dim=3)
	img = T.ToPILImage()(vutils.make_grid( img, normalize=True, padding=2, nrow=4))
	img.save(f'test/test_resnet_{i+1}.jpg')