In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import rasterio
import numpy as np
import matplotlib.pyplot as plt
import os
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
from tqdm import tqdm
import warnings
import time

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore", category=UserWarning)

# Set environment variable for CUDA error debugging (optional)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == "cuda":
    print(torch.cuda.get_device_name(0))
    print(f"Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
    print(f"Memory Reserved: {torch.cuda.memory_reserved(0) / 1024**2:.2f} MB")

class GeoTIFFDataset(Dataset):
    def __init__(self, root_dir, patch_size=64, stride=32):
        self.root_dir = root_dir
        self.patch_size = patch_size
        self.stride = stride  # Stride for overlapping patches
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith('.tif')]
        self.samples = []
        self._create_sample_list()
        
    def _create_sample_list(self):
        """Create a list of (file_path, row, col) tuples for patch extraction with validation"""
        print("Creating sample list...")
        for img_file in self.image_files:
            img_path = os.path.join(self.root_dir, img_file)
            try:
                with rasterio.open(img_path) as src:
                    height, width = src.height, src.width
                    
                    # Skip files that are too small
                    if height < self.patch_size or width < self.patch_size:
                        print(f"Skipping {img_file} - too small for patch size {self.patch_size}")
                        continue
                    
                    # Create patches with stride
                    for row in range(0, height - self.patch_size + 1, self.stride):
                        for col in range(0, width - self.patch_size + 1, self.stride):
                            self.samples.append((img_path, row, col))
            except Exception as e:
                print(f"Error processing {img_file}: {e}")
        
        print(f"Created {len(self.samples)} samples from {len(self.image_files)} files")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, row, col = self.samples[idx]
        
        try:
            with rasterio.open(img_path) as src:
                # Read a patch
                window = rasterio.windows.Window(col, row, self.patch_size, self.patch_size)
                image = src.read(window=window).astype(np.float32)
                
                if image.shape[0] < 2:  # Need at least 2 bands (1 for data, 1 for label)
                    raise ValueError(f"File {img_path} has insufficient bands: {image.shape[0]}")
                
                # Extract label from last band and features from others
                label_band = image[-1].astype(np.int64)
                feature_bands = image[:-1]
                
                # Handle potential NaN/Inf values
                feature_bands = np.nan_to_num(feature_bands, nan=0.0, posinf=0.0, neginf=0.0)
                
                # Normalize feature bands (with safety checks)
                for i in range(feature_bands.shape[0]):
                    band_std = np.std(feature_bands[i])
                    if band_std > 1e-6:  # Only normalize if std dev is non-zero
                        feature_bands[i] = (feature_bands[i] - np.mean(feature_bands[i])) / band_std
                
                # Get most common label in the patch (with non-negative check)
                unique_labels, counts = np.unique(label_band, return_counts=True)
                if len(unique_labels) > 0:
                    most_common_idx = np.argmax(counts)
                    most_common_label = unique_labels[most_common_idx]
                    
                    # Ensure label is non-negative (for CrossEntropyLoss)
                    if most_common_label < 0:
                        most_common_label = 0
                else:
                    most_common_label = 0  # Default if no labels found
                
                # Convert to torch tensor
                feature_tensor = torch.tensor(feature_bands, dtype=torch.float32)
                label_tensor = torch.tensor(most_common_label, dtype=torch.long)
                
                # Final validation
                if torch.isnan(feature_tensor).any() or torch.isinf(feature_tensor).any():
                    print(f"Warning: NaN or Inf values found in sample {idx} after processing")
                    feature_tensor = torch.nan_to_num(feature_tensor)
                
                return feature_tensor, label_tensor
                
        except Exception as e:
            print(f"Error loading sample {idx} from {img_path}: {e}")
            # Return a dummy sample as fallback (better than crashing)
            dummy_features = torch.zeros((1, self.patch_size, self.patch_size), dtype=torch.float32)
            dummy_label = torch.tensor(0, dtype=torch.long)
            return dummy_features, dummy_label

