In [None]:
# In your setup cell (Cell 1), after mounting drive:
!pip install timm torchvision tifffile imagecodecs


Collecting imagecodecs
  Downloading imagecodecs-2025.3.30-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Downloading imagecodecs-2025.3.30-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (45.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.6/45.6 MB[0m [31m18.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: imagecodecs
Successfully installed imagecodecs-2025.3.30


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Cell 1: Mount Drive & install
from google.colab import drive
drive.mount('/content/drive')

!pip install timm torchvision tifffile


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Cell 2 ▶ Consolidated imports
import os
import math
import time

import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

import timm
from torchvision import transforms, models

from tifffile import imread

In [None]:
# Cell 1 ▶ Dataset (224×224 targets, no down-scaling)
class LSTDataset(Dataset):
    def __init__(self, df, patches_dir, weather_cols):
        self.df           = df.reset_index(drop=True)
        self.patches_dir  = patches_dir
        self.weather_cols = weather_cols
        self.transform    = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),          # resize image
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std =[0.229, 0.224, 0.225]),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row  = self.df.loc[idx]
        arr  = imread(os.path.join(self.patches_dir, row["filename"])
                     ).astype(np.float32)            # (4,H,W)

        # --- inputs -------------------------------------------------
        img_np = arr[[1,2,3]].transpose(1,2,0).astype(np.uint8)
        img    = self.transform(img_np)              # [3,224,224]

        # --- target (LST) at *full* 224×224 --------------------------
        lst    = arr[0]                              # (H,W)
        lst    = torch.tensor(lst, dtype=torch.float32).unsqueeze(0)
        lst    = F.interpolate(lst.unsqueeze(0), size=(224,224),
                               mode='bilinear', align_corners=False
                              ).squeeze(0)           # [1,224,224]

        # --- meteorology vector -------------------------------------
        weather = torch.tensor(
            row[self.weather_cols].values.astype(np.float32)
        )
        return img, weather, lst


In [None]:
# define which meteorological columns to pull
weather_cols = [
    "air_temp_C",
    "dew_point_C",
    "relative_humidity_percent",
    "wind_speed_m_s",
    "precipitation_in",
]

In [None]:
df = pd.read_csv("/content/drive/MyDrive/PatchedOutput/tiff_with_meteo.csv")
for col in weather_cols:
    df[col] = pd.to_numeric(df[col], errors="coerce")
df = df.dropna(subset=weather_cols + ["filename"]).reset_index(drop=True)

patches_dir = "/content/drive/MyDrive/PatchedOutput_Cleaned"
dataset     = LSTDataset(df, patches_dir, weather_cols)
train_sz    = int(0.8 * len(dataset))
val_sz      = len(dataset) - train_sz
train_ds, val_ds = random_split(dataset, [train_sz, val_sz])

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True,  num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch_size=4, shuffle=False, num_workers=0, pin_memory=False)

In [None]:
# ─── Cell 4: VIT-UNET MODEL DEFINITION ───────────────────────────────────────────
class ViTUNet(nn.Module):
    def __init__(self, in_channels=3, weather_dim=5, base_channels=64,
                 num_heads=8, trans_layers=2, trans_dim=512):
        super().__init__()
        def conv_block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)
            )
        # Encoder
        self.conv1 = conv_block(in_channels, base_channels)
        self.conv2 = conv_block(base_channels, base_channels*2)
        self.conv3 = conv_block(base_channels*2, base_channels*4)
        self.conv4 = conv_block(base_channels*4, base_channels*8)
        self.conv5 = conv_block(base_channels*8, base_channels*8)
        self.pool  = nn.MaxPool2d(2)
        # Weather projection
        self.weather_proj = nn.Linear(weather_dim, trans_dim)
        # Bottleneck + Transformer
        self.bottleneck_proj = nn.Conv2d(base_channels*8, trans_dim, 1)
        self.pos_embed       = nn.Parameter(torch.zeros(1, 14*14, trans_dim))
        self.weather_token   = nn.Parameter(torch.zeros(1,1,trans_dim))
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=trans_dim, nhead=num_heads,
            dim_feedforward=trans_dim*4, dropout=0.1, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=trans_layers)
        # Decoder
        self.up4 = nn.ConvTranspose2d(trans_dim,     base_channels*8, 2,2)
        self.up3 = nn.ConvTranspose2d(base_channels*8, base_channels*4, 2,2)
        self.up2 = nn.ConvTranspose2d(base_channels*4, base_channels*2, 2,2)
        self.up1 = nn.ConvTranspose2d(base_channels*2, base_channels,   2,2)
        self.dec4 = conv_block(base_channels*8 + base_channels*8, base_channels*8)
        self.dec3 = conv_block(base_channels*4 + base_channels*4, base_channels*4)
        self.dec2 = conv_block(base_channels*2 + base_channels*2, base_channels*2)
        self.dec1 = conv_block(base_channels   + base_channels,   base_channels)
        self.final_conv = nn.Conv2d(base_channels, 1, 1)

    def forward(self, x, weather):
        B,C,H,W = x.shape
        e1 = self.conv1(x)
        e2 = self.conv2(self.pool(e1))
        e3 = self.conv3(self.pool(e2))
        e4 = self.conv4(self.pool(e3))
        e5 = self.conv5(self.pool(e4))

        bt      = self.bottleneck_proj(e5)
        N       = bt.shape[2]*bt.shape[3]
        bt_flat = bt.view(B, bt.shape[1], N).permute(0,2,1)
        bt_pos  = bt_flat + self.pos_embed[:, :N, :]
        w_tok   = self.weather_proj(weather).unsqueeze(1) + self.weather_token
        trans   = torch.cat([bt_pos, w_tok], dim=1)
        out_t   = self.transformer(trans)
        spat    = out_t[:, :-1, :]
        fs      = int(math.sqrt(spat.size(1)))
        f_ts    = spat.permute(0,2,1).view(B, bt.shape[1], fs, fs)

        d4 = self.dec4(torch.cat([self.up4(f_ts), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4),  e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3),  e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2),  e1], dim=1))
        return self.final_conv(d1)


