In [None]:
class MovingMNISTDataset(Dataset):
    """Moving MNIST Dataset"""
    
    def __init__(self, root='./data', train=True, download=True):
        self.root = root
        self.train = train
        
        if download:
            self.download()
        
        # Load the same file for both train and test, but split differently
        self.data = np.load(os.path.join(root, 'mnist_test_seq.npy'))
        
        # Split the data: use first 8000 sequences for train, rest for test
        if train:
            self.data = self.data[:, :8000, :, :]
        else:
            self.data = self.data[:, 8000:, :, :]
    
    def download(self):
        os.makedirs(self.root, exist_ok=True)
        url = 'http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy'
        filepath = os.path.join(self.root, 'mnist_test_seq.npy')
        
        if not os.path.exists(filepath):
            print('Downloading Moving MNIST dataset...')
            urllib.request.urlretrieve(url, filepath)
            print('Download completed!')
    
    def __len__(self):
        return self.data.shape[1]
    
    def __getitem__(self, idx):
        sequence = self.data[:, idx, :, :].astype(np.float32) / 255.0
        sequence = np.expand_dims(sequence, axis=1)  # Add channel dim
        
        input_seq = sequence[:10]   # First 10 frames
        target_seq = sequence[10:]  # Last 10 frames
        
        return torch.tensor(input_seq), torch.tensor(target_seq)


class ConvLSTMPredictor(pl.LightningModule):
    def __init__(self, 
                 input_dim=3,  # RGB channels
                 temporal_dim=5,  # Temporal feature dimension
                 hidden_dims=[64, 64, 64], 
                 kernel_size=(3, 3), 
                 num_layers=3,
                 learning_rate=1e-3,
                 batch_size=32,
                 temporal_encoding='sinusoidal'):
        super().__init__()
        self.save_hyperparameters()
        
        # Temporal encoder
        self.temporal_encoder = TemporalEncoder(
            encoding_type=temporal_encoding,
            embed_dim=temporal_dim
        )
        
        # Project temporal features to spatial dimensions
        self.temporal_projection = nn.Linear(temporal_dim, input_dim)
        
        # ConvLSTM expects input_dim + temporal features
        total_input_dim = input_dim + temporal_dim
        
        # Encoder ConvLSTM
        self.encoder = ConvLSTM(
            input_dim=total_input_dim,
            hidden_dim=hidden_dims,
            kernel_size=kernel_size,
            num_layers=num_layers,
            batch_first=True,
            bias=True,
            return_all_layers=True
        )
        
        # Decoder ConvLSTM  
        self.decoder = ConvLSTM(
            input_dim=total_input_dim,
            hidden_dim=hidden_dims,
            kernel_size=kernel_size,
            num_layers=num_layers,
            batch_first=True,
            bias=True,
            return_all_layers=True
        )
        
        # Output layer (back to RGB)
        self.output_conv = nn.Conv2d(
            in_channels=hidden_dims[-1],
            out_channels=input_dim,
            kernel_size=1
        )
        
        self.criterion = nn.MSELoss()
    
    def _add_temporal_features(self, rgb_data, years, months, days):
        """Add temporal features to RGB data"""
        batch_size, seq_len, channels, height, width = rgb_data.shape
        
        # Get temporal encodings
        temporal_features = self.temporal_encoder(years, months, days)
        # Shape: [batch, seq_len, temporal_dim]
        
        # Expand temporal features to spatial dimensions
        temporal_spatial = temporal_features.unsqueeze(-1).unsqueeze(-1)
        temporal_spatial = temporal_spatial.expand(-1, -1, -1, height, width)
        # Shape: [batch, seq_len, temporal_dim, height, width]
        
        # Concatenate with RGB data
        combined_data = torch.cat([rgb_data, temporal_spatial], dim=2)
        # Shape: [batch, seq_len, channels + temporal_dim, height, width]
        
        return combined_data
    
    def forward(self, rgb_data, years, months, days, future_steps=10):
        # Add temporal features to input
        input_with_temporal = self._add_temporal_features(rgb_data, years, months, days)
        
        # Encode input sequence
        _, encoder_states = self.encoder(input_with_temporal)
        
        # For prediction, you'll need future temporal information
        # This is where you'd handle irregular temporal spacing
        decoder_hidden = encoder_states
        predictions = []
        
        # Use last input frame as initial decoder input
        last_temporal = self._add_temporal_features(
            rgb_data[:, -1:], years[:, -1:], months[:, -1:], days[:, -1:]
        )
        decoder_input = last_temporal
        
        for step in range(future_steps):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            
            # Generate RGB prediction (remove temporal channels)
            pred_frame = torch.sigmoid(self.output_conv(decoder_output[-1][:, -1]))
            predictions.append(pred_frame.unsqueeze(1))
            
            # For next step, you'd need to predict/provide future temporal info
            # This is application-specific
            decoder_input = last_temporal  # Simplified
        
        return torch.cat(predictions, dim=1)


