In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Cell 1 ▶ Install required packages
!pip install timm torchvision tifffile imagecodecs --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m127.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m100.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m63.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m41.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# Cell 2 ▶ Consolidated imports
import os
import math
import time

import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

import timm
from torchvision import transforms, models

from tifffile import imread

In [None]:
# Cell 1 ▶ Dataset (224×224 targets, no down-scaling)
class LSTDataset(Dataset):
    def __init__(self, df, patches_dir, weather_cols):
        self.df           = df.reset_index(drop=True)
        self.patches_dir  = patches_dir
        self.weather_cols = weather_cols
        self.transform    = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),          # resize image
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std =[0.229, 0.224, 0.225]),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row  = self.df.loc[idx]
        arr  = imread(os.path.join(self.patches_dir, row["filename"])
                     ).astype(np.float32)            # (4,H,W)

        # --- inputs -------------------------------------------------
        img_np = arr[[1,2,3]].transpose(1,2,0).astype(np.uint8)
        img    = self.transform(img_np)              # [3,224,224]

        # --- target (LST) at *full* 224×224 --------------------------
        lst    = arr[0]                              # (H,W)
        lst    = torch.tensor(lst, dtype=torch.float32).unsqueeze(0)
        lst    = F.interpolate(lst.unsqueeze(0), size=(224,224),
                               mode='bilinear', align_corners=False
                              ).squeeze(0)           # [1,224,224]

        # --- meteorology vector -------------------------------------
        weather = torch.tensor(
            row[self.weather_cols].values.astype(np.float32)
        )
        return img, weather, lst


In [None]:
# define which meteorological columns to pull
weather_cols = [
    "air_temp_C",
    "dew_point_C",
    "relative_humidity_percent",
    "wind_speed_m_s",
    "precipitation_in",
]

In [None]:
df = pd.read_csv("/content/drive/MyDrive/PatchedOutput/patch_with_meteo.csv")
for col in weather_cols:
    df[col] = pd.to_numeric(df[col], errors="coerce")
df = df.dropna(subset=weather_cols + ["patch_filename"]).reset_index(drop=True)

patches_dir = "/content/drive/MyDrive/PatchedOutput_Cleaned"
dataset     = LSTDataset(df, patches_dir, weather_cols)
train_sz    = int(0.8 * len(dataset))
val_sz      = len(dataset) - train_sz
train_ds, val_ds = random_split(dataset, [train_sz, val_sz])

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True,  num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch_size=4, shuffle=False, num_workers=0, pin_memory=False)

In [None]:
import re

# 1) Read in the master CSV
df = pd.read_csv("/content/drive/MyDrive/PatchedOutput/patch_with_meteo.csv")

# 2) Convert weather columns to numeric & drop NaNs
for col in weather_cols:
    df[col] = pd.to_numeric(df[col], errors="coerce")
df = df.dropna(subset=weather_cols + ["patch_filename"]).reset_index(drop=True)

# 3) Extract patch_id (“r1152_c672”) and date (“2022-07-12”) from filename
def parse_patch_id(fn):
    m = re.search(r'(r\d+_c\d+)', fn)
    return m.group(1) if m else None

def parse_date(fn):
    m = re.search(r'_(\d{4}-\d{2}-\d{2})_', fn)
    return pd.to_datetime(m.group(1)) if m else pd.NaT

df['patch_id'] = df['patch_filename'].map(parse_patch_id)
df['date']     = df['patch_filename'].map(parse_date)

# 4) Drop any rows where parsing failed
df = df.dropna(subset=["patch_id", "date"]).reset_index(drop=True)

print(f"Loaded {len(df)} rows; found {df['patch_id'].nunique()} unique patches")

Loaded 18813 rows; found 174 unique patches