# Simple CNN model with robust input handling
class SimpleCNN(nn.Module):
    def __init__(self, input_channels, num_classes):
        super(SimpleCNN, self).__init__()
        
        # More robust architecture with fewer layers
        self.features = nn.Sequential(
            nn.Conv2d(input_channels, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.AdaptiveAvgPool2d((4, 4))  # Fixed output size regardless of input
        )
        
        # Calculate feature size after pooling
        feature_size = 32 * 4 * 4
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(feature_size, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )
    
    def forward(self, x):
        # Handle unexpected input dimensions
        if len(x.shape) != 4:  # Expected: [batch, channels, height, width]
            if len(x.shape) == 3:  # Missing batch dimension
                x = x.unsqueeze(0)
            else:
                raise ValueError(f"Unexpected input shape: {x.shape}")
        
        # Handle NaN/Inf values
        if torch.isnan(x).any() or torch.isinf(x).any():
            x = torch.nan_to_num(x)
        
        x = self.features(x)
        x = self.classifier(x)
        return x

def main():
    # Create output directory
    os.makedirs("Export_Model", exist_ok=True)
    
    # Load Data
    root_dir = r"Raster_Train"  # Change to your actual root directory path
    
    # Try to create dataset with smaller patch size for better stability
    patch_size = 64  # Smaller patches are easier to process
    
    try:
        # Create dataset
        dataset = GeoTIFFDataset(root_dir, patch_size=patch_size)
        
        # Check if the dataset is empty
        if len(dataset) == 0:
            raise ValueError(f"No valid samples found in {root_dir}. Please check the data.")
        
        # Try to fetch a single sample to verify everything works
        print("Validating dataset by fetching a sample...")
        try:
            sample_input, sample_label = dataset[0]
            print(f"Sample input shape: {sample_input.shape}, Sample label: {sample_label.item()}")
            print("Sample validation successful!")
        except Exception as e:
            print(f"Error during sample validation: {e}")
            raise
            
        # Split dataset indices
        print("Splitting dataset into train/val sets...")
        indices = list(range(len(dataset)))
        train_indices, val_indices = train_test_split(indices, test_size=0.2, random_state=42)
        
        # Create subset datasets
        train_dataset = torch.utils.data.Subset(dataset, train_indices)
        val_dataset = torch.utils.data.Subset(dataset, val_indices)
        
        # Start with very small batch size and 0 workers for safety
        BATCH_SIZE = 8  # Small batch size to avoid CUDA issues
        
        # Create DataLoaders with safety settings
        train_loader = DataLoader(
            train_dataset, 
            batch_size=BATCH_SIZE, 
            shuffle=True, 
            num_workers=0,  # No multiprocessing to avoid issues
            drop_last=True  # Drop last incomplete batch
        )
        
        val_loader = DataLoader(
            val_dataset, 
            batch_size=BATCH_SIZE, 
            shuffle=False, 
            num_workers=0
        )
        
        # Verify loader works by fetching a batch
        print("Verifying DataLoader by fetching a batch...")
        try:
            inputs, labels = next(iter(train_loader))
            print(f"Batch input shape: {inputs.shape}, Batch labels shape: {labels.shape}")
            print("DataLoader verification successful!")
        except Exception as e:
            print(f"Error during DataLoader verification: {e}")
            raise
        
        # Determine input channels and number of classes
        num_bands = sample_input.shape[0]  # Number of feature bands
        
        # Safely determine number of classes by checking unique labels
        try:
            all_labels = []
            for _, label in [dataset[i] for i in range(min(100, len(dataset)))]:
                all_labels.append(label.item())
            unique_labels = np.unique(all_labels)
            # Make sure classes start from 0 and are consecutive
            num_classes = len(unique_labels)
            if num_classes < 2:  # Binary or single class
                num_classes = 2  # Minimum for CrossEntropyLoss
            print(f"Detected classes: {unique_labels}")
            print(f"Using num_classes = {num_classes}")
        except Exception as e:
            print(f"Error detecting classes: {e}")
            num_classes = 2  # Default to binary classification
            print(f"Using default num_classes = {num_classes}")
        
        # Model setup
        model = SimpleCNN(input_channels=num_bands, num_classes=num_classes).to(device)
        
        # Print model summary and parameter count
        print(model)
        param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Trainable parameters: {param_count:,}")
        
        # Loss & Optimizer with safer defaults
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  # SGD instead of Adam for stability
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=3, verbose=True
        )
        
        # Turn off mixed precision for stability
        use_amp = False
        scaler = torch.amp.GradScaler(enabled=use_amp)
        
        # Training Loop with extensive error handling
        num_epochs = 30  # Reduced epochs for initial testing
        train_losses = []
        val_losses = []
        best_val_loss = float('inf')
        patience_counter = 0
        patience_limit = 5  # Early stopping patience
        
        print("Starting training...")
        for epoch in range(num_epochs):
            start_time = time.time()
            
            # Training phase
            model.train()
            running_loss = 0.0
            batch_count = 0
            
            # Training with extensive error handling
            try:
                train_progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
                
                for batch_idx, (inputs, labels) in enumerate(train_progress):
                    try:
                        # Move to device with error handling
                        try:
                            inputs, labels = inputs.to(device), labels.to(device)
                        except RuntimeError as e:
                            print(f"Error moving batch {batch_idx} to device: {e}")
                            continue  # Skip this batch
                        
                        # Check for invalid input values
                        if torch.isnan(inputs).any() or torch.isinf(inputs).any():
                            print(f"Warning: NaN/Inf in inputs at batch {batch_idx}")
                            inputs = torch.nan_to_num(inputs)
                        
                        # Check input range
                        if inputs.max() > 1e6 or inputs.min() < -1e6:
                            print(f"Warning: Extreme values in inputs at batch {batch_idx}")
                            inputs = torch.clamp(inputs, -1e6, 1e6)
                        
                        # Safeguard batch size
                        if inputs.size(0) == 1:
                            print(f"Skipping single-item batch {batch_idx}")
                            continue  # Skip very small batches
                        
                        # Forward pass (without mixed precision for stability)
                        outputs = model(inputs)
                        
                        # Check labels before computing loss
                        if torch.min(labels) < 0:
                            print(f"Warning: Negative labels found in batch {batch_idx}")
                            labels = torch.clamp(labels, min=0)
                            
                        if torch.max(labels) >= num_classes:
                            print(f"Warning: Labels >= num_classes ({num_classes}) found in batch {batch_idx}")
                            labels = torch.clamp(labels, max=num_classes-1)
                        
                        # Compute loss
                        loss = criterion(outputs, labels)
                        
                        # Check if loss is valid
                        if torch.isnan(loss) or torch.isinf(loss):
                            print(f"Skipping batch {batch_idx} due to invalid loss: {loss.item()}")
                            continue  # Skip this batch
                        
                        # Backpropagation and optimization
                        optimizer.zero_grad()
                        loss.backward()
                        
                        # Gradient clipping to prevent exploding gradients
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                        
                        # Step optimizer
                        optimizer.step()
                        
                        # Update statistics
                        running_loss += loss.item()
                        batch_count += 1
                        
                        # Update progress bar
                        if batch_count > 0:
                            avg_loss = running_loss / batch_count
                            train_progress.set_postfix(loss=f"{avg_loss:.4f}")
                    
                    except Exception as e:
                        print(f"Error in batch {batch_idx}: {e}")
                        # Continue with next batch
            
            except Exception as e:
                print(f"Error during training: {e}")
            
            # Calculate epoch loss safely
            train_loss = running_loss / max(1, batch_count)  # Avoid division by zero
            train_losses.append(train_loss)
            
            # Validation phase with error handling
            model.eval()
            val_loss = 0.0
            val_batch_count = 0
            all_preds = []
            all_labels = []
            
            try:
                with torch.no_grad():
                    for batch_idx, (inputs, labels) in enumerate(tqdm(val_loader, desc="Validation")):
                        try:
                            inputs, labels = inputs.to(device), labels.to(device)
                            
                            # Check for invalid values
                            if torch.isnan(inputs).any() or torch.isinf(inputs).any():
                                inputs = torch.nan_to_num(inputs)
                            
                            # Forward pass
                            outputs = model(inputs)
                            
                            # Validate labels
                            if torch.min(labels) < 0 or torch.max(labels) >= num_classes:
                                labels = torch.clamp(labels, 0, num_classes-1)
                            
                            # Compute loss
                            loss = criterion(outputs, labels)
                            
                            if not (torch.isnan(loss) or torch.isinf(loss)):
                                val_loss += loss.item()
                                val_batch_count += 1
                                
                                # Collect predictions for metrics
                                preds = outputs.argmax(dim=1)
                                all_preds.append(preds.cpu().numpy())
                                all_labels.append(labels.cpu().numpy())
                            
                        except Exception as e:
                            print(f"Error in validation batch {batch_idx}: {e}")
            except Exception as e:
                print(f"Error during validation: {e}")
            
            # Calculate validation loss safely
            epoch_val_loss = val_loss / max(1, val_batch_count)
            val_losses.append(epoch_val_loss)
            
            # Calculate epoch time
            epoch_time = time.time() - start_time
            
            # Print epoch summary
            print(f"Epoch [{epoch+1}/{num_epochs}], "
                  f"Train Loss: {train_loss:.4f}, "
                  f"Val Loss: {epoch_val_loss:.4f}, "
                  f"Time: {epoch_time:.2f}s")
            
            # Update learning rate
            scheduler.step(epoch_val_loss)
            
            # Save the best model
            if epoch_val_loss < best_val_loss:
                best_val_loss = epoch_val_loss
                print(f"Saving best model with validation loss: {best_val_loss:.4f}")
                torch.save(model.state_dict(), "Export_Model/CNN_model_best.pth")
                patience_counter = 0
            else:
                patience_counter += 1
            
            # Early stopping
            if patience_counter >= patience_limit:
                print(f"Early stopping triggered after {epoch+1} epochs")
                break
            
            # Free memory
            if device.type == 'cuda':
                torch.cuda.empty_cache()
        
        print("Training completed!")
        
        # Plotting Training and Validation Losses
        plt.figure(figsize=(10, 5))
        plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss')
        plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.title('Training and Validation Loss')
        plt.show()
        
        # Final evaluation (if we have predictions)
        if all_preds and all_labels:
            try:
                all_preds = np.concatenate(all_preds)
                all_labels = np.concatenate(all_labels)
                
                # Classification Report
                print("Classification Report:")
                print(classification_report(all_labels, all_preds))
                
                # Confusion Matrix
                cm = confusion_matrix(all_labels, all_preds)
                plt.figure(figsize=(10, 8))
                sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
                plt.xlabel('Predicted Labels')
                plt.ylabel('True Labels')
                plt.title('Confusion Matrix')
                plt.show()
            except Exception as e:
                print(f"Error generating final metrics: {e}")
        
        # Save the final model
        print("Saving final model...")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_losses,
            'val_loss': val_losses
        }, "Export_Model/CNN_model_final.pth")
        
    except Exception as e:
        print(f"Critical error: {e}")
        
if __name__ == "__main__":
    main()