In [1]:
import os
import tempfile
import shutil
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils

from medmnist import ChestMNIST
from pytorch_fid import fid_score
from skimage.metrics import structural_similarity as ssim

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channels)
        )
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        out = self.block(x)
        out += x  # Skip connection
        return self.relu(out)

In [3]:
class GeneratorIncreasedChannels(nn.Module):
    def __init__(self, latent_dim, embedding_dim, num_classes, img_channels, img_size):
        super(GeneratorIncreasedChannels, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, embedding_dim)
        # Keep the same scaling factor: downsampled size will be img_size//4
        self.init_size = img_size // 4  
        # Increase the number of channels in the first fully connected layer to 256 instead of 128
        self.fc = nn.Sequential(
            nn.Linear(latent_dim + embedding_dim, 256 * self.init_size * self.init_size)
        )
        
        # Initial upsampling block – now working with 256 channels
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(256),
            nn.Upsample(scale_factor=2),  # Upsample from init_size to init_size*2
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256, 0.8),
            nn.ReLU(inplace=True)
        )
        
        # Add a couple of Residual Blocks; these work on 256 channels
        self.residual_blocks = nn.Sequential(
            ResidualBlock(256)
        )
        
        # Final blocks: upsample and reduce channels to produce the final image
        self.final_blocks = nn.Sequential(
            nn.Upsample(scale_factor=2),  # Upsample to full image resolution
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, img_channels, kernel_size=3, stride=1, padding=1),
            nn.Tanh()  # Tanh activation yields outputs in [-1, 1]
        )
        
    def forward(self, noise, labels):
        # Embed the labels
        label_input = self.label_embedding(labels)
        # Concatenate noise and label embedding along the last dimension
        gen_input = torch.cat((noise, label_input), dim=-1)
        out = self.fc(gen_input)
        # Reshape to a feature map of shape (batch, 256, init_size, init_size)
        out = out.view(out.shape[0], 256, self.init_size, self.init_size)
        out = self.conv_blocks(out)
        out = self.residual_blocks(out)
        img = self.final_blocks(out)
        return img

In [4]:

PNEUMONIA_LABEL = 1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)
latent_dim = 100
embedding_dim = 10
num_classes = 2  # 0: non-pneumonia, 1: pneumonia
img_size = 64
img_channels = 1  # Grayscale images



# Load generator
generator = GeneratorIncreasedChannels(
    latent_dim=latent_dim,
    embedding_dim=embedding_dim,
    num_classes=num_classes,
    img_channels=img_channels,
    img_size=img_size
).to(DEVICE)
generator.load_state_dict(torch.load("best_rlrgenerator.pth", map_location=DEVICE))
generator.eval()


cuda