In [None]:
# ─── Cell 5: INITIALIZE MODEL, SmoothL1Loss, OPTIMIZER & LR SCHEDULER ─────────
device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model     = ViTUNet(in_channels=3, weather_dim=len(weather_cols)).to(device)

# Use SmoothL1 (Huber) loss instead of MSE
loss_fn = nn.SmoothL1Loss()

# Optimizer: only on parameters that require grad
opt = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4,
    weight_decay=1e-2
)

# Reduce LR on plateau of validation loss
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    opt,
    mode='min',
    factor=0.5,
    patience=3,
    verbose=True
)

scaler = GradScaler()

print("✅ Model, SmoothL1Loss, optimizer & scheduler ready")


✅ Model, SmoothL1Loss, optimizer & scheduler ready


  scaler = GradScaler()


In [None]:
# ─── Cell 6: TRAIN & VALIDATE (PRINT TRAIN/VAL SmoothL1) ──────────────────────
import warnings
# suppress the per-iteration AMP FutureWarning that was cluttering the console
warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    module="torch\\.cuda\\.amp"
)

num_epochs = 10
for epoch in range(1, num_epochs+1):
    # — Training —
    model.train()
    train_loss = 0.0
    # one clean tqdm line
    train_bar = tqdm(
        train_loader,
        desc=f"Epoch {epoch:02d} ▶ Train",
        leave=False,
        dynamic_ncols=True
    )
    for imgs, weather, masks in train_bar:
        imgs, weather, masks = imgs.to(device), weather.to(device), masks.to(device)
        opt.zero_grad()
        # use the new autocast API to avoid the warning
        with torch.amp.autocast(device_type="cuda"):
            preds = model(imgs, weather)
            loss  = loss_fn(preds, masks)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()

        train_loss += loss.item() * imgs.size(0)
        train_bar.set_postfix(smoothl1=f"{loss.item():.4f}")

    avg_train = train_loss / len(train_loader.dataset)
    print(f"Epoch {epoch:02d} ▶ Train SmoothL1: {avg_train:.4f}")

    # — Validation —
    model.eval()
    val_loss = 0.0
    val_bar  = tqdm(
        val_loader,
        desc=f"Epoch {epoch:02d} ▶ Val  ",
        leave=False,
        dynamic_ncols=True
    )
    with torch.no_grad():
        for imgs, weather, masks in val_bar:
            imgs, weather, masks = imgs.to(device), weather.to(device), masks.to(device)
            with torch.amp.autocast(device_type="cuda"):
                preds = model(imgs, weather)
                loss  = loss_fn(preds, masks)

            val_loss += loss.item() * imgs.size(0)
            val_bar.set_postfix(smoothl1=f"{loss.item():.4f}")

    avg_val = val_loss / len(val_loader.dataset)
    print(f"           Val   SmoothL1: {avg_val:.4f}\n")

    # — Update LR scheduler on validation metric —
    scheduler.step(avg_val)




Epoch 01 ▶ Train SmoothL1: 2.1559




           Val   SmoothL1: 2.1402





Epoch 02 ▶ Train SmoothL1: 1.2872




           Val   SmoothL1: 1.9945





Epoch 03 ▶ Train SmoothL1: 0.9879




           Val   SmoothL1: 0.9034





Epoch 04 ▶ Train SmoothL1: 0.8334




           Val   SmoothL1: 1.1317





Epoch 05 ▶ Train SmoothL1: 0.7943




           Val   SmoothL1: 1.3011





Epoch 06 ▶ Train SmoothL1: 0.6681




           Val   SmoothL1: 1.4668





Epoch 07 ▶ Train SmoothL1: 0.5610




           Val   SmoothL1: 0.9393





