# Multi-Model U-Net Training for Methane Emission Segmentation
### Annabel Simpson

This notebook implements and trains three different U-Net models using RGB, RGB+Mag1c, and Multispectral inputs.

In [None]:
#Install Dependencies
!pip install --quiet rasterio gdown fsspec omegaconf torch torchvision torchtext \
    pytorch-lightning segmentation_models_pytorch hydra-core geopandas \
    ipykernel matplotlib scikit-image scikit-learn wandb kornia==0.6.7 torchmetrics==0.10.0


!pip install git+https://github.com/spaceml-org/georeader.git --quiet


!git clone https://github.com/spaceml-org/STARCOP.git || echo "Repository already exists"

%cd /content/

#Download Data & Models
!gdown https://drive.google.com/uc?id=1Qw96Drmk2jzBYSED0YPEUyuc2DnBechl -O STARCOP_mini.zip
!gdown https://drive.google.com/uc?id=1TXFlAHO_eRdfbJGLNNt3KY0lJqjm3fdX -O multistarcop_varon.zip
!gdown https://drive.google.com/uc?id=1Kvnc_lOBn4z-xO1HFRyLZOMEldXWQvql -O hyperstarcop_magic_rgb.zip

#Unzip
!unzip -qo STARCOP_mini.zip
!unzip -qo multistarcop_varon.zip
!unzip -qo hyperstarcop_magic_rgb.zip


!rm -f *.zip


!ls /content/hyperstarcop_magic_rgb
!ls /content/multistarcop_varon


%cd /content/STARCOP


import omegaconf
import pylab as plt
import torch
import os
import pandas as pd
import numpy as np
import rasterio
import ast
import pkgutil
from torchvision import transforms
from PIL import Image
from mpl_toolkits.axes_grid1 import make_axes_locatable

from starcop.data.datamodule import Permian2019DataModule
from starcop.models.model_module import ModelModule

#Set Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


try:
    import georeader
    print("Spaceman Georeader found at:", georeader.__file__)
    modules = [module.name for module in pkgutil.iter_modules(georeader.__path__)]
    print("Available Spaceman Georeader Modules:", modules)

    if 'slices' not in modules:
        raise ModuleNotFoundError("Missing 'slices' module in Spaceman Georeader")
except ModuleNotFoundError:
    print("Spaceman Georeader module not found or missing 'slices'. Reinstalling...")
    !pip install --upgrade --force-reinstall git+https://github.com/spaceml-org/georeader.git --quiet
    import georeader


config_general = omegaconf.OmegaConf.load("scripts/configs/config.yaml")
root_folder = "/content/STARCOP_mini"

# Func for Pre-trained Model with rgb+mag1c
def load_model_with_datamodule(model_path, config_path):
    config_model = omegaconf.OmegaConf.load(config_path)
    config = omegaconf.OmegaConf.merge(config_general, config_model)

    config_dict = omegaconf.OmegaConf.to_container(config_model, resolve=True)
    print("Config Model Structure:", config_dict) 

    dataset_str = config_dict.get("_content", {}).get("value", {}).get("dataset", "")
    try:
        dataset_dict = ast.literal_eval(dataset_str)
    except (SyntaxError, ValueError):
        raise ValueError("Dataset configuration is not properly formatted as a dictionary")

    dataset_dict['root_folder'] = root_folder
    dataset_dict['train_csv'] = '/content/STARCOP_mini/train_mini10.csv'
    config.dataset = dataset_dict
    config.products_plot = config_dict.get("products_plot", {})

    data_module = Permian2019DataModule(config)
    data_module.test_csv = '/content/STARCOP_mini/test_mini10.csv'
    data_module.settings['dataset'] = dataset_dict

    data_module.prepare_data()

    model = ModelModule.load_from_checkpoint(model_path, settings=config)
    model.to(device)
    model.eval()

    print(f"Loaded {os.path.basename(model_path)} with {model.num_channels} input channels")
    return model, data_module, config