def train_model(config=None):
    """Train the ConvLSTM model using PyTorch Lightning with wandb logging"""
    
    # Initialize wandb
    wandb.init(
        project="convlstm-moving-mnist",
        config=config or {
            "input_dim": 1,
            "hidden_dims": [64, 64, 64],
            "kernel_size": [3, 3],  # Changed to list format for wandb compatibility
            "num_layers": 3,
            "learning_rate": 1e-3,
            "batch_size": 64,
            "max_epochs": 50,
            "architecture": "ConvLSTM",
            "dataset": "Moving MNIST",
            "optimizer": "Adam",
            "scheduler": "StepLR"
        },
        tags=["convlstm", "video-prediction", "pytorch-lightning"]
    )
    
    # Initialize model with wandb config
    model = ConvLSTMPredictor(
        input_dim=wandb.config.input_dim,
        hidden_dims=wandb.config.hidden_dims,
        kernel_size=wandb.config.kernel_size,
        num_layers=wandb.config.num_layers,
        learning_rate=wandb.config.learning_rate,
        batch_size=wandb.config.batch_size,
        log_images=True,
        log_frequency=100
    )
    
    # Log model architecture
    wandb.watch(model, log_freq=100, log_graph=True)
    
    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        monitor='val/loss',
        dirpath='checkpoints/',
        filename='convlstm-{epoch:02d}-{val_loss:.4f}',
        save_top_k=3,
        mode='min',
        save_last=True
    )
    
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    
    # Wandb Logger
    wandb_logger = WandbLogger(
        project="convlstm-moving-mnist",
        log_model="all",  # Log model checkpoints
        save_dir="./wandb_logs"
    )
    
    # Trainer
    trainer = pl.Trainer(
        max_epochs=wandb.config.max_epochs,
        accelerator='auto',
        devices='auto',
        callbacks=[checkpoint_callback, lr_monitor],
        logger=wandb_logger,
        log_every_n_steps=50,
        val_check_interval=1.0,
        enable_progress_bar=True,
        enable_model_summary=True
    )
    
    # Train
    trainer.fit(model)
    
    # Test
    trainer.test(model)
    
    # Log final metrics
    wandb.log({
        "final_train_loss": trainer.callback_metrics.get("train/loss_epoch", 0),
        "final_val_loss": trainer.callback_metrics.get("val/loss", 0),
        "best_val_loss": checkpoint_callback.best_model_score.item() if checkpoint_callback.best_model_score else 0
    })
    
    # Finish wandb run
    wandb.finish()
    
    return model, trainer

# Example usage
if __name__ == "__main__":
    mode = 'train'  
    print("Starting ConvLSTM training on Moving MNIST...")
    model, trainer = train_model()
    print("Training completed!")
    print("View logs at: https://wandb.ai/")
    print("Best model saved in: checkpoints/")
        
