In [0]:
#| default_exp synthetic_data.diffusion_model

# Synthetic data with diffusion model
> Generate synthetic data with diffusion model

In [1]:
#| hide
%load_ext autoreload
%autoreload 2

In [2]:
#| export
from cv_tools.core import *
from cv_tools.imports import *


In [40]:
#| export
import torch
from torch.utils.data import Dataset, DataLoader
from diffusers import StableDiffusionPipeline, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
from accelerate import Accelerator
import albumentations as A
from PIL import Image
import numpy as np
import os
from fastcore.all import *
import time
import logging
import gc

In [4]:
#| export
class XRayDataset(Dataset):
    def __init__(self, image_paths, mask_paths, tokenizer, size=512):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.tokenizer = tokenizer
        self.size = size
        
        self.transform = A.Compose([
            A.Resize(size, size),
            A.Normalize(mean=[0.5], std=[0.5]),
        ])
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image and mask
        image = np.array(Image.open(self.image_paths[idx]).convert('RGB'))
        mask = np.array(Image.open(self.mask_paths[idx]).convert('L'))
        
        # Apply transforms
        transformed = self.transform(image=image, mask=mask)
        image = transformed['image']
        mask = transformed['mask']
        
        # Convert to tensor
        image = torch.from_numpy(image).permute(2, 0, 1).float()
        mask = torch.from_numpy(mask).unsqueeze(0).float()
        
        # Generate text embedding
        text = "X-ray image of chest"  # Customize based on your dataset
        text_ids = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids[0]
        
        return {
            "pixel_values": image,
            "mask": mask,
            "text_ids": text_ids,
        }


In [6]:
#| export
def fine_tune_stable_diffusion(
    train_dataset,
    output_dir="./sd-xray-model",
    num_epochs=100,
    batch_size=4,
    learning_rate=1e-5
):
    # Initialize accelerator
    accelerator = Accelerator()
    
    # Load models
    tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder")
    unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
    noise_scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
    
    # Freeze text encoder
    text_encoder.requires_grad_(False)
    
    # Create DataLoader
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
    )
    
    # Initialize optimizer
    optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)
    
    # Prepare everything for accelerator
    unet, optimizer, train_dataloader = accelerator.prepare(
        unet, optimizer, train_dataloader
    )
    
    # Training loop
    for epoch in range(num_epochs):
        unet.train()
        for step, batch in enumerate(train_dataloader):
            # Get input images and text
            clean_images = batch["pixel_values"]
            text_embeddings = text_encoder(batch["text_ids"])[0]
            
            # Add noise to images
            noise = torch.randn_like(clean_images)
            timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (clean_images.shape[0],))
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
            
            # Predict noise
            noise_pred = unet(noisy_images, timesteps, text_embeddings).sample
            
            # Calculate loss
            loss = torch.nn.functional.mse_loss(noise_pred, noise)
            
            # Backward pass and optimization
            accelerator.backward(loss)
            optimizer.step()
            optimizer.zero_grad()
            
            if step % 100 == 0:
                print(f"Epoch {epoch}, Step {step}: Loss {loss.item():.4f}")
        
        # Save checkpoint
        if epoch % 10 == 0:
            accelerator.save_state(f"{output_dir}/checkpoint-{epoch}")
    
    return unet

def generate_xray_with_mask(
    prompt,
    pipeline,
    num_inference_steps=50,
    guidance_scale=7.5,
    num_images=1
):
    # Generate images
    images = pipeline(
        prompt,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        num_images_per_prompt=num_images
    ).images
    
    # For each generated image, create a corresponding mask
    masks = []
    for image in images:
        # Convert to numpy array
        img_array = np.array(image)
        
        # Simple thresholding for demonstration
        # You might want to use more sophisticated segmentation here
        mask = (img_array.mean(axis=2) > 128).astype(np.uint8) * 255
        masks.append(Image.fromarray(mask))
    
    return images, masks


In [None]:


