In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import collections

class Corr_pyTorch(nn.Module):
    '''
    my implementation of correlation layer using pytorch
    note that the Ispeed is much slower than cuda version
    '''

    def __init__(self, pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1):
        assert pad_size == max_displacement
        assert stride1 == stride2 == 1
        super().__init__()
        self.pad_size = pad_size
        self.kernel_size = kernel_size
        self.stride1 = stride1
        self.stride2 = stride2
        self.max_hdisp = max_displacement
        self.padlayer = nn.ConstantPad2d(pad_size, 0)

    def forward(self, in1, in2):
        bz, cn, hei, wid = in1.shape
        f1 = F.unfold(in1, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=self.stride1)
        f2 = F.unfold(in2, kernel_size=self.kernel_size, padding=self.kernel_size // 2,stride=self.stride2)
        searching_kernel_size = f2.shape[1]
        f2_ = torch.reshape(f2, (bz, searching_kernel_size, hei, wid))
        f2_ = torch.reshape(f2_, (bz * searching_kernel_size, hei, wid)).unsqueeze(1)
        f2 = F.unfold(f2_, kernel_size=(hei, wid), padding=self.pad_size, stride=self.stride2)
        _, kernel_number, window_number = f2.shape
        f2_ = torch.reshape(f2, (bz, searching_kernel_size, kernel_number, window_number))
        f2_2 = torch.transpose(f2_, dim0=1, dim1=3).transpose(2, 3)
        f1_2 = f1.unsqueeze(1)

        res = f2_2 * f1_2
        res = torch.mean(res, dim=2)
        res = torch.reshape(res, (bz, window_number, hei, wid))
        return res

class WarpingLayer(nn.Module):
    def __init__(self):
        super(WarpingLayer, self).__init__()
    def forward(self, x, deformation):
        B, C, H, W = x.shape
        # mesh grid
        xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
        yy = torch.arange(0, H).view(-1, 1).repeat(1, W)

        xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
        yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
        grid = torch.cat((xx, yy), 1).float().to(device=x.device)
        vgrid = grid + deformation  # B,2,H,W

        vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :] / max(W - 1, 1) - 1.0
        vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :] / max(H - 1, 1) - 1.0

        vgrid = vgrid.permute(0, 2, 3, 1)  # from B,2,H,W -> B,H,W,2，
        x_warp = F.grid_sample(x, vgrid, padding_mode='zeros', align_corners=True)
        mask = torch.ones(x.size(), requires_grad=False).to(x.device)
        mask = F.grid_sample(mask, vgrid, align_corners=True)
        mask = (mask >= 1.0).float()
        return x_warp * mask


def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True, if_IN=False, IN_affine=False, if_BN=False):
    if isReLU:
        if if_IN:

            return nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
                          padding=((kernel_size - 1) * dilation) // 2, bias=True),
                nn.LeakyReLU(0.1, inplace=True),
                nn.InstanceNorm2d(out_planes, affine=IN_affine)

            )
        elif if_BN:
            return nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
                          padding=((kernel_size - 1) * dilation) // 2, bias=True),
                nn.LeakyReLU(0.1, inplace=True),
                nn.BatchNorm2d(out_planes, affine=IN_affine)
            )
        else:
            return nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
                          padding=((kernel_size - 1) * dilation) // 2, bias=True),
                nn.LeakyReLU(0.1, inplace=True)
            )
    else:
        if if_IN:
            return nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
                          padding=((kernel_size - 1) * dilation) // 2, bias=True),
                nn.InstanceNorm2d(out_planes, affine=IN_affine)
            )
        elif if_BN:
            return nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
                          padding=((kernel_size - 1) * dilation) // 2, bias=True),
                nn.BatchNorm2d(out_planes, affine=IN_affine)
            )
        else:
            return nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
                          padding=((kernel_size - 1) * dilation) // 2, bias=True)
            )


