In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from diffusers import DDPMScheduler, UNet2DModel, AutoencoderKL
from pathlib import Path
import logging
from tqdm import tqdm
import torchvision.utils as vutils
from PIL import Image
import torchvision.transforms as T
import numpy as np


In [None]:
!pip install -r /kaggle/input/codesa/requirements.txt

Preprocess Data

In [None]:
import os
import numpy as np
import cv2
from PIL import Image
from tqdm import tqdm
from pathlib import Path
import logging
import shutil

class DataPreprocessor:
    def __init__(self, input_dir, output_dir):
        self.input_dir = Path(input_dir)
        self.output_base = Path(output_dir)
        
        # Verify input directory exists
        if not self.input_dir.exists():
            raise FileNotFoundError(f"Input directory not found: {self.input_dir}")
        
        logging.info(f"Input directory: {self.input_dir}")
        logging.info(f"Output directory: {self.output_base}")
        
        # Create train/val/test splits
        self.splits = {
            'train': self.output_base / 'train',
            'val': self.output_base / 'val',
            'test': self.output_base / 'test'
        }
    
    def get_all_images(self):
        """Get all images from all class folders"""
        all_images = []
        # Iterate through all class folders
        for class_dir in self.input_dir.iterdir():
            if class_dir.is_dir():  # Make sure it's a directory
                logging.info(f"Found class directory: {class_dir.name}")
                # Get all images in this class directory
                class_images = list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.png'))
                all_images.extend(class_images)
                logging.info(f"Found {len(class_images)} images in {class_dir.name}")
        
        return all_images

    def setup_directories(self):
        """Create all necessary directories"""
        created_dirs = []
        for split in self.splits.values():
            for subdir in ['clean', 'noisy', 'low_res', 'masked']:
                dir_path = split / subdir
                dir_path.mkdir(parents=True, exist_ok=True)
                created_dirs.append(str(dir_path))
        
        logging.info("Created directories:")
        for dir_path in created_dirs:
            logging.info(f"  - {dir_path}")
    
    def split_data(self, train_ratio=0.8, val_ratio=0.1, use_fraction=0.5):
        """Split data into train/val/test sets"""
        all_images = self.get_all_images()
        
        if not all_images:
            raise RuntimeError(f"No images found in {self.input_dir}")
        
        # Reduce dataset size
        num_samples = int(len(all_images) * use_fraction)
        all_images = all_images[:num_samples]  # Take only fraction of data
        
        logging.info(f"Using {num_samples} images out of {len(all_images)} total images")
        
        np.random.seed(42)  # For reproducibility
        np.random.shuffle(all_images)
        
        n_images = len(all_images)
        n_train = int(n_images * train_ratio)
        n_val = int(n_images * val_ratio)
        
        splits = {
            'train': all_images[:n_train],
            'val': all_images[n_train:n_train + n_val],
            'test': all_images[n_train + n_val:]
        }
        
        for split_name, split_images in splits.items():
            logging.info(f"{split_name} split: {len(split_images)} images")
        
        return splits

    def process_image(self, img_path, split_type):
        """Process a single image and save its variations"""
        try:
            # Create output filename (preserve class information in filename)
            class_name = img_path.parent.name
            new_filename = f"{class_name}_{img_path.name}"
            
            # Read and resize image
            original = Image.open(img_path)
            original = original.resize((256, 256), Image.Resampling.LANCZOS)
            img_array = np.array(original).astype(np.float32) / 255.0
            
            # Save clean image
            clean_path = self.splits[split_type] / 'clean' / new_filename
            original.save(clean_path)
            
            # Create noisy version
            noise = np.random.normal(0, 0.1, img_array.shape)
            noisy_img = np.clip(img_array + noise, 0, 1)
            noisy_img = (noisy_img * 255).astype(np.uint8)
            Image.fromarray(noisy_img).save(
                self.splits[split_type] / 'noisy' / new_filename
            )
            
            # Create low-resolution version
            low_res = cv2.resize(img_array, (64, 64))
            low_res = cv2.resize(low_res, (256, 256))
            low_res = (low_res * 255).astype(np.uint8)
            Image.fromarray(low_res).save(
                self.splits[split_type] / 'low_res' / new_filename
            )
            
            # Create masked version
            masked = img_array.copy()
            num_masks = np.random.randint(1, 4)
            for _ in range(num_masks):
                x1, x2 = sorted(np.random.randint(0, 256, 2))
                y1, y2 = sorted(np.random.randint(0, 256, 2))
                masked[y1:y2, x1:x2] = 0
            masked = (masked * 255).astype(np.uint8)
            Image.fromarray(masked).save(
                self.splits[split_type] / 'masked' / new_filename
            )
            
        except Exception as e:
            logging.error(f"Error processing {img_path}: {e}")
            raise

    def preprocess(self):
        """Run the entire preprocessing pipeline"""
        try:
            logging.info("Setting up directories...")
            self.setup_directories()
            
            logging.info("Splitting data...")
            splits = self.split_data()
            
            total_processed = 0
            # Process each split
            for split_type, images in splits.items():
                logging.info(f"Processing {split_type} split...")
                for img_path in tqdm(images, desc=f"Processing {split_type}"):
                    self.process_image(img_path, split_type)
                    total_processed += 1
            
            logging.info(f"Preprocessing complete. Processed {total_processed} images.")
            
            # Verify final counts
            for split_type in self.splits:
                for subdir in ['clean', 'noisy', 'low_res', 'masked']:
                    dir_path = self.splits[split_type] / subdir
                    file_count = len(list(dir_path.glob('*')))
                    logging.info(f"{split_type}/{subdir}: {file_count} files")
                    
        except Exception as e:
            logging.error(f"Preprocessing failed: {e}")
            raise

