In [None]:
import pandas as pd
import os
import cv2
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import Dataset,DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
input_path = '/content/drive/MyDrive/Dataset_SEM_Images/Denoised_Images'
output_path = '/content/drive/MyDrive/Dataset_SEM_Images/Output_Images'

In [None]:
cuda = True if torch.cuda.is_available() else False

In [None]:
class SEMImagesDataset(Dataset):
  def __init__(self,ordered_path,unordered_path,transforms=None):

    self.file_paths = []
    self.labels = []
    self.transforms = transforms

    for path in os.listdir(ordered_path):
      self.file_paths.append(os.path.join(ordered_path,path))
      self.labels.append(1)

    for path in os.listdir(unordered_path):
      self.file_paths.append(os.path.join(unordered_path,path))
      self.labels.append(0)

    # self.train_file_paths = train_data['file_path'].values
    # self.train_labels = train_data['label'].values

    # self.test_file_paths = test_data['file_path'].values
    # self.test_labels = test_data['label'].values


  def __len__(self):
    return len(self.file_paths)

  def __getitem__(self,idx):

    image_path = self.file_paths[idx]
    label = self.labels[idx]

    image = Image.open(image_path)

    if self.transforms:
      image = self.transforms(image)

    return image,label


In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Resize((256,256)),
                                transforms.Normalize([0.5],[0.5])])

ordered_path = os.path.join(input_path,'Ordered')
unordered_path = os.path.join(input_path,'Unordered')

In [None]:
batch_size = 32
latent_dims = (32,32)
label_dims = 2

dataset = SEMImagesDataset(ordered_path,unordered_path,transform)

total_dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True)

In [None]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator,self).__init__()
    self.label_embedding = nn.Embedding(label_dims,label_dims)

    self.layers = nn.Sequential(
        nn.Conv2d(512,512,kernel_size=3,padding='same'),
        nn.Conv2d(512,512,kernel_size=3,padding='same'),
        nn.LeakyReLU(0.2,inplace=True),
        nn.BatchNorm2d(512),

        nn.ConvTranspose2d(512,256,kernel_size=2,stride=2),
        nn.Conv2d(256,256,kernel_size=3,padding='same'),
        nn.LeakyReLU(0.2,inplace=True),
        nn.BatchNorm2d(256),

        nn.ConvTranspose2d(256,128,kernel_size=2,stride=2),
        nn.Conv2d(128,128,kernel_size=3,padding='same'),
        nn.LeakyReLU(0.2,inplace=True),
        nn.BatchNorm2d(128),

        nn.ConvTranspose2d(128,64,kernel_size=2,stride=2),
        nn.Conv2d(64,64,kernel_size=3,padding='same'),
        nn.LeakyReLU(0.2,inplace=True),
        nn.BatchNorm2d(64),

        nn.ConvTranspose2d(64,64,kernel_size=2,stride=2),
        nn.Conv2d(64,64,kernel_size=3,padding='same'),
        nn.LeakyReLU(0.2,inplace=True),
        nn.BatchNorm2d(64),

        nn.ConvTranspose2d(64,64,kernel_size=2,stride=2),
        nn.Conv2d(64,3,kernel_size=3, padding='same'),
        nn.Conv2d(3,1,kernel_size=3, padding='same'),

        nn.Tanh()
    )



  def forward(self,noise, labels):
    label_embed = labels.view(-1,)

    label_embed = self.label_embedding(labels)
    label_embed = label_embed.view(labels.size(0),label_dims,1,1)

    label_embed = label_embed.expand(labels.size(0), label_dims, noise.size(2), noise.size(3))

    g_in = torch.cat((noise, label_embed), dim=1)

    return self.layers(g_in)

In [None]:
# label_dims = 2

# generator = Generator()

# if cuda:
#   generator = generator.cuda()

# noise = torch.rand(32,510,8,8).to(device='cuda')
# input_label = torch.randint(0,1,(32,)).to(device='cuda').unsqueeze(1)

# print(input_label.shape)

# output = generator(noise,input_label)

# print(output.shape)

torch.Size([32, 1])
torch.Size([32, 512, 8, 8])
torch.Size([32, 1, 256, 256])


In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()

    self.label_embedding = nn.Embedding(label_dims,label_dims)

    self.layers = nn.Sequential(
        nn.Conv2d(3,64,kernel_size=4,stride=2,padding=1), # (Grayscale, one-hot encoding)
        nn.LeakyReLU(0.2,inplace=True),
        nn.Conv2d(64,128,kernel_size=4,stride=2,padding=1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2,inplace=True),

        nn.Conv2d(128,256,kernel_size=4,stride=2,padding=1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2,inplace=True),

        nn.Conv2d(256,1,kernel_size=4,stride=1,padding=0),
        nn.Flatten(1),
        nn.Linear(841,1),
        nn.Sigmoid()
    )

  def forward(self,img, labels):
    label_embed = labels.view(-1,)

    label_embed = self.label_embedding(labels)
    label_embed = label_embed.view(labels.size(0),label_dims,1,1)

    label_embed = label_embed.expand(labels.size(0), label_dims, img.size(2), img.size(3))

    d_in = torch.cat((img, label_embed), dim=1)

    return self.layers(d_in)