class ResidualBlock(nn.Module):
    def __init__(self, in_planes, planes, norm_fn='group', stride=1):
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
        self.relu = nn.LeakyReLU(inplace=False, negative_slope=0.1)

        num_groups = planes // 8

        if norm_fn == 'group':
            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
            if not stride == 1:
                self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)

        elif norm_fn == 'batch':
            self.norm1 = nn.BatchNorm2d(planes)
            self.norm2 = nn.BatchNorm2d(planes)
            if not stride == 1:
                self.norm3 = nn.BatchNorm2d(planes)

        elif norm_fn == 'instance':
            self.norm1 = nn.InstanceNorm2d(planes)
            self.norm2 = nn.InstanceNorm2d(planes)
            if not stride == 1:
                self.norm3 = nn.InstanceNorm2d(planes)

        elif norm_fn == 'none':
            self.norm1 = nn.Sequential()
            self.norm2 = nn.Sequential()
            if not stride == 1:
                self.norm3 = nn.Sequential()

        if stride == 1:
            self.downsample = None

        else:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)

    def forward(self, x):
        y = x
        y = self.relu(self.norm1(self.conv1(y)))
        y = self.relu(self.norm2(self.conv2(y)))

        if self.downsample is not None:
            x = self.downsample(x)

        return self.relu(x + y)
class Encoder(nn.Module):
    def __init__(self, input_dim=1 ,norm_fn='batch', dropout=0.0):
        super(Encoder, self).__init__()
        self.norm_fn = norm_fn
        if self.norm_fn == 'group':
            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=16)
        elif self.norm_fn == 'batch':
            self.norm1 = nn.BatchNorm2d(16)
        elif self.norm_fn == 'instance':
            self.norm1 = nn.InstanceNorm2d(16)
        elif self.norm_fn == 'none':
            self.norm1 = nn.Sequential()
        self.conv1 = nn.Conv2d(input_dim, 16, kernel_size=7, stride=2, padding=3)
        self.relu1 =nn.LeakyReLU(inplace=False, negative_slope=0.1)
        self.in_planes =16

        self.layer1 = self._make_layer(16, stride=1)
        self.layer2 = self._make_layer(32, stride=2)
        self.layer3 = self._make_layer(64, stride=2)
        self.layer4 = self._make_layer(96, stride=2)
        self.layer5 = self._make_layer(128, stride=2)
        self.layer6= self._make_layer(196, stride=2)

        self.dropout = None
        if dropout > 0:
            self.dropout = nn.Dropout2d(p=dropout)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
                if m.weight is not None:
                    nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def _make_layer(self, dim, stride=1):
        layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
        layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
        layers = (layer1, layer2)
        self.in_planes = dim
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu1(x)
        disp1 = self.layer1(x)
        disp2 = self.layer2(disp1)
        disp3 = self.layer3(disp2)
        disp4 = self.layer4(disp3)
        disp5= self.layer5(disp4)
        disp6 = self.layer6(disp5)
        return disp6, disp5,disp4, disp3, disp2,


class Denseblock(nn.Module):

    def __init__(self, ch_in, f_channels=(128, 128, 96, 64, 32), out_channel=2):
        super(Denseblock, self).__init__()

        N = 0
        ind = 0
        N += ch_in
        self.conv1 = conv(N, f_channels[ind])
        N += f_channels[ind]

        ind += 1
        self.conv2 = conv(N, f_channels[ind])
        N += f_channels[ind]

        ind += 1
        self.conv3 = conv(N, f_channels[ind])
        N += f_channels[ind]

        ind += 1
        self.conv4 = conv(N, f_channels[ind])
        N += f_channels[ind]

        ind += 1
        self.conv5 = conv(N, f_channels[ind])
        N += f_channels[ind]
        self.n_channels = N
        ind += 1
        self.conv_last = conv(N, out_channel, isReLU=False)
        self.channels = (128, 128, 128, 96, 64, 32, 2)
        self.convs = nn.Sequential(
            conv(self.n_channels+2, self.channels[0], 3, 1, 1),
            conv(self.channels[0], self.channels[1], 3, 1, 2),
            conv(self.channels[1], self.channels[2], 3, 1, 4),
            conv(self.channels[2], self.channels[3], 3, 1, 8),
            conv(self.channels[3], self.channels[4], 3, 1, 16),
            conv(self.channels[4], self.channels[5], 3, 1, 1),
            conv(self.channels[5], self.channels[6], isReLU=False)
        )
    def forward(self,disp, x):
        x1 = torch.cat([self.conv1(x), x], dim=1)
        x2 = torch.cat([self.conv2(x1), x1], dim=1)
        x3 = torch.cat([self.conv3(x2), x2], dim=1)
        x4 = torch.cat([self.conv4(x3), x3], dim=1)
        x5 = torch.cat([self.conv5(x4), x4], dim=1)
        x_out = self.conv_last(x5)
        disp_ = disp + x_out
        out=self.convs(torch.cat([x5, disp_], dim=1))
        out_=out + x_out
        return disp,out_