if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
    
    try:
        input_dir = '/kaggle/input/eurosat/EuroSAT_RGB'
        output_dir = '/kaggle/working/data/processed'
        
        logging.info(f"Starting preprocessing...")
        preprocessor = DataPreprocessor(input_dir, output_dir)
        preprocessor.preprocess()
        
    except Exception as e:
        logging.error(f"Failed to complete preprocessing: {e}")
        raise 

In [None]:
pwd

In [None]:
mkdir data/preprocessed

In [None]:
ls

In [5]:
class RestorationDataset(Dataset):
    def __init__(self, root_dir, transform=None, task="noisy", image_size=256):
        self.root_dir = Path(root_dir)
        self.image_size = image_size
        self.transform = transform or self._get_default_transforms()
        self.task = task
        
        # Verify directory exists
        if not self.root_dir.exists():
            raise FileNotFoundError(f"Directory not found: {self.root_dir}")
            
        logging.info(f"Initializing dataset from {self.root_dir}")
        self.image_pairs = self._get_image_pairs()
        logging.info(f"Found {len(self.image_pairs)} image pairs")
        
    def _get_default_transforms(self):
        return T.Compose([
            T.Resize(self.image_size),
            T.RandomCrop(self.image_size),
            T.RandomHorizontalFlip(p=0.5),
            T.ToTensor(),
            T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
    
    def _get_image_pairs(self):
        clean_dir = self.root_dir / 'clean'
        degraded_dir = self.root_dir / self.task
        
        # Check if directories exist
        if not clean_dir.exists():
            raise FileNotFoundError(f"Clean images directory not found: {clean_dir}")
        if not degraded_dir.exists():
            raise FileNotFoundError(f"Degraded images directory not found: {degraded_dir}")
            
        # Get all image files
        clean_images = list(clean_dir.glob('*.jpg')) + list(clean_dir.glob('*.png'))
        degraded_images = list(degraded_dir.glob('*.jpg')) + list(degraded_dir.glob('*.png'))
        
        logging.info(f"Found {len(clean_images)} clean images and {len(degraded_images)} degraded images")
        
        if not clean_images:
            raise ValueError(f"No images found in clean directory: {clean_dir}")
        if not degraded_images:
            raise ValueError(f"No images found in degraded directory: {degraded_dir}")
            
        # Match pairs
        pairs = list(zip(sorted(degraded_images), sorted(clean_images)))
        
        if not pairs:
            raise ValueError("No matching image pairs found")
            
        return pairs
    
    def __len__(self):
        """Return the total number of image pairs"""
        return len(self.image_pairs)
    
    def __getitem__(self, idx):
        """Return a specific image pair"""
        degraded_path, clean_path = self.image_pairs[idx]
        
        # Load images
        degraded_img = Image.open(degraded_path).convert('RGB')
        clean_img = Image.open(clean_path).convert('RGB')
        
        # Apply transforms
        if self.transform:
            # Use same random seed for both transforms to ensure same augmentation
            seed = torch.randint(0, 2**32, (1,)).item()
            
            torch.manual_seed(seed)
            degraded_img = self.transform(degraded_img)
            
            torch.manual_seed(seed)
            clean_img = self.transform(clean_img)
        
        return {
            'degraded': degraded_img,
            'clean': clean_img,
            'path': str(clean_path)
        }

In [6]:
# 2. The Latent Diffusion Model
class LatentDiffusionModel(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        
        # Load pretrained VAE
        self.vae = AutoencoderKL.from_pretrained(
            "stabilityai/sd-vae-ft-mse",
            torch_dtype=torch.float32
        )
        
        # Freeze VAE parameters
        for param in self.vae.parameters():
            param.requires_grad = False
            
        # Initialize UNet for latent diffusion
        latent_size = config.model.image_size // 8  # VAE downsamples by 8
        self.unet = UNet2DModel(
            sample_size=latent_size,
            in_channels=4,  # Latent space channels
            out_channels=4,  # Match input channels
            layers_per_block=config.model.num_res_blocks,
            block_out_channels=(128, 128, 256, 256),  # Changed this
            down_block_types=["DownBlock2D"] * 4,
            up_block_types=["UpBlock2D"] * 4,
            norm_num_groups=8,  # Changed this
            norm_eps=1e-6
        )
    
    def encode(self, x):
        return self.vae.encode(x).latent_dist.sample()
    
    def decode(self, z):
        return self.vae.decode(z).sample
    
    def forward(self, x, t):
        return self.unet(x, t).sample

In [7]:
import torch
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler
import hydra
from omegaconf import DictConfig
from pathlib import Path
import logging
from tqdm import tqdm
import torchvision.utils as vutils

class LatentDiffusionTrainer:
    def __init__(self, config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logging.info(f"Using device: {self.device}")
        
        # Initialize model and move to device
        self.model = LatentDiffusionModel(config).to(self.device)
        
        # Initialize noise scheduler
        self.noise_scheduler = DDPMScheduler(
            num_train_timesteps=config.diffusion.num_diffusion_timesteps,
            beta_start=config.diffusion.beta_start,
            beta_end=config.diffusion.beta_end,
            beta_schedule=config.diffusion.beta_schedule
        )
        
        # Initialize optimizer
        self.optimizer = torch.optim.AdamW(
            self.model.unet.parameters(),
            lr=config.training.learning_rate
        )
        
        # Initialize learning rate scheduler
        self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=10, verbose=True
        )
        
        # Setup logging directory
        self.save_dir = Path("results")
        self.save_dir.mkdir(exist_ok=True)
        
    def train_step(self, clean_images):
        # First encode the clean images to latent space using VAE
        with torch.no_grad():
            clean_latents = self.model.encode(clean_images) * 0.18215
        
        # Sample noise and timesteps
        noise = torch.randn_like(clean_latents)
        timesteps = torch.randint(
            0, self.noise_scheduler.num_train_timesteps,
            (clean_latents.shape[0],), device=self.device
        ).long()
        
        # Add noise to latents
        noisy_latents = self.noise_scheduler.add_noise(
            clean_latents, noise, timesteps
        )
        
        # Predict noise
        noise_pred = self.model.unet(noisy_latents, timesteps)
        
        # Calculate loss
        loss = torch.nn.functional.mse_loss(noise_pred.sample, noise)
        
        # Optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()  # Return scalar value

    def validation_step(self, clean_images):
        # Similar to train_step but without gradients
        with torch.no_grad():
            clean_latents = self.model.encode(clean_images) * 0.18215
            
            noise = torch.randn_like(clean_latents)
            timesteps = torch.randint(
                0, self.noise_scheduler.num_train_timesteps,
                (clean_latents.shape[0],), device=self.device
            ).long()
            
            noisy_latents = self.noise_scheduler.add_noise(
                clean_latents, noise, timesteps
            )
            
            noise_pred = self.model.unet(noisy_latents, timesteps)
            
            loss = torch.nn.functional.mse_loss(noise_pred.sample, noise)
            
            return loss.item()  # Return scalar value

    def validate(self, val_loader):
        self.model.eval()  # Set model to evaluation mode
        val_loss = 0
        
        # Create progress bar for validation
        val_pbar = tqdm(val_loader, 
                       desc="Validating",
                       leave=False,
                       position=1,
                       bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
        
        for batch in val_pbar:
            clean_images = batch['clean'].to(self.device)
            loss = self.validation_step(clean_images)
            val_loss += loss
            
            # Update validation progress bar
            val_pbar.set_postfix({'val_loss': f"{loss:.4f}"})
        
        avg_val_loss = val_loss / len(val_loader)
        return avg_val_loss

    def train(self):
        # Initialize datasets
        train_dataset = RestorationDataset(self.config.data.train_path, task=self.config.data.task)
        val_dataset = RestorationDataset(self.config.data.val_path, task=self.config.data.task)
        
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.config.training.batch_size,
            shuffle=True,
            num_workers=self.config.training.num_workers
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=self.config.training.batch_size,
            shuffle=False,
            num_workers=self.config.training.num_workers
        )
        
        logging.info("Starting training...")
        best_val_loss = float('inf')
        
        for epoch in range(self.config.training.num_epochs):
            # Training phase
            self.model.train()
            train_loss = 0
            progress_bar = tqdm(train_loader, 
                              desc=f"Epoch {epoch}",
                              bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
            
            for step, batch in enumerate(progress_bar):
                clean_images = batch['clean'].to(self.device)
                loss = self.train_step(clean_images)
                train_loss += loss
                
                # Update progress bar
                progress_bar.set_postfix({
                    'loss': f"{loss:.4f}",
                    'lr': f"{self.optimizer.param_groups[0]['lr']:.6f}"
                })
                
                # Save sample restoration periodically
                if step % self.config.logging.sample_interval == 0:
                    self.save_sample_restoration(
                        batch['degraded'][:1],
                        epoch,
                        step
                    )
            
            avg_train_loss = train_loss / len(train_loader)
            logging.info(f"Epoch {epoch} - Train Loss: {avg_train_loss:.4f}")
            
            # Validation phase
            val_loss = self.validate(val_loader)
            logging.info(f"Epoch {epoch} - Validation Loss: {val_loss:.4f}")
            
            # Update learning rate
            self.lr_scheduler.step(val_loss)
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.save_checkpoint(epoch, is_best=True)
            
            # Regular checkpoint saving
            if epoch % self.config.training.save_interval == 0:
                self.save_checkpoint(epoch)
                
    def save_sample_restoration(self, degraded_image, epoch, step):
        self.model.eval()
        with torch.no_grad():
            # Process image
            latents = self.model.encode(degraded_image.to(self.device))
            denoised_latents = self.denoise_latents(latents)
            restored_image = self.model.decode(denoised_latents)
            
            # Save image
            sample_dir = self.save_dir / "samples"
            sample_dir.mkdir(exist_ok=True)
            vutils.save_image(
                restored_image,
                sample_dir / f"restored_e{epoch}_s{step}.png",
                normalize=True
            )
        self.model.train()
    
    def denoise_latents(self, latents):
        # Gradually denoise latents using DDPM sampling
        for t in self.noise_scheduler.timesteps:
            noise_pred = self.model.unet(latents, t).sample
            latents = self.noise_scheduler.step(
                noise_pred, t, latents
            ).prev_sample
        return latents
    
    def save_checkpoint(self, epoch, is_best=False):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.lr_scheduler.state_dict(),
        }
        
        # Save regular checkpoint
        checkpoint_path = self.save_dir / f"checkpoint_epoch_{epoch}.pt"
        torch.save(checkpoint, checkpoint_path)
        
        # Save best model
        if is_best:
            best_path = self.save_dir / "best_model.pt"
            torch.save(checkpoint, best_path)


In [None]:
# 4. Configuration
config = {
    'model': {
        'image_size': 256,
        'in_channels': 3,
        'model_channels': 128,
        'out_channels': 3,
        'num_res_blocks': 2,
        'attention_resolutions': [16, 8],
        'dropout': 0.0,
        'channel_mult': [1, 2, 3, 4],
        'conv_resample': True,
        'num_heads': 4
    },
    'training': {
        'batch_size': 8,
        'num_epochs': 1000,
        'learning_rate': 2e-4,
        'num_workers': 2,
        'save_interval': 1000,
        'eval_interval': 100
    },
    'diffusion': {
        'beta_schedule': 'linear',
        'beta_start': 0.0001,
        'beta_end': 0.02,
        'num_diffusion_timesteps': 1000
    },
    'data': {
        'train_path': "data/processed/train",
        'val_path': "data/processed/val",
        'test_path': "data/processed/test",
        'task': "noisy"  # Added this line - can be "noisy", "low_res", or "masked"
    },
    'logging': {
        'sample_interval': 500
    }
}

# Convert dict to an object for dot notation access
class Config:
    def __init__(self, config_dict):
        for key, value in config_dict.items():
            if isinstance(value, dict):
                setattr(self, key, Config(value))
            else:
                setattr(self, key, value)

config = Config(config)
# 5. Initialize and train
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

trainer = LatentDiffusionTrainer(config)
trainer.train()

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
Epoch 0: 100%|██████████| 1350/1350 [06:25<00:00,  3.50it/s, loss=0.2305, lr=0.000200] 

Validating:   0%|          | 0/169 [00:00<?, ?it/s][A
Validating:   0%|          | 0/169 [00:00<?, ?it/s, val_loss=0.0887][A
Validating:   1%|          | 1/169 [00:00<00:59,  2.80it/s, val_loss=0.0887][A
Validating:   1%|          | 1/169 [00:00<00:59,  2.80it/s, val_loss=0.0477][A
Validating:   1%|          | 2/169 [00:00<00:47,  3.55it/s, val_loss=0.0477][A
Validating:   1%|          | 2/169 [00:00<00:47,  3.55it/s, val_loss=0.0483][A
Validating:   2%|▏         | 3/169 [00:00<00:42,  3.88it/s, val_loss=0.0483][A
Validating:   2%|▏         | 3/169 [00:01<00:42,  3.88it/s, val_loss=0.0748][A
Validating:   2%|▏         | 4/169 [00:01<00:40,  4.07it/s, val_loss=0.0748][A
Validating:   2%|▏         | 4/169 [00:01<00:40,  4.07it/s, val_loss=0.0494][A
Validating:   3%|▎         | 5/169 [00:01<00:39,  4.

In [1]:
ls data/processed/train/

ls: cannot access 'data/processed/train/': No such file or directory