# Example usage
if __name__ == "__main__":
    # Setup paths
    image_paths = ["path/to/xray/images"]
    mask_paths = ["path/to/mask/images"]
    
    # Initialize tokenizer
    tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
    
    # Create dataset
    dataset = XRayDataset(image_paths, mask_paths, tokenizer)
    
    # Fine-tune model
    fine_tuned_unet = fine_tune_stable_diffusion(dataset)
    
    # Load fine-tuned pipeline
    pipeline = StableDiffusionPipeline.from_pretrained(
        "CompVis/stable-diffusion-v1-4",
        unet=fine_tuned_unet,
        torch_dtype=torch.float16
    ).to("cuda")
    
    # Generate new images with masks
    images, masks = generate_xray_with_mask(
        "X-ray image of chest with clear lung fields",
        pipeline
    )

# with memory optimization with 2080ti

In [10]:
#| export
import torch
from diffusers import StableDiffusionPipeline, DDPMScheduler
from accelerate import Accelerator
import torch.nn.functional as F
import psutil
import GPUtil
from threading import Thread

In [12]:
#| export
class GPUMonitor:
    """
    Real-time GPU memory monitoring
    """
    def __init__(self, delay=1):
        self.delay = delay
        self.monitoring = False
        self.memory_history = []

    def start(self):
        self.monitoring = True
        Thread(target=self._monitor).start()

    def stop(self):
        self.monitoring = False

In [None]:
#| export
@patch_to(GPUMonitor)
def _monitor(self):
    while self.monitoring:
        memory_used = torch.cuda.memory_allocated() / 1024**3	
        memory_reserved = torch.cuda.memory_reserved() / 1024**3

        self.memory_history.append(
			{
                'used':memory_used,
				'reserved':memory_reserved,
				'timestamp':time.time()
			}
        )
		# Alert if memory usage is too high (>90% of available memory)
        if memory_used / torch.cuda.get_device_properties(0).total_memory > 0.9:
            print(f"\n⚠️ WARNING: High GPU memory usage: {memory_used:.2f}GB")
            time.sleep(self.delay)


        time.sleep(self.delay)


In [16]:
#| export
@patch_to(GPUMonitor)
def get_stats(self):
	if not self.memory_history:
		return None
	memory_used = [x['used'] for x in self.memory_history]
	return {
		'max_used':np.max(memory_used),
		'avg_used':sum(memory_used)/len(memory_used),
		'current': memory_used[-1] if memory_used else 0
	}


In [29]:
#| export
class MemoryOptimizedDataset(Dataset):
	"Memory efficient dataset"
	def __init__(
			self, 
			image_paths,
			mask_paths,
			cache_dir='./cache',
			max_images=None,
			resolution=512
		):
		self.image_paths = [Path(p) for p in image_paths]
		self.mask_paths = [Path(p) for p in mask_paths]

		if len(self.image_paths) != len(self.mask_paths):
			raise ValueError(
				f"Number of images ({len(self.image_paths)}) and masks ({len(self.mask_paths)}) must be the same"
			)
		
		if max_images is not None:
			if not isinstance(max_images, int):
				raise ValueError(
					f"max_images must be an integer, got {type(max_images)}"
				)
			
			if max_images < 0:
				raise ValueError(
					f"max_images must be a non-negative integer, got {max_images}"
				)

			if max_images > len(image_paths):
				logging.warning(
                    f"max_images ({max_images}) is greater than available images "
                    f"({len(self.image_paths)}). Using all available images."
                )
				max_images = len(self.image_paths)

			self.image_paths = sorted(image_paths)[:max_images]
			self.mask_paths = sorted(mask_paths)[:max_images]
		
		# set up cache directory
		self.cache_dir = Path(cache_dir)

		Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
		self.resolution = resolution

		self.transform = self._setup_transforms()
		self._validate_files()


		logging.info(
			f"Initialized MemoryOptimizedDataset with {len(self)} images"
		)

	def _setup_transforms(self)->A.Compose:
		"""Setup memory-efficient transforms"""

		# Memory-efficient transforms
		return 	A.Compose([
			A.SmallestMaxSize(max_size=self.resolution),
			A.CenterCrop(self.resolution, self.resolution),
			A.Normalize(mean=[0.5], std=[0.5]),
		])

	def _validate_files(self):
		"""Validate all files exist and are readable"""
		invalid_pairs = []
		for img_path, mask_path in zip(self.image_paths, self.mask_paths):
			if not img_path.exists():
				invalid_pairs.append((str(img_path), "Image file missing"))
			if not mask_path.exists():
				invalid_pairs.append((str(mask_path), "Mask file missing"))
                
		if invalid_pairs:
			error_msg = "\nInvalid files found:"
			for path, reason in invalid_pairs:
				error_msg += f"\n{path}: {reason}"
			raise FileNotFoundError(error_msg)
        
        # Pre-compute and cache image statistics
		self._compute_dataset_stats()

	def _compute_dataset_stats(self):
		"""Compute dataset statistics for memory-efficient normalization"""
		print("Computing dataset statistics...")
		means, stds = [], []
		
		for img_path in self.image_paths[:min(100, len(self.image_paths))]:
			img = np.array(Image.open(img_path).convert('RGB'))
			means.append(img.mean() / 255.0)
			stds.append(img.std() / 255.0)
        
		self.dataset_mean = np.mean(means)
		self.dataset_std = np.max(stds)  # Use max std for better normalization





