In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# === 1) Mount Drive & imports ===
from google.colab import drive
drive.mount('/content/drive')

import os, math, time
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

# install timm and torchvision if missing
!pip install timm torchvision

import timm
from torchvision import transforms
!pip install timm torchvision tifffile --quiet



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# 0) Install dependencies (including LZW support for tifffile)
!pip install timm torchvision tifffile imagecodecs --quiet


In [None]:
# 0) Install dependencies
!pip install timm torchvision rasterio --quiet


In [None]:
import os
import math
import time
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn import MSELoss

import timm
from tifffile import imread        # ← lightweight TIFF reader
from torchvision import transforms
from PIL import Image


In [None]:
class LSTDataset(Dataset):
    def __init__(self, df, patches_dir, weather_cols):
        self.df           = df.reset_index(drop=True)
        self.patches_dir  = patches_dir
        self.weather_cols = weather_cols
        self.transform    = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485,0.456,0.406],
                std =[0.229,0.224,0.225],
            ),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row  = self.df.loc[idx]
        path = os.path.join(self.patches_dir, row["filename"])

        # read all bands via tifffile; yields shape (bands, H, W)
        arr = imread(path).astype(np.float32)

        # ── bands 2,3,4 for RGB input ──
        img_np = arr[[1,2,3], :, :].transpose(1,2,0).astype(np.uint8)
        img    = self.transform(img_np)

        # ── band 1 as LST target ──
        tar_np = arr[0, :, :]
        target = torch.tensor(tar_np, dtype=torch.float32).unsqueeze(0)

        # ── weather as float32 tensor ──
        weather = torch.tensor(
            row[self.weather_cols].values.astype(np.float32)
        )

        return img, weather, target


In [None]:
class PretrainedViTLSTModel(nn.Module):
    def __init__(self,
                 weather_dim: int = 5,
                 hidden_dim:  int = 768,
                 vit_name:    str = "vit_base_patch16_224",
                 num_layers:  int = 2,
                 num_heads:   int = 8):
        super().__init__()
        # pretrained ViT (no head)
        self.vit = timm.create_model(
            vit_name,
            pretrained=True,
            num_classes=0
        )
        for p in self.vit.parameters(): p.requires_grad = False

        # weather → embedding
        self.weather_proj = nn.Linear(weather_dim, hidden_dim)

        # small Transformer to fuse tokens + weather
        enc_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim*4,
            dropout=0.1
        )
        self.transformer = nn.TransformerEncoder(enc_layer, num_layers)

        # deconv head back to 1‑channel map
        p = getattr(self.vit.patch_embed, "patch_size", 16)
        self.deconv = nn.ConvTranspose2d(hidden_dim, 1,
                                         kernel_size=p, stride=p)

    def forward(self, images, weather):
        feats  = self.vit.forward_features(images)  # [B,1+N,D]
        tokens = feats[:,1:,:]                      # [B,N,D]
        w_emb  = self.weather_proj(weather).unsqueeze(1)  # [B,1,D]
        tkns   = torch.cat([tokens, w_emb], dim=1)         # [B,N+1,D]

        t = tkns.permute(1,0,2)   # [seq,B,D]
        t = self.transformer(t)
        t = t.permute(1,0,2)      # [B,seq,D]

        patch_out = t[:,:-1,:]    # [B,N,D]
        B,N,D     = patch_out.shape
        G         = int(math.sqrt(N))
        x         = patch_out.transpose(1,2).view(B,D,G,G)  # [B,D,G,G]

        return self.deconv(x)     # [B,1,G*p,G*p]


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

weather_cols = [
    "air_temp_C", "dew_point_C",
    "relative_humidity_percent",
    "wind_speed_m_s", "precipitation_in"
]

model = PretrainedViTLSTModel(
    weather_dim=len(weather_cols),
    hidden_dim=768,
    vit_name="vit_base_patch16_224",
    num_layers=2,
    num_heads=8
).to(device)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
# — run once and check the output in Colab —
df          = pd.read_csv("/content/drive/MyDrive/PatchedOutput/tiff_with_meteo.csv")
print(df.columns)


Index(['filename', 'date', 'datetime', 'air_temp_C', 'dew_point_C',
       'relative_humidity_percent', 'wind_speed_m_s', 'precipitation_in'],
      dtype='object')


