In [1]:
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
import torch.nn as nn
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms 
from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os

In [2]:
Re = 400
bs = 16
epochs = 100

In [13]:
class CustomDatasetUV(Dataset):
    def __init__(self, root_dir, data_transform=None, label_transform=None):
        self.root_dir = root_dir
        self.data_transform = data_transform
        self.label_transform = label_transform
        self.image_paths = []
        self.labels_path = []
        image_dir = os.path.join(root_dir, 'Mesh')
        u_labels_dir = os.path.join(root_dir, 'W_Data')
        v_labels_dir = os.path.join(root_dir, 'W_Data')
        for file in os.listdir(image_dir):
            Ulabel_file = 'U' + file[4:]
            Vlabel_file = 'U' + file[4:]
            if file.endswith('.png'):
                self.image_paths.append(os.path.join(image_dir, file))
                self.labels_path.append((os.path.join(u_labels_dir, Ulabel_file), os.path.join(v_labels_dir, Vlabel_file)))


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label_path = self.labels_path[idx]
        image = Image.open(img_path).convert('L')
        Ulabel = Image.open(label_path[0])
        Ulabel = Ulabel.transpose(Image.FLIP_TOP_BOTTOM)
        Ulabel = Ulabel.transpose(Image.FLIP_LEFT_RIGHT)
        Vlabel = Image.open(label_path[1])
        Vlabel = Vlabel.transpose(Image.FLIP_TOP_BOTTOM)
        Vlabel = Vlabel.transpose(Image.FLIP_LEFT_RIGHT)

        image = self.data_transform(image)
        Ulabel = self.label_transform(Ulabel)
        Vlabel = self.label_transform(Vlabel)
        label = torch.concat((Ulabel, Vlabel), dim=0)
        return image, label
    
class CustomDatasetMag(Dataset):
    def __init__(self, root_dir, data_transform=None, label_transform=None):
        self.root_dir = root_dir
        self.data_transform = data_transform
        self.label_transform = label_transform
        self.image_paths = []
        self.labels_path = []
        image_dir = os.path.join(root_dir, 'Mesh')
        labels_dir = os.path.join(root_dir, 'W_Data')
        for file in os.listdir(image_dir):
            label_file = 'U' + file[4:]
            if file.endswith('.png'):
                self.image_paths.append(os.path.join(image_dir, file))
                self.labels_path.append(os.path.join(labels_dir, label_file))


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label_path = self.labels_path[idx]
        image = Image.open(img_path).convert('L')
        label = Image.open(label_path)

        image = self.data_transform(image)
        label = self.label_transform(label)
        return image, label
    
    # Define transformations
img_transform = transforms.Compose([
    transforms.Resize((128, 256)),
    transforms.ToTensor()
])

label_transform = transforms.Compose([
    transforms.Resize((128, 256)),
    transforms.ToTensor()
])
# Path to your dataset folder
dataset_root = ".\\"

# Create dataset
dataset = CustomDatasetMag(dataset_root, data_transform=img_transform, label_transform = label_transform)

# Split dataset into train and validation
dataset_size = len(dataset)
val_split = 0.2
val_size = int(val_split * dataset_size)
train_size = dataset_size - val_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [14]:
class ChannelAttention(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelAttention, self).__init__()
        self.gate_channels = gate_channels
        self.pool_types = pool_types
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        channel_att_sum = 0
        for pool_type in self.pool_types:
            if pool_type=='avg':
                avg_pool = self.avg_pool(x)
                channel_att_raw = self.mlp(avg_pool)
            elif pool_type=='max':
                max_pool = self.max_pool(x)
                channel_att_raw = self.mlp(max_pool)

        channel_att_sum = channel_att_sum + channel_att_raw

        scale = self.sigmoid(x).expand_as(x)
        return x * scale

In [15]:
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input):
        avg_out = torch.mean(input, dim=1, keepdim=True)
        max_out, _ = torch.max(input, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return input * self.sigmoid(x)

In [16]:
class Down(nn.Module) :
    '''
    Encoder Block
    '''
    def __init__(self, in_ch, out_ch) : 
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=4, padding=1, stride= 2),
            nn.ReLU(),
            nn.BatchNorm2d(out_ch)
        )
        
    def forward(self, x) :
        return self.net(x)

