In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Cell 1 ▶ Install required packages
!pip install timm torchvision tifffile imagecodecs --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m122.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m27.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m30.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m19.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# Cell 2 ▶ Consolidated imports
import os
import math
import time

import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

import timm
from torchvision import transforms, models

from tifffile import imread

In [None]:
class LSTDataset(Dataset):
    def __init__(self, df, patches_dir, weather_cols):
        self.df           = df.reset_index(drop=True)
        self.patches_dir  = patches_dir
        self.weather_cols = weather_cols
        self.transform    = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],
                                 std =[0.229,0.224,0.225]),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row   = self.df.loc[idx]
        path  = os.path.join(self.patches_dir, row["filename"])
        arr   = imread(path).astype(np.float32)  # shape = (bands, H, W)

        img_np = arr[[1,2,3],:,:].transpose(1,2,0).astype(np.uint8)
        img    = self.transform(img_np)  # [3,224,224]

        tar_np = arr[0,:,:]
        target = torch.tensor(tar_np, dtype=torch.float32).unsqueeze(0)  # [1,H,W]

        # ✅ Resize your target to match model output shape (56x56 exactly)
        target = F.interpolate(target.unsqueeze(0), size=(56, 56),
                               mode='bilinear', align_corners=False).squeeze(0)

        weather = torch.tensor(
            row[self.weather_cols].values.astype(np.float32)
        )

        return img, weather, target

In [None]:
# define which meteorological columns to pull
weather_cols = [
    "air_temp_C",
    "dew_point_C",
    "relative_humidity_percent",
    "wind_speed_m_s",
    "precipitation_in",
]

In [None]:
df = pd.read_csv("/content/drive/MyDrive/PatchedOutput/tiff_with_meteo.csv")
for col in weather_cols:
    df[col] = pd.to_numeric(df[col], errors="coerce")
df = df.dropna(subset=weather_cols + ["filename"]).reset_index(drop=True)

patches_dir = "/content/drive/MyDrive/PatchedOutput_Cleaned"
dataset     = LSTDataset(df, patches_dir, weather_cols)
train_sz    = int(0.8 * len(dataset))
val_sz      = len(dataset) - train_sz
train_ds, val_ds = random_split(dataset, [train_sz, val_sz])

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True,  num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch_size=4, shuffle=False, num_workers=0, pin_memory=False)


In [None]:
# 1) Define your model class (you only need to run this once per restart)
import math, torch.nn as nn, timm

class PretrainedViTLSTModel(nn.Module):
    def __init__(self,
                 weather_dim: int = 5,
                 hidden_dim:  int = 768,
                 vit_name:    str = "vit_base_patch16_224",
                 num_layers:  int = 2,
                 num_heads:   int = 8):
        super().__init__()
        self.vit = timm.create_model(vit_name, pretrained=True, num_classes=0)
        for p in self.vit.parameters(): p.requires_grad = False
        self.weather_proj = nn.Linear(weather_dim, hidden_dim)
        enc = nn.TransformerEncoderLayer(d_model=hidden_dim,
                                         nhead=num_heads,
                                         dim_feedforward=hidden_dim*4,
                                         dropout=0.1)
        self.transformer = nn.TransformerEncoder(enc, num_layers)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(hidden_dim, hidden_dim//2, kernel_size=2, stride=2),
            nn.BatchNorm2d(hidden_dim//2), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(hidden_dim//2, hidden_dim//4, kernel_size=2, stride=2),
            nn.BatchNorm2d(hidden_dim//4), nn.ReLU(inplace=True),
            nn.Conv2d(hidden_dim//4, 1, kernel_size=1)
        )

    def forward(self, images, weather):
        feats   = self.vit.forward_features(images)   # [B,1+N,D]
        cls_tok = feats[:,0:1]                        # [B,1,D]
        patches = feats[:,1:,:]                       # [B,N,D]
        w_emb   = self.weather_proj(weather).unsqueeze(1)  # [B,1,D]
        tokens  = torch.cat([patches, w_emb, cls_tok], dim=1)  # [B,N+2,D]
        t       = tokens.permute(1,0,2)
        t       = self.transformer(t)
        t       = t.permute(1,0,2)
        patch_out = t[:,:-2,:]                        # drop weather+CLS
        B,N,D     = patch_out.shape
        G         = int(math.sqrt(N))
        x         = patch_out.transpose(1,2).view(B,D,G,G)
        return self.deconv(x)                         # [B,1,224,224]


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = PretrainedViTLSTModel(
    weather_dim=len(weather_cols),
    hidden_dim=768,
    vit_name="vit_base_patch16_224",
    num_layers=2,
    num_heads=8
).to(device)

# unfreeze last ViT blocks:
for name, param in model.vit.named_parameters():
    if any(layer in name for layer in ["blocks.10", "blocks.11", "norm"]):
        param.requires_grad = True

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]



In [None]:
opt       = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4, weight_decay=1e-2
)
loss_fn   = nn.SmoothL1Loss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    opt, mode='min', factor=0.5, patience=3, verbose=True
)



