In [1]:
# Conditional GAN (CGAN) implementation

In [2]:
!pip install -q torch torchvision datasets pillow scikit-learn matplotlib pandas tabulate

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from datasets import load_dataset
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
from collections import Counter
import os

#  Load dataset
dataset = load_dataset("yuighj123/covid-19-classification")
num_classes = len(set(dataset["train"]["label"]))
label_names = dataset["train"].features["label"].names

#  Preprocess Dataset for CGAN
class CustomCovidDataset(Dataset):
    def __init__(self, hf_dataset, split, transform=None):
        self.items = hf_dataset[split]
        self.transform = transform
        self.images = self.items['image']
        self.labels = self.items['label']
        self.label2idx = {l: i for i, l in enumerate(set(self.labels))}
        self.idx2label = {i: l for l, i in self.label2idx.items()}

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        if not isinstance(img, Image.Image):
            img = Image.fromarray(np.array(img))
        img = img.convert('RGB')
        label = self.label2idx[self.labels[idx]]
        if self.transform:
            img = self.transform(img)
        return img, label

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

train_ds = CustomCovidDataset(dataset, split="train", transform=transform)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)

#  Define CGAN
latent_dim = 100
image_shape = (3, 64, 64)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 128),
            nn.ReLU(True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(True),
            nn.Linear(512, int(np.prod(image_shape))),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        c = self.label_emb(labels)
        x = torch.cat([noise, c], 1)
        img = self.model(x)
        img = img.view(img.size(0), *image_shape)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(num_classes + int(np.prod(image_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img, labels):
        c = self.label_emb(labels)
        x = torch.cat([img.view(img.size(0), -1), c], 1)
        validity = self.model(x)
        return validity

generator = Generator().to(device)
discriminator = Discriminator().to(device)

adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

#  Train CGAN
n_epochs = 10
os.makedirs("generated_samples", exist_ok=True)

for epoch in range(n_epochs):
    for imgs, labels in train_loader:
        batch_size = imgs.size(0)
        real_imgs = imgs.to(device)
        labels = labels.to(device)
        valid = torch.ones(batch_size, 1, device=device)
        fake = torch.zeros(batch_size, 1, device=device)
        optimizer_G.zero_grad()
        z = torch.randn(batch_size, latent_dim, device=device)
        gen_labels = torch.randint(0, num_classes, (batch_size,), device=device)
        gen_imgs = generator(z, gen_labels)
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = adversarial_loss(validity, valid)
        g_loss.backward()
        optimizer_G.step()
        optimizer_D.zero_grad()
        real_pred = discriminator(real_imgs, labels)
        d_real_loss = adversarial_loss(real_pred, valid)
        fake_pred = discriminator(gen_imgs.detach(), gen_labels)
        d_fake_loss = adversarial_loss(fake_pred, fake)
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
    print(f"Epoch [{epoch+1}/{n_epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")
    if (epoch + 1) % 5 == 0:
        for class_idx in range(num_classes):
            z = torch.randn(10, latent_dim, device=device)
            labels = torch.full((10,), class_idx, dtype=torch.long, device=device)
            with torch.no_grad():
                fake_imgs = generator(z, labels)
            utils.save_image(
                fake_imgs,
                f"generated_samples/epoch_{epoch+1}_class_{class_idx}.png",
                nrow=5, normalize=True
            )

# Generate Synthetic Images for Balancing
counts = dict(Counter(dataset["train"]["label"]))
max_count = max(counts.values())
to_generate = {label: max_count - count for label, count in counts.items()}
label2idx = train_ds.label2idx
idx2label = train_ds.idx2label

balanced_dataset = list(dataset["train"])
for label, n_gen in to_generate.items():
    if n_gen == 0:
        continue
    print(f"Generating {n_gen} images for class '{idx2label[label]}'...")
    z = torch.randn(n_gen, latent_dim, device=device)
    gen_labels = torch.full((n_gen,), label, dtype=torch.long, device=device)
    with torch.no_grad():
        fake_imgs = generator(z, gen_labels)
    # Convert synthetic images to PIL for CNN
    fake_imgs = fake_imgs.cpu().numpy()  # Shape: (n_gen, 3, 64, 64)
    fake_imgs = (fake_imgs + 1) / 2  # Denormalize to [0, 1]
    fake_imgs = np.transpose(fake_imgs, (0, 2, 3, 1))  # Shape: (n_gen, 64, 64, 3)
    for i in range(n_gen):
        img_array = (fake_imgs[i] * 255).astype(np.uint8)
        img = Image.fromarray(img_array)
        balanced_dataset.append({"image": img, "label": label})


2025-05-25 14:02:41.239359: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-25 14:02:43.533015: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748174564.367854    5928 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748174564.625252    5928 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1748174566.144220    5928 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

Epoch [1/10] | D Loss: 0.1962 | G Loss: 1.1470
Epoch [2/10] | D Loss: 0.1198 | G Loss: 1.6021
Epoch [3/10] | D Loss: 0.3381 | G Loss: 1.4070
Epoch [4/10] | D Loss: 0.1582 | G Loss: 1.3190
Epoch [5/10] | D Loss: 0.3479 | G Loss: 0.8758
Epoch [6/10] | D Loss: 0.2178 | G Loss: 1.2153
Epoch [7/10] | D Loss: 0.1390 | G Loss: 1.6391
Epoch [8/10] | D Loss: 0.0823 | G Loss: 2.2249
Epoch [9/10] | D Loss: 0.0465 | G Loss: 2.9046
Epoch [10/10] | D Loss: 0.0402 | G Loss: 3.0568
Generating 41 images for class '1'...
Generating 41 images for class '2'...


In [2]:

#  Preprocess for CNN
def preprocess_image_for_cnn(image, target_size=(224, 224)):
    try:
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image.astype(np.uint8))
        image = image.convert('RGB').resize(target_size)
        image = np.array(image, dtype=np.float32) / 255.0
        if image.shape != (224, 224, 3):
            raise ValueError(f"Image has unexpected shape: {image.shape}")
        return image
    except Exception as e:
        print(f"Error processing image: {e}")
        return None

train_images = []
train_labels = []
for sample in balanced_dataset:
    img = preprocess_image_for_cnn(sample["image"], target_size=(224, 224))
    if img is not None:
        train_images.append(img)
        train_labels.append(sample["label"])
    else:
        print(f"Skipping invalid image for label {sample['label']}")

train_images = np.array(train_images, dtype=np.float32)
train_labels = np.array(train_labels, dtype=np.int32)
print(f"Train images shape: {train_images.shape}")

val_images = []
val_labels = []
for sample in dataset["test"]:
    img = preprocess_image_for_cnn(sample["image"], target_size=(224, 224))
    if img is not None:
        val_images.append(img)
        val_labels.append(sample["label"])
    else:
        print(f"Skipping invalid validation image for label {sample['label']}")

val_images = np.array(val_images, dtype=np.float32)
val_labels = np.array(val_labels, dtype=np.int32)
print(f"Validation images shape: {val_images.shape}")

#  Train CNN
train_datagen = ImageDataGenerator(
    rotation_range=10,
    zoom_range=0.1,
    horizontal_flip=True,
    fill_mode='nearest'
)
val_datagen = ImageDataGenerator()

train_generator = train_datagen.flow(train_images, train_labels, batch_size=32)
val_generator = val_datagen.flow(val_images, val_labels, batch_size=32)

model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)),
    MaxPooling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Flatten(),
    Dense(128, activation='relu'),
    Dense(num_classes, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

history = model.fit(
    train_generator,
    epochs=10,
    validation_data=val_generator
)

#  Evaluate
val_loss, val_accuracy = model.evaluate(val_generator)
print(f"Validation Loss: {val_loss:.4f}")
print(f"Validation Accuracy: {val_accuracy:.4f}")

val_predictions = model.predict(val_images)
val_pred_labels = np.argmax(val_predictions, axis=1)

print("\nClassification Report:")
print(classification_report(val_labels, val_pred_labels, target_names=label_names))

cm = confusion_matrix(val_labels, val_pred_labels)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=label_names, yticklabels=label_names)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.savefig('confusion_matrix.png')
plt.close()

Train images shape: (333, 224, 224, 3)
Validation images shape: (66, 224, 224, 3)


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
2025-05-25 14:20:48.996031: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
  self._warn_if_super_not_called()


Epoch 1/10
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 2s/step - accuracy: 0.4053 - loss: 2.4042 - val_accuracy: 0.3030 - val_loss: 1.0842
Epoch 2/10
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 3s/step - accuracy: 0.5073 - loss: 1.0124 - val_accuracy: 0.6970 - val_loss: 0.7598
Epoch 3/10
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 3s/step - accuracy: 0.6267 - loss: 0.8092 - val_accuracy: 0.6212 - val_loss: 0.7901
Epoch 4/10
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 3s/step - accuracy: 0.7527 - loss: 0.6529 - val_accuracy: 0.6818 - val_loss: 0.5748
Epoch 5/10
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 3s/step - accuracy: 0.7326 - loss: 0.5363 - val_accuracy: 0.7273 - val_loss: 0.5197
Epoch 6/10
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 2s/step - accuracy: 0.7429 - loss: 0.4973 - val_accuracy: 0.6970 - val_loss: 0.8129
Epoch 7/10
[1m11/11[0m [32m━━━━━━━━━━