In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import glob
import os
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [2]:
image_size = 256

class ImageDataset(Dataset):
    def __init__(self, paths, transform=None, masked_size=100):
        self.paths = paths
        self.transform = transform
        self.masked_size = masked_size

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img_path = self.paths[idx]
        img = Image.open(img_path).convert("RGB")
        masked_img = img.copy()
        masked_img = np.array(masked_img)
        
        if self.transform:
            img = self.transform(img)

        if self.masked_size:
            h = np.random.randint(0, image_size - self.masked_size)
            w = np.random.randint(0, image_size - self.masked_size)
            masked_img[h:h+self.masked_size, w:w+self.masked_size] = 0
            masked_img = Image.fromarray(masked_img)
            masked_img = self.transform(masked_img)

        return img, masked_img
    
    @staticmethod
    def collate_fn(batch):
        image, masked_image = zip(*batch)
        image = torch.stack(image, dim=0)
        masked_image = torch.stack(masked_image, dim=0)
        return image, masked_image

    
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

image_dir = "testSet_resize"
image_paths = glob.glob(os.path.join(image_dir, "*.jpg"))  # Adjust the file extension if needed

train_dataset = ImageDataset(image_paths, transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=ImageDataset.collate_fn)


In [3]:
class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, kernel_size=3, stride=2, padding=1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_size, 0.8))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)
    
class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super(UNetUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_size, out_size, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
            nn.BatchNorm2d(out_size, 0.8),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)

        return x

    
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(Generator, self).__init__()

        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512, dropout=0.5)
        self.down5 = UNetDown(512, 512, dropout=0.5)
        self.down6 = UNetDown(512, 512, dropout=0.5)
        self.down7 = UNetDown(512, 512, dropout=0.5)
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)

        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 512, dropout=0.5)
        self.up4 = UNetUp(1024, 512, dropout=0.5)
        self.up5 = UNetUp(1024, 256)
        self.up6 = UNetUp(512, 128)
        self.up7 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)

        return self.final(u7)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),
        )
        self.fc = nn.Linear(13 * 13, 1)

    def forward(self, img):
        x = self.model(img)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

generator = Generator().to(device)
discriminator = Discriminator().to(device)

criterion = nn.BCEWithLogitsLoss()

optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

num_epochs = 20
loss_G = []
loss_D = []

for epoch in tqdm(range(num_epochs)):
    for images, masked_images in train_loader:
        images = images.to(device)
        masked_images = masked_images.to(device)

        # Train the discriminator
        optimizer_D.zero_grad()

        real_labels = torch.ones(images.size(0), 1).to(device)
        fake_labels = torch.zeros(images.size(0), 1).to(device)

        outputs = discriminator(images)
        real_loss = criterion(outputs, real_labels)

        fake_images = generator(masked_images)

        outputs = discriminator(fake_images.detach())
        fake_loss = criterion(outputs, fake_labels)

        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()

        # Train the generator
        optimizer_G.zero_grad()

        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        optimizer_G.step()


    print(f"[{epoch}/{num_epochs}] Loss_D: {d_loss.item():.4f} Loss_G: {g_loss.item():.4f}")
    loss_G.append(g_loss.item())
    loss_D.append(d_loss.item())

generator._save_to_state_dict('generator.pth')
discriminator._save_to_state_dict('discriminator.pth')

plt.plot(loss_G, label='Generator')
plt.plot(loss_D, label='Discriminator')
plt.legend()
plt.show()


  5%|▌         | 1/20 [02:09<40:54, 129.16s/it]

[0/20] Loss_D: 0.0451 Loss_G: 22.1729