In [None]:
from tqdm import tqdm
import os
from pathlib import Path
import torch

# Create a directory on your Drive for checkpoints
save_dir = Path("/content/drive/MyDrive/ModelCheckpoints")
save_dir.mkdir(parents=True, exist_ok=True)

num_epochs = 10
start_ep   = 0  # or pick up from a checkpoint

for epoch in range(start_ep, num_epochs):
    # — Train —
    model.train()
    train_loss   = 0.0
    seen_samples = 0
    train_bar    = tqdm(train_loader, desc=f"Epoch {epoch+1:02d} Train")
    for imgs, weather, tgt in train_bar:
        imgs, weather, tgt = imgs.to(device), weather.to(device), tgt.to(device)

        opt.zero_grad()
        out  = model(imgs, weather)
        loss = loss_fn(out, tgt)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        batch_val     = loss.item()
        n             = imgs.size(0)
        train_loss   += batch_val * n
        seen_samples += n
        avg_train     = train_loss / seen_samples

        train_bar.set_postfix(
            batch_loss=f"{batch_val:.4f}",
            avg_loss  =f"{avg_train:.4f}"
        )

    train_rmse = (train_loss / len(train_loader.dataset))**0.5

    # — Validate —
    model.eval()
    val_loss = 0.0
    seen_val = 0
    val_bar  = tqdm(val_loader, desc=f"Epoch {epoch+1:02d}   Val ")
    with torch.no_grad():
        for imgs, weather, tgt in val_bar:
            imgs, weather, tgt = imgs.to(device), weather.to(device), tgt.to(device)
            out       = model(imgs, weather)
            batch_val = loss_fn(out, tgt).item()
            n         = imgs.size(0)
            val_loss += batch_val * n
            seen_val += n
            avg_val   = val_loss / seen_val

            val_bar.set_postfix(
                batch_loss=f"{batch_val:.4f}",
                avg_loss  =f"{avg_val:.4f}"
            )

    val_rmse = (val_loss / len(val_loader.dataset))**0.5

    # Step the scheduler
    scheduler.step(val_loss)

    # Print metrics
    print(f"Epoch {epoch+1:02d} ▶ Train RMSE: {train_rmse:.3f} | Val RMSE: {val_rmse:.3f}")

    # — Save checkpoint —
    ckpt_path = save_dir / f"cnn_mlp_epoch{epoch+1:02d}.pt"
    torch.save({
        'epoch': epoch+1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict(),
        'train_rmse': train_rmse,
        'val_rmse': val_rmse
    }, ckpt_path)
    print(f"✅ Saved checkpoint: {ckpt_path}")

print("✅ Training finished")


Epoch 01 Train: 100%|██████████| 2858/2858 [1:37:02<00:00,  2.04s/it, avg_loss=1.4960, batch_loss=1.1200]
Epoch 01   Val : 100%|██████████| 715/715 [22:24<00:00,  1.88s/it, avg_loss=0.6518, batch_loss=0.6657]


Epoch 01 ▶ Train RMSE: 1.223 | Val RMSE: 0.807
✅ Saved checkpoint: /content/drive/MyDrive/ModelCheckpoints/cnn_mlp_epoch01.pt


Epoch 02 Train: 100%|██████████| 2858/2858 [08:06<00:00,  5.87it/s, avg_loss=0.5169, batch_loss=0.2031]
Epoch 02   Val : 100%|██████████| 715/715 [00:54<00:00, 13.12it/s, avg_loss=0.3477, batch_loss=0.2042]


Epoch 02 ▶ Train RMSE: 0.719 | Val RMSE: 0.590
✅ Saved checkpoint: /content/drive/MyDrive/ModelCheckpoints/cnn_mlp_epoch02.pt