In [None]:
class LSTTimeSeriesDataset(torch.utils.data.Dataset):
    def __init__(self, df, patches_dir, weather_cols, Δt_days=16):
        self.df = df.copy()
        self.patches_dir = patches_dir
        self.weather_cols = weather_cols
        self.Δt = np.timedelta64(int(Δt_days), 'D')

        # Group by patch_id, sorted in time
        self.groups = {
            pid: g.sort_values('date').reset_index(drop=True)
            for pid, g in self.df.groupby('patch_id')
        }

        # Build index of valid (pid, i) pairs where date[i+1] - date[i] == Δt
        self.index = []
        for pid, g in self.groups.items():
            for i in range(len(g)-1):
                if (g.loc[i+1,'date'] - g.loc[i,'date']) == self.Δt:
                    self.index.append((pid, i))

        # Reuse your standard image transform
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],
                                 std=[0.229,0.224,0.225])
        ])

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        pid, i = self.index[idx]
        g       = self.groups[pid]

        # --- t0: image + weather + lst
        row0 = g.loc[i]
        arr = imread(os.path.join(self.patches_dir, row0.patch_filename))

        lst0    = arr[0]                              # (H,W)
        lst0    = torch.tensor(lst0, dtype=torch.float32).unsqueeze(0)
        lst0    = F.interpolate(lst0.unsqueeze(0), size=(224,224),
                               mode='bilinear', align_corners=False
                              ).squeeze(0)           # [1,224,224]                                          # [1,224,224]

        # --- t1: only LST target
        row1 = g.loc[i+1]
        arr1 = imread(os.path.join(
            self.patches_dir,
            row1.patch_filename
        ))
        img_np = arr1[[1,2,3]].transpose(1,2,0).astype(np.uint8)
        img    = self.transform(img_np)              # [3,224,224]
        w1 = torch.tensor(
            row1[self.weather_cols].values.astype(np.float32),
            dtype=torch.float32
        )

        lst1_np = arr1[0]                                    # (H,W)
        lst1 = torch.tensor(lst1_np, dtype=torch.float32).unsqueeze(0)
        lst1 = F.interpolate(
            lst1.unsqueeze(0),
            size=(224,224),
            mode='bilinear', align_corners=False
        ).squeeze(0)                                          # [1,224,224]

        return img, w1, lst0, lst1


In [None]:
ts_dataset = LSTTimeSeriesDataset(df, patches_dir, weather_cols)
train_sz = int(0.8 * len(ts_dataset))
train_ds, val_ds = random_split(ts_dataset, [train_sz, len(ts_dataset)-train_sz])

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=4, shuffle=False)

In [None]:
# Cell 4 ▶ ViT + weather → 224×224 decoder (with α)
import math, torch.nn as nn, timm