In [30]:
#| export
@patch_to(MemoryOptimizedDataset)
def __len__(self):
	return len(self.image_paths)



In [36]:
#| export
@patch_to(MemoryOptimizedDataset)
def get_cache_path(self, idx):
	return self.cache_dir / f"processed_{idx}.pt"

In [37]:
#| export
@patch_to(MemoryOptimizedDataset)
def load_image(self, path: Path) -> Tuple[np.ndarray, bool]:
        """Load image with error handling"""
        try:
            with Image.open(path) as img:
                if path in self.mask_paths:
                    img = img.convert('L')
                else:
                    img = img.convert('RGB')
                return np.array(img), True
        except Exception as e:
            logging.error(f"Error loading image {path}: {str(e)}")
            return None, False

In [38]:
#| export
@patch_to(MemoryOptimizedDataset)
def __getitem__(self, idx):
        # Load and process image with minimal memory usage
        try:
            # Check if cached version exists
            cache_path = self.get_cache_path(idx)
            if cache_path.exists():
                return torch.load(cache_path)


            img_path = self.image_paths[idx]
            mask_path = self.mask_paths[idx]

            img, img_success = self.load_image(img_path)
            mask, mask_success = self.load_image(mask_path)

            if not (img_success and mask_success):
                raise FileNotFoundError(
                    f"Error loading image or mask file at index {idx}"
				)
            
            # Apply transforms
            transformed = self.transform(image=img, mask=mask)
            img_tensor = torch.from_numpy(transformed['image']).permute(2, 0, 1)
            mask_tensor = torch.from_numpy(transformed['mask']).unsqueeze(0)
            
            # Cache the processed tensors
            data = {
                'image': img_tensor,
                'mask': mask_tensor,
                'path': str(self.image_paths[idx])
            }
            torch.save(data, cache_path)
            
            return data
            
        except Exception as e:
            logging.error(f"Error processing image {self.image_paths[idx]}: {str(e)}")
            return None

In [33]:
#| export
def handle_oom_error():
    """
    Handle out-of-memory errors
    """
    print("\n🚨 Out of Memory Error detected! Taking corrective actions...")
    
    # Clear GPU cache
    torch.cuda.empty_cache()
    gc.collect()
    
    # Get current memory status
    gpu = GPUtil.getGPUs()[0]
    memory_used = gpu.memoryUsed
    memory_total = gpu.memoryTotal
    
    print(f"\nGPU Memory Status:")
    print(f"Used: {memory_used}MB / {memory_total}MB")
    
    # Recommendations based on memory usage
    if memory_used / memory_total > 0.9:
        print("\nRecommendations:")
        print("1. Reduce batch size")
        print("2. Enable gradient checkpointing")
        print("3. Use mixed precision training")
        print("4. Reduce image resolution")
        print("5. Enable attention slicing")
    
    return True

