In [29]:
import os
import glob
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
from tqdm import tqdm

# **CONFIGURATION**

In [30]:
# ==========================================
# 
# ==========================================               # Where to save new "dreamt" images
# Paths 
MASK_DIR = os.path.join(os.getcwd(),os.pardir,"an2dl2526c2", "preprocessing_results", "train_patches", "masks")  # Where your binary masks are
IMAGE_DIR = os.path.join(os.getcwd(),os.pardir,"an2dl2526c2", "preprocessing_results", "train_patches")        # Where your real tissue patches are
#SYNTHETIC_OUT_DIR = os.path.join(os.getcwd(),os.pardir,"an2dl2526c2", "synthetic_patches")    
#MASK_DIR = os.path.join(os.getcwd(),"temp","train_mask")  # Where your binary masks are
#IMAGE_DIR = os.path.join(os.getcwd(),"temp","train_patch")  # Wher    # Where your real tissue patches are
SYNTHETIC_OUT_DIR = os.path.join(os.getcwd(),"temp", "synthetic_patches")

img_size = 256
batch_size = 16
epochs = 50          # 50-100 is usually enough for texture learning
lr = 0.0002
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print(f"Mask dir: {MASK_DIR}")
print(f"Image dir: {IMAGE_DIR}")
print(f"Synthetic output dir: {SYNTHETIC_OUT_DIR}")

Using device: cuda
Mask dir: d:\POLIMI\AN2DL\AN2DL_CH_2\Scripts\..\an2dl2526c2\preprocessing_results\train_patches\masks
Image dir: d:\POLIMI\AN2DL\AN2DL_CH_2\Scripts\..\an2dl2526c2\preprocessing_results\train_patches
Synthetic output dir: d:\POLIMI\AN2DL\AN2DL_CH_2\Scripts\temp\synthetic_patches


# 1. MODEL ARCHITECTURE (Pix2Pix)

In [31]:
# ==========================================

# ==========================================

class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_size))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super(UNetUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(out_size),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)
        return x

class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(GeneratorUNet, self).__init__()
        # Standard U-Net structure
        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512, dropout=0.5) 
        self.down5 = UNetDown(512, 512, dropout=0.5)
        self.down6 = UNetDown(512, 512, dropout=0.5)
        self.down7 = UNetDown(512, 512, dropout=0.5)
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)

        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 512, dropout=0.5)
        self.up4 = UNetUp(1024, 512, dropout=0.5)
        self.up5 = UNetUp(1024, 256)
        self.up6 = UNetUp(512, 128)
        self.up7 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(128, out_channels, 4, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)
        return self.final(u7)

# 2. DATASET

In [32]:
# ==========================================
# 2. DATASET (ROBUST VERSION)
# ==========================================
class PairedMaskDataset(Dataset):
    def __init__(self, mask_dir, img_dir, transforms_=None):
        self.mask_dir = mask_dir
        self.img_dir = img_dir
        self.transforms = transforms_
        
        # 1. Check if directories exist
        if not os.path.isdir(mask_dir):
            raise ValueError(f"Mask directory not found: {mask_dir}")
        if not os.path.isdir(img_dir):
            raise ValueError(f"Image directory not found: {img_dir}")

        # 2. Get file lists
        mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith('.png')])
        img_files_set = set(os.listdir(img_dir)) # Set for fast lookup
        
        self.pairs = []
        
        # 3. Match logic
        for m_file in mask_files:
            # OPTION A: Exact match (e.g. "img_01.png" in masks vs "img_01.png" in images)
            if m_file in img_files_set:
                self.pairs.append((m_file, m_file))
                continue
            
            # OPTION B: Prefix mismatch (e.g. "mask_img_01.png" vs "img_01.png")
            # Try removing "mask_" or "mask"
            clean_name = m_file.replace("mask_", "").replace("mask", "")
            # Sometimes clean_name might need "img" added back if it was stripped too much? 
            # Let's try simple replacement first: "mask" -> "img"
            swapped_name = m_file.replace("mask", "img")
            
            if clean_name in img_files_set:
                self.pairs.append((m_file, clean_name))
            elif swapped_name in img_files_set:
                self.pairs.append((m_file, swapped_name))
        
        # 4. Debugging output if empty
        if len(self.pairs) == 0:
            print("\n!!! ERROR: No pairs found. !!!")
            print(f"Mask Dir: {mask_dir}")
            print(f"First 5 masks: {mask_files[:5]}")
            print(f"First 5 images in target dir: {list(img_files_set)[:5]}")
            raise ValueError("Dataset is empty. Please check the filenames printed above.")

        print(f"Found {len(self.pairs)} paired training examples.")

    def __getitem__(self, index):
        mask_name, img_name = self.pairs[index]
        
        mask_path = os.path.join(self.mask_dir, mask_name)
        img_path = os.path.join(self.img_dir, img_name)
        
        mask = Image.open(mask_path).convert("RGB")
        img = Image.open(img_path).convert("RGB")
        
        if self.transforms:
            # Apply seed for geometric consistency
            seed = np.random.randint(2147483647) 
            
            random.seed(seed)
            torch.manual_seed(seed)
            mask = self.transforms(mask)
            
            random.seed(seed)
            torch.manual_seed(seed)
            img = self.transforms(img)
            
        return {"mask": mask, "image": img, "filename": mask_name}

    def __len__(self):
        return len(self.pairs)



