In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import glob
import random

from google.colab import drive
drive.mount('/content/drive')



Mounted at /content/drive


In [None]:
# Define the Generator model
class InpaintGenerator(nn.Module):
    def __init__(self):
        super(InpaintGenerator, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256, 0.8),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.BatchNorm2d(512, 0.8),
            nn.LeakyReLU(0.2),
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256, 0.8),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.ReLU(),
            nn.Conv2d(64, 3, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, data):
        encoded = self.encoder(data)
        decoded = self.decoder(encoded)
        return decoded

# Define the Discriminator model
class InpaintDiscriminator(nn.Module):
    def __init__(self):
        super(InpaintDiscriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 3, 2, 1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 3, 1, 1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 3, 1, 1)
        )

    def forward(self, image_data):
        return self.model(image_data)


In [None]:
# Create a folder to save sampled reconstructed output images
os.makedirs('inpaint_output', exist_ok=True)

# Initialize weights for the model
def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

# Initialize the loss functions
device = torch.device('cuda')
mse_loss = torch.nn.MSELoss()
l1_loss = torch.nn.L1Loss()
mse_loss.to(device)
l1_loss.to(device)

# Initialize the Generator and Discriminator
generator = InpaintGenerator().to(device)
discriminator = InpaintDiscriminator().to(device)

# Apply initial weights to Generator and Discriminator
generator.apply(init_weights)
discriminator.apply(init_weights)

class CustomDataset(Dataset):
   def __init__(self, dataset_path, image_transforms=None, inp_imgsize=128, inp_masksize=64, is_trainable=True):
        self.image_transforms = transforms.Compose(image_transforms)
        self.inp_imgsize = inp_imgsize
        self.inp_masksize = inp_masksize
        self.is_trainable = is_trainable
        self.input_data = glob.glob(os.path.join(dataset_path, "*.jpg"))
        self.input_data = sorted(self.input_data)

   def __getitem__(self, index):
        inp_image = Image.open(self.input_data[index % len(self.input_data)])
        inp_image = self.image_transforms(inp_image)
        if self.is_trainable:
            image_masked = inp_image.clone()
            y, x = np.random.randint(0, self.inp_imgsize - self.inp_masksize, 2)
            range_x = int(x + self.inp_masksize)
            range_y = int(y + self.inp_masksize)
            mask_region = inp_image[:, y:range_y, x:range_x]
            image_masked[:, y:range_y, x:range_x] = 1
        else:
            mask_region = (self.inp_imgsize - self.inp_masksize) // 2
            image_masked = inp_image.clone()
            x = int(self.inp_imgsize / 4)
            y = int(self.inp_imgsize / 4)
            range_x = int(x + self.inp_masksize)
            range_y = int(y + self.inp_masksize)
            image_masked[:, x:range_x, y:range_y] = 1

        return inp_image, image_masked, mask_region

   def __len__(self):
        return len(self.input_data)