class DICnet(nn.Module):
    def __init__(self):
        super(DICnet, self).__init__()
        self.num_chs = [1, 16, 32, 64, 96, 128, 196]  # 1/2 1/4 1/8 1/16 1/32 1/64
        self.feature_pyramid_extractor = Encoder()
        self.conv_1x1 = nn.ModuleList([conv(196, 32, kernel_size=1, stride=1, dilation=1),
                                       conv(128, 32, kernel_size=1, stride=1, dilation=1),
                                       conv(96, 32, kernel_size=1, stride=1, dilation=1),
                                       conv(64, 32, kernel_size=1, stride=1, dilation=1),
                                       conv(32, 32, kernel_size=1, stride=1, dilation=1)])

        self.leakyRELU = nn.LeakyReLU(0.1, inplace=True)
        self.warping_layer = WarpingLayer()
        self.search_range = 4
        self.output_level = 4
        self.dim_corr = (self.search_range * 2 + 1) ** 2
        self.num_ch_in = self.dim_corr + 32 + 2
        self.d_channels = (128, 128, 96, 64, 32)
        self.correlation_pytorch = Corr_pyTorch(pad_size=self.search_range, kernel_size=1,max_displacement=self.search_range, stride1=1,stride2=1)  # correlation layer using pytorch
        self.denseblock = Denseblock(self.num_ch_in, f_channels=self.d_channels)

    def upsample2d_as(self,inputs, target_as, mode="bilinear", if_rate=False):
        _, _, h, w = target_as.size()
        res = F.interpolate(inputs, [h, w], mode=mode, align_corners=True)
        if if_rate:
            _, _, h_, w_ = inputs.size()
            u_scale = (w / w_)
            v_scale = (h / h_)
            u, v = res.chunk(2, dim=1)
            u = u * u_scale
            v = v * v_scale
            res = torch.cat([u, v], dim=1)
        return res
    def normalize_features(self, feature_list, normalize, center, moments_across_channels=False,
                           moments_across_images=False):
        """Normalizes feature tensors (e.g., before computing the cost volume).
        Args:
          feature_list: list of torch tensors, each with dimensions [b, c, h, w]
          normalize: bool flag, divide features by their standard deviation
          center: bool flag, subtract feature mean
          moments_across_channels: bool flag, compute mean and std across channels, 看到UFlow默认是True
          moments_across_images: bool flag, compute mean and std across images, 看到UFlow默认是True

        Returns:
          list, normalized feature_list
        """
        statistics = collections.defaultdict(list)
        axes = [1, 2, 3] if moments_across_channels else [2, 3]  # [b, c, h, w]
        for feature_image in feature_list:
            mean = torch.mean(feature_image, dim=axes, keepdim=True)  # [b,1,1,1] or [b,c,1,1]
            variance = torch.var(feature_image, dim=axes, keepdim=True)  # [b,1,1,1] or [b,c,1,1]
            statistics['mean'].append(mean)
            statistics['var'].append(variance)
        if moments_across_images:
            statistics['mean'] = ([torch.mean(torch.stack(statistics['mean'], dim=0), dim=(0,))] * len(feature_list))
            statistics['var'] = ([torch.var(torch.stack(statistics['var'], dim=0), dim=(0,))] * len(feature_list))
        statistics['std'] = [torch.sqrt(v + 1e-16) for v in statistics['var']]
        if center:
            feature_list = [
                f - mean for f, mean in zip(feature_list, statistics['mean'])
            ]
        if normalize:
            feature_list = [f / std for f, std in zip(feature_list, statistics['std'])]

        return feature_list
    def Update(self, level, flow_1,  feature_1, feature_1_1x1, feature_2):
        up_bilinear = self.upsample2d_as(flow_1, feature_1, mode="bilinear", if_rate=True)
        if level == 0:
            feature_2_warp = feature_2
        else:
            feature_2_warp = self.warping_layer(feature_2, up_bilinear)
        feature_1, feature_2_warp = self.normalize_features((feature_1, feature_2_warp), normalize=True,center=True)
        out_corr_1 = self.correlation_pytorch(feature_1, feature_2_warp)
        out_corr_relu_1 = self.leakyRELU(out_corr_1)
        up_bilinear, res = self.denseblock(up_bilinear ,torch.cat([out_corr_relu_1, feature_1_1x1, up_bilinear], dim=1))
        return  up_bilinear+res

    def forward(self,img1,img2):
        x1 = self.feature_pyramid_extractor(img1)
        x2 = self.feature_pyramid_extractor(img2)
        b_size, _, h_x1, w_x1, = x1[0].size()
        init_dtype = x1[0].dtype
        init_device = x1[0].device
        disp_f= torch.zeros(b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device).float()
        feature_level_ls = []
        for l, (x1,x2) in enumerate(zip(x1,x2)):
            x1_1by1 = self.conv_1x1[l](x1)
            feature_level_ls.append((x1, x1_1by1, x2))  # len = 5
        for level, (x1, x1_1by1, x2) in enumerate(feature_level_ls):
            disp_f = self.Update(level=level, flow_1=disp_f,feature_1=x1, feature_1_1x1=x1_1by1,feature_2=x2)
        disp = self.upsample2d_as(disp_f, img1, mode="bilinear", if_rate=True)
        return disp

