In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from torch.cuda.amp import autocast
import numpy as np
from scipy.linalg import sqrtm
import pandas as pd
from typing import Tuple, List
import os
from PIL import Image

In [2]:
if torch.cuda.is_available():
    torch.cuda.set_device(0)  # Set the current device to the first GPU
    print("Using GPU")
else:
    print("Using CPU")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class VGG16_MRI(nn.Module):
    def __init__(self, num_classes=2):
        super(VGG16_MRI, self).__init__()
        model = torchvision.models.vgg16_bn(pretrained=True)
        model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.feature = model.features
        self.feat_dim = 512 * 7 * 7
        self.num_classes = num_classes
        
        # Batch normalization layer
        self.bn = nn.BatchNorm1d(self.feat_dim)
        self.bn.bias.requires_grad_(False)  # no shift
        
        # Fully connected layer to map features to the number of classes
        self.fc_layer = nn.Linear(self.feat_dim, self.num_classes)
        
        self.model = model
            
    def forward(self, x):
        # Pass input through feature extraction layers
        feature = self.feature(x)
        feature = feature.view(feature.size(0), -1)  # Flatten the feature map
        feature = self.bn(feature)  # Apply batch normalization
        res = self.fc_layer(feature)  # Output class scores
        
        return feature, res

    def predict(self, x):
        # Pass input through feature extraction layers
        feature = self.feature(x)
        feature = feature.view(feature.size(0), -1)  # Flatten the feature map
        feature = self.bn(feature)  # Apply batch normalization
        res = self.fc_layer(feature)  # Output class scores

        return res

Using GPU


In [3]:
class Generator(nn.Module):
    def __init__(self, in_dim=100, dim=64):
        super(Generator, self).__init__()
        
        def dconv_bn_relu(in_dim, out_dim):
            return nn.Sequential(
                nn.ConvTranspose2d(in_dim, out_dim, 5, 2, padding=2, output_padding=1, bias=False),
                nn.BatchNorm2d(out_dim),
                nn.ReLU())
        
        # Fully connected layer to expand noise to a larger size
        self.l1 = nn.Sequential(
            nn.Linear(in_dim, dim * 8 * 15 * 15, bias=False),
            nn.BatchNorm1d(dim * 8 * 15 * 15),
            nn.ReLU())

        # Deconvolutional layers for upsampling to 240x240
        self.l2_5 = nn.Sequential(
            dconv_bn_relu(dim * 8, dim * 4),   # 15x15 -> 30x30
            dconv_bn_relu(dim * 4, dim * 2),   # 30x30 -> 60x60
            dconv_bn_relu(dim * 2, dim),       # 60x60 -> 120x120
            nn.ConvTranspose2d(dim, 1, 5, 2, padding=2, output_padding=1),  # 120x120 -> 240x240
            nn.Sigmoid())  # Output pixel values in range [0, 1]

    def forward(self, x):
        y = self.l1(x)
        y = y.view(y.size(0), -1, 15, 15)
        y = self.l2_5(y)
        return y

In [4]:
def load_pretrained_classifier():
    path = "/kaggle/input/brats23-classifier/pytorch/default/1/classifier.pt"
    model = VGG16_MRI(num_classes=2)
    model.load_state_dict(torch.load(path))
    model.eval()
    return model
    
def get_augmodel():
    # model = pretrained_VGG_MRI_model
    model = load_pretrained_classifier()
    model = torch.nn.DataParallel(model).cuda()
    return model
    
def get_GAN():
    G = Generator(100)
    # D = MinibatchDiscriminator(n_classes=n_classes)
    
    G = torch.nn.DataParallel(G).to(device)
    # D = torch.nn.DataParallel(D).to(device)
    root_path = "/kaggle/input/brats23-gan-epoch75/pytorch/default/1/attack_results"
    dataset_name = "BraTS23"
    model_name_T = "VGG16_MRI"
    path = os.path.join(root_path, os.path.join(dataset_name, model_name_T))
    # path = os.path.join(os.path.join(gan_model_dir, dataset), target_model)
    path_G = os.path.join(path, "ep75_improved_{}_G.pt".format(dataset_name))
    # path_D = os.path.join(path, "ep75_improved_{}_D.pt".format(dataset_name))
    ckp_G = torch.load(path_G)
    G.load_state_dict(ckp_G['state_dict'], strict=True)
    # ckp_D = torch.load(path_D)
    # D.load_state_dict(ckp_D['state_dict'], strict=True)
    print("Loaded Pretrained Model (Specific GAN)")
    return G

In [5]:
class MRIDataset(Dataset):
    """Dataset for loading grayscale MRI images with labels."""
    def __init__(self, df: pd.DataFrame, data_dir: str, transform=None):
        self.df = df
        self.data_dir = data_dir
        self.transform = transform
        
    def __len__(self) -> int:
        return len(self.df)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        img_name = os.path.join(self.data_dir, self.df.iloc[idx]['filename'])
        image = Image.open(img_name).convert('L')  # Convert to grayscale
        label = int(self.df.iloc[idx]['label'])
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

