In [1]:
print("hello world")

hello world


In [2]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torchvision import models, transforms
import torchvision
from typing import Tuple, List, Optional, Dict, Any
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as sk_ssim
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import warnings
import numpy as np
import pandas as pd
print("done")

done


In [3]:
# Optional import: torchxrayvision (preferred for chest-xray pretrained models)
try:
    import torchxrayvision as xrv
    _HAS_TXRV = True
except Exception:
    _HAS_TXRV = False
    warnings.warn("torchxrayvision not available — falling back to VGG16 for perceptual features.")

In [4]:
# cvae_unet_gan_perceptual.py
# Full architecture + training/eval helpers for conditional VAE (U-Net) + PatchGAN + Perceptual Loss

In [5]:
# --------------------------------------
# Utility: basic conv / upconv blocks
# --------------------------------------
def conv_block(in_ch, out_ch, kernel=3, stride=1, padding=1, use_bn=True):
    layers = [nn.Conv2d(in_ch, out_ch, kernel, stride, padding, bias=not use_bn)]
    if use_bn:
        layers.append(nn.InstanceNorm2d(out_ch, affine=True))
    layers.append(nn.ReLU(inplace=True))
    return nn.Sequential(*layers)


def upconv_block(in_ch, out_ch, kernel=3, stride=1, padding=1, use_bn=True):
    # simple upsample + conv
    layers = [nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
              nn.Conv2d(in_ch, out_ch, kernel, stride, padding, bias=not use_bn)]
    if use_bn:
        layers.append(nn.InstanceNorm2d(out_ch, affine=True))
    layers.append(nn.ReLU(inplace=True))
    return nn.Sequential(*layers)
print("done")

done


In [6]:
# --------------------------------------
# FiLM modulation layer (simple)
# Given an embedding for label, produce scale & shift per channel
# --------------------------------------
class FiLM(nn.Module):
    def __init__(self, embed_dim: int, num_features: int):
        super().__init__()
        self.fc = nn.Linear(embed_dim, num_features * 2)  # produce gamma and beta
        nn.init.xavier_uniform_(self.fc.weight)

    def forward(self, x: torch.Tensor, emb: torch.Tensor):
        """
        x: (B, C, H, W)
        emb: (B, embed_dim)
        """
        params = self.fc(emb)  # (B, 2*C)
        gamma, beta = params.chunk(2, dim=1)
        gamma = gamma.unsqueeze(-1).unsqueeze(-1)
        beta = beta.unsqueeze(-1).unsqueeze(-1)
        return x * (1 + gamma) + beta
print("done")

done


In [7]:
# --------------------------------------
# U-Net style Encoder
# --------------------------------------
class UNetEncoder(nn.Module):
    """
    U-Net encoder that compresses image into features.
    Does NOT receive label in this variant (we keep encoder label-free for better swap controllability).
    """
    def __init__(self, in_channels=1, base_channels=32, num_down=4):
        super().__init__()
        ch = base_channels
        self.initial = conv_block(in_channels, ch)

        self.downs = nn.ModuleList()
        self.skip_channels = [ch]  # ✅ record initial skip channel

        for i in range(num_down):
            self.downs.append(nn.Sequential(
                conv_block(ch, ch*2, kernel=4, stride=2, padding=1),  # downsample
            ))
            ch *= 2
            self.skip_channels.append(ch)  # ✅ record skip channel after each down

        # final conv to map to latent intermediate feature map
        self.final = conv_block(ch, ch)

        self.out_channels = ch

    def forward(self, x):
        skips = []
        x = self.initial(x)
        skips.append(x)
        for d in self.downs:
            x = d(x)
            skips.append(x)
        x = self.final(x)
        return x, skips  # x is deepest feature map, skips for decoder
    
print("done")

done


