# PyTorch GPU version of COSMOS multi‑task nowcast
This notebook replicates the original TensorFlow implementation but runs fully on **PyTorch** with GPU support (CUDA, mixed‑precision). Functionality and results should be equivalent.

*GPU in use will be automatically detected; if you have installed CUDA correctly, training should take place on your RTX 3060.*

In [13]:

# Imports, GPU & Mixed‑Precision Setup

import os, glob, h5py, numpy as np, pandas as pd, matplotlib.pyplot as plt
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader, IterableDataset
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from sklearn.metrics import (confusion_matrix, precision_score, recall_score,
                             f1_score, accuracy_score, mean_squared_error, mean_absolute_error)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Running on device: {device}')

torch.backends.cudnn.benchmark = True        # speed
# optional: allow TF‑like float16 mixed precision
scaler = GradScaler()

# Configuration (identical to original)
DATA_DIR      = r"C:\college\CV\COSMOS\6C_full"
SEQ_LEN       = 4
PATCH_SIZE    = 32
BATCH_SIZE    = 16
EPOCHS        = 20
THRESHOLD     = 265.0
CV_THRESHOLD  = 260.0
FOG_THRESHOLD = 270.0

MODEL_DIR     = "checkpoints_pytorch"
os.makedirs(MODEL_DIR, exist_ok=True)

# Build list of sliding‑window sequences
all_files = sorted(glob.glob(os.path.join(DATA_DIR, "*.h5")))
sequences = [all_files[i:i+SEQ_LEN+1] for i in range(len(all_files)-SEQ_LEN)]
print(f"Total sequences: {len(sequences)}")


Running on device: cuda
Total sequences: 896


  scaler = GradScaler()


In [14]:

# Dataset utilities ----------------------------------------------------------
def load_multi(fp_seq):
    frames = []
    for fp in fp_seq[:SEQ_LEN]:
        with h5py.File(fp,'r') as f:
            cnt1, cnt2   = f['IMG_TIR1'][0][...], f['IMG_TIR2'][0][...]
            cnt_wv, cnt_mir = f['IMG_WV'][0][...], f['IMG_MIR'][0][...]
            cnt_vis      = f['IMG_VIS'][0][...]
            lut1, lut2   = f['IMG_TIR1_TEMP'][:],  f['IMG_TIR2_TEMP'][:]
            lut_wv, lut_mir = f['IMG_WV_TEMP'][:], f['IMG_MIR_TEMP'][:]
            lut_vis      = f['IMG_VIS_ALBEDO'][:]
        bt1 = lut1[cnt1]; bt2 = lut2[cnt2]
        wv  = lut_wv[cnt_wv]; mir = lut_mir[cnt_mir]
        vis = lut_vis[cnt_vis]
        frames.append(np.stack([bt1,bt2,wv,mir,vis],axis=-1)/300.0)      # Normalise

    X = np.stack(frames,axis=0).astype(np.float32)

    # Labels come from final frame in sequence
    with h5py.File(fp_seq[-1],'r') as f:
        cnt1, cnt2   = f['IMG_TIR1'][0][...], f['IMG_TIR2'][0][...]
        cnt_wv, cnt_mir = f['IMG_WV'][0][...], f['IMG_MIR'][0][...]
        lut1, lut2   = f['IMG_TIR1_TEMP'][:],  f['IMG_TIR2_TEMP'][:]
        lut_wv, lut_mir = f['IMG_WV_TEMP'][:], f['IMG_MIR_TEMP'][:]
    bt1_t = lut1[cnt1]; bt2_t = lut2[cnt2]
    wv_t  = lut_wv[cnt_wv]; mir_t = lut_mir[cnt_mir]

    last_mean  = bt1_t.mean()/300.0
    first_mean = X[0,...,0].mean()
    temp_trend = np.array([last_mean - first_mean],dtype=np.float32)

    labels = {
        'cloud'          : (bt1_t<THRESHOLD).astype(np.float32)[...,None],
        'convective'     : (bt1_t<CV_THRESHOLD).astype(np.float32)[...,None],
        'fog'            : (mir_t<FOG_THRESHOLD).astype(np.float32)[...,None],
        'moisture'       : (wv_t/300.0).astype(np.float32)[...,None],
        'thermo_contrast': ((bt2_t-bt1_t)/100.0).astype(np.float32)[...,None],
        'temp_trend'     : temp_trend
    }
    return X, labels

def random_crop(X, y):
    H,W = X.shape[1], X.shape[2]
    i = np.random.randint(0, H-PATCH_SIZE)
    j = np.random.randint(0, W-PATCH_SIZE)
    Xc = X[:, i:i+PATCH_SIZE, j:j+PATCH_SIZE, :]
    yc = {}
    for k,v in y.items():
        yc[k] = v[i:i+PATCH_SIZE, j:j+PATCH_SIZE] if v.ndim==3 else v
    return Xc, yc