class UnDICnet_d(nn.Module):
    def __init__(self,args):
        super(UnDICnet_d, self).__init__()
        self.DICnet = DICnet()
        self.args = args
    def forward(self, img1,img2):

        d_f_out = self.DICnet(img1,img2)  # forward estimation
        d_b_out = self.DICnet(img2,img1)  # backward estimation
        output_dict = {}

        output_dict['flow_f_out'] = d_f_out
        output_dict['flow_b_out'] = d_b_out
        return output_dict


class UnDICnet_s(nn.Module):
    def __init__(self, args):
        super(UnDICnet_s, self).__init__()
        self.DICnet = DICnet()
        self.args = args

    def forward(self, img1, img2):
        d_f_out = self.DICnet(img1, img2)  # forward estimation
        return d_f_out



In [2]:
def get_patches(x, x_wind=143):
    kh, dh = (x_wind*2)+1, 1
    patches = x.unfold(2, kh, dh)
    patches = torch.squeeze(patches,dim=1).permute(0,1,3,2)
    return patches

def get_strain(disp, x_wind=143):
    d = x_wind*2+1
    Uxx_list = []
    disp = get_patches(disp,x_wind=x_wind)
    depthX = torch.linspace(1,d,d)
    depthX = torch.stack([depthX,torch.ones_like(depthX)]).float().permute(1,0).cuda()
    depthX = depthX.unsqueeze(0).repeat(disp.shape[1],1,1)
    XtX = depthX.permute(0,2,1).bmm(depthX)
    for i in range(len(disp)):
        # Cholesky decomposition
        XtY = depthX.permute(0,2,1).bmm(disp[i,...])
        betas_cholesky = torch.linalg.solve(XtX, XtY)
        Uxx = torch.squeeze(betas_cholesky[:,0,:])
        # pad to original size
        Uxx_list += [F.pad(Uxx, (0,0,x_wind, x_wind))]
    return torch.stack(Uxx_list).unsqueeze(1)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt  # ✅ FIXED
import random