class TemporalEncoder(nn.Module):
    def __init__(self, encoding_type='sinusoidal', embed_dim=32):
        super().__init__()
        self.encoding_type = encoding_type
        if encoding_type == 'learned':
            self.embeddings = nn.ModuleList([
                nn.Embedding(10, embed_dim // 3),  # years
                nn.Embedding(12, embed_dim // 3),  # months  
                nn.Embedding(32, embed_dim // 3)   # days
            ])
        
    def forward(self, years, months, days):
        if self.encoding_type == 'sinusoidal':
            year_norm = (years - 2020) / 4.0
            month_rad = 2 * math.pi * (months - 1) / 12.0
            day_rad = 2 * math.pi * days / 31.0
            return torch.stack([year_norm, torch.sin(month_rad), torch.cos(month_rad), 
                              torch.sin(day_rad), torch.cos(day_rad)], dim=-1)
        elif self.encoding_type == 'learned':
            embs = [emb((vals - offset).long()) for emb, vals, offset in 
                   zip(self.embeddings, [years, months, days], [2020, 1, 0])]
            return torch.cat(embs, dim=-1)
        else:
            return torch.stack([(years-2020)/4.0, (months-1)/11.0, days/31.0], dim=-1)

In [None]:
import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import rioxarray as rxr
import xarray as xr
import torchvision.transforms as transforms
from datetime import datetime
import re
from typing import List, Tuple, Optional
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
import matplotlib.pyplot as plt
import urllib.request
import wandb
import math

class TemporalEncoder(nn.Module):
    def __init__(self, embed_dim=5):
        super().__init__()
        self.embed_dim = embed_dim
    
    def forward(self, years, months, days):
        return self._sinusoidal_encoding(years, months, days)
    
    def _sinusoidal_encoding(self, years, months, days):        
        # Year encoding (linear trend)
        year_norm = (years - 1985) / 55.0
        
        # Month encoding (cyclical)
        month_rad = 2 * math.pi * (months - 1) / 12.0
        month_sin = torch.sin(month_rad)
        month_cos = torch.cos(month_rad)
        
        # Day encoding (cyclical within month)
        day_rad = 2 * math.pi * (days - 1) / 31.0
        day_sin = torch.sin(day_rad)
        day_cos = torch.cos(day_rad)
        
        # Combine features [batch_size, 5]
        temporal_features = torch.stack([
            year_norm, month_sin, month_cos, 
            day_sin, day_cos
        ], dim=-1)
        
        return temporal_features

class ConvLSTMCell(nn.Module):
    """ConvLSTM cell implementation."""
    def __init__(self, input_dim, hidden_dim, kernel_size, bias=True):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        
        self.conv = nn.Conv2d(input_dim + hidden_dim, 4 * hidden_dim, 
                             kernel_size, padding=self.padding, bias=bias)
    
    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state
        combined = torch.cat([input_tensor, h_cur], dim=1)
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        
        i, f, o = torch.sigmoid(cc_i), torch.sigmoid(cc_f), torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)
        
        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next
    
    def init_hidden(self, batch_size, image_size):
        h, w = image_size
        device = self.conv.weight.device
        return (torch.zeros(batch_size, self.hidden_dim, h, w, device=device),
                torch.zeros(batch_size, self.hidden_dim, h, w, device=device))

class ConvLSTM(nn.Module):
    """Multi-layer ConvLSTM."""
    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers, batch_first=True, bias=True):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim if isinstance(hidden_dim, list) else [hidden_dim] * num_layers
        self.kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] * num_layers
        self.num_layers = num_layers
        self.batch_first = batch_first
        
        self.cell_list = nn.ModuleList([
            ConvLSTMCell(input_dim if i == 0 else self.hidden_dim[i-1], 
                        self.hidden_dim[i], self.kernel_size[i], bias)
            for i in range(num_layers)
        ])
    
    def forward(self, input_tensor, hidden_state=None):
        if not self.batch_first:
            input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
        
        b, seq_len, _, h, w = input_tensor.size()
        
        if hidden_state is None:
            hidden_state = [cell.init_hidden(b, (h, w)) for cell in self.cell_list]
        
        layer_output_list, last_state_list = [], []
        cur_layer_input = input_tensor
        
        for layer_idx in range(self.num_layers):
            h, c = hidden_state[layer_idx]
            output_inner = []
            
            for t in range(seq_len):
                h, c = self.cell_list[layer_idx](cur_layer_input[:, t], [h, c])
                output_inner.append(h)
            
            layer_output = torch.stack(output_inner, dim=1)
            cur_layer_input = layer_output
            layer_output_list.append(layer_output)
            last_state_list.append([h, c])
        
        return layer_output_list, last_state_list

class TemporalFusionModule(nn.Module):
    """Module to fuse temporal encoding with spatial features."""
    def __init__(self, temporal_dim=5, spatial_channels=64, fused_channels=64):
        super().__init__()
        self.temporal_dim = temporal_dim
        self.spatial_channels = spatial_channels
        self.fused_channels = fused_channels
        
        # Project temporal features to spatial dimensions
        self.temporal_proj = nn.Sequential(
            nn.Linear(temporal_dim, spatial_channels),
            nn.ReLU(),
            nn.Linear(spatial_channels, spatial_channels)
        )
        
        # Fusion layer
        self.fusion_conv = nn.Conv2d(spatial_channels * 2, fused_channels, 1)
        
    def forward(self, spatial_features, temporal_features):
        # spatial_features: [batch, channels, height, width]
        # temporal_features: [batch, temporal_dim]
        
        batch_size, channels, height, width = spatial_features.shape
        
        # Project temporal features
        temporal_proj = self.temporal_proj(temporal_features)  # [batch, spatial_channels]
        
        # Expand temporal features to spatial dimensions
        temporal_spatial = temporal_proj.unsqueeze(-1).unsqueeze(-1).expand(
            batch_size, self.spatial_channels, height, width
        )
        
        # Concatenate and fuse
        combined = torch.cat([spatial_features, temporal_spatial], dim=1)
        fused = self.fusion_conv(combined)
        
        return fused