In [8]:
# --------------------------------------
# Latent projection: map feature map -> latent vector (mu, logvar)
# and back (latent -> feature map)
# --------------------------------------
class LatentMapper(nn.Module):
    def __init__(self, feat_channels, latent_dim=256):
        super().__init__()
        self.feat_channels = feat_channels
        self.latent_dim = latent_dim
        # global pooling -> FC -> mu/logvar
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc_mu = nn.Linear(feat_channels, latent_dim)
        self.fc_logvar = nn.Linear(feat_channels, latent_dim)

        # map latent back to channels spatial map (up-projection)
        self.fc_dec = nn.Linear(latent_dim, feat_channels)

    def encode(self, feat):
        # feat: (B, C, H, W)
        b, c, _, _ = feat.shape
        pooled = self.pool(feat).view(b, c)  # (B, C)
        mu = self.fc_mu(pooled)
        logvar = self.fc_logvar(pooled)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = (0.5 * logvar).exp()
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode_latent_to_feat(self, z):
        # z: (B, latent_dim)
        x = self.fc_dec(z)  # (B, feat_channels)
        # reshape to (B, C, 1, 1) then broadcast spatially when needed
        return x.unsqueeze(-1).unsqueeze(-1)
print("done")

done


In [9]:
# --------------------------------------
# U-Net Decoder conditioned on label via FiLM and concatenation
# --------------------------------------
class UNetDecoder(nn.Module):
    """
    Decoder: takes latent features and label embedding to produce image.
    Uses skip connections and FiLM modulations at each decoding block.
    Handles dynamic input channels from encoder skips.
    """
    def __init__(self, out_channels=1, base_channels=32, num_up=4, latent_feat_channels=512, 
                 label_embed_dim=32, encoder_channels: list = None):
        super().__init__()
        assert encoder_channels is not None, "You must pass encoder skip channels"
        self.num_up = num_up
        self.encoder_channels = encoder_channels[::-1]  # reverse to match decoder order

        # label embedding
        self.label_embed = nn.Sequential(
            nn.Linear(1, label_embed_dim),
            nn.ReLU(inplace=True)
        )

        self.up_convs = nn.ModuleList()
        self.films = nn.ModuleList()
        ch = latent_feat_channels

        # dynamically set input channels based on skip connections
        for i in range(num_up):
            skip_ch = self.encoder_channels[i+1] if i+1 < len(self.encoder_channels) else 0
            in_ch = ch + skip_ch   # <--- include skip channels here
            out_ch = ch // 2
            self.up_convs.append(upconv_block(in_ch, out_ch))
            self.films.append(FiLM(label_embed_dim, out_ch))
            ch = out_ch

        # final conv: expects concatenated channels from last upconv + first skip
        final_in_ch = ch + self.encoder_channels[0]
        self.final_conv = nn.Sequential(
            conv_block(final_in_ch, ch),
            nn.Conv2d(ch, out_channels, kernel_size=1),
            nn.Tanh()
        )

    def forward(self, feat_from_latent, skips, label):
        B = feat_from_latent.shape[0]
        x = feat_from_latent
        emb = self.label_embed(label.view(B, -1))
        skips_rev = skips[::-1]

        for i, (up, film) in enumerate(zip(self.up_convs, self.films)):
            if i+1 < len(skips_rev):
                skip = skips_rev[i+1]
                if x.shape[-2:] != skip.shape[-2:]:
                    x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
                x = torch.cat([x, skip], dim=1)   # concat first
                x = up(x)
                x = film(x, emb)
            else:
                x = up(x)
                x = film(x, emb)

        if x.shape[2:] != skips_rev[0].shape[2:]:
            skips_rev[0] = F.interpolate(skips_rev[0], size=x.shape[2:], mode="bilinear", align_corners=False)

        out = self.final_conv(torch.cat([x, skips_rev[0]], dim=1))
        return out
print("done")

done


In [10]:
# --------------------------------------
# Conditional VAE model (Encoder + LatentMapper + Decoder)
# --------------------------------------
class ConditionalVAE_UNet(nn.Module):
    def __init__(self,
                 img_channels=1,
                 base_channels=32,
                 num_scales=4,
                 latent_dim=256,
                 label_embed_dim=32,
                 device: torch.device = torch.device('cpu')):
        super().__init__()
        self.device = device
        self.encoder = UNetEncoder(in_channels=img_channels, base_channels=base_channels, num_down=num_scales)
        feat_ch = self.encoder.out_channels
        self.latent_map = LatentMapper(feat_ch, latent_dim=latent_dim)
        # the decoder expects the latent mapped to same count of channels as feat_ch
        self.decoder = UNetDecoder(out_channels=img_channels,
                                   base_channels=base_channels,
                                   num_up=num_scales,
                                   latent_feat_channels=feat_ch,
                                   label_embed_dim=label_embed_dim,
                                    encoder_channels=self.encoder.skip_channels)

    def forward(self, x: torch.Tensor, label: torch.Tensor, sample: bool = True):
        """
        Forward pass for reconstruction training.
        x: (B, C, H, W)
        label: (B,1) float tensor (0 or 1) or continuous
        returns: reconstructed image, mu, logvar, z
        """
        feat, skips = self.encoder(x)  # feat: deepest feature map
        mu, logvar = self.latent_map.encode(feat)
        if sample:
            z = self.latent_map.reparameterize(mu, logvar)
        else:
            z = mu
        # project latent back to feature map
        feat_z = self.latent_map.decode_latent_to_feat(z)  # (B, featC, 1, 1)
        # decode with label
        x_recon = self.decoder(feat_z, skips, label)
        return x_recon, mu, logvar, z

