In [None]:
import pickle

train_dataset = pickle.load(open("save/T007_Transformer/train_dataset.p", "rb"))
test_dataset = pickle.load(open("save/T007_Transformer/test_dataset.p", "rb"))

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
#from torch import nn
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.trainer import Trainer
from torch.utils.data import DataLoader

In [None]:
### from https://stackoverflow.com/questions/61616810/how-to-do-cubic-spline-interpolation-and-integration-in-pytorch

#import torch

def h_poly(t):
    tt = t[None, :]**torch.arange(4, device=t.device)[:, None]
    A = torch.tensor([
        [1, 0, -3, 2],
        [0, 1, -2, 1],
        [0, 0, 3, -2],
        [0, 0, -1, 1]
    ], dtype=t.dtype, device=t.device)
    return A @ tt


def interp(x, y, xs):
    m = (y[1:] - y[:-1]) / (x[1:] - x[:-1])
    m = torch.cat([m[[0]], (m[1:] + m[:-1]) / 2, m[[-1]]])
    idxs = torch.searchsorted(x[1:], xs)
    dx = (x[idxs + 1] - x[idxs])
    hh = h_poly((xs - x[idxs]) / dx)
    return hh[0] * y[idxs] + hh[1] * m[idxs] * dx + hh[2] * y[idxs + 1] + hh[3] * m[idxs + 1] * dx

In [None]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