In [None]:
# — read your patched CSV —
df = pd.read_csv("/content/drive/MyDrive/PatchedOutput/tiff_with_meteo.csv")

# — ensure weather columns are floats —
for col in weather_cols:
    df[col] = pd.to_numeric(df[col], errors="coerce")

# — drop any rows with missing weather or filename —
df = df.dropna(subset=weather_cols + ["filename"]).reset_index(drop=True)

patches_dir = "/content/drive/MyDrive/PatchedOutput_Cleaned"

# — dataset & split —
dataset  = LSTDataset(df, patches_dir, weather_cols)
train_sz = int(0.8 * len(dataset))
val_sz   = len(dataset) - train_sz
train_ds, val_ds = random_split(dataset, [train_sz, val_sz])

# — DataLoaders —
train_loader = DataLoader(
    train_ds,
    batch_size=8,      # smaller batch
    shuffle=True,
    num_workers=0,     # safest: no background workers
    pin_memory=False   # also turn off pin_memory
)
val_loader = DataLoader(
    val_ds,
    batch_size=8,
    shuffle=False,
    num_workers=0,
    pin_memory=False
)



In [None]:
opt       = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4, weight_decay=1e-5
)
loss_fn   = MSELoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    opt, mode='min', factor=0.5, patience=3, verbose=True
)




In [None]:
ckpt_dir  = "/content/drive/MyDrive/checkpointsViT"
os.makedirs(ckpt_dir, exist_ok=True)
ckpt_path = os.path.join(ckpt_dir, "vit_lstm_ckpt.pth")

def save_ckpt(epoch):
    torch.save({
        "epoch":       epoch,
        "model_state": model.state_dict(),
        "opt_state":   opt.state_dict(),
        "sched_state": scheduler.state_dict()
    }, ckpt_path)
    print(f"✔ Saved checkpoint at epoch {epoch}")

def load_ckpt():
    if os.path.isfile(ckpt_path):
        ck = torch.load(ckpt_path)
        model.load_state_dict(ck["model_state"])
        opt.load_state_dict(ck["opt_state"])
        scheduler.load_state_dict(ck["sched_state"])
        print(f"→ Resuming from epoch {ck['epoch']}")
        return ck["epoch"] + 1
    return 0


In [None]:
from tqdm import tqdm
import torch.nn.functional as F  # for F.interpolate

num_epochs = 5
start_ep   = load_ckpt()

for epoch in range(start_ep, num_epochs):
    # — Train —
    model.train()
    train_loss = 0.0
    train_bar  = tqdm(train_loader, desc=f"Epoch {epoch:02d} Train", ascii=True)
    for imgs, weather, tgt in train_bar:
        imgs, weather, tgt = imgs.to(device), weather.to(device), tgt.to(device)

        opt.zero_grad()
        out = model(imgs, weather)
        # ─── resize output to match tgt H×W ─────────────────────────
        if out.shape[2:] != tgt.shape[2:]:
            out = F.interpolate(out, size=tgt.shape[2:], mode='bilinear', align_corners=False)
        # ──────────────────────────────────────────────────────────────

        loss = loss_fn(out, tgt)
        loss.backward()
        opt.step()

        train_loss += loss.item() * imgs.size(0)
        train_bar.set_postfix(batch_loss=f"{loss.item():.4f}")

    train_loss /= len(train_loader.dataset)

    # — Validate —
    model.eval()
    val_loss = 0.0
    val_bar  = tqdm(val_loader, desc=f"Epoch {epoch:02d}  Val ", ascii=True)
    with torch.no_grad():
        for imgs, weather, tgt in val_bar:
            imgs, weather, tgt = imgs.to(device), weather.to(device), tgt.to(device)

            out = model(imgs, weather)
            if out.shape[2:] != tgt.shape[2:]:
                out = F.interpolate(out, size=tgt.shape[2:], mode='bilinear', align_corners=False)

            batch_loss = loss_fn(out, tgt).item()
            val_loss   += batch_loss * imgs.size(0)
            val_bar.set_postfix(batch_loss=f"{batch_loss:.4f}")

    val_loss /= len(val_loader.dataset)

    # — Scheduler & Logging —
    scheduler.step(val_loss)
    print(f"Epoch {epoch:02d} | Train: {train_loss:.4f} | Val: {val_loss:.4f}")
    save_ckpt(epoch)

print("✅ Training finished")