GeneratorIncreasedChannels(
  (label_embedding): Embedding(2, 10)
  (fc): Sequential(
    (0): Linear(in_features=110, out_features=65536, bias=True)
  )
  (conv_blocks): Sequential(
    (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Upsample(scale_factor=2.0, mode='nearest')
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): ReLU(inplace=True)
  )
  (residual_blocks): Sequential(
    (0): ResidualBlock(
      (block): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    

In [5]:
import torch, gc
torch.cuda.empty_cache()
gc.collect()


8

In [None]:
# import numpy as np

# NUM_IMAGES = 20000  # adjust as needed
# labels = torch.full((NUM_IMAGES,), PNEUMONIA_LABEL, dtype=torch.long).to(DEVICE)
# noise = torch.randn(NUM_IMAGES, latent_dim).to(DEVICE)

# with torch.no_grad():
#     gen_imgs = generator(noise, labels).cpu().squeeze(1).numpy()  # (B, H, W)

# generated_images = [(img, PNEUMONIA_LABEL) for img in gen_imgs]


In [7]:
import torch
import numpy as np

def generate_pneumonia_images(generator, num_images, batch_size, latent_dim, label_idx, device):
    generator.eval()
    generated_images = []

    with torch.no_grad():
        for start in range(0, num_images, batch_size):
            curr_batch = min(batch_size, num_images - start)
            noise = torch.randn(curr_batch, latent_dim).to(device)
            labels = torch.full((curr_batch,), label_idx, dtype=torch.long).to(device)

            gen_imgs = generator(noise, labels).cpu().squeeze(1).numpy()  # (B, H, W)
            generated_images.extend([(img, label_idx) for img in gen_imgs])

    return generated_images


In [8]:
PNEUMONIA_LABEL = 1
latent_dim = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generated_images = generate_pneumonia_images(
    generator=generator,
    num_images=20000,
    batch_size=64,       # adjust if you want
    latent_dim=latent_dim,
    label_idx=PNEUMONIA_LABEL,
    device=DEVICE
)

print(f"✅ Generated {len(generated_images)} pneumonia images.")


✅ Generated 20000 pneumonia images.


In [9]:
class PneumoniaDataset(Dataset):
    def __init__(self, base_dataset, pneumonia_index=6):
        self.base_dataset = base_dataset
        self.indices = [i for i in range(len(self.base_dataset)) if int(self.base_dataset[i][1][pneumonia_index]) == 1]
    def __len__(self):
        return len(self.indices)
    def __getitem__(self, idx):
        real_idx = self.indices[idx]
        img, _ = self.base_dataset[real_idx]
        # Convert image range from [0,1] to [-1,1] to match generator output (using Tanh)
        img = img * 2 - 1
        return img, 1  # pneumonia label is 1

transform = transforms.Compose([
    transforms.ToTensor()
])
batch_size = 16  # or 8, 32 – small and efficient
full_train_dataset = ChestMNIST(split='train', download=True, transform=transform, size=img_size)
pneumonia_dataset = PneumoniaDataset(full_train_dataset, pneumonia_index=6)
pneumonia_loader = DataLoader(pneumonia_dataset, batch_size=batch_size, shuffle=True)

In [11]:
real_pneumonia_images = []

for img_batch, _ in pneumonia_loader:
    img_batch = (img_batch + 1) / 2  # convert [-1, 1] → [0, 1]

    for i in range(img_batch.size(0)):
        img_tensor = img_batch[i]  # shape: (1, 64, 64)
        img_2d = img_tensor.squeeze(0)  # remove channel dimension → (64, 64)
        img_np = img_2d.numpy()  # convert to NumPy
        assert img_np.shape == (64, 64), f"Got shape {img_np.shape}"  # ✅ safety check
        real_pneumonia_images.append(img_np)

    if len(real_pneumonia_images) >= 500:
        break



In [12]:
from skimage.metrics import structural_similarity as ssim
import numpy as np
import random

def avg_ssim(gen_img, real_imgs):
    return np.mean([
        ssim(gen_img, real_img, data_range=1.0) for real_img in real_imgs
    ])

# Ensure real pneumonia images are a list of (64, 64) arrays
real_pneumonia = list(real_pneumonia_images)  # shape: (64, 64), values in [0, 1]

ssim_scores = []

# ⬅️ FIXED: Unpack (img, label) tuple
for img, _ in generated_images:
    sample_real = random.sample(real_pneumonia, k=min(5, len(real_pneumonia)))
    score = avg_ssim(img, sample_real)
    ssim_scores.append(score)

# Keep top 30% most SSIM-similar
threshold = np.percentile(ssim_scores, 70)
filtered_images = [
    img for (img, _), score in zip(generated_images, ssim_scores) if score >= threshold
]

print(f"✅ Filtered {len(filtered_images)} high-quality images (top 30% by SSIM)")


✅ Filtered 6000 high-quality images (top 30% by SSIM)


In [None]:
# from skimage.metrics import structural_similarity as ssim
# import random
# def avg_ssim(gen_img, real_imgs):
#     return np.mean([
#         ssim(gen_img, real_img, data_range=1.0) for real_img in real_imgs
#     ])

# real_pneumonia = real_pneumonia_images
# real_pneumonia = list(real_pneumonia)  # forces conversion

# ssim_scores = []
# for img in generated_images:
#     sample_real = random.sample(real_pneumonia, k=min(5, len(real_pneumonia)))

#     score = avg_ssim(img, sample_real)
#     ssim_scores.append(score)

# # Keep top 30% most similar
# threshold = np.percentile(ssim_scores, 70)
# filtered_images = [img for img, score in zip(generated_images, ssim_scores) if score >= threshold]


AttributeError: 'tuple' object has no attribute 'shape'

In [None]:
# print("Generated:", img.shape)
# print("Real example:", real_pneumonia[0].shape)


Generated: (64, 64)
Real example: (64, 64)


In [None]:
# ssim_scores = []
# for img in gen_imgs:
#     sample_real = random.sample(real_pneumonia, k=min(5, len(real_pneumonia)))
#     score = avg_ssim(img, sample_real)
#     ssim_scores.append(score)


In [None]:
# import numpy as np

# # Keep top 30% most similar images
# threshold = np.percentile(ssim_scores, 70)

# filtered_images = [img for img, score in zip(gen_imgs, ssim_scores) if score >= threshold]
# print(f"Selected {len(filtered_images)} high-quality images.")


Selected 19 high-quality images.


In [13]:
import os
import imageio
import numpy as np

save_dir = "augmented/pneumonia"
os.makedirs(save_dir, exist_ok=True)

for i, img in enumerate(filtered_images):
    # Step 1: Normalize from [-1, 1] to [0, 1]
    img = (img + 1) / 2 if img.min() < 0 else img

    # Step 2: Clip to avoid out-of-range values
    img = np.clip(img, 0, 1)

    # Step 3: Convert to uint8
    img_uint8 = (img * 255).astype(np.uint8)

    # Step 4: Save as PNG
    imageio.imwrite(os.path.join(save_dir, f"gen_{i}.png"), img_uint8)