# HyperSTARCOP Model with rgb+mag1c
hsi_model_path = "/content/hyperstarcop_magic_rgb/final_checkpoint_model.ckpt"
hsi_config_path = "/content/hyperstarcop_magic_rgb/config.yaml"
hsi_model, hsi_dm, hsi_config = load_model_with_datamodule(hsi_model_path, hsi_config_path)
print("Successfully loaded HyperSTARCOP model!")

# MultiSTARCOP Model (TOA_WV3_* or TOA_AVIRIS_*)
msi_model_path = "/content/multistarcop_varon/final_checkpoint_model.ckpt"
msi_config_path = "/content/multistarcop_varon/config.yaml"
msi_model, msi_dm, msi_config = load_model_with_datamodule(msi_model_path, msi_config_path)
print("Successfully loaded MultiSTARCOP model!")

In [None]:
import os
import numpy as np
import rasterio
import matplotlib.pyplot as plt


def load_raster(image_path):
    """Loads a single band from a raster file."""
    if not os.path.exists(image_path):
        return None
    with rasterio.open(image_path) as src:
        return src.read(1).astype(np.float32)  


def normalize(img):
    """Normalize image data to [0,1] for display."""
    if img is None or img.max() == 0:
        return None  
    return (img - img.min()) / (img.max() - img.min() + 1e-8) 


def get_image_paths(folder_path):
    """Gets the paths to available spectral bands inside a given folder, including Mag1c."""
    return {
        "RGB Composite": [os.path.join(folder_path, f"TOA_AVIRIS_{wavelength}nm.tif") for wavelength in ["640", "550", "460"]],
        "Ground Truth": os.path.join(folder_path, "labelbinary.tif"),
        "Mag1c": os.path.join(folder_path, "mag1c.tif"), 
        "AVIRIS_2004": os.path.join(folder_path, "TOA_AVIRIS_2004nm.tif"),
        "AVIRIS_2109": os.path.join(folder_path, "TOA_AVIRIS_2109nm.tif"),
        "AVIRIS_2310": os.path.join(folder_path, "TOA_AVIRIS_2310nm.tif"),
        "AVIRIS_2350": os.path.join(folder_path, "TOA_AVIRIS_2350nm.tif"),
        "AVIRIS_2360": os.path.join(folder_path, "TOA_AVIRIS_2360nm.tif"),
        "SWIR1": os.path.join(folder_path, "TOA_WV3_SWIR1.tif"),
        "SWIR2": os.path.join(folder_path, "TOA_WV3_SWIR2.tif"),
        "SWIR3": os.path.join(folder_path, "TOA_WV3_SWIR3.tif"),
        "SWIR4": os.path.join(folder_path, "TOA_WV3_SWIR4.tif"),
        "SWIR5": os.path.join(folder_path, "TOA_WV3_SWIR5.tif"),
    }


def visualize_more_bands(folder_path):
    """Displays multiple bands including AVIRIS, SWIR, and Mag1c bands."""
    paths = get_image_paths(folder_path)
    loaded_bands = {}


    rgb_bands = [load_raster(path) for path in paths["RGB Composite"]]
    if all(b is not None for b in rgb_bands):
        rgb_image = np.stack(rgb_bands, axis=-1)
        rgb_image = normalize(rgb_image)  
        loaded_bands["RGB Composite"] = rgb_image

    # Load other bands
    for key, path in paths.items():
        if key == "RGB Composite":  
            continue  

        band = load_raster(path)
        if band is not None:
            loaded_bands[key] = normalize(band)

    # Remove any bands that failed to load
    loaded_bands = {k: v for k, v in loaded_bands.items() if v is not None}

    # Ensure at least one band is available
    if not loaded_bands:
        print("No valid bands found for visualization.")
        return

    # Create figure dynamically based on number of available bands
    num_bands = len(loaded_bands)
    fig, axes = plt.subplots(nrows=1, ncols=num_bands, figsize=(20, 5))

    # Handle single band case
    if num_bands == 1:
        axes = [axes]  

    for ax, (band_name, band_data) in zip(axes, loaded_bands.items()):
        ax.imshow(band_data, cmap="gray" if "Ground Truth" in band_name else "viridis")
        ax.set_title(band_name)
        ax.axis("off")

    plt.tight_layout()
    plt.show()