Epoch 00 Train: 100%|##########| 1429/1429 [39:35<00:00,  1.66s/it, batch_loss=2.8788]
Epoch 00  Val : 100%|##########| 358/358 [09:50<00:00,  1.65s/it, batch_loss=1.3212]


Epoch 00 | Train: 25.3160 | Val: 9.8544
✔ Saved checkpoint at epoch 0


Epoch 01 Train: 100%|##########| 1429/1429 [04:44<00:00,  5.02it/s, batch_loss=1.1887]
Epoch 01  Val : 100%|##########| 358/358 [00:55<00:00,  6.44it/s, batch_loss=1.8620]


Epoch 01 | Train: 6.7949 | Val: 5.5388
✔ Saved checkpoint at epoch 1


Epoch 02 Train: 100%|##########| 1429/1429 [04:49<00:00,  4.93it/s, batch_loss=0.8701]
Epoch 02  Val : 100%|##########| 358/358 [00:55<00:00,  6.48it/s, batch_loss=1.5197]


Epoch 02 | Train: 3.6407 | Val: 3.0290
✔ Saved checkpoint at epoch 2


Epoch 03 Train: 100%|##########| 1429/1429 [04:48<00:00,  4.96it/s, batch_loss=0.6319]
Epoch 03  Val : 100%|##########| 358/358 [00:55<00:00,  6.48it/s, batch_loss=2.0159]


Epoch 03 | Train: 3.7118 | Val: 2.4946
✔ Saved checkpoint at epoch 3


Epoch 04 Train: 100%|##########| 1429/1429 [04:49<00:00,  4.94it/s, batch_loss=0.8953]
Epoch 04  Val : 100%|##########| 358/358 [00:55<00:00,  6.45it/s, batch_loss=1.0777]


Epoch 04 | Train: 1.8198 | Val: 1.2943
✔ Saved checkpoint at epoch 4
✅ Training finished


Here’s what those two metrics mean in your tqdm bars:

- **b‑loss**: the **batch loss** for the *current* mini‑batch—that is, the value of `loss.item()` you just computed before backprop.  
- **avg**: the **running average loss** up through that batch within the epoch. It’s computed as  
  \[
    \text{avg} = \frac{\sum_{\text{all batches so far}} \bigl(\text{batch\_loss} \times \text{batch\_size}\bigr)}{\text{(number of samples seen so far)}}
  \]
  so it tells you how training (or validation) loss is trending on average as the epoch progresses.

In [None]:
import os
import torch
from tqdm import tqdm
import torch.nn.functional as F  # for interpolation

CKPT_PATH    = "checkpoint.pth"
TOTAL_EPOCHS = 10  # total epochs you want to train

def save_ckpt(epoch):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
    }, CKPT_PATH)

def load_ckpt():
    if os.path.exists(CKPT_PATH):
        ckpt = torch.load(CKPT_PATH, map_location=device)
        model.load_state_dict(ckpt['model_state_dict'])
        opt.load_state_dict(ckpt['optimizer_state_dict'])
        scheduler.load_state_dict(ckpt['scheduler_state_dict'])
        # resume at the next epoch
        return ckpt['epoch'] + 1
    return 0

start_ep = load_ckpt()

