In [1]:
import os
import fnmatch
import h5py
import pandas as pd
import torch
import numpy as np
import pickle

In [None]:
#fn = os.path.join("F:\\Uni\\Masterarbeit\\Daten", 'T007.h5')
#db = h5py.File(fn, 'r')

In [None]:
path = "F:\\Uni\\Masterarbeit\\Daten\\CollectedData\\T007"

AOA = pickle.load(open(os.path.join(path, "AOA_mat.p"), "rb"))
DVS = pickle.load(open(os.path.join(path, "dv_list.p"), "rb"))
MA = pickle.load(open(os.path.join(path, "Ma_mat.p"), "rb"))

In [None]:
def get_qois(surface_flow, Ma=.729, AOA=2.31, Re=6.5e6, T=288.15):
    x = surface_flow["x"][:]
    y = surface_flow["y"][:]
    xp = np.hstack((x[-1:], x))
    yp = np.hstack((y[-1:], y))
    dx = xp[1:] - xp[:-1]
    dy = yp[1:] - yp[:-1]
    ds = np.sqrt(np.power(dx, 2.) + np.power(dy, 2.))
    dxy = np.vstack((dx,dy))
    rot = np.array([[0, 1], [-1, 0]])
    nxy = rot @ dxy
    nxy[0,:] = nxy[0,:] / ds
    nxy[1,:] = nxy[1,:] / ds
    s = ds.cumsum()
    #sp = np.hstack((np.array([0.]), s))
    #spn = sp / sp[-1]
    sn = s / s[-1]
    
    Rs = 287.1
    l_ref = 1
    mu_ref = 1.716e-5
    T_ref = 273.15
    S = 110.4
    gamma = 1.4
    
    V = Ma*np.sqrt(gamma*Rs*T)
    mu = mu_ref*np.power(T/T_ref, 1.5)*(T_ref+S)/(T+S)
    rho = Re*mu/(V*l_ref)
    q = 0.5*rho*V*V
    p = rho*Rs*T
    
    #print(sp.shape)
    
    qoi = np.vstack((np.zeros(192), np.zeros(192), np.zeros(192), np.zeros(192), surface_flow["x"], surface_flow["y"], surface_flow["Density"]/rho, surface_flow["Energy"]/(q+p), surface_flow["Pressure_Coefficient"], surface_flow["Skin_Friction_Coefficient_x"], surface_flow["Skin_Friction_Coefficient_y"]))
    qoip = np.hstack((qoi[:, -1:], qoi))
    qoip[0,1:] = s
    qoip[1,1:] = sn
    #qoip[0,0] = 0.
    #qoip[1,0] = 0.
    qoip[2,1:] = nxy[0,:]
    qoip[2,:-1] += nxy[0,:]
    qoip[2,1:-1] *= 0.5
    qoip[3,1:] = nxy[1,:]
    qoip[3,:-1] += nxy[1,:]
    qoip[3,1:-1] *= 0.5
    qoip[5] *= 10.
    
    
    return qoip


data = []
fn = "T007.h5"
with h5py.File(fn, "r") as db:
    #print(db['DESIGNS'].keys())
    i = 0
    for dsn_name, dsn in db['DESIGNS'].items():
        print (dsn_name, dsn)
        #print(dsn.keys(), dsn["ADJOINT_DRAG/surface_adjoint.csv"].keys())
        print("DIRECT/surface_flow.csv" in dsn, "ADJOINT_DRAG/surface_adjoint.csv" in dsn)
        if "DIRECT/surface_flow.csv" in dsn and "ADJOINT_DRAG/surface_adjoint.csv" in dsn:
            surface_flow = dsn["DIRECT/surface_flow.csv"]        
            qoip = get_qois(surface_flow)

            #print(dsn.keys())
            adj_surface_flow = dsn["ADJOINT_DRAG/surface_adjoint.csv"]
            print(adj_surface_flow.keys())
            ### muss noch zu qoip hinzugefügt werden

            if qoip.shape == (11,193):
                data.append(qoip)
                if not int(dsn_name[-4:]) == i:
                    print("Mismtch", i, dsn_name)
                
                print(True)
                i += 1
        


