In [None]:
from google.colab import drive

# This will display the contents of your Google Drive's root directory if mounted,
# otherwise, it will throw an error.
try:
    drive.mount('/content/drive')
    print("Google Drive is mounted.")
except:
    print("Google Drive is not mounted.")


In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from IPython.display import clear_output
import torchvision.datasets as dset
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
import os

# Hyperparameters etc.
device = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
BATCH_SIZE = 32
IMAGE_WIDTH = 512
IMAGE_HEIGHT = 512
CHANNELS_IMG = 1
Z_DIM = 100
NUM_EPOCHS = 1000
FEATURES_CRITIC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 3
LAMBDA_GP = 10



checkpoint_dir = "/content/drive/My Drive/GAN - Data Projekt/"

dataroot = "/content/drive/My Drive/GAN - Data Projekt/Billeder(Trimmet)/Billeder (gode)"  # dental radiography data/Good"

dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
                               transforms.Grayscale(num_output_channels=1),  # Converts RGB to greyscale
                               transforms.ToTensor(),
                               transforms.Normalize((0.5), (0.5))  # Use a single value for mean and std
                           ]))



loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True
)


device = "cuda" if torch.cuda.is_available() else "cpu"
device


'cuda'

In [None]:
# Check GPU information
print("GPU Information:")
!nvidia-smi

# Check CPU information
print("\nCPU Information:")
!lscpu

In [None]:
def gradient_penalty(critic, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    #print(BATCH_SIZE, C, H, W)
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    #print(f"Real shape: {real.shape}, Fake shape: {fake.shape}")
    interpolated_images = real * alpha + fake * (1 - alpha)


    # Calculate critic scores
    mixed_scores = critic(interpolated_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty


def save_checkpoint(state, filename="celeba_wgan_gp.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)





In [None]:

class Discriminator(nn.Module):
    def __init__(self, img_channels, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # 512x512 -> 256x256
            nn.Conv2d(in_channels=img_channels, out_channels=features_d, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2),
            # 256x256 -> 128x128
            nn.Conv2d(in_channels=features_d, out_channels=features_d * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(num_features=features_d * 2, affine=True),
            nn.LeakyReLU(0.2),
             # 128x128 -> 64x64
            nn.Conv2d(in_channels=features_d * 2, out_channels=features_d * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(num_features=features_d * 4, affine=True),
            nn.LeakyReLU(0.2),
             # 64x64 -> 32x32
            nn.Conv2d(in_channels=features_d * 4, out_channels=features_d * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(num_features=features_d * 8, affine=True),
            nn.LeakyReLU(0.2),
             # 32x32 -> 16x16
            nn.Conv2d(in_channels=features_d * 8, out_channels=features_d * 16, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(num_features=features_d * 16, affine=True),
            nn.LeakyReLU(0.2),
            # 16x16 -> 8x8
            nn.Conv2d(in_channels=features_d * 16, out_channels=features_d * 32, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(num_features=features_d * 32, affine=True),
            nn.LeakyReLU(0.2),
            # 8x8 -> 4x4
            nn.Conv2d(in_channels=features_d * 32, out_channels=features_d * 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(num_features=features_d * 64, affine=True),
            nn.LeakyReLU(0.2),
            # 4x4 -> 1x1
            nn.Conv2d(in_channels=features_d * 64, out_channels=1, kernel_size=4, stride=2, padding=0, bias=False),
        )

    def forward(self, x):
        return self.disc(x)


class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # Input: N x channels_noise x 1 x 1
            # img: 4x4
            nn.ConvTranspose2d(in_channels=channels_noise, out_channels=features_g * 64, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(num_features=features_g * 64),
            nn.ReLU(),
            # img: 8x8
            nn.ConvTranspose2d(in_channels=features_g * 64, out_channels=features_g * 32, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=features_g * 32),
            nn.ReLU(),
            # img: 16x16
            nn.ConvTranspose2d(in_channels=features_g * 32, out_channels=features_g * 16, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=features_g * 16),
            nn.ReLU(),
            # img: 32x32
            nn.ConvTranspose2d(in_channels=features_g * 16, out_channels=features_g * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=features_g * 8),
            nn.ReLU(),
            # img: 64x64
            nn.ConvTranspose2d(in_channels=features_g * 8, out_channels=features_g * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=features_g * 4),
            nn.ReLU(),
             # img: 128x128
            nn.ConvTranspose2d(in_channels=features_g * 4, out_channels=features_g * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=features_g * 2),
            nn.ReLU(),
            # img: 256x256
            nn.ConvTranspose2d(in_channels=features_g * 2, out_channels=features_g, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=features_g),
            nn.ReLU(),
            # img: 512x512
            nn.ConvTranspose2d(in_channels=features_g, out_channels=channels_img, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh(),  # Output: N x channels_img x 512 x 512
        )

    def forward(self, x):
        return self.net(x)



In [None]:
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)




In [None]:
# Initialize the generator and critic
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)

# Initialize the optimizers
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

# Initialize TensorBoard writer
writer_real = SummaryWriter()
writer_fake = SummaryWriter()

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=runs

In [None]:
def load_checkpoint(filename, model, optimizer, model_key, optimizer_key):
    print("=> Loading checkpoint")
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint[model_key])
    optimizer.load_state_dict(checkpoint[optimizer_key])
    return checkpoint['epoch']

In [None]:
start_epoch = 400
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint_epoch_400_512x512.pth")
if os.path.exists(checkpoint_path):
    start_epoch = load_checkpoint(checkpoint_path, gen, opt_gen, 'gen_state_dict', 'gen_optimizer')
    load_checkpoint(checkpoint_path, critic, opt_critic, 'critic_state_dict', 'critic_optimizer')
    start_epoch += 1  # To continue from the next epoch

In [None]:
gen.train()
critic.train()


# Training loop
for epoch in range(NUM_EPOCHS):
    loop = tqdm(loader, leave=True)
    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(device)
        cur_batch_size = real.shape[0]

        # Generate fake images
        noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
        fake = gen(noise)

        # Train Critic
        for _ in range(CRITIC_ITERATIONS):
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            gp = gradient_penalty(critic, real, fake, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
            )
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        # Train Generator
        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Log the losses to TensorBoard
        if batch_idx % 10 == 0:
            writer_real.add_scalar("Critic Loss", loss_critic.item(), global_step=epoch * len(loader) + batch_idx)
            writer_fake.add_scalar("Generator Loss", loss_gen.item(), global_step=epoch * len(loader) + batch_idx)

            with torch.no_grad():
                fake = gen(noise).detach().cpu()
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                writer_real.add_image("Real Images", img_grid_real, global_step=epoch * len(loader) + batch_idx)
                writer_fake.add_image("Fake Images", img_grid_fake, global_step=epoch * len(loader) + batch_idx)

        # Update the progress bar
        loop.set_postfix(loss_critic=loss_critic.item(), loss_gen=loss_gen.item())

    print(f"Completed Epoch {epoch+1}/{NUM_EPOCHS}.")  # Confirms the completion of the current epoch

    # Save the model checkpoint after 50 epochs
    if (epoch + 1) % 50 == 0:
        checkpoint = {
            "epoch": epoch,
            "gen_state_dict": gen.state_dict(),
            "critic_state_dict": critic.state_dict(),
            "gen_optimizer": opt_gen.state_dict(),
            "critic_optimizer": opt_critic.state_dict(),
        }
        save_checkpoint(checkpoint, filename=os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}_512x512.pth"))


# Close the TensorBoard writers
writer_real.close()
writer_fake.close()