In [678]:
#!g1.1
from pathlib import Path
no_obs_path = Path('/home/jupyter/mnt/datasets/den2vel/datasets/2D_no_obs/')

In [682]:
#!g1.1
import numpy as np
from matplotlib import pyplot as plt

In [683]:
#!g1.1

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from typing import Union, Callable, Optional, List
import os
from tqdm import tqdm
import json

In [684]:
#!g1.1
val_sims = ['/home/jupyter/mnt/datasets/den2vel/datasets/2D_no_obs/v01/sim_1001',
            '/home/jupyter/mnt/datasets/den2vel/datasets/2D_no_obs/v01/sim_1030',
            '/home/jupyter/mnt/datasets/den2vel/datasets/2D_no_obs/v01/sim_1058',
            '/home/jupyter/mnt/datasets/den2vel/datasets/2D_no_obs/v02/sim_1001',
            '/home/jupyter/mnt/datasets/den2vel/datasets/2D_no_obs/v02/sim_1030',
            '/home/jupyter/mnt/datasets/den2vel/datasets/2D_no_obs/v02/sim_1058',
            '/home/jupyter/mnt/datasets/den2vel/datasets/2D_no_obs/v03/sim_1001',
            '/home/jupyter/mnt/datasets/den2vel/datasets/2D_no_obs/v03/sim_1030',
            '/home/jupyter/mnt/datasets/den2vel/datasets/2D_no_obs/v03/sim_1058',
            '/home/jupyter/mnt/datasets/den2vel/datasets/2D_no_obs/v04/sim_1001',
            '/home/jupyter/mnt/datasets/den2vel/datasets/2D_no_obs/v04/sim_1030',
            '/home/jupyter/mnt/datasets/den2vel/datasets/2D_no_obs/v04/sim_1058']

train_sims = []

for p in tqdm(os.walk(no_obs_path)):
    if 'sim' in p[0]:
        if p[0] not in val_sims:
            train_sims.append(p[0])

245it [00:00, 2099.77it/s]


In [685]:
#!g1.1
class NoObsDataset(torch.utils.data.Dataset):
    
    def __init__(self, sims_pth: List[Union[str, Path]],transforms: Optional[Callable] = None):
        super().__init__()
        self._transforms = transforms 
        self.sims_pth = sims_pth
        self.density = []
        self.velocity = []
        self.s_dict = {}
        
        for s in tqdm(sims_pth):
            p = next(os.walk(s))
            for f in p[-1]:
                if 'npz' == f[-3:]:
                    if 'density' in f:
                        self.density.append(f'{p[0]}/{f}')
                    else:
                        self.velocity.append(f'{p[0]}/{f}')
                elif 'json' in f:
                    with open(f'{p[0]}/{f}', 'r') as f:
                        loaded = json.load(f)
                        self.s_dict[p[0]] = np.array([float(loaded['bnds']), 
                                                      float(loaded['buoyFac'])], dtype=np.float32)
                            
                            
        assert len(self.density) == len(self.velocity)
        
        self.density.sort()
        self.velocity.sort()
                        
    @property
    def transforms(self):
        return self._transforms
    
    @transforms.setter
    def transforms(self, transforms: Callable):
        self._transforms = transforms
        
    def __len__(self):
        return len(self.density)
    
    def __getitem__(self, index: int):
        den_pth = self.density[index]
        den = np.ascontiguousarray(np.load(den_pth)['arr_0'][0, ::-1])
        vel = np.ascontiguousarray(np.load(self.velocity[index])['arr_0'][0, ::-1, :, :-1])
        i = den_pth.find('v0')
        s_pth = den_pth[:i + 12]
        s = self.s_dict[s_pth]
        
        if self._transforms is not None:
            den = self._transforms(den)
            vel = self._transforms(vel)
        return den, vel, torch.from_numpy(s)

In [686]:
#!g1.1
import torchvision
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

In [687]:
#!g1.1
no_obs_train_dataset = NoObsDataset(train_sims, transforms=transform)
no_obs_val_dataset = NoObsDataset(val_sims, transforms=transform)

100%|██████████| 228/228 [00:00<00:00, 1273.49it/s]
100%|██████████| 12/12 [00:00<00:00, 1100.72it/s]


