## Dataset Preprocessing

In [None]:
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
import torch
import torch.nn as nn
import numpy as np

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
class MalariaDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (str): Root directory containing species folders.
            transform (callable, optional): Transform to apply to the images.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.data = []

        # Mapping for stages in binary vector
        stage_mapping = {'R': 0, 'T': 1, 'S': 2, 'G': 3}

        # Loading the images from the species directory "specie/img/"
        for species in os.listdir(self.root_dir):
            species_dir = os.path.join(self.root_dir, species, 'img')
            if not os.path.isdir(species_dir):
                continue

            for filename in os.listdir(species_dir):
                if filename.endswith('.jpg'):
                    # Parsing species and stages
                    filepath = os.path.join(species_dir, filename)
                    stage_tag = filename.split('-')[-1].split('.')[0]
                    # For multi-stage cases
                    stages = stage_tag.split('_')

                    # Stage vector for mapping stages
                    stage_vector = [0, 0, 0, 0]
                    for stage in stages:
                        if stage in stage_mapping:
                            # print('stage:', stage)
                            stage_vector[stage_mapping[stage]] = 1
                    self.data.append({'filepath': filepath, 'species': species, 'stages': stage_vector})

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image = Image.open(item['filepath']).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, item['species'], item['stages']

In [None]:
# Instantiate the dataset
root_dir = '/content/drive/MyDrive/malaria_project/malaria_dataset'
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

dataset = MalariaDataset(root_dir, transform=transform)
print(f"Dataset size: {len(dataset)}")

# Test single data retrieval from dataset
image, species, stages = dataset[0]
print(f"Species: {species}, Stages: {stages}")

Dataset size: 210
Species: Falciparum, Stages: [1, 0, 0, 0]


In [None]:
# Create DataLoader
data_loader = DataLoader(dataset, batch_size=16, shuffle=True)

# Checking a batch
for batch_idx, (images, species, stages) in enumerate(data_loader):
    print(f"Batch {batch_idx + 1}")
    print(f"Images shape: {images.shape}")
    print(f"Species: {species}")
    print(f"Stages: {stages}")
    break

Batch 1
Images shape: torch.Size([16, 3, 224, 224])
Species: ('Falciparum', 'Falciparum', 'Falciparum', 'Falciparum', 'Ovale', 'Falciparum', 'Malariae', 'Malariae', 'Falciparum', 'Falciparum', 'Falciparum', 'Falciparum', 'Falciparum', 'Falciparum', 'Malariae', 'Vivax')
Stages: [tensor([1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0]), tensor([0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1]), tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0]), tensor([0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0])]


## GAN Training

In [None]:
import torch.optim as optim
from torchvision.utils import save_image