#dataset = torch.utils.data.TensorDataset(torch.from_numpy(np.stack(data)))

In [None]:
N = len(data)

aoa = torch.from_numpy(AOA[:N])
dvs = torch.from_numpy(np.stack(DVS[:N]))
ma = torch.from_numpy(MA[:N])

dataset = torch.utils.data.TensorDataset(torch.from_numpy(np.stack(data)).float(), ma.float(), dvs.float(), aoa.float())

In [None]:
dataset[0]

In [None]:
"12.011614".split(", ")

In [None]:
with h5py.File(fn, "r") as db:
    dsn = db["DESIGNS/DSN_0023"]
    surface_flow = dsn["DIRECT/surface_flow.csv"]
    adj_surface_flow = dsn["ADJOINT_DRAG/surface_adjoint.csv"]
    for k in adj_surface_flow.keys():
        print(k, adj_surface_flow[k][:])

In [None]:
len(dataset)

In [None]:
dataset[0]

In [None]:
np.stack(data).shape

In [3]:
import torch
import torch.nn as nn
from torch.nn import functional as F
#from torch import nn
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.trainer import Trainer
from torch.utils.data import DataLoader

In [4]:
import matplotlib.pyplot as plt

In [5]:
### from https://stackoverflow.com/questions/61616810/how-to-do-cubic-spline-interpolation-and-integration-in-pytorch

#import torch

def h_poly(t):
    tt = t[None, :]**torch.arange(4, device=t.device)[:, None]
    A = torch.tensor([
        [1, 0, -3, 2],
        [0, 1, -2, 1],
        [0, 0, 3, -2],
        [0, 0, -1, 1]
    ], dtype=t.dtype, device=t.device)
    return A @ tt


def interp(x, y, xs):
    m = (y[1:] - y[:-1]) / (x[1:] - x[:-1])
    m = torch.cat([m[[0]], (m[1:] + m[:-1]) / 2, m[[-1]]])
    idxs = torch.searchsorted(x[1:], xs)
    dx = (x[idxs + 1] - x[idxs])
    hh = h_poly((xs - x[idxs]) / dx)
    return hh[0] * y[idxs] + hh[1] * m[idxs] * dx + hh[2] * y[idxs + 1] + hh[3] * m[idxs + 1] * dx

In [None]:
batch = dataset[:3]

qoip = batch[0]

x = qoip[:,:6]

s = x[:,1]
s.shape, qoip.shape

In [None]:
x = s
y = qoip[:,6:]
x.shape, y.shape

In [None]:
m = (y[:,:,1:] - y[:,:,:-1]) / (x[:,1:] - x[:,:-1])
m
#m = torch.cat([m[[0]], (m[1:] + m[:-1]) / 2, m[[-1]]])

In [None]:
x = torch.linspace(0, 6, 7)
y = x.sin()
xs = torch.linspace(0, 6, 101)

m = (y[1:] - y[:-1]) / (x[1:] - x[:-1])
m = torch.cat([m[[0]], (m[1:] + m[:-1]) / 2, m[[-1]]])
idxs = torch.searchsorted(x[1:], xs)
dx = (x[idxs + 1] - x[idxs])
hh = h_poly((xs - x[idxs]) / dx)

m, idxs, dx, hh

In [None]:
np.pi

In [None]:
torch.rand(10)

In [None]:
x = torch.linspace(0, 6, 7)
y = x.sin()
xs = torch.linspace(0, 6, 101)
ys = interp(x, y, xs)
#Ys = integ(x, y, xs)
plt.scatter(x, y, label='Samples', color='purple')
plt.plot(xs, ys, label='Interpolated curve')
plt.plot(xs, xs.sin(), '--', label='True Curve')
#plt.plot(xs, Ys, label='Spline Integral')
#plt.plot(xs, 1-xs.cos(), '--', label='True Integral')
plt.legend()

In [None]:
qoip

In [None]:
x = torch.from_numpy(qoip[1])
y = torch.from_numpy(qoip[5])
xs = torch.linspace(0, 1, 10001)
ys = interp(x, y, xs)


