**Model**

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)

class TransformerSkip(nn.Module):
    def __init__(self, channels, nhead=4, dim_feedforward=512):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=channels,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=1)

    def forward(self, x):
        B, C, H, W = x.shape
        tokens = x.view(B, C, -1).permute(0, 2, 1)      
        tokens = self.transformer(tokens)              
        x_out = tokens.permute(0, 2, 1).view(B, C, H, W)
        return x_out

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.pool    = nn.MaxPool2d(2)
        self.conv    = DoubleConv(in_ch, out_ch)
        self.skip_tr = TransformerSkip(out_ch)

    def forward(self, x):
        x = self.pool(x)
        x = self.conv(x)
        skip = self.skip_tr(x)
        return x, skip

class Up(nn.Module):

    def __init__(self, dec_ch, skip_ch, out_ch, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(dec_ch, dec_ch, kernel_size=2, stride=2)
        self.conv = DoubleConv(dec_ch + skip_ch, out_ch)

    def forward(self, x, skip):
        x = self.up(x)

        x = torch.cat([skip, x], dim=1)
        return self.conv(x)

class UNetWithTransformerSkips(nn.Module):
    def __init__(self, in_channels=38, n_classes=3, bilinear=True):
        super().__init__()

        self.inc    = DoubleConv(in_channels, 32)
        self.skip0  = TransformerSkip(32)

        self.down1  = Down(32,  64)
        self.down2  = Down(64, 128)
        self.down3  = Down(128, 256)

        self.bottleneck = DoubleConv(256, 256)

        self.up1 = Up(dec_ch=256, skip_ch=128, out_ch=128, bilinear=bilinear)
        self.up2 = Up(dec_ch=128, skip_ch=64, out_ch=64, bilinear=bilinear)
        self.up3 = Up(dec_ch=64, skip_ch=32,  out_ch=32,  bilinear=bilinear)

        self.seg_head = nn.Conv2d(32, n_classes, kernel_size=1)
        self.reg_head = nn.Conv2d(32, 1,         kernel_size=1)

    def forward(self, x):
        x0 = self.inc(x)         
        s0 = self.skip0(x0)      

     
        x1, s1 = self.down1(x0)
        x2, s2 = self.down2(x1)  
        x3, _  = self.down3(x2)  

  
        b = self.bottleneck(x3) 

 
        u1 = self.up1(b,  s2)     
        u2 = self.up2(u1, s1)     
        u3 = self.up3(u2, s0)    

        seg_logits = self.seg_head(u3)  
        reg_map    = self.reg_head(u3)  
        with torch.no_grad():
            predicted_class = torch.argmax(seg_logits, dim=1, keepdim=True) 
            class1_mask = (predicted_class == 1).float()                    
            masked_reg = reg_map * class1_mask 

        return seg_logits, masked_reg


**FirstDataLoader**

In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader

arr = np.load("ReadytoTrainDLNew2.npy")
features = arr[:, :38, :, :].astype(np.float32)    
rates    = arr[:, 40, :, :].astype(np.float32)     
flags    = arr[:, 39, :, :].astype(np.int64)

del arr

eps       = 1e-4
log_rates = np.log(rates + eps)                  

mean = features.mean(axis=(0,2,3), keepdims=True)
std  = features.std(axis=(0,2,3), keepdims=True)
features = (features - mean) / (std + 1e-6)

np.save('std.npy',std)
np.save('mean.npy',mean)

X_train, X_test, y_log_train, y_log_test, y_flag_train, y_flag_test = train_test_split(
    features, log_rates, flags,
    test_size=0.2,
    random_state=42
)
del log_rates,features,flags

class PrecipDataset(Dataset):
    def __init__(self, feats, log_r, flags):
        self.X      = torch.from_numpy(feats)               
        self.y_rate = torch.from_numpy(log_r)               
        self.y_flag = torch.from_numpy(flags)              

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y_flag[idx], self.y_rate[idx]

train_ds = PrecipDataset(X_train, y_log_train, y_flag_train)
test_ds  = PrecipDataset(X_test,  y_log_test,  y_flag_test)

del X_train, y_log_train, y_flag_train, X_test, y_log_test, y_flag_test

batch_size = 128

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=4)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=4)

del train_ds, test_ds


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNetWithTransformerSkips(in_channels=38, n_classes=2).to(device)