class SanAntonioSatelliteDataset(Dataset):
    """Enhanced dataset with temporal encoding support."""
    
    def __init__(self, data_dir: str, sequence_length: int = 5, target_length: int = 3,
                 image_size: int = 512, normalize: bool = True, temporal_stride: int = 1):
        self.data_dir = data_dir
        self.sequence_length = sequence_length
        self.target_length = target_length
        self.image_size = image_size
        self.normalize = normalize
        self.temporal_stride = temporal_stride
        
        # Get sorted .tif files by date
        self.tif_files = sorted(
            glob.glob(os.path.join(data_dir, "*.tif")),
            key=lambda f: datetime.strptime(re.search(r'(\d{4}-\d{2}-\d{2})', f).group(1), '%Y-%m-%d')
        )
        
        # Valid sequence starting indices
        total_needed = (sequence_length + target_length - 1) * temporal_stride + 1
        self.valid_sequences = list(range(len(self.tif_files) - total_needed + 1))
        
        print(f"Found {len(self.tif_files)} .tif files, {len(self.valid_sequences)} sequences")
    
    def _extract_date_from_filename(self, filename: str) -> Tuple[int, int, int]:
        """Extract year, month, day from filename."""
        # Extract date from San_Antonio_YYYY-MM-DD.tif format
        date_match = re.search(r'(\d{4})-(\d{2})-(\d{2})', filename)
        if date_match:
            year, month, day = map(int, date_match.groups())
            return year, month, day
        else:
            raise ValueError(f"Could not extract date from filename: {filename}")
    
    def _load_and_crop_tif(self, tif_path: str) -> torch.Tensor:
        """Load RGB channels from .tif and center crop to target size."""
        # Load with rioxarray (automatically handles CRS, transforms, etc.)
        da = rxr.open_rasterio(tif_path, chunks={'band': 1, 'x': 512, 'y': 512})
        
        # Take first 3 bands as RGB, center crop
        rgb = da.isel(band=slice(0, 3))
        h, w = rgb.sizes['y'], rgb.sizes['x']
        
        # Center crop indices
        center_y, center_x = h // 2, w // 2
        half_size = self.image_size // 2
        y_slice = slice(max(0, center_y - half_size), center_y + half_size)
        x_slice = slice(max(0, center_x - half_size), center_x + half_size)
        
        cropped = rgb.isel(y=y_slice, x=x_slice)
        
        # Convert to numpy and ensure correct shape/dtype
        data = cropped.values.astype(np.float32)
        
        # Pad if necessary to reach target size
        if data.shape[1] < self.image_size or data.shape[2] < self.image_size:
            padded = np.zeros((3, self.image_size, self.image_size), dtype=np.float32)
            h, w = data.shape[1], data.shape[2]
            start_h, start_w = (self.image_size - h) // 2, (self.image_size - w) // 2
            padded[:, start_h:start_h+h, start_w:start_w+w] = data
            data = padded
        
        # Normalize to [0,1]
        if self.normalize:
            if data.max() > 1:  # Assume uint8/uint16 if values > 1
                data = data / (65535.0 if data.max() > 255 else 255.0)
            else:
                # Percentile normalization for float data
                p1, p99 = np.percentile(data, [1, 99])
                if p99 > p1:
                    data = np.clip((data - p1) / (p99 - p1), 0, 1)
        
        return torch.from_numpy(data)
    
    def __len__(self) -> int:
        return len(self.valid_sequences)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        start_idx = self.valid_sequences[idx]
        
        # Get file indices for input and target sequences
        input_indices = [start_idx + i * self.temporal_stride for i in range(self.sequence_length)]
        target_indices = [start_idx + (self.sequence_length + i) * self.temporal_stride 
                         for i in range(self.target_length)]
        
        # Load image sequences
        input_seq = torch.stack([self._load_and_crop_tif(self.tif_files[i]) for i in input_indices])
        target_seq = torch.stack([self._load_and_crop_tif(self.tif_files[i]) for i in target_indices])
        
        # Extract temporal information
        input_dates = []
        target_dates = []
        
        for i in input_indices:
            year, month, day = self._extract_date_from_filename(self.tif_files[i])
            input_dates.append([year, month, day])
            
        for i in target_indices:
            year, month, day = self._extract_date_from_filename(self.tif_files[i])
            target_dates.append([year, month, day])
        
        input_temporal = torch.tensor(input_dates, dtype=torch.float32)  # [seq_len, 3]
        target_temporal = torch.tensor(target_dates, dtype=torch.float32)  # [target_len, 3]
        
        return input_seq, target_seq, input_temporal, target_temporal