In [17]:
class Up(nn.Module):
    def __init__(self, out_size, insize, upsize, att_type='channel'):  #channel or spatial
        super(Up, self).__init__()
        self.up = nn.Sequential(
            nn.ConvTranspose2d(insize, upsize, 4, 2, 1, bias=False),
            nn.BatchNorm2d(upsize),
            nn.ReLU(True)
        )

        if att_type == 'channel':
            self.att = ChannelAttention(256)
        elif att_type == 'spatial':
            self.att = SpatialAttention()
        else:
            self.att = nn.Identity()

    def forward(self, inputs1, inputs2):
        inputs1_att = self.att(inputs1)
        outputs = torch.cat([inputs1_att, self.up(inputs2)], 1)
        # outputs = self.conv1(outputs)
        # outputs = self.att(outputs) * outputs
        # outputs = self.conv2(outputs)
        return outputs

In [18]:
class UnetAM(nn.Module):
    def __init__(self, in_channels=1):
        super(UnetAM, self).__init__()                                   # 64 512

        self.down1 = Down(in_channels, 16)
        self.down2 = Down(16, 32)
        self.down3 = Down(32, 32)
        self.down4 = Down(32, 64)
        self.down5 = Down(64, 64)
        self.down6 = Down(64, 256)
        self.down7 = Down(256, 256)

        self.c_att = ChannelAttention(256)

        self.up = nn.Sequential(
            nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(2),
            nn.ReLU(True)
        )
        
        # upsampling  #channel or spatial
        self.up_concat6 = Up(512, 256, 256, att_type='spatial')
        self.up_concat5 = Up(128, 512, 64, att_type='spatial')
        self.up_concat4 = Up(128, 128, 64, att_type='spatial')
        self.up_concat3 = Up(64, 128, 32, att_type='spatial')
        self.up_concat2 = Up(64, 64, 32, att_type='spatial')
        self.up_concat1 = Up(32, 64, 16, att_type='spatial')
        self.final = nn.Sequential(
            nn.ConvTranspose2d(32, in_channels * 3, 4, 2, 1, bias=False),
            nn.ReLU(True)
        )
        
        
    def forward(self, inputs):
        h, w = inputs.shape[2], inputs.shape[3]
        feat1 = self.down1(inputs)  # (128, 256, 256)
        feat2 = self.down2(feat1)  # (128, 128, 128)
        feat3 = self.down3(feat2)  # (256, 64, 64)
        feat4 = self.down4(feat3)  # (256, 32, 32)
        feat5 = self.down5(feat4)  # (512, 16, 16)
        feat6 = self.down6(feat5)
        feat7 = self.down7(feat6)
        feat7 = self.c_att(feat7) + feat7

        up6 = self.up_concat6(feat6, feat7)
        up5 = self.up_concat5(feat5, up6)  # (256, 32, 32)
        up4 = self.up_concat4(feat4, up5)
        up3 = self.up_concat3(feat3, up4)
        up2 = self.up_concat2(feat2, up3)
        up1 = self.up_concat1(feat1, up2)
        final = self.final(up1)

        return final

In [19]:
class DataLoss(nn.Module):
    def __init__(self):
        super(DataLoss, self).__init__()
        
    def forward(self, true, pred):
        data_loss = (torch.abs(true - pred).sum()) / (6 * true.shape[0] * true.shape[2] * true.shape[3])
        return data_loss
    
