In [None]:
import torch 
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Normalize
from torch import nn
import torch.optim
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print("Using device:", device)

In [None]:
def normalize(x):
    eta, u, v = x[:,0,:,:], x[:,1,:,:], x[:,2,:,:]
    u_max, u_min = u.max(), u.min()
    v_max, v_min = v.max(), v.min()
    eta_max, eta_min = eta.max(), eta.min()
    
    eta_norm = 2 * (eta - eta_min) / (eta_max - eta_min) - 1
    u_norm = 2 * (u - u_min) / (u_max - u_min) - 1
    v_norm = 2 * (v - v_min) / (v_max - v_min) - 1

    u_norm = u_norm.unsqueeze(1)
    v_norm = v_norm.unsqueeze(1)
    eta_norm = eta_norm.unsqueeze(1)

    x_norm = torch.cat((eta_norm, u_norm, v_norm), dim=1)

    return x_norm, [u_max, u_min, v_max, v_min, eta_max, eta_min]

def denormalize(x_norm, max_min_vals):
    eta_norm, u_norm, v_norm = x_norm[:,0,:,:], x_norm[:,1,:,:], x_norm[:,2,:,:]
    u_max, u_min, v_max, v_min, eta_max, eta_min  = max_min_vals

    u = ((u_norm + 1) / 2) * (u_max - u_min) + u_min
    v = ((v_norm + 1) / 2) * (v_max - v_min) + v_min
    eta = ((eta_norm + 1) / 2) * (eta_max - eta_min) + eta_min

    u = u.unsqueeze(1)
    v = v.unsqueeze(1)
    eta = eta.unsqueeze(1)
    x = torch.cat((eta, u, v), dim=1)
    
    return x
    
class dataset(Dataset): 
    def __init__(self, data, ic, length):
        self.data = data
        self.len = length
        self.ic = ic
    def __len__(self):
        return self.len
    def __getitem__(self, idx):
        dataPoint = self.data[idx, :, :, :]
        initial_condition = self.ic[idx, :, :, :]
        return dataPoint, initial_condition

def training_loop(dataLoader, model, device, optimizer, loss_fn, epochs, path, batchSize, Mse_loss_fn):
    for epoch in range(epochs):
        for batch in dataLoader:
            data, ic = batch
            data = data.to(device)
            ic = ic.to(device)
            optimizer.zero_grad()
            pred = model(data).view(batchSize, 3, 64, 64)
            loss = loss_fn(pred, ic, Mse_loss_fn)
            loss.backward()
            optimizer.step()
        print(f'epoch {epoch} : loss {loss.item()}') 
    torch.save(model.state_dict(), path)

def loss_function(output, target, Mse_loss_fn):
    # define a custom loss function for the model to train. Both output and pred have shapes [batchSize, 3, 64, 64].
    # goal is to calculate channel wise-loss, and encourage sparsity in u and v 
    # MSELoss_fn must be a object of class nn.MSELoss() to calculate loss in eta

    u_target, u_output = target[:, 1, :, :] , output[:, 1, :, :]
    v_target, v_output = target[:, 2, :, :] , output[:, 2, :, :]
    eta_target, eta_output = target[:, 0, :, :] , output[:, 0, :, :]

    eta_error = Mse_loss_fn(eta_target, eta_output)
    # u_error = lr * L1loss_fn(u_target, u_output) 
    # v_error = lr * L1loss_fn(v_target, v_output)
    u_error = Mse_loss_fn(u_target, u_output) 
    v_error = Mse_loss_fn(v_target, v_output)

    total_err = torch.stack([eta_error, u_error, v_error]).mean()

    return total_err

def plot_figs(idx, X1, X2):  
    # function used to plot actual vs predicted initial conditions for 1 datapoint
    x = np.linspace(0, 1, 64)
    y = np.linspace(0, 1, 64)
    X, Y = np.meshgrid(x, y)

    eta1 = X1[idx, 0, :, :].cpu()
    u1 = X1[idx, 1, :, :].cpu()
    v1 = X1[idx, 2, :, :].cpu()
    eta2 = X2[idx, 0, :, :].cpu()
    u2 = X2[idx, 1, :, :].cpu()
    v2 = X2[idx, 2, :, :].cpu()

    fig, axes = plt.subplots(2, 3, figsize=(18, 8), subplot_kw={'projection': '3d'})

    titles = [
        "Actual η", "Actual u", "Actual v",
        "Predicted η", "Predicted u", "Predicted v"
    ]
    Zs = [eta1, u1, v1, eta2, u2, v2]
    cmaps = ['viridis', 'viridis', 'plasma', 'viridis', 'viridis', 'plasma']

    for i, ax in enumerate(axes.flat):
        ax.plot_surface(X, Y, Zs[i], cmap=cmaps[i])
        if i==1 or i==2 or i==4 or i==5:
            ax.set_zlim(-1e-4, 1e-4)
        ax.set_title(titles[i], fontsize=18, pad=10)
        ax.set_xlabel('x', fontsize=12, labelpad=5)
        ax.set_ylabel('y', fontsize=12, labelpad=5)
        ax.tick_params(axis='both', labelsize=10)

    # Try to reduce whitespace (tight_layout often has limited effect on 3D)
    #plt.tight_layout(pad=2.0)
    plt.subplots_adjust(hspace=0.3, wspace=0.1)  # manual tuning

    plt.show()