class PatchDataset(IterableDataset):
    """Yields infinite random crops exactly like the TF generator."""
    def __init__(self, seqs):
        super().__init__()
        self.seqs = seqs

    def __iter__(self):
        while True:
            np.random.shuffle(self.seqs)
            for seq in self.seqs:
                X, y = load_multi(seq)
                Xc, yc = random_crop(X, y)
                # convert to tensors
                Xc = torch.from_numpy(Xc).permute(0,3,1,2)        # (T,C,H,W)
                targets = {
                    k: torch.from_numpy(v).permute(2,0,1) if v.ndim==3 else torch.from_numpy(v)
                    for k,v in yc.items()
                }
                yield Xc, targets


In [15]:

# DataLoader construction ----------------------------------------------------
import torch.multiprocessing as mp
mp.set_start_method('spawn', force=True)

split = int(0.9*len(sequences))
train_seqs = sequences[:split]
val_seqs   = sequences[split:]

train_dataset = PatchDataset(train_seqs)
val_dataset   = PatchDataset(val_seqs)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=0,            # use single process to avoid HDF5 deadlocks
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    num_workers=0,
    pin_memory=True
)


In [16]:

# Model definition -----------------------------------------------------------
class ConvLSTMCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size=3):
        super().__init__()
        padding = kernel_size//2
        self.hidden_channels = hidden_channels
        self.conv = nn.Conv2d(input_channels + hidden_channels,
                              4*hidden_channels, kernel_size,
                              padding=padding, bias=True)

    def forward(self, x, h_prev, c_prev):
        combined = torch.cat([x, h_prev], dim=1)
        gates = self.conv(combined)
        i, f, o, g = gates.chunk(4, dim=1)
        i = torch.sigmoid(i); f = torch.sigmoid(f)
        o = torch.sigmoid(o); g = torch.tanh(g)
        c = f * c_prev + i * g
        h = o * torch.tanh(c)
        return h, c

class ConvLSTM(nn.Module):
    """Multi-step ConvLSTM with optional sequence output."""
    def __init__(self, input_channels, hidden_channels, kernel_size=3, return_sequences=False):
        super().__init__()
        self.cell = ConvLSTMCell(input_channels, hidden_channels, kernel_size)
        self.return_sequences = return_sequences

    def forward(self, x):
        # x: (B, T, C, H, W)
        B, T, C, H, W = x.shape
        h = torch.zeros(B, self.cell.hidden_channels, H, W, device=x.device)
        c = torch.zeros_like(h)
        seq_out = []
        for t in range(T):
            h, c = self.cell(x[:, t], h, c)
            if self.return_sequences:
                seq_out.append(h)
        if self.return_sequences:
            return torch.stack(seq_out, dim=1)  # (B, T, hidden, H, W)
        else:
            return h  # (B, hidden, H, W)

class MultiTaskNowcast(nn.Module):
    def __init__(self):
        super().__init__()
        self.convlstm1 = ConvLSTM(5, 32, return_sequences=True)
        self.bn1 = nn.BatchNorm2d(32)
        self.convlstm2 = ConvLSTM(32, 16, return_sequences=False)
        self.bn2 = nn.BatchNorm2d(16)

        self.heads = nn.ModuleDict({
            'cloud'          : nn.Conv2d(16,1,1),
            'convective'     : nn.Conv2d(16,1,1),
            'fog'            : nn.Conv2d(16,1,1),
            'moisture'       : nn.Conv2d(16,1,1),
            'thermo_contrast': nn.Conv2d(16,1,1),
        })
        self.temp_pool = nn.AdaptiveAvgPool2d(1)
        self.temp_fc   = nn.Linear(16, 1)

    def forward(self, x):
        # x: (B, T, C, H, W)
        x = self.convlstm1(x)  # (B, T, 32, H, W)
        # apply batchnorm to each time step separately
        B, T, C, H, W = x.shape
        x = x.reshape(B*T, C, H, W)
        x = self.bn1(x)
        x = x.reshape(B, T, C, H, W)
        x = self.convlstm2(x)  # (B, 16, H, W)
        x = self.bn2(x)

        out = {}
        for k, layer in self.heads.items():
            if k in ['moisture', 'thermo_contrast']:
                out[k] = layer(x).squeeze(1)
            else:
                out[k] = layer(x)
        temp_feat = self.temp_pool(x).flatten(1)
        out['temp_trend'] = self.temp_fc(temp_feat)
        return out

model = MultiTaskNowcast().to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")


Model parameters: 0.07M


In [17]:

