# CellSimul Conditional GAN Training in Google Colab

This notebook allows you to train the CellSimul conditional GAN for fluorescent microscopy image generation using Google Colab's free GPU resources.

## Features:
- 🚀 Free GPU training with Google Colab
- 📦 Automatic repository installation
- 🎯 Conditional GAN training with real fluorescent + synthetic distance masks
- 📊 Real-time training visualization
- 💾 Easy model and sample saving to Google Drive

## Requirements:
- Google account for Colab access
- Your fluorescent images and distance masks (can be uploaded or mounted from Drive)

## 🔧 Step 1: Setup Environment and Install Repository

In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️ No GPU detected. Training will be slow on CPU.")
    print("💡 Enable GPU: Runtime → Change runtime type → Hardware accelerator → GPU")

In [None]:
# Install required packages
!pip install -q torch torchvision torchaudio
!pip install -q numpy matplotlib pillow tifffile scikit-image
!pip install -q ipywidgets tqdm

print("✅ Required packages installed!")

In [None]:
# Clone the CellSimul repository
import os
import sys

# Remove existing directory if it exists
if os.path.exists('CellSimul'):
    !rm -rf CellSimul

# Clone the repository
!git clone https://github.com/Cetus137/CellSimul.git

# Navigate to the repository
os.chdir('CellSimul/CellSimul')

# Add to Python path
if '/content/CellSimul/CellSimul' not in sys.path:
    sys.path.append('/content/CellSimul/CellSimul')
    sys.path.append('/content/CellSimul/CellSimul/cond_models')

print(f"✅ Repository cloned successfully!")
print(f"📁 Current directory: {os.getcwd()}")
print(f"📂 Contents: {os.listdir('.')}")

## 📂 Step 2: Data Setup

You have several options for getting your data into Colab:

### Option A: Mount Google Drive (Recommended)
If your data is already in Google Drive

In [None]:
# Option A: Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set paths to your data in Google Drive
# Adjust these paths to match your Google Drive structure
DRIVE_DATA_PATH = '/content/drive/MyDrive/CellSimul_Data'  # Adjust this path
FLUORESCENCE_DIR = f'{DRIVE_DATA_PATH}/fluorescence_rescaled'
DISTANCE_MASKS_DIR = f'{DRIVE_DATA_PATH}/distance_masks_rescaled'

print(f"📁 Drive mounted! Looking for data at:")
print(f"  Fluorescence: {FLUORESCENCE_DIR}")
print(f"  Distance masks: {DISTANCE_MASKS_DIR}")

# Check if directories exist
if os.path.exists(FLUORESCENCE_DIR):
    fluor_files = len([f for f in os.listdir(FLUORESCENCE_DIR) if f.endswith(('.tif', '.tiff'))])
    print(f"✅ Found {fluor_files} fluorescence images")
else:
    print(f"❌ Fluorescence directory not found: {FLUORESCENCE_DIR}")

if os.path.exists(DISTANCE_MASKS_DIR):
    mask_files = len([f for f in os.listdir(DISTANCE_MASKS_DIR) if f.endswith(('.tif', '.tiff'))])
    print(f"✅ Found {mask_files} distance mask images")
else:
    print(f"❌ Distance masks directory not found: {DISTANCE_MASKS_DIR}")

### Option B: Upload Files Directly
If you want to upload files directly to Colab

In [None]:
# Option B: Upload files directly (uncomment if not using Google Drive)
# from google.colab import files
# import zipfile

# # Create data directories
# os.makedirs('data/fluorescence_rescaled', exist_ok=True)
# os.makedirs('data/distance_masks_rescaled', exist_ok=True)

# print("📤 Upload your data files:")
# print("1. Fluorescence images (TIF files)")
# print("2. Distance mask images (TIF files)")
# print("3. Or upload ZIP files containing the images")

# uploaded = files.upload()