for epoch in range(start_ep, TOTAL_EPOCHS):
    # — Train —
    model.train()
    train_loss = 0.0
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch:02d} ▶ Train", ascii=True)
    for batch_idx, (imgs, weather, tgt) in enumerate(train_bar, 1):
        imgs, weather, tgt = imgs.to(device), weather.to(device), tgt.to(device)

        opt.zero_grad()
        out = model(imgs, weather)
        if out.shape[2:] != tgt.shape[2:]:
            out = F.interpolate(out,
                                size=tgt.shape[2:],
                                mode='bilinear',
                                align_corners=False)

        loss = loss_fn(out, tgt)
        loss.backward()
        opt.step()

        train_loss += loss.item() * imgs.size(0)
        avg_loss = train_loss / (batch_idx * train_loader.batch_size)

        train_bar.set_postfix({
            'batch': f"{batch_idx}/{len(train_loader)}",
            'b-loss': f"{loss.item():.4f}",
            'avg':    f"{avg_loss:.4f}"
        })

    train_loss /= len(train_loader.dataset)

    # — Validate —
    model.eval()
    val_loss = 0.0
    val_bar  = tqdm(val_loader, desc=f"Epoch {epoch:02d} ◀ Val  ", ascii=True)
    with torch.no_grad():
        for batch_idx, (imgs, weather, tgt) in enumerate(val_bar, 1):
            imgs, weather, tgt = imgs.to(device), weather.to(device), tgt.to(device)

            out = model(imgs, weather)
            if out.shape[2:] != tgt.shape[2:]:
                out = F.interpolate(out,
                                    size=tgt.shape[2:],
                                    mode='bilinear',
                                    align_corners=False)

            batch_loss = loss_fn(out, tgt).item()
            val_loss   += batch_loss * imgs.size(0)
            avg_val    = val_loss / (batch_idx * val_loader.batch_size)

            val_bar.set_postfix({
                'batch': f"{batch_idx}/{len(val_loader)}",
                'b-loss': f"{batch_loss:.4f}",
                'avg':    f"{avg_val:.4f}"
            })

    val_loss /= len(val_loader.dataset)

    # — Scheduler & Checkpoint —
    scheduler.step(val_loss)
    print(f"Epoch {epoch:02d} Summary → Train: {train_loss:.4f} | Val: {val_loss:.4f}")
    save_ckpt(epoch)

print("✅ Training finished")


Epoch 00 ▶ Train: 100%|##########| 1429/1429 [04:42<00:00,  5.05it/s, batch=1429/1429, b-loss=3.4007, avg=1.7620]
Epoch 00 ◀ Val  : 100%|##########| 358/358 [00:55<00:00,  6.45it/s, batch=358/358, b-loss=1.3964, avg=3.8357]


Epoch 00 Summary → Train: 1.7620 | Val: 3.8438


Epoch 01 ▶ Train: 100%|##########| 1429/1429 [04:44<00:00,  5.02it/s, batch=1429/1429, b-loss=0.6000, avg=1.5009]
Epoch 01 ◀ Val  : 100%|##########| 358/358 [00:56<00:00,  6.37it/s, batch=358/358, b-loss=1.2923, avg=1.5352]


Epoch 01 Summary → Train: 1.5009 | Val: 1.5384


Epoch 02 ▶ Train: 100%|##########| 1429/1429 [04:43<00:00,  5.04it/s, batch=1429/1429, b-loss=0.5604, avg=1.4002]
Epoch 02 ◀ Val  : 100%|##########| 358/358 [00:55<00:00,  6.43it/s, batch=358/358, b-loss=1.5912, avg=1.3513]


Epoch 02 Summary → Train: 1.4002 | Val: 1.3542


Epoch 03 ▶ Train: 100%|##########| 1429/1429 [04:42<00:00,  5.05it/s, batch=1429/1429, b-loss=0.9012, avg=0.8043]
Epoch 03 ◀ Val  : 100%|##########| 358/358 [00:56<00:00,  6.35it/s, batch=358/358, b-loss=1.3241, avg=1.3961]


Epoch 03 Summary → Train: 0.8043 | Val: 1.3990


Epoch 04 ▶ Train: 100%|##########| 1429/1429 [04:42<00:00,  5.06it/s, batch=1429/1429, b-loss=0.2675, avg=0.6305]
Epoch 04 ◀ Val  : 100%|##########| 358/358 [00:55<00:00,  6.45it/s, batch=358/358, b-loss=1.3185, avg=0.9852]


Epoch 04 Summary → Train: 0.6305 | Val: 0.9873


Epoch 05 ▶ Train: 100%|##########| 1429/1429 [04:41<00:00,  5.07it/s, batch=1429/1429, b-loss=0.4374, avg=0.5946]
Epoch 05 ◀ Val  : 100%|##########| 358/358 [00:55<00:00,  6.48it/s, batch=358/358, b-loss=1.4934, avg=1.2413]


Epoch 05 Summary → Train: 0.5946 | Val: 1.2439


Epoch 06 ▶ Train: 100%|##########| 1429/1429 [04:43<00:00,  5.04it/s, batch=1429/1429, b-loss=0.2204, avg=0.6138]
Epoch 06 ◀ Val  : 100%|##########| 358/358 [00:55<00:00,  6.44it/s, batch=358/358, b-loss=1.0285, avg=0.7208]


Epoch 06 Summary → Train: 0.6138 | Val: 0.7223