plt.figure(figsize=(12,4))
plt.scatter(x, y, label='Samples', color='purple')
plt.plot(xs, ys, label='Interpolated curve')
#plt.plot(xs, xs.sin(), '--', label='True Curve')
#plt.plot(xs, Ys, label='Spline Integral')
#plt.plot(xs, 1-xs.cos(), '--', label='True Integral')
plt.legend()



In [None]:
x.shape, y.shape

In [None]:
x = torch.linspace(0, 1, 101)
x

y = torch.exp(-2*(x-0.5)**2)
y = torch.sin(30*x/np.pi)

x2 = torch.rand(10)
y2 = torch.sin(30*x2/np.pi)
#ys2 = interp(x2,y2,x)

plt.plot(x,y)
plt.plot(x2,y2, 'o')
#plt.plot(x, ys2)



In [None]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

In [None]:
train_size = int(0.85 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

In [8]:
class AirfoilModel(LightningModule):
    
    def __init__(self):
        super().__init__()
        
        self.batch_size = 128
        #self.hparams.batch_size = 64
        self.lr=1e-3
        
        self.train_variance = 1.0
        
        c1 = 16
        c2 = 8
        k2 = 8
        c3 = 32

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_0 = nn.Conv1d(in_channels=6,out_channels=c1,kernel_size=5,stride=1, padding=2)

        self.layer_11 = nn.Conv1d(in_channels=c1,out_channels=c2,kernel_size=3,stride=1, padding=1)
        self.layer_12 = nn.Conv1d(in_channels=c2,out_channels=c2,kernel_size=3,stride=1, padding=1)
        self.layer_13 = nn.Conv1d(in_channels=c2,out_channels=c2,kernel_size=3,stride=1, padding=1)
        self.layer_14 = nn.Conv1d(in_channels=c2,out_channels=c2,kernel_size=3,stride=1, padding=1)
        self.relu_11 = nn.ReLU()
        self.relu_12 = nn.ReLU()
        self.relu_13 = nn.ReLU()
        self.relu_14 = nn.ReLU()

        self.layer_2 = nn.Conv1d(in_channels=8*4,out_channels=c3,kernel_size=5,stride=1, padding=2)
        self.relu_2 = nn.ReLU()
        
        encoder_layers = nn.TransformerEncoderLayer(d_model=32, nhead=8, dim_feedforward=64, dropout=0.001)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=3)
        
        self.output = nn.Conv1d(in_channels=c3,out_channels=11,kernel_size=1, stride=1, padding=0)
        
        self.beta_dist = torch.distributions.beta.Beta(2,2)
        
        self.ma_aoa_lin1 = nn.Linear(2,16)
        self.ma_aoa_relu1 = nn.ReLU()
        self.ma_aoa_lin2 = nn.Linear(16,32)
        self.ma_aoa_relu2 = nn.ReLU()
        
        latenc_layers = nn.TransformerDecoderLayer(d_model=32, nhead=8, dim_feedforward=64, dropout=0.001)
        self.latenc_transformer = nn.TransformerDecoder(latenc_layers, num_layers=2)        
        
        decoder_layers = nn.TransformerDecoderLayer(d_model=32, nhead=8, dim_feedforward=64, dropout=0.001)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layers, num_layers=3)
        
    def data_augmentation(self, x):
        
        s = x[:,1]
        ds = s[:,1:] - s[:,:-1]
        
        rv = self.beta_dist.rsample([s.shape[0], s.shape[1]-2]).to(self.device)
        snew = s.clone()
        #print(snew.device, rv.device, ds.device)
        snew[:,1:-1] += 0.5 * ( (rv > 0)*ds[:,1:] - (rv < 0)* ds[:,:-1] )

        #xnew = interp(s, x, snew)
        
        xnew = torch.zeros_like(x, device=self.device)
        for i in range(x.shape[0]):
            for j in range(x.shape[1]):
                xnew[i,j] = interp(s[i], x[i,j], snew[i])
        
        indices = torch.rand(xnew.shape[1]) >= 0.25
        xnew = xnew[:,indices]
        
        return xnew
    
    def augment_output(self, x, y):
        
        s = x[:,1]
        ds = s[:,1:] - s[:,:-1]
        
        rv = self.beta_dist.rsample([s.shape[0], s.shape[1]-2]).to(self.device)
        snew = s.clone()
        #print(snew.device, rv.device, ds.device)
        snew[:,1:-1] += 0.5 * ( (rv > 0)*ds[:,1:] - (rv < 0)* ds[:,:-1] )

        #xnew = interp(s, x, snew)
        
        xnew = torch.zeros_like(x, device=self.device)
        for i in range(x.shape[0]):
            for j in range(x.shape[1]):
                xnew[i,j] = interp(s[i], x[i,j], snew[i])
                
        ynew = torch.zeros_like(y, device=self.device)
        for i in range(y.shape[0]):
            for j in range(y.shape[1]):
                ynew[i,j] = interp(s[i], y[i,j], snew[i])
        
        indices = torch.rand(xnew.shape[1]) >= 0.25
        xnew = xnew[:,indices]
        ynew = ynew[:,indices]
        
        return xnew, ynew
        
    def airfoil_encoder(self, x):
        
        #xxx = torch.cat((x, x, x), dim=2)
        #x_0 = self.layer_0(xxx)
        x_0 = self.layer_0(x)

        x_11 = self.relu_11(self.layer_11(x_0))
        x_12 = self.relu_12(self.layer_12(x_11))
        x_13 = self.relu_13(self.layer_13(x_12))
        x_14 = self.relu_14(self.layer_14(x_13))

        x_1 = torch.cat((x_11, x_12, x_13, x_14), dim=1)
        
        x_2 = self.relu_2(self.layer_2(x_1))
        
        ### Insert Transformer!
        
        x_3 = x_2
        
        return x_3
        
    def latent_encoder(self, x_3, ma, aoa):
        
        #ma = torch.ones(x_3.shape[0], 1,193).to(self.device)*ma.view(-1, 1, 1)
        #aoa = torch.ones(x_3.shape[0], 1, 193).to(self.device)*aoa.view(-1, 1, 1)
        
        #print("lat", x_3.device, ma.device, aoa.device)
        
        ma_aoa = torch.stack((ma,aoa)).to(self.device)
        ma_aoa1 = self.ma_aoa_relu1(self.ma_aoa_lin1(ma_aoa.T))
        ma_aoa2 = self.ma_aoa_relu2(self.ma_aoa_lin2(ma_aoa1)).view(-1,1,32)
        
        #relu(lin(cat((ma,aoa))))
        #x_4 = torch.cat((x_3, aoa, ma), dim=1)        
        
        ### Insert Transformer!
        
        #z = self.transformer_encoder(x_4.transpose(1,2))
        
        #print(ma_aoa2.shape, x_3.shape)
        #print(ma_aoa2.device, x_3.device)
        
        z = self.latenc_transformer(ma_aoa2, x_3.transpose(1,2))
        
        return z
    
    
    def airfoil_predictor(self, z, x_output):
        
        ### Insert Transformer!
        
        #y = self.output(z.transpose(1,2))      
        
        
        #x3_output = self.airfoil_encoder(x_output)
        
        #print(z.shape, x_output.shape, x3_output.shape)
        
        #y = self.transformer_decoder(x3_output, z.transpose(1,2))
        
        
        
        x3_output = self.airfoil_encoder(x_output).transpose(1,2).transpose(0,1)
        z = z.transpose(0,1)
        
        #print(z.shape, x_output.shape, x3_output.shape)

        y = self.transformer_decoder(x3_output, z).transpose(0,1).transpose(1,2)
        
        out = self.output(y)
        
        return out
    
    def forward(self, x, ma, aoa):
        
        xnew = self.data_augmentation(x)
        x3 = self.airfoil_encoder(xnew)
        z = self.latent_encoder(x3, ma, aoa)
        y = self.airfoil_predictor(z)
        
        return y
        
    def training_step(self, batch, batch_idx):
        qoip, ma, _, aoa = batch
        
        #print(qoip.device, ma.device, aoa.device)
        
        x = qoip[:,:6,:]
        y_target = qoip
        
        xout, y_target = self.augment_output(x.clone(), y_target.clone())
        
        xnew = self.data_augmentation(x.clone())
        x3 = self.airfoil_encoder(xnew)
        z = self.latent_encoder(x3, ma, aoa)
        #y = self.airfoil_predictor(z, self.data_augmentation(x))
        y = self.airfoil_predictor(z, xout)
        
        #y[:,5,:] = 10. * y[:,5,:]
        #y_target[:,5,:] = 10. * y_target[:,5,:]
        
        
        loss1 = F.mse_loss(y, y_target) * 1e6
        
        loss = loss1
        
        self.log('train_loss', loss1)
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        qoip, ma, _, aoa = batch
        x = qoip[:,:6,:]
        y_target = qoip
        
        #print(qoip.device, ma.device, aoa.device)
        
        #xnew = self.data_augmentation(x)
        x3 = self.airfoil_encoder(x)
        z = self.latent_encoder(x3, ma, aoa)
        y = self.airfoil_predictor(z, x)
        
        #y[:,5,:] *= 10.
        #y_target[:,5,:] *= 10.
        
        loss1 = F.mse_loss(y, y_target) * 1e6
        
        loss = loss1
        
        self.log('val_loss', loss1)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=2000, eta_min=0.0001)

        return [optimizer], [scheduler]
    
    def train_dataloader(self):
        #print("get dataloader ", self.batch_size)
        return DataLoader(train_dataset, batch_size=self.batch_size, num_workers=0, shuffle=True, pin_memory=True)    
    
    def val_dataloader(self):
        return DataLoader(test_dataset, batch_size=len(test_dataset), num_workers=0, pin_memory=True)