# # Extract ZIP files if uploaded
# for filename in uploaded.keys():
#     if filename.endswith('.zip'):
#         with zipfile.ZipFile(filename, 'r') as zip_ref:
#             zip_ref.extractall('data/')
#         print(f"✅ Extracted {filename}")

# FLUORESCENCE_DIR = 'data/fluorescence_rescaled'
# DISTANCE_MASKS_DIR = 'data/distance_masks_rescaled'

### Option C: Use Sample Data (for testing)
Generate synthetic sample data for testing the pipeline

In [None]:
# Option C: Generate sample data for testing (uncomment if needed)
# import numpy as np
# import tifffile
# from PIL import Image

# # Create sample data directories
# os.makedirs('data/fluorescence_rescaled', exist_ok=True)
# os.makedirs('data/distance_masks_rescaled', exist_ok=True)

# print("🧪 Generating synthetic sample data for testing...")

# # Generate 50 sample images
# for i in range(50):
#     # Synthetic fluorescent image
#     fluor_img = np.random.rand(256, 256) * 255
#     fluor_img = fluor_img.astype(np.uint8)
#     tifffile.imwrite(f'data/fluorescence_rescaled/sample_fluor_{i:03d}.tif', fluor_img)
    
#     # Synthetic distance mask
#     x, y = np.meshgrid(np.arange(256), np.arange(256))
#     center_x, center_y = np.random.randint(64, 192, 2)
#     distance = np.sqrt((x - center_x)**2 + (y - center_y)**2)
#     mask = np.exp(-distance / 30)  # Exponential decay
#     mask = (mask * 255).astype(np.uint8)
#     tifffile.imwrite(f'data/distance_masks_rescaled/sample_mask_{i:03d}.tif', mask)

# FLUORESCENCE_DIR = 'data/fluorescence_rescaled'
# DISTANCE_MASKS_DIR = 'data/distance_masks_rescaled'
# print("✅ Sample data generated!")

## 🔍 Step 3: Verify Data Setup

In [None]:
# Verify data setup
import matplotlib.pyplot as plt
import tifffile
import numpy as np

# Check data directories
if not os.path.exists(FLUORESCENCE_DIR):
    print(f"❌ Error: Fluorescence directory not found: {FLUORESCENCE_DIR}")
    print("Please set up your data using one of the options above.")
else:
    fluor_files = [f for f in os.listdir(FLUORESCENCE_DIR) if f.endswith(('.tif', '.tiff'))]
    print(f"✅ Found {len(fluor_files)} fluorescence images")

if not os.path.exists(DISTANCE_MASKS_DIR):
    print(f"❌ Error: Distance masks directory not found: {DISTANCE_MASKS_DIR}")
    print("Please set up your data using one of the options above.")
else:
    mask_files = [f for f in os.listdir(DISTANCE_MASKS_DIR) if f.endswith(('.tif', '.tiff'))]
    print(f"✅ Found {len(mask_files)} distance mask images")

# Visualize sample data
if os.path.exists(FLUORESCENCE_DIR) and os.path.exists(DISTANCE_MASKS_DIR):
    if fluor_files and mask_files:
        # Load sample images
        sample_fluor = tifffile.imread(os.path.join(FLUORESCENCE_DIR, fluor_files[0]))
        sample_mask = tifffile.imread(os.path.join(DISTANCE_MASKS_DIR, mask_files[0]))
        
        # Plot samples
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        axes[0].imshow(sample_fluor, cmap='green')
        axes[0].set_title('Sample Fluorescent Image')
        axes[0].axis('off')
        
        axes[1].imshow(sample_mask, cmap='hot')
        axes[1].set_title('Sample Distance Mask')
        axes[1].axis('off')
        
        # Overlay
        overlay = np.stack([sample_mask/255, sample_fluor/255, np.zeros_like(sample_mask)/255], axis=-1)
        axes[2].imshow(overlay)
        axes[2].set_title('Overlay (Red=Mask, Green=Fluorescent)')
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        print(f"📊 Image info:")
        print(f"  Fluorescent shape: {sample_fluor.shape}")
        print(f"  Distance mask shape: {sample_mask.shape}")
        print(f"  Fluorescent range: [{sample_fluor.min()}, {sample_fluor.max()}]")
        print(f"  Distance mask range: [{sample_mask.min()}, {sample_mask.max()}]")