class PhysicsLoss(nn.Module):
    def __init__(self):
        super(PhysicsLoss, self).__init__()
        
    def forward(self, true, pred):
        ui = true[:, 0, 1:-1, 1:-1]
        vi = true[:, 1, 1:-1, 1:-1]
        ut = pred[:, 0, 1:-1, 1:-1]
        vt = pred[:, 1, 1:-1, 1:-1]

        dui_dx = (true[:, 0, 2:, 1:-1] - true[:, 0, :-2, 1:-1] ) / 2
        dui_dy = (true[:, 0, 1:-1:, 2:] - true[:, 0, 1:-1, :-2] ) / 2
        dvi_dy = (true[:, 1, 1:-1, 2:] - true[:, 1, 1:-1, :-2] ) / 2
        dvi_dx = (true[:, 1, 2:, 1:-1] - true[:, 1, :-2, 1:-1] ) / 2
        dut_dx = (pred[:, 0, 2:, 1:-1] - pred[:, 0, :-2, 1:-1] ) / 2
        dut_dy = (pred[:, 0, 1:-1:, 2:] - pred[:, 0, 1:-1, :-2] ) / 2
        dvt_dy = (pred[:, 1, 1:-1, 2:] - pred[:, 1, 1:-1, :-2] ) / 2
        dvt_dx = (pred[:, 1, 2:, 1:-1] - pred[:, 1, :-2, 1:-1] ) / 2

        d2ui_dx2 = (true[:, 0, 2:, 1:-1] - 2 * true[:, 0, 1:-1, 1:-1]  + true[:, 0, :-2, 1:-1]) 
        d2ui_dy2 = (true[:, 0, 1:-1, 2:] - 2 * true[:, 0, 1:-1, 1:-1]  + true[:, 0, 1:-1, :-2]) 
        d2vi_dy2 = (true[:, 1, 1:-1, 2:] - 2 * true[:, 1, 1:-1, 1:-1] + true[:, 1, 1:-1, :-2])
        d2vi_dx2 = (true[:, 1, 2:, 1:-1] - 2 * true[:, 1, 1:-1, 1:-1]  + true[:, 1, :-2, 1:-1])
        d2ut_dx2 = (pred[:, 0, 2:, 1:-1] - 2 * pred[:, 0, 1:-1, 1:-1]  + pred[:, 0, :-2, 1:-1]) 
        d2ut_dy2 = (pred[:, 0, 1:-1, 2:] - 2 * pred[:, 0, 1:-1, 1:-1]  + pred[:, 0, 1:-1, :-2]) 
        d2vt_dy2 = (pred[:, 1, 1:-1, 2:] - 2 * pred[:, 1, 1:-1, 1:-1] + pred[:, 1, 1:-1, :-2])
        d2vt_dx2 = (pred[:, 1, 2:, 1:-1] - 2 * pred[:, 1, 1:-1, 1:-1]  + pred[:, 1, :-2, 1:-1])

        mass_loss = torch.abs((dui_dx + dvi_dy) - (dut_dx + dvt_dy)).sum() / (true.shape[0] * (true.shape[2] - 2) * (true.shape[3] - 2))

        momentum_loss_x = torch.abs(((2 * dui_dx + ui * dvi_dy + vi * dui_dy) - (d2ui_dx2 + d2ui_dy2) / Re) - ((2 * dut_dx + ut * dvt_dy + vt * dut_dy) - (d2ut_dx2 + d2ut_dy2) / Re))
        momentum_loss_y = torch.abs(((2 * dvi_dy + ui * dvi_dx + vi * dui_dx) - (d2vi_dx2 + d2vi_dy2) / Re) - ((2 * dvt_dy + ut * dvt_dx + vt * dut_dx) - (d2vt_dx2 + d2vt_dy2) / Re))

        total_momentum_loss = (momentum_loss_x + momentum_loss_y).sum() / (true.shape[0] * (true.shape[2] - 2) * (true.shape[3] - 2))
        return mass_loss, total_momentum_loss
    
class Loss(nn.Module):
    def __init__(self, alpha_1 = 0.33, alpha_2 = 1.67, alpha_3 = 8.33) -> None:
        super(Loss, self).__init__()
        self.a1 = alpha_1
        self.a2 = alpha_2
        self.a3 = alpha_3
        self.data_loss = DataLoss()
        self.physics_loss = PhysicsLoss()

    def forward(self, true, pred):
        loss_data = self.data_loss(true, pred)
        loss_mass, loss_momentum = self.physics_loss(true, pred)

        total_loss = self.a1 * loss_data + self.a2 * loss_mass + self.a3 * loss_momentum
        return total_loss

In [20]:
def MRE(true, pred):
    return ((1 / true.shape[0]) * ((torch.abs(true - pred)).sum()) / true.sum())

In [21]:
class UnetAM_Module(L.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        self.model = UnetAM()
        self.loss = DataLoss()

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=1e-4)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.1)
        return [optimizer], [scheduler]
    
    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = self.loss(labels, preds)
        mre = MRE(labels, preds)

        self.log("train/loss", loss)
        self.log("train/mre", mre)
        return loss
    
    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = self.loss(labels, preds)
        mre = MRE(labels, preds)

        self.log("val/loss", loss)
        self.log("val/mre", mre)

In [22]:
model = UnetAM_Module()
checkpoint_callback = ModelCheckpoint(monitor="val/mre", mode="min")
trainer = L.Trainer(max_epochs=epochs,  callbacks=[checkpoint_callback])
trainer.fit(model, train_loader, val_loader)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type     | Params
-----------------------------------
0 | model | UnetAM   | 4.3 M 
1 | loss  | DataLoss | 0     
-----------------------------------
4.3 M     Trainable params
0         Non-trainable params
4.3 M     Total params
17.259    Total estimated model params size (MB)


Epoch 99: 100%|██████████| 181/181 [01:26<00:00,  2.08it/s, v_num=14]      

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 181/181 [01:26<00:00,  2.08it/s, v_num=14]
