In [1]:
from functools import reduce
from itertools import pairwise, accumulate

import torch
from torch import nn
from torch import Tensor
from torch.nn.functional import max_pool2d, interpolate

from config import CHANNELS_DIMENSION
from datasets import PreprocessedOpenFWI

In [19]:
class ConvBlock(nn.Sequential):
    """3x3+1Padding Conv, BN, ReLu"""
    def __init__(self, in_channels:int, out_channels:int):
        super().__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

class ResidualBlock(nn.Module):
    """2 Convulution block + residual"""
    def __init__(self, in_channels:int, out_channels:int):
        super().__init__()
        self.blocks = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            ConvBlock(out_channels, out_channels),
        )
        if in_channels == out_channels:
            self.skip_connection = nn.Identity() 
        else:
            # May want to set bias to False ?
            self.skip_connection = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x:Tensor) -> Tensor:
        return self.skip_connection(x) + self.blocks(x)

def encode(x:Tensor, module:nn.Module) -> Tensor:
    return max_pool2d(module(x), 2)

def decode(xs:tuple[Tensor, Tensor], module:nn.Module) -> Tensor:
    x = torch.concatenate(xs, CHANNELS_DIMENSION)
    out = module(x)
    return interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)

class UNet(nn.Module):
    def __init__(self, in_channels:int, out_channels:int, start_features:int, depth:int):
        super().__init__()
        # Define the channels per depth of the Unet
        chs_per_depth = [start_features * 2 ** i for i in range(depth)]
        # Instantiate the Unet
        down_blocks_chns_it = pairwise([in_channels, *chs_per_depth])
        self.down_blocks = [ResidualBlock(in_chs, out_chs) for in_chs, out_chs in down_blocks_chns_it]
        # Instantiate the bottle neck
        self.bottle_neck_block = ResidualBlock(chs_per_depth[-1], chs_per_depth[-1])
        # Instantialte the downblocks
        up_blocks_chans_it = pairwise([*chs_per_depth[::-1], out_channels])
        self.up_blocks = [ResidualBlock(in_chs, out_chs) for in_chs, out_chs in up_blocks_chans_it]

    def forward(self, x:Tensor) -> Tensor:
        encoder_outputs = list(accumulate(self.down_blocks, encode, initial=x))
        bottleneck_output = self.bottle_neck_block(encoder_outputs[-1])
        out = bottleneck_output
        for up_block, encode_output in zip(self.up_blocks, encoder_outputs):
            print("out:", out.shape)
            print("out:", encode_output.shape)
            out = decode((out, encode_output), up_block)
        return out


model = UNet(5, 1, 32, 4)#.cuda()

next(model.parameters()).device

device(type='cpu')

In [20]:
test_input = torch.randn(1, 5, 72, 72)#.cuda()

model(test_input).shape

out: torch.Size([1, 256, 4, 4])
out: torch.Size([1, 5, 72, 72])


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 4 but got size 72 for tensor number 1 in the list.