## 🧠 Step 4: Import and Test Models

In [None]:
# Import the CellSimul models and training components
try:
    # Change to the models directory
    os.chdir('cond_models')
    
    # Import models
    from conditional_generator import ConditionalGenerator, SimpleConditionalGenerator
    from conditional_discriminator import ConditionalDiscriminator, SimpleConditionalDiscriminator
    from unpaired_conditional_dataloader import UnpairedConditionalImageDataset
    
    print("✅ Successfully imported CellSimul models!")
    
    # Test model initialization
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"🎯 Using device: {device}")
    
    # Initialize models
    generator = ConditionalGenerator().to(device)
    discriminator = ConditionalDiscriminator().to(device)
    
    print(f"🧠 Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
    print(f"🧠 Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
    
    # Test forward pass
    with torch.no_grad():
        test_noise = torch.randn(1, 100, device=device)
        test_mask = torch.randn(1, 1, 256, 256, device=device)
        test_generated = generator(test_noise, test_mask)
        test_discriminator_out = discriminator(test_generated, test_mask)
        
        print(f"✅ Model test successful!")
        print(f"  Generated image shape: {test_generated.shape}")
        print(f"  Discriminator output shape: {test_discriminator_out.shape}")
        
except Exception as e:
    print(f"❌ Error importing models: {e}")
    print("\nTrying to debug...")
    print(f"Current directory: {os.getcwd()}")
    print(f"Directory contents: {os.listdir('.')}")
    import traceback
    traceback.print_exc()

## 📊 Step 5: Create Dataset and Test Loading

In [None]:
# Create dataset
try:
    # Go back to main directory for data access
    os.chdir('..')
    
    # Create dataset
    dataset = UnpairedConditionalImageDataset(
        fluorescent_dir=FLUORESCENCE_DIR,
        mask_dir=DISTANCE_MASKS_DIR,
        image_size=256,
        max_images=100  # Limit for faster testing
    )
    
    print(f"✅ Dataset created successfully!")
    print(f"📊 Dataset size: {len(dataset)}")
    
    # Test data loading
    sample_fluor, sample_mask = dataset[0]
    print(f"📏 Sample shapes - Fluorescent: {sample_fluor.shape}, Mask: {sample_mask.shape}")
    print(f"📈 Sample ranges - Fluorescent: [{sample_fluor.min():.3f}, {sample_fluor.max():.3f}], Mask: [{sample_mask.min():.3f}, {sample_mask.max():.3f}]")
    
    # Create dataloader
    from torch.utils.data import DataLoader
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)  # num_workers=0 for Colab
    
    # Test batch loading
    batch_fluor, batch_mask = next(iter(dataloader))
    print(f"✅ Batch loading successful!")
    print(f"📦 Batch shapes - Fluorescent: {batch_fluor.shape}, Mask: {batch_mask.shape}")
    
except Exception as e:
    print(f"❌ Error creating dataset: {e}")
    import traceback
    traceback.print_exc()

## 🎯 Step 6: Training Configuration

In [None]:
# Training configuration
TRAINING_CONFIG = {
    'batch_size': 8,
    'num_epochs': 50,  # Reduced for Colab time limits
    'learning_rate': 0.0002,
    'save_frequency': 5,  # Save samples every 5 epochs
    'model_type': 'complex',  # or 'simple' for faster training
    'max_images': None,  # Use all available images, or set a number for faster training
    'device': device
}

print("🎯 Training Configuration:")
for key, value in TRAINING_CONFIG.items():
    print(f"  {key}: {value}")

# Create output directory
output_dir = 'colab_outputs'
os.makedirs(output_dir, exist_ok=True)
print(f"\n📁 Output directory: {output_dir}")

# Estimate training time
estimated_time_per_epoch = len(dataset) // TRAINING_CONFIG['batch_size'] * 2  # seconds
total_estimated_time = estimated_time_per_epoch * TRAINING_CONFIG['num_epochs'] / 60  # minutes
print(f"\n⏱️ Estimated training time: {total_estimated_time:.1f} minutes")

if total_estimated_time > 60:
    print("⚠️ Training might take more than 1 hour. Consider reducing num_epochs or using max_images for faster training.")

## 🚀 Step 7: Run Training

In [None]:
# Training function (simplified version of the training script)
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from tqdm import tqdm
import time

def train_conditional_gan(config):
    """Train the conditional GAN with the given configuration."""
    
    # Create fresh dataset with config
    dataset = UnpairedConditionalImageDataset(
        fluorescent_dir=FLUORESCENCE_DIR,
        mask_dir=DISTANCE_MASKS_DIR,
        image_size=256,
        max_images=config['max_images']
    )
    
    dataloader = DataLoader(
        dataset, 
        batch_size=config['batch_size'], 
        shuffle=True, 
        num_workers=0  # Must be 0 in Colab
    )
    
    # Initialize models
    if config['model_type'] == 'simple':
        generator = SimpleConditionalGenerator().to(config['device'])
        discriminator = SimpleConditionalDiscriminator().to(config['device'])
    else:
        generator = ConditionalGenerator().to(config['device'])
        discriminator = ConditionalDiscriminator().to(config['device'])
    
    # Optimizers
    g_optimizer = torch.optim.Adam(generator.parameters(), lr=config['learning_rate'], betas=(0.5, 0.999))
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=config['learning_rate']*0.5, betas=(0.5, 0.999))
    
    # Loss functions
    adversarial_criterion = nn.BCELoss()
    identity_criterion = nn.L1Loss()
    
    # Training tracking
    training_history = {
        'g_losses': [],
        'd_losses': [],
        'adv_losses': [],
        'identity_losses': []
    }
    
    print(f"🚀 Starting training on {len(dataset)} images...")
    start_time = time.time()
    
    # Training loop
    for epoch in range(config['num_epochs']):
        epoch_g_loss = 0.0
        epoch_d_loss = 0.0
        epoch_adv_loss = 0.0
        epoch_identity_loss = 0.0
        num_batches = 0
        
        # Progress bar for the epoch
        pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{config["num_epochs"]}')
        
        for batch_idx, (real_fluorescent, condition_masks) in enumerate(pbar):
            real_fluorescent = real_fluorescent.to(config['device'])
            condition_masks = condition_masks.to(config['device'])
            batch_size = real_fluorescent.size(0)
            
            # Train Generator
            g_optimizer.zero_grad()
            
            z = torch.randn(batch_size, 100, device=config['device'])
            generated_fluorescent = generator(z, condition_masks)
            
            # Identity loss (structure preservation)
            identity_loss = identity_criterion(generated_fluorescent, real_fluorescent) * 0.1
            
            # Adversarial loss
            d_output_fake = discriminator(generated_fluorescent, condition_masks)
            if d_output_fake.dim() > 2:
                d_output_fake = d_output_fake.view(batch_size, -1).mean(dim=1)
            
            valid_labels = torch.ones(batch_size, device=config['device'])
            adversarial_loss = adversarial_criterion(d_output_fake, valid_labels)
            
            g_loss = adversarial_loss + identity_loss
            g_loss.backward()
            g_optimizer.step()
            
            # Train Discriminator
            d_optimizer.zero_grad()
            
            # Real samples
            d_output_real = discriminator(real_fluorescent, condition_masks)
            if d_output_real.dim() > 2:
                d_output_real = d_output_real.view(batch_size, -1).mean(dim=1)
            
            real_labels = torch.ones(batch_size, device=config['device'])
            d_real_loss = adversarial_criterion(d_output_real, real_labels)
            
            # Fake samples
            d_output_fake_for_d = discriminator(generated_fluorescent.detach(), condition_masks)
            if d_output_fake_for_d.dim() > 2:
                d_output_fake_for_d = d_output_fake_for_d.view(batch_size, -1).mean(dim=1)
            
            fake_labels = torch.zeros(batch_size, device=config['device'])
            d_fake_loss = adversarial_criterion(d_output_fake_for_d, fake_labels)
            
            d_loss = (d_real_loss + d_fake_loss) / 2
            d_loss.backward()
            d_optimizer.step()
            
            # Update metrics
            epoch_g_loss += g_loss.item()
            epoch_d_loss += d_loss.item()
            epoch_adv_loss += adversarial_loss.item()
            epoch_identity_loss += identity_loss.item()
            num_batches += 1
            
            # Update progress bar
            pbar.set_postfix({
                'G_loss': f'{g_loss.item():.3f}',
                'D_loss': f'{d_loss.item():.3f}',
                'Adv': f'{adversarial_loss.item():.3f}',
                'Id': f'{identity_loss.item():.3f}'
            })
        
        # Epoch summary
        avg_g_loss = epoch_g_loss / num_batches
        avg_d_loss = epoch_d_loss / num_batches
        avg_adv_loss = epoch_adv_loss / num_batches
        avg_identity_loss = epoch_identity_loss / num_batches
        
        training_history['g_losses'].append(avg_g_loss)
        training_history['d_losses'].append(avg_d_loss)
        training_history['adv_losses'].append(avg_adv_loss)
        training_history['identity_losses'].append(avg_identity_loss)
        
        print(f"\n📊 Epoch {epoch+1} Summary:")
        print(f"  G_loss: {avg_g_loss:.4f}, D_loss: {avg_d_loss:.4f}")
        print(f"  Adv_loss: {avg_adv_loss:.4f}, Identity_loss: {avg_identity_loss:.4f}")
        
        # Save samples
        if (epoch + 1) % config['save_frequency'] == 0:
            generator.eval()
            with torch.no_grad():
                sample_z = torch.randn(4, 100, device=config['device'])
                sample_mask = condition_masks[:4]
                sample_generated = generator(sample_z, sample_mask)
                
                save_image(
                    sample_generated,
                    f'{output_dir}/samples_epoch_{epoch+1:03d}.png',
                    normalize=True,
                    nrow=2
                )
                print(f"💾 Saved sample images for epoch {epoch+1}")
            generator.train()
    
    # Save final models
    torch.save(generator.state_dict(), f'{output_dir}/final_generator.pth')
    torch.save(discriminator.state_dict(), f'{output_dir}/final_discriminator.pth')
    
    total_time = time.time() - start_time
    print(f"\n🎉 Training completed in {total_time/60:.1f} minutes!")
    
    return training_history, generator, discriminator