# 3. TRAINING ROUTINE


In [33]:

def train_pix2pix():
    os.makedirs("saved_models", exist_ok=True)
    
    # Transforms
    transforms_ = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    
    # Loaders
    dataset = PairedMaskDataset(MASK_DIR, IMAGE_DIR, transforms_=transforms_)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    
    # Initialize Model
    generator = GeneratorUNet().to(device)
    optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    criterion_pixel = torch.nn.L1Loss() # Pixel-wise loss is enough for simple structure learning
    
    print("Starting Training...")
    generator.train()
    
    for epoch in range(epochs):
        loop = tqdm(dataloader, leave=False)
        for batch in loop:
            real_mask = batch["mask"].to(device)
            real_img = batch["image"].to(device)

            # Train Generator
            optimizer_G.zero_grad()
            
            fake_img = generator(real_mask)
            
            # Loss: L1 distance between generated tissue and real tissue
            loss_G = criterion_pixel(fake_img, real_img)
            
            loss_G.backward()
            optimizer_G.step()
            
            loop.set_description(f"Epoch [{epoch}/{epochs}]")
            loop.set_postfix(loss=loss_G.item())
            
    # Save Model
    torch.save(generator.state_dict(), "saved_models/generator_final.pth")
    print("Training Complete. Model saved.")
    return generator


# 4. SYNTHESIS ROUTINE ("Dreaming")

In [34]:
def generate_synthetic_data(generator, num_variations=1):
    os.makedirs(SYNTHETIC_OUT_DIR, exist_ok=True)
    generator.eval()
    
    # CRITICAL: Enable Dropout during inference to generate diversity
    # This turns the deterministic UNet into a stochastic one
    for m in generator.modules():
        if isinstance(m, nn.Dropout):
            m.train() 
            
    transforms_ = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    
    # We iterate over ALL masks (training set)
    mask_files = [f for f in os.listdir(MASK_DIR) if f.endswith('.png')]
    print(f"Generating {num_variations} variations for {len(mask_files)} masks...")
    
    with torch.no_grad():
        for filename in tqdm(mask_files):
            # Load Mask
            mask_path = os.path.join(MASK_DIR, filename)
            mask_img = Image.open(mask_path).convert("RGB")
            mask_tensor = transforms_(mask_img).unsqueeze(0).to(device)
            
            # Generate N variations
            for i in range(num_variations):
                # Forward pass (Dropout creates random texture variations)
                fake_img = generator(mask_tensor)
                
                # Denormalize
                fake_img = fake_img * 0.5 + 0.5
                
                # Save: format "filename_v1.png"
                save_name = f"{os.path.splitext(filename)[0]}_syn_v{i+1}.png"
                save_path = os.path.join(SYNTHETIC_OUT_DIR, save_name)
                save_image(fake_img, save_path)

# RUN

In [35]:
# ==========================================

# ==========================================

# Step 1: Train
if os.path.exists("saved_models/generator_final.pth"):
    print("Loading pre-trained generator...")
    gen = GeneratorUNet().to(device)
    gen.load_state_dict(torch.load("saved_models/generator_final.pth"))
else:
    gen = train_pix2pix()
    
# Step 2: Generate
# Set num_variations=1 or 2 to double/triple your dataset
generate_synthetic_data(gen, num_variations=1)

Found 4071 paired training examples.
Starting Training...


                                                                            

KeyboardInterrupt: 