Environment setup


In [1]:
# Cell 1: Mount Drive & Install Dependencies

# 1. Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# 2. Install required libraries
!pip install --quiet torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
!pip install --quiet xarray h5py einops tqdm scikit-image opencv-python matplotlib pytorch-lightning


Mounted at /content/drive
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m116.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m78.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m47.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━

In [2]:
# Cell 2: Define Dataset Path and Configuration (for .h5-based dataset)

import os

# Path to the folder with .h5 files
DATASET_DIR = '/content/drive/MyDrive/reshaped_DS'

# Prototype configuration
IMG_SIZE = (64, 64)         # Resized shape to downsample images
NUM_INPUT_FRAMES = 6        # Past 3 hours (30 min interval)
NUM_OUTPUT_FRAMES = 2       # Future 1 hour (30 min intervals)
CHANNELS = 1                # Single spectral band (e.g., IR)

# File format
FILE_EXTENSION = '.h5'      # We'll scan for HDF5 files

# List and sort files
all_files = sorted([
    f for f in os.listdir(DATASET_DIR)
    if f.endswith(FILE_EXTENSION)
])

# Quick check
print(f"Total HDF5 files found: {len(all_files)}")
print("Sample file names:", all_files[:3])


Total HDF5 files found: 9
Sample file names: ['3SIMG_24JUN2025_0000_L1C_SGP_V01R00.h5', '3SIMG_24JUN2025_0030_L1C_SGP_V01R00.h5', '3SIMG_24JUN2025_0100_L1C_SGP_V01R00.h5']


In [3]:
# Cell 3: Load and Prepare a Sample Sequence (Input: 6, Output: 2) from .h5 Files

import h5py
import numpy as np
import torch

def read_band_from_h5(filepath, band_key='IMG_TIR1'):
    with h5py.File(filepath, 'r') as f:
        band = f[band_key][0]  # shape: (128, 128)
    return band

def normalize_band(band_array):
    return band_array.astype(np.float32) / 255.0  # normalize to [0, 1]

def load_sequence(file_list, band_key='IMG_TIR1'):
    """
    Loads 6 input frames and 2 output frames using a single band.
    Returns tensors: input_tensor (6 frames), target_tensor (2 frames)
    """
    input_frames = []
    target_frames = []

    for i, filename in enumerate(file_list[:NUM_INPUT_FRAMES + NUM_OUTPUT_FRAMES]):
        filepath = os.path.join(DATASET_DIR, filename)
        band = read_band_from_h5(filepath, band_key)
        band_norm = normalize_band(band)

        if i < NUM_INPUT_FRAMES:
            input_frames.append(band_norm)
        else:
            target_frames.append(band_norm)

    # Convert to tensors: shape (1, 1, T, H, W)
    input_tensor = torch.tensor(np.stack(input_frames)).unsqueeze(0).unsqueeze(0)
    target_tensor = torch.tensor(np.stack(target_frames)).unsqueeze(0).unsqueeze(0)

    print(f"Input tensor shape: {input_tensor.shape}")    # (1, 1, 6, 128, 128)
    print(f"Target tensor shape: {target_tensor.shape}")  # (1, 1, 2, 128, 128)

    return input_tensor, target_tensor

# Try loading the first 8 files (6 for input + 2 for output)
input_tensor, target_tensor = load_sequence(all_files, band_key='IMG_TIR1')


Input tensor shape: torch.Size([1, 1, 6, 128, 128])
Target tensor shape: torch.Size([1, 1, 2, 128, 128])


In [8]:
# Cell 4: Corrected 3D UNet that outputs only NUM_OUTPUT_FRAMES frames

import torch.nn as nn
import torch.nn.functional as F

class UNet3D(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, base_channels=16):
        super(UNet3D, self).__init__()

        self.encoder1 = nn.Sequential(
            nn.Conv3d(in_channels, base_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv3d(base_channels, base_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool3d(kernel_size=2)

        self.encoder2 = nn.Sequential(
            nn.Conv3d(base_channels, base_channels * 2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv3d(base_channels * 2, base_channels * 2, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool3d(kernel_size=2)

        self.bottleneck = nn.Sequential(
            nn.Conv3d(base_channels * 2, base_channels * 4, kernel_size=3, padding=1),
            nn.ReLU()
        )

        self.decoder2 = nn.Sequential(
            nn.Conv3d(base_channels * 4, base_channels * 2, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.decoder1 = nn.Sequential(
            nn.Conv3d(base_channels * 2, base_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )

        self.final_conv = nn.Conv3d(base_channels, out_channels, kernel_size=1)

    def forward(self, x):
        enc1 = self.encoder1(x)                     # -> (B, C, T, H, W)
        enc2 = self.encoder2(self.pool1(enc1))      # -> (B, C2, T//2, H//2, W//2)
        bottleneck = self.bottleneck(self.pool2(enc2))  # -> (B, C4, T//4, H//4, W//4)

        # Upsample to enc2 shape
        up2 = F.interpolate(bottleneck, size=enc2.shape[2:], mode='trilinear', align_corners=False)
        dec2 = self.decoder2(up2)

        # Upsample to enc1 shape
        up1 = F.interpolate(dec2, size=enc1.shape[2:], mode='trilinear', align_corners=False)
        dec1 = self.decoder1(up1)

        out = self.final_conv(dec1)  # -> (B, 1, T, H, W)

        # Slice temporal dimension to keep only the next NUM_OUTPUT_FRAMES frames
        # Assumes global NUM_OUTPUT_FRAMES is defined (e.g., 2)
        out = out[:, :, -NUM_OUTPUT_FRAMES:, :, :]

        return out

# Instantiate the model
model = UNet3D(in_channels=1, out_channels=1)
print(f"Model initialized. Total parameters: {sum(p.numel() for p in model.parameters())}")


Model initialized. Total parameters: 173457


In [9]:
# Cell 5: Training Setup and Mini Training Loop (Fixed for Updated UNet3D)

import torch.optim as optim
import torch.nn as nn

# Move model and data to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
input_tensor = input_tensor.to(device)
target_tensor = target_tensor.to(device)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Number of training epochs for prototype
EPOCHS = 10

print("Starting training...\n")
for epoch in range(EPOCHS):
    model.train()
    optimizer.zero_grad()

    # Forward pass
    output = model(input_tensor)

    # Compute loss
    loss = criterion(output, target_tensor)

    # Backward pass and optimization
    loss.backward()
    optimizer.step()

    print(f"Epoch [{epoch+1}/{EPOCHS}] - Loss: {loss.item():.6f}")


Starting training...

Epoch [1/10] - Loss: 6.867555
Epoch [2/10] - Loss: 6.777252
Epoch [3/10] - Loss: 6.505719
Epoch [4/10] - Loss: 5.526911
Epoch [5/10] - Loss: 3.057830
Epoch [6/10] - Loss: 0.598686
Epoch [7/10] - Loss: 5.179178
Epoch [8/10] - Loss: 0.521489
Epoch [9/10] - Loss: 1.336145
Epoch [10/10] - Loss: 2.498595