Epoch 03 Train: 100%|██████████| 2858/2858 [08:06<00:00,  5.88it/s, avg_loss=0.3360, batch_loss=0.2114]
Epoch 03   Val : 100%|██████████| 715/715 [00:55<00:00, 12.86it/s, avg_loss=0.1722, batch_loss=0.0814]


Epoch 03 ▶ Train RMSE: 0.580 | Val RMSE: 0.415
✅ Saved checkpoint: /content/drive/MyDrive/ModelCheckpoints/cnn_mlp_epoch03.pt


Epoch 04 Train: 100%|██████████| 2858/2858 [08:04<00:00,  5.90it/s, avg_loss=0.2636, batch_loss=0.2214]
Epoch 04   Val : 100%|██████████| 715/715 [00:54<00:00, 13.11it/s, avg_loss=0.2124, batch_loss=0.0626]


Epoch 04 ▶ Train RMSE: 0.513 | Val RMSE: 0.461
✅ Saved checkpoint: /content/drive/MyDrive/ModelCheckpoints/cnn_mlp_epoch04.pt


Epoch 05 Train: 100%|██████████| 2858/2858 [08:04<00:00,  5.90it/s, avg_loss=0.2217, batch_loss=0.0420]
Epoch 05   Val : 100%|██████████| 715/715 [00:54<00:00, 13.12it/s, avg_loss=0.2825, batch_loss=0.1147]


Epoch 05 ▶ Train RMSE: 0.471 | Val RMSE: 0.532
✅ Saved checkpoint: /content/drive/MyDrive/ModelCheckpoints/cnn_mlp_epoch05.pt


Epoch 06 Train: 100%|██████████| 2858/2858 [08:03<00:00,  5.91it/s, avg_loss=0.1837, batch_loss=0.0439]
Epoch 06   Val : 100%|██████████| 715/715 [00:54<00:00, 13.12it/s, avg_loss=0.1789, batch_loss=0.1157]


Epoch 06 ▶ Train RMSE: 0.429 | Val RMSE: 0.423
✅ Saved checkpoint: /content/drive/MyDrive/ModelCheckpoints/cnn_mlp_epoch06.pt


Epoch 07 Train: 100%|██████████| 2858/2858 [08:05<00:00,  5.88it/s, avg_loss=0.1625, batch_loss=0.1841]
Epoch 07   Val : 100%|██████████| 715/715 [00:54<00:00, 13.04it/s, avg_loss=0.1573, batch_loss=0.0425]


Epoch 07 ▶ Train RMSE: 0.403 | Val RMSE: 0.397
✅ Saved checkpoint: /content/drive/MyDrive/ModelCheckpoints/cnn_mlp_epoch07.pt


Epoch 08 Train: 100%|██████████| 2858/2858 [08:05<00:00,  5.89it/s, avg_loss=0.1478, batch_loss=0.2889]
Epoch 08   Val : 100%|██████████| 715/715 [00:54<00:00, 13.03it/s, avg_loss=0.1542, batch_loss=0.0649]


Epoch 08 ▶ Train RMSE: 0.384 | Val RMSE: 0.393
✅ Saved checkpoint: /content/drive/MyDrive/ModelCheckpoints/cnn_mlp_epoch08.pt


Epoch 09 Train: 100%|██████████| 2858/2858 [08:02<00:00,  5.92it/s, avg_loss=0.1350, batch_loss=0.2004]
Epoch 09   Val : 100%|██████████| 715/715 [00:54<00:00, 13.17it/s, avg_loss=0.1383, batch_loss=0.0717]


Epoch 09 ▶ Train RMSE: 0.367 | Val RMSE: 0.372
✅ Saved checkpoint: /content/drive/MyDrive/ModelCheckpoints/cnn_mlp_epoch09.pt


Epoch 10 Train: 100%|██████████| 2858/2858 [08:02<00:00,  5.92it/s, avg_loss=0.1504, batch_loss=0.3031]
Epoch 10   Val : 100%|██████████| 715/715 [00:54<00:00, 13.19it/s, avg_loss=0.2178, batch_loss=0.1156]


Epoch 10 ▶ Train RMSE: 0.388 | Val RMSE: 0.467
✅ Saved checkpoint: /content/drive/MyDrive/ModelCheckpoints/cnn_mlp_epoch10.pt
✅ Training finished


In [None]:
!find /content -type f -name "*checkpoint*" -print