In [None]:
model = AirfoilModel()
count_parameters(model)

In [None]:
qoip, ma, _, aoa = dataset[:3]
x = qoip[:,:6]

ma_aoa = torch.stack((ma, aoa))
w = torch.randn(8,2)

ma_aoa_lin = nn.Linear(2,8)
ma_aoa_relu = nn.ReLU()

ma_aoa = ma_aoa_lin(ma_aoa.T)

ma_aoa.shape

x3 = model.airfoil_encoder(x)

x3.shape, ma_aoa.shape

#torch.cat((x3, ma_aoa.view(-1,8,1)), dim=1)

In [None]:
batch = dataset[:3]

qoip, ma, _, aoa = batch

#print(qoip.device, ma.device, aoa.device)

x = qoip[:,:6,:]
y_target = qoip

xnew = model.data_augmentation(x)
x3 = model.airfoil_encoder(xnew)
z = model.latent_encoder(x3, ma, aoa)
print(z.shape)

#y = model.airfoil_predictor(z, x)

x_output = x

x3_output = model.airfoil_encoder(x_output).transpose(1,2).transpose(0,1)
z = z.transpose(0,1)
        
print(z.shape, x_output.shape, x3_output.shape)

y = model.transformer_decoder(x3_output, z).transpose(0,1) #.transpose(1,2)

