# 🚀 LAION-2B Progressive Training
**Launched from Cursor**

This notebook will train your custom LoRA model using the progressive approach:
- Phase 1: 1,000 samples (5 minutes)
- Phase 2: 5,000 samples (15 minutes)
- Phase 3: 10,000 samples (30 minutes)
- Phase 4: 50,000 samples (2-3 hours)

**Total Time: 2.5-3.5 hours on Colab Free**

In [None]:
# Check GPU
import torch
print(f"GPU: {torch.cuda.get_device_name() if torch.cuda.is_available() else 'CPU'}")
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB" if torch.cuda.is_available() else 'No GPU')

In [None]:
# Install required packages
!pip install -q diffusers transformers accelerate peft datasets xformers webdataset torchvision pillow requests
print("✅ Packages installed")

In [None]:
# Download training script
import requests

# Get the full training script
script_content = '''#!/usr/bin/env python3\n"""\nProgressive LAION-2B Training Script for Google Colab\nSmart training approach: Start small, test quality, scale up\nOptimized for Colab Free with practical sample sizes\n"""\n\nimport os\nimport torch\nimport logging\nfrom pathlib import Path\nimport time\nimport json\nfrom datetime import datetime\n\n# Setup logging\nlogging.basicConfig(\n    level=logging.INFO,\n    format=\'%(asctime)s - %(levelname)s - %(message)s\'\n)\nlogger = logging.getLogger(__name__)\n\nclass ProgressiveLAION2BTrainer:\n    def __init__(self):\n        """Initialize progressive trainer for Colab Free"""\n        self.setup_device()\n        \n        # Progressive training phases\n        self.training_phases = [\n            {"name": "Phase 1 - Quick Test", "samples": 1000, "time": "5 minutes", "purpose": "Basic testing"},\n            {"name": "Phase 2 - Quality Training", "samples": 5000, "time": "15 minutes", "purpose": "Good quality"},\n            {"name": "Phase 3 - Enhanced Training", "samples": 10000, "time": "30 minutes", "purpose": "Very good quality"},\n            {"name": "Phase 4 - Professional", "samples": 50000, "time": "2-3 hours", "purpose": "Professional quality"}\n        ]\n        \n        # Current phase settings\n        self.current_phase = 0\n        self.batch_size = 8\n        self.image_size = 512\n        self.max_text_length = 77\n        \n        # Training state\n        self.global_step = 0\n        self.total_samples_processed = 0\n        self.start_time = time.time()\n        self.phase_start_time = time.time()\n        \n        # Quality metrics\n        self.loss_history = []\n        self.quality_threshold = 0.02  # Good loss threshold\n        \n        # Create output directories\n        Path("progressive_training_outputs").mkdir(exist_ok=True)\n        Path("progressive_training_outputs/checkpoints").mkdir(exist_ok=True)\n        Path("progressive_training_outputs/phases").mkdir(exist_ok=True)\n        \n        logger.info(f"🚀 Progressive LAION-2B Trainer initialized on {self.device}")\n        logger.info(f"📊 Training Phases: {len(self.training_phases)} phases planned")\n        \n    def setup_device(self):\n        """Setup device for Colab GPU"""\n        if torch.cuda.is_available():\n            self.device = "cuda"\n            # Enable memory efficient attention\n            torch.backends.cuda.enable_flash_sdp(True)\n            torch.backends.cuda.enable_mem_efficient_sdp(True)\n            logger.info(f"✅ CUDA GPU available: {torch.cuda.get_device_name()}")\n            logger.info(f"✅ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")\n        else:\n            self.device = "cpu"\n            logger.warning("⚠️ No CUDA GPU found, using CPU (very slow!)")\n    \n    def load_models(self):\n        """Load models optimized for Colab GPUs"""\n        logger.info("📦 Loading models for progressive training...")\n        \n        from diffusers import StableDiffusionPipeline, DDPMScheduler\n        from peft import LoraConfig, get_peft_model\n        from torch.optim import AdamW\n        from torch.cuda.amp import GradScaler\n        \n        # Load Stable Diffusion pipeline with optimizations\n        pipe = StableDiffusionPipeline.from_pretrained(\n            "runwayml/stable-diffusion-v1-5",\n            torch_dtype=torch.float16,\n            safety_checker=None,\n            requires_safety_checker=False,\n            use_safetensors=True\n        )\n        \n        # Extract components\n        self.unet = pipe.unet\n        self.vae = pipe.vae\n        self.tokenizer = pipe.tokenizer\n        self.text_encoder = pipe.text_encoder\n        \n        # Load noise scheduler\n        self.noise_scheduler = DDPMScheduler.from_pretrained(\n            "runwayml/stable-diffusion-v1-5", subfolder="scheduler"\n        )\n        \n        # Apply LoRA with progressive settings\n        lora_config = LoraConfig(\n            r=16,  # Start with moderate rank\n            lora_alpha=32,\n            target_modules=["to_q", "to_k", "to_v", "to_out.0"],\n            lora_dropout=0.1,\n            bias="none",\n        )\n        \n        # Apply LoRA to UNet\n        self.unet = get_peft_model(self.unet, lora_config)\n        self.unet.print_trainable_parameters()\n        \n        # Move models to device\n        self.unet.to(self.device)\n        self.text_encoder.to(self.device)\n        self.vae.to(self.device)\n        \n        # Enable gradient checkpointing for memory efficiency\n        self.unet.enable_gradient_checkpointing()\n        \n        # Setup optimizer\n        self.optimizer = AdamW(self.unet.parameters(), lr=2e-4, weight_decay=0.01)\n        \n        # Setup mixed precision training\n        self.scaler = GradScaler()\n        \n        logger.info("✅ Models loaded and optimized for progressive training")\n    \n    def load_dataset(self):\n        """Load LAION-2B dataset with streaming"""\n        logger.info("📚 Loading LAION-2B dataset...")\n        \n        from datasets import load_dataset\n        \n        self.dataset = load_dataset(\n            "laion/laion2B-en",\n            split="train",\n            streaming=True\n        ).shuffle(buffer_size=10000)\n        \n        logger.info("✅ LAION-2B dataset loaded")\n    \n    def is_valid_sample(self, sample):\n        """Validate sample data (optimized)"""\n        try:\n            sample_lower = {k.lower(): v for k, v in sample.items()}\n            text = sample_lower.get("text") or sample_lower.get("caption")\n            url = sample_lower.get("url") or sample_lower.get("image_url")\n            \n            # Quick validation\n            return bool(text and url and isinstance(url, str) and url.startswith((\'http://\', \'https://\')))\n        except:\n            return False\n    \n    def normalize_sample(self, sample):\n        """Normalize sample keys"""\n        sample_lower = {k.lower(): v for k, v in sample.items()}\n        return {\n            "text": sample_lower.get("text") or sample_lower.get("caption"),\n            "url": sample_lower.get("url") or sample_lower.get("image_url")\n        }\n    \n    def preprocess_image(self, image_url):\n        """Download and preprocess image"""\n        try:\n            import requests\n            from PIL import Image\n            from io import BytesIO\n            import numpy as np\n            \n            response = requests.get(image_url, timeout=10)\n            image = Image.open(BytesIO(response.content)).convert("RGB")\n            image = image.resize((self.image_size, self.image_size))\n            \n            # Convert to tensor [C, H, W] normalized to [0, 1]\n            image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0\n            return image_tensor\n        except:\n            return None\n    \n    def preprocess_images_parallel(self, image_urls):\n        """Download and preprocess multiple images in parallel"""\n        import concurrent.futures\n        \n        with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:\n            futures = [executor.submit(self.preprocess_image, url) for url in image_urls]\n            results = [future.result() for future in concurrent.futures.as_completed(futures)]\n        return [r for r in results if r is not None]\n    \n    def training_step(self, images, captions):\n        """Single training step with mixed precision"""\n        import torch.nn.functional as F\n        from torch.cuda.amp import autocast\n        \n        # Move images to device\n        images = images.to(self.device, dtype=torch.float16)\n        \n        # Tokenize captions\n        text_inputs = self.tokenizer(\n            captions,\n            padding="max_length",\n            truncation=True,\n            max_length=self.max_text_length,\n            return_tensors="pt",\n        )\n        input_ids = text_inputs.input_ids.to(self.device)\n        \n        # Get text embeddings\n        with torch.no_grad():\n            encoder_hidden_states = self.text_encoder(input_ids)[0]\n        \n        # Encode images to latent space\n        with torch.no_grad():\n            latents = self.vae.encode(images).latent_dist.sample()\n            latents = latents * 0.18215\n        \n        # Add noise to latents\n        noise = torch.randn_like(latents)\n        timesteps = torch.randint(\n            0, self.noise_scheduler.config.num_train_timesteps,\n            (latents.shape[0],), device=self.device\n        ).long()\n        noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)\n        \n        # UNet forward pass with mixed precision\n        with autocast():\n            noise_pred = self.unet(\n                noisy_latents,\n                timesteps,\n                encoder_hidden_states=encoder_hidden_states\n            ).sample\n            loss = F.mse_loss(noise_pred, noise, reduction="mean")\n        \n        return loss\n    \n    def evaluate_quality(self):\n        """Evaluate model quality and decide next phase"""\n        if len(self.loss_history) < 10:\n            return "continue"  # Need more data\n        \n        recent_losses = self.loss_history[-10:]  # Last 10 losses\n        avg_loss = sum(recent_losses) / len(recent_losses)\n        \n        logger.info(f"📊 Quality Assessment:")\n        logger.info(f"   Recent average loss: {avg_loss:.4f}")\n        logger.info(f"   Quality threshold: {self.quality_threshold:.4f}")\n        \n        if avg_loss < self.quality_threshold:\n            logger.info("✅ Quality threshold met! Ready for next phase.")\n            return "proceed"\n        else:\n            logger.info("⚠️ Quality threshold not met. Continue current phase.")\n            return "continue"\n    \n    def save_phase_checkpoint(self):\n        """Save checkpoint for current phase"""\n        phase = self.training_phases[self.current_phase]\n        phase_dir = f"progressive_training_outputs/phases/{phase[\'name\'].replace(\' \', \'_\').replace(\'-\', \'_\')}"\n        Path(phase_dir).mkdir(exist_ok=True)\n        \n        # Save LoRA weights\n        self.unet.save_pretrained(phase_dir)\n        \n        # Save phase info\n        phase_info = {\n            "phase_name": phase["name"],\n            "samples_processed": self.total_samples_processed,\n            "global_step": self.global_step,\n            "phase_time": time.time() - self.phase_start_time,\n            "average_loss": sum(self.loss_history[-10:]) / len(self.loss_history[-10:]) if self.loss_history else 0,\n            "timestamp": datetime.now().isoformat()\n        }\n        \n        with open(f"{phase_dir}/phase_info.json", \'w\') as f:\n            json.dump(phase_info, f, indent=2)\n        \n        logger.info(f"💾 Phase checkpoint saved: {phase[\'name\']}")\n    \n    def train_phase(self):\n        """Train for current phase"""\n        phase = self.training_phases[self.current_phase]\n        target_samples = phase["samples"]\n        \n        logger.info(f"🚀 Starting {phase[\'name\']}")\n        logger.info(f"📊 Target: {target_samples:,} samples")\n        logger.info(f"⏱️ Expected time: {phase[\'time\']}")\n        logger.info(f"🎯 Purpose: {phase[\'purpose\']}")\n        \n        # Reset phase timer\n        self.phase_start_time = time.time()\n        \n        # Load models and dataset if first phase\n        if self.current_phase == 0:\n            self.load_models()\n            self.load_dataset()\n        \n        dataset_iter = iter(self.dataset)\n        phase_samples_processed = 0\n        \n        while phase_samples_processed < target_samples:\n            try:\n                # Collect valid samples for batch\n                batch_samples = []\n                batch_urls = []\n                batch_texts = []\n                \n                # Collect multiple samples at once\n                for _ in range(self.batch_size * 3):\n                    try:\n                        sample = next(dataset_iter)\n                    except StopIteration:\n                        dataset_iter = iter(self.dataset)\n                        sample = next(dataset_iter)\n                    \n                    # Validate sample\n                    if self.is_valid_sample(sample):\n                        normalized = self.normalize_sample(sample)\n                        batch_samples.append(normalized)\n                        batch_urls.append(normalized["url"])\n                        batch_texts.append(normalized["text"][:self.max_text_length])\n                    \n                    # Stop if we have enough samples\n                    if len(batch_samples) >= self.batch_size:\n                        break\n                \n                # Process images in parallel\n                if batch_urls:\n                    batch_images = self.preprocess_images_parallel(batch_urls[:self.batch_size])\n                    batch_texts = batch_texts[:len(batch_images)]\n                else:\n                    batch_images = []\n                \n                # Skip if no valid images in batch\n                if not batch_images:\n                    continue\n                \n                # Stack images into batch tensor\n                batch_images_tensor = torch.stack(batch_images)\n                \n                # Training step\n                loss = self.training_step(batch_images_tensor, batch_texts)\n                \n                # Backward pass with mixed precision\n                self.scaler.scale(loss).backward()\n                self.scaler.step(self.optimizer)\n                self.scaler.update()\n                self.optimizer.zero_grad()\n                \n                # Update counters\n                samples_in_batch = len(batch_images)\n                self.total_samples_processed += samples_in_batch\n                phase_samples_processed += samples_in_batch\n                self.global_step += 1\n                \n                # Track loss history\n                self.loss_history.append(loss.item())\n                if len(self.loss_history) > 100:  # Keep last 100 losses\n                    self.loss_history.pop(0)\n                \n                # Logging\n                if self.global_step % 5 == 0:\n                    elapsed_time = time.time() - self.phase_start_time\n                    samples_per_sec = phase_samples_processed / elapsed_time if elapsed_time > 0 else 0\n                    recent_loss = sum(self.loss_history[-5:]) / len(self.loss_history[-5:])\n                    \n                    logger.info(\n                        f"📊 Phase {self.current_phase + 1} | "\n                        f"Step {self.global_step} | "\n                        f"Loss: {recent_loss:.4f} | "\n                        f"Phase Samples: {phase_samples_processed:,}/{target_samples:,} | "\n                        f"Speed: {samples_per_sec:.1f} samples/sec"\n                    )\n                \n                # Save checkpoint every 50 steps\n                if self.global_step % 50 == 0:\n                    self.save_phase_checkpoint()\n                \n                # Quality check every 100 steps\n                if self.global_step % 100 == 0:\n                    quality_decision = self.evaluate_quality()\n                    if quality_decision == "proceed" and self.current_phase < len(self.training_phases) - 1:\n                        logger.info("🎯 Quality target reached! Moving to next phase...")\n                        break\n                \n            except Exception as e:\n                logger.error(f"Training error: {e}")\n                import traceback\n                traceback.print_exc()\n                continue\n        \n        # Save final phase checkpoint\n        self.save_phase_checkpoint()\n        \n        # Phase completion summary\n        phase_time = time.time() - self.phase_start_time\n        logger.info(f"✅ {phase[\'name\']} completed!")\n        logger.info(f"⏱️ Phase time: {phase_time/60:.1f} minutes")\n        logger.info(f"📊 Total samples processed: {self.total_samples_processed:,}")\n        \n        return phase_samples_processed >= target_samples\n    \n    def train_progressive(self):\n        """Main progressive training loop"""\n        logger.info("🚀 Starting Progressive LAION-2B Training...")\n        logger.info("📋 Training Strategy:")\n        for i, phase in enumerate(self.training_phases):\n            logger.info(f"   {i+1}. {phase[\'name\']}: {phase[\'samples\']:,} samples ({phase[\'time\']})")\n        \n        for phase_idx in range(len(self.training_phases)):\n            self.current_phase = phase_idx\n            phase = self.training_phases[phase_idx]\n            \n            logger.info(f"\n{\'=\'*60}")\n            logger.info(f"🎯 PHASE {phase_idx + 1}/{len(self.training_phases)}: {phase[\'name\']}")\n            logger.info(f"{\'=\'*60}")\n            \n            # Train current phase\n            phase_completed = self.train_phase()\n            \n            if not phase_completed:\n                logger.warning(f"⚠️ Phase {phase_idx + 1} did not complete. Stopping training.")\n                break\n            \n            # Ask user if they want to continue to next phase\n            if phase_idx < len(self.training_phases) - 1:\n                logger.info(f"\n💡 Phase {phase_idx + 1} completed successfully!")\n                logger.info(f"📊 Next phase: {self.training_phases[phase_idx + 1][\'name\']}")\n                logger.info(f"⏱️ Expected time: {self.training_phases[phase_idx + 1][\'time\']}")\n                \n                # In Colab, we\'ll auto-continue, but log the decision\n                logger.info("🔄 Auto-continuing to next phase...")\n                logger.info("💡 You can stop training anytime by interrupting the kernel")\n        \n        logger.info("\n🎉 Progressive training completed!")\n        self.save_final_model()\n    \n    def save_final_model(self):\n        """Save final trained model"""\n        final_dir = "progressive_training_outputs/final_model"\n        Path(final_dir).mkdir(exist_ok=True)\n        \n        # Save final LoRA weights\n        self.unet.save_pretrained(final_dir)\n        \n        # Save comprehensive training info\n        training_info = {\n            "total_phases_completed": self.current_phase + 1,\n            "total_steps": self.global_step,\n            "total_samples": self.total_samples_processed,\n            "total_training_time": time.time() - self.start_time,\n            "final_average_loss": sum(self.loss_history[-10:]) / len(self.loss_history[-10:]) if self.loss_history else 0,\n            "device": self.device,\n            "batch_size": self.batch_size,\n            "image_size": self.image_size,\n            "phases": self.training_phases[:self.current_phase + 1],\n            "timestamp": datetime.now().isoformat()\n        }\n        \n        with open(f"{final_dir}/training_info.json", \'w\') as f:\n            json.dump(training_info, f, indent=2)\n        \n        logger.info("✅ Final model saved to progressive_training_outputs/final_model")\n        logger.info(f"📊 Training Summary:")\n        logger.info(f"   Total phases: {self.current_phase + 1}")\n        logger.info(f"   Total samples: {self.total_samples_processed:,}")\n        logger.info(f"   Total time: {(time.time() - self.start_time)/60:.1f} minutes")\n        logger.info(f"   Final loss: {training_info[\'final_average_loss\']:.4f}")\n\ndef main():\n    """Main function"""\n    trainer = ProgressiveLAION2BTrainer()\n    trainer.train_progressive()\n\nif __name__ == "__main__":\n    main() '''

with open('train_laion2b_colab_progressive.py', 'w') as f:
    f.write(script_content)

print("✅ Training script downloaded")

In [None]:
# Run progressive training
!python train_laion2b_colab_progressive.py

In [None]:
# Download results
from google.colab import files
import zipfile
import os

# Create zip file
!zip -r progressive_trained_model.zip progressive_training_outputs/

# Download to local machine
files.download('progressive_trained_model.zip')

print("✅ Model downloaded successfully!")
print("📁 Check your Downloads folder for 'progressive_trained_model.zip'")