Epoch 07 ▶ Train: 100%|##########| 1429/1429 [04:42<00:00,  5.07it/s, batch=1429/1429, b-loss=0.4487, avg=0.5894]
Epoch 07 ◀ Val  : 100%|##########| 358/358 [00:55<00:00,  6.47it/s, batch=358/358, b-loss=1.9437, avg=1.4219]


Epoch 07 Summary → Train: 0.5894 | Val: 1.4249


Epoch 08 ▶ Train: 100%|##########| 1429/1429 [04:41<00:00,  5.08it/s, batch=1429/1429, b-loss=0.4082, avg=0.5761]
Epoch 08 ◀ Val  : 100%|##########| 358/358 [00:55<00:00,  6.50it/s, batch=358/358, b-loss=1.0682, avg=0.8165]


Epoch 08 Summary → Train: 0.5761 | Val: 0.8182


Epoch 09 ▶ Train: 100%|##########| 1429/1429 [04:42<00:00,  5.06it/s, batch=1429/1429, b-loss=25.0908, avg=0.5526]
Epoch 09 ◀ Val  : 100%|##########| 358/358 [00:55<00:00,  6.50it/s, batch=358/358, b-loss=2.0534, avg=6.8430]


Epoch 09 Summary → Train: 0.5526 | Val: 6.8574
✅ Training finished


We add random crops, flips, rotations and color jitters to expose the model to more varied inputs, which helps it generalize better rather than memorizing the exact patches. This augmentation reduces overfitting and makes your ViT‑based head more robust to real‑world variability.

In [None]:
from torchvision import transforms

class LSTDataset(Dataset):
    def __init__(self, df, patches_dir, weather_cols):
        self.df           = df.reset_index(drop=True)
        self.patches_dir  = patches_dir
        self.weather_cols = weather_cols
        # stronger augmentations for better generalization
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std =[0.229, 0.224, 0.225],
            ),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row  = self.df.loc[idx]
        path = os.path.join(self.patches_dir, row["filename"])

        # read all bands; arr shape = (bands, H, W)
        arr    = imread(path).astype(np.float32)
        img_np = arr[[1,2,3],:,:].transpose(1,2,0).astype(np.uint8)
        img    = self.transform(img_np)

        # band 1 as LST target
        tar_np = arr[0,:,:]
        target = torch.tensor(tar_np, dtype=torch.float32).unsqueeze(0)

        weather = torch.tensor(
            row[self.weather_cols].values.astype(np.float32)
        )

        return img, weather, target


In [None]:
# Cell 1: Imports & setup
from google.colab import drive
drive.mount('/content/drive')

import os, math, time
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F           # ← added
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import timm
from tifffile import imread

!pip install timm torchvision


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Cell 5: Updated Dataset (now resizes target as well)
class LSTDataset(Dataset):
    def __init__(self, df, patches_dir, weather_cols):
        self.df           = df.reset_index(drop=True)
        self.patches_dir  = patches_dir
        self.weather_cols = weather_cols
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomResizedCrop(224, scale=(0.8,1.0)),
            transforms.RandomHorizontalFlip(0.5),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],
                                 std =[0.229,0.224,0.225]),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row  = self.df.loc[idx]
        path = os.path.join(self.patches_dir, row["filename"])

        arr    = imread(path).astype(np.float32)            # (bands,H,W)
        img_np = arr[[1,2,3],:,:].transpose(1,2,0).astype(np.uint8)
        img    = self.transform(img_np)                     # → [3,224,224]

        # original target at its native size
        tar_np = arr[0,:,:]
        target = torch.tensor(tar_np, dtype=torch.float32).unsqueeze(0)   # [1,H,W]
        # resize to 224×224 so it matches your head’s output
        target = F.interpolate(
            target.unsqueeze(0), size=(224,224),
            mode='bilinear', align_corners=False
        ).squeeze(0)                                          # [1,224,224]

        weather = torch.tensor(row[self.weather_cols].values.astype(np.float32))
        return img, weather, target


In [None]:
# Cell 9: Read CSV, split, and create DataLoaders
df = pd.read_csv("/content/drive/MyDrive/PatchedOutput/tiff_with_meteo.csv")
for col in weather_cols:
    df[col] = pd.to_numeric(df[col], errors="coerce")
df = df.dropna(subset=weather_cols + ["filename"]).reset_index(drop=True)