# -----------------------------------------------------------------------------
# Warping function: using grid_sample to warp an image given a displacement field.
# -----------------------------------------------------------------------------
def warp(img, flow):
    # img: [B, C, H, W], flow: [B, 2, H, W] (displacement in pixel units)
    B, C, H, W = img.shape
    grid_y, grid_x = torch.meshgrid(torch.arange(0, H, device=img.device),
                                    torch.arange(0, W, device=img.device), indexing='ij')
    grid = torch.stack((grid_x, grid_y), dim=2).float()  # [H, W, 2]
    grid = grid.unsqueeze(0).repeat(B, 1, 1, 1)  # [B, H, W, 2]
    # Convert grid to normalized coordinates in [-1, 1]
    grid_norm = torch.zeros_like(grid)
    grid_norm[:, :, :, 0] = 2.0 * grid[:, :, :, 0] / (W - 1) - 1.0
    grid_norm[:, :, :, 1] = 2.0 * grid[:, :, :, 1] / (H - 1) - 1.0
    # Normalize flow to the same scale and add
    flow_norm = torch.zeros_like(flow)
    flow_norm[:, 0, :, :] = flow[:, 0, :, :] * (2.0 / (W - 1))
    flow_norm[:, 1, :, :] = flow[:, 1, :, :] * (2.0 / (H - 1))
    warped_grid = grid_norm + flow_norm.permute(0, 2, 3, 1)  # [B, H, W, 2]
    warped_img = F.grid_sample(img, warped_grid, align_corners=True, padding_mode='border')
    return warped_img

# -----------------------------------------------------------------------------
# Loss function implementations
# -----------------------------------------------------------------------------

def patch_znssd_loss(I, I_warp, patch_size=32, stride=16, epsilon=1e-8):
    """
    Computes the patch-based zero-normalized sum of squared differences (ZNSSD)
    between image I and its warped version I_warp.
    I and I_warp are assumed to be of shape [B, 1, H, W].
    """
    # Extract patches; output shape: [B, patch_size*patch_size, L]
    patches_I = F.unfold(I, kernel_size=patch_size, stride=stride)
    patches_I_warp = F.unfold(I_warp, kernel_size=patch_size, stride=stride)
    # Compute per-patch mean and standard deviation.
    mean_I = patches_I.mean(dim=1, keepdim=True)
    std_I = patches_I.std(dim=1, keepdim=True) + epsilon
    mean_I_warp = patches_I_warp.mean(dim=1, keepdim=True)
    std_I_warp = patches_I_warp.std(dim=1, keepdim=True) + epsilon
    # Normalize patches
    norm_I = (patches_I - mean_I) / std_I
    norm_I_warp = (patches_I_warp - mean_I_warp) / std_I_warp
    loss = torch.mean((norm_I - norm_I_warp) ** 2)
    return loss

def smoothness_loss(flow, img):
    """
    Computes an edge-aware smoothness loss on the flow.
    flow: [B, 2, H, W]
    img: [B, 1, H, W] used to weight the gradients.
    """
    # Calculate gradients of flow along x and y directions
    grad_flow_x = torch.abs(flow[:, :, :, 1:] - flow[:, :, :, :-1])
    grad_flow_y = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :])
    # Calculate image gradients, averaged over channel
    grad_img_x = torch.mean(torch.abs(img[:, :, :, 1:] - img[:, :, :, :-1]), dim=1, keepdim=True)
    grad_img_y = torch.mean(torch.abs(img[:, :, 1:, :] - img[:, :, :-1, :]), dim=1, keepdim=True)
    loss_x = grad_flow_x * torch.exp(-grad_img_x)
    loss_y = grad_flow_y * torch.exp(-grad_img_y)
    return torch.mean(loss_x) + torch.mean(loss_y)

def census_loss(I, I_warp, kernel_size=7):
    """
    Computes a simplified census loss between I and I_warp.
    This loss compares the local structure by forming binary descriptors.
    """
    pad = kernel_size // 2
    # Extract local patches
    patches_I = F.unfold(I, kernel_size=kernel_size, padding=pad)
    patches_I_warp = F.unfold(I_warp, kernel_size=kernel_size, padding=pad)
    center_idx = (kernel_size * kernel_size) // 2
    # Get the center pixel intensity for each patch
    center_I = patches_I[:, center_idx:center_idx+1, :]
    center_I_warp = patches_I_warp[:, center_idx:center_idx+1, :]
    # Form binary descriptors by comparing each pixel with the center pixel
    desc_I = torch.sign(patches_I - center_I)
    desc_I_warp = torch.sign(patches_I_warp - center_I_warp)
    # Compute the (normalized) Hamming distance as the census loss
    diff = torch.abs(desc_I - desc_I_warp) / 2.0
    return diff.mean()

