In [181]:
!pip install torch torchvision
!pip install torch torchvision pillow
!pip install torch torchvision matplotlib
!pip install tqdm




In [182]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.optim import Adam
from tqdm.notebook import tqdm
from torch import optim


In [183]:
pcam_directory =  '/Users/costanzasiniscalchi/Documents/Senior/ACV/project/histopathologic-cancer-detection'

In [184]:
class CGanGenerator(nn.Module):
    def __init__(self, n_classes=2, output_channels=3):
        super().__init__()
        self.n_classes = n_classes
        self.output_channels = output_channels
        self.embedding = nn.Embedding(n_classes, 96 * 96)  # Adjusted embedding dimension
        self.conv1 = nn.Conv2d(3, 1024, kernel_size=3, stride=1, padding=1)
        self.conv_blocks = nn.ModuleList([
            self._get_conv_transpose_block(1024 + 1, 512),  # Adjusted for one additional channel from embedding
            self._get_conv_transpose_block(512, 256),
            self._get_conv_transpose_block(256, 128),
            self._get_conv_transpose_block(128, output_channels, last_block=True)
        ])

    def _get_conv_transpose_block(self, in_channels, out_channels, last_block=False):
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        ]
        if last_block:
            layers[-1] = nn.Tanh()
        return nn.Sequential(*layers)

    def forward(self, x, labels):
        labels = labels.long()
        x = F.relu(self.conv1(x))
        y = self.embedding(labels).view(-1, 1, 96, 96)  # Adjusted view to match the new embedding dimension
        x = torch.cat([x, y], 1)  # Concatenate label embedding to feature map
        for block in self.conv_blocks:
            x = block(x)
        return x


