In [None]:
# Cell 1: Force fork start method (must be first!)
#import multiprocessing as mp
#mp.set_start_method('fork', force=True)

In [None]:
# Cell 2: Install dependencies
!pip install segmentation-models-pytorch timm imagecodecs tifffile




In [None]:
# Cell 2: Install dependencies (with TIFF codecs)
#!pip install segmentation-models-pytorch timm tifffile[all] imagecodecs


In [None]:
# Cell 3: Imports & Drive mount
import imagecodecs  # enable LZW TIFF support
import os
import pandas as pd
from google.colab import drive
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from tifffile import imread
import numpy as np
from tqdm.auto import tqdm
import torch.nn.functional as F

drive.mount('/content/drive', force_remount=True)


Mounted at /content/drive


In [None]:
# Cell 1 ▶ Dataset (224×224 targets, no down-scaling)
class LSTDataset(Dataset):
    def __init__(self, df, patches_dir, weather_cols):
        self.df           = df.reset_index(drop=True)
        self.patches_dir  = patches_dir
        self.weather_cols = weather_cols
        self.transform    = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),          # resize image
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std =[0.229, 0.224, 0.225]),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row  = self.df.loc[idx]
        arr  = imread(os.path.join(self.patches_dir, row["filename"])
                     ).astype(np.float32)            # (4,H,W)

        # --- inputs -------------------------------------------------
        img_np = arr[[1,2,3]].transpose(1,2,0).astype(np.uint8)
        img    = self.transform(img_np)              # [3,224,224]

        # --- target (LST) at *full* 224×224 --------------------------
        lst    = arr[0]                              # (H,W)
        lst    = torch.tensor(lst, dtype=torch.float32).unsqueeze(0)
        lst    = F.interpolate(lst.unsqueeze(0), size=(224,224),
                               mode='bilinear', align_corners=False
                              ).squeeze(0)           # [1,224,224]

        # --- meteorology vector -------------------------------------
        weather = torch.tensor(
            row[self.weather_cols].values.astype(np.float32)
        )
        return img, weather, lst


In [None]:
# New Cell (before Cell 3)
!pip install rasterio


Collecting rasterio
  Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Collecting cligj>=0.5 (from rasterio)
  Downloading cligj-0.7.2-py3-none-any.whl.metadata (5.0 kB)