y.shape

In [None]:
batch = dataset[:3]

model.training_step(batch,0)

In [None]:
model = AirfoilModel()

In [None]:
beta = torch.distributions.beta.Beta(2,2)
beta.rsample([2,32])

In [9]:
from pytorch_lightning.callbacks import LearningRateMonitor

In [10]:
model = AirfoilModel()

In [11]:

lr_monitor = LearningRateMonitor(logging_interval='step')
trainer = Trainer(gpus=1, weights_summary='full', precision=16, check_val_every_n_epoch=2, max_epochs=10_000,
                 limit_train_batches=0.5, auto_lr_find=False, callbacks=[lr_monitor]) #, auto_scale_batch_size=None
#train_loader = DataLoader(train_dataset, batch_size=1468, shuffle=True, pin_memory=True)
#val_loader = DataLoader(test_dataset, batch_size=64, shuffle=True, pin_memory=True)
#trainer.fit(model, train_loader)
#trainer.tune(model)
model.batch_size = 8
model.train_variance = 0.01 #1.0
model.lr = 0.001

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.


In [12]:
trainer.fit(model)


    | Name                                                 | Type                    | Params
---------------------------------------------------------------------------------------------------
0   | layer_0                                              | Conv1d                  | 496   
1   | layer_11                                             | Conv1d                  | 392   
2   | layer_12                                             | Conv1d                  | 200   
3   | layer_13                                             | Conv1d                  | 200   
4   | layer_14                                             | Conv1d                  | 200   
5   | relu_11                                              | ReLU                    | 0     
6   | relu_12                                              | ReLU                    | 0     
7   | relu_13                                              | ReLU                    | 0     
8   | relu_14                                        