sample_folder = "/content/STARCOP_mini/ang20191018t165503_r2660_c460_w151_h151"
visualize_more_bands(sample_folder)




In [None]:
!pip install rasterio matplotlib pandas numpy folium

import os
import zipfile
import rasterio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import folium
from rasterio.warp import transform_bounds
from mpl_toolkits.axes_grid1 import make_axes_locatable
from google.colab import drive

In [None]:
# Mount Google Drive
drive.mount('/content/drive', force_remount=True)

# Define dataset paths
dataset_folder = "/content/drive/MyDrive"
zip_path = "/content/drive/MyDrive/STARCOPtrain.zip"
extract_path = "/content/starcop_data"

# Extract
if os.path.exists(zip_path) and not os.path.exists(extract_path):
    print("Dataset found. Extracting...")
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extract_path)
    print("Extraction complete!")
elif os.path.exists(extract_path):
    print("Dataset already extracted!")
else:
    print("Dataset not found! Check the path.")

In [None]:
import os
import pandas as pd

# Paths
train_csv_path = "/content/drive/MyDrive/AL4CC/AL4CC/train.csv"
test_csv_path = "/content/drive/MyDrive/AL4CC/AL4CC/test.csv"
train_data_dir = "/content/starcop_data/STARCOP_train_easy"
test_data_dir = "/content/drive/MyDrive/STARCOP_test"

# Load CSVs
df_train = pd.read_csv(train_csv_path)
df_test = pd.read_csv(test_csv_path)

# Get actual available folders
available_train_folders = set(os.listdir(train_data_dir))
available_test_folders = set(os.listdir(test_data_dir))

# Ensure folder name column exists
df_train["folder_name"] = df_train["folder"].apply(lambda x: os.path.basename(str(x)))
df_test["folder_name"] = df_test["folder"].apply(lambda x: os.path.basename(str(x)))

# Filter to only include available folders
df_train_filtered = df_train[df_train["folder_name"].isin(available_train_folders)].copy().reset_index(drop=True)
df_test_filtered = df_test[df_test["folder_name"].isin(available_test_folders)].copy().reset_index(drop=True)

# Save updated CSVs
filtered_train_path = "/content/drive/MyDrive/AL4CC/AL4CC/train_filtered.csv"
filtered_test_path = "/content/drive/MyDrive/AL4CC/AL4CC/test_filtered.csv"

df_train_filtered.to_csv(filtered_train_path, index=False)
df_test_filtered.to_csv(filtered_test_path, index=False)

print(f"Updated train CSV: {df_train_filtered.shape}")
print(f"Updated test CSV: {df_test_filtered.shape}")



In [None]:
import os
import numpy as np
import rasterio
import matplotlib.pyplot as plt
import cv2
from mpl_toolkits.axes_grid1 import make_axes_locatable


def load_raster(image_path):
    if not os.path.exists(image_path):
        print(f"Missing file: {image_path}")
        return None
    with rasterio.open(image_path) as src:
        return src.read(1)


def normalize_image(img):
    """ Normalize image to 0-255 and convert to uint8 for display. """
    if img is None or np.max(img) == 0:
        return None  
    img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-8)  
    return (img * 255).astype(np.uint8)  

def enhance_contrast(img):
    """ Apply histogram equalization for better contrast. """
    if img is None:
        return None
    img = normalize_image(img)
    return cv2.equalizeHist(img)