In [None]:
class InverseCNNUpsampled(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.net = nn.Sequential(
            # ----encoder----
            nn.Conv2d(in_channels, 15, kernel_size=3, padding=1), # 3x64x64 --> 15x64x64
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0), # 15x64x64 --> 15x32x32
            nn.LeakyReLU(),

            nn.Conv2d(15, 30, kernel_size=3, padding=1), # 15x32x32 --> 30x32x32
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0), # 30x32x32 --> 30x16x16
            nn.LeakyReLU(),

            nn.Conv2d(30, 60, kernel_size=3, padding=1), # 30x16x16 --> 60x16x16
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0), # 60x16x16 --> 60x8x8
            nn.LeakyReLU(),

            nn.Conv2d(60, 128, kernel_size=3, padding=1), # 60x8x8 --> 128x8x8
            nn.LeakyReLU(),

            # -----decoder-----
            nn.Conv2d(in_channels=128, out_channels=60, kernel_size=3, stride=1, padding=1), # 128x8x8 --> 60x8x8
            nn.LeakyReLU(),

            nn.Upsample(scale_factor=2), # 60x8x8 --> 60x16x16
            nn.Conv2d(in_channels=60, out_channels=30, kernel_size=3, stride=1, padding=1), # 60x16x16 --> 30x16x16
            nn.LeakyReLU(),

            nn.Upsample(scale_factor=2), # 30x16x16 --> 30x32x32
            nn.Conv2d(in_channels=30, out_channels=15, kernel_size=3, stride=1, padding=1),  # 30x32x32 -> 15x32x32
            nn.LeakyReLU(),

            nn.Upsample(scale_factor=2), # 15x32x32 --> 15x64x64
            nn.ConvTranspose2d(in_channels=15, out_channels=3, kernel_size=3, stride=1, padding=1),  # 15x64x64 --> 3x64x64
            nn.Tanh(),
        )

    def forward(self, x):
        return self.net(x)

In [None]:
# if createSWEdata.py is being run for the first time, run the first block of code to 
# load data into single files
# if single files have already been created, then run only the second block of data to load the files

# --- creating single files of data ---
ic_names = ["Gaussian Bump", "2 Gaussian Bumps", "Sinusoidal Wave Pattern", "Flat Conditions"]
num_dataPoints = 4 * 2500
ic_set = torch.zeros(size=(num_dataPoints, 3, 64, 64)) # rows are eta,u,v for each datapoint. Columns are per datapoint. 
data_set = torch.zeros(size=(num_dataPoints, 3, 64, 64)) # each datapoint has shape 3x64x64 (3 channels over 64x64 grid)
case = 64
batches = 2500

i=1
for ic_name in ic_names:
    x = torch.load(f'new_data/train_x1_{ic_name}_allBatch_Close.pt')[:2500, :, :, :]
    print(x.shape)
    ic = torch.load(f'new_data/train_ic_{ic_name}_allBatch.pt')[:2500, :, :, :]
    print(x.shape)
    data_set[batches*(i-1): batches*i] = x
    ic_set[batches*(i-1): batches*i] = ic
    i = i+1
# save full datasets 
torch.save(data_set, f"new_data/training_data_10000.pt")
torch.save(ic_set, f"new_data/ic_training_data_10000.pt")

#
# --- loading the data into variables ---
#

data_set = torch.load('new_data/training_data_10000.pt').clone()
ic_set = torch.load('new_data/ic_training_data_10000.pt').clone()

In [None]:
size = 10000
batchSize = 125
# --- normalization of each channel using min and max, with output range [-1,1]
dataset_normalized, min_max_vals = normalize(data_set)
ic_normalized, min_max_vals_ic = normalize(ic_set)

DataSet = dataset(dataset_normalized, ic_normalized, size)
dataLoader = DataLoader(DataSet, shuffle=True, batch_size=batchSize, num_workers=0)
loss_fn = nn.MSELoss(reduction='none')

model_inverse = InverseCNNUpsampled(in_channels=3)
model_inverse = model_inverse.to(device)

# if model has already been stored, can directly load model using block of code below and skip training
''' 
state_dict= torch.load("inv_model_upsampled.pth", map_location='mps')
model_inverse.load_state_dict(state_dict) 
'''

