In [5]:
from functools import reduce
from itertools import pairwise, accumulate

import torch
from torch import nn
from torch import Tensor
from torch.nn.functional import max_pool2d, interpolate

from config import CHANNELS_DIMENSION
# from datasets import PreprocessedOpenFWI

In [None]:
class ConvBlock(nn.Sequential):
    """3x3+1Padding Conv, BN, ReLu"""
    def __init__(self, in_channels:int, out_channels:int):
        super().__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

class ResidualBlock(nn.Module):
    """2 Convulution block + residual"""
    def __init__(self, in_channels:int, out_channels:int):
        super().__init__()
        self.blocks = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            ConvBlock(out_channels, out_channels),
        )
        if in_channels == out_channels:
            self.skip_connection = nn.Identity() 
        else:
            # May want to set bias to False ?
            self.skip_connection = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x:Tensor) -> Tensor:
        return self.skip_connection(x) + self.blocks(x)

def encode(x:Tensor, module:nn.Module) -> Tensor:
    return max_pool2d(module(x), 2)

def decode(prev_block_x:Tensor, skip_x:Tensor, module:nn.Module) -> Tensor:
    x_diff = skip_x.shape[2] - prev_block_x.shape[2]
    y_diff = skip_x.shape[3] - prev_block_x.shape[3]
    print("prev_block_x.shape:", prev_block_x.shape, "skip_x.shape:", skip_x.shape)
    # todo: Center padding?
    print("x_diff:", x_diff, "y_diff:", y_diff, "pad:", (x_diff, 0, y_diff, 0))
    padded_prev_block_x = nn.functional.pad(prev_block_x, (x_diff, 0, y_diff, 0))
    print("padded_prev_block_x.shape:", padded_prev_block_x.shape)
    
    x = torch.concatenate((padded_prev_block_x, skip_x), CHANNELS_DIMENSION)
    print("concat x.shape:", x.shape)
    out = module(x)
    return interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)

class UNet(nn.Module):
    def __init__(self, in_channels:int, out_channels:int, start_features:int, depth:int):
        super().__init__()
        # Define the channels per depth of the Unet
        chs_per_depth = [start_features * 2 ** i for i in range(depth)]
        # Instantiate the Unet
        down_blocks_chns_it = pairwise([in_channels, *chs_per_depth])
        self.down_blocks = [ResidualBlock(in_chs, out_chs) for in_chs, out_chs in down_blocks_chns_it]
        self.down_blocks = nn.ModuleList(self.down_blocks)
        # Instantiate the bottle neck
        self.bottle_neck_block = ResidualBlock(chs_per_depth[-1], chs_per_depth[-1])
        # Instantialte the downblocks
        self.up_blocks = nn.ModuleList()
        for in_chs, out_chs in pairwise([*chs_per_depth[::-1], out_channels]):
            self.up_blocks.append(ResidualBlock(in_chs * 2, out_chs))
        # up_blocks_chans_it = pairwise([*chs_per_depth[::-1], out_channels])
        # self.up_blocks = [ResidualBlock(in_chs, out_chs) for in_chs, out_chs in up_blocks_chans_it]
        print(len(self.up_blocks))

    def forward(self, x:Tensor) -> Tensor:
        encoder_outputs = list(accumulate(self.down_blocks, encode, initial=x))
        for i, encoder_output in enumerate(encoder_outputs):
            print(i, encoder_output.shape)
        bottleneck_output = self.bottle_neck_block(encoder_outputs[-1])
        out = bottleneck_output
        print("up:", len(self.up_blocks))
        print("encoder_outputs:", len(encoder_outputs[::-1]))
        for up_block, encode_output in zip(self.up_blocks, encoder_outputs[::-1]):
            print("=======")
            print("out:", out.shape)
            print("encode_output:", encode_output.shape)
            out = decode(out, encode_output, up_block)
        return out


model = UNet(5, 1, 32, 4)#.cuda()

next(model.parameters()).device

4


device(type='cpu')

In [44]:
test_input = torch.randn(200, 5, 72, 72)#.cuda()

model(test_input).shape

0 torch.Size([200, 5, 72, 72])
1 torch.Size([200, 32, 36, 36])
2 torch.Size([200, 64, 18, 18])
3 torch.Size([200, 128, 9, 9])
4 torch.Size([200, 256, 4, 4])
up: 4
encoder_outputs: 5
out: torch.Size([200, 256, 4, 4])
encode_output: torch.Size([200, 256, 4, 4])
prev_block_x.shape: torch.Size([200, 256, 4, 4]) skip_x.shape: torch.Size([200, 256, 4, 4])
x_diff: 0 y_diff: 0 pad: (0, 0, 0, 0)
padded_prev_block_x.shape: torch.Size([200, 256, 4, 4])
concat x.shape: torch.Size([200, 512, 4, 4])
out: torch.Size([200, 128, 8, 8])
encode_output: torch.Size([200, 128, 9, 9])
prev_block_x.shape: torch.Size([200, 128, 8, 8]) skip_x.shape: torch.Size([200, 128, 9, 9])
x_diff: 1 y_diff: 1 pad: (1, 0, 1, 0)
padded_prev_block_x.shape: torch.Size([200, 128, 9, 9])
concat x.shape: torch.Size([200, 256, 9, 9])
out: torch.Size([200, 64, 18, 18])
encode_output: torch.Size([200, 64, 18, 18])
prev_block_x.shape: torch.Size([200, 64, 18, 18]) skip_x.shape: torch.Size([200, 64, 18, 18])
x_diff: 0 y_diff: 0 pad: (

torch.Size([200, 1, 72, 72])