class SanAntonioDataModule(pl.LightningDataModule):
    """Lightning DataModule for San Antonio satellite data with temporal encoding."""
    
    def __init__(self, data_dir: str, sequence_length: int = 10, target_length: int = 5,
                 image_size: int = 512, batch_size: int = 4, num_workers: int = 4,
                 train_split: float = 0.8, val_split: float = 0.1, **kwargs):
        super().__init__()
        self.save_hyperparameters()
    
    def setup(self, stage: Optional[str] = None):
        dataset = SanAntonioSatelliteDataset(
            self.hparams.data_dir, 
            self.hparams.sequence_length, 
            self.hparams.target_length,
            self.hparams.image_size, 
            **{k: v for k, v in self.hparams.items() 
               if k not in ['data_dir', 'batch_size', 'num_workers', 'train_split', 'val_split', 
                           'sequence_length', 'target_length', 'image_size']}
        )
        
        # Split dataset
        total = len(dataset)
        train_size = int(self.hparams.train_split * total)
        val_size = int(self.hparams.val_split * total)
        
        self.train_dataset = torch.utils.data.Subset(dataset, range(train_size))
        self.val_dataset = torch.utils.data.Subset(dataset, range(train_size, train_size + val_size))
        self.test_dataset = torch.utils.data.Subset(dataset, range(train_size + val_size, total))
        
        print(f"Splits - Train: {len(self.train_dataset)}, Val: {len(self.val_dataset)}, Test: {len(self.test_dataset)}")
    
    def _dataloader(self, dataset, shuffle=False):
        return DataLoader(dataset, batch_size=self.hparams.batch_size, shuffle=shuffle,
                         num_workers=self.hparams.num_workers, pin_memory=True,
                         persistent_workers=self.hparams.num_workers > 0)
    
    def train_dataloader(self): return self._dataloader(self.train_dataset, shuffle=True)
    def val_dataloader(self): return self._dataloader(self.val_dataset)
    def test_dataloader(self): return self._dataloader(self.test_dataset)