print("done")

done


In [11]:

# --------------------------------------
# PatchGAN Discriminator (conditional)
# --------------------------------------
class PatchDiscriminator(nn.Module):
    """
    PatchGAN discriminator that optionally conditions on label via concatenation/projection.
    Output is patch map of real/fake logits.
    """
    def __init__(self, in_channels=1, base_channels=32, n_layers=4, label_condition=True):
        super().__init__()
        self.label_condition = label_condition
        ch = base_channels
        layers = [nn.Conv2d(in_channels + (1 if label_condition else 0), ch, kernel_size=4, stride=2, padding=1),
                  nn.LeakyReLU(0.2, inplace=True)]
        for i in range(1, n_layers):
            layers += [spectral_norm(nn.Conv2d(ch, ch*2, kernel_size=4, stride=2, padding=1)), nn.LeakyReLU(0.2, inplace=True)]
            ch *= 2
        # final conv to produce 1-channel patch score
        layers += [nn.Conv2d(ch, 1, kernel_size=4, stride=1, padding=1)]
        self.model = nn.Sequential(*layers)

    def forward(self, img: torch.Tensor, label: torch.Tensor = None):
        # if conditional, concat label as extra channel broadcasted spatially
        if self.label_condition:
            assert label is not None
            B, _, H, W = img.shape
            # label shape (B,1)
            lab_map = label.view(B, 1, 1, 1).expand(-1, 1, H, W)
            x = torch.cat([img, lab_map], dim=1)
        else:
            x = img
        return self.model(x)

print("done")

done


In [12]:

# --------------------------------------
# Perceptual feature extractor and loss
# Prefers torchxrayvision DenseNet if available; else uses VGG16 features
# --------------------------------------
class PerceptualFeatureExtractor(nn.Module):
    def __init__(self, device: torch.device, layers: List[int] = [3, 8, 15, 22], prefer_txrv=True):
        super().__init__()
        self.device = device
        self.preferred = False
        self.layers = layers
        if prefer_txrv and _HAS_TXRV:
            # Use torchxrayvision DenseNet pretrained on ChestX-ray14 (or CheXpert) if available
            # NOTE: model architecture selection may need to be adapted by the user
            # Here we choose DenseNet121 as an example
            self.net = xrv.models.DenseNet(weights="densenet121-res224-all").to(device).eval()
            for p in self.net.parameters():
                p.requires_grad = False

            # torchxrayvision expects normalized input differently; we'll not freeze transformation here
            self.preferred = True
            # For simplicity, we'll use the raw features from penultimate layer. Expose forward hook
        else:
            # Fallback: use VGG16 features from torchvision
            vgg = models.vgg16(pretrained=True).features.to(device).eval()
            self.net = vgg[:max(layers) + 1]
            for p in self.net.parameters():
                p.requires_grad = False
            self.preferred = False

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        """
        Return list of feature maps from selected layers.
        Note: x must be normalized appropriately by caller (ImageNet norm for VGG).
        For torchxrayvision models, x should be in same normalization as model expects (refer to torchxrayvision docs).
        """
        if self.preferred and _HAS_TXRV:
            # torchxrayvision DenseNet121: call full model and extract intermediate features if required.
            # The object returned by xrv DenseNet may already provide embeddings/predictions.
            # For compatibility we return the penultimate feature map repeated.
            with torch.no_grad():
                # model.forward: returns logits and embedding depending on model configuration
                # Use model.features if available, else model.forward gives preds only.
                try:
                    # attempt to use model.features (if available)
                    feats = self.net.features(x)
                    # choose some mid layers - here we'll take the last few feature maps as proxies
                    return [feats]
                except Exception:
                    # fallback: return output embedding
                    emb = self.net(x)
                    return [emb]
        else:
            # VGG path: step through layers and collect
            features = []
            curr = x
            for i, layer in enumerate(self.net):
                curr = layer(curr)
                if i in self.layers:
                    features.append(curr)
            return features