def get_image_paths(event_id, train_data_dir):
    folder_path = os.path.join(train_data_dir, event_id)
    if not os.path.exists(folder_path):
        print(f"Skipping {event_id} (Folder missing)")
        return None

    return {
        "R": os.path.join(folder_path, "TOA_AVIRIS_640nm.tif"),
        "G": os.path.join(folder_path, "TOA_AVIRIS_550nm.tif"),
        "B": os.path.join(folder_path, "TOA_AVIRIS_460nm.tif"),
        "Mag1c": os.path.join(folder_path, "mag1c.tif"),
        "GT": os.path.join(folder_path, "labelbinary.tif"),
    }


def visualize_more_bands(event_id, train_data_dir):
    paths = get_image_paths(event_id, train_data_dir)
    if paths is None:
        return
    
    # Load images
    loaded_bands = {key: load_raster(paths[key]) for key in paths if os.path.exists(paths[key])}
    
    # Ensure all images exist
    required_bands = ["R", "G", "B", "Mag1c", "GT"]
    if any(b not in loaded_bands for b in required_bands):
        print(f"Skipping {event_id} (missing essential bands)")
        return

    # Normalize RGB and Mag1c
    r, g, b = map(normalize_image, [loaded_bands["R"], loaded_bands["G"], loaded_bands["B"]])
    mag1c = enhance_contrast(loaded_bands["Mag1c"])

    # Normalize GT Mask (Ensure binary 0/1) - Convert mask to binary
    gt = loaded_bands["GT"]
    gt = (gt > 0).astype(np.uint8) 

    # Stack RGB + Mag1c
    rgb_mag1c_image = np.stack([r, g, b], axis=-1)

    # Plot Images
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # RGB Composite
    axes[0].imshow(rgb_mag1c_image)
    axes[0].set_title("RGB Composite")
    axes[0].axis("off")

    # Mag1c Band
    im = axes[1].imshow(mag1c, cmap="magma")
    axes[1].set_title("Enhanced Mag1c")
    axes[1].axis("off")
    plt.colorbar(im, ax=axes[1])

    # GT Mask
    im = axes[2].imshow(gt, cmap="gray")
    axes[2].set_title("Ground Truth (Binary Mask)")
    axes[2].axis("off")
    plt.colorbar(im, ax=axes[2])

    plt.tight_layout()
    plt.show(block=True)  
    print("Display successful!")

# Test
visualize_more_bands("ang20191010t192008_r8444_c339_w151_h151", "/content/starcop_data/STARCOP_train_easy")



In [None]:
# Create a folium map centered at the Permian Basin
map_plumes = folium.Map(location=[31.7, -103.6], zoom_start=8)

for idx, row in df_train_filtered.iterrows():
    folder_path = os.path.join(train_data_dir, row["folder_name"])
    label_path = os.path.join(folder_path, "labelbinary.tif")

    # Skip missing labels
    if not os.path.exists(label_path):
        print(f"Skipping {row['folder_name']} (no labelbinary.tif)")
        continue

    with rasterio.open(label_path) as src:
        bounds_utm = src.bounds
        bounds_lng_lat = transform_bounds(src.crs, "EPSG:4326", *bounds_utm)
        lat, lon = (bounds_lng_lat[1] + bounds_lng_lat[3]) / 2, (bounds_lng_lat[0] + bounds_lng_lat[2]) / 2

    folium.Circle(
        location=[lat, lon],
        radius=500,
        color='red',
        fill=True,
        popup=f"Event {row['folder_name']}"
    ).add_to(map_plumes)

map_plumes

In [None]:
import torch
import os
import numpy as np
import rasterio
from torch.utils.data import Dataset
import torchvision.transforms as transforms

# Function to load raster images
def load_raster(image_path):
    if not os.path.exists(image_path):
        return None
    with rasterio.open(image_path) as src:
        return src.read(1)