Epoch 08 ▶ Train SmoothL1: 0.3853




           Val   SmoothL1: 2.0100





Epoch 09 ▶ Train SmoothL1: 0.3270




           Val   SmoothL1: 1.5663





Epoch 10 ▶ Train SmoothL1: 0.2831


                                                                                    

           Val   SmoothL1: 1.2631





In [None]:
# ─── Cell 3: DATASET WITH AGGRESSIVE AUGMENTATION ───────────────────────────────
class LSTDataset(Dataset):
    def __init__(self, df, patches_dir, weather_cols):
        self.df           = df.reset_index(drop=True)
        self.patches_dir  = patches_dir
        self.weather_cols = weather_cols
        self.transform    = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224,224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(20),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row  = self.df.loc[idx]
        arr  = imread(os.path.join(self.patches_dir, row["filename"])).astype(np.float32)

        # Image input
        img_np = arr[[1,2,3]].transpose(1,2,0).astype(np.uint8)
        img    = self.transform(img_np)

        # Target LST
        lst    = torch.tensor(arr[0], dtype=torch.float32).unsqueeze(0)
        lst    = F.interpolate(lst.unsqueeze(0), size=(224,224),
                               mode='bilinear', align_corners=False
                              ).squeeze(0)

        # Weather vector
        weather = torch.tensor(
            row[self.weather_cols].values.astype(np.float32)
        )
        return img, weather, lst


In [None]:
# define which meteorological columns to pull
weather_cols = [
    "air_temp_C",
    "dew_point_C",
    "relative_humidity_percent",
    "wind_speed_m_s",
    "precipitation_in",
]
df = pd.read_csv("/content/drive/MyDrive/PatchedOutput/tiff_with_meteo.csv")
for col in weather_cols:
    df[col] = pd.to_numeric(df[col], errors="coerce")
df = df.dropna(subset=weather_cols + ["filename"]).reset_index(drop=True)

patches_dir = "/content/drive/MyDrive/PatchedOutput_Cleaned"
dataset     = LSTDataset(df, patches_dir, weather_cols)
train_sz    = int(0.8 * len(dataset))
val_sz      = len(dataset) - train_sz
train_ds, val_ds = random_split(dataset, [train_sz, val_sz])

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True,  num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch_size=4, shuffle=False, num_workers=0, pin_memory=False)

In [None]:
# ─── Cell 4: ViT-UNet WITH SPATIAL DROPOUT ─────────────────────────────────────
class ViTUNet(nn.Module):
    def __init__(self, in_channels=3, weather_dim=5, base_channels=64,
                 num_heads=8, trans_layers=2, trans_dim=512):
        super().__init__()
        def conv_block(in_ch, out_ch, p_drop=0.2):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Dropout2d(p_drop),
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
            )
        # Encoder with dropout
        self.conv1 = conv_block(in_channels,    base_channels,    p_drop=0.2)
        self.conv2 = conv_block(base_channels,  base_channels*2,  p_drop=0.2)
        self.conv3 = conv_block(base_channels*2,base_channels*4,  p_drop=0.2)
        self.conv4 = conv_block(base_channels*4,base_channels*8,  p_drop=0.2)
        self.conv5 = conv_block(base_channels*8,base_channels*8,  p_drop=0.2)
        self.pool  = nn.MaxPool2d(2)

        # Weather projection & transformer (unchanged)…
        self.weather_proj   = nn.Linear(weather_dim, trans_dim)
        self.bottleneck_proj = nn.Conv2d(base_channels*8, trans_dim, 1)
        self.pos_embed       = nn.Parameter(torch.zeros(1, 14*14, trans_dim))
        self.weather_token   = nn.Parameter(torch.zeros(1,1,trans_dim))
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=trans_dim, nhead=num_heads,
            dim_feedforward=trans_dim*4, dropout=0.1, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=trans_layers)

        # Decoder upsample & convs (unchanged)…
        self.up4 = nn.ConvTranspose2d(trans_dim,     base_channels*8, 2,2)
        self.up3 = nn.ConvTranspose2d(base_channels*8, base_channels*4, 2,2)
        self.up2 = nn.ConvTranspose2d(base_channels*4, base_channels*2, 2,2)
        self.up1 = nn.ConvTranspose2d(base_channels*2, base_channels,   2,2)
        self.dec4 = conv_block(base_channels*8 + base_channels*8, base_channels*8, p_drop=0.2)
        self.dec3 = conv_block(base_channels*4 + base_channels*4, base_channels*4, p_drop=0.2)
        self.dec2 = conv_block(base_channels*2 + base_channels*2, base_channels*2, p_drop=0.2)
        self.dec1 = conv_block(base_channels   + base_channels,   base_channels,   p_drop=0.2)
        self.final_conv = nn.Conv2d(base_channels, 1, 1)

    def forward(self, x, weather):
        # same forward as before…
        B,C,H,W = x.shape
        e1 = self.conv1(x)
        e2 = self.conv2(self.pool(e1))
        e3 = self.conv3(self.pool(e2))
        e4 = self.conv4(self.pool(e3))
        e5 = self.conv5(self.pool(e4))

        bt      = self.bottleneck_proj(e5)
        N       = bt.shape[2]*bt.shape[3]
        bt_flat = bt.view(B, bt.shape[1], N).permute(0,2,1)
        bt_pos  = bt_flat + self.pos_embed[:, :N, :]
        w_tok   = self.weather_proj(weather).unsqueeze(1) + self.weather_token
        trans   = torch.cat([bt_pos, w_tok], dim=1)
        out_t   = self.transformer(trans)
        spat    = out_t[:, :-1, :]
        fs      = int(math.sqrt(spat.size(1)))
        f_ts    = spat.permute(0,2,1).view(B, bt.shape[1], fs, fs)

        d4 = self.dec4(torch.cat([self.up4(f_ts), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4),  e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3),  e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2),  e1], dim=1))
        return self.final_conv(d1)


