<a href="https://colab.research.google.com/github/Hailemicael/Covid-GAN/blob/master/covid_gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
!unzip "/content/images_1.zip"

Archive:  /content/images_1.zip
  inflating: images/COVID-1.png      
  inflating: images/COVID-10.png     
  inflating: images/COVID-100.png    
  inflating: images/COVID-1000.png   
  inflating: images/COVID-1001.png   
  inflating: images/COVID-1002.png   
  inflating: images/COVID-1003.png   
  inflating: images/COVID-1004.png   
  inflating: images/COVID-1005.png   
  inflating: images/COVID-1006.png   
  inflating: images/COVID-1007.png   
  inflating: images/COVID-1008.png   
  inflating: images/COVID-1009.png   
  inflating: images/COVID-101.png    
  inflating: images/COVID-1010.png   
  inflating: images/COVID-1011.png   
  inflating: images/COVID-1012.png   
  inflating: images/COVID-1013.png   
  inflating: images/COVID-1014.png   
  inflating: images/COVID-1015.png   
  inflating: images/COVID-1016.png   
  inflating: images/COVID-1017.png   
  inflating: images/COVID-1018.png   
  inflating: images/COVID-1019.png   
  inflating: images/COVID-102.png    
  inflating: image

### **Importing Libraries**

In [6]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image, make_grid
from PIL import Image
from torch.utils.tensorboard import SummaryWriter

In [7]:
print("GPU Available:", torch.cuda.is_available())
print("GPU Device Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")

GPU Available: True
GPU Device Name: Tesla T4


### **Dataset Loading**


In [8]:
class CovidDataset(Dataset):
    def __init__(self, root_dir='/content/images', transform=None):
        self.root_dir = root_dir
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform or transforms.Compose([
            transforms.Grayscale(1),
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_path)
        return self.transform(image)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(0.2),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
        )

        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * 8 * 8, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv(x)
        return self.fc(x)

class Generator(nn.Module):
    def __init__(self, z_dim):
        super().__init__()
        self.fc = nn.Linear(z_dim, 512 * 8 * 8)

        self.conv = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.ConvTranspose2d(64, 1, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(-1, 512, 8, 8)
        return self.conv(x)

class CovidGAN:
    def __init__(self, z_dim=100, lr=1e-4, batch_size=32):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.z_dim = z_dim
        self.batch_size = batch_size

        self.gen = Generator(z_dim).to(self.device)
        self.disc = Discriminator().to(self.device)

        self.opt_disc = optim.Adam(self.disc.parameters(), lr=lr, betas=(0.5, 0.999))
        self.opt_gen = optim.Adam(self.gen.parameters(), lr=lr*1.5, betas=(0.5, 0.999))

        self.criterion = nn.BCELoss()
        self.writer = SummaryWriter('runs/covid_gan')

        os.makedirs('generated_images', exist_ok=True)
        os.makedirs('model_checkpoints', exist_ok=True)

    def train(self, num_epochs=200):
        dataset = CovidDataset()
        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=2)

        fixed_noise = torch.randn(32, self.z_dim).to(self.device)

        print(f"Starting training on {self.device}")
        print(f"Dataset size: {len(dataset)} images")

        step = 0
        for epoch in range(num_epochs):
            for batch_idx, real_images in enumerate(dataloader):
                batch_size = real_images.size(0)
                real_images = real_images.to(self.device)

                # Train Discriminator
                self.opt_disc.zero_grad()
                label_real = torch.ones(batch_size, 1).to(self.device)
                label_fake = torch.zeros(batch_size, 1).to(self.device)

                output_real = self.disc(real_images)
                d_loss_real = self.criterion(output_real, label_real)

                noise = torch.randn(batch_size, self.z_dim).to(self.device)
                fake_images = self.gen(noise)
                output_fake = self.disc(fake_images.detach())
                d_loss_fake = self.criterion(output_fake, label_fake)

                d_loss = (d_loss_real + d_loss_fake) / 2
                d_loss.backward()
                self.opt_disc.step()

                # Train Generator
                self.opt_gen.zero_grad()
                output_fake = self.disc(fake_images)
                g_loss = self.criterion(output_fake, label_real)
                g_loss.backward()
                self.opt_gen.step()

                if batch_idx % 10 == 0:
                    print(f"Epoch [{epoch}/{num_epochs}] Batch [{batch_idx}/{len(dataloader)}] "
                          f"D_loss: {d_loss:.4f}, G_loss: {g_loss:.4f}")

                    self.writer.add_scalar("D_loss", d_loss.item(), step)
                    self.writer.add_scalar("G_loss", g_loss.item(), step)

                    with torch.no_grad():
                        fake_samples = self.gen(fixed_noise)
                        save_image(fake_samples, f"generated_images/epoch_{epoch}_batch_{batch_idx}.png",
                                 normalize=True, nrow=8)

                        self.writer.add_image("Generated Images",
                                            make_grid(fake_samples, normalize=True, nrow=8),
                                            step)

                step += 1

            # Save model every epoch
            torch.save({
                'generator_state_dict': self.gen.state_dict(),
                'discriminator_state_dict': self.disc.state_dict(),
                'gen_optimizer_state_dict': self.opt_gen.state_dict(),
                'disc_optimizer_state_dict': self.opt_disc.state_dict(),
            }, f'model_checkpoints/covid_gan_epoch_{epoch}.pth')

        self.writer.close()

    def generate_samples(self, num_samples=16):
        with torch.no_grad():
            noise = torch.randn(num_samples, self.z_dim).to(self.device)
            fake_images = self.gen(noise)
            save_image(fake_images, "generated_samples.png", normalize=True, nrow=4)

if __name__ == "__main__":
    gan = CovidGAN()
    gan.train(num_epochs=200)
    gan.generate_samples()

Starting training on cuda
Dataset size: 3616 images
Epoch [0/200] Batch [0/113] D_loss: 0.6740, G_loss: 1.1175
Epoch [0/200] Batch [10/113] D_loss: 0.0611, G_loss: 3.5388
Epoch [0/200] Batch [20/113] D_loss: 0.0465, G_loss: 3.9597
Epoch [0/200] Batch [30/113] D_loss: 0.0582, G_loss: 5.5345
Epoch [0/200] Batch [40/113] D_loss: 0.0241, G_loss: 6.2572
Epoch [0/200] Batch [50/113] D_loss: 0.0181, G_loss: 6.2205
Epoch [0/200] Batch [60/113] D_loss: 0.0149, G_loss: 6.3180
Epoch [0/200] Batch [70/113] D_loss: 0.0102, G_loss: 6.3120
Epoch [0/200] Batch [80/113] D_loss: 0.0658, G_loss: 6.5752
Epoch [0/200] Batch [90/113] D_loss: 0.0011, G_loss: 7.3538
Epoch [0/200] Batch [100/113] D_loss: 0.0098, G_loss: 7.3150
Epoch [0/200] Batch [110/113] D_loss: 0.0212, G_loss: 6.4913
Epoch [1/200] Batch [0/113] D_loss: 0.0015, G_loss: 8.5346
Epoch [1/200] Batch [10/113] D_loss: 0.0014, G_loss: 7.3338
Epoch [1/200] Batch [20/113] D_loss: 0.0047, G_loss: 7.4501
Epoch [1/200] Batch [30/113] D_loss: 0.0076, G_l