Collecting click-plugins (from rasterio)
  Downloading click_plugins-1.1.1-py2.py3-none-any.whl.metadata (6.4 kB)
Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.2/22.2 MB[0m [31m90.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading cligj-0.7.2-py3-none-any.whl (7.1 kB)
Downloading affine-2.4.0-py3-none-any.whl (15 kB)
Downloading click_plugins-1.1.1-py2.py3-none-any.whl (7.5 kB)
Installing collected packages: cligj, click-plugins, affine, rasterio
Successfully installed affine-2.4.0 click-plugins-1.1.1 cligj-0.7.2 rasterio-1.4.3


In [None]:
# define which meteorological columns to pull
weather_cols = [
    "air_temp_C",
    "dew_point_C",
    "relative_humidity_percent",
    "wind_speed_m_s",
    "precipitation_in",
]

In [None]:
import pandas as pd
import numpy as np

In [None]:
df = pd.read_csv("/content/drive/MyDrive/PatchedOutput/tiff_with_meteo.csv")
for col in weather_cols:
    df[col] = pd.to_numeric(df[col], errors="coerce")
df = df.dropna(subset=weather_cols + ["filename"]).reset_index(drop=True)

patches_dir = "/content/drive/MyDrive/PatchedOutput_Cleaned"
dataset     = LSTDataset(df, patches_dir, weather_cols)
train_sz    = int(0.8 * len(dataset))
val_sz      = len(dataset) - train_sz
train_ds, val_ds = random_split(dataset, [train_sz, val_sz])

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True,  num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch_size=4, shuffle=False, num_workers=0, pin_memory=False)

In [None]:
# Cell 5: TransUNet-like model (MiT-B0 backbone) + Focal-Tversky + optimizer
import segmentation_models_pytorch as smp

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# UNet with MiT-B0 backbone (transformer-based encoder)
model = smp.Unet(
    encoder_name='mit_b0',         # <-- use a supported MiT backbone
    encoder_weights='imagenet',
    in_channels=3,
    classes=1
).to(DEVICE)

def focal_tversky_loss(logits, targets,
                       alpha=0.7, beta=0.3, gamma=0.75, eps=1e-6):
    probs = torch.sigmoid(logits)
    dims  = (2,3)
    TP = (probs * targets).sum(dim=dims)
    FN = ((1-probs) * targets).sum(dim=dims)
    FP = (probs * (1-targets)).sum(dim=dims)

    denom = TP + alpha*FN + beta*FP + eps
    denom = torch.clamp(denom, min=eps)
    tversky = (TP + eps) / denom
    return ((1 - tversky) ** gamma).mean()

optimizer = optim.Adam(model.parameters(), lr=1e-5)
scaler    = torch.cuda.amp.GradScaler()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/135 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/14.3M [00:00<?, ?B/s]

  scaler    = torch.cuda.amp.GradScaler()


In [None]:
# Cell 6: Training loop w/ live tqdm, grad-clip, no NaNs
NUM_EPOCHS = 10
best_val   = float('inf')

for epoch in range(1, NUM_EPOCHS+1):
    print(f"\n=== Epoch {epoch}/{NUM_EPOCHS} ===")

    # — Training —
    train_losses = []
    model.train()
    train_bar = tqdm(train_loader, desc='Train', leave=False)
    for imgs, _, masks in train_bar:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        masks = torch.clamp(masks, 0.0, 1.0)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            preds = model(imgs)
            if masks.shape[2:] != preds.shape[2:]:
                masks = F.interpolate(
                    masks, size=preds.shape[2:], mode='nearest'
                )
            loss = focal_tversky_loss(preds, masks)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        train_losses.append(loss.item())
        train_bar.set_postfix(loss=f"{loss.item():.4f}")

    avg_train = sum(train_losses) / len(train_losses)

    # — Validation —
    val_losses = []
    model.eval()
    valid_bar = tqdm(val_loader, desc='Valid', leave=False)
    with torch.no_grad():
        for imgs, _, masks in valid_bar:
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            masks = torch.clamp(masks, 0.0, 1.0)

            preds = model(imgs)
            if masks.shape[2:] != preds.shape[2:]:
                masks = F.interpolate(
                    masks, size=preds.shape[2:], mode='nearest'
                )
            loss = focal_tversky_loss(preds, masks)

            val_losses.append(loss.item())
            valid_bar.set_postfix(loss=f"{loss.item():.4f}")

    avg_val = sum(val_losses) / len(val_losses)
    print(f"Epoch {epoch}/{NUM_EPOCHS} → "
          f"Avg Train: {avg_train:.4f} | Avg Val: {avg_val:.4f}")

    scheduler.step(avg_val)
    if avg_val < best_val:
        best_val = avg_val
        torch.save(
            model.state_dict(),
            "/content/drive/MyDrive/best_transunet_ftl.pth"
        )
        print("🔖 Saved new best model!")

print("✅ Training complete.")



=== Epoch 1/10 ===


Train:   0%|          | 0/2858 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda

Valid:   0%|          | 0/715 [00:00<?, ?it/s]

Epoch 1/10 → Avg Train: 0.4391 | Avg Val: 0.3627
🔖 Saved new best model!

=== Epoch 2/10 ===


Train:   0%|          | 0/2858 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda

Valid:   0%|          | 0/715 [00:00<?, ?it/s]

Epoch 2/10 → Avg Train: 0.3558 | Avg Val: 0.3287
🔖 Saved new best model!

=== Epoch 3/10 ===


Train:   0%|          | 0/2858 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda

Valid:   0%|          | 0/715 [00:00<?, ?it/s]

Epoch 3/10 → Avg Train: 0.3329 | Avg Val: 0.3142
🔖 Saved new best model!

=== Epoch 4/10 ===


Train:   0%|          | 0/2858 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda

Valid:   0%|          | 0/715 [00:00<?, ?it/s]

Epoch 4/10 → Avg Train: 0.3244 | Avg Val: 0.3096
🔖 Saved new best model!

=== Epoch 5/10 ===


Train:   0%|          | 0/2858 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda

Valid:   0%|          | 0/715 [00:00<?, ?it/s]

Epoch 5/10 → Avg Train: 0.3212 | Avg Val: 0.3104

=== Epoch 6/10 ===


Train:   0%|          | 0/2858 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda

Valid:   0%|          | 0/715 [00:00<?, ?it/s]

Epoch 6/10 → Avg Train: 0.3199 | Avg Val: 0.3082
🔖 Saved new best model!

=== Epoch 7/10 ===


Train:   0%|          | 0/2858 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda

Valid:   0%|          | 0/715 [00:00<?, ?it/s]

Epoch 7/10 → Avg Train: 0.3187 | Avg Val: 0.3089

=== Epoch 8/10 ===


Train:   0%|          | 0/2858 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda

Valid:   0%|          | 0/715 [00:00<?, ?it/s]

Epoch 8/10 → Avg Train: 0.3182 | Avg Val: 0.3093

=== Epoch 9/10 ===


Train:   0%|          | 0/2858 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda

Valid:   0%|          | 0/715 [00:00<?, ?it/s]

Epoch 9/10 → Avg Train: 0.3181 | Avg Val: 0.3087

=== Epoch 10/10 ===


Train:   0%|          | 0/2858 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda

Valid:   0%|          | 0/715 [00:00<?, ?it/s]

Epoch 10/10 → Avg Train: 0.3177 | Avg Val: 0.3079
🔖 Saved new best model!
✅ Training complete.