In [None]:
# ─── Cell 5: INITIALIZE MODEL, SmoothL1Loss, OPTIMIZER & LR SCHEDULER ─────────
device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model     = ViTUNet(in_channels=3, weather_dim=len(weather_cols)).to(device)

# Use SmoothL1 (Huber) loss instead of MSE
loss_fn = nn.SmoothL1Loss()

# Optimizer: only on parameters that require grad
opt = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4,
    weight_decay=1e-2
)

# Reduce LR on plateau of validation loss
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    opt,
    mode='min',
    factor=0.5,
    patience=3,
    verbose=True
)

scaler = GradScaler()

print("✅ Model, SmoothL1Loss, optimizer & scheduler ready")


✅ Model, SmoothL1Loss, optimizer & scheduler ready


  scaler = GradScaler()


In [None]:
# ─── Cell 6: TRAIN & VALIDATE WITH EARLY STOPPING ──────────────────────────────
import warnings
warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    module="torch\\.cuda\\.amp"
)

patience           = 5
best_val           = float('inf')
epochs_no_improve  = 0
num_epochs         = 10

for epoch in range(1, num_epochs+1):
    # —— Training ——
    model.train()
    train_loss = 0.0
    train_bar  = tqdm(train_loader, desc=f"Epoch {epoch:02d} ▶ Train", leave=False)
    for imgs, weather, masks in train_bar:
        imgs, weather, masks = imgs.to(device), weather.to(device), masks.to(device)
        opt.zero_grad()
        with torch.amp.autocast(device_type="cuda"):
            preds = model(imgs, weather)
            loss  = loss_fn(preds, masks)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()

        train_loss += loss.item() * imgs.size(0)
        train_bar.set_postfix(smoothl1=f"{loss.item():.4f}")

    avg_train = train_loss / len(train_loader.dataset)
    print(f"Epoch {epoch:02d} ▶ Train SmoothL1: {avg_train:.4f}")

    # —— Validation ——
    model.eval()
    val_loss = 0.0
    val_bar  = tqdm(val_loader, desc=f"Epoch {epoch:02d} ▶ Val  ", leave=False)
    with torch.no_grad():
        for imgs, weather, masks in val_bar:
            imgs, weather, masks = imgs.to(device), weather.to(device), masks.to(device)
            with torch.amp.autocast(device_type="cuda"):
                preds = model(imgs, weather)
                loss  = loss_fn(preds, masks)

            val_loss += loss.item() * imgs.size(0)
            val_bar.set_postfix(smoothl1=f"{loss.item():.4f}")

    avg_val = val_loss / len(val_loader.dataset)
    print(f"           Val   SmoothL1: {avg_val:.4f}\n")

    # Scheduler step and early stopping
    scheduler.step(avg_val)
    if avg_val < best_val:
        best_val          = avg_val
        epochs_no_improve = 0
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"↳ Early stopping at epoch {epoch}")
            break




Epoch 01 ▶ Train SmoothL1: 2.9314




           Val   SmoothL1: 2.2961





Epoch 02 ▶ Train SmoothL1: 2.5191




           Val   SmoothL1: 4.2387





Epoch 03 ▶ Train SmoothL1: 2.5855




           Val   SmoothL1: 2.3298





Epoch 04 ▶ Train SmoothL1: 2.3638




           Val   SmoothL1: 2.3189





Epoch 05 ▶ Train SmoothL1: 2.2857




           Val   SmoothL1: 2.2814





Epoch 06 ▶ Train SmoothL1: 2.2344




           Val   SmoothL1: 2.3143





Epoch 07 ▶ Train SmoothL1: 2.1900




           Val   SmoothL1: 2.2961





Epoch 08 ▶ Train SmoothL1: 2.2054




           Val   SmoothL1: 2.1888





Epoch 09 ▶ Train SmoothL1: 2.1270




           Val   SmoothL1: 2.3965





Epoch 10 ▶ Train SmoothL1: 2.1667


                                                                                    

           Val   SmoothL1: 2.2493