In [None]:
# Generator
class Generator(nn.Module):
    def __init__(self, latent_dim: int, img_size: int):
        super(Generator, self).__init__()

        self.latent_dim = latent_dim
        self.img_size = img_size
        self.img_shape = (3, self.img_size, self.img_size)

        def layer(in_neurons: int, out_neurons: int, normalize_: bool = True, dropout_: bool = False):
            block = [nn.Linear(in_neurons, out_neurons)]
            if normalize_:
                block.append(nn.BatchNorm1d(out_neurons))
            if dropout_:
                block.append(nn.Dropout(0.3))
            block.append(nn.LeakyReLU(0.2))
            return block

        self.model = nn.Sequential(
            *layer(self.latent_dim, 128, normalize_=False, dropout_=True),
            *layer(128, 256, dropout_=True),
            *layer(256, 512, dropout_=True),
            *layer(512, 1024),
            *layer(1024, 2048),
            nn.Linear(2048, int(np.prod(self.img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        z = self.model(z)
        z = z.view(z.size(0), 3, self.img_size, self.img_size)
        return z

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, img_size: int):
        super(Discriminator, self).__init__()

        self.img_size = img_size
        self.img_shape = (3, self.img_size, self.img_size)

        def layer(in_neurons, out_neurons, dropout_=False):
            block = [nn.Linear(in_neurons, out_neurons), nn.LeakyReLU(0.2)]
            if dropout_:
                block.append(nn.Dropout(0.3))
            return block

        self.model = nn.Sequential(
            *layer(int(np.prod(self.img_shape)), 1024, dropout_=True),
            *layer(1024, 512, dropout_=True),
            *layer(512, 256),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, z):
        z = z.view(z.size(0), -1)
        return self.model(z)

In [None]:
# For GPU
ngpu = 1

device = torch.device("cuda" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [None]:
img_size = 224

# Hyperparameters
latent_dim = 128
epochs = 400
batch_size = 32
lr = 0.0002

In [None]:
# Initialize models
generator = Generator(latent_dim, img_size=img_size).to(device)
discriminator = Discriminator(img_size=img_size).to(device)

In [None]:
# Loss function and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

In [None]:
# Training loop
for epoch in range(epochs):
    for batch_idx, (real_images, _, _) in enumerate(data_loader):
        # Move real images to device
        real_images = real_images.to(device)
        batch_size = real_images.size(0)

        # Labels
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # --------------------------------------------------- #
        # --------------- Train Discriminator --------------- #
        # --------------------------------------------------- #

        # Generate random noise
        z = torch.randn(batch_size, latent_dim).to(device)
        # Generate fake images from noise
        fake_images = generator(z)

        real_loss = criterion(discriminator(real_images), real_labels)
        fake_loss = criterion(discriminator(fake_images.detach()), fake_labels)
        d_loss = real_loss + fake_loss

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # --------------------------------------------------- #
        # ----------------- Train Generator ----------------- #
        # --------------------------------------------------- #

        # Generate random noise
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_images = generator(z)
        g_loss = criterion(discriminator(fake_images), real_labels)

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        # Print progress
        if batch_idx % 14 == 0:
            print(f"Epoch [{epoch+1}/{epochs}]\n[{batch_idx}/{len(data_loader)}] --------------- "
                  f"Discriminator Loss: {d_loss.item():.4f} - Generator Loss: {g_loss.item():.4f}")

      # -------------------------------------------------------- #
      # ----------------------- Logging ------------------------ #
      # -------------------------------------------------------- #
      # TODO: Implement logging...

    # Save generator output after every epoch
    z = torch.randn(16, latent_dim).to(device)
    generated_images = generator(z)
    save_image(generated_images, f"/content/drive/MyDrive/malaria_project/generated_images/sixth_gan/generated_images_epoch_{(epoch+1):03}.png", normalize=True)

# Save models
torch.save(generator.state_dict(), "/content/drive/MyDrive/malaria_project/checkpoints/sixth_generator.pth")
torch.save(discriminator.state_dict(), "/content/drive/MyDrive/malaria_project/checkpoints/sixth_discriminator.pth")

Epoch [1/400]
[0/14] --------------- Discriminator Loss: 1.3886 - Generator Loss: 1.7959
Epoch [2/400]
[0/14] --------------- Discriminator Loss: 1.4597 - Generator Loss: 7.6888
Epoch [3/400]
[0/14] --------------- Discriminator Loss: 2.1736 - Generator Loss: 4.3686
Epoch [4/400]
[0/14] --------------- Discriminator Loss: 1.2875 - Generator Loss: 4.9161
Epoch [5/400]
[0/14] --------------- Discriminator Loss: 3.6989 - Generator Loss: 0.6239
Epoch [6/400]
[0/14] --------------- Discriminator Loss: 5.5686 - Generator Loss: 7.1128
Epoch [7/400]
[0/14] --------------- Discriminator Loss: 1.4421 - Generator Loss: 2.4969
Epoch [8/400]
[0/14] --------------- Discriminator Loss: 3.1587 - Generator Loss: 0.3512
Epoch [9/400]
[0/14] --------------- Discriminator Loss: 1.9770 - Generator Loss: 3.2177
Epoch [10/400]
[0/14] --------------- Discriminator Loss: 2.9886 - Generator Loss: 1.1634
Epoch [11/400]
[0/14] --------------- Discriminator Loss: 4.9721 - Generator Loss: 0.2262
Epoch [12/400]
[0/1