class SatelliteConvLSTMPredictor(pl.LightningModule):
    """Enhanced ConvLSTM model with temporal encoding for satellite imagery prediction."""
    
    def __init__(self, input_dim=3, hidden_dims=[64, 64, 64], kernel_size=(3, 3), 
                 num_layers=3, learning_rate=1e-3, target_length=5, batch_size=4,
                 temporal_dim=5, use_temporal_fusion=True,
                 log_images=True, log_frequency=100):
        super().__init__()
        self.save_hyperparameters()
        
        # Temporal encoder
        self.temporal_encoder = TemporalEncoder(embed_dim=temporal_dim)
        
        # Both encoder and decoder use same input dimensions
        self.encoder = ConvLSTM(input_dim, hidden_dims, kernel_size, num_layers, True, True)
        self.decoder = ConvLSTM(input_dim, hidden_dims, kernel_size, num_layers, True, True)
        
        # Temporal fusion modules
        if use_temporal_fusion:
            self.encoder_temporal_fusion = TemporalFusionModule(
                temporal_dim, hidden_dims[-1], hidden_dims[-1]
            )
            self.decoder_temporal_fusion = TemporalFusionModule(
                temporal_dim, hidden_dims[-1], hidden_dims[-1]
            )
        
        # Output projection
        self.output_conv = nn.Conv2d(hidden_dims[-1], input_dim, 1)
        self.criterion = nn.MSELoss()
        
        # For logging
        self.log_images = log_images
        self.log_frequency = log_frequency
        self.step_count = 0
    
    def forward(self, x, input_temporal, target_temporal):
        # x: [batch, seq_len, channels, height, width]
        # input_temporal: [batch, seq_len, 3] (year, month, day)
        # target_temporal: [batch, target_len, 3]
        
        batch_size, seq_len = x.shape[:2]
        target_len = target_temporal.shape[1]
        
        # Encode input sequence
        encoder_outputs, encoder_states = self.encoder(x)
        
        # Apply temporal fusion to encoder states if enabled
        if self.hparams.use_temporal_fusion:
            # Use last input temporal encoding for encoder fusion
            last_input_temporal = input_temporal[:, -1]  # [batch, 3]
            temporal_encoding = self.temporal_encoder(
                last_input_temporal[:, 0], 
                last_input_temporal[:, 1], 
                last_input_temporal[:, 2]
            )
            
            # Fuse with encoder output
            encoder_features = encoder_outputs[-1][:, -1]  # [batch, hidden_dim, H, W]
            fused_encoder = self.encoder_temporal_fusion(encoder_features, temporal_encoding)
            
            # Update encoder states
            encoder_states[-1][0] = fused_encoder
        
        # Decode target sequence
        predictions = []
        decoder_input = x[:, -1:, :, :, :]  # Start with last input frame
        decoder_hidden = encoder_states
        
        for t in range(target_len):
            # Get temporal encoding for current target timestep
            current_temporal = target_temporal[:, t]  # [batch, 3]
            temporal_encoding = self.temporal_encoder(
                current_temporal[:, 0], 
                current_temporal[:, 1], 
                current_temporal[:, 2]
            )
            
            # Decode one step
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            
            # Apply temporal fusion if enabled
            if self.hparams.use_temporal_fusion:
                decoder_features = decoder_output[-1][:, -1]  # [batch, hidden_dim, H, W]
                fused_decoder = self.decoder_temporal_fusion(decoder_features, temporal_encoding)
            else:
                fused_decoder = decoder_output[-1][:, -1]
            
            # Generate prediction
            pred_frame = torch.sigmoid(self.output_conv(fused_decoder))
            predictions.append(pred_frame.unsqueeze(1))
            
            # Use prediction as next input
            decoder_input = pred_frame.unsqueeze(1)
        
        return torch.cat(predictions, dim=1)
    
    def _step(self, batch, stage):
        input_seq, target_seq, input_temporal, target_temporal = batch
        predictions = self(input_seq.float(), input_temporal, target_temporal)
        loss = self.criterion(predictions, target_seq.float())
        
        # Enhanced logging with additional metrics
        self.log(f'{stage}/loss', loss, prog_bar=True, sync_dist=True, on_epoch=True, on_step=True)
        
        # Calculate additional metrics
        if stage in ['val', 'test']:
            with torch.no_grad():
                mae = torch.mean(torch.abs(predictions - target_seq.float()))
                mse = torch.mean((predictions - target_seq.float()) ** 2)
                psnr = 20 * torch.log10(1.0 / torch.sqrt(mse + 1e-8))
                
                self.log(f'{stage}/mae', mae, sync_dist=True)
                self.log(f'{stage}/mse', mse, sync_dist=True)
                self.log(f'{stage}/psnr', psnr, sync_dist=True)
        
        # Log images periodically
        if self.log_images and self.step_count % self.log_frequency == 0 and stage == 'val':
            self._log_prediction_images(input_seq, target_seq, predictions, input_temporal, target_temporal, stage)
        
        self.step_count += 1
        return loss
    
    def _log_prediction_images(self, input_seq, target_seq, predictions, input_temporal, target_temporal, stage):
        """Log prediction visualizations with temporal information to wandb"""
        try:
            # Take first sample from batch
            input_sample = input_seq[0].detach().cpu().float().numpy()
            target_sample = target_seq[0].detach().cpu().float().numpy()
            pred_sample = predictions[0].detach().cpu().float().numpy()
            input_temp = input_temporal[0].detach().cpu().numpy()
            target_temp = target_temporal[0].detach().cpu().numpy()
            
            # Create visualization
            num_frames = min(5, self.hparams.target_length)
            fig, axes = plt.subplots(3, num_frames, figsize=(num_frames * 4, 12))
            
            if num_frames == 1:
                axes = axes.reshape(3, 1)
            
            for t in range(num_frames):
                # Last input frame (only show in first column)
                if t == 0:
                    input_img = np.transpose(input_sample[-1], (1, 2, 0)).astype(np.float32)
                    input_img = np.clip(input_img, 0, 1)
                    axes[0, t].imshow(input_img)
                    # Add temporal info
                    last_date = input_temp[-1]
                    axes[0, t].set_title(f'Last Input\n{int(last_date[0])}-{int(last_date[1]):02d}-{int(last_date[2]):02d}', 
                                        fontsize=10)
                else:
                    axes[0, t].axis('off')
                
                # Target frame
                target_img = np.transpose(target_sample[t], (1, 2, 0)).astype(np.float32)
                target_img = np.clip(target_img, 0, 1)
                axes[1, t].imshow(target_img)
                target_date = target_temp[t]
                axes[1, t].set_title(f'Target {t+1}\n{int(target_date[0])}-{int(target_date[1]):02d}-{int(target_date[2]):02d}', 
                                    fontsize=10)
                
                # Predicted frame
                pred_img = np.transpose(pred_sample[t], (1, 2, 0)).astype(np.float32)
                pred_img = np.clip(pred_img, 0, 1)
                axes[2, t].imshow(pred_img)
                axes[2, t].set_title(f'Predicted {t+1}\n{int(target_date[0])}-{int(target_date[1]):02d}-{int(target_date[2]):02d}', 
                                    fontsize=10)
                
                # Remove axes
                for i in range(3):
                    axes[i, t].set_xticks([])
                    axes[i, t].set_yticks([])
            
            # Add row labels
            axes[0, 0].set_ylabel('Input', fontsize=14, rotation=90, labelpad=20)
            axes[1, 0].set_ylabel('Target', fontsize=14, rotation=90, labelpad=20)
            axes[2, 0].set_ylabel('Predicted', fontsize=14, rotation=90, labelpad=20)
            
            plt.tight_layout()
            
            # Log to wandb
            if hasattr(self.logger, 'experiment'):
                self.logger.experiment.log({
                    f'{stage}/predictions_with_dates': wandb.Image(fig),
                    'epoch': self.current_epoch,
                    'step': self.step_count
                })
            
            plt.close(fig)
            
        except Exception as e:
            print(f"Error logging images: {e}")
            import traceback
            print(f"Full traceback: {traceback.format_exc()}")
    
    def training_step(self, batch, batch_idx): 
        return self._step(batch, 'train')
    
    def validation_step(self, batch, batch_idx): 
        return self._step(batch, 'val')
    
    def test_step(self, batch, batch_idx):
        return self._step(batch, 'test')
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
        return [optimizer], [scheduler]