In [None]:
# Cell 2 ▶ Consolidated imports
import os
import math
import time

import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

import timm
from torchvision import transforms, models

from tifffile import imread

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

def conv_block(in_ch, out_ch):
    return nn.Sequential(
        nn.Conv2d(in_ch,  out_ch, 3, padding=1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_ch, out_ch, 3, padding=1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(inplace=True),
    )

class ViTUNet(nn.Module):
    def __init__(self, weather_dim=5, trans_dim=512, num_heads=8, trans_layers=2):
        super().__init__()
        # Encoder: ResNet34 backbone
        self.backbone = timm.create_model(
            'resnet34', pretrained=True, features_only=True,
            out_indices=[0,1,2,3,4], in_chans=3
        )
        feats = self.backbone.feature_info.channels()  # [64, 64, 128, 256, 512]

        # Weather token & pos-embed
        self.weather_proj   = nn.Linear(weather_dim, trans_dim)
        self.weather_token  = nn.Parameter(torch.zeros(1, 1, trans_dim))
        self.pos_embed      = nn.Parameter(torch.zeros(1, 14*14, trans_dim))

        # Bottleneck projection & transformer
        self.bottleneck_proj = nn.Conv2d(feats[-1], trans_dim, 1)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=trans_dim, nhead=num_heads,
            dim_feedforward=trans_dim*4, dropout=0.1, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=trans_layers)

        # Decoder upsample & convs
        self.up4 = nn.ConvTranspose2d(trans_dim,   feats[3], 2, 2)
        self.dec4 = conv_block(feats[3]*2, feats[3])
        self.up3 = nn.ConvTranspose2d(feats[3],    feats[2], 2, 2)
        self.dec3 = conv_block(feats[2]*2, feats[2])
        self.up2 = nn.ConvTranspose2d(feats[2],    feats[1], 2, 2)
        self.dec2 = conv_block(feats[1]*2, feats[1])
        self.up1 = nn.ConvTranspose2d(feats[1],    feats[0], 2, 2)
        self.dec1 = conv_block(feats[0]*2, feats[0])
        self.final_conv = nn.Conv2d(feats[0], 1, 1)

    def forward(self, x, weather):
        # 1) remember input size
        orig_h, orig_w = x.shape[2], x.shape[3]

        # 2) Encoder
        e1, e2, e3, e4, e5 = self.backbone(x)   # 224→112→56→28→14

        # 3) Bottleneck + Transformer prep
        bt = self.bottleneck_proj(e5)           # [B, C, 14, 14]
        B, C, H, W = bt.shape
        N = H * W
        feat = bt.view(B, C, N).permute(0, 2, 1)      # [B, N, C]
        feat = feat + self.pos_embed[:, :N, :]       # add positional
        w_tok = self.weather_proj(weather).unsqueeze(1) + self.weather_token
        trans = torch.cat([feat, w_tok], dim=1)       # [B, N+1, C]
        out_t = self.transformer(trans)               # [B, N+1, C]

        # 4) reshape back to 2D feature map
        feat = out_t[:, :-1, :]                       # drop weather token → [B, N, C]
        fs = int(math.sqrt(feat.size(1)))             # should be 14
        f_ts = feat.permute(0, 2, 1).view(B, C, fs, fs)  # [B, C, 14, 14]

        # 5) Decoder
        d4 = self.dec4(torch.cat([self.up4(f_ts), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4),  e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3),  e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2),  e1], dim=1))

        # 6) Final conv + upsample back to 224×224
        out = self.final_conv(d1)
        out = F.interpolate(
            out,
            size=(orig_h, orig_w),
            mode='bilinear',
            align_corners=False
        )
        return out


In [None]:
# ─── Cell 3: DATASET WITH AGGRESSIVE AUGMENTATION ───────────────────────────────
class LSTDataset(Dataset):
    def __init__(self, df, patches_dir, weather_cols):
        self.df           = df.reset_index(drop=True)
        self.patches_dir  = patches_dir
        self.weather_cols = weather_cols
        self.transform    = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224,224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(20),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row  = self.df.loc[idx]
        arr  = imread(os.path.join(self.patches_dir, row["filename"])).astype(np.float32)

        # Image input
        img_np = arr[[1,2,3]].transpose(1,2,0).astype(np.uint8)
        img    = self.transform(img_np)

        # Target LST
        lst    = torch.tensor(arr[0], dtype=torch.float32).unsqueeze(0)
        lst    = F.interpolate(lst.unsqueeze(0), size=(224,224),
                               mode='bilinear', align_corners=False
                              ).squeeze(0)

        # Weather vector
        weather = torch.tensor(
            row[self.weather_cols].values.astype(np.float32)
        )
        return img, weather, lst


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# define which meteorological columns to pull
weather_cols = [
    "air_temp_C",
    "dew_point_C",
    "relative_humidity_percent",
    "wind_speed_m_s",
    "precipitation_in",
]
df = pd.read_csv("/content/drive/MyDrive/PatchedOutput/tiff_with_meteo.csv")
for col in weather_cols:
    df[col] = pd.to_numeric(df[col], errors="coerce")
