In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torch.optim as optim

In [25]:
class PredictionDataset(Dataset):
    def __init__(self, df_X):
        self.df_X = df_X

    def __len__(self):
        return len(self.df_X)

    def __getitem__(self, idx):
        # --- Inputs ---
        s1 = self.df_X.iloc[idx]["s1"]   # numpy array (H, W)
        cs = self.df_X.iloc[idx]["cs"]   # numpy array (H, W)

        # Convert to tensor
        s1 = torch.tensor(s1, dtype=torch.float32)
        cs = torch.tensor(cs, dtype=torch.float32)

        # Stack into shape (2, H, W)
        X = torch.stack([s1, cs], dim=0)

        return X

In [26]:
normalization = "raw"
dataDir = f"../data/interim/normalization-{normalization}/"

X_all = pd.read_pickle(f'{dataDir}/X_all.pkl')

prediction_dataset = PredictionDataset(X_all)
prediction_loader = DataLoader(prediction_dataset, batch_size=8, shuffle=False)

In [27]:
# --- Conv Block ---
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, batchnorm=True, dropout_rate=0):
        super(ConvBlock, self).__init__()
        
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)]
        if batchnorm:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(inplace=True))

        layers.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
        if batchnorm:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(inplace=True))

        self.conv = nn.Sequential(*layers)
        self.dropout = nn.Dropout2d(dropout_rate) if dropout_rate > 0 else nn.Identity()

    def forward(self, x):
        return self.conv(x)
    

# --- Encoder Block ---
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv = ConvBlock(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x_conv = self.conv(x)
        x_pooled = self.pool(x_conv)
        return x_conv, x_pooled
    

class UpConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpConv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x):
        return self.up(x)


class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.wg = nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0)
        self.wx = nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0)
        self.psi = nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()



    def forward(self, g, x):
        g1 = self.wg(g)
        x1 = self.wx(x)
        psi = self.relu(g1 + x1)
        psi = self.sigmoid(self.psi(psi))
        return x * psi

    


# --- U-Net ---
class UNet(nn.Module):
    def __init__(self, in_channels=2, out_channels=1, int_filters = 32):
        super(UNet, self).__init__()

        df = int_filters
        uf = int_filters

        self.e1 = EncoderBlock(in_channels, df)
        self.e2 = EncoderBlock(df, df * 2)
        self.e3 = EncoderBlock(df * 2, df * 4)
        self.e4 = EncoderBlock(df * 4, df * 8)

        self.bottleneck = ConvBlock(df * 8, df * 16)

        self.up4 = UpConv(df * 16, uf * 8)
        self.att4 = AttentionBlock(uf * 8, df * 8, uf * 8)
        self.conv4 = ConvBlock(uf * 8 + df * 8, uf * 8)

        self.up3 = UpConv(uf * 8, uf * 4)
        self.att3 = AttentionBlock(uf * 4, df * 4, uf * 4)
        self.conv3 = ConvBlock(uf * 4 + df * 4, uf * 4)

        self.up2 = UpConv(uf * 4, uf * 2)
        self.att2 = AttentionBlock(uf * 2, df * 2, uf * 2)
        self.conv2 = ConvBlock(uf * 2 + df * 2, uf * 2)

        self.up1 = UpConv(uf * 2, uf)
        self.att1 = AttentionBlock(uf, df, uf)
        self.conv1 = ConvBlock(uf + df, uf)

        self.final_conv = nn.Conv2d(uf, out_channels, kernel_size=1)

    def forward(self, x):
        s1, p1 = self.e1(x)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        b = self.bottleneck(p4)

        u4 = self.up4(b)
        a4 = self.att4(u4, s4)
        u4 = torch.cat([u4, a4], dim=1)
        c4 = self.conv4(u4)

        u3 = self.up3(c4)
        a3 = self.att3(u3, s3)
        u3 = torch.cat([u3, a3], dim=1)
        c3 = self.conv3(u3)

        u2 = self.up2(c3)
        a2 = self.att2(u2, s2)
        u2 = torch.cat([u2, a2], dim=1)
        c2 = self.conv2(u2)

        u1 = self.up1(c2)
        a1 = self.att1(u1, s1)
        u1 = torch.cat([u1, a1], dim=1)
        c1 = self.conv1(u1)

        outputs = self.final_conv(c1)
        
        return outputs