print("✅ Training function defined and ready!")

In [None]:
# Start training!
print("🚀 Starting Conditional GAN Training...")
print("⚠️  This may take a while. You can monitor progress with the progress bars.")
print("📱 Colab may disconnect after ~12 hours. Consider shorter training runs.")

# Run training
training_history, trained_generator, trained_discriminator = train_conditional_gan(TRAINING_CONFIG)

## 📈 Step 8: Visualize Training Results

In [None]:
# Plot training history
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Generator and Discriminator losses
axes[0, 0].plot(training_history['g_losses'], label='Generator Loss', color='blue')
axes[0, 0].plot(training_history['d_losses'], label='Discriminator Loss', color='red')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Generator vs Discriminator Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Adversarial loss
axes[0, 1].plot(training_history['adv_losses'], label='Adversarial Loss', color='green')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].set_title('Adversarial Loss')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Identity loss
axes[1, 0].plot(training_history['identity_losses'], label='Identity Loss', color='orange')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].set_title('Identity Loss (Structure Preservation)')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Loss ratio (training balance)
if len(training_history['g_losses']) > 0 and len(training_history['d_losses']) > 0:
    loss_ratios = [g/(d+1e-8) for g, d in zip(training_history['g_losses'], training_history['d_losses'])]
    axes[1, 1].plot(loss_ratios, label='G_loss / D_loss', color='purple')
    axes[1, 1].axhline(y=1.0, color='red', linestyle='--', alpha=0.5, label='Balanced (ratio=1)')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Loss Ratio')
    axes[1, 1].set_title('Training Balance')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{output_dir}/training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print("📈 Training history plotted and saved!")