In [688]:
#!g1.1
BATCH_SIZE = 2
train_loader = torch.utils.data.DataLoader(no_obs_train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
val_loader = torch.utils.data.DataLoader(no_obs_val_dataset, batch_size=1, shuffle=False, num_workers=8)

In [689]:
#!g1.1
from typing import Optional, Tuple


class ResBlock(nn.Module):
    
    def __init__(self, in_c: int, out_c: int):
        super().__init__()
        self.conv_1 = nn.Sequential(nn.Conv2d(in_c, out_c, 3, padding=1, padding_mode='reflect'),
                                    nn.InstanceNorm2d(out_c),
                                    nn.ReLU())
        self.conv_2 = nn.Sequential(nn.Conv2d(out_c, out_c, 3, padding=1, padding_mode='reflect'),
                                    nn.InstanceNorm2d(out_c))
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.conv_1(x)
        y = self.conv_2(y)
        return x + y
    
    
class SubNetS(nn.Module):
    
    def __init__(self, in_c: int):
        super().__init__()
#         self.conv_in = nn.Sequential(nn.Conv2d(in_c, 16, 7),
#                                     nn.ReLU())
        self.pool = nn.AvgPool2d(8, 8)
        self.flatten_conv = nn.Sequential(nn.Conv2d(16, 16, 3, padding=1, padding_mode='reflect'), 
                                          nn.AvgPool2d(2, 2))
        self.s_out = nn.Sequential(nn.Flatten(), nn.Linear(16 * 4 * 4, 2))
        self.s_in = nn.Linear(4, 4 * 8 * 8)
        self.conv_out = nn.Sequential(nn.Conv2d(4, 8, 3, padding=1, padding_mode='reflect'),
                                      nn.LeakyReLU(0.2),
                                      nn.Conv2d(8, 16, 3, padding=1, padding_mode='reflect'),
                                      nn.LeakyReLU(0.2))
        self.unpool = nn.Sequential(nn.Upsample((64, 64), mode='bilinear', align_corners=False), 
                                    nn.Conv2d(16, 16, 1))
        
    def forward(self, x: torch.Tensor, s: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        #x = self.conv_in(x)
        x = self.pool(x)
        x = self.flatten_conv(x)
        s_out = self.s_out(x)
        if s is None:
            s = -torch.ones_like(s_out)
            s.requires_grad = False
        x = torch.cat([s_out, s], dim=1)
        x = self.s_in(x)
        x = x.reshape(-1, 4, 8, 8)
        x = self.conv_out(x)
        x = self.unpool(x)
        return x, s_out
            
        
class SubNetKE(nn.Module):
    
    def __init__(self, in_c: int):
        super().__init__()
        self.conv_1 = nn.Sequential(nn.Conv2d(in_c, 8, 3, padding=1, padding_mode='reflect'),
                                    nn.AvgPool2d(2, 2), 
                                    nn.InstanceNorm2d(8),
                                    nn.ReLU())
        self.conv_2 = nn.Sequential(nn.Conv2d(8, 4, 3, padding=1, padding_mode='reflect'),
                                    nn.AvgPool2d(2, 2),
                                    nn.InstanceNorm2d(4))
        self.ke_out_conv = nn.Conv2d(4, 1, 1)
        self.ke_in_conv = nn.Conv2d(2, 2, 3, padding=1, padding_mode='reflect')
        self.conv_5 = nn.Conv2d(2, 2, 3, padding=1, padding_mode='reflect')
        self.unpool = nn.Sequential(nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False), nn.Conv2d(2, 1, 1))
        
    def forward(self, x: torch.Tensor, ke: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self.conv_1(x)
        x = self.conv_2(x)
        ke_out = self.ke_out_conv(x)
        if ke is None:
            ke = -torch.ones_like(ke_out)
            ke.requires_grad = False
        x = torch.cat([ke_out, ke], dim=1) if x.ndim == 4 else torch.cat([ke_out, ke], dim=0)
        x = self.ke_in_conv(x)
        x = self.conv_5(x)
        return self.unpool(x), ke_out
        
        
class SubNetW(nn.Module):
    def __init__(self, in_c: int):
        super().__init__()
        self.up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 
                                 nn.Conv2d(in_c, 32, 3, padding=1, padding_mode='reflect'),
                                 nn.InstanceNorm2d(32),
                                 nn.ReLU())
        self.up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 
                                 nn.Conv2d(32, 16, 3, padding=1, padding_mode='reflect'),
                                 nn.InstanceNorm2d(16),
                                 nn.ReLU())
        self.w_out = nn.Conv2d(16, 1, 7, padding='same')
        self.w_in = nn.Sequential(nn.Conv2d(2, 16, 5, padding=1, padding_mode='reflect'),
                                  nn.AvgPool2d(2, 2),
                                  nn.InstanceNorm2d(16),
                                  nn.ReLU())
        
        self.down = nn.Sequential(nn.Conv2d(16, 24, 5, padding=1, padding_mode='reflect'),
                                  nn.AvgPool2d(2, 2),
                                  nn.InstanceNorm2d(24),
                                  nn.ReLU())
        
        self.conv_out = nn.Conv2d(24, 32, 3, padding=1, padding_mode='reflect')
        
    def forward(self, x: torch.Tensor, w: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self.up1(x)
        x = self.up2(x)
        w_out = self.w_out(x)
        if w is None:
            w = -torch.ones_like(w_out)
            w.requires_grad = False
        x = torch.cat([w_out, w], dim=1) if x.ndim == 4 else torch.cat([w_out, w], dim=0)
        x = self.w_in(x)
        x = self.down(x)
        x = self.conv_out(x)
        return x, w_out
    
    
    
class UNet(nn.Module):
    def __init__(self, in_c: int, out_c: int, s_subnet: nn.Module, 
                 w_subnet: nn.Module, ke_subnet: nn.Module, obstacles: bool = False):
        super().__init__()
        self.obstacles = obstacles
        self.down_1 = nn.Sequential(nn.Conv2d(in_c + obstacles, 16, 3, padding='same', padding_mode='reflect'),
                                    nn.InstanceNorm2d(16),
                                    nn.ReLU(),
                                    nn.Conv2d(16, 32, 3, padding='same'),
                                    nn.InstanceNorm2d(32),
                                    nn.ReLU())
        self.down_2 = nn.Sequential(nn.Conv2d(32, 64, 3, padding='same', padding_mode='reflect'),
                                    nn.AvgPool2d(2, 2),
                                    nn.InstanceNorm2d(64),
                                    nn.ReLU())
        
        self.down_3 = nn.Sequential(nn.Conv2d(64, 128, 3, padding='same', padding_mode='reflect'),
                                    nn.AvgPool2d(2, 2),
                                    nn.InstanceNorm2d(128),
                                    nn.ReLU())
        self.resblock_1 = ResBlock(128, 128)
        self.resblock_2 = ResBlock(128, 128)
        self.resblock_3 = ResBlock(128, 128)
        self.resblock_4 = ResBlock(145, 145)
        self.resblock_5 = ResBlock(145, 145)
        self.resblock_6 = ResBlock(145, 145)
        
        self.subnet_conv = nn.Sequential(nn.Conv2d(128, 16, 3, padding='same', padding_mode='reflect'), 
                                         nn.ReLU())
        
        self.up_1 = nn.Sequential(nn.Conv2d(177, 128, 3, padding='same', padding_mode='reflect'),
                                  nn.InstanceNorm2d(128),
                                  nn.ReLU(),
                                  nn.Conv2d(128, 128, 3, padding='same', padding_mode='reflect'),
                                  nn.InstanceNorm2d(128),
                                  nn.ReLU(),
                                  nn.Upsample((128, 128), mode='bilinear', align_corners=False))
        self.up_2 = nn.Sequential(nn.Conv2d(128, 64, 3, padding='same', padding_mode='reflect'),
                                  nn.InstanceNorm2d(64),
                                  nn.ReLU(),
                                  nn.Upsample((256, 256), mode='bilinear', align_corners=False))
            
        self.up_3 = nn.Sequential(nn.Conv2d(64, 32, 3, padding='same', padding_mode='reflect'),
                                  nn.ReLU())
                        
        self.conv_out = nn.Conv2d(32, out_c, 3, padding='same', padding_mode='reflect')
        
        self.s_subnet = s_subnet
        self.ke_subnet = ke_subnet
        self.w_subnet = w_subnet
        
    def forward(self, x: torch.Tensor, s: Optional[torch.Tensor] = None, 
                w: Optional[torch.Tensor] = None, ke: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, 
                                                                                              torch.Tensor, torch.Tensor]:
        x = self.down_1(x)
        x = self.down_2(x)
        x = self.down_3(x)
        
        x = self.resblock_1(x) 
        x = self.resblock_2(x)
        x = self.resblock_3(x)
        
        y = self.subnet_conv(x)
        
        
        y_s, s_out = self.s_subnet(y, s) # 16 x 8 x 8
        y_ke, ke_out = self.ke_subnet(y, ke) # 1 x 16 x 16
        
        x = torch.cat([F.interpolate(y_s, size=(64, 64), mode='bilinear', align_corners=False), 
                       F.interpolate(y_ke, size=(64, 64), mode='bilinear', align_corners=False), x], dim=1)
        
        x = self.resblock_4(x)
        x = self.resblock_5(x)
        x = self.resblock_6(x)

        y, w_out = self.w_subnet(x, w)
        
        
        x = torch.cat([F.interpolate(y, size=(64, 64), mode='bilinear', align_corners=False), x], dim=1)
        
        x = self.up_1(x)
        x = self.up_2(x)
        x = self.up_3(x)
        
        x = self.conv_out(x)
        
        return x, s_out, w_out, ke_out

In [690]:
#!g1.1
from typing import Sequence
def flatten(*tensors: Sequence[torch.Tensor], BATCH_SIZE: int) -> torch.Tensor:
    return torch.cat([t.reshape(BATCH_SIZE, -1) for t in tensors], dim=1)

In [691]:
#!g1.1
def curl(x):
    dy = x[:, :, 1:] - x[:, :, :-1]
    dx = -x[..., 1:] + x[..., :-1]
    dy = torch.cat([dy, dy[:, :, -1].unsqueeze(2)], dim=2)
    dx = torch.cat([dx, dx[..., -1].unsqueeze(3)], dim=3)
    return torch.cat([dy,dx], dim=1)

In [692]:
#!g1.1

from typing import Mapping

def train_step(G: nn.Module, D: nn.Module, d: torch.Tensor, u: torch.Tensor, 
               s: torch.Tensor, w: torch.Tensor, ke: torch.Tensor, optG: Callable, optD: Callable,
               BATCH_SIZE: int, l_adv: float = 0.2, l_l1: float = 1.0) -> Mapping[str, torch.Tensor]:
    opt_inputs = []
    optG.zero_grad(True)
    g_out, g_s, g_w, g_ke = G(d, *opt_inputs)
    
    g_u = curl(g_out)
    
    optD.zero_grad(True)
    
    dt_d, dt_s, dt_w, dt_ke = D(u)
    df_d, df_s, df_w, df_ke = D(g_u.detach(), g_s.detach(), g_w.detach(), g_ke.detach())
    
    s_r, w_r, ke_r = s.reshape(BATCH_SIZE, -1), w.reshape(BATCH_SIZE, - 1), ke.reshape(BATCH_SIZE, -1)
    
    flatten_den = flatten(d, s_r, w_r, ke_r, BATCH_SIZE=BATCH_SIZE)
    
    flatten_dt_out = flatten(dt_d, dt_s, dt_w, dt_ke, BATCH_SIZE=BATCH_SIZE)
    flatten_df_out = flatten(df_d, df_s, df_w, df_ke, BATCH_SIZE=BATCH_SIZE)
    
    D_p_loss = F.mse_loss(flatten_dt_out, flatten_den)
    D_n_loss = F.mse_loss(flatten_df_out, flatten_den)
    
    #D_p_poss = F.mse_loss(dt_d, d) + F.mse_loss(dt_s, s) + F.mse_loss(dt_w, w) +  F.mse_loss(dt_ke, ke)
    #D_n_loss = F.mse_loss(df_f, d) + F.mse_loss(df_s, s) + F.mse_loss(df_w, w) + F.mse_loss(df_ke, ke)
    
    k = 0.5
    D_loss = D_p_loss - k * D_n_loss
    D_loss.backward()
    optD.step()
    
    flatten_gd_out = flatten(*D(g_u, g_s, g_w, g_ke), BATCH_SIZE=BATCH_SIZE)
    flatten_vel = flatten(u, s_r, w_r, ke_r, BATCH_SIZE=BATCH_SIZE)
    
    g_out_l1, s_l1, w_l1, ke_l1 = G(d)
    l1_out = flatten(curl(g_out_l1), s_l1, w_l1, ke_l1, BATCH_SIZE=BATCH_SIZE)
    
    G_loss = l_adv * F.mse_loss(flatten_gd_out, flatten_den) + l_l1 * F.l1_loss(l1_out, flatten_vel)
    
    G_loss.backward()
    optG.step()
    
    return {'G_loss': G_loss.detach().cpu().item(),
            'D_p_loss': D_p_loss.detach().cpu().item(),
            'D_n_loss': D_n_loss.detach().cpu().item()}
    
    #k = (1 - alpha_k) * k + alpha * D_p_loss.detach() / D_n_loss.detach()
    
    

In [693]:
#!g1.1
from tqdm.notebook import tqdm

def energy(vel: torch.Tensor) -> torch.Tensor:
    return F.avg_pool2d(vel.pow(2).sum(1, keepdim=True), 16, 16)

def augment_energy(ke, den):
    std = ke.std((-2, -1), keepdim=True)
    return ke + F.avg_pool2d((torch.randn_like(den) / max(10, std.max())), 16, 16)

def augment_velocity(vel):
    eps = torch.randn_like(vel) / max(10, vel.std())
    return vel + F.interpolate(F.avg_pool2d(eps, 16, 16), size=(256, 256), mode='bilinear', align_corners=True)
    

def train(G, D, optG, optD, train_loader, val_loader, num_epochs, device, BATCH_SIZE):
    losses = {'G_loss': [],
              'D_p_loss': [],
              'D_n_loss': []}
    for epoch in range(num_epochs):
        for k in losses.keys():
            losses[k].append([])
        G.train(), D.train()
        for den, vel, s in tqdm(train_loader):
            den, vel, s = den.to(device), vel.to(device), s.to(device),
            ke = energy(vel)
#             if np.random.binomial(1, 0.2):
#                 ke = augment_energy(ke, den)
#             if np.random.binomial(1, 0.2):
#                 ke = augment_energy(ke, den)
            
            vdx = F.pad((vel[:, 1, :, 1:] - vel[:, 1, :, :-1])[:, None], (0, 1, 0, 0), mode='reflect')
            udy = F.pad((vel[:, 0, 1:] - vel[:, 0, :-1])[:, None], (0, 0, 0, 1), mode='reflect')

            #w = udx - vdy
            w = vdx - udy
             
            ke_mask = torch.tensor(np.random.binomial(1, 0.5, size=(BATCH_SIZE, 1, 1, 1)), device=device)
            w_mask = torch.tensor(np.random.binomial(1, 0.5, size=(BATCH_SIZE, 1, 1, 1)), device=device)
            s_mask = torch.tensor(np.random.binomial(1, 0.5, size=(BATCH_SIZE, 1)), device=device)
            
            ke = ke * (1.0 - ke_mask) - ke_mask * torch.ones_like(ke)
            w = w * (1.0 - w_mask) - w_mask * torch.ones_like(w)
            s = s * (1.0 - s_mask) - s_mask * torch.ones_like(s)
            
            loss = train_step(G, D, den, vel, s, w, ke, optG, optD, BATCH_SIZE)
            
            for k in loss.keys():
                losses[k][-1].append(loss[k])
        for k in losses.keys():
            losses[k][-1] = np.mean(losses[k][-1])
        
        print(f'Epoch:{ epoch}')
        
        for k in loss.keys():
            print(f'{k}: {losses[k][-1]}')
        
        G.eval()
        with torch.no_grad():
            for den, vel, s in tqdm(val_loader):
                den, vel, s = den.to(device), vel.to(device), s.to(device),
                g_out, *_ = G(den)
                u = curl(g_out).cpu().numpy()
                if np.random.binomial(1, p=0.05):
                    plt.imshow(u.pow(2).sum(1)[0])

In [694]:
#!g1.1
device = torch.device('cuda')

In [695]:
#!g1.1
G_S_net = SubNetS(16)
G_W_net = SubNetW(145)
G_KE_net = SubNetKE(16)

G = UNet(1, 1, G_S_net, G_W_net, G_KE_net)

D_S_net = SubNetS(16)
D_W_net = SubNetW(145)
D_KE_net = SubNetKE(16)

D = UNet(2, 1, D_S_net, D_W_net, D_KE_net)

G.to(device), D.to(device)
print('G and D initialized')

G and D initialized


In [696]:
#!g1.1
optG = torch.optim.Adam(G.parameters(), lr=2e-4)
optD = torch.optim.Adam(D.parameters(), lr=2e-4)

In [None]:
#!g1.1
train(G, D, optG, optD, train_loader, val_loader, 100, device, BATCH_SIZE)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=22800.0), HTML(value='')))