optimizer_inverse = torch.optim.Adam(model_inverse.parameters(), lr=1e-3)


In [None]:
# --- training loop ---

epochs = 15
path = 'inv_model.pth' # define path to store model states
Mse_loss_fn = nn.MSELoss()

training_loop(dataLoader, model_inverse, device, optimizer_inverse, loss_function, epochs, path, batchSize, Mse_loss_fn)

In [None]:
# --- compute reconstructions and errors

test_dataLoader = DataLoader(DataSet, shuffle=False, batch_size=batchSize)
model = model_inverse
model.eval()
ics = torch.zeros(size=(size, 3, 64, 64)) # store initial conditions in new variable, in same order as reconstrctions
reconstructions = torch.zeros(size=(size, 3, 64, 64))
data_stored = torch.zeros(size=(size, 3, 64, 64)) # store data (input conditions) in a new variable, in same order as reconstructions
i=0
with torch.no_grad():
    for batch in test_dataLoader:
        data, ic = batch
        data, ic = data.to(device), ic.to(device)
        pred = model(data).view(batchSize, 3, 64, 64)
        reconstructions[i:batchSize+i, :, :, :] = pred
        ics[i:batchSize+i, :, :, :] = ic
        data_stored[i:batchSize+i, :, :, :] = data
        i += batchSize

reconstruction_denormalized = denormalize(reconstructions, min_max_vals_ic)
ics_denormalized = denormalize(ics, min_max_vals_ic)
data_denormalized = denormalize(data_stored, min_max_vals)

total_err = torch.norm(reconstruction_denormalized - ics_denormalized) / torch.norm(ics_denormalized)
eta_err = torch.norm(reconstruction_denormalized[:,0,:,:] - ics_denormalized[:,0,:,:] ) / torch.norm(ics_denormalized[:,0,:,:] )
u_err = torch.norm(reconstruction_denormalized[:,1,:,:]  - ics_denormalized[:,1,:,:] ) / torch.norm(ics_denormalized[:,1,:,:] )
v_err = torch.norm(reconstruction_denormalized[:,2,:,:]  - ics_denormalized[:,2,:,:] ) / torch.norm(ics_denormalized[:,2,:,:] )

print(f'Total Error : {err}')
print(f'eta error : {eta_err} , u error : {u_err} , v error : {v_err}')

In [None]:
# 
# --- plot initial conditions only, for 4 different data points ---
#

fig, axes = plt.subplots(2, 6, figsize=(15, 10), subplot_kw={'projection': '3d'})
x = np.linspace(0, 1, 64)
y = np.linspace(0, 1, 64)
X, Y = np.meshgrid(x, y)

indices = [0,3000,6000,9000]
gaussian = ics_denormalized[indices[0], :, :, :]
two_gaussian = ics_denormalized[indices[1], :, :, :]
sinusoidal = ics_denormalized[indices[2], :, :, :]
flat = ics_denormalized[indices[3], :, :, :]

print(Zs[1].shape)

Zs = [gaussian[0,:,:], gaussian[1,:,:], gaussian[2,:,:], two_gaussian[0,:,:], two_gaussian[1,:,:], two_gaussian[2,:,:],
    sinusoidal[0,:,:], sinusoidal[1,:,:], sinusoidal[2,:,:], flat[0,:,:], flat[1,:,:], flat[2,:,:] ]
titles = [ "η", "u", "v",  "η", "u", "v",  "η", "u", "v",  "η", "u", "v" ]
ics = ["IC: Gaussian Bump \nu","IC: 2 Gaussian Bumps \nu", "IC: Sinusoidal Wave \nu","IC: Flat Conditions \nu"]
cmaps = ['viridis', 'viridis', 'plasma', 'viridis', 'viridis', 'plasma',
        'viridis', 'viridis', 'plasma', 'viridis', 'viridis', 'plasma']

k=0
for i, ax in enumerate(axes.flat):
    ax.plot_surface(X, Y, Zs[i], cmap=cmaps[i])
    ax.set_title(titles[i], fontsize=20, pad=5)
    if ((i-1)%3==0):
        ax.set_title(ics[k], fontsize=20, pad=10)
        k += 1
    ax.set_xlabel('x', fontsize=12, labelpad=5)
    ax.set_ylabel('y', fontsize=12, labelpad=5)
    ax.tick_params(axis='both', labelsize=5)
    ax.axis('tight')
    
plt.subplots_adjust(hspace=0.1, wspace=0.1)  # manual tuning

plt.tight_layout()  # Adjusts spacing to prevent overlap
plt.show()

In [None]:
# 
# --- plot predictions vs actual initial conditions for 1 datapoint ---
#
idx = 3000
original = ics_denormalized
predicted = reconstruction_denormalized

plot_figs(idx, original, predicted)