## 🎨 Step 9: Generate and Visualize Results

In [None]:
# Generate new samples with the trained model
trained_generator.eval()

with torch.no_grad():
    # Get some real samples for comparison
    real_fluor, real_masks = next(iter(dataloader))
    real_fluor = real_fluor[:4].to(device)
    real_masks = real_masks[:4].to(device)
    
    # Generate new samples using the real masks
    z = torch.randn(4, 100, device=device)
    generated_fluor = trained_generator(z, real_masks)
    
    # Convert to numpy for visualization
    real_fluor_np = real_fluor.cpu().numpy()
    real_masks_np = real_masks.cpu().numpy()
    generated_fluor_np = generated_fluor.cpu().numpy()
    
    # Create comparison plot
    fig, axes = plt.subplots(3, 4, figsize=(16, 12))
    
    for i in range(4):
        # Real masks
        axes[0, i].imshow(real_masks_np[i, 0], cmap='hot')
        axes[0, i].set_title(f'Distance Mask {i+1}')
        axes[0, i].axis('off')
        
        # Real fluorescent
        real_img = (real_fluor_np[i, 0] + 1) / 2  # Denormalize
        axes[1, i].imshow(real_img, cmap='green')
        axes[1, i].set_title(f'Real Fluorescent {i+1}')
        axes[1, i].axis('off')
        
        # Generated fluorescent
        gen_img = (generated_fluor_np[i, 0] + 1) / 2  # Denormalize
        axes[2, i].imshow(gen_img, cmap='green')
        axes[2, i].set_title(f'Generated Fluorescent {i+1}')
        axes[2, i].axis('off')
    
    plt.suptitle('Conditional GAN Results: Distance Masks → Generated Fluorescent Images', fontsize=16)
    plt.tight_layout()
    plt.savefig(f'{output_dir}/final_results.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("🎨 Generated new fluorescent images conditioned on distance masks!")
    print(f"💾 Results saved to {output_dir}/final_results.png")

## 💾 Step 10: Save Results to Google Drive (Optional)

In [None]:
# Copy results to Google Drive (if mounted)
try:
    if os.path.exists('/content/drive/MyDrive'):
        import shutil
        
        # Create output directory in Drive
        drive_output_dir = '/content/drive/MyDrive/CellSimul_Training_Results'
        os.makedirs(drive_output_dir, exist_ok=True)
        
        # Copy all output files
        for file in os.listdir(output_dir):
            src = os.path.join(output_dir, file)
            dst = os.path.join(drive_output_dir, file)
            shutil.copy2(src, dst)
        
        print(f"✅ Results copied to Google Drive: {drive_output_dir}")
        print("📁 Saved files:")
        for file in os.listdir(drive_output_dir):
            print(f"  - {file}")
    else:
        print("ℹ️ Google Drive not mounted. Results are saved locally in Colab.")
        print("⚠️ Remember that Colab files are temporary and will be lost when the session ends.")
        
except Exception as e:
    print(f"❌ Error saving to Drive: {e}")

## 📥 Step 11: Download Results (Alternative)

In [None]:
# Download results as ZIP file
from google.colab import files
import zipfile

# Create ZIP file with all results
zip_filename = 'cellsimul_training_results.zip'
with zipfile.ZipFile(zip_filename, 'w') as zipf:
    for root, dirs, files_list in os.walk(output_dir):
        for file in files_list:
            file_path = os.path.join(root, file)
            arcname = os.path.relpath(file_path, output_dir)
            zipf.write(file_path, arcname)

print(f"📦 Created {zip_filename} with all training results")
print(f"📊 ZIP file size: {os.path.getsize(zip_filename) / 1024 / 1024:.1f} MB")

# Download the ZIP file
print("⬇️ Downloading results...")
files.download(zip_filename)

## 🎯 Step 12: Using the Trained Model

In [None]:
# Example of how to use the trained model for inference
def generate_fluorescent_from_mask(generator, distance_mask, device, num_variations=4):
    """
    Generate fluorescent images from a distance mask.
    
    Args:
        generator: Trained generator model
        distance_mask: Input distance mask (numpy array or tensor)
        device: Device to run inference on
        num_variations: Number of different generations from the same mask
    
    Returns:
        Generated fluorescent images
    """
    generator.eval()
    
    with torch.no_grad():
        # Prepare mask tensor
        if isinstance(distance_mask, np.ndarray):
            mask_tensor = torch.from_numpy(distance_mask).float()
        else:
            mask_tensor = distance_mask.float()
        
        # Ensure correct shape
        if mask_tensor.dim() == 2:
            mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0)  # Add batch and channel dims
        elif mask_tensor.dim() == 3:
            mask_tensor = mask_tensor.unsqueeze(0)  # Add batch dim
        
        # Repeat for multiple variations
        mask_batch = mask_tensor.repeat(num_variations, 1, 1, 1).to(device)
        
        # Generate random noise
        noise = torch.randn(num_variations, 100, device=device)
        
        # Generate fluorescent images
        generated = generator(noise, mask_batch)
        
        return generated