In [34]:
#| export
def setup_memory_efficient_training(
    dataset,
    batch_size=1,
    max_memory_usage=0.9
):
    """
    Setup training with memory limits and monitoring
    """
    # Initialize memory monitor
    monitor = GPUMonitor()
    
    try:
        # Configure training components with memory limits
        pipeline = StableDiffusionPipeline.from_pretrained(
            "CompVis/stable-diffusion-v1-4",
            torch_dtype=torch.float16,
            safety_checker=None,
            requires_safety_checker=False
        )
        
        # Enable all memory optimizations
        pipeline.enable_attention_slicing(slice_size="max")
        pipeline.enable_vae_slicing()
        pipeline.enable_sequential_cpu_offload()


		# Memory-efficient dataloader
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=1,
            pin_memory=False,  # Disable pin_memory to save RAM
            drop_last=True
        )
        
        return pipeline, dataloader, monitor
        
    except RuntimeError as e:
        if "out of memory" in str(e):
            handle_oom_error()
        raise e

In [35]:
#| export
def monitor_training(monitor, pipeline, dataloader, num_epochs):
    """
    Monitor training process with memory tracking
    """
    monitor.start()
    
    try:
        for epoch in range(num_epochs):
            for batch in dataloader:
                # Check memory status
                stats = monitor.get_stats()
                if stats and stats['current'] > 10:  # 10GB threshold for 2080 Ti
                    print(f"\n⚠️ High memory usage detected: {stats['current']:.2f}GB")
                    
                    # Take preventive actions
                    torch.cuda.empty_cache()
                    gc.collect()
                
                # Training step here
                # ...
                
    except RuntimeError as e:
        if "out of memory" in str(e):
            handle_oom_error()
            raise e
    finally:
        monitor.stop()

In [None]:
# Example usage
if __name__ == "__main__":
    # Setup logging
    logging.basicConfig(level=logging.INFO)
    
    # Initialize dataset
    dataset = MemoryOptimizedDataset(
        image_paths=["path/to/images"],
        mask_paths=["path/to/masks"],
        resolution=512
    )
    
    try:
        # Setup training
        pipeline, dataloader, monitor = setup_memory_efficient_training(dataset)
        
        # Print initial memory status
        print("\nInitial GPU Memory Status:")
        print(f"Total: {torch.cuda.get_device_properties(0).total_memory/1024**3:.1f}GB")
        print(f"Reserved: {torch.cuda.memory_reserved()/1024**3:.1f}GB")
        print(f"Allocated: {torch.cuda.memory_allocated()/1024**3:.1f}GB")
        
        # Start training with monitoring
        monitor_training(monitor, pipeline, dataloader)
        
    except Exception as e:
        logging.error(f"Training failed: {str(e)}")
        handle_oom_error()

In [None]:


def setup_2080ti_optimized_training(
    dataset,
    model_id="CompVis/stable-diffusion-v1-4",
    output_dir="./sd-xray-model"
):
    """
    Optimized setup for RTX 2080 Ti (11GB VRAM)
    """
    # Memory optimization settings
    config = {
        "batch_size": 1,
        "image_size": 512,
        "gradient_accumulation_steps": 4,
        "mixed_precision": "fp16",
        "cache_latents": True,
        "use_8bit_adam": True
    }
    
    # Initialize accelerator with mixed precision
    accelerator = Accelerator(
        mixed_precision="fp16",
        gradient_accumulation_steps=config["gradient_accumulation_steps"],
    )
    
    # Load pipeline with memory optimizations
    pipeline = StableDiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        safety_checker=None,  # Disable safety checker to save memory
        requires_safety_checker=False
    )
    
    # Apply memory optimizations
    pipeline.enable_attention_slicing(slice_size="max")
    pipeline.enable_vae_slicing()
    pipeline.enable_sequential_cpu_offload()
    pipeline.enable_model_cpu_offload()
    
    # Enable gradient checkpointing
    pipeline.unet.enable_gradient_checkpointing()
    
    # Use 8-bit Adam optimizer to save memory
    try:
        import bitsandbytes as bnb
        optimizer = bnb.optim.AdamW8bit(
            pipeline.unet.parameters(),
            lr=1e-5,
            betas=(0.9, 0.999)
        )
    except ImportError:
        print("bitsandbytes not installed. Falling back to regular AdamW")
        optimizer = torch.optim.AdamW(
            pipeline.unet.parameters(),
            lr=1e-5
        )
    
    # Memory efficient dataloader
    train_dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=1,  # Reduced workers to save memory
        pin_memory=True
    )
    
    return pipeline, train_dataloader, optimizer, accelerator, config

def train_with_memory_monitoring(
    pipeline, 
    train_dataloader, 
    optimizer, 
    accelerator, 
    num_epochs=100
):
    """
    Training loop with memory monitoring for RTX 2080 Ti
    """
    # Get device
    device = accelerator.device
    
    # Training loop
    for epoch in range(num_epochs):
        total_loss = 0
        for step, batch in enumerate(train_dataloader):
            # Clear cache periodically
            if step % 10 == 0:
                torch.cuda.empty_cache()
            
            # Get memory status
            gpu_memory = torch.cuda.memory_allocated(device) / 1024**3
            gpu_memory_reserved = torch.cuda.memory_reserved(device) / 1024**3
            
            # Print memory usage
            if step % 10 == 0:
                print(f"\nGPU Memory in use: {gpu_memory:.2f}GB")
                print(f"GPU Memory reserved: {gpu_memory_reserved:.2f}GB")
                
                # Warning if memory is getting too high
                if gpu_memory > 10:  # 10GB threshold for 11GB card
                    print("WARNING: High memory usage detected!")
            
            with accelerator.accumulate(pipeline.unet):
                # Forward pass with memory optimization
                with torch.cuda.amp.autocast():
                    # Your training step here
                    # ... (rest of training logic)
                    pass
                
                # Backward pass
                accelerator.backward(loss)
                
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(pipeline.unet.parameters(), 1.0)
                
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
            
            # Save memory by moving batch to CPU
            del batch
            torch.cuda.empty_cache()
        
        # Save checkpoint with memory optimization
        if epoch % 5 == 0:
            print("Saving checkpoint...")
            # Move model to CPU before saving
            pipeline.to("cpu")
            accelerator.save_state(f"{output_dir}/checkpoint-{epoch}")
            pipeline.to(device)
            print("Checkpoint saved!")

def check_system_compatibility():
    """
    Check if system meets minimum requirements
    """
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    ram = psutil.virtual_memory().total / 1024**3
    
    print(f"\nSystem Check:")
    print(f"GPU: {gpu_name}")
    print(f"GPU Memory: {gpu_memory:.1f}GB")
    print(f"RAM: {ram:.1f}GB")
    
    if gpu_memory < 10:
        print("\nWARNING: GPU memory might be insufficient for Stable Diffusion training")
    if ram < 16:
        print("\nWARNING: More RAM recommended for optimal performance")
    
    return gpu_memory >= 10 and ram >= 16


In [None]:

# Example usage
if __name__ == "__main__":
    if check_system_compatibility():
        pipeline, dataloader, optimizer, accelerator, config = setup_2080ti_optimized_training(
            dataset=your_dataset  # Your dataset here
        )
        
        print("\nConfiguration for RTX 2080 Ti:")
        for key, value in config.items():
            print(f"{key}: {value}")
        
        train_with_memory_monitoring(pipeline, dataloader, optimizer, accelerator)

In [0]:
#| hide
import nbdev; nbdev.nbdev_export('20_synthetic_data.diffusion_model.ipynb')