patches_dir = "/content/drive/MyDrive/PatchedOutput_Cleaned"
dataset     = LSTDataset(df, patches_dir, weather_cols)
train_sz    = int(0.8 * len(dataset))
val_sz      = len(dataset) - train_sz
train_ds, val_ds = random_split(dataset, [train_sz, val_sz])

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True,  num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch_size=8, shuffle=False, num_workers=0, pin_memory=False)


In [None]:
# Cell 10: Optimizer, loss function, scheduler
opt       = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4, weight_decay=1e-2
)
loss_fn   = nn.SmoothL1Loss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    opt, mode='min', factor=0.5, patience=3, verbose=True
)




In [None]:
# Cell 7: Device setup & model instantiation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = PretrainedViTLSTModel(
    weather_dim=len(weather_cols),
    hidden_dim=768,
    vit_name="vit_base_patch16_224",
    num_layers=2,
    num_heads=8
).to(device)




In [None]:
# Cell 12: Training & validation loop WITH real‑time avg loss (fresh start)
from tqdm import tqdm

num_epochs = 20
start_ep   = 0          # ← always start at 0

for epoch in range(start_ep, num_epochs):
    # — Train —
    model.train()
    train_loss    = 0.0
    seen_samples  = 0
    train_bar     = tqdm(train_loader, desc=f"Epoch {epoch:02d} Train")
    for batch_idx, (imgs, weather, tgt) in enumerate(train_bar):
        imgs, weather, tgt = imgs.to(device), weather.to(device), tgt.to(device)

        opt.zero_grad()
        out   = model(imgs, weather)
        loss  = loss_fn(out, tgt)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        # update running sums
        batch_val     = loss.item()
        n             = imgs.size(0)
        train_loss   += batch_val * n
        seen_samples += n
        avg_train     = train_loss / seen_samples

        train_bar.set_postfix(
            batch_loss=f"{batch_val:.4f}",
            avg_loss  =f"{avg_train:.4f}"
        )

    train_loss /= len(train_loader.dataset)

    # — Validate —
    model.eval()
    val_loss    = 0.0
    seen_val    = 0
    val_bar     = tqdm(val_loader, desc=f"Epoch {epoch:02d}   Val ")
    with torch.no_grad():
        for imgs, weather, tgt in val_bar:
            imgs, weather, tgt = imgs.to(device), weather.to(device), tgt.to(device)
            out        = model(imgs, weather)
            batch_val  = loss_fn(out, tgt).item()
            n          = imgs.size(0)
            val_loss  += batch_val * n
            seen_val  += n
            avg_val    = val_loss / seen_val

            val_bar.set_postfix(
                batch_loss=f"{batch_val:.4f}",
                avg_loss  =f"{avg_val:.4f}"
            )

    val_loss /= len(val_loader.dataset)

    scheduler.step(val_loss)
    print(f"Epoch {epoch:02d} | Train: {train_loss:.4f} | Val: {val_loss:.4f}")
    save_ckpt(epoch)

print("✅ Training finished")


Epoch 00 Train: 100%|██████████| 1429/1429 [05:14<00:00,  4.54it/s, avg_loss=4.6777, batch_loss=4.9709]
Epoch 00   Val : 100%|██████████| 358/358 [01:04<00:00,  5.51it/s, avg_loss=4.7516, batch_loss=3.0185]


Epoch 00 | Train: 4.6777 | Val: 4.7516


Epoch 01 Train: 100%|██████████| 1429/1429 [05:13<00:00,  4.55it/s, avg_loss=4.6767, batch_loss=3.6956]
Epoch 01   Val : 100%|██████████| 358/358 [01:05<00:00,  5.46it/s, avg_loss=4.7568, batch_loss=3.3660]


Epoch 01 | Train: 4.6767 | Val: 4.7568


Epoch 02 Train: 100%|██████████| 1429/1429 [05:13<00:00,  4.56it/s, avg_loss=4.6833, batch_loss=4.4240]
Epoch 02   Val : 100%|██████████| 358/358 [01:05<00:00,  5.49it/s, avg_loss=4.7533, batch_loss=3.6487]


Epoch 02 | Train: 4.6833 | Val: 4.7533


Epoch 03 Train: 100%|██████████| 1429/1429 [05:13<00:00,  4.56it/s, avg_loss=4.6897, batch_loss=3.5375]
Epoch 03   Val : 100%|██████████| 358/358 [01:04<00:00,  5.52it/s, avg_loss=4.7309, batch_loss=3.1766]


