In [None]:
import pandas as pd
import os
import cv2
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import Dataset,DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
input_path = '/content/drive/MyDrive/Dataset_SEM_Images/Denoised_Images'
output_path = '/content/drive/MyDrive/Dataset_SEM_Images/Output_Images'

In [None]:
cuda = True if torch.cuda.is_available() else False
device_use = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class SEMImagesDataset(Dataset):
  def __init__(self,ordered_path,unordered_path,transforms=None):

    self.file_paths = []
    self.labels = []
    self.transforms = transforms

    for path in os.listdir(ordered_path):
      self.file_paths.append(os.path.join(ordered_path,path))
      self.labels.append(1)

    for path in os.listdir(unordered_path):
      self.file_paths.append(os.path.join(unordered_path,path))
      self.labels.append(0)

    # self.train_file_paths = train_data['file_path'].values
    # self.train_labels = train_data['label'].values

    # self.test_file_paths = test_data['file_path'].values
    # self.test_labels = test_data['label'].values


  def __len__(self):
    return len(self.file_paths)

  def __getitem__(self,idx):

    image_path = self.file_paths[idx]
    label = self.labels[idx]

    image = Image.open(image_path)

    if self.transforms:
      image = self.transforms(image)

    return image,label


In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Resize((256,256)),
                                transforms.Normalize([0.5],[0.5])])

ordered_path = os.path.join(input_path,'Ordered')
unordered_path = os.path.join(input_path,'Unordered')

In [None]:
batch_size = 32
latent_dims = (32,32)
label_dims = 2

dataset = SEMImagesDataset(ordered_path,unordered_path,transform)

total_dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True)

In [None]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator,self).__init__()
    self.label_embedding = nn.Embedding(label_dims,label_dims)

    self.layers = nn.Sequential(
        nn.Conv2d(512,512,kernel_size=3,padding='same'),
        nn.Conv2d(512,512,kernel_size=3,padding='same'),
        nn.BatchNorm2d(512),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(512,512,kernel_size=2,stride=2),
        nn.Conv2d(512,256,kernel_size=3,padding='same'),
        nn.Conv2d(256,256,kernel_size=3,padding='same'),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(256,256,kernel_size=2,stride=2),
        nn.Conv2d(256,128,kernel_size=3,padding='same'),
        nn.Conv2d(128,128,kernel_size=3,padding='same'),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(128,128,kernel_size=2,stride=2),
        nn.Conv2d(128,64,kernel_size=3,padding='same'),
        nn.Conv2d(64,64, kernel_size=3,padding='same'),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(64,64,kernel_size=2,stride=2),
        nn.Conv2d(64,32,kernel_size=3,padding='same'),
        nn.Conv2d(32,32,kernel_size=3,padding='same'),
        nn.BatchNorm2d(32),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(32,32,kernel_size=2,stride=2),
        nn.Conv2d(32,16,kernel_size=3, padding='same'),
        nn.Conv2d(16,16, kernel_size=3, padding='same'),
        nn.BatchNorm2d(16),
        nn.ReLU(inplace=True),

        nn.Conv2d(16,1,kernel_size=3, padding='same'),
        nn.Tanh()
    )



  def forward(self,noise, labels):
    label_embed = labels.view(-1,)

    label_embed = self.label_embedding(labels)
    label_embed = label_embed.view(labels.size(0),label_dims,1,1)

    label_embed = label_embed.expand(labels.size(0), label_dims, noise.size(2), noise.size(3))

    g_in = torch.cat((noise, label_embed), dim=1)

    return self.layers(g_in)

In [None]:
# label_dims = 2

# generator = Generator()

# if cuda:
#   generator = generator.cuda()

# noise = torch.randn((32,510,8,8), device='cuda')
# input_label = torch.randint(0,1,(32,1), device='cuda')

# print(input_label.shape)

# output = generator(noise,input_label)

# print(output.shape)

torch.Size([32, 1])
torch.Size([32, 1, 256, 256])


In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()

    self.label_embedding = nn.Embedding(label_dims,label_dims)

    self.layers = nn.Sequential(
        nn.Conv2d(3,64,kernel_size=4,stride=2), # (Grayscale, one-hot encoding)
        nn.LeakyReLU(0.2,inplace=True),

        nn.Conv2d(64,128,kernel_size=4,stride=2),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2,inplace=True),

        nn.Conv2d(128,256,kernel_size=4,stride=2),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2,inplace=True),

        nn.Conv2d(256,512,kernel_size=4,stride=2),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2,inplace=True),

        nn.Conv2d(512,1024, kernel_size=4,stride=2),
        nn.Flatten(1),
        nn.Linear(36864, 1024),
        nn.ReLU(),
        nn.Linear(1024, 128),
        nn.ReLU(),
        nn.Linear(128, 1),
        nn.Sigmoid()
    )

  def forward(self,img, labels):
    label_embed = labels.view(-1,)

    label_embed = self.label_embedding(labels)
    label_embed = label_embed.view(labels.size(0),label_dims,1,1)

    label_embed = label_embed.expand(labels.size(0), label_dims, img.size(2), img.size(3))

    d_in = torch.cat((img, label_embed), dim=1)

    return self.layers(d_in)