def train_satellite_model(config=None):
    """Train the Satellite ConvLSTM model with temporal encoding using PyTorch Lightning"""
    
    # Initialize wandb
    wandb.init(
        project="convlstm-satellite-temporal",
        config=config or {
            "input_dim": 3,
            "hidden_dims": [32, 64, 64],
            "kernel_size": [3, 3],
            "num_layers": 3,
            "learning_rate": 1e-4,
            "batch_size": 2,
            "max_epochs": 50,
            "sequence_length": 10,
            "target_length": 5,
            "image_size": 512,
            "temporal_dim": 5,
            "use_temporal_fusion": True,
            "architecture": "ConvLSTM-Temporal",
            "dataset": "San Antonio Satellite",
            "optimizer": "Adam",
            "scheduler": "StepLR",
            "precision": "16-mixed",
            "gradient_clip_val": 1.0,
            "data_dir": "./San_Antonio",
            "train_split": 0.8,
            "val_split": 0.1,
            "num_workers": 2,
            "log_images": True,
            "log_frequency": 50
        },
        tags=["convlstm", "temporal-encoding", "satellite-prediction", "pytorch-lightning", "san-antonio"]
    )
    
    # Data module
    data_module = SanAntonioDataModule(
        data_dir=wandb.config.data_dir,
        sequence_length=wandb.config.sequence_length,
        target_length=wandb.config.target_length,
        image_size=wandb.config.image_size,
        batch_size=wandb.config.batch_size,
        num_workers=wandb.config.num_workers,
        train_split=wandb.config.train_split,
        val_split=wandb.config.val_split
    )
    
    # Initialize model with wandb config
    model = SatelliteConvLSTMPredictor(
        input_dim=wandb.config.input_dim,
        hidden_dims=wandb.config.hidden_dims,
        kernel_size=tuple(wandb.config.kernel_size),
        num_layers=wandb.config.num_layers,
        learning_rate=wandb.config.learning_rate,
        target_length=wandb.config.target_length,
        batch_size=wandb.config.batch_size,
        temporal_dim=wandb.config.temporal_dim,
        use_temporal_fusion=wandb.config.use_temporal_fusion,
        log_images=wandb.config.log_images,
        log_frequency=wandb.config.log_frequency
    )
    
    # Log model architecture
    wandb.watch(model, log_freq=100, log_graph=True)
    
    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        monitor='val/loss',
        dirpath='checkpoints/',
        filename='satellite-convlstm-temporal-{epoch:02d}-{val_loss:.4f}',
        save_top_k=3,
        mode='min',
        save_last=True
    )
    
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    
    # Wandb Logger
    wandb_logger = WandbLogger(
        project="convlstm-satellite-temporal",
        log_model="all",
        save_dir="./wandb_logs"
    )
    
    # Trainer
    trainer = pl.Trainer(
        max_epochs=wandb.config.max_epochs,
        accelerator='auto',
        devices=1,
        precision=wandb.config.precision,
        gradient_clip_val=wandb.config.gradient_clip_val,
        callbacks=[checkpoint_callback, lr_monitor],
        logger=wandb_logger,
        log_every_n_steps=10,
        val_check_interval=0.5,
        limit_val_batches=10,
        enable_progress_bar=True,
        enable_model_summary=True
    )
    
    # Train
    trainer.fit(model, data_module)
    
    # Test
    trainer.test(model, data_module)
    
    # Log final metrics
    final_metrics = {
        "final_train_loss": trainer.callback_metrics.get("train/loss_epoch", 0),
        "final_val_loss": trainer.callback_metrics.get("val/loss", 0),
        "final_val_mae": trainer.callback_metrics.get("val/mae", 0),
        "final_val_psnr": trainer.callback_metrics.get("val/psnr", 0),
        "best_val_loss": checkpoint_callback.best_model_score.item() if checkpoint_callback.best_model_score else 0,
        "total_parameters": sum(p.numel() for p in model.parameters()),
        "trainable_parameters": sum(p.numel() for p in model.parameters() if p.requires_grad),
        "temporal_encoding_enabled": wandb.config.use_temporal_fusion,
        "temporal_dimensions": wandb.config.temporal_dim
    }
    
    wandb.log(final_metrics)
    
    # Create summary table
    summary_data = []
    for key, value in final_metrics.items():
        summary_data.append([key, value])
    
    table = wandb.Table(data=summary_data, columns=["Metric", "Value"])
    wandb.log({"final_metrics_table": table})
    
    # Finish wandb run
    wandb.finish()
    
    return model, trainer