# Example usage
print("🧪 Testing inference with trained model...")

# Get a sample mask
sample_mask = real_masks[0:1]  # Take first mask

# Generate variations
variations = generate_fluorescent_from_mask(trained_generator, sample_mask, device, num_variations=6)

# Visualize variations
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

# Show original mask
axes[0, 0].imshow(sample_mask[0, 0].cpu().numpy(), cmap='hot')
axes[0, 0].set_title('Input Distance Mask')
axes[0, 0].axis('off')

# Show variations
for i in range(6):
    row = i // 3
    col = (i % 3) + 1
    
    if col > 3:
        row = 1
        col = i % 3
    
    gen_img = (variations[i, 0].cpu().numpy() + 1) / 2  # Denormalize
    axes[row, col].imshow(gen_img, cmap='green')
    axes[row, col].set_title(f'Generated Variation {i+1}')
    axes[row, col].axis('off')

# Clear unused subplot
axes[0, 3].axis('off')

plt.suptitle('Single Distance Mask → Multiple Fluorescent Variations', fontsize=16)
plt.tight_layout()
plt.savefig(f'{output_dir}/inference_example.png', dpi=150, bbox_inches='tight')
plt.show()

print("✅ Inference example completed!")
print("🎯 The trained model can generate diverse fluorescent images from the same distance mask.")

## 📋 Summary and Next Steps

🎉 **Congratulations!** You have successfully:

✅ Set up the CellSimul environment in Google Colab  
✅ Loaded your fluorescent microscopy data  
✅ Trained a conditional GAN model  
✅ Generated realistic fluorescent images from distance masks  
✅ Saved and visualized your results  

### 🔄 **To run more training:**
- Adjust `TRAINING_CONFIG` parameters above
- Re-run the training cell
- Try different model types (`simple` vs `complex`)

### 📈 **To improve results:**
- Increase `num_epochs` for longer training
- Adjust `learning_rate` if training is unstable
- Use more training data if available
- Experiment with different loss weightings

### 💾 **Your trained models are saved as:**
- `final_generator.pth` - The trained generator
- `final_discriminator.pth` - The trained discriminator
- Sample images and training plots

### 🚀 **Next steps:**
- Use the trained model to generate synthetic training data
- Fine-tune on specific cell types or conditions
- Integrate into your research pipeline
- Share results with collaborators

---

**Need help?** Check the [CellSimul repository](https://github.com/Cetus137/CellSimul) for documentation and examples!