class PerceptualLoss(nn.Module):
    def __init__(self, device: torch.device, layers: List[int] = [3, 8, 15, 22], pref_txrv=True):
        super().__init__()
        self.fe = PerceptualFeatureExtractor(device, layers=layers, prefer_txrv=pref_txrv)
        self.criterion = nn.MSELoss()

    def forward(self, gen: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
        # gen and tgt must be pre-normalized to match feature extractor expectations
        if not self.fe.preferred:  # VGG path
            norm = transforms.Normalize(mean=[0.485,0.456,0.406],
                                        std=[0.229,0.224,0.225])
            gen = torch.stack([norm(img) for img in gen])
            tgt = torch.stack([norm(img) for img in tgt])
        gf = self.fe(gen)
        tf = self.fe(tgt)
        loss = 0.0
        for a, b in zip(gf, tf):
            loss = loss + self.criterion(a, b)
        return loss
print("done")

done


In [13]:
# --------------------------------------
# Loss helpers for VAE
# --------------------------------------
def kl_divergence(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
    # returns sum over latent dims per batch, averaged by batch
    # standard VAE KL between q ~ N(mu, sigma^2) and p ~ N(0,I)
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    return torch.mean(kld)
print("done")

done


In [14]:
# --------------------------------------
# SSIM utilities
# --------------------------------------
def compute_ssim_pair(img_a: np.ndarray, img_b: np.ndarray, multichannel=False) -> float:
    """
    Compute SSIM for a single image pair.
    img_a, img_b: numpy arrays in [0,1] or [0,255], dtype float. If using sk_ssim ensure range argument set appropriately.
    """
    # If grayscale, sk_ssim expects 2D arrays
    if img_a.ndim == 3 and img_a.shape[2] == 1:
        img_a = img_a.squeeze(2)
        img_b = img_b.squeeze(2)
    # Use skimage ssim; ensure float in [0,1]
    return sk_ssim(img_a, img_b, data_range=img_a.max() - img_a.min())


def compute_ssim_batch_numpy(origs: np.ndarray, gens: np.ndarray, labels: Optional[np.ndarray] = None):
    """
    origs, gens: arrays of shape (N,H,W) or (N,H,W,C)
    labels: optional (N,) with 0/1 indicating normal/pneumonia
    returns dict with full list, normal list, pneumonia list
    """
    N = origs.shape[0]
    all_scores = []
    normal_scores = []
    pneu_scores = []
    for i in range(N):
        s = compute_ssim_pair(origs[i], gens[i])
        all_scores.append(s)
        if labels is not None:
            if labels[i] == 0:
                normal_scores.append(s)
            else:
                pneu_scores.append(s)
    result = {
        "all": np.array(all_scores),
        "normal": np.array(normal_scores),
        "pneumonia": np.array(pneu_scores)
    }
    return result


def plot_ssim_distributions(ssim_dict: Dict[str, np.ndarray], bins=50, figsize=(12, 4)):
    plt.figure(figsize=figsize)
    plt.subplot(1, 3, 1)
    plt.hist(ssim_dict['all'], bins=bins)
    plt.title('SSIM - All')
    plt.subplot(1, 3, 2)
    plt.hist(ssim_dict['normal'], bins=bins)
    plt.title('SSIM - Normal')
    plt.subplot(1, 3, 3)
    plt.hist(ssim_dict['pneumonia'], bins=bins)
    plt.title('SSIM - Pneumonia')
    plt.tight_layout()
    plt.show()

print("done")

done


In [15]:

# --------------------------------------
# Utility: load a pretrained pneumonia classifier
# Prefer torchxrayvision DenseNet finetuned; else user supplies path
# --------------------------------------
def load_pretrained_pneumonia_classifier(device: torch.device, model_name: str = "densenet121", checkpoint_path: Optional[str] = None):
    """
    Try to load a pneumonia classifier:
     - If torchxrayvision available, load DenseNet pretrained on CXR data.
     - Else, if checkpoint_path provided, load a torch checkpoint into ResNet/DenseNet skeleton.
    Returns a model in eval() on device and a preprocessing function (callable) that converts raw images to model input.
    """
    if _HAS_TXRV:
        # Example: DenseNet trained on combined datasets. xrv wrapper gives standardized preprocess.
        model = xrv.models.DenseNet(weights="densenet121-res224-all").to(device)
        model.eval()

        def preprocess(x: torch.Tensor):
            # torchxrayvision models usually expect: (B,1,H,W) normalized to range and standardization
            # xrv has own normalization; use xrv.utils.normalize? Here assume input in [0,1]
            # Use xrv.utils.normalize in actual usage; for now return x
            return x

        return model, preprocess
    else:
        # Fallback: user must supply checkpoint_path. We'll create ResNet18 skeleton as example.
        if checkpoint_path is None:
            raise RuntimeError("No torchxrayvision available and no checkpoint_path provided. Please provide a classifier checkpoint.")
        # Example skeleton: ResNet18 with single channel input adapted
        model = torchvision.models.resnet18(pretrained=False)
        # modify first conv to accept 1 channel
        model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # modify final fc to binary
        model.fc = nn.Linear(model.fc.in_features, 2)
        ckpt = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(ckpt)
        model = model.to(device).eval()

        def preprocess(x: torch.Tensor):
            return x  # user should implement the required normalization

        return model, preprocess


def classify_batch(model, preprocess_fn, imgs: torch.Tensor, device: torch.device):
    """
    imgs: Tensor (B,1,H,W) in raw float range expected by preprocess_fn
    returns predicted labels (B,) (0 or 1) and softmax probs
    """
    model.eval()
    inp = preprocess_fn(imgs.to(device))
    with torch.no_grad():
        logits = model(inp)
        probs = F.softmax(logits, dim=1)
        preds = torch.argmax(probs, dim=1).cpu().numpy()
        scores = probs[:, 1].cpu().numpy()  # pneumonia probability if label 1 is pneumonia
    return preds, scores

print("done")

done


In [16]:

# --------------------------------------
# Trainer skeleton: single-step functions for discriminator & generator update calculations
# (Does not include data loader loops)
# --------------------------------------
class CVAEGANTrainer:
    def __init__(self,
                 device: torch.device,
                 model: ConditionalVAE_UNet,
                 discriminator: PatchDiscriminator,
                 perceptual_loss_fn: PerceptualLoss,
                 lr_g: float = 2e-4,
                 lr_d: float = 2e-4,
                 lambda_adv: float = 0.5,
                 lambda_perc: float = 1.0,
                 lambda_l1: float = 10.0,
                 lambda_kl: float = 1.0):
        self.device = device
        self.model = model.to(device)
        self.discriminator = discriminator.to(device)
        self.perc = perceptual_loss_fn
        self.lambda_adv = lambda_adv
        self.lambda_perc = lambda_perc
        self.lambda_l1 = lambda_l1
        self.lambda_kl = lambda_kl

        self.optim_g = torch.optim.Adam(self.model.parameters(), lr=lr_g, betas=(0.5, 0.999))
        self.optim_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.999))
        # adversarial criterion
        self.adv_loss_fn = nn.BCEWithLogitsLoss()
        # pixel loss
        self.l1 = nn.L1Loss()

    def train_step(self, imgs: torch.Tensor, labels: torch.Tensor):
        """
        Single training step:
         - imgs: (B,1,H,W) normalized to [-1,1] (decoder uses Tanh output)
         - labels: (B,1) float (0 or 1)
        Returns dict of losses
        """
        B = imgs.shape[0]
        imgs = imgs.to(self.device)
        labels = labels.to(self.device)

        # ----------------------------
        # Forward generator (VAE)
        # ----------------------------
        self.model.train()
        recon_imgs, mu, logvar, z = self.model(imgs, labels, sample=True)

        # ----------------------------
        # Update discriminator
        # ----------------------------
        self.discriminator.train()
        # Real: label 1 for real
        real_logits = self.discriminator(imgs, labels)
        real_targets = torch.ones_like(real_logits).to(self.device)
        d_loss_real = self.adv_loss_fn(real_logits, real_targets)

        # Fake: feed generated images (detach so gradients not flow to generator)
        fake_logits = self.discriminator(recon_imgs.detach(), labels)
        fake_targets = torch.zeros_like(fake_logits).to(self.device)
        d_loss_fake = self.adv_loss_fn(fake_logits, fake_targets)

        d_loss = (d_loss_real + d_loss_fake) * 0.5
        self.optim_d.zero_grad()
        d_loss.backward()
        self.optim_d.step()

        # ----------------------------
        # Update generator (VAE encoder+decoder)
        # ----------------------------
        # compute adversarial loss with labels as want D to predict real
        adv_logits = self.discriminator(recon_imgs, labels)
        adv_targets = torch.ones_like(adv_logits).to(self.device)
        adv_loss = self.adv_loss_fn(adv_logits, adv_targets)

        # pixel reconstruction (L1) — imgs and recon_imgs expected [-1,1] range
        recon_loss = self.l1(recon_imgs, imgs)

        # perceptual loss - model expects inputs normalized appropriately depending on extractor
        # NOTE: user must ensure normalization is consistent for perceptual extractor (ImageNet or CXR)
        perc_loss = self.perc(recon_imgs, imgs)

        # KL
        kld = kl_divergence(mu, logvar)

        g_loss = (self.lambda_l1 * recon_loss) + (self.lambda_kl * kld) + (self.lambda_adv * adv_loss) + (self.lambda_perc * perc_loss)

        self.optim_g.zero_grad()
        g_loss.backward()
        self.optim_g.step()

        return {
            "d_loss": d_loss.item(),
            "g_loss": g_loss.item(),
            "recon_loss": recon_loss.item(),
            "adv_loss": adv_loss.item(),
            "perc_loss": perc_loss.item(),
            "kl": kld.item()
        }

    def generate_opposite_label(self, imgs: torch.Tensor, labels: torch.Tensor, deterministic=True):
        """
        Given imgs (B,1,H,W) and labels (B,1), generate images with swapped label.
        deterministic=True uses mu instead of sampling z for reproducible outputs.
        """
        imgs = imgs.to(self.device)
        labels = labels.to(self.device)
        swapped = 1.0 - labels  # assumes binary 0/1
        with torch.no_grad():
            recon, mu, logvar, z = self.model(imgs, labels, sample=not deterministic)
            # To generate opposite label, we should re-encode (using encoder path) but then send swapped label to decoder.
            # Re-run encoder to get mu, but we already have mu above. Use mu for deterministic
            if deterministic:
                z_use = mu
            else:
                z_use = z
            feat_z = self.model.latent_map.decode_latent_to_feat(z_use.to(self.device))

            # Above is placeholder: easier to re-run encode->latent->decode flow with swapped label using model.forward (but model.forward expects input label used by encoder)
            # We'll implement a dedicated function to perform decode given z and swapped label:
            # Use latent mapping decode + feed decoder with encoder skips using encoder(imgs)
            feat, skips = self.model.encoder(imgs)
            mu2, logvar2 = self.model.latent_map.encode(feat)
            if deterministic:
                z2 = mu2
            else:
                z2 = self.model.latent_map.reparameterize(mu2, logvar2)
            feat_z2 = self.model.latent_map.decode_latent_to_feat(z2)
            # now decode with swapped label
            gen_swap = self.model.decoder(feat_z2, skips, swapped)
        return gen_swap  # (B,1,H,W) in [-1,1]