# Methane Emission Dataset with Multi-Spectral Inputs
class MethaneEmissionDataset(Dataset):
    def __init__(self, df, train_data_dir, mode="rgb+mag1c", transform=None):
        self.df = df.reset_index(drop=True)
        self.train_data_dir = train_data_dir
        self.mode = mode
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        folder_path = os.path.join(self.train_data_dir, row["folder_name"])

        # Load bands
        def load(name):
            return load_raster(os.path.join(folder_path, name))

        r = load("TOA_AVIRIS_640nm.tif")
        g = load("TOA_AVIRIS_550nm.tif")
        b = load("TOA_AVIRIS_460nm.tif")
        mag1c = load("mag1c.tif")
        gt = load("labelbinary.tif")

        if any(x is None for x in [r, g, b, gt]):
            return self.__getitem__((idx + 1) % len(self.df))

        def normalize(img):
            return img / np.max(img) if np.max(img) > 0 else img

        r, g, b = map(normalize, [r, g, b])
        rgb = torch.tensor(np.stack([r, g, b], axis=-1), dtype=torch.float32).permute(2, 0, 1)

        if self.transform:
            rgb = self.transform(rgb)

        if self.mode == "rgb":
            final_image = rgb
        elif self.mode == "rgb+mag1c":
            mag1c = normalize(mag1c)
            mag1c_tensor = torch.tensor(mag1c, dtype=torch.float32).unsqueeze(0)
            final_image = torch.cat([rgb, mag1c_tensor], dim=0)
        elif self.mode == "multispectral":
            av_2004 = load("TOA_AVIRIS_2004nm.tif")
            swir1 = load("TOA_WV3_SWIR1.tif")
            if any(x is None for x in [mag1c, av_2004, swir1]):
                return self.__getitem__((idx + 1) % len(self.df))
            mag1c = normalize(mag1c)
            av_2004 = normalize(av_2004)
            swir1 = normalize(swir1)
            additional = torch.tensor(np.stack([mag1c, av_2004, swir1]), dtype=torch.float32)
            final_image = torch.cat([rgb, additional], dim=0)
        else:
            raise ValueError(f"Unknown mode: {self.mode}")

        gt_tensor = torch.tensor(gt, dtype=torch.float32).unsqueeze(0)
        gt_tensor[gt_tensor > 0] = 1
        return final_image, gt_tensor





# Data transformations (Apply only on RGB)
data_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(degrees=30),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
])

# Split dataset
train_size = int(0.8 * len(df_train_filtered))
val_size = len(df_train_filtered) - train_size

train_dataset = MethaneEmissionDataset(df_train_filtered.iloc[:train_size], train_data_dir, transform=data_transforms)
val_dataset = MethaneEmissionDataset(df_train_filtered.iloc[train_size:], train_data_dir, transform=None)

# DataLoaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import rasterio
import os
from mpl_toolkits.axes_grid1 import make_axes_locatable