In [28]:
model = UNet(in_channels=2, out_channels=1, int_filters=32)
model.load_state_dict(torch.load(f"../models/unet_normalization_{normalization}.pth", weights_only=True))
device = torch.device("mps" if torch.mps.is_available() else "cpu")
model = model.to(device)
model.eval()



UNet(
  (e1): EncoderBlock(
    (conv): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(2, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.01, inplace=True)
        (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): LeakyReLU(negative_slope=0.01, inplace=True)
      )
      (dropout): Identity()
    )
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (e2): EncoderBlock(
    (conv): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.01, inplace=True)
        (3): Conv2d(64, 64, ke

In [29]:
all_preds = []

with torch.no_grad():
    for X in prediction_loader:
        X = X.to(device)

        preds = model(X)              # (B, 1, H, W)
        preds = preds.cpu().numpy()   # convert to numpy

        all_preds.append(preds)

In [30]:
len(all_preds)

25384

In [31]:
import numpy as np

all_preds = np.concatenate(all_preds, axis=0)

In [32]:
all_preds[:,0,:,:].shape

(203070, 64, 64)

In [33]:
X_all['pred'] = list(all_preds[:,0,:,:])

In [34]:
X_all

Unnamed: 0,cs,s1,tile,region,time,pred
0,"[[1321.5018, 1321.5018, 1321.5018, 1321.5018, ...","[[-3.0440319, -3.9007003, -4.656244, -6.187236...",tile_10,01,2014_10,"[[1227.6663, 1277.6422, 1297.5355, 1279.4984, ..."
1,"[[0.0, 0.0, 0.0, 0.0, 888.5931, 888.5931, 888....","[[0.0, 0.0, 0.0, 0.0, -0.46023303, -7.0914674,...",tile_101,01,2014_10,"[[9.129043, 8.786194, 9.184349, 8.947771, 885...."
2,"[[1294.3636, 1294.3636, 1294.3636, 1350.764, 1...","[[-0.48107538, -1.604826, -2.2076383, -2.89677...",tile_108,01,2014_10,"[[1164.9523, 1239.4497, 1302.0836, 1303.2202, ..."
3,"[[186.76949, 186.76949, 186.76949, 186.76949, ...","[[-8.337175, -8.040663, -9.163354, -9.161297, ...",tile_109,01,2014_10,"[[176.29524, 201.40758, 183.69368, 186.56265, ..."
4,"[[895.30536, 895.30536, 895.30536, 895.30536, ...","[[-5.05783, -5.5084124, -5.58629, -5.8013773, ...",tile_11,01,2014_10,"[[860.23895, 870.6348, 859.527, 862.0249, 861...."
...,...,...,...,...,...,...
203065,"[[1105.7786, 1105.7786, 1105.7786, 1105.7786, ...","[[-4.825405, -4.721585, -4.73615, -4.424411, -...",tile_95,06,2024_10,"[[1045.689, 1051.1731, 1045.0349, 1053.5786, 1..."
203066,"[[1252.2408, 1252.2408, 1257.4795, 1257.4795, ...","[[-3.1178699, -3.0993083, -3.2134385, -3.48768...",tile_96,06,2024_10,"[[1171.3232, 1186.3425, 1179.0221, 1182.1045, ..."
203067,"[[1324.8103, 1324.8103, 1324.8103, 1324.8103, ...","[[-2.9030392, -2.9227698, -3.2303104, -3.27568...",tile_97,06,2024_10,"[[1227.8917, 1243.9807, 1236.8665, 1234.3662, ..."
203068,"[[1148.0347, 1148.0347, 1156.8484, 1156.8484, ...","[[-6.92509, -6.6605544, -6.6482167, -6.5667467...",tile_98,06,2024_10,"[[1082.5204, 1081.5443, 1078.659, 1087.716, 10..."