# Losses, optimizer, metrics -----------------------------------------------
bce = nn.BCEWithLogitsLoss()
mse = nn.MSELoss()
loss_weights = {'cloud':1,'convective':1,'fog':1,'moisture':0.5,'thermo_contrast':0.5,'temp_trend':0.1}

def compute_loss(preds, targets):
    loss = 0.0
    for k, w in loss_weights.items():
        if k in ['cloud','convective','fog']:
            loss += w * bce(preds[k].float(), targets[k].float())
        else:
            loss += w * mse(preds[k].float().squeeze(), targets[k].float().squeeze())
    return loss

optimizer = optim.Adam(model.parameters(), lr=1e-3)


In [18]:
from torch.amp import autocast

# Training loop --------------------------------------------------------------
best_acc = 0.0
train_history, val_history = [], []

for epoch in range(1, EPOCHS+1):
    model.train()
    train_loss = 0.0
    pbar = tqdm(range(len(train_seqs)//BATCH_SIZE), desc=f'Epoch {epoch}/{EPOCHS} [train]')
    train_iter = iter(train_loader)
    for _ in pbar:
        Xb, yb = next(train_iter)
        Xb = Xb.to(device, non_blocking=True)
        yb = {k: v.to(device, non_blocking=True) for k,v in yb.items()}

        optimizer.zero_grad()
        with autocast('cuda'):
            preds = model(Xb)
            loss = compute_loss(preds, yb)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()
        pbar.set_postfix({'loss': train_loss/(_+1)})
    train_loss /= len(pbar)
    train_history.append(train_loss)

    # Validation ------------------------------------------------------------
    model.eval()
    val_loss = 0.0
    correct_cloud, total_cloud = 0,0
    with torch.no_grad():
        pbar_val = tqdm(range(len(val_seqs)//BATCH_SIZE), desc=f'Epoch {epoch} [val]')
        val_iter = iter(val_loader)
        for _ in pbar_val:
            Xb, yb = next(val_iter)
            Xb = Xb.to(device, non_blocking=True)
            yb = {k: v.to(device, non_blocking=True) for k,v in yb.items()}
            with autocast(device_type='cuda'):
                preds = model(Xb)
                loss = compute_loss(preds, yb)
            val_loss += loss.item()

            preds_cloud = torch.sigmoid(preds['cloud'])>0.5
            correct_cloud += (preds_cloud == yb['cloud'].bool()).sum().item()
            total_cloud += yb['cloud'].numel()
            pbar_val.set_postfix({'loss': val_loss/(_+1), 'cloud_acc': correct_cloud/total_cloud})
    val_loss /= len(pbar_val)
    val_history.append(val_loss)
    cloud_acc = correct_cloud/total_cloud

    # Save checkpoint -------------------------------------------------------
    torch.save({'epoch':epoch,
                'model_state_dict':model.state_dict(),
                'optimizer_state_dict':optimizer.state_dict()},
                os.path.join(MODEL_DIR, f"model_epoch_{epoch:02d}.pt"))


Epoch 1/20 [train]:   2%|▏         | 1/50 [02:15<1:50:30, 135.31s/it, loss=2.88]


MemoryError: Unable to allocate 749. MiB for an array with shape (4, 3207, 3062, 5) and data type float32

In [None]:

# Final Evaluation on validation set ----------------------------------------
seg_keys = ["cloud","convective","fog"]
reg_keys = ["moisture","thermo_contrast","temp_trend"]

y_true_seg, y_pred_seg = {k:[] for k in seg_keys}, {k:[] for k in seg_keys}
y_true_reg, y_pred_reg = {k:[] for k in reg_keys}, {k:[] for k in reg_keys}

model.eval()
with torch.no_grad():
    for Xb, yb in tqdm(val_loader, total=100, desc="Evaluation"):  # limited batches for speed
        Xb = Xb.to(device, non_blocking=True)
        preds = model(Xb)
        # collect predictions
        for k in seg_keys:
            y_true_seg[k].append(yb[k].flatten().cpu().numpy())
            y_pred_seg[k].append(torch.sigmoid(preds[k]).flatten().cpu().numpy())
        for k in reg_keys:
            y_true_reg[k].append(yb[k].flatten().cpu().numpy())
            y_pred_reg[k].append(preds[k].flatten().cpu().numpy())

# Compute simple metrics
for k in seg_keys:
    yt = np.concatenate(y_true_seg[k])
    yp = np.concatenate(y_pred_seg[k])>0.5
    acc = accuracy_score(yt, yp)
    print(f"{k} accuracy: {acc:.3f}")

for k in reg_keys:
    yt = np.concatenate(y_true_reg[k])
    yp = np.concatenate(y_pred_reg[k])
    mse_val = mean_squared_error(yt, yp)
    print(f"{k} MSE: {mse_val:.4f}")