freq = torch.tensor([1, 1], device=device)
class_weights = (1.0 / freq)
class_weights = class_weights * (2.0 / class_weights.sum())

seg_criterion = nn.CrossEntropyLoss(weight=class_weights)
reg_criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=1e-2 / 2)

scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=15, T_mult=1, eta_min=1e-6)

lambda_reg = 0.09
num_epochs = 30

TRAIN = []
VAL = []
VAL_SEG = []
VAL_REG = []
for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss = 0.0
    for batch_idx, (X, flags, log_rates) in enumerate(train_loader):
        X, flags, log_rates = X.to(device), flags.to(device), log_rates.to(device)
        seg_logits, reg_map = model(X)

        loss_seg = seg_criterion(seg_logits, flags)
        loss_reg = reg_criterion(reg_map.squeeze(1), log_rates)
        loss = loss_seg + lambda_reg * loss_reg

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step(epoch + batch_idx / len(train_loader))

        running_loss += loss.item() * X.size(0)

    train_loss = running_loss / len(train_loader.dataset)


    model.eval()
    val_loss = 0.0
    val_seg_loss = 0.0
    val_reg_loss = 0.0
    correct_pixels = 0
    total_pixels = 0

    with torch.no_grad():
        for X, flags, log_rates in test_loader:
            X, flags, log_rates = X.to(device), flags.to(device), log_rates.to(device)

            seg_logits, reg_map = model(X)
            loss_seg = seg_criterion(seg_logits, flags)
            loss_reg = reg_criterion(reg_map.squeeze(1), log_rates)
            loss = loss_seg + lambda_reg * loss_reg

            batch_size = X.size(0)
            val_loss += loss.item() * batch_size
            val_seg_loss += loss_seg.item() * batch_size
            val_reg_loss += loss_reg.item() * batch_size

            preds = seg_logits.argmax(dim=1)
            correct_pixels += (preds == flags).sum().item()
            total_pixels += flags.numel()

    val_loss /= len(test_loader.dataset)
    val_seg_loss /= len(test_loader.dataset)
    val_reg_loss /= len(test_loader.dataset)
    val_acc = correct_pixels / total_pixels
    TRAIN.append(train_loss)
    VAL.append(val_loss)
    VAL_SEG.append(val_seg_loss)
    VAL_REG.append(val_reg_loss)
    

    print(f"Epoch {epoch:02d} | "
          f"Train Loss: {train_loss:.4f} | "
          f" Val Loss: {val_loss:.4f} (Seg: {val_seg_loss:.4f}, Reg: {val_reg_loss:.4f}) | "
          f" Seg Acc: {val_acc:.4f}")

    if val_seg_loss < 0.15:
        torch.save(model.state_dict(), "model_weights_R2.pth")
        np.save('TRAIN4.npy', np.array(TRAIN))
        np.save('VAL4.npy', np.array(VAL))
        np.save('VAL_SEG4.npy', np.array(VAL_SEG))
        np.save('VAL_REG4.npy', np.array(VAL_REG))
        
        break


**SecondDataLoader**

In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader

arr = np.load("FinalDATANew2.npy")


features = arr[:, :38, :, :].astype(np.float32)    
rates    = arr[:, 40, :, :].astype(np.float32)    
flags    = arr[:, 39, :, :].astype(np.int64)       


eps       = 1e-4
log_rates = np.log(rates + eps)                

mean = np.load('mean.npy')
std  = np.load('std.npy')
features = (features - mean) / (std + 1e-6)

X_train, X_test, y_log_train, y_log_test, y_flag_train, y_flag_test = train_test_split(
    features, log_rates, flags,
    test_size=0.2,
    random_state=42
)

class PrecipDataset(Dataset):
    def __init__(self, feats, log_r, flags):
        self.X      = torch.from_numpy(feats)              
        self.y_rate = torch.from_numpy(log_r)               
        self.y_flag = torch.from_numpy(flags)              

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y_flag[idx], self.y_rate[idx]

train_ds = PrecipDataset(X_train, y_log_train, y_flag_train)
test_ds  = PrecipDataset(X_test,  y_log_test,  y_flag_test)

batch_size = 128

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=4)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=4)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetWithTransformerSkips(in_channels=38, n_classes=2).to(device)
model.load_state_dict(torch.load('model_weights_R22.pth'))