In [185]:
class CGanDiscriminator(nn.Module):
    def __init__(self, input_shape=(3, 96, 96), n_classes=2):
        super().__init__()
        self.n_classes = n_classes
        self.conv_blocks = nn.ModuleList([
            self._get_conv_block(3, 64, first_block=True),
            self._get_conv_block(64, 128),
            self._get_conv_block(128, 256),
            self._get_conv_block(256, 512),
            self._get_conv_block(512, 512)
        ])
        # Calculate the output size of the last conv layer dynamically
        example_input = torch.rand(1, *input_shape)
        output_size = self.forward_features(example_input).view(-1).shape[0]
        print("Output size for linear layer:", output_size)  # Debug print

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(output_size, 1),
            nn.Sigmoid()
        )
        self.class_label = nn.Sequential(
            nn.Flatten(),
            nn.Linear(output_size, n_classes),
            nn.Softmax(dim=1)
        )

    def _get_conv_block(self, in_channels, out_channels, first_block=False):
        layers = [
            nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.Dropout(0.5)
        ]
        if not first_block:
            layers.insert(1, nn.BatchNorm2d(out_channels))
        return nn.Sequential(*layers)

    def forward_features(self, x):
        for block in self.conv_blocks:
            x = block(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        real_fake = self.classifier(x)
        class_label = self.class_label(x)
        return real_fake, class_label


In [186]:
class CGan(nn.Module):
    def __init__(self, generator, discriminator):
        super().__init__()
        self.generator = generator
        self.discriminator = discriminator

    def forward(self, noise, labels):
        self.discriminator.eval()
        gen_imgs = self.generator(noise, labels)
        validity, _ = self.discriminator(gen_imgs)
        return validity



In [187]:
import os
import torch
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image


def get_transforms():
    return transforms.Compose([
        transforms.Resize((96, 96)),  # Assuming images need to be resized
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

def generate_real_samples(batch, device='cpu'):
    """Prepares a batch of real samples from the DataLoader"""
    images, latent_vectors, labels = batch  # Adjust this line based on your DataLoader output
    images = images.to(device)
    labels = labels.to(device)
    y = torch.ones(images.size(0), 1, dtype=torch.float32).to(device)  # Labels for real samples are all ones
    return images, labels, y

def generate_latent_points(latent_dim, n_samples, n_classes=2, device='cpu'):
    """Generates a batch of latent vectors of random points"""
    z_input = torch.randn(n_samples, latent_dim, device=device)
    labels = torch.randint(0, n_classes, (n_samples,), device=device)
    return z_input, labels

def generate_fake_samples(generator, latent_dim, n_samples, device='cpu'):
    """Generates a batch of fake samples from latent vectors using the generator model."""
    z_input, labels = generate_latent_points(latent_dim, n_samples, device=device)
    with torch.no_grad():
        # Make sure to pass both z_input and labels to the generator
        images = generator(z_input, labels)
    y = torch.zeros(n_samples, 1, dtype=torch.float32).to(device)  # Labels for fake samples are all zeros
    return images, labels, y


In [188]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from math import sqrt
from os.path import join
from dataclasses import dataclass

@dataclass
class TrainParam:
    n_epochs: int
    batch_size: int
    latent_dim: int
    epoch_checkpoint: int
    n_summary_samples: int
    starting_epoch: int = 0
    output_path: str = './'  # Default output path
    model_path: str = './'   # Default model path

import torch
from torch import optim
import os

def trainer(gan, train_loader, train_param, device):
    generator, discriminator = gan.generator, gan.discriminator
    

    # Establish criterion for loss calculation
    adversarial_loss = torch.nn.BCELoss()
    reconstruction_loss = torch.nn.L1Loss()

    # Optimizers
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    generator.to(device)
    discriminator.to(device)

    for epoch in range(train_param.starting_epoch, train_param.n_epochs):
        epoch_loss_D, epoch_loss_G = 0, 0
        train_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{train_param.n_epochs}', leave=False)
        for i, (blurred_images, real_images, labels) in enumerate(train_loader):
            batch_size = blurred_images.size(0)
            real_labels = torch.ones(batch_size, 1, device=device)
            fake_labels = torch.zeros(batch_size, 1, device=device)

            # Transfer data to device
            real_images = real_images.to(device)
            blurred_images = blurred_images.to(device)
            labels = labels.to(device)

            # -----------------
            # Train Discriminator
            # -----------------
            optimizer_D.zero_grad()

            # Real samples
            real_preds, _ = discriminator(real_images)
            d_real_loss = adversarial_loss(real_preds, real_labels)

            # Fake samples
            generated_images = generator(blurred_images, labels)
            fake_preds, _ = discriminator(generated_images.detach())
            d_fake_loss = adversarial_loss(fake_preds, fake_labels)

            # Total discriminator loss
            d_loss = (d_real_loss + d_fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()

            # -----------------
            # Train Generator
            # -----------------
            optimizer_G.zero_grad()

            # Another forward pass for the generated images
            fake_preds, _ = discriminator(generated_images)
            g_adv_loss = adversarial_loss(fake_preds, real_labels)
            g_rec_loss = reconstruction_loss(generated_images, real_images)

            # Total generator loss
            g_loss = 0.001 * g_adv_loss + g_rec_loss
            g_loss.backward()
            optimizer_G.step()

            # Update progress bar
            train_bar.set_postfix({
                'D Loss': f'{d_loss.item():.4f}',
                'G Loss': f'{g_loss.item():.4f}'
            })

            epoch_loss_D += d_loss.item()
            epoch_loss_G += g_loss.item()

            if (i + 1) % 100 == 0:
                print(f"Epoch {epoch+1}/{train_param.n_epochs}, Batch {i+1}/{len(train_loader)}, Discriminator Loss: {d_loss.item()}, Generator Loss: {g_loss.item()}")

        # Save models periodically
        if (epoch + 1) % train_param.epoch_checkpoint == 0:
            torch.save(generator.state_dict(), os.path.join(train_param.model_path, f'generator_epoch_{epoch+1}.pth'))
            torch.save(discriminator.state_dict(), os.path.join(train_param.model_path, f'discriminator_epoch_{epoch+1}.pth'))





def plot_images(X, figsize, n_samples, epoch, output_path=None):
    """Plot and save generated images."""
    plt.figure(figsize=figsize)
    sample_sqrt = int(sqrt(n_samples))
    plt.subplots_adjust(right=0.9, left=0.0, top=0.9, bottom=0.0, hspace=0.02, wspace=0.02)
    for i in range(n_samples):
        plt.subplot(sample_sqrt, sample_sqrt, 1 + i)
        plt.axis('off')
        plt.imshow(X[i].transpose(1, 2, 0))  # Transpose as needed depending on data format
    plt.show()
    if output_path:
        plt.savefig(join(output_path, f'generated_plot_{epoch}.png'))
        plt.close()


In [189]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import os
import random
import matplotlib.pyplot as plt
import torch.nn.functional as F

class BlurredPCamDataset(Dataset):
    def __init__(self, root_dir, label_mapping, transform=None):
        self.root_dir = root_dir
        self.label_mapping = label_mapping
        self.transform = transform
        self.all_fps = [os.path.join(root_dir, fp) for fp in label_mapping.keys()]
        self.kernel_cache = {}  # Cache Gaussian kernels

    def __getitem__(self, index):
        image_fp = self.all_fps[index]
        slide_id = os.path.basename(image_fp)
        image = Image.open(image_fp).convert('RGB')

        if self.transform:
            image = self.transform(image)

        # Convert image to tensor (if not already done by transform)
        if not isinstance(image, torch.Tensor):
            image = transforms.ToTensor()(image)

        # Degrade the image by adding blur and noise
        degraded_image, _ = self.degrade_image(image)

        label = self.label_mapping[slide_id]

        return degraded_image, image, label

    def __len__(self):
        return len(self.all_fps)

    def create_gaussian_kernel(self, size, sigma):
        if (size, sigma) not in self.kernel_cache:
            ax = torch.linspace(-(size - 1) / 2., (size - 1) / 2., size)
            xx, yy = torch.meshgrid(ax, ax)
            kernel = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2))
            kernel = kernel / torch.sum(kernel)
            self.kernel_cache[(size, sigma)] = kernel.view(1, 1, size, size)
        return self.kernel_cache[(size, sigma)]

    def degrade_image(self, image):
        c, h, w = image.shape
        patch_size = random.randint(10, 32)  # Random patch size between 10 and 32
        x = random.randint(0, w - patch_size)
        y = random.randint(0, h - patch_size)

        # Extract patch
        patch = image[:, y:y + patch_size, x:x + patch_size]

        # Apply Gaussian blur
        if random.random() > 0.5:
            size = random.choice([3, 5, 7])  # Kernel size choice
            sigma = random.uniform(0.5, 1.5)  # Sigma choice
            blur_kernel = self.create_gaussian_kernel(size, sigma).repeat(c, 1, 1, 1).to(image.device)
            pad_size = size // 2
            padded_patch = F.pad(patch, (pad_size, pad_size, pad_size, pad_size), mode='reflect')
            blurred_patch = F.conv2d(padded_patch.unsqueeze(0), blur_kernel, padding=0, stride=1, groups=c).squeeze(0)
            # Ensure the size matches the original patch
            blurred_patch = blurred_patch[:, pad_size:-pad_size, pad_size:-pad_size]
            if blurred_patch.shape[1] != patch_size or blurred_patch.shape[2] != patch_size:
                blurred_patch = F.interpolate(blurred_patch.unsqueeze(0), size=(patch_size, patch_size), mode='bilinear').squeeze(0)
            image[:, y:y + patch_size, x:x + patch_size] = blurred_patch

        # Optionally add Gaussian noise
        if random.random() > 0.5:
            noise = torch.randn_like(patch) * 0.05  # Noise addition
            image[:, y:y + patch_size, x:x + patch_size] += noise

        return image, (x, y, patch_size, patch_size)

    def show_dataset_images(self, indices, ncols=3):
        plt.figure(figsize=(15, 5))
        for i, idx in enumerate(indices):
            degraded_image, original_image, _ = self[idx]
            plt.subplot(1, ncols, i + 1)
            plt.imshow(degraded_image.permute(1, 2, 0).numpy())
            plt.title(f"Index: {idx}")
            plt.axis('off')
        plt.show()


In [190]:
import csv
label_mapping = {}
with open(os.path.join(pcam_directory, 'train_labels.csv'), 'r') as f:
    reader = csv.reader(f)
    next(reader)  # To skip the header
    label_mapping = {slide_id +'.tif': int(label) for [slide_id, label] in reader}
all_fps = [fp for fp in os.listdir(os.path.join(pcam_directory, 'train'))]
for fp in all_fps: assert fp[-4:] == '.tif', fp[-4:]


In [191]:

train_transforms = transforms.Compose([
    transforms.ColorJitter(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Imagenet statistics
])
val_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Imagenet statistics
])
test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Imagenet statistics
])