df = df.dropna(subset=weather_cols + ["filename"]).reset_index(drop=True)

patches_dir = "/content/drive/MyDrive/PatchedOutput_Cleaned"
dataset     = LSTDataset(df, patches_dir, weather_cols)
train_sz    = int(0.8 * len(dataset))
val_sz      = len(dataset) - train_sz
train_ds, val_ds = random_split(dataset, [train_sz, val_sz])

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True,  num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch_size=4, shuffle=False, num_workers=0, pin_memory=False)

In [None]:
# ─── Cell 0: Install TIFF LZW support ─────────────────────────────────────────
!pip install imagecodecs


Collecting imagecodecs
  Downloading imagecodecs-2025.3.30-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Downloading imagecodecs-2025.3.30-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (45.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.6/45.6 MB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: imagecodecs
Successfully installed imagecodecs-2025.3.30


In [None]:
# ─── Cell 0: install the SSIM library ─────────────────────────────────────────
!pip install piq


Collecting piq
  Downloading piq-0.8.0-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.6.0->torchvision>=0.10.0->piq)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.6.0->torchvision>=0.10.0->piq)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.6.0->torchvision>=0.10.0->piq)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.6.0->torchvision>=0.10.0->piq)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch==2.6.0->torchvision>=0.10.0->piq)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
C

In [None]:
# ─── Cell 5: INITIALIZE MODEL, SmoothL1+SSIM LOSS & One-Cycle LR ─────────────
import torch
from torch.cuda.amp import GradScaler
from piq import ssim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = ViTUNet(weather_dim=len(weather_cols)).to(device)

# Losses
loss_fn = nn.SmoothL1Loss()

# Optimizer
opt = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-3, weight_decay=1e-2
)

# OneCycleLR scheduler
num_epochs     = 10
steps_per_epoch = len(train_loader)
scheduler      = torch.optim.lr_scheduler.OneCycleLR(
    opt,
    max_lr=1e-3,
    total_steps=steps_per_epoch * num_epochs,
    pct_start=0.3,
    anneal_strategy='cos',
    div_factor=25.0,
    final_div_factor=1e4
)

scaler = GradScaler()
print("✅ Model, SmoothL1+SSIM, OneCycleLR ready")


✅ Model, SmoothL1+SSIM, OneCycleLR ready


  scaler = GradScaler()


In [None]:
from tqdm import tqdm

In [None]:
# ─── Cell 4: ViT-UNet WITH DEEP SUPERVISION ────────────────────────────────────
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class ViTUNetDS(nn.Module):
    def __init__(self, in_channels=3, weather_dim=5, base_channels=64,
                 num_heads=8, trans_layers=2, trans_dim=512):
        super().__init__()
        def conv_block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch,  out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
            )

        # Encoder
        self.conv1 = conv_block(in_channels,    base_channels)
        self.conv2 = conv_block(base_channels,  base_channels*2)
        self.conv3 = conv_block(base_channels*2,base_channels*4)
        self.conv4 = conv_block(base_channels*4,base_channels*8)
        self.conv5 = conv_block(base_channels*8,base_channels*8)
        self.pool  = nn.MaxPool2d(2)

        # Weather + Transformer
        self.weather_proj    = nn.Linear(weather_dim, trans_dim)
        self.weather_token   = nn.Parameter(torch.zeros(1,1,trans_dim))
        self.pos_embed       = nn.Parameter(torch.zeros(1,14*14,trans_dim))
        self.bottleneck_proj = nn.Conv2d(base_channels*8, trans_dim, 1)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=trans_dim, nhead=num_heads,
            dim_feedforward=trans_dim*4, dropout=0.1, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=trans_layers)

        # Decoder
        self.up4 = nn.ConvTranspose2d(trans_dim,     base_channels*8, 2,2)
        self.dec4 = conv_block(base_channels*8*2,   base_channels*8)
        self.up3 = nn.ConvTranspose2d(base_channels*8, base_channels*4, 2,2)
        self.dec3 = conv_block(base_channels*4*2,   base_channels*4)
        self.up2 = nn.ConvTranspose2d(base_channels*4, base_channels*2, 2,2)
        self.dec2 = conv_block(base_channels*2*2,   base_channels*2)
        self.up1 = nn.ConvTranspose2d(base_channels*2, base_channels,   2,2)
        self.dec1 = conv_block(base_channels*2,     base_channels)
        self.final_conv = nn.Conv2d(base_channels, 1, 1)

        # Deep-supervision heads
        self.aux_conv3 = nn.Conv2d(base_channels*4, 1, 1)  # from d3
        self.aux_conv2 = nn.Conv2d(base_channels*2, 1, 1)  # from d2

    def forward(self, x, weather):
        e1 = self.conv1(x)
        e2 = self.conv2(self.pool(e1))
        e3 = self.conv3(self.pool(e2))
        e4 = self.conv4(self.pool(e3))
        e5 = self.conv5(self.pool(e4))

        bt = self.bottleneck_proj(e5)
        B,C,H,W = bt.shape
        N = H*W
        bt_flat = bt.view(B,C,N).permute(0,2,1)
        bt_pos  = bt_flat + self.pos_embed[:, :N, :]
        w_tok   = self.weather_proj(weather).unsqueeze(1) + self.weather_token
        trans   = torch.cat([bt_pos, w_tok], dim=1)
        out_t   = self.transformer(trans)
        spat    = out_t[:, :-1, :]
        fs      = int(math.sqrt(spat.size(1)))
        f_ts    = spat.permute(0,2,1).view(B, C, fs, fs)

        d4 = self.dec4(torch.cat([self.up4(f_ts), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4),  e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3),  e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2),  e1], dim=1))

        main_out = self.final_conv(d1)
        aux3     = F.interpolate(self.aux_conv3(d3), size=(224,224),
                                 mode='bilinear', align_corners=False)
        aux2     = F.interpolate(self.aux_conv2(d2), size=(224,224),
                                 mode='bilinear', align_corners=False)
        return main_out, aux3, aux2


