<a href="https://colab.research.google.com/github/HimanshuPant7/Crack_Generation/blob/main/WGAN_Crack_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import save_image

In [None]:
class CrackDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = os.listdir(root_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.images[idx])
        image = Image.open(img_name).convert('RGB')
        image = image.resize((256, 256), resample=Image.BICUBIC)
        if self.transform:
            image = self.transform(image)
        return image


DATA AUGMENTATION

In [None]:
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # state size. (512) x 4 x 4
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # state size. (256) x 8 x 8
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # state size. (128) x 16 x 16
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # state size. (64) x 32 x 32
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            # state size. (32) x 64 x 64
            nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            # state size. (16) x 128 x 128
            nn.ConvTranspose2d(16, 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            # state size. (8) x 256 x 256
            nn.ConvTranspose2d(8, 3, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (3) x 256 x 256
        )

    def forward(self, input):
        output = self.main(input)
        return output


In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, 4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, 4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        self.fc1 = nn.Linear(512 * 4 * 4, 1)

    def forward(self, x):
        x = nn.functional.leaky_relu(self.conv1(x), 0.2)
        x = nn.functional.leaky_relu(self.bn2(self.conv2(x)), 0.2)
        x = nn.functional.leaky_relu(self.bn3(self.conv3(x)), 0.2)
        x = nn.functional.leaky_relu(self.bn4(self.conv4(x)), 0.2)
        x = x.view(-1, 512 * 4 * 4)
        x = self.fc1(x)
        return x



In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
    elif classname.find("Linear") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


In [None]:
G = Generator()
D = Discriminator()
G.apply(weights_init)
D.apply(weights_init)


Discriminator(
  (conv1): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=8192, out_features=1, bias=True)
)

In [None]:
lr = 0.0002
beta1 = 0.5
beta2 = 0.999

G_optimizer = optim.Adam(G.parameters(), lr, [beta1, beta2])
D_optimizer = optim.Adam(D.parameters(), lr, [beta1, beta2])


In [None]:
# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Move the models to the device
G.to(device)
D.to(device)

Discriminator(
  (conv1): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=8192, out_features=1, bias=True)
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 64
num_epochs = 2000
clip_value = 0.01

dataset = CrackDataset("/content/drive/MyDrive/LiuLabels", transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

for epoch in range(num_epochs):
    for i, data in enumerate(dataloader):
        real_images = data.to(device)
        batch_size = real_images.size(0)

        # Train discriminator
        for j in range(5):
            D.zero_grad()
            z = torch.randn(batch_size, 100, 1, 1).to(device)
            fake_images = G(z).detach()
            D_real = D(real_images)
            D_fake = D(fake_images)
            D_loss = D_fake.mean() - D_real.mean()
            D_loss.backward()
            D_optimizer.step()

            for p in D.parameters():
                p.data.clamp_(-clip_value, clip_value)

        # Train generator
        G.zero_grad()
        z = torch.randn(batch_size, 100, 1, 1).to(device)
        fake_images = G(z)
        D_fake = D(fake_images)
        G_loss = -D_fake.mean()
        G_loss.backward()
        G_optimizer.step()

        if i % 10 == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Batch [{i}/{len(dataloader)}], D_loss: {D_loss.item():.4f}, G_loss: {G_loss.item():.4f}")

    # Save generated images
    with torch.no_grad():
        # z = torch.randn(16, 100, 1, 1).to(device)
        z = torch.randn(1, 100, 1, 1).to(device)
        fake_images = G(z).detach().cpu()
        save_image(fake_images, f"generated_images_{epoch+1}.png", nrow=4, normalize=True)

    # # Save models
    # torch.save(G.state_dict(), f"G_{epoch+1}.pt")
    # torch.save(D.state_dict(), f"D_{epoch+1}.pt")


Epoch [0/2000], Batch [0/4], D_loss: -0.0323, G_loss: 0.0011
Epoch [1/2000], Batch [0/4], D_loss: -0.1694, G_loss: -0.0007
Epoch [2/2000], Batch [0/4], D_loss: -0.3520, G_loss: 0.0060
Epoch [3/2000], Batch [0/4], D_loss: -0.4452, G_loss: 0.0169
Epoch [4/2000], Batch [0/4], D_loss: -0.5097, G_loss: 0.0298
Epoch [5/2000], Batch [0/4], D_loss: -0.5595, G_loss: 0.0433
Epoch [6/2000], Batch [0/4], D_loss: -0.6024, G_loss: 0.0572
Epoch [7/2000], Batch [0/4], D_loss: -0.6384, G_loss: 0.0709
Epoch [8/2000], Batch [0/4], D_loss: -0.6691, G_loss: 0.0846
Epoch [9/2000], Batch [0/4], D_loss: -0.6961, G_loss: 0.0986
Epoch [10/2000], Batch [0/4], D_loss: -0.7205, G_loss: 0.1122
Epoch [11/2000], Batch [0/4], D_loss: -0.7437, G_loss: 0.1262
Epoch [12/2000], Batch [0/4], D_loss: -0.7666, G_loss: 0.1413
Epoch [13/2000], Batch [0/4], D_loss: -0.7875, G_loss: 0.1574
Epoch [14/2000], Batch [0/4], D_loss: -0.8077, G_loss: 0.1748
Epoch [15/2000], Batch [0/4], D_loss: -0.8300, G_loss: 0.1923
Epoch [16/2000], 