In [192]:
permutation = np.random.permutation(len(all_fps))
num_train = int(len(permutation) * 0.6)

train_dataset = BlurredPCamDataset(os.path.join(pcam_directory, 'train'), {fp: label_mapping[fp] for fp in all_fps[:num_train] if fp in label_mapping}, transform=train_transforms)
test_dataset = BlurredPCamDataset(os.path.join(pcam_directory, 'train'), {fp: label_mapping[fp] for fp in all_fps[num_train:] if fp in label_mapping}, transform=test_transforms)



In [193]:

train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [194]:
import torchsummary

In [195]:
latent_dimension = 100  # Dimension of the latent space
n_classes = 2          # Number of classes (for conditional GAN)
output_channels = 3    # Output channels for the generated images (typically 3 for RGB images)
input_shape = (3, 96, 96)  # Input shape for the discriminator (channels, height, width)

train_param = TrainParam(
    n_epochs=2000,
    batch_size=128,
    latent_dim=100,
    epoch_checkpoint=20,
    n_summary_samples=36
)

# generator = CGanGenerator(latent_dimension=latent_dimension, n_classes=n_classes, output_channels=output_channels)

generator = CGanGenerator(n_classes=n_classes, output_channels=output_channels)

discriminator = CGanDiscriminator(input_shape=input_shape, n_classes=n_classes)

# Configure device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move models to the correct device
generator.to(device)
discriminator.to(device)

# Printing summaries
# The generator expects a latent space vector and class labels as input
# The discriminator expects an image of shape input_shape

print("Generator Summary:")
image_size = (3, 96, 96)
# torchsummary.summary(generator, input_size=image_size, device=str(device))

print("\nDiscriminator Summary:")
# torchsummary.summary(discriminator, input_shape, device=str(device))
gan = CGan(generator=generator, discriminator=discriminator)


Output size for linear layer: 4608
Generator Summary:

Discriminator Summary:


In [196]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [197]:
gan_model, history = trainer(gan, train_dataloader, train_param, device=device)

Epoch 1/2000:   0%|          | 0/1032 [00:00<?, ?it/s]

: 