def visualize_multispectral(folder_path):
    # Load images
    r = load_raster(os.path.join(folder_path, "TOA_AVIRIS_640nm.tif"))
    g = load_raster(os.path.join(folder_path, "TOA_AVIRIS_550nm.tif"))
    b = load_raster(os.path.join(folder_path, "TOA_AVIRIS_460nm.tif"))
    mag1c = load_raster(os.path.join(folder_path, "mag1c.tif"))
    aviris_2004 = load_raster(os.path.join(folder_path, "TOA_AVIRIS_2004nm.tif"))
    swir1 = load_raster(os.path.join(folder_path, "TOA_WV3_SWIR1.tif"))
    gt = load_raster(os.path.join(folder_path, "labelbinary.tif"))


    if any(x is None for x in [r, g, b, mag1c, aviris_2004, swir1, gt]):
        print("Skipping due to missing data.")
        return

    # Normalize images
    def normalize(img):
        return img / np.max(img) if np.max(img) > 0 else img

    r, g, b, mag1c, aviris_2004, swir1 = map(normalize, [r, g, b, mag1c, aviris_2004, swir1])

    # Stack RGB
    rgb = np.stack([r, g, b], axis=-1)


    fig, axes = plt.subplots(2, 3, figsize=(18, 12))

    # RGB Composite
    axes[0, 0].imshow(rgb)
    axes[0, 0].set_title("RGB Composite")
    axes[0, 0].axis("off")

    # Mag1c
    axes[0, 1].imshow(mag1c, cmap="magma")
    axes[0, 1].set_title("Mag1c")
    axes[0, 1].axis("off")

    # AVIRIS 2004nm [hyperspectral]
    axes[0, 2].imshow(aviris_2004, cmap="viridis")
    axes[0, 2].set_title("AVIRIS 2004nm")
    axes[0, 2].axis("off")

    # SWIR1 [multispectral]
    axes[1, 0].imshow(swir1, cmap="inferno")
    axes[1, 0].set_title("SWIR1")
    axes[1, 0].axis("off")

    # GT Mask
    im = axes[1, 1].imshow(gt, cmap="gray")
    axes[1, 1].set_title("Ground Truth Mask")
    axes[1, 1].axis("off")
    divider = make_axes_locatable(axes[1, 1])
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)

    plt.tight_layout()
    plt.show()
  
  # Example usage with a valid folder path
sample_folder = "/content/starcop_data/STARCOP_train_easy/ang20190923t174142_r4096_c0_w512_h512"
visualize_multispectral(sample_folder)



In [None]:
for imgs, masks in train_loader:
    print(f"Input batch shape: {imgs.shape}")
    break

In [None]:
import torch
import os
import numpy as np
import rasterio
from torch.utils.data import Dataset

class MethaneEmissionDataset(Dataset):
    def __init__(self, df, train_data_dir, mode="rgb", transform=None):
        self.df = df.reset_index(drop=True)
        self.train_data_dir = train_data_dir
        self.mode = mode
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        folder_path = os.path.join(self.train_data_dir, row["folder_name"])

        def load_image(name):
            path = os.path.join(folder_path, name)
            if os.path.exists(path):
                with rasterio.open(path) as src:
                    return src.read(1).astype(np.float32)
            return None

        # Load channels
        r = load_image("TOA_AVIRIS_640nm.tif")
        g = load_image("TOA_AVIRIS_550nm.tif")
        b = load_image("TOA_AVIRIS_460nm.tif")
        mag1c = load_image("mag1c.tif")
        av_2004 = load_image("TOA_AVIRIS_2004nm.tif")
        swir1 = load_image("TOA_WV3_SWIR1.tif")
        gt = load_image("labelbinary.tif")

        if any(x is None for x in [r, g, b, gt]):
            return self.__getitem__((idx + 1) % len(self.df))

        # Normalize images
        def normalize(img):
            return img / np.max(img) if np.max(img) > 0 else img

        r, g, b = map(normalize, [r, g, b])
        gt = (gt > 0).astype(np.float32)

        # Select input mode
        if self.mode == "rgb":
            input_tensor = torch.tensor(np.stack([r, g, b], axis=-1), dtype=torch.float32).permute(2, 0, 1)
        elif self.mode == "rgb+mag1c":
            mag1c = normalize(mag1c)
            input_tensor = torch.tensor(np.stack([r, g, b, mag1c], axis=-1), dtype=torch.float32).permute(2, 0, 1)
        elif self.mode == "multispectral":
            if any(x is None for x in [mag1c, av_2004, swir1]):
                return self.__getitem__((idx + 1) % len(self.df))
            mag1c, av_2004, swir1 = map(normalize, [mag1c, av_2004, swir1])
            input_tensor = torch.tensor(np.stack([r, g, b, mag1c, av_2004, swir1], axis=-1), dtype=torch.float32).permute(2, 0, 1)
        else:
            raise ValueError(f"Invalid mode: {self.mode}")

        gt_tensor = torch.tensor(gt, dtype=torch.float32).unsqueeze(0)
        return input_tensor, gt_tensor


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class UNet(nn.Module):
    def __init__(self, num_classes=1, in_channels=3):
        super(UNet, self).__init__()
        self.encoder = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        
        # Modify input layer to handle different in_channels
        self.encoder.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Encoder layers (ResNet outputs)
        self.enc1 = self.encoder.layer1  
        self.enc2 = self.encoder.layer2  
        self.enc3 = self.encoder.layer3  
        self.enc4 = self.encoder.layer4  

        # Decoder w/proper upsampling
        self.upconv1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv1 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)

        self.upconv3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)

        self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.encoder.relu(self.encoder.bn1(self.encoder.conv1(x)))  
        x2 = self.enc1(x1)  
        x3 = self.enc2(x2)  
        x4 = self.enc3(x3)  
        x5 = self.enc4(x4)  


        x = self.upconv1(x5)  
        x = torch.cat([x, x4], dim=1)  
        x = self.bn1(F.relu(self.conv1(x)))

        x = self.upconv2(x)  
        x = torch.cat([x, x3], dim=1)
        x = self.bn2(F.relu(self.conv2(x)))

        x = self.upconv3(x)  
        x = torch.cat([x, x2], dim=1)
        x = self.bn3(F.relu(self.conv3(x)))

        x = self.final_conv(x)  
        return torch.sigmoid(x)  


