In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
from diffusers import UNet2DModel, DDPMScheduler
import os
import matplotlib.pyplot as plt
from typing import Optional, Tuple, Union
import traceback


In [2]:
class LandcoverDiffusionModel:
    """
    A diffusion model for generating realistic land cover maps using a UNet architecture.
    The model takes a noise input and gradually denoises it into a coherent land cover map
    with different terrain classes.
    """
    
    def __init__(self, class_data_path: Union[str, Path] = 'E:/Research/RP1/archive/class_dict.csv'):
        """
        Initialize the land cover diffusion model.
        
        Args:
            class_data_path: Path to CSV file containing class definitions with RGB values
        """
        # Load and validate class definitions
        self.class_df = pd.read_csv(class_data_path)
        required_columns = {'name', 'r', 'g', 'b'}
        if not all(col in self.class_df.columns for col in required_columns):
            raise ValueError(f"Class data CSV must contain columns: {required_columns}")
            
        self.num_classes = len(self.class_df)
        
        # Create color map from CSV data
        self.class_colors = {
            idx: [r/255.0, g/255.0, b/255.0] 
            for idx, (_, r, g, b) in enumerate(self.class_df[['r', 'g', 'b']].itertuples())
        }
        
        self.class_names = self.class_df['name'].tolist()
        
        # Initialize model components
        self._initialize_unet()
        self._initialize_scheduler()
        
    def _initialize_unet(self):
        """Initialize the UNet model with appropriate architecture for land cover generation."""
        self.unet = UNet2DModel(
            sample_size=256,
            in_channels=self.num_classes,
            out_channels=self.num_classes,
            layers_per_block=2,
            block_out_channels=(128, 128, 256, 256, 512, 512),
            down_block_types=(
                "DownBlock2D",
                "DownBlock2D",
                "DownBlock2D",
                "DownBlock2D",
                "AttnDownBlock2D",
                "DownBlock2D",
            ),
            up_block_types=(
                "UpBlock2D",
                "AttnUpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
            ),
        )
        
    def _initialize_scheduler(self):
        """Initialize the noise scheduler for the diffusion process."""
        self.noise_scheduler = DDPMScheduler(
            num_train_timesteps=1000,
            beta_start=0.00085,
            beta_end=0.012,
        )
    
    def generate_landcover_map(
        self, 
        initial_noise: Optional[torch.Tensor] = None,
        num_inference_steps: int = 50,
        batch_size: int = 1,
        image_size: Tuple[int, int] = (256, 256),
        device: Optional[torch.device] = None,
    ) -> torch.Tensor:
        """
        Generate a land cover map using the diffusion model.
        
        Args:
            initial_noise: Optional initial noise tensor. If None, random noise will be generated
            num_inference_steps: Number of denoising steps
            batch_size: Number of maps to generate in parallel
            image_size: Size of the output map (height, width)
            device: Device to run generation on. If None, will use CUDA if available
            
        Returns:
            Tensor of shape (batch_size, num_classes, height, width) containing class probabilities
        """
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            
            print(f"Using CUDA device: {torch.cuda.get_device_name()}")
            
        self.unet.to(device)
        
        if initial_noise is None:
            initial_noise = torch.randn(
                (batch_size, self.num_classes, *image_size),
                device=device
            )
        
        self.noise_scheduler.set_timesteps(num_inference_steps)
        current_noise = initial_noise
        
        # Denoising loop
        for t in self.noise_scheduler.timesteps:
            with torch.no_grad():
                noise_pred = self.unet(current_noise, t).sample
                
            current_noise = self.noise_scheduler.step(
                noise_pred,
                t,
                current_noise
            ).prev_sample
        
        # Get final class probabilities
        return F.softmax(current_noise, dim=1)

    def visualize_landcover(
        self,
        generated_map: Union[torch.Tensor, np.ndarray],
        save_path: Optional[Union[str, Path]] = None,
        figure_size: Tuple[int, int] = (10, 10),
        dpi: int = 300
    ) -> None:
        """
        Visualize the generated land cover map with a color-coded image.
        
        Args:
            generated_map: Tensor of shape (batch_size, num_classes, height, width)
            save_path: Optional path to save the visualization
            figure_size: Size of the output figure in inches
            dpi: DPI for saved figure
        """
        # Convert to numpy if needed
        if torch.is_tensor(generated_map):
            generated_map = generated_map.detach().cpu().numpy()
        
        class_predictions = np.argmax(generated_map[0], axis=0)
        height, width = class_predictions.shape
        rgb_image = np.zeros((height, width, 3))
        
        # Create colored image
        for class_idx, color in self.class_colors.items():
            mask = class_predictions == class_idx
            rgb_image[mask] = color
            
        plt.figure(figsize=figure_size)
        plt.imshow(rgb_image)
        plt.axis('off')
        
        # Add legend
        legend_elements = [
            plt.Rectangle((0, 0), 1, 1, fc=color)
            for color in self.class_colors.values()
        ]
        plt.legend(
            legend_elements,
            self.class_names,
            loc='center left',
            bbox_to_anchor=(1, 0.5)
        )
        
        if save_path:
            plt.savefig(save_path, bbox_inches='tight', dpi=dpi)
            plt.close()
        else:
            plt.show()

    def save_model(self, save_dir: Union[str, Path]) -> None:
        """
        Save the model state and configuration.
        
        Args:
            save_dir: Directory to save model files
        """
        save_dir = Path(save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)
        
        torch.save(self.unet.state_dict(), save_dir / 'unet.pt')
        self.noise_scheduler.save_config(save_dir / 'scheduler_config.json')
        self.class_df.to_csv(save_dir / 'class_dict.csv', index=False)

    @classmethod
    def load_model(cls, load_dir: Union[str, Path]) -> 'LandcoverDiffusionModel':
        """
        Load a saved model from disk.
        
        Args:
            load_dir: Directory containing saved model files
            
        Returns:
            Loaded LandcoverDiffusionModel instance
        """
        load_dir = Path(load_dir)
        if not load_dir.exists():
            raise ValueError(f"Model directory {load_dir} does not exist")
            
        # Initialize new model with saved class definitions
        model = cls(class_data_path=load_dir / 'class_dict.csv')
        
        # Load saved states
        model.unet.load_state_dict(
            torch.load(load_dir / 'unet.pt')
        )
        model.noise_scheduler = DDPMScheduler.from_config(
            load_dir / 'scheduler_config.json'
        )
        
        return model