print("done")

done


In [17]:

# --------------------------------------
# Evaluation: generate opposite label images for a full dataset and compute confusion matrix using classifier
# --------------------------------------
def evaluate_translations_and_confusion(classifier_model, preprocess_fn, cvae_model: ConditionalVAE_UNet,
                                         dataloader, device: torch.device, label_map: Dict[int, str] = {0: "normal", 1: "pneumonia"}):
    """
    Given a dataloader that yields (img, label, optional meta),
    generate swapped-label images for every example and run through the classifier, producing confusion matrix.
    Returns:
      - y_true: np.array (N,) original labels
      - y_pred_on_swapped: np.array (N,) predictions of classifier on swapped images
      - probs: np.array (N,) classifier probability for 'pneumonia' class on swapped images
    """
    cvae_model.eval()
    ys_true = []
    ys_pred_swapped = []
    probs_swapped = []
    images_orig = []
    images_swapped = []
    for batch in dataloader:
        # Expect batch to be (imgs, labels) or (imgs, labels, ...). The user should adapt if different.
        imgs = batch[0].to(device)
        labels = batch[1].to(device)
        with torch.no_grad():
            # encode -> decode with swapped label
            feat, skips = cvae_model.encoder(imgs)
            mu, logvar = cvae_model.latent_map.encode(feat)
            z = mu  # deterministic
            feat_z = cvae_model.latent_map.decode_latent_to_feat(z)
            swapped = 1.0 - labels
            gen_swapped = cvae_model.decoder(feat_z, skips, swapped)
            # convert to input range expected by classifier
            imgs_for_classifier = preprocess_fn(gen_swapped.cpu())
            preds, probs = classify_batch(classifier_model, preprocess_fn, gen_swapped.cpu(), device)
            ys_true.append(labels.cpu().numpy().ravel())
            ys_pred_swapped.append(preds)
            probs_swapped.append(probs)
            images_orig.append(imgs.cpu().numpy())
            images_swapped.append(gen_swapped.cpu().numpy())

    y_true = np.concatenate(ys_true, axis=0)
    y_pred = np.concatenate(ys_pred_swapped, axis=0)
    probs = np.concatenate(probs_swapped, axis=0)
    images_orig = np.concatenate(images_orig, axis=0)
    images_swapped = np.concatenate(images_swapped, axis=0)

    # confusion matrix
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[label_map[0], label_map[1]])
    return {
        "y_true": y_true,
        "y_pred_swapped": y_pred,
        "probs_swapped": probs,
        "confusion_matrix": cm,
        "confusion_display": disp,
        "images_orig": images_orig,
        "images_swapped": images_swapped
    }