In [None]:
# label_dims = 2

# discriminator = Discriminator()

# if cuda:
#   discriminator = discriminator.cuda()

# input_img = torch.randn((32,1,256,256), device='cuda')
# input_label = torch.randint(0,2,(32,1), device='cuda')

# print(input_label.shape)

# output = discriminator(input_img,input_label)

# print(output.shape)
# print(output)

torch.Size([32, 1])
torch.Size([32, 1])
tensor([[0.4943],
        [0.4952],
        [0.4971],
        [0.4994],
        [0.4992],
        [0.4848],
        [0.4893],
        [0.4842],
        [0.4965],
        [0.5024],
        [0.4958],
        [0.4779],
        [0.4966],
        [0.4911],
        [0.4993],
        [0.4864],
        [0.4884],
        [0.4911],
        [0.4904],
        [0.4927],
        [0.4970],
        [0.4941],
        [0.4935],
        [0.4979],
        [0.4817],
        [0.5008],
        [0.4900],
        [0.4945],
        [0.5015],
        [0.4883],
        [0.4911],
        [0.4980]], device='cuda:0', grad_fn=<SigmoidBackward0>)


In [None]:
generator = Generator()
discriminator = Discriminator()
adversial_loss = nn.BCELoss()

if cuda:
  generator = generator.cuda()
  discriminator = discriminator.cuda()
  adversial_loss = adversial_loss.cuda()

optimizer_G = optim.Adam(generator.parameters(),lr=0.0002,betas=(0.5,0.999))
optimizer_D = optim.Adam(discriminator.parameters(),lr=0.0002,betas=(0.5,0.999))

In [None]:
epochs = 1500

for epoch in range(epochs):
  d_loss_real_print = 0
  d_loss_total_print = 0
  g_loss_print = 0
  for real_images, real_labels in total_dataloader:

    generator.train()
    discriminator.train()

    batch_size = real_images.shape[0]

    real_targets = torch.ones((batch_size, 1), dtype=torch.float, device=device_use)
    fake_targets = torch.zeros((batch_size, 1), dtype=torch.float, device=device_use)
    real_labels = real_labels.to(device=device_use).unsqueeze(1)

    real_images= real_images.to(device=device_use)

    # optimizer_D.zero_grad()

    # output_real = discriminator(real_images, real_labels)
    # d_loss_real = adversial_loss(output_real, real_targets)
    # d_loss_real.backward()
    # optimizer_D.step()

    # Training Generator
    optimizer_G.zero_grad()

    noise = torch.randn((batch_size, 510, 8,8), dtype=torch.float, device=device_use)
    gen_labels = torch.randint(0,2,(batch_size, 1) , device=device_use)

    fake_images = generator(noise, gen_labels)

    output_fake = discriminator(fake_images, gen_labels)
    g_loss = adversial_loss(output_fake, fake_targets)
    g_loss.backward()
    optimizer_G.step()

    # Training Discriminaotr
    optimizer_D.zero_grad()

    output_real = discriminator(real_images, real_labels)
    d_loss_real = adversial_loss(output_real, real_targets)

    output_fake = discriminator(fake_images.detach(), gen_labels)
    d_loss_fake = adversial_loss(output_fake, fake_targets)

    d_loss = (d_loss_real + d_loss_fake)/2
    d_loss.backward()
    optimizer_D.step()

    d_loss_real_print = d_loss_real.item()
    d_loss_total_print = d_loss.item()
    g_loss_print = g_loss.item()


  print('epoch [{}/{}], d_loss_real:{:.9f}'.format(epoch+1, epochs, d_loss_real_print),
        ', g_loss:{:.9f}'.format(g_loss_print),
        ', d_loss_total:{:.9f}'.format(d_loss_total_print))

  if epoch%100 == 0:
    generator.eval()
    discriminator.eval()

    with torch.no_grad():
      test_size = 20

      noise = torch.randn((test_size, 510, 8,8), dtype=torch.float, device=device_use)
      gen_labels = torch.randint(0,2,(test_size, 1), device=device_use)

      fake_images = generator(noise, gen_labels)

      to_pil = transforms.ToPILImage()

      output_dir = '/content/drive/MyDrive/Dataset_SEM_Images/Output_Images/' + f"epoch_{epoch}_"

      os.makedirs(output_dir, exist_ok=True)

      for idx in range(fake_images.size(0)):
        img = to_pil(fake_images[idx])
        img.save(f"{output_dir}/image_{idx}_{gen_labels[idx,0]}.png")


epoch [1/1500], d_loss_real:0.014551490 , g_loss:0.000000467 , d_loss_total:0.007275978
epoch [2/1500], d_loss_real:0.001017987 , g_loss:0.000085839 , d_loss_total:0.000551913
epoch [3/1500], d_loss_real:5.481132507 , g_loss:0.000000012 , d_loss_total:2.740566254
epoch [4/1500], d_loss_real:0.000203200 , g_loss:0.000000000 , d_loss_total:0.000101600
