In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models
import torchvision

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.tensorboard import SummaryWriter

import os
from PIL import Image
from IPython.display import display

import warnings
warnings.filterwarnings('ignore')

import torch.optim as optim

In [2]:
transforms = transforms.Compose(
    
    [
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

In [3]:
torch.manual_seed(42)

root = '/Users/karthik/GANS/Resized-224X224-Bio_Materials/Rezied-224X224-Bio-Material-DATA/'
print(os.path.exists(root))

True


In [4]:
import shutil

class_dir = os.path.join(root, 'bio_materials')
os.makedirs(class_dir, exist_ok=True)

for entry in os.scandir(root):
    if entry.is_file() and entry.name.endswith('.png'):  
        shutil.move(entry.path, class_dir)


train_data = datasets.ImageFolder(root, transform=transforms)
train_loader = torch.utils.data.DataLoader(train_data, batch_size= 64, shuffle=True)

In [5]:
num_epochs = 200
z_dim = 112
channels_img = 1

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [7]:
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img):
        super(Generator, self).__init__()

        # Initial ANN stage with gradually increasing neurons
        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, 2048 * 8 * 8)  # Reshape to 8x8 feature map

        # CNN Stage
        self.conv1 = nn.ConvTranspose2d(2048, 1024, kernel_size=4, stride=2, padding=1)  # Output: 16x16
        self.conv2 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1)   # Output: 32x32
        self.conv3 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)    # Output: 64x64
        self.conv4 = nn.ConvTranspose2d(256, channels_img, kernel_size=4, stride=2, padding=1) # Output: 128x128

    def forward(self, x):

        x = x.view(x.size(0), -1)
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = x.view(-1, 2048, 8, 8)
        x = F.leaky_relu(self.conv1(x), 0.2)
        x = F.leaky_relu(self.conv2(x), 0.2)
        x = F.leaky_relu(self.conv3(x), 0.2)
        x = torch.tanh(self.conv4(x))
        return x

In [8]:
class Discriminator(nn.Module):
    def __init__(self, channels_img):
        super(Discriminator, self).__init__()

        # CNN Stage
        self.conv1 = nn.Conv2d(channels_img, 64, kernel_size=4, stride=2, padding=1)  # Output: 64x64
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)           # Output: 32x32
        self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)         # Output: 16x16
        self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)         # Output: 8x8

        # Flatten the output of the last convolutional layer
        self.flatten_size = 512 * 8 * 8

        # ANN Stage with LeakyReLU and Dropout
        self.fc1 = nn.Linear(self.flatten_size, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 1)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.2)
        x = F.leaky_relu(self.conv2(x), 0.2)
        x = F.leaky_relu(self.conv3(x), 0.2)
        x = F.leaky_relu(self.conv4(x), 0.2)
        x = x.view(-1, self.flatten_size)
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = self.dropout(x)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = self.dropout(x)
        x = torch.sigmoid(self.fc3(x))
        return x

In [9]:
generator = Generator(z_dim, channels_img=1)  
discriminator = Discriminator(channels_img=1) 

In [10]:
criterion = nn.BCELoss()

opt_gen = torch.optim.Adam(generator.parameters(), lr= 0.0002, betas= (0.5, 0.999))
opt_disc = torch.optim.Adam(discriminator.parameters(), lr= 0.0002, betas= (0.5, 0.999))


In [11]:
for images, labels in train_loader:
    break

In [12]:
images.shape

torch.Size([64, 1, 128, 128])

In [13]:
from torch.utils.tensorboard import SummaryWriter 

writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0

In [None]:
fixed_noise = torch.randn(64, z_dim, 1, 1, device=device)

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(train_loader):
        real = real.to(device)
        batch_size = real.size(0)

        # Correctly sized target tensors
        real_target = torch.ones(batch_size, 1, device=device)
        fake_target = torch.zeros(batch_size, 1, device=device)

        noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
        fake = generator(noise)

        # Train Discriminator
        discriminator.zero_grad()
        real_loss = criterion(discriminator(real), real_target)
        fake_loss = criterion(discriminator(fake.detach()), fake_target)
        disc_loss = real_loss + fake_loss
        disc_loss.backward()
        opt_disc.step()

        # Train Generator
        generator.zero_grad()
        gen_loss = criterion(discriminator(fake), real_target)
        gen_loss.backward()
        opt_gen.step()

        # TensorBoard logging
        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(train_loader)} \
                   Loss D: {disc_loss:.4f}, loss G: {gen_loss:.4f}")

            with torch.no_grad():
                fake = generator(fixed_noise)

                img_grid_real = torchvision.utils.make_grid(real[:16], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:16], normalize=True)

                writer_real.add_image("Real Images", img_grid_real, global_step=step)
                writer_fake.add_image("Fake Images", img_grid_fake, global_step=step)
                step += 1


Epoch [0/200] Batch 0/34                    Loss D: 1.3854, loss G: 0.6942
Epoch [1/200] Batch 0/34                    Loss D: 0.6953, loss G: 5.4329