for p in model.inc.double_conv.parameters():
    p.requires_grad = False

for down in (model.down1, model.down2, model.down3):
    for p in down.conv.double_conv.parameters():
        p.requires_grad = False


for name, p in model.named_parameters():
    if "double_conv" in name and any(x in name for x in ["inc", "down1", "down2", "down3"]):
        assert not p.requires_grad

optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4
)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

freq = torch.tensor([1, 1], device=device)
class_weights = (1.0 / freq)
class_weights = class_weights * (2.0 / class_weights.sum())

seg_criterion = nn.CrossEntropyLoss(weight=class_weights)
reg_criterion = nn.MSELoss()


lambda_reg = 0.2
num_epochs = 30

TRAIN = []
VAL = []
VAL_SEG = []
VAL_REG = []
for epoch in range(1, num_epochs + 1):
 
    model.train()
    running_loss = 0.0
    for batch_idx, (X, flags, log_rates) in enumerate(train_loader):
        X, flags, log_rates = X.to(device), flags.to(device), log_rates.to(device)
        seg_logits, reg_map = model(X)

        loss_seg = seg_criterion(seg_logits, flags)
        loss_reg = reg_criterion(reg_map.squeeze(1), log_rates)
        loss = loss_seg + lambda_reg * loss_reg

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * X.size(0)

    train_loss = running_loss / len(train_loader.dataset)

    model.eval()
    val_loss = 0.0
    val_seg_loss = 0.0
    val_reg_loss = 0.0
    correct_pixels = 0
    total_pixels = 0

    with torch.no_grad():
        for X, flags, log_rates in test_loader:
            X, flags, log_rates = X.to(device), flags.to(device), log_rates.to(device)

            seg_logits, reg_map = model(X)
            loss_seg = seg_criterion(seg_logits, flags)
            loss_reg = reg_criterion(reg_map.squeeze(1), log_rates)
            loss = loss_seg + lambda_reg * loss_reg

            batch_size = X.size(0)
            val_loss += loss.item() * batch_size
            val_seg_loss += loss_seg.item() * batch_size
            val_reg_loss += loss_reg.item() * batch_size

            preds = seg_logits.argmax(dim=1)
            correct_pixels += (preds == flags).sum().item()
            total_pixels += flags.numel()

    val_loss /= len(test_loader.dataset)
    val_seg_loss /= len(test_loader.dataset)
    val_reg_loss /= len(test_loader.dataset)
    val_acc = correct_pixels / total_pixels
    TRAIN.append(train_loss)
    VAL.append(val_loss)
    VAL_SEG.append(val_seg_loss)
    VAL_REG.append(val_reg_loss)

    

    print(f"Epoch {epoch:02d} | "
          f"Train Loss: {train_loss:.4f} | "
          f" Val Loss: {val_loss:.4f} (Seg: {val_seg_loss:.4f}, Reg: {val_reg_loss:.4f}) | "
          f" Seg Acc: {val_acc:.4f}")

torch.save(model.state_dict(), "model_weights_JustRAINTT.pth")
np.save('TRAIN4T.npy', np.array(TRAIN))
np.save('VAL4T.npy', np.array(VAL))
np.save('VAL_SEG4T.npy', np.array(VAL_SEG))
np.save('VAL_REG4T.npy', np.array(VAL_REG))
        

# Retrieving Rainfall on a Real Orbit

In [None]:
import h5py
import numpy as np
def CollectingData(filename):
    
    
    era5_features = ['tciw', 'tclw', 'tcwv', 't2m', 'cape', 'u10', 'v10', 'skt', 'asn','rsn', 'tcslw', 'tcw', 'lsm', 'mtpr', 'msr']

    
    dataset = h5py.File(filename, 'r')
    
    final = []
    for variables in era5_features: 
        data = dataset[f'/ERA5/{variables}']
        final.append(data)
    
    ERA5_Data = np.stack(final, axis=2)
    
    Latitude = dataset['/NewGrids/Latitude'][:]
    Longitude = dataset['/NewGrids/Longitude'][:]


    SP_Data = np.stack([Latitude,Longitude], axis=2)
    
    
    Tc = dataset['/NewGrids/Tc'][:]
    incidenceAngle = np.expand_dims(dataset['/NewGrids/incidenceAngle'][:][:,:,0], axis = 2)

    ATMS_data = np.concatenate((Tc,incidenceAngle), axis=2)


    Final_Data = np.concatenate((ATMS_data,SP_Data,ERA5_Data), axis=2)
    
    Final_Data = np.transpose(Final_Data, (2, 0, 1))
    
    Rain = (Final_Data[-2]-Final_Data[-1])*3600
    Snow = (Final_Data[-1])*3600
    R = ((Rain>0.05)&(Snow<0.025))
    S = (Snow>0.025)
    flag = np.zeros_like(Final_Data[-1])
    rate = np.zeros_like(Final_Data[-1])
    flag[S]=2
    flag[R]=1
    rate[S]=Snow[S]
    rate[R]=Rain[R]
    

    
    return np.concatenate((Final_Data[:-2],np.expand_dims(rate, axis=0),np.expand_dims(flag, axis=0)), axis=0)