# Train and test data
image_transforms = [
    transforms.Resize((128, 128), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

training_data = DataLoader(
    CustomDataset("/content/drive/MyDrive/censor_content/train",
                         image_transforms=image_transforms, is_trainable=True),
    batch_size=8,
    shuffle=True,
    num_workers=4,
)

testing_data = DataLoader(
    CustomDataset("/content/drive/MyDrive/censor_content/test",
                         image_transforms=image_transforms, is_trainable=False),
    batch_size=12,
    shuffle=True,
    num_workers=1,
)

# Visualizing Sample Output Images
def store_sample_image():
    inp_image, image_masked, mask_coord = next(iter(testing_data))
    inp_image = Variable(inp_image.type(torch.cuda.FloatTensor))
    image_masked = Variable(image_masked.type(torch.cuda.FloatTensor))
    mask_coord = mask_coord[0].item()
    # Output of reconstructed generated image
    generator_output = generator(image_masked)
    inpainted_image = image_masked.clone()
    inpainted_image[:, :, mask_coord: mask_coord + 64, mask_coord: mask_coord + 64] = generator_output
    # Store the output data to disk
    sample = torch.cat((image_masked.data, inpainted_image.data, inp_image.data), -2)
    save_image(sample, "inpaint_output.png" , nrow=6, normalize=True)





In [None]:

# Initialize Optimizer for Generator and Discriminator
image_generator_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
image_discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training the inpainting network
for epoch in range(200):
    for i, (inp_image, image_masked, mask_region) in enumerate(training_data):

        # Setting real / fake class
        real_labels = Variable(torch.cuda.FloatTensor(inp_image.shape[0], 1, 8, 8).fill_(1.0), requires_grad=False)
        fake_labels = Variable(torch.cuda.FloatTensor(inp_image.shape[0], 1, 8, 8).fill_(0.0), requires_grad=False)

        # Setting input image data
        inp_image = Variable(inp_image.type(torch.cuda.FloatTensor))
        image_masked = Variable(image_masked.type(torch.cuda.FloatTensor))
        mask_region = Variable(mask_region.type(torch.cuda.FloatTensor))

        # Training Generator
        image_generator_optimizer.zero_grad()

        # Generate images from the generator
        generator_output = generator(image_masked)

        # Calculate Loss from the Generator side
        gen_entropy = mse_loss(discriminator(generator_output), real_labels)
        gen_reconstruct = l1_loss(generator_output, mask_region)
        generator_loss = 0.001 * gen_entropy + 0.999 * gen_reconstruct

        generator_loss.backward()
        image_generator_optimizer.step()

        # Training Discriminator
        image_discriminator_optimizer.zero_grad()

        # Classifying real / fake images from the samples generated
        real_loss = mse_loss(discriminator(mask_region), real_labels)
        fake_loss = mse_loss(discriminator(generator_output.detach()), fake_labels)
        discriminator_loss = (real_loss + fake_loss) / 2

        discriminator_loss.backward()
        image_discriminator_optimizer.step()

        print(
            "[Epoch %d/200] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, i, len(training_data), discriminator_loss.item(), generator_loss.item())
        )


store_sample_image()
# Save the model here
model = {
    'state_dict': generator.state_dict(),
    'optimizer': generator.state_dict(),
}

torch.save(model, "model.pth")

  real_labels = Variable(torch.cuda.FloatTensor(inp_image.shape[0], 1, 8, 8).fill_(1.0), requires_grad=False)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[Epoch 166/200] [Batch 100/150] [D loss: 0.001427] [G loss: 0.349273]
[Epoch 166/200] [Batch 101/150] [D loss: 0.001619] [G loss: 0.372919]
[Epoch 166/200] [Batch 102/150] [D loss: 0.001797] [G loss: 0.329402]
[Epoch 166/200] [Batch 103/150] [D loss: 0.001946] [G loss: 0.324341]
[Epoch 166/200] [Batch 104/150] [D loss: 0.002168] [G loss: 0.262266]
[Epoch 166/200] [Batch 105/150] [D loss: 0.002351] [G loss: 0.372712]
[Epoch 166/200] [Batch 106/150] [D loss: 0.002600] [G loss: 0.300687]
[Epoch 166/200] [Batch 107/150] [D loss: 0.002880] [G loss: 0.327597]
[Epoch 166/200] [Batch 108/150] [D loss: 0.003030] [G loss: 0.300408]
[Epoch 166/200] [Batch 109/150] [D loss: 0.003053] [G loss: 0.307274]
[Epoch 166/200] [Batch 110/150] [D loss: 0.002821] [G loss: 0.356773]
[Epoch 166/200] [Batch 111/150] [D loss: 0.002553] [G loss: 0.270208]
[Epoch 166/200] [Batch 112/150] [D loss: 0.002295] [G loss: 0.314854]
[Epoch 166/200] [Batch 11