In [None]:
# label_dims = 2

# discriminator = Discriminator()

# if cuda:
#   discriminator = discriminator.cuda()

# input_img = torch.rand(32,1,256,256).to(device='cuda')
# input_label = torch.randint(0,2,(32,)).to(device='cuda').unsqueeze(1)

# print(input_label.shape)

# output = discriminator(input_img,input_label)

# print(output.shape)

torch.Size([32, 1])
torch.Size([32, 3, 256, 256])
torch.Size([32, 1])


In [None]:
generator = Generator()
discriminator = Discriminator()
adversial_loss = nn.BCELoss()

if cuda:
  generator = generator.cuda()
  discriminator = discriminator.cuda()
  adversial_loss = adversial_loss.cuda()

optimizer_G = optim.Adam(generator.parameters(),lr=0.0002,betas=(0.5,0.999))
optimizer_D = optim.Adam(discriminator.parameters(),lr=0.0002,betas=(0.5,0.999))

In [None]:
epochs = 1500

for epoch in range(epochs):
  for real_images, real_labels in total_dataloader:

    generator.train()
    discriminator.train()

    batch_size = real_images.shape[0]

    real_targets = torch.ones((batch_size,1),dtype=torch.float, device='cuda')
    fake_targets = torch.zeros((batch_size, 1), dtype=torch.float, device='cuda')
    real_labels = real_labels.to(device='cuda').unsqueeze(1)

    real_images= real_images.to(device='cuda')

    optimizer_D.zero_grad()

    output_real = discriminator(real_images, real_labels)
    d_loss_real = adversial_loss(output_real, real_targets)
    d_loss_real.backward()
    optimizer_D.step()


    noise = torch.rand((batch_size, 510, 8,8), dtype=torch.float, device='cuda')
    gen_labels = torch.randint(0,2,(batch_size,)).to(device='cuda').unsqueeze(1)

    #gen_labels = LongTensor(batch_size, 1).uniform_(0, label_dims)

    optimizer_G.zero_grad()

    fake_images = generator(noise, gen_labels)

    output_fake = discriminator(fake_images, gen_labels)
    g_loss = adversial_loss(output_fake, fake_targets)
    g_loss.backward()
    optimizer_G.step()


    optimizer_D.zero_grad()

    output_fake = discriminator(fake_images.detach(), gen_labels)
    d_loss_fake = adversial_loss(output_fake, fake_targets)
    d_loss_fake.backward()
    optimizer_D.step()


  print('epoch [{}/{}], d_loss_real:{:.6f}'.format(epoch+1, epochs, d_loss_real.item()),
        ', g_loss:{:.6f}'.format(g_loss.item()),
        ', d_loss_fake:{:.6f}'.format(d_loss_fake.item()))

  if epoch%100 == 0:
    generator.eval()
    discriminator.eval()

    with torch.no_grad():
      test_size = 20

      noise = torch.rand((test_size, 510, 8,8), dtype=torch.float, device='cuda')
      gen_labels = torch.randint(0,2,(test_size,)).to(device='cuda').unsqueeze(1)

      fake_images = generator(noise, gen_labels)

      to_pil = transforms.ToPILImage()

      output_dir = '/content/drive/MyDrive/Dataset_SEM_Images/Output_Images' + f"_{epoch}_"

      os.makedirs(output_dir, exist_ok=True)

      for idx in range(fake_images.size(0)):
        img = to_pil(fake_images[idx])
        img.save(f"{output_dir}/image_{idx}_{gen_labels[idx,0]}.png")


epoch [1/1500], d_loss_real:0.514175 , g_loss:0.024285 , d_loss_fake:0.024285
epoch [2/1500], d_loss_real:0.183472 , g_loss:0.010834 , d_loss_fake:0.010834
epoch [3/1500], d_loss_real:0.058758 , g_loss:0.000845 , d_loss_fake:0.000845
epoch [4/1500], d_loss_real:0.019087 , g_loss:0.000593 , d_loss_fake:0.000593
epoch [5/1500], d_loss_real:0.013865 , g_loss:0.000594 , d_loss_fake:0.000594
epoch [6/1500], d_loss_real:0.006577 , g_loss:0.000020 , d_loss_fake:0.000020
epoch [7/1500], d_loss_real:0.007122 , g_loss:0.000001 , d_loss_fake:0.000001
epoch [8/1500], d_loss_real:0.001504 , g_loss:0.000184 , d_loss_fake:0.000184
epoch [9/1500], d_loss_real:0.002019 , g_loss:0.000292 , d_loss_fake:0.000292
epoch [10/1500], d_loss_real:0.002775 , g_loss:0.000019 , d_loss_fake:0.000019
epoch [11/1500], d_loss_real:0.004158 , g_loss:0.000254 , d_loss_fake:0.000254
epoch [12/1500], d_loss_real:0.003837 , g_loss:0.000131 , d_loss_fake:0.000131
epoch [13/1500], d_loss_real:0.001380 , g_loss:0.000093 , d_l

KeyboardInterrupt: 