In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
# %matplotlib widget
plt.style.use('classic')
import warnings

import seaborn as sns
sns.set_theme(style="whitegrid")

import torch
from torch.utils.data import Dataset, DataLoader, Subset
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor, transforms
from torchsummary import summary
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau


import netCDF4

import scipy.io
from scipy.signal import convolve2d
from sklearn.model_selection import train_test_split

from scipy.special import kl_div
import os
import time
import datetime
from scipy.interpolate import griddata
import random

from pytorch_msssim import SSIM, ms_ssim, ssim, MS_SSIM
import gc

print("Ready to Go!")

Ready to Go!


In [2]:
# Create a MPS device
device = torch.device("mps")
device

device(type='mps')

# Initial U-Net Model

In [4]:
# Base Model 3d

# Check if MPS is available, otherwise fallback to CPU
if torch.backends.mps.is_available():
    device_3d = torch.device("mps")

class UNet_3d(nn.Module):
    def __init__(self):
        super(UNet_3d, self).__init__()

        # Contracting path
        self.enc_conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),  # Changed input channels from 1 to 7
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )
        self.pool1 = nn.MaxPool2d(2)

        self.enc_conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )
        self.pool2 = nn.MaxPool2d(2)

        self.enc_conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )
        self.pool3 = nn.MaxPool2d(2)

        # Additional layer for deeper U-Net
        self.enc_conv4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )
        self.pool4 = nn.MaxPool2d(2)

        # Bottom layer
        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )

        # Expanding path
        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec_conv1 = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )

        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec_conv2 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )

        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec_conv3 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )

        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec_conv4 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )

        # Output layer
        self.output = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        # Contracting path
        c1 = self.enc_conv1(x)
        p1 = self.pool1(c1)

        c2 = self.enc_conv2(p1)
        p2 = self.pool2(c2)

        c3 = self.enc_conv3(p2)
        p3 = self.pool3(c3)

        c4 = self.enc_conv4(p3)
        p4 = self.pool4(c4)

        # Bottom layer
        b = self.bottleneck(p4)

        # Expanding path
        u1 = self.up1(b)
        cat1 = torch.cat((u1, c4), dim=1)
        dc1 = self.dec_conv1(cat1)

        u2 = self.up2(dc1)
        cat2 = torch.cat((u2, c3), dim=1)
        dc2 = self.dec_conv2(cat2)

        u3 = self.up3(dc2)
        cat3 = torch.cat((u3, c2), dim=1)
        dc3 = self.dec_conv3(cat3)

        u4 = self.up4(dc3)
        cat4 = torch.cat((u4, c1), dim=1)
        dc4 = self.dec_conv4(cat4)

        # Output layer
        out = self.output(dc4)

        return out