class FakeDataset(Dataset):
    """Dataset wrapper for generated grayscale images."""
    def __init__(self, images: torch.Tensor):
        self.images = images
        self.images = images.detach().cpu()
    def __len__(self) -> int:
        return len(self.images)
    
    def __getitem__(self, idx: int) -> torch.Tensor:
        return self.images[idx]

def extract_features(
    data_loader: DataLoader,
    model: nn.Module,
    device: torch.device,
    is_fake: bool = False
) -> np.ndarray:
    """Extract features from grayscale images using the provided model."""
    model.eval()
    features = []
    
    with torch.no_grad():
        for batch in data_loader:
            images = batch if is_fake else batch[0]
            
            # Ensure images are in the correct format (B, 1, H, W)
            if images.dim() == 3:
                images = images.unsqueeze(1)
                
            images = images.to(device)
            
            with autocast():
                feature, _ = model(images)
            features.append(feature.cpu().numpy())
            
    return np.concatenate(features, axis=0)

def calculate_fid(real_features: np.ndarray, fake_features: np.ndarray) -> float:
    """Calculate the Fréchet Inception Distance between real and fake features."""
    mu_real = np.mean(real_features, axis=0)
    sigma_real = np.cov(real_features, rowvar=False)
    
    mu_fake = np.mean(fake_features, axis=0)
    sigma_fake = np.cov(fake_features, rowvar=False)
    
    diff = mu_real - mu_fake
    covmean = sqrtm(sigma_real.dot(sigma_fake))
    
    if np.iscomplexobj(covmean):
        covmean = covmean.real
        
    fid = diff.dot(diff) + np.trace(sigma_real + sigma_fake - 2 * covmean)
    return float(fid)

def generate_images(
    generator: nn.Module,
    n_generated: int,
    z_dim: int,
    device: torch.device,
    batch_size: int = 32
) -> torch.Tensor:
    """Generate fake grayscale images in batches."""
    generator.eval()
    images = []
    
    with torch.no_grad():
        for i in range(0, n_generated, batch_size):
            batch_size_i = min(batch_size, n_generated - i)
            noise = torch.randn(batch_size_i, z_dim, device=device)
            
            with autocast():
                fake_batch = generator(noise)
                fake_images = fake_images.cpu()
                # Ensure output is single-channel grayscale
                if fake_batch.size(1) != 1:
                    fake_batch = fake_batch.mean(dim=1, keepdim=True)
            images.append(fake_batch)
            
    return torch.cat(images)

def compute_fid_score(
    classifier: nn.Module,
    generator: nn.Module,
    real_dataset: Dataset,
    batch_size: int,
    n_generated: int,
    z_dim: int,
    device: torch.device
) -> float:
    """Compute FID score between real and generated grayscale images."""
    real_loader = DataLoader(
        real_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    real_features = extract_features(real_loader, classifier, device)
    
    chunk_size = min(n_generated // 2, 1000)  # Limit chunk size
    fake_features_chunks = []
    
    for i in range(0, n_generated, chunk_size):
        chunk_size_i = min(chunk_size, n_generated - i)
        fake_images = generate_images(generator, chunk_size_i, z_dim, device)
        
        fake_dataset = FakeDataset(fake_images)
        fake_loader = DataLoader(
            fake_dataset,
            batch_size=batch_size,
            num_workers=2,
            pin_memory=True
        )
        
        chunk_features = extract_features(fake_loader, classifier, device, is_fake=True)
        fake_features_chunks.append(chunk_features)
    
    fake_features = np.concatenate(fake_features_chunks, axis=0)
    return calculate_fid(real_features, fake_features)

def main():
    # Configuration
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
    config = {
        'batch_size': 1,
        'n_generated': 10,
        'z_dim': 100,
        'data_dir': "/kaggle/input/preprocessed-brats23/Images",
        'csv_path': "/kaggle/input/preprocessed-brats23/labels.csv",
        'image_size': 240
    }
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Data preprocessing for grayscale images
    transform = transforms.Compose([
        transforms.Resize((config['image_size'], config['image_size'])),
        transforms.ToTensor()
    ])
    
    # Load dataset
    df = pd.read_csv(config['csv_path'])
    real_dataset = MRIDataset(df=df, data_dir=config['data_dir'], transform=transform)
    
    # Load models
    classifier = get_augmodel()
    generator = get_GAN()
    classifier.to(device)
    generator.to(device)
    
    # Calculate FID score
    fid_score = compute_fid_score(
        classifier=classifier,
        generator=generator,
        real_dataset=real_dataset,
        batch_size=config['batch_size'],
        n_generated=config['n_generated'],
        z_dim=config['z_dim'],
        device=device
    )
    
    print(f"FID Score: {fid_score:.4f}")

In [None]:
main()

Using device: cuda


Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /root/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth
100%|██████████| 528M/528M [00:02<00:00, 211MB/s] 
  model.load_state_dict(torch.load(path))
  ckp_G = torch.load(path_G)


Loaded Pretrained Model (Specific GAN)


  with autocast():
  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