# -----------------------------------------------------------------------------
# Dummy Dataset for illustration (replace with real image pairs)
# -----------------------------------------------------------------------------
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class BModeDataset(Dataset):
    
    def __init__(self, root_dir, transform=None, folder_list=None):
        
        self.root_dir = root_dir
        self.transform = transform

        # Gather list of subfolders that contain the expected images
        self.folder_list = [
            folder for folder in os.listdir(root_dir)
            if os.path.isdir(os.path.join(root_dir, folder))
        ]

        # Sort folders (assumes folder names are numeric; adjust key if necessary)
        self.folder_list = sorted(self.folder_list, key=lambda x: int(x) if x.isdigit() else x)

    def __len__(self):
        return len(self.folder_list)

    def __getitem__(self, idx):
        folder = self.folder_list[idx]
        folder_path = os.path.join(self.root_dir, folder)
        pre_path = os.path.join(folder_path, "pre.png")
        post_path = os.path.join(folder_path, "post.png")

        # Load as grayscale images
        pre_img = Image.open(pre_path).convert("L")
        post_img = Image.open(post_path).convert("L")

        if self.transform:
            pre_img = self.transform(pre_img)
            post_img = self.transform(post_img)
        else:
            pre_img = transforms.ToTensor()(pre_img)
            post_img = transforms.ToTensor()(post_img)

        return {'img1': pre_img, 'img2': post_img}


# -----------------------------------------------------------------------------
# Data Loader
# -----------------------------------------------------------------------------

# Define transformation: resize to 256x256, convert to tensor, then normalize.
transform = transforms.Compose([
  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

root_dir = "/teamspace/studios/this_studio/final_faulty_denoised_png"
dataset = BModeDataset(root_dir=root_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)



# Instantiate the model; here args is set to None for simplicity


model = UnDICnet_s(args=None)
device = 'cuda' 

model.to(device)
optimizer = AdamW(model.parameters(), lr=0.0002, weight_decay=0.5e-4)
# Loss weighting parameters as selected in the paper
omega1 = 5
omega2 = 1
model.train()

from tqdm import tqdm
import torch
import os

# Training setup
num_epochs = 300
best_loss = float('inf')
checkpoint_dir = "checkpoints_results_faulty_insampled_2sample_smothness3"
os.makedirs(checkpoint_dir, exist_ok=True)

# Learning rate scheduler (e.g., StepLR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

for epoch in range(num_epochs):


    # Create dataset and loader using selected folders

    model.train()
    epoch_loss = 0.0
    loop = tqdm(dataloader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False)

    for batch in loop:
        img1 = batch['img1'].to(device)
        img2 = batch['img2'].to(device)
        optimizer.zero_grad()

        outputs = model(img1, img2)
        flow_f = outputs
        img2_warp = warp(img2, flow_f)

        l_sim = patch_znssd_loss(img1, img2_warp)
        l_s = smoothness_loss(flow_f, img1)
        l_c = census_loss(img1, img2_warp)

        loss = l_sim + omega1 * l_s + omega2 * l_c
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    scheduler.step()

    if (epoch ) % 5 == 0:
        plt.imshow(flow_f[0, 1, :, :].cpu().detach().numpy(), cmap='jet')
        plt.colorbar()
        plt.title(f"Flow Y Component at Epoch {epoch+1}")
        plt.show()

        strain_map = get_strain(flow_f[:, 1:2, :, :], 143)  # Use y-displacement
        plt.figure(figsize=(8, 6))
        plt.imshow(strain_map[0, 0].cpu().detach().numpy(), cmap='jet')  # or 'jet' if you prefer
        plt.title("Strain Map (Uxx)")
        # plt.colorbar(label='Strain')
        plt.axis('off')
        plt.show()

    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Avg Loss: {avg_loss:.6f}")

    if avg_loss < best_loss:
        best_loss = avg_loss
        checkpoint_path = os.path.join(checkpoint_dir, "best_model.pth")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss
        }, checkpoint_path)
        print(f"✅ Saved best model at epoch {epoch+1} with loss {best_loss:.6f}")