# Create the model and print its summary for 1 channels input
model_unet_3d = UNet_3d().to(device_3d)
_ = summary(model_unet_3d, (3, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 64, 32, 32]          --
|    └─Conv2d: 2-1                       [-1, 64, 32, 32]          1,792
|    └─BatchNorm2d: 2-2                  [-1, 64, 32, 32]          128
|    └─ReLU: 2-3                         [-1, 64, 32, 32]          --
|    └─Conv2d: 2-4                       [-1, 64, 32, 32]          36,928
|    └─BatchNorm2d: 2-5                  [-1, 64, 32, 32]          128
|    └─ReLU: 2-6                         [-1, 64, 32, 32]          --
|    └─Dropout: 2-7                      [-1, 64, 32, 32]          --
├─MaxPool2d: 1-2                         [-1, 64, 16, 16]          --
├─Sequential: 1-3                        [-1, 128, 16, 16]         --
|    └─Conv2d: 2-8                       [-1, 128, 16, 16]         73,856
|    └─BatchNorm2d: 2-9                  [-1, 128, 16, 16]         256
|    └─ReLU: 2-10                        [-1, 128, 16, 16]         --
|

# U-Net Model Ext ( With Direct Equation Equation )

In [5]:
# Model 3D Extended

# Check if MPS is available, otherwise fallback to CPU
if torch.backends.mps.is_available():
    device_3d_ext = torch.device("mps")
    

class UNet_3d_ext(nn.Module):
    def __init__(self):
        super(UNet_3d_ext, self).__init__()

        # Contracting path
        self.enc_conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )
        self.pool1 = nn.MaxPool2d(2)

        self.enc_conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )
        self.pool2 = nn.MaxPool2d(2)

        self.enc_conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )
        self.pool3 = nn.MaxPool2d(2)

        # Additional layer for deeper U-Net
        self.enc_conv4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )
        self.pool4 = nn.MaxPool2d(2)

        # Bottom layer
        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )

        # Expanding path
        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec_conv1 = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )

        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec_conv2 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )

        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec_conv3 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )

        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec_conv4 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )

        # Output layer
        self.output = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        # Contracting path
        c1 = self.enc_conv1(x)
        p1 = self.pool1(c1)

        c2 = self.enc_conv2(p1)
        p2 = self.pool2(c2)

        c3 = self.enc_conv3(p2)
        p3 = self.pool3(c3)

        c4 = self.enc_conv4(p3)
        p4 = self.pool4(c4)

        # Bottom layer
        b = self.bottleneck(p4)

        # Expanding path
        u1 = self.up1(b)
        cat1 = torch.cat((u1, c4), dim=1)
        dc1 = self.dec_conv1(cat1)

        u2 = self.up2(dc1)
        cat2 = torch.cat((u2, c3), dim=1)
        dc2 = self.dec_conv2(cat2)

        u3 = self.up3(dc2)
        cat3 = torch.cat((u3, c2), dim=1)
        dc3 = self.dec_conv3(cat3)

        u4 = self.up4(dc3)
        cat4 = torch.cat((u4, c1), dim=1)
        dc4 = self.dec_conv4(cat4)

        # Output layer
        I_p = self.output(dc4)  # Predicted output from U-Net

        # Equation implementation with the center channel of the input
        f = 0.9
        I_n_center = x[:, 1, :, :].unsqueeze(1)  # Extract center channel of input, reshaped to match I_p dimensions
        output_final = I_n_center - (I_n_center - I_p) * f  # Apply equation

        return output_final


# Create the extended model
model_unet_equation = UNet_3d_ext().to(device_3d_ext)

# Print summary of the model
# from torchsummary import summary
_ = summary(model_unet_equation, (3, 32, 32))


Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 64, 32, 32]          --
|    └─Conv2d: 2-1                       [-1, 64, 32, 32]          1,792
|    └─BatchNorm2d: 2-2                  [-1, 64, 32, 32]          128
|    └─ReLU: 2-3                         [-1, 64, 32, 32]          --
|    └─Conv2d: 2-4                       [-1, 64, 32, 32]          36,928
|    └─BatchNorm2d: 2-5                  [-1, 64, 32, 32]          128
|    └─ReLU: 2-6                         [-1, 64, 32, 32]          --
|    └─Dropout: 2-7                      [-1, 64, 32, 32]          --
├─MaxPool2d: 1-2                         [-1, 64, 16, 16]          --
├─Sequential: 1-3                        [-1, 128, 16, 16]         --
|    └─Conv2d: 2-8                       [-1, 128, 16, 16]         73,856
|    └─BatchNorm2d: 2-9                  [-1, 128, 16, 16]         256
|    └─ReLU: 2-10                        [-1, 128, 16, 16]         --
|

# U-Net Model Ext ( With Blending Layer )

In [6]:
# Check if MPS is available, otherwise fallback to CPU
if torch.backends.mps.is_available():
    device_3d_extMLP = torch.device("mps")

