In [None]:
import torch
import torch.nn as nn
import time
import torch.nn.functional as F

In [20]:

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def measure_inference_times(model, dataloader, device='cpu'):
    model.eval()
    model.to(device)

    batch_times = []

    # Warm-up
    with torch.no_grad():
        for _ in range(5):
            for batch in dataloader:
                x, y = batch
                x = x.to(device)
                _ = model(x)
                break  # Only first batch for warmup

    # Measure time per batch
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Measuring Inference Time"):
            x,y = batch
            x = x.to(device)
            start_batch_time = time.time()
            _ = model(x)
            end_batch_time = time.time()
            batch_times.append(end_batch_time - start_batch_time)

    return batch_times


In [21]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")


Using device: cuda:0


In [42]:
class PeriodicPadding(nn.Module):
    def __init__(self, pad_width, pad_height):
        super().__init__()
        self.pad_width = pad_width
        self.pad_height = pad_height

    def forward(self, x):
        return F.pad(x, (self.pad_width, self.pad_width, self.pad_height, self.pad_height), mode='circular')

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, hidden_dim, dropout, pad_width=1, pad_height=1):
        super().__init__()
        self.padding = PeriodicPadding(pad_width, pad_height)
        self.conv1 = nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(hidden_dim)
        self.relu = nn.LeakyReLU(negative_slope=0.3, inplace=True)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(hidden_dim)
        self.shortcut = (
            nn.Sequential(
                nn.Conv2d(in_channels, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(hidden_dim)
            ) if in_channels != hidden_dim else nn.Identity()
        )

    def forward(self, x):
        residual = x
        out = self.padding(x)
        out = self.conv1(out)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.padding(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(residual)
        out = self.relu(out)
        out = self.dropout(out)
        return out


class SelfAttention(nn.Module):
    def __init__(self, dim, num_heads=2, dropout=0.1):
        super().__init__()
        # self.norm = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True)

        # Xavier initialization
        nn.init.xavier_uniform_(self.attn.in_proj_weight)
        if self.attn.in_proj_bias is not None:
            nn.init.zeros_(self.attn.in_proj_bias)
        nn.init.xavier_uniform_(self.attn.out_proj.weight)
        if self.attn.out_proj.bias is not None:
            nn.init.zeros_(self.attn.out_proj.bias)

    def forward(self, x):
        b, lat, lon, f = x.shape
        # x = self.norm(x)
        x = x.view(b, lat * lon, f)
        x, _ = self.attn(x, x, x)
        x = x.view(b, lat, lon, f)
        return x

class UNet(nn.Module):
    def __init__(self, in_channels=141, hidden_dim=128, channel_mults=[1, 2, 2],
                 num_blocks=2, dropout=0.1, num_heads=3, final_channels=1,lat=32,lon=64):
        super().__init__()
        self.in_channels = in_channels
        self.padding = PeriodicPadding(pad_width=3, pad_height=3)
        self.lat = lat
        self.lon = lon
        self.initial_conv = nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=1, padding=0, bias=False)
        self.down_blocks = nn.ModuleList()
        current_channels = hidden_dim

        for mult in channel_mults:
            out_channels = hidden_dim * mult
            self.down_blocks.append(self._make_layer(current_channels, out_channels, num_blocks, dropout))
            current_channels = out_channels

        self.bottleneck = self._make_layer(current_channels, current_channels, num_blocks, dropout)

        self.up_blocks = nn.ModuleList()
        for mult in reversed(channel_mults):
            out_channels = hidden_dim * mult
            self.up_blocks.append(self._make_layer(current_channels + out_channels, out_channels, num_blocks, dropout))
            current_channels = out_channels

        self.final_conv = nn.Conv2d(current_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.pred_conv = nn.Conv2d(in_channels, final_channels, kernel_size=1, stride=1, padding=0)

    def _make_layer(self, in_channels, out_channels, num_blocks, dropout):
        layers = [ResidualBlock(in_channels, out_channels, dropout)]
        for _ in range(num_blocks - 1):
            layers.append(ResidualBlock(out_channels, out_channels, dropout))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)  # (B, C, H, W)
        x = self.padding(x)
        # print(x.shape)
        x = self.initial_conv(x)
        # print(x.shape)
        skips = []
        for down in self.down_blocks:
            x = down(x)
            # print(x.shape)
            skips.append(x)
            # print(x.shape)
            x = F.avg_pool2d(x, kernel_size=2, stride=2)
            print(x.shape)
        # print(x.shape)
        x = self.bottleneck(x)
        # print(x.shape)
        for up, skip in zip(self.up_blocks, reversed(skips)):
            x = F.interpolate(x, scale_factor=2, mode="nearest")
            # print(x.shape)
            if skip.shape[2:] != x.shape[2:]:
                skip = F.interpolate(skip, size=x.shape[2:], mode="nearest")
                # print(skip.shape)
            x = torch.cat([x, skip], dim=1)
            # print(x.shape)
            x = up(x)
        # print(x.shape)
        x = self.final_conv(x)
        # print(x.shape)
        x = F.interpolate(x, size=(self.lat, self.lon), mode="bilinear", align_corners=False)
        # x = x.permute(0, 2, 3, 1)  # (B, H, W, C)
        x = self.pred_conv(x)
        return x


In [43]:
name = "UNet"
model = UNet()

In [33]:
count_parameters(model)

14694683