In [None]:
# ─── Cell 5: INIT MODEL, LOSS, OPTIMIZER & ONE-CYCLE LR ───────────────────────
from torch.cuda.amp import GradScaler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = ViTUNetDS(in_channels=3, weather_dim=len(weather_cols)).to(device)

# Loss
loss_fn = nn.MSELoss()

# Optimizer
opt = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4, weight_decay=1e-2
)

# One-Cycle LR scheduler
num_epochs      = 10
steps_per_epoch = len(train_loader)
scheduler       = torch.optim.lr_scheduler.OneCycleLR(
    opt,
    max_lr=1e-3,
    total_steps=steps_per_epoch * num_epochs,
    pct_start=0.3,
    anneal_strategy='cos',
    div_factor=25.0,
    final_div_factor=1e4
)

scaler = GradScaler()
print("✅ DS Model, optimizer & OneCycleLR ready")


✅ DS Model, optimizer & OneCycleLR ready


  scaler = GradScaler()


In [None]:
# ─── Cell 6: TRAIN/VAL with MixUp + Deep Supervision (fixed warnings & one‐line tqdm) ───
import numpy as np
import warnings
from tqdm import tqdm
from torch.cuda.amp import GradScaler
import torch.nn.functional as F

# suppress all FutureWarnings (including the autocast deprecation)
warnings.filterwarnings("ignore", category=FutureWarning)

def mixup_data(x, y, w, alpha=0.4):
    lam = np.random.beta(alpha, alpha) if alpha>0 else 1.0
    idx = torch.randperm(x.size(0), device=x.device)
    return lam*x + (1-lam)*x[idx], lam*w + (1-lam)*w[idx], lam*y + (1-lam)*y[idx]

patience, best_val, no_imp = 5, float('inf'), 0
scaler = GradScaler()

for epoch in range(1, num_epochs+1):
    # — Training —
    model.train()
    t_loss = 0
    train_bar = tqdm(
        train_loader,
        desc=f"Epoch {epoch:02d} ▶ Train",
        leave=False,
        dynamic_ncols=True
    )
    for imgs, weather, masks in train_bar:
        imgs, weather, masks = imgs.to(device), weather.to(device), masks.to(device)
        xim, wim, yim        = mixup_data(imgs, masks, weather, alpha=0.4)

        opt.zero_grad()
        # use the new autocast API to avoid deprecation warnings
        with torch.amp.autocast("cuda"):
            pm, pa3, pa2 = model(xim, wim)
            l0 = loss_fn(pm, yim)
            l3 = loss_fn(pa3, yim)
            l2 = loss_fn(pa2, yim)
            loss = l0 + 0.5*(l3 + l2)

        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        scheduler.step()

        t_loss += loss.item() * imgs.size(0)
        train_bar.set_postfix(loss=f"{loss.item():.4f}")

    avg_t = t_loss / len(train_loader.dataset)
    print(f"Epoch {epoch:02d} ▶ Train Loss: {avg_t:.4f}")

    # — Validation —
    model.eval()
    v_loss = 0
    val_bar = tqdm(
        val_loader,
        desc=f"Epoch {epoch:02d} ▶ Val  ",
        leave=False,
        dynamic_ncols=True
    )
    with torch.no_grad():
        for imgs, weather, masks in val_bar:
            imgs, weather, masks = imgs.to(device), weather.to(device), masks.to(device)
            with torch.amp.autocast("cuda"):
                pm, pa3, pa2 = model(imgs, weather)
                l0 = loss_fn(pm, masks)
                l3 = loss_fn(pa3, masks)
                l2 = loss_fn(pa2, masks)
                loss = l0 + 0.5*(l3 + l2)

            v_loss += loss.item() * imgs.size(0)
            val_bar.set_postfix(loss=f"{loss.item():.4f}")

    avg_v = v_loss / len(val_loader.dataset)
    print(f"           Val Loss:   {avg_v:.4f}\n")

    # — Early stopping —
    if avg_v < best_val:
        best_val, no_imp = avg_v, 0
        torch.save(model.state_dict(), 'best_ds.pth')
    else:
        no_imp += 1
        if no_imp >= patience:
            print(f"↳ Early stopping at epoch {epoch}")
            break