class UNet_3d_ext_with_MLP(nn.Module):
    def __init__(self):
        super(UNet_3d_ext_with_MLP, self).__init__()

        # Contracting path
        self.enc_conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )
        self.pool1 = nn.MaxPool2d(2)

        self.enc_conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )
        self.pool2 = nn.MaxPool2d(2)

        self.enc_conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )
        self.pool3 = nn.MaxPool2d(2)

        # Additional layer for deeper U-Net
        self.enc_conv4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )
        self.pool4 = nn.MaxPool2d(2)

        # Bottom layer
        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )

        # Expanding path
        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec_conv1 = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )

        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec_conv2 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )

        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec_conv3 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )

        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec_conv4 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2)
        )

        # Output layer (from U-Net)
        self.output = nn.Conv2d(64, 1, kernel_size=1)

        # Define the MLP or blending layer
        self.blending_mlp = nn.Sequential(
            nn.Linear(2 * 32 * 32, 1024),  # Input is concatenation of I_p and I_n_center (flattened)
            nn.ReLU(),
            nn.Linear(1024, 32 * 32),  # Output should match the shape of the final image
        )

    def forward(self, x):
        # Contracting path (same as before)
        c1 = self.enc_conv1(x)
        p1 = self.pool1(c1)

        c2 = self.enc_conv2(p1)
        p2 = self.pool2(c2)

        c3 = self.enc_conv3(p2)
        p3 = self.pool3(c3)

        c4 = self.enc_conv4(p3)
        p4 = self.pool4(c4)

        # Bottom layer (same as before)
        b = self.bottleneck(p4)

        # Expanding path (same as before)
        u1 = self.up1(b)
        cat1 = torch.cat((u1, c4), dim=1)
        dc1 = self.dec_conv1(cat1)

        u2 = self.up2(dc1)
        cat2 = torch.cat((u2, c3), dim=1)
        dc2 = self.dec_conv2(cat2)

        u3 = self.up3(dc2)
        cat3 = torch.cat((u3, c2), dim=1)
        dc3 = self.dec_conv3(cat3)

        u4 = self.up4(dc3)
        cat4 = torch.cat((u4, c1), dim=1)
        dc4 = self.dec_conv4(cat4)

        # Output from U-Net
        I_p = self.output(dc4)

        # Get the center channel of the input
        I_n_center = x[:, 1, :, :].unsqueeze(1)  # Reshape to match the U-Net output dimensions (Bx1x32x32)

        # Flatten the tensors to feed them into the MLP
        I_p_flat = I_p.view(I_p.size(0), -1)  # Flatten to (batch_size, 32 * 32)
        I_n_flat = I_n_center.view(I_n_center.size(0), -1)  # Flatten to (batch_size, 32 * 32)

        # Concatenate the flattened tensors (I_p and I_n_center)
        concat = torch.cat((I_p_flat, I_n_flat), dim=1)  # Shape: (batch_size, 2 * 32 * 32)

        # Pass the concatenated result through the MLP
        blended_output = self.blending_mlp(concat)  # Output shape: (batch_size, 32 * 32)

        # Reshape back to the original image dimensions
        output_final = blended_output.view(I_p.size(0), 1, 32, 32)  # Shape: (batch_size, 1, 32, 32)

        return output_final


# Create the extended model
UNet_3d_ext_MLP = UNet_3d_ext_with_MLP().to(device_3d_extMLP)

# Print summary of the model
# from torchsummary import summary
_ = summary(UNet_3d_ext_MLP, (3, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 64, 32, 32]          --
|    └─Conv2d: 2-1                       [-1, 64, 32, 32]          1,792
|    └─BatchNorm2d: 2-2                  [-1, 64, 32, 32]          128
|    └─ReLU: 2-3                         [-1, 64, 32, 32]          --
|    └─Conv2d: 2-4                       [-1, 64, 32, 32]          36,928
|    └─BatchNorm2d: 2-5                  [-1, 64, 32, 32]          128
|    └─ReLU: 2-6                         [-1, 64, 32, 32]          --
|    └─Dropout: 2-7                      [-1, 64, 32, 32]          --
├─MaxPool2d: 1-2                         [-1, 64, 16, 16]          --
├─Sequential: 1-3                        [-1, 128, 16, 16]         --
|    └─Conv2d: 2-8                       [-1, 128, 16, 16]         73,856
|    └─BatchNorm2d: 2-9                  [-1, 128, 16, 16]         256
|    └─ReLU: 2-10                        [-1, 128, 16, 16]         --
|