In [None]:
from google.colab import drive
drive.mount('/content/drive')  # Mounts Google Drive at /content/drive


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [9]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm  # For progress tracking

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define a basic transform for your images
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load images from a folder
def load_images_from_folder(folder, transform, limit=None):
    images = []
    for i, filename in enumerate(os.listdir(folder)):
        if limit and i >= limit:  # Limit the number of images loaded
            break
        img_path = os.path.join(folder, filename)
        img = Image.open(img_path)
        if transform:
            img = transform(img)
        images.append(img)
    return torch.stack(images)

# Folder containing real images
image_folder = '/content/drive/My Drive/PRIE/train/Vascular lesion/vascular lesion 4'
real_images = load_images_from_folder(image_folder, transform)

# Create a DataLoader
dataloader = torch.utils.data.DataLoader(real_images, batch_size=32, shuffle=True)

# Hyperparameters
latent_dim =256
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
num_epochs = 1500
save_interval = 50  # Generate and display images every 50 epochs
save_folder = '/content/drive/My Drive/Dataset/generate 1'

# Define the generator
# Define the generator with ConvTranspose2d layers for smoother upscaling
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 1024 * 7 * 7),  # Start with a small feature map
            nn.ReLU(),
            nn.Unflatten(1, (1024, 7, 7)),  # Shape it into (1024, 7, 7)

            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),  # 7x7 -> 14x14
            nn.BatchNorm2d(512),
            nn.ReLU(),

            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),  # 14x14 -> 28x28
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 28x28 -> 56x56
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 56x56 -> 112x112
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # 112x112 -> 224x224
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.Conv2d(32, 3, kernel_size=3, padding=1),  # 224x224 output with 3 channels (RGB)
            nn.Tanh()  # Output values in range [-1, 1] for normalized images
        )

    def forward(self, z):
        img = self.model(z)
        return img


# Define the discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),  # 224x224 -> 112x112
            nn.LeakyReLU(0.2),
            nn.Dropout(0.25),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 112x112 -> 56x56
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.25),

            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # 56x56 -> 28x28
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.25),

            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),  # 28x28 -> 14x14
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.25),

            nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1),  # 14x14 -> 7x7
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.25)
        )
        self.fc = nn.Linear(1024 * 7 * 7, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.size(0), -1)  # Flatten
        validity = self.fc(out)
        return self.sigmoid(validity)



# Define the generator and discriminator
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

# Loss function
adversarial_loss = nn.BCELoss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))

# Ensure the save folder exists
if not os.path.exists(save_folder):
    os.makedirs(save_folder)


# Real-time image generation function
def generate_image(generator, latent_dim, device):
    with torch.no_grad():
        # Sample random noise for the generator
        z = torch.randn(1, latent_dim, device=device)

        # Generate a single image (1 image of 224x224 size)
        generated_image = generator(z).detach().cpu()

        # Denormalize the image from [-1, 1] to [0, 1]
        generated_image = (generated_image + 1) / 2.0

        # Display the image
        plt.figure(figsize=(4, 4))
        plt.imshow(np.transpose(generated_image.squeeze().numpy(), (1, 2, 0)))
        plt.show()

    return generated_image


# Training loop
for epoch in tqdm(range(num_epochs), desc="Training Epochs"):
    for i, real_images in enumerate(dataloader):
        real_images = real_images.to(device)

        # Adversarial ground truths
        valid = torch.ones(real_images.size(0), 1, device=device)
        fake = torch.zeros(real_images.size(0), 1, device=device)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        # Sample noise as generator input
        z = torch.randn(real_images.size(0), latent_dim, device=device)
        fake_images = generator(z)

        # Discriminator losses
        real_loss = adversarial_loss(discriminator(real_images), valid)
        fake_loss = adversarial_loss(discriminator(fake_images.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()

        # Generator loss
        g_loss = adversarial_loss(discriminator(fake_images), valid)
        g_loss.backward()
        optimizer_G.step()

    # Print the losses every 100 epochs for monitoring
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}] - Generator Loss: {g_loss.item():.4f}, Discriminator Loss: {d_loss.item():.4f}")

    # Generate and display images every `save_interval` epochs
    if (epoch + 1) % save_interval == 0:
        with torch.no_grad():
            z = torch.randn(16, latent_dim, device=device)
            generated_images = generator(z).detach().cpu()

            # Denormalize images
            generated_images = (generated_images + 1) / 2.0  # Transform from [-1, 1] to [0, 1]

            # Display images
            grid_img = torchvision.utils.make_grid(generated_images, nrow=4)
            plt.figure(figsize=(8, 8))
            plt.imshow(np.transpose(grid_img.numpy(), (1, 2, 0)))
            plt.show()

# Save the final generated images after training completes
with torch.no_grad():
    num_generated_images = 1000  # You can specify how many images you want to generate
    for i in range(num_generated_images // 16):
        z = torch.randn(16, latent_dim, device=device)
        generated_images = generator(z).detach().cpu()

        for j, image in enumerate(generated_images):
            img_name = os.path.join(save_folder, f"generated_{i*16 + j + 1}.png")
            torchvision.utils.save_image(image, img_name, normalize=True)

print(f"All {num_generated_images} images saved to {save_folder}.")

# Call this function to generate an image in real-time
generate_image(generator, latent_dim, device)


Output hidden; open in https://colab.research.google.com to view.