if __name__ == "__main__":  
    # Start training
    print("Starting Satellite ConvLSTM training with temporal encoding...")
    model, trainer = train_satellite_model()
    print("Training completed!")
    print("View logs at: https://wandb.ai/")
    print("Best model saved in: checkpoints/")

Enhanced Satellite ConvLSTM with Temporal Encoding
Testing temporal encoding...
Temporal encoding shape: torch.Size([3, 5])
Sample encoding for 2023-01-15: tensor([ 0.6909,  0.0000,  1.0000,  0.2994, -0.9541])
Sample encoding for 2023-06-20: tensor([ 0.6909,  0.5000, -0.8660, -0.6514, -0.7588])
Sample encoding for 2024-12-25: tensor([ 0.7091, -0.5000,  0.8660, -0.9885,  0.1514])

Testing dataset with temporal encoding...
Found 48 .tif files, 44 sequences
Input sequence shape: torch.Size([3, 3, 256, 256])
Target sequence shape: torch.Size([2, 3, 256, 256])
Input temporal shape: torch.Size([3, 3])
Target temporal shape: torch.Size([2, 3])
Input dates: tensor([[2.0200e+03, 1.0000e+00, 7.0000e+00],
        [2.0200e+03, 1.0000e+00, 2.3000e+01],
        [2.0200e+03, 4.0000e+00, 1.2000e+01]])
Target dates: tensor([[2.0200e+03, 8.0000e+00, 2.0000e+00],
        [2.0200e+03, 1.0000e+01, 5.0000e+00]])

Analyzing temporal patterns...

Seasonal temporal encodings:
Winter: tensor([ 0.6909, -0.5000, 

[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/root/projects/LSTM/.venv/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/root/projects/LSTM/.venv/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /root/projects/LSTM/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                    | Type                 | Params | Mode 
-------------------------------------------------------------------------
0 | temporal_encoder        | TemporalEncoder      | 0      | train
1 | encoder                 | ConvLS

Found 48 .tif files, 34 sequences
Splits - Train: 27, Val: 3, Test: 4


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Found 48 .tif files, 34 sequences
Splits - Train: 27, Val: 3, Test: 4


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test/loss_epoch       0.0025221684481948614
        test/mae           0.033692747354507446
        test/mse           0.0025221684481948614
        test/psnr           25.982460021972656
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


TypeError: Data row contained incompatible types:
{'Metric': 'temporal_encoding_enabled', 'Value': True} of type {'Metric': String, 'Value': Boolean} is not assignable to {'Metric': None or String, 'Value': None or Number}
Key 'Value':
	Boolean not assignable to None or Number
		Boolean not assignable to None
	and
		Boolean not assignable to Number

In [None]:
#Add wandb functionality
#Int LST to Int LST
#MSE RMSE MAE
#Add time encoding