def Expand(Data): 

    if Data.ndim < 4: 
        Data = np.expand_dims(Data, axis=0)


    Rate = np.zeros((Data.shape[0],2,Data.shape[-2],Data.shape[-1]))

    for i in range(Data.shape[0]): 

        RainMask = Data[i,-1]==1
        SnowMask = Data[i,-1]==2
        Rate[i,0][RainMask] = Data[i,-2][RainMask]
        Rate[i,1][SnowMask] = Data[i,-2][SnowMask]


    return np.concatenate((Data, Rate ), axis = 1).squeeze()


def RerievalDL(model,X,mean,std): 

        
    if X.ndim < 4: 
        X = np.expand_dims(X, axis=0)

    X = (X - mean) / (std + 1e-6)    

    eps = 1e-4  # must match what you used in training
    model.eval()
    Flag = np.zeros((X.shape[0],16,16))
    Rain_rate = np.zeros((X.shape[0],16,16))
    Snow_rate = np.zeros((X.shape[0],16,16))

    for i in range(X.shape[0]): 
        with torch.no_grad():
            
            
            seg_logits, reg_map = model(torch.from_numpy(X[i]).float().unsqueeze(dim=0))
            seg_preds = seg_logits.argmax(dim=1)
            log_pred   = reg_map.squeeze(1)         
            rate_pred  = torch.exp(log_pred) - eps   
            rate_pred  = torch.clamp(rate_pred, min=0.0)

    
        seg_numpy  = seg_preds.cpu().numpy().squeeze()   
        rate_numpy = rate_pred.cpu().numpy().squeeze()   
        Rain_rate[i] = rate_numpy

        Flag[i] = seg_numpy
        
           
    return np.concatenate((np.expand_dims(Flag, axis=1),np.expand_dims(Rain_rate, axis=1)), axis=1)  


import numpy as np

def extract_patches_with_padding(data, patch_size=16):
    C, H, W = data.shape
    pad_h = (patch_size - (H % patch_size)) % patch_size
    padded = np.pad(data, ((0, 0), (0, pad_h), (0, 0)), mode='constant')

    H_padded = padded.shape[1]
    patches = []
    for i in range(0, H_padded, patch_size):
        for j in range(0, W, patch_size):
            patch = padded[:, i:i+patch_size, j:j+patch_size]
            patches.append(patch)
    
    return np.stack(patches) 

def reconstruct_from_patches(patches, original_height=2236, width=176, patch_size=16):
    C = patches.shape[1]
    h_patches = (original_height + patch_size - 1) // patch_size
    w_patches = width // patch_size
    
    full_height = h_patches * patch_size
    reconstructed = np.zeros((C, full_height, width), dtype=patches.dtype)

    idx = 0
    for i in range(0, full_height, patch_size):
        for j in range(0, width, patch_size):
            reconstructed[:, i:i+patch_size, j:j+patch_size] = patches[idx]
            idx += 1

    return reconstructed[:, :original_height, :]



Filename = '1C.NPP.ATMS.XCAL2019-V.20230712-S071701-E085830.060650.V07A.ERA5.HDF5'


model = UNetWithTransformerSkips(in_channels=38, n_classes=2)
model.load_state_dict(torch.load("model_weights_R22.pth", map_location=torch.device('cpu')))
std = np.load('std.npy')
mean = np.load('mean.npy')

Flag, Rate = reconstruct_from_patches(RerievalDL(model,extract_patches_with_padding(Expand(CollectingData(Filename))[:38]),mean,std))

