In [None]:
import os
import csv
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from diffusers import StableDiffusionXLImg2ImgPipeline
import torch.optim as optim

In [None]:

# Usage
csv_path = r"C:\TjallingData\greenearthnet\metadata_total_imputed_with_seasons_usable_5050.csv"
image_dir = r"C:\TjallingData\greenearthnet\train\NIR_combined"
max_minicubes = None  



class SatelliteImageDataset(Dataset):
    def __init__(self, csv_path, image_dir, max_minicubes=None):
        self.data = []
        self.image_dir = image_dir
        
        self.transform = transforms.Compose([
            transforms.Resize((1024, 1024)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])
        
        self.metadata = self.load_metadata(csv_path)
        minicube_data = {}
        minicube_count = 0

        with open(csv_path, 'r') as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                if row['usable'].lower() == 'yes':
                    minicube = row['minicube']
                    if minicube not in minicube_data:
                        minicube_data[minicube] = []
                        minicube_count += 1
                        if max_minicubes and minicube_count > max_minicubes:
                            break
                    
                    # Construct the filename with _NIR before the date
                    file_name = f"{minicube}_NIR_{row['frame_date']}.png"
                    minicube_data[minicube].append((file_name, row['frame_date'], int(row['number'])))

        for minicube, images in minicube_data.items():
            sorted_images = sorted(images, key=lambda x: x[2])  # Sort by image number
            for i in range(len(sorted_images) - 1):
                self.data.append((minicube, *sorted_images[i], *sorted_images[i+1][1:]))
        
        print(f"Total usable data points: {len(self.data)}")
        print(f"Number of minicubes loaded: {minicube_count}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        minicube, input_file, input_date, input_number, target_date, target_number = self.data[idx]
        
        # Use the correct file naming convention
        input_image = self.load_image(input_file)
        target_image = self.load_image(f"{minicube}_NIR_{target_date}.png")
        
        # Convert single-channel images to 3-channel
        input_image = input_image.repeat(3, 1, 1)
        target_image = target_image.repeat(3, 1, 1)
        
        input_image = input_image.requires_grad_(True)
        
        prompt, input_num, target_num = self.generate_prompt(minicube, str(input_number), str(target_number))
        
        input_climate = self.get_metadata(minicube, str(input_number))
        target_climate = self.get_metadata(minicube, str(target_number))
        
        print(f"Input image: {input_file}, Target image: {minicube}_NIR_{target_date}.png")
        print(f"Input image number: {input_number}, Target image number: {target_number}")
        print(f"Input climate data: {input_climate}")
        print(f"Target climate data: {target_climate}")
        print(f"Full prompt: {prompt}")
        
        return input_image, target_image, input_number, target_number, prompt

    def load_image(self, filename):
        path = os.path.join(self.image_dir, filename)
        
        if os.path.exists(path):
            # Load as grayscale image
            image = Image.open(path).convert("L")
            image = self.transform(image)
            return image
        else:
            raise FileNotFoundError(f"Image not found: {path}")

    def load_metadata(self, metadata_csv_path):
        metadata = {}
        with open(metadata_csv_path, 'r') as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                minicube = row['minicube'] 
                number = row['number']
                
                if minicube not in metadata:
                    metadata[minicube] = {}
                metadata[minicube][number] = {
                    'Temperature Avg': float(row['Temperature Avg']),
                    'Sea-level Pressure': float(row['Sea-level Pressure']),
                    'Shortwave Downwelling Radiation': float(row['Shortwave Downwelling Radiation']),
                    'Relative Humidity': float(row['Relative Humidity']),
                    'region': row['region'],
                    'season': row['season'],
                    'number': row['number']
                }
        return metadata

    def get_metadata(self, minicube, number):
        if minicube in self.metadata and number in self.metadata[minicube]:
            return self.metadata[minicube][number]
        else:
            print(f"Warning: Metadata not found for minicube {minicube}, number {number}")
            return None



    def generate_prompt(self, minicube, input_number, target_number):
        input_data = self.get_metadata(minicube, input_number)
        target_data = self.get_metadata(minicube, target_number)
        
        if input_data is None or target_data is None:
            return "Orthophoto taken by a satellite of a nature region", None, None
        
        differences = []
        for key in ['Temperature Avg', 'Sea-level Pressure', 'Shortwave Downwelling Radiation', 'Relative Humidity']:
            if key in input_data and key in target_data:
                diff = (target_data[key] - input_data[key]) / input_data[key] * 100  # Convert to percentage change
                if abs(diff) > 5:  # Only include significant changes (5%)
                    if key == 'Temperature Avg':
                        differences.append(f"{'warmer' if diff > 0 else 'colder'} by {abs(diff):.2f}%")
                    elif key == 'Sea-level Pressure':
                        differences.append(f"{'higher pressure' if diff > 0 else 'lower pressure'} by {abs(diff):.2f}%")
                    elif key == 'Shortwave Downwelling Radiation':
                        differences.append(f"{'more solar radiation' if diff > 0 else 'less solar radiation'} by {abs(diff):.2f}%")
                    elif key == 'Relative Humidity':
                        differences.append(f"{'more humid' if diff > 0 else 'drier'} by {abs(diff):.2f}%")
        
        region = input_data.get('region', 'unknown region')
        season = input_data.get('season', 'unknown season')
        prompt = f"Orthophoto taken by a satellite of {region} in {season} that is " + ", ".join(differences)
        return prompt, input_data['number'], target_data['number']

# Create dataset with a limit on the number of minicubes
max_minicubes = None  
dataset = SatelliteImageDataset(csv_path, image_dir, max_minicubes=max_minicubes)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

# Load the model
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")

# Training loop
num_epochs = 1
optimizer = torch.optim.AdamW(pipeline.unet.parameters(), lr=1e-5)

# hyperparameters
strength = 0.3  
guidance_scale = 3.0  
num_inference_steps = 50  

# Checkpointing parameters
checkpoint_interval = 10  # Save checkpoint every 10 batches
checkpoint_path = "stabilityai/stable-diffusion-xl-base-1.0"

for epoch in range(num_epochs):
    for batch_idx, (input_image, target_image, input_number, target_number, prompt) in enumerate(dataloader):
        input_image = input_image.to("cuda")
        target_image = target_image.to("cuda")
        
        # Generate image
        generated_images = pipeline(
            prompt=prompt[0],  # Use the generated prompt
            image=input_image,
            strength=strength,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps
        ).images

        # Convert output image to tensor and ensure it requires gradients
        output_tensor = transforms.ToTensor()(generated_images[0]).to("cuda").requires_grad_(True)

        # Calculate loss
        loss = torch.nn.functional.mse_loss(output_tensor, target_image)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Input: {input_number.item()}, Target: {target_number.item()}, Loss: {loss.item()}")

        # Checkpointing
        if (batch_idx + 1) % checkpoint_interval == 0:
            pipeline.save_pretrained(checkpoint_path)
            torch.save(optimizer.state_dict(), os.path.join(checkpoint_path, "optimizer_state.pth"))
            print(f"Checkpoint saved at {checkpoint_path}")





dataset = SatelliteImageDataset(csv_path, image_dir, max_minicubes=max_minicubes)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)

# Save the final trained model and optimizer state
pipeline.save_pretrained("trained_satellite_model_with_prompt_normalized_5050_NIR")
torch.save(optimizer.state_dict(), "optimizer_state_NIR.pth")