In [3]:
class LandcoverDataset(Dataset):
    """Dataset for loading and preprocessing landcover mask images"""
    def __init__(self, image_dir: str, class_dict_path: str, target_size: tuple = (256, 256)):
        self.image_dir = Path(image_dir)
        self.target_size = target_size
        
        # Validate input directory
        if not self.image_dir.exists():
            raise ValueError(f"Directory not found: {self.image_dir}")
        
        # Load and validate class definitions first
        try:
            self.class_df = pd.read_csv(class_dict_path)
            required_columns = {'name', 'r', 'g', 'b'}
            if not all(col in self.class_df.columns for col in required_columns):
                raise ValueError(f"Class CSV must contain columns: {required_columns}")
        except Exception as e:
            raise ValueError(f"Error loading class definitions: {e}")
        
        # Create color to class mapping
        self.color_to_class = {
            (r, g, b): idx 
            for idx, (r, g, b) in enumerate(
                self.class_df[['r', 'g', 'b']].itertuples(index=False)
            )
        }
        
        self.num_classes = len(self.class_df)
        
        # Find and validate PNG files
        self.image_files = []
        for file in self.image_dir.glob('*.png'):
            try:
                # Test open each image
                with Image.open(file) as img:
                    if img.mode != 'RGB':
                        print(f"Warning: {file.name} is not in RGB mode. Will convert during loading.")
                self.image_files.append(file)
            except Exception as e:
                print(f"Warning: Could not open {file.name}: {e}")
                continue
        
        if not self.image_files:
            raise ValueError(f"No valid PNG files found in {image_dir}")
        
        self.image_files.sort()
        print(f"Successfully loaded {len(self.image_files)} PNG files")
        print(f"Number of classes: {self.num_classes}")
        print(f"Target size: {self.target_size}")
        
        # Validate first image completely
        self._validate_first_image()

    def _validate_first_image(self):
        """Fully validate the first image to catch potential issues early"""
        try:
            sample = self[0]
            if not isinstance(sample, torch.Tensor):
                raise ValueError("Dataset output is not a tensor")
            if sample.shape[0] != self.num_classes:
                raise ValueError(f"Expected {self.num_classes} channels, got {sample.shape[0]}")
            print("First image validated successfully")
        except Exception as e:
            print(f"Error validating first image: {e}")
            print("Full traceback:")
            traceback.print_exc()
            raise

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        try:
            # Load and convert image
            with Image.open(img_path) as image:
                if image.mode != 'RGB':
                    image = image.convert('RGB')
                
                # Get original size before resizing
                orig_size = image.size
                if orig_size != (2448, 2448):
                    print(f"Warning: Image {img_path.name} size is {orig_size}, expected (2448, 2448)")
                
                # Resize image
                image = image.resize(self.target_size, Image.Resampling.NEAREST)
                image_array = np.array(image)

            # Create one-hot encoded tensor
            target = torch.zeros((self.num_classes, *self.target_size))
            
            # Track if any pixels are unclassified
            classified_pixels = np.zeros(image_array.shape[:2], dtype=bool)
            
            # Convert RGB to class indices
            for color, class_idx in self.color_to_class.items():
                mask = np.all(image_array == color, axis=2)
                target[class_idx][mask] = 1
                classified_pixels |= mask
            
            # Check for unclassified pixels
            unclassified = ~classified_pixels
            if np.any(unclassified):
                unclassified_colors = set(map(tuple, image_array[unclassified].reshape(-1, 3)))
                print(f"Warning: Image {img_path.name} contains unclassified colors: {unclassified_colors}")
            
            return target

        except Exception as e:
            print(f"\nError processing image {img_path}:")
            print(f"Error type: {type(e).__name__}")
            print(f"Error message: {str(e)}")
            print("Full traceback:")
            traceback.print_exc()
            raise