class PretrainedViTLSTModel(nn.Module):
    def __init__(self,
                 weather_dim=5,
                 hidden_dim=768,
                 vit_name="vit_base_patch16_224",
                 num_layers=2,
                 num_heads=8):
        super().__init__()
        # --- frozen ViT backbone ---
        self.vit = timm.create_model(vit_name, pretrained=True, num_classes=0)
        for p in self.vit.parameters():
            p.requires_grad = False

        # --- project weather → token ---
        self.weather_proj = nn.Linear(weather_dim, hidden_dim)

        # --- tiny transformer for fusion ---
        enc = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=0.1
        )
        self.transformer = nn.TransformerEncoder(enc, num_layers)

        # --- decoder: 14→224 via conv-transpose ---
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(hidden_dim, hidden_dim//2, 2, 2),
            nn.BatchNorm2d(hidden_dim//2), nn.ReLU(inplace=True),

            nn.ConvTranspose2d(hidden_dim//2, hidden_dim//4, 2, 2),
            nn.BatchNorm2d(hidden_dim//4), nn.ReLU(inplace=True),

            nn.ConvTranspose2d(hidden_dim//4, hidden_dim//8, 2, 2),
            nn.BatchNorm2d(hidden_dim//8), nn.ReLU(inplace=True),

            nn.ConvTranspose2d(hidden_dim//8, 1, 2, 2)
        )

        # --- learnable Newton-cooling coefficient α ---
        self.alpha = nn.Parameter(torch.tensor(0.01, dtype=torch.float32))

    def forward(self, images, weather):
        feats   = self.vit.forward_features(images)  # [B,197,768]
        cls_tok = feats[:, :1]                       # [B,1,768]
        patch_t = feats[:, 1:]                       # [B,196,768]

        w_tok   = self.weather_proj(weather).unsqueeze(1)  # [B,1,768]
        tokens  = torch.cat([patch_t, w_tok, cls_tok], 1)  # [B,198,768]

        t = self.transformer(tokens.permute(1,0,2)).permute(1,0,2)
        patch_out = t[:, :-2, :]                       # drop weather+CLS

        B, N, D = patch_out.shape  # N=196
        G = int(math.sqrt(N))      # =14
        x = patch_out.transpose(1,2).view(B, D, G, G)   # [B,768,14,14]
        return self.deconv(x)                          # [B,1,224,224]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = PretrainedViTLSTModel(
    weather_dim=len(weather_cols),
    hidden_dim=768,
    vit_name="vit_base_patch16_224",
    num_layers=2,
    num_heads=8
).to(device)

# unfreeze last ViT blocks:
for name, param in model.vit.named_parameters():
    if any(layer in name for layer in ["blocks.10", "blocks.11", "norm"]):
        param.requires_grad = True

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]



In [None]:
opt       = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4, weight_decay=1e-2
)
loss_fn   = nn.SmoothL1Loss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    opt, mode='min', factor=0.5, patience=3, verbose=True
)



In [None]:
# Cell 0 ▶ install LZW support
!pip install imagecodecs --quiet


In [None]:
from tifffile import imread


In [None]:
from tqdm import tqdm
import os
from pathlib import Path
import torch

save_dir = Path("/content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints")
save_dir.mkdir(parents=True, exist_ok=True)

num_epochs = 20
start_ep   = 0

λ_phys   = 0.1     # physics‐loss weight
Δt_days  = 16.0    # 16‐day revisit

for epoch in range(start_ep, num_epochs):
    model.train()
    total_data, total_phys, seen = 0.0, 0.0, 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1:02d} Train", unit="batch")

    for imgs, weather, lst0, lst1 in pbar:
        imgs, weather, lst0, lst1 = (
            imgs.to(device), weather.to(device),
            lst0.to(device), lst1.to(device)
        )
        B = imgs.size(0)

        opt.zero_grad()
        pred1     = model(imgs, weather)              # [B,1,224,224]
        data_loss = loss_fn(pred1, lst1)

        # pixel-wise physics term
        T_air1    = weather[:,0].view(-1,1,1,1)
        resid     = (pred1 - lst0)/Δt_days \
                    - model.alpha * (T_air1 - pred1)
        phys_loss = resid.pow(2).mean()

        loss = data_loss + λ_phys * phys_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        # accumulate
        total_data  += data_loss.item() * B
        total_phys  += phys_loss.item() * B
        seen        += B

        # compute running averages
        avg_data  = total_data / seen
        avg_phys  = total_phys / seen
        avg_total = avg_data + λ_phys * avg_phys

        # update tqdm postfix
        pbar.set_postfix({
            'batch':    f"{loss.item():.4f}",
            'avg_data': f"{avg_data:.4f}",
            'avg_phys': f"{avg_phys:.4f}",
            'avg_tot':  f"{avg_total:.4f}"
        })

    # end of epoch: compute epoch RMSE
    train_rmse = math.sqrt(avg_total)

    # — VALIDATION (data‐only) —
    model.eval()
    val_sum, val_n = 0.0, 0
    with torch.no_grad():
        for imgs, weather, lst0, lst1 in val_loader:
            imgs, weather, lst0, lst1 = (
                imgs.to(device), weather.to(device),
                lst0.to(device), lst1.to(device)
            )
            batch_loss = loss_fn(model(imgs, weather), lst1).item()
            val_sum   += batch_loss * imgs.size(0)
            val_n     += imgs.size(0)
    val_rmse = math.sqrt(val_sum / val_n)
    scheduler.step(val_sum)

    print(f"Epoch {epoch+1:02d} ▶ Train RMSE: {train_rmse:.3f} | Val RMSE: {val_rmse:.3f}")
    ckpt = save_dir / f"pinn_epoch{epoch+1:02d}.pt"
    torch.save({
        'epoch': epoch+1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict(),
        'train_rmse': train_rmse,
        'val_rmse': val_rmse,
        'alpha': model.alpha.item()
    }, ckpt)
    print(f"✅ Saved checkpoint: {ckpt}")

print("✅ Training finished")


Epoch 01 Train: 100%|██████████| 3421/3421 [1:22:15<00:00,  1.44s/batch, batch=3.5710, avg_data=1.3416, avg_phys=0.5443, avg_tot=1.3960]


Epoch 01 ▶ Train RMSE: 1.182 | Val RMSE: 0.651
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch01.pt


Epoch 02 Train: 100%|██████████| 3421/3421 [04:59<00:00, 11.42batch/s, batch=0.2212, avg_data=0.5478, avg_phys=0.6093, avg_tot=0.6087]


Epoch 02 ▶ Train RMSE: 0.780 | Val RMSE: 0.666
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch02.pt


Epoch 03 Train: 100%|██████████| 3421/3421 [04:46<00:00, 11.93batch/s, batch=0.4323, avg_data=0.3765, avg_phys=0.6320, avg_tot=0.4397]


Epoch 03 ▶ Train RMSE: 0.663 | Val RMSE: 0.441
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch03.pt


Epoch 04 Train: 100%|██████████| 3421/3421 [04:38<00:00, 12.28batch/s, batch=0.2369, avg_data=0.2921, avg_phys=0.6464, avg_tot=0.3568]


Epoch 04 ▶ Train RMSE: 0.597 | Val RMSE: 0.498
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch04.pt


Epoch 05 Train: 100%|██████████| 3421/3421 [04:38<00:00, 12.28batch/s, batch=0.0697, avg_data=0.2444, avg_phys=0.6553, avg_tot=0.3099]


Epoch 05 ▶ Train RMSE: 0.557 | Val RMSE: 0.382
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch05.pt


Epoch 06 Train: 100%|██████████| 3421/3421 [04:38<00:00, 12.27batch/s, batch=0.1403, avg_data=0.2091, avg_phys=0.6638, avg_tot=0.2755]


Epoch 06 ▶ Train RMSE: 0.525 | Val RMSE: 0.394
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch06.pt


Epoch 07 Train: 100%|██████████| 3421/3421 [04:38<00:00, 12.29batch/s, batch=0.2298, avg_data=0.1802, avg_phys=0.6697, avg_tot=0.2472]


Epoch 07 ▶ Train RMSE: 0.497 | Val RMSE: 0.477
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch07.pt


Epoch 08 Train: 100%|██████████| 3421/3421 [04:38<00:00, 12.28batch/s, batch=0.0783, avg_data=0.1615, avg_phys=0.6743, avg_tot=0.2289]


Epoch 08 ▶ Train RMSE: 0.478 | Val RMSE: 0.305
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch08.pt


Epoch 09 Train: 100%|██████████| 3421/3421 [04:48<00:00, 11.87batch/s, batch=0.0920, avg_data=0.1445, avg_phys=0.6811, avg_tot=0.2126]


Epoch 09 ▶ Train RMSE: 0.461 | Val RMSE: 0.374
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch09.pt


Epoch 10 Train: 100%|██████████| 3421/3421 [04:52<00:00, 11.68batch/s, batch=0.0598, avg_data=0.1324, avg_phys=0.6816, avg_tot=0.2005]


Epoch 10 ▶ Train RMSE: 0.448 | Val RMSE: 0.363
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch10.pt


Epoch 11 Train: 100%|██████████| 3421/3421 [04:45<00:00, 12.00batch/s, batch=0.1342, avg_data=0.1214, avg_phys=0.6852, avg_tot=0.1899]


Epoch 11 ▶ Train RMSE: 0.436 | Val RMSE: 0.389
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch11.pt


Epoch 12 Train: 100%|██████████| 3421/3421 [04:38<00:00, 12.27batch/s, batch=0.1376, avg_data=0.1166, avg_phys=0.6909, avg_tot=0.1856]


Epoch 12 ▶ Train RMSE: 0.431 | Val RMSE: 0.314
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch12.pt


Epoch 13 Train: 100%|██████████| 3421/3421 [04:40<00:00, 12.20batch/s, batch=0.0721, avg_data=0.0767, avg_phys=0.6918, avg_tot=0.1459]


Epoch 13 ▶ Train RMSE: 0.382 | Val RMSE: 0.234
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch13.pt


Epoch 14 Train: 100%|██████████| 3421/3421 [04:38<00:00, 12.30batch/s, batch=0.0685, avg_data=0.0698, avg_phys=0.6957, avg_tot=0.1394]


Epoch 14 ▶ Train RMSE: 0.373 | Val RMSE: 0.257
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch14.pt


Epoch 15 Train: 100%|██████████| 3421/3421 [04:37<00:00, 12.31batch/s, batch=0.3438, avg_data=0.0688, avg_phys=0.6955, avg_tot=0.1384]


Epoch 15 ▶ Train RMSE: 0.372 | Val RMSE: 0.278
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch15.pt


Epoch 16 Train: 100%|██████████| 3421/3421 [04:38<00:00, 12.29batch/s, batch=0.0583, avg_data=0.0651, avg_phys=0.6972, avg_tot=0.1348]


Epoch 16 ▶ Train RMSE: 0.367 | Val RMSE: 0.250
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch16.pt


Epoch 17 Train: 100%|██████████| 3421/3421 [04:38<00:00, 12.30batch/s, batch=0.1155, avg_data=0.0588, avg_phys=0.7014, avg_tot=0.1289]


Epoch 17 ▶ Train RMSE: 0.359 | Val RMSE: 0.239
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch17.pt


Epoch 18 Train: 100%|██████████| 3421/3421 [04:38<00:00, 12.30batch/s, batch=0.1142, avg_data=0.0478, avg_phys=0.6993, avg_tot=0.1177]


Epoch 18 ▶ Train RMSE: 0.343 | Val RMSE: 0.248
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch18.pt


Epoch 19 Train: 100%|██████████| 3421/3421 [04:38<00:00, 12.30batch/s, batch=0.1447, avg_data=0.0456, avg_phys=0.7004, avg_tot=0.1157]


Epoch 19 ▶ Train RMSE: 0.340 | Val RMSE: 0.233
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch19.pt


Epoch 20 Train: 100%|██████████| 3421/3421 [04:38<00:00, 12.29batch/s, batch=0.0373, avg_data=0.0458, avg_phys=0.7005, avg_tot=0.1159]


Epoch 20 ▶ Train RMSE: 0.340 | Val RMSE: 0.212
✅ Saved checkpoint: /content/drive/MyDrive/Model_vit_PINN_transformer_30m_Checkpoints/pinn_epoch20.pt
✅ Training finished