Epoch 03 | Train: 4.6897 | Val: 4.7309


Epoch 04 Train: 100%|██████████| 1429/1429 [05:13<00:00,  4.55it/s, avg_loss=4.6853, batch_loss=6.0677]
Epoch 04   Val : 100%|██████████| 358/358 [01:04<00:00,  5.51it/s, avg_loss=4.7438, batch_loss=3.6821]


Epoch 04 | Train: 4.6853 | Val: 4.7438


Epoch 05 Train: 100%|██████████| 1429/1429 [05:14<00:00,  4.54it/s, avg_loss=4.6773, batch_loss=5.7809]
Epoch 05   Val : 100%|██████████| 358/358 [01:04<00:00,  5.54it/s, avg_loss=4.7562, batch_loss=3.5703]


Epoch 05 | Train: 4.6773 | Val: 4.7562


Epoch 06 Train: 100%|██████████| 1429/1429 [05:13<00:00,  4.55it/s, avg_loss=4.6789, batch_loss=4.0838]
Epoch 06   Val : 100%|██████████| 358/358 [01:04<00:00,  5.52it/s, avg_loss=4.7504, batch_loss=2.8590]


Epoch 06 | Train: 4.6789 | Val: 4.7504


Epoch 07 Train: 100%|██████████| 1429/1429 [05:13<00:00,  4.55it/s, avg_loss=4.6822, batch_loss=3.8228]
Epoch 07   Val : 100%|██████████| 358/358 [01:04<00:00,  5.55it/s, avg_loss=4.7257, batch_loss=3.7118]


Epoch 07 | Train: 4.6822 | Val: 4.7257


Epoch 08 Train: 100%|██████████| 1429/1429 [05:13<00:00,  4.55it/s, avg_loss=4.6753, batch_loss=6.2062]
Epoch 08   Val : 100%|██████████| 358/358 [01:04<00:00,  5.55it/s, avg_loss=4.7334, batch_loss=3.4046]


Epoch 08 | Train: 4.6753 | Val: 4.7334


Epoch 09 Train: 100%|██████████| 1429/1429 [05:14<00:00,  4.55it/s, avg_loss=4.6748, batch_loss=8.2943]
Epoch 09   Val : 100%|██████████| 358/358 [01:04<00:00,  5.55it/s, avg_loss=4.7353, batch_loss=3.5355]


Epoch 09 | Train: 4.6748 | Val: 4.7353


Epoch 10 Train: 100%|██████████| 1429/1429 [05:11<00:00,  4.58it/s, avg_loss=4.6699, batch_loss=3.9576]
Epoch 10   Val : 100%|██████████| 358/358 [01:04<00:00,  5.57it/s, avg_loss=4.7527, batch_loss=2.8998]


Epoch 10 | Train: 4.6699 | Val: 4.7527


Epoch 11 Train: 100%|██████████| 1429/1429 [05:12<00:00,  4.57it/s, avg_loss=4.6747, batch_loss=3.3215]
Epoch 11   Val : 100%|██████████| 358/358 [01:04<00:00,  5.52it/s, avg_loss=4.7523, batch_loss=2.8771]


Epoch 11 | Train: 4.6747 | Val: 4.7523


Epoch 12 Train: 100%|██████████| 1429/1429 [05:13<00:00,  4.56it/s, avg_loss=4.6735, batch_loss=3.1250]
Epoch 12   Val : 100%|██████████| 358/358 [01:04<00:00,  5.57it/s, avg_loss=4.7473, batch_loss=2.8887]


Epoch 12 | Train: 4.6735 | Val: 4.7473


Epoch 13 Train: 100%|██████████| 1429/1429 [05:13<00:00,  4.56it/s, avg_loss=4.6704, batch_loss=4.3612]
Epoch 13   Val : 100%|██████████| 358/358 [01:04<00:00,  5.55it/s, avg_loss=4.7352, batch_loss=2.8172]


Epoch 13 | Train: 4.6704 | Val: 4.7352


Epoch 14 Train: 100%|██████████| 1429/1429 [05:12<00:00,  4.57it/s, avg_loss=4.6704, batch_loss=4.4716]
Epoch 14   Val : 100%|██████████| 358/358 [01:04<00:00,  5.54it/s, avg_loss=4.7438, batch_loss=2.9095]