In [4]:
def train_landcover_model(
    data_dir: str,
    class_dict_path: str,
    output_dir: str,
    num_epochs: int = 10,
    batch_size: int = 4,
    learning_rate: float = 1e-4,
    device: str = None,
    save_interval: int = 10
):
    """Training function with enhanced error handling"""
    try:
        # Initialize dataset with more detailed error reporting
        print("Initializing dataset...")
        dataset = LandcoverDataset(data_dir, class_dict_path)
        
        # Create DataLoader with reduced number of workers and enabled debugging
        print("Creating DataLoader...")
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=0,  # Set to 0 for debugging
            pin_memory=False  # Disable pin_memory for debugging
        )
        
        # Test the first batch
        print("Testing first batch...")
        first_batch = next(iter(dataloader))
        print(f"First batch shape: {first_batch.shape}")
        
        # Continue with rest of training...
            # Initialize model
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            print(f"Using CUDA device: {torch.cuda.get_device_name()}")
            
        model = LandcoverDiffusionModel()
        model.unet.to(device)
    
    # Initialize optimizer
        optimizer = torch.optim.AdamW(model.unet.parameters(), lr=learning_rate)
    
    # Training loop
        losses = []
        for epoch in range(num_epochs):
            model.unet.train()
            epoch_losses = []
        
            progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')
            for batch in progress_bar:
                batch = batch.to(device)
            
            # Sample noise to add to the images
                noise = torch.randn_like(batch)
                timesteps = torch.randint(
                    0, model.noise_scheduler.num_train_timesteps, 
                    (batch.shape[0],), device=device
                ).long()
            
            # Add noise to the clean images according to the noise magnitude at each timestep
                noisy_images = model.noise_scheduler.add_noise(batch, noise, timesteps)
            
            # Get the model prediction for the noise
                noise_pred = model.unet(noisy_images, timesteps).sample
            
            # Calculate the loss
                loss = F.mse_loss(noise_pred, noise)
            
            # Update model
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            
                epoch_losses.append(loss.item())
                progress_bar.set_postfix({'loss': sum(epoch_losses) / len(epoch_losses)})
        
        # Save model checkpoint
            if (epoch + 1) % 10 == 0:
                model.save_model(output_dir + f'checkpoint_epoch_{epoch+1}')
        
        # Record average epoch loss
            avg_loss = sum(epoch_losses) / len(epoch_losses)
            losses.append(avg_loss)
        
        # Plot and save loss curve
            plt.figure(figsize=(10, 5))
            plt.plot(losses)
            plt.title('Training Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.savefig(output_dir + 'loss_curve.png')
            plt.close()
        
            print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.6f}')
        
    except Exception as e:
        print("\nError during initialization:")
        print(f"Error type: {type(e).__name__}")
        print(f"Error message: {str(e)}")
        print("Full traceback:")
        traceback.print_exc()
        raise