In [None]:
class AirfoilModel(LightningModule):
    
    def __init__(self):
        super().__init__()
        
        self.batch_size = 128
        #self.hparams.batch_size = 64
        self.lr=1e-3
        
        self.train_variance = 1.0
        
        c1 = 16
        c2 = 8
        k2 = 8
        c3 = 32

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_0 = nn.Conv1d(in_channels=6,out_channels=c1,kernel_size=5,stride=1, padding=2)

        self.layer_11 = nn.Conv1d(in_channels=c1,out_channels=c2,kernel_size=3,stride=1, padding=1)
        self.layer_12 = nn.Conv1d(in_channels=c2,out_channels=c2,kernel_size=3,stride=1, padding=1)
        self.layer_13 = nn.Conv1d(in_channels=c2,out_channels=c2,kernel_size=3,stride=1, padding=1)
        self.layer_14 = nn.Conv1d(in_channels=c2,out_channels=c2,kernel_size=3,stride=1, padding=1)
        self.relu_11 = nn.ReLU()
        self.relu_12 = nn.ReLU()
        self.relu_13 = nn.ReLU()
        self.relu_14 = nn.ReLU()

        self.layer_2 = nn.Conv1d(in_channels=8*4,out_channels=c3,kernel_size=5,stride=1, padding=2)
        self.relu_2 = nn.ReLU()
        
        encoder_layers = nn.TransformerEncoderLayer(d_model=32, nhead=8, dim_feedforward=64, dropout=0.001)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=3)
        
        self.output = nn.Conv1d(in_channels=c3,out_channels=11,kernel_size=1, stride=1, padding=0)
        
        self.beta_dist = torch.distributions.beta.Beta(2,2)
        
        self.ma_aoa_lin1 = nn.Linear(2,16)
        self.ma_aoa_relu1 = nn.ReLU()
        self.ma_aoa_lin2 = nn.Linear(16,32)
        self.ma_aoa_relu2 = nn.ReLU()
        
        latenc_layers = nn.TransformerDecoderLayer(d_model=32, nhead=8, dim_feedforward=64, dropout=0.001)
        self.latenc_transformer = nn.TransformerDecoder(latenc_layers, num_layers=2)        
        
        decoder_layers = nn.TransformerDecoderLayer(d_model=32, nhead=8, dim_feedforward=64, dropout=0.001)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layers, num_layers=3)
        
    def data_augmentation(self, x):
        
        s = x[:,1]
        ds = s[:,1:] - s[:,:-1]
        
        rv = self.beta_dist.rsample([s.shape[0], s.shape[1]-2]).to(self.device)
        snew = s.clone()
        #print(snew.device, rv.device, ds.device)
        snew[:,1:-1] += 0.5 * ( (rv > 0)*ds[:,1:] - (rv < 0)* ds[:,:-1] )

        #xnew = interp(s, x, snew)
        
        xnew = torch.zeros_like(x, device=self.device)
        for i in range(x.shape[0]):
            for j in range(x.shape[1]):
                xnew[i,j] = interp(s[i], x[i,j], snew[i])
        
        indices = torch.rand(xnew.shape[1]) >= 0.25
        xnew = xnew[:,indices]
        
        return xnew
    
    def augment_output(self, x, y):
        
        s = x[:,1]
        ds = s[:,1:] - s[:,:-1]
        
        rv = self.beta_dist.rsample([s.shape[0], s.shape[1]-2]).to(self.device)
        snew = s.clone()
        #print(snew.device, rv.device, ds.device)
        snew[:,1:-1] += 0.5 * ( (rv > 0)*ds[:,1:] - (rv < 0)* ds[:,:-1] )

        #xnew = interp(s, x, snew)
        
        xnew = torch.zeros_like(x, device=self.device)
        for i in range(x.shape[0]):
            for j in range(x.shape[1]):
                xnew[i,j] = interp(s[i], x[i,j], snew[i])
                
        ynew = torch.zeros_like(y, device=self.device)
        for i in range(y.shape[0]):
            for j in range(y.shape[1]):
                ynew[i,j] = interp(s[i], y[i,j], snew[i])
        
        indices = torch.rand(xnew.shape[1]) >= 0.25
        xnew = xnew[:,indices]
        ynew = ynew[:,indices]
        
        return xnew, ynew
        
    def airfoil_encoder(self, x):
        
        #xxx = torch.cat((x, x, x), dim=2)
        #x_0 = self.layer_0(xxx)
        x_0 = self.layer_0(x)

        x_11 = self.relu_11(self.layer_11(x_0))
        x_12 = self.relu_12(self.layer_12(x_11))
        x_13 = self.relu_13(self.layer_13(x_12))
        x_14 = self.relu_14(self.layer_14(x_13))

        x_1 = torch.cat((x_11, x_12, x_13, x_14), dim=1)
        
        x_2 = self.relu_2(self.layer_2(x_1))
        
        ### Insert Transformer!
        
        x_3 = x_2
        
        return x_3
        
    def latent_encoder(self, x_3, ma, aoa):
        
        #ma = torch.ones(x_3.shape[0], 1,193).to(self.device)*ma.view(-1, 1, 1)
        #aoa = torch.ones(x_3.shape[0], 1, 193).to(self.device)*aoa.view(-1, 1, 1)
        
        #print("lat", x_3.device, ma.device, aoa.device)
        
        ma_aoa = torch.stack((ma,aoa)).to(self.device)
        ma_aoa1 = self.ma_aoa_relu1(self.ma_aoa_lin1(ma_aoa.T))
        ma_aoa2 = self.ma_aoa_relu2(self.ma_aoa_lin2(ma_aoa1)).view(-1,1,32)
        
        #relu(lin(cat((ma,aoa))))
        #x_4 = torch.cat((x_3, aoa, ma), dim=1)        
        
        ### Insert Transformer!
        
        #z = self.transformer_encoder(x_4.transpose(1,2))
        
        #print(ma_aoa2.shape, x_3.shape)
        #print(ma_aoa2.device, x_3.device)
        
        z = self.latenc_transformer(ma_aoa2, x_3.transpose(1,2))
        
        return z
    
    
    def airfoil_predictor(self, z, x_output):
        
        ### Insert Transformer!
        
        #y = self.output(z.transpose(1,2))      
        
        
        #x3_output = self.airfoil_encoder(x_output)
        
        #print(z.shape, x_output.shape, x3_output.shape)
        
        #y = self.transformer_decoder(x3_output, z.transpose(1,2))
        
        
        
        x3_output = self.airfoil_encoder(x_output).transpose(1,2).transpose(0,1)
        z = z.transpose(0,1)
        
        #print(z.shape, x_output.shape, x3_output.shape)

        y = self.transformer_decoder(x3_output, z).transpose(0,1).transpose(1,2)
        
        out = self.output(y)
        
        return out
    
    def forward(self, x, ma, aoa):
        
        xnew = self.data_augmentation(x)
        x3 = self.airfoil_encoder(xnew)
        z = self.latent_encoder(x3, ma, aoa)
        y = self.airfoil_predictor(z)
        
        return y
        
    def training_step(self, batch, batch_idx):
        qoip, ma, _, aoa = batch
        
        #print(qoip.device, ma.device, aoa.device)
        
        x = qoip[:,:6,:]
        y_target = qoip
        
        xout, y_target = self.augment_output(x.clone(), y_target.clone())
        
        xnew = self.data_augmentation(x.clone())
        x3 = self.airfoil_encoder(xnew)
        z = self.latent_encoder(x3, ma, aoa)
        #y = self.airfoil_predictor(z, self.data_augmentation(x))
        y = self.airfoil_predictor(z, xout)
        
        #y[:,5,:] = 10. * y[:,5,:]
        #y_target[:,5,:] = 10. * y_target[:,5,:]
        
        
        loss1 = F.mse_loss(y, y_target) * 1e6
        
        loss = loss1
        
        self.log('train_loss', loss1)
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        qoip, ma, _, aoa = batch
        x = qoip[:,:6,:]
        y_target = qoip
        
        #print(qoip.device, ma.device, aoa.device)
        
        #xnew = self.data_augmentation(x)
        x3 = self.airfoil_encoder(x)
        z = self.latent_encoder(x3, ma, aoa)
        y = self.airfoil_predictor(z, x)
        
        #y[:,5,:] *= 10.
        #y_target[:,5,:] *= 10.
        
        loss1 = F.mse_loss(y, y_target) * 1e6
        
        loss = loss1
        
        self.log('val_loss', loss1)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=2000, eta_min=0.0001)

        return [optimizer], [scheduler]
    
    def train_dataloader(self):
        #print("get dataloader ", self.batch_size)
        return DataLoader(train_dataset, batch_size=self.batch_size, num_workers=0, shuffle=True, pin_memory=True)    
    
    def val_dataloader(self):
        return DataLoader(test_dataset, batch_size=len(test_dataset), num_workers=0, pin_memory=True)

In [None]:
from pytorch_lightning.callbacks import LearningRateMonitor

In [None]:
model = AirfoilModel()

In [None]:

lr_monitor = LearningRateMonitor(logging_interval='step')
trainer = Trainer(gpus=1, weights_summary='full', precision=16, check_val_every_n_epoch=2, max_epochs=10_000,
                 limit_train_batches=0.5, auto_lr_find=False, callbacks=[lr_monitor]) #, auto_scale_batch_size=None
#train_loader = DataLoader(train_dataset, batch_size=1468, shuffle=True, pin_memory=True)
#val_loader = DataLoader(test_dataset, batch_size=64, shuffle=True, pin_memory=True)
#trainer.fit(model, train_loader)
#trainer.tune(model)
model.batch_size = 8
model.train_variance = 0.01 #1.0
model.lr = 0.001

In [None]:
trainer.fit(model)