RuntimeError: CUDA out of memory. Tried to allocate 670.00 MiB (GPU 0; 2.00 GiB total capacity; 696.28 MiB already allocated; 430.30 MiB free; 716.00 MiB reserved in total by PyTorch)

In [None]:
dataset[0]

In [None]:
model.device, dataset

In [None]:
model.cpu()

In [None]:
i = 21
qoip, ma, _, aoa = dataset[i]
qoip = qoip.view(-1,11,193)

x = qoip[:,:6,:]
y = model(x, ma, aoa)

In [None]:
y

In [None]:
x.grad_fn

In [None]:
model.eval()

In [None]:
model = model.eval()
model.cpu()

with torch.no_grad():

    i = 23
    qoip, ma, _, aoa = test_dataset[i]
    qoip = qoip.view(-1,11,193)
    print(ma, aoa)

    x = qoip[:,:6,:]
    #y = model(x, ma, aoa)


    x3 = model.airfoil_encoder(x)
    z = model.latent_encoder(x3, ma, aoa)
    y = model.airfoil_predictor(z, x)



    k = qoip.shape[1]
    y = y.detach()

    fig, ax = plt.subplots(k, figsize=(12,32))
    fields = ["s", "sn", "nx", "ny", "x", "y", "Density", "Energy", "Pressure_Coefficient", "Skin_Friction_Coefficient_x", "Skin_Friction_Coefficient_y"]

    for i in range(k):
        ax[i].plot(qoip[0,1], qoip[0,i], "o")
        ax[i].plot(qoip[0,1], y[0,i], "x")
        ax[i].set_title(fields[i])
        #print(i, fields[i], F.mse_loss(qoip[0,i], y[0,i])*1e6)
        print(i, fields[i], torch.pow(y[0,i] - qoip[0,i], 2).sum() / y.shape[2]*1e6)

    #y[0,5] = y[0,5]*10.
    #qoip[0,5] = qoip[0,5]*10.

    #print(F.mse_loss(y, qoip[:,:])*1e6)
    print(torch.pow(y[0,:] - qoip[0,:], 2).sum() / y.shape[1] / y.shape[2]*1e6)

    model.zero_grad()

In [None]:
plt.plot(qoip[0,4,:], qoip[0,5,:])
plt.plot(y[0,4,:], y[0,5,:])

In [None]:
torch.pow(y - qoip[:,:], 2).sum(), y.shape[2]

In [None]:
(y[0,2]*y[0,8]).sum()

In [None]:
(qoip[0,2]*qoip[0,8]).sum()

In [None]:
plt.plot(qoip[0,0],y[0,2]*y[0,8])
plt.plot(qoip[0,0],qoip[0,2]*qoip[0,8])

In [None]:
plt.plot(qoip[0,0],y[0,3]*y[0,8])
plt.plot(qoip[0,0],qoip[0,3]*qoip[0,8])

In [None]:
ma, aoa

In [None]:
plt.plot(MA, AOA, 'o')

In [None]:
#pickle.dump(dataset, open("save/T007_Transformer/full_dataset.p", "wb"))
pickle.dump(train_dataset, open("save/T007_Transformer/train_dataset.p", "wb"))
pickle.dump(test_dataset, open("save/T007_Transformer/test_dataset.p", "wb"))

In [None]:
pickle.dump(dataset, open("save/T007_Transformer/full_dataset.p", "wb"))


In [6]:
train_dataset = pickle.load(open("save/T007_Transformer/train_dataset.p", "rb"))
test_dataset = pickle.load(open("save/T007_Transformer/test_dataset.p", "rb"))

In [5]:
len(train_dataset), len(test_dataset)

(1905, 337)