Epoch 14 | Train: 4.6704 | Val: 4.7438


Epoch 15 Train: 100%|██████████| 1429/1429 [05:12<00:00,  4.58it/s, avg_loss=4.6747, batch_loss=3.3528]
Epoch 15   Val : 100%|██████████| 358/358 [01:04<00:00,  5.58it/s, avg_loss=4.7635, batch_loss=2.8907]


Epoch 15 | Train: 4.6747 | Val: 4.7635


Epoch 16 Train: 100%|██████████| 1429/1429 [05:12<00:00,  4.57it/s, avg_loss=4.6795, batch_loss=4.6411]
Epoch 16   Val : 100%|██████████| 358/358 [01:04<00:00,  5.56it/s, avg_loss=4.7341, batch_loss=2.9712]


Epoch 16 | Train: 4.6795 | Val: 4.7341


Epoch 17 Train: 100%|██████████| 1429/1429 [05:12<00:00,  4.57it/s, avg_loss=4.6723, batch_loss=4.0108]
Epoch 17   Val : 100%|██████████| 358/358 [01:04<00:00,  5.55it/s, avg_loss=4.7540, batch_loss=3.0447]


Epoch 17 | Train: 4.6723 | Val: 4.7540


Epoch 18 Train: 100%|██████████| 1429/1429 [05:11<00:00,  4.58it/s, avg_loss=4.6745, batch_loss=3.9277]
Epoch 18   Val : 100%|██████████| 358/358 [01:04<00:00,  5.58it/s, avg_loss=4.7448, batch_loss=2.8143]


Epoch 18 | Train: 4.6745 | Val: 4.7448


Epoch 19 Train: 100%|██████████| 1429/1429 [05:11<00:00,  4.59it/s, avg_loss=4.6865, batch_loss=6.6359]
Epoch 19   Val : 100%|██████████| 358/358 [01:04<00:00,  5.56it/s, avg_loss=4.7551, batch_loss=3.7441]


Epoch 19 | Train: 4.6865 | Val: 4.7551
✅ Training finished


Now we try to unfreeze the last part of the vit

https://colab.research.google.com/drive/1ht-CYyX1jwkL6UaMAZNqiDrBIPZx8M21

Here’s a quick breakdown of **why** we make each of those three tweaks and **what** they’re doing under the hood:

1. **Unfreezing late ViT blocks**  
   - **Why:** The ViT backbone was pretrained on ImageNet RGB images, which look very different from your multi‑band LST patches. By unfreezing just the last two transformer blocks and the final norm layer, you let the model adjust its high‑level feature detectors to your domain without blowing away all of its useful low‑level filters.  
   - **What:** Those `blocks.10`, `blocks.11`, and `norm` layers become trainable, so during backpropagation their weights will shift to better capture patterns (e.g., thermal gradients, texture) in your satellite patches.

2. **Richer upsampling head**  
   - **Why:** A single `ConvTranspose2d` can only “blow up” your feature map once, which often leads to blurry or blocky outputs. A deeper sequence of smaller up‑ and down‑sampling steps with intermediate nonlinearities gives the network more capacity to reconstruct fine spatial details.  
   - **What:** You replace  
     ```python
     ConvTranspose2d(hidden_dim → 1)
     ```  
     with  
     ```python
     ConvTranspose2d(hidden_dim → hidden_dim/2) → ReLU → ConvTranspose2d(hidden_dim/2 → hidden_dim/4) → ReLU → Conv2d(hidden_dim/4 → 1)
     ```  
     so the model gradually upsamples and refines features at each stage.

3. **Injecting the [CLS] token back in**  
   - **Why:** The `[CLS]` token in ViT holds a global summary of the entire image (all patches). If you discard it, your head only sees local patch embeddings plus weather. Re‑injecting it lets every patch token “know” the global context, which helps coordinate outputs across the whole map.  
   - **What:** Instead of dropping `feats[:,0]`, you concatenate it with your patch tokens and the weather embedding—so your transformer fusion layer sees `(patches + weather + CLS)` as one sequence, letting global information flow back into each spatial location before you decode.

Altogether, these changes let your pretrained ViT adapt its highest‑level concepts, decode richer spatial structure, and leverage both local and global context to produce cleaner, more accurate LST maps.