print("done")

done


In [18]:

# --------------------------------------
# SSIM evaluation wrapper
# --------------------------------------
def evaluate_self_reconstruction_ssim(cvae_model: ConditionalVAE_UNet, dataloader, device: torch.device):
    """
    For each image in dataloader, reconstruct using same label and compute SSIM between original and reconstructed.
    Returns numpy arrays of original images, reconstructed images, labels.
    """
    cvae_model.eval()
    origs = []
    gens = []
    labs = []
    for batch in dataloader:
        imgs = batch[0].to(device)
        labels = batch[1].to(device)
        with torch.no_grad():
            recon, mu, logvar, z = cvae_model(imgs, labels, sample=False)
        # move to cpu numpy and scale to [0,1]
        origs.append(((imgs.cpu().numpy() + 1.0) / 2.0).squeeze(1))   # assumes imgs in [-1,1]
        gens.append(((recon.cpu().numpy() + 1.0) / 2.0).squeeze(1))
        labs.append(labels.cpu().numpy().ravel())
    origs = np.concatenate(origs, axis=0)
    gens = np.concatenate(gens, axis=0)
    labs = np.concatenate(labs, axis=0)
    ssim_dict = compute_ssim_batch_numpy(origs, gens, labels=labs)
    return ssim_dict, origs, gens, labs