In [5]:
train_landcover_model(
    data_dir='E:/Research/RP1/archive/train/',
    class_dict_path='E:/Research/RP1/archive/class_dict.csv',
    output_dir='E:/Research/RP1/archive/landcover_model_output/'

    #target_size=(512, 512)  # Would need to add this parameter to the function definition
)

Initializing dataset...
Successfully loaded 803 PNG files
Number of classes: 7
Target size: (256, 256)
First image validated successfully
Creating DataLoader...
Testing first batch...
First batch shape: torch.Size([4, 7, 256, 256])
Using CUDA device: NVIDIA GeForce RTX 3070


  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
  hidden_states = F.scaled_dot_product_attention(
Epoch 1/10: 100%|██████████| 201/201 [36:14<00:00, 10.82s/it, loss=0.118]


Epoch 1/10, Average Loss: 0.117926


Epoch 2/10: 100%|██████████| 201/201 [35:21<00:00, 10.56s/it, loss=0.04]  


Epoch 2/10, Average Loss: 0.039968


Epoch 3/10: 100%|██████████| 201/201 [32:34<00:00,  9.72s/it, loss=0.0269]


Epoch 3/10, Average Loss: 0.026862


Epoch 4/10: 100%|██████████| 201/201 [32:21<00:00,  9.66s/it, loss=0.0232]


Epoch 4/10, Average Loss: 0.023182


Epoch 5/10: 100%|██████████| 201/201 [32:31<00:00,  9.71s/it, loss=0.018] 


Epoch 5/10, Average Loss: 0.018047


Epoch 6/10: 100%|██████████| 201/201 [32:10<00:00,  9.61s/it, loss=0.018] 


Epoch 6/10, Average Loss: 0.017953


Epoch 7/10: 100%|██████████| 201/201 [32:31<00:00,  9.71s/it, loss=0.013] 


Epoch 7/10, Average Loss: 0.013024


Epoch 8/10: 100%|██████████| 201/201 [32:23<00:00,  9.67s/it, loss=0.0136]


Epoch 8/10, Average Loss: 0.013565


Epoch 9/10: 100%|██████████| 201/201 [29:45<00:00,  8.88s/it, loss=0.0136]


Epoch 9/10, Average Loss: 0.013628


Epoch 10/10: 100%|██████████| 201/201 [31:36<00:00,  9.44s/it, loss=0.0105]


Epoch 10/10, Average Loss: 0.010543


In [7]:
model_dir = "E:/Research/RP1/archive/landcover_model_output/checkpoint_epoch_10"  # Adjust path to your saved model
model = LandcoverDiffusionModel.load_model(model_dir)

# Generate a sample map
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.unet.to(device)

# Generate the landcover map
generated_map = model.generate_landcover_map(
    batch_size=1,
    num_inference_steps=50,  # You can adjust this - more steps = potentially better quality
    image_size=(256, 256),  # Match the training size
    device=device
)

# Visualize and save the result
output_path = Path("E:/Research/RP1/archive/landcover_model_output/generated_landcover_map.png")
model.visualize_landcover(
    generated_map,
    save_path=output_path,
    figure_size=(12, 12),
    dpi=300
)

print(f"Generated landcover map saved to {output_path.absolute()}")

  torch.load(load_dir / 'unet.pt')
  deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)


Generated landcover map saved to E:\Research\RP1\archive\landcover_model_output\generated_landcover_map.png