In [None]:
from tqdm import tqdm
import torch.optim as optim

def train_model(model, train_loader, val_loader, num_epochs=5, lr=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Define loss function inside
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(1, num_epochs + 1):
        model.train()
        train_loss = 0

        loop = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs} - Training")
        for images, masks in loop:
            images, masks = images.to(device), masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            # Resize masks to match output size
            masks_resized = F.interpolate(masks, size=outputs.shape[2:], mode="nearest")

            loss = criterion(outputs, masks_resized)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            loop.set_postfix(loss=train_loss / len(train_loader))

        # Validation 
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                masks_resized = F.interpolate(masks, size=outputs.shape[2:], mode="nearest")
                loss = criterion(outputs, masks_resized)
                val_loss += loss.item()

        print(f"Epoch [{epoch}/{num_epochs}] | Train Loss: {train_loss/len(train_loader):.4f} | Val Loss: {val_loss/len(val_loader):.4f}")

    return model


In [None]:
# RGB Model (in_channels=3)
train_dataset_rgb = MethaneEmissionDataset(df_train_filtered, train_data_dir, mode="rgb")  
train_loader_rgb = torch.utils.data.DataLoader(train_dataset_rgb, batch_size=8, shuffle=True, num_workers=2)
model_rgb = UNet(in_channels=3)
model_rgb = train_model(model_rgb, train_loader_rgb, val_loader, num_epochs=5)

# RGB + Mag1c Model (in_channels=4)
train_dataset_rgb_mag1c = MethaneEmissionDataset(df_train_filtered, train_data_dir, mode="rgb+mag1c")  
train_loader_rgb_mag1c = torch.utils.data.DataLoader(train_dataset_rgb_mag1c, batch_size=8, shuffle=True, num_workers=2)
model_rgb_mag1c = UNet(in_channels=4)
model_rgb_mag1c = train_model(model_rgb_mag1c, train_loader_rgb_mag1c, val_loader, num_epochs=5)

# Multispectral Model (in_channels=6)
train_dataset_multispectral = MethaneEmissionDataset(df_train_filtered, train_data_dir, mode="multispectral")  
train_loader_multispectral = torch.utils.data.DataLoader(train_dataset_multispectral, batch_size=8, shuffle=True, num_workers=2)
model_multispectral = UNet(in_channels=6)
model_multispectral = train_model(model_multispectral, train_loader_multispectral, val_loader, num_epochs=5)