# --------------------------------------
# Example of how to plot confusion matrix (user calls .plot())
# --------------------------------------
def plot_confusion_matrix_from_eval(eval_res):
    disp = eval_res["confusion_display"]
    fig, ax = plt.subplots(figsize=(6, 6))
    disp.plot(ax=ax)
    plt.title("Confusion Matrix for Classifier on Swapped-label Generated Images")
    plt.show()


# -------------------------------------------------------------------------------------------
# End of module
# -------------------------------------------------------------------------------------------
print("done")

done


In [19]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from google.cloud import storage
from PIL import Image
import io
import torch
import torch.nn as nn
import torch.optim as optim

# -------------------------------
# Dataset class (Option B style with checks)
# -------------------------------
class GCSChestXrayDataset(Dataset):
    def __init__(self, project_id, bucket_name, prefix, transform=None):
        self.transform = transform
        self.samples = []

        client = storage.Client(project=project_id)
        bucket = client.bucket(bucket_name)

        print(f"Listing objects from gs://{bucket_name}/{prefix} ...")

        blobs = bucket.list_blobs(prefix=prefix)
        for blob in blobs:
            if blob.name.endswith(".jpeg") or blob.name.endswith(".jpg") or blob.name.endswith(".png"):
                # Infer label from folder name
                if "NORMAL" in blob.name.upper():
                    label = 0
                elif "PNEUMONIA" in blob.name.upper():
                    label = 1
                else:
                    continue
                self.samples.append((blob.name, label))

        print(f"✅ Found {len(self.samples)} samples under {prefix}")

        self.bucket = bucket

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        blob_name, label = self.samples[idx]
        blob = self.bucket.blob(blob_name)
        img_bytes = blob.download_as_bytes()
        img = Image.open(io.BytesIO(img_bytes)).convert("L")
        
        if self.transform:
            img = self.transform(img)
        return img, label