Epoch 01 ▶ Train Loss: 122.4519




           Val Loss:   112.8348





Epoch 02 ▶ Train Loss: nan




           Val Loss:   nan





Epoch 03 ▶ Train Loss: nan




           Val Loss:   nan





Epoch 04 ▶ Train Loss: nan




           Val Loss:   nan





Epoch 05 ▶ Train Loss: nan




           Val Loss:   nan





Epoch 06 ▶ Train Loss: nan


                                                                             

           Val Loss:   nan

↳ Early stopping at epoch 6




In [None]:
# ─── Cell 0: install Optuna ─────────────────────────────────────────────────
!pip install optuna


Collecting optuna
  Downloading optuna-4.3.0-py3-none-any.whl.metadata (17 kB)
Collecting alembic>=1.5.0 (from optuna)
  Downloading alembic-1.15.2-py3-none-any.whl.metadata (7.3 kB)
Collecting colorlog (from optuna)
  Downloading colorlog-6.9.0-py3-none-any.whl.metadata (10 kB)
Downloading optuna-4.3.0-py3-none-any.whl (386 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m386.6/386.6 kB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading alembic-1.15.2-py3-none-any.whl (231 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m231.9/231.9 kB[0m [31m20.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading colorlog-6.9.0-py3-none-any.whl (11 kB)
Installing collected packages: colorlog, alembic, optuna
Successfully installed alembic-1.15.2 colorlog-6.9.0 optuna-4.3.0


In [None]:
# ─── Cell 7: HYPERPARAMETER SWEEP (Optuna) ─────────────────────────────────────
import optuna

def objective(trial):
    lr    = trial.suggest_loguniform("lr", 1e-5, 1e-3)
    wd    = trial.suggest_loguniform("wd", 1e-6, 1e-2)
    alpha = trial.suggest_uniform("mixup_alpha", 0.0, 1.0)
    aw    = trial.suggest_uniform("aux_weight", 0.0, 1.0)

    model = ViTUNetDS().to(device)
    opt   = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min", factor=0.5, patience=2)
    loss_fn = nn.MSELoss()
    scaler  = GradScaler()

    for _ in range(3):
        model.train()
        for imgs, weather, masks in train_loader:
            imgs, weather, masks = imgs.to(device), weather.to(device), masks.to(device)
            xim, wim, yim, _     = mixup_data(imgs, masks, weather, alpha)
            opt.zero_grad()
            with autocast(device_type="cuda"):
                pm, pa3, pa2 = model(xim, wim)
                l0 = loss_fn(pm, yim)
                l3 = loss_fn(pa3, yim)
                l2 = loss_fn(pa2, yim)
                loss = l0 + aw*(l3 + l2)
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
        # val
        model.eval()
        v_loss = 0
        with torch.no_grad(), autocast(device_type="cuda"):
            for imgs, weather, masks in val_loader:
                imgs, weather, masks = imgs.to(device), weather.to(device), masks.to(device)
                pm, pa3, pa2 = model(imgs, weather)
                l0 = loss_fn(pm, masks)
                l3 = loss_fn(pa3, masks)
                l2 = loss_fn(pa2, masks)
                v_loss += (l0 + aw*(l3 + l2)).item() * imgs.size(0)
        avg_v = v_loss / len(val_loader.dataset)
        sched.step(avg_v)

    return avg_v

study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=20)
print("Best params:", study.best_params)
print("Best val:", study.best_value)


[I 2025-04-26 13:35:53,693] A new study created in memory with name: no-name-63810d19-d221-4a8f-9f54-7aa424ce2cef
[W 2025-04-26 13:35:54,203] Trial 0 failed with parameters: {'lr': 2.5672848672791938e-05, 'wd': 1.0595219092334e-06, 'mixup_alpha': 0.16335933798614355, 'aux_weight': 0.18221419820502227} because of the following error: ValueError('not enough values to unpack (expected 4, got 3)').
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "<ipython-input-24-f4aceaff4921>", line 20, in objective
    xim, wim, yim, _     = mixup_data(imgs, masks, weather, alpha)
    ^^^^^^^^^^^^^^^^
ValueError: not enough values to unpack (expected 4, got 3)
[W 2025-04-26 13:35:54,215] Trial 0 failed with value None.


ValueError: not enough values to unpack (expected 4, got 3)