# -------------------------------
# Transforms
# -------------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# -------------------------------
# Load datasets
# -------------------------------
project_id = "pneumonia-generator"
bucket_name = "rsna-pneumonia-x-rays"

train_dataset = GCSChestXrayDataset(project_id, bucket_name, prefix="chest_xray/chest_xray/train/", transform=transform)
val_dataset   = GCSChestXrayDataset(project_id, bucket_name, prefix="chest_xray/chest_xray/val/", transform=transform)
test_dataset  = GCSChestXrayDataset(project_id, bucket_name, prefix="chest_xray/chest_xray/test/", transform=transform)


# -------------------------------
# Debug: print dataset lengths
# -------------------------------
print("Train size:", len(train_dataset))
print("Val size:", len(val_dataset))
print("Test size:", len(test_dataset))

# -------------------------------
# Create loaders only if non-empty
# -------------------------------
def safe_loader(dataset, batch_size, shuffle):
    if len(dataset) == 0:
        print("⚠️ WARNING: Empty dataset, loader will be None")
        return None
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=2)

train_loader = safe_loader(train_dataset, batch_size=32, shuffle=True)
val_loader   = safe_loader(val_dataset, batch_size=32, shuffle=False)
test_loader  = safe_loader(test_dataset, batch_size=32, shuffle=False)
print("done")

Listing objects from gs://rsna-pneumonia-x-rays/chest_xray/chest_xray/train/ ...
✅ Found 5216 samples under chest_xray/chest_xray/train/
Listing objects from gs://rsna-pneumonia-x-rays/chest_xray/chest_xray/val/ ...
✅ Found 16 samples under chest_xray/chest_xray/val/
Listing objects from gs://rsna-pneumonia-x-rays/chest_xray/chest_xray/test/ ...
✅ Found 624 samples under chest_xray/chest_xray/test/
Train size: 5216
Val size: 16
Test size: 624
done


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate model, discriminator, perceptual loss
cvae_model = ConditionalVAE_UNet(device=device)
discriminator = PatchDiscriminator()
perc_loss_fn = PerceptualLoss(device=device)

trainer = CVAEGANTrainer(device, cvae_model, discriminator, perc_loss_fn)

num_epochs = 5

print("Starting Training...")
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    epoch_losses = {"d_loss": [], "g_loss": [], "recon_loss": [], "adv_loss": [], "perc_loss": [], "kl": []}
    
    cvae_model.train()
    for imgs, labels in train_loader:
        labels = labels.unsqueeze(1).float()  # ensure (B,1) float
        loss_dict = trainer.train_step(imgs, labels)
        for k, v in loss_dict.items():
            epoch_losses[k].append(v)
    
    # Print epoch averages
    avg_losses = {k: np.mean(v) for k, v in epoch_losses.items()}
    print(f"Avg losses: {avg_losses}")
    
print("Training Done")



Starting Training...
Epoch 1/5


In [None]:
# 2a. SSIM self-reconstruction
ssim_res, origs, gens, labs = evaluate_self_reconstruction_ssim(cvae_model, val_loader, device)
print("SSIM Results:")
for k, v in ssim_res.items():
    print(f"{k}: mean={np.mean(v):.4f}, std={np.std(v):.4f}")

# 2b. Classifier-based swapped label evaluation
classifier_model, preprocess_fn = load_pretrained_pneumonia_classifier(device)
eval_res = evaluate_translations_and_confusion(classifier_model, preprocess_fn, cvae_model, val_loader, device)
plot_confusion_matrix_from_eval(eval_res)
