In [1]:
import os
import time
import argparse
import numpy as np
from glob import glob

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import imageio

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpuIDs', type=str, default='0', help='GPU IDs, e.g. "0" or "0,1"')
    parser.add_argument('--scale', type=int, default=4)
    parser.add_argument('--ngsrf', type=int, default=64, help='Base #channels for generatorSR')
    parser.add_argument('--numResBlocks', type=int, default=8, help='Number of residual blocks in EDSR')
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--fine_size', type=int, default=100, help='Crop size for sub-volumes')
    parser.add_argument('--itersPerEpoch', type=int, default=300)
    parser.add_argument('--iterCyclesPerEpoch', type=int, default=1)
    parser.add_argument('--valNum', type=int, default=10)
    parser.add_argument('--valTest', action='store_true', help='If True, load separate test data for validation')
    parser.add_argument('--epoch', type=int, default=1200)
    parser.add_argument('--epoch_step', type=int, default=10)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--print_freq', type=int, default=10)
    parser.add_argument('--save_freq', type=int, default=10)
    parser.add_argument('--dataset_dir', type=str, default='./Dataset/Bentheimer_mixed_fw90/Train')
    parser.add_argument('--test_dir', type=str, default='./test_images/')
    parser.add_argument('--test_save_dir', type=str, default='./results_final/')
    parser.add_argument('--test_temp_save_dir', type=str, default='./results_temp/')
    parser.add_argument('--modelName', type=str, default='DoubleSRNet')
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints_PyTorch')

    parser.add_argument('--continue_train', action='store_true')
    parser.add_argument('--phase', type=str, default='train', choices=['train','test'])
    parser.add_argument('--continueEpoch', type=int, default=0)

    parser.add_argument('--augFlag', action='store_true')
    parser.add_argument('--distributed', action='store_true')
    args = parser.parse_args([])
    return args

class DoubleSRDataset(Dataset):
    """
    Loads entire HR.npy & LR.npy, randomly crops sub-slices of size:
      (batch_size, 1, fine_size, fine_size) for LR
      (batch_size, 1, fine_size*scale, fine_size*scale) for HR
    Each dataset item is already a mini-batch => DataLoader batch_size=1
    """
    def __init__(self, hr_path, lr_path,
                 scale=4, fine_size=100,
                 iters=300, batch_size=4,
                 aug_flag=False):
        super(DoubleSRDataset, self).__init__()
        # load volumes
        self.hr = np.load(hr_path)  # e.g. shape (1000,1000,1000)
        self.lr = np.load(lr_path)  # e.g. shape (250,250,250)

        self.scale = scale
        self.fine_size = fine_size
        self.iters = iters
        self.batch_size = batch_size
        self.aug_flag = aug_flag

        # store (lr_batch, hr_batch) pairs in self.data_pairs
        self.data_pairs = []
        self.generate_subcubes()

    def generate_subcubes(self):
        for _ in range(self.iters):
            lr_batch = []
            hr_batch = []
            for b in range(self.batch_size):
                # random pick
                x = np.random.randint(0, self.lr.shape[0] - self.fine_size)
                y = np.random.randint(0, self.lr.shape[1] - self.fine_size)
                z = np.random.randint(0, self.lr.shape[2])  # pick 1 slice in Z

                # LR block
                lr_block = self.lr[x:x+self.fine_size, y:y+self.fine_size, z]
                lr_block = lr_block[None, ...]  # shape => (1, fine_size, fine_size)
                lr_block = lr_block.astype(np.float32)/127.5 - 1.0

                # HR block
                X = x*self.scale
                Y = y*self.scale
                Z = z*self.scale  # if your HR is 4x in z dimension as well
                hr_block = self.hr[X:X+self.fine_size*self.scale,
                                   Y:Y+self.fine_size*self.scale,
                                   Z]
                hr_block = hr_block[None, ...]
                hr_block = hr_block.astype(np.float32)/127.5 - 1.0

                # optional augmentation
                if self.aug_flag:
                    if np.random.rand()<0.5:
                        lr_block = lr_block[..., ::-1].copy()
                        hr_block = hr_block[..., ::-1].copy()
                    if np.random.rand()<0.5:
                        lr_block = lr_block[..., ::-1, :].copy()
                        hr_block = hr_block[..., ::-1, :].copy()

                lr_batch.append(lr_block)
                hr_batch.append(hr_block)

            # stack into one mini-batch
            lr_batch = np.stack(lr_batch, axis=0)  # (batch_size,1,H,W)
            hr_batch = np.stack(hr_batch, axis=0)  # (batch_size,1,H*scale,W*scale)
            self.data_pairs.append((lr_batch, hr_batch))

    def __len__(self):
        return len(self.data_pairs)

    def __getitem__(self, idx):
        lr_batch, hr_batch = self.data_pairs[idx]
        return torch.from_numpy(lr_batch), torch.from_numpy(hr_batch)


class ResidualBlock(nn.Module):
    """ 
    """
    def __init__(self, n_feats=64, kernel_size=3):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(n_feats, n_feats, kernel_size, padding=kernel_size//2)
        self.conv2 = nn.Conv2d(n_feats, n_feats, kernel_size, padding=kernel_size//2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        return out + identity

def upsample_edsr_2d(x, scale):
    """
    """
    if scale == 2:
        # 2x
        x = F.interpolate(x, scale_factor=2, mode='nearest')
    elif scale == 3:
        # 3x
        x = F.interpolate(x, scale_factor=3, mode='nearest')
    elif scale == 4:
        # do two successive 2x
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = F.interpolate(x, scale_factor=2, mode='nearest')
    else:
        raise NotImplementedError("Unsupported scale in upsample_edsr_2d()")
    return x

def upsample_edsr_1d(x, scale):
    """
    """
    if scale == 2:
        # 1 pass of (2,1)
        x = F.interpolate(x, scale_factor=(2,1), mode='nearest')
    elif scale == 3:
        x = F.interpolate(x, scale_factor=(3,1), mode='nearest')
    elif scale == 4:
        # do two successive (2,1)
        x = F.interpolate(x, scale_factor=(2,1), mode='nearest')
        x = F.interpolate(x, scale_factor=(2,1), mode='nearest')
    else:
        raise NotImplementedError("Unsupported scale in upsample_edsr_1d()")
    return x

class EDSR(nn.Module):
    """
    """
    def __init__(self, scale=4, n_resblocks=8, n_feats=64):
        super(EDSR, self).__init__()
        self.scale = scale
        self.n_resblocks = n_resblocks

        self.head = nn.Conv2d(1, n_feats, kernel_size=3, padding=1)

        body_blocks = []
        for _ in range(n_resblocks):
            body_blocks.append(ResidualBlock(n_feats=n_feats, kernel_size=3))
        self.body = nn.Sequential(*body_blocks)

        # final conv after upsampling
        self.tail_conv = nn.Conv2d(n_feats, 1, kernel_size=3, padding=1)

    def forward(self, x):
        # 1) main features
        x_in = self.head(x)
        res = self.body(x_in)
        x_mid = x_in + res

        # 2) upsample in 2D
        x_up = upsample_edsr_2d(x_mid, self.scale)

        # 3) final conv
        out = self.tail_conv(x_up)

        return torch.tanh(out)

class EDSR1D(nn.Module):
    """
    """
    def __init__(self, scale=4, n_resblocks=4, n_feats=32):
        super(EDSR1D, self).__init__()
        self.scale = scale

        self.head = nn.Conv2d(1, n_feats, kernel_size=3, padding=1)

        body_blocks = []
        for _ in range(n_resblocks):
            body_blocks.append(ResidualBlock(n_feats=n_feats, kernel_size=3))
        self.body = nn.Sequential(*body_blocks)

        self.tail_conv = nn.Conv2d(n_feats, 1, kernel_size=3, padding=1)

    def forward(self, x):
        x_in = self.head(x)
        res = self.body(x_in)
        x_mid = x_in + res

        # upsample only one dimension
        x_up = upsample_edsr_1d(x_mid, self.scale)

        out = self.tail_conv(x_up)
        return torch.tanh(out)


def mean_absolute_error(pred, target):
    return torch.mean(torch.abs(pred - target))

def train_one_step(generatorSR, generatorSRC,
                   optimizerSR, optimizerSRC,
                   lr_batch, hr_batch,
                   device='cuda'):
    """
    """
    lr_batch = lr_batch.to(device)  # shape (B,1,H,W)
    hr_batch = hr_batch.to(device)  # shape (B,1,H*scale, W*scale)

    # pass 1: normal 2D
    sr_xy = generatorSR(lr_batch)  # => (B,1,H*scale, W*scale)

    # pass 2: transpose => 1D upsample
    sr_xy_t = sr_xy.transpose(2,3)        # => (B,1,W*scale, H*scale)
    hr_batch_t = hr_batch.transpose(2,3)  # same shape

    sr_xyz = generatorSRC(sr_xy_t)        # => upscales one dim => (B,1,W*scale, H*scale *someFactor?)
    # If scale=4 in EDSR1D, it might do a second 4x in the width dimension, leading to mismatch
    # If you truly only want to upsample the "Z" axis, you'd store your slices differently.

    # compare
    loss1 = mean_absolute_error(sr_xy, hr_batch)
    loss2 = mean_absolute_error(sr_xyz, hr_batch_t)
    total_loss = loss1 + loss2

    optimizerSR.zero_grad()
    optimizerSRC.zero_grad()
    total_loss.backward()
    optimizerSR.step()
    optimizerSRC.step()

    return loss1.item(), loss2.item()


def main_train(args):
    # set device
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpuIDs
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # build dataset
    hr_path = os.path.join(args.dataset_dir, "HR.npy")
    lr_path = os.path.join(args.dataset_dir, "LR.npy")
    train_dataset = DoubleSRDataset(hr_path=hr_path,
                                    lr_path=lr_path,
                                    scale=args.scale,
                                    fine_size=args.fine_size,
                                    iters=args.itersPerEpoch,
                                    batch_size=args.batch_size,
                                    aug_flag=args.augFlag)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)

    # create models
    generatorSR = EDSR(scale=args.scale, n_resblocks=args.numResBlocks, n_feats=args.ngsrf).to(device)
    # typically smaller for pass 2
    generatorSRC = EDSR1D(scale=args.scale, n_resblocks=args.numResBlocks//2, n_feats=args.ngsrf//2).to(device)

    optimizerSR = torch.optim.Adam(generatorSR.parameters(), lr=args.lr, betas=(0.9, 0.999))
    optimizerSRC = torch.optim.Adam(generatorSRC.parameters(), lr=args.lr, betas=(0.9, 0.999))

    # load checkpoints if continuing
    if args.continue_train:
        sr_ckpt = os.path.join(args.checkpoint_dir, f"generatorSR_{args.continueEpoch}.pth")
        src_ckpt = os.path.join(args.checkpoint_dir, f"generatorSRC_{args.continueEpoch}.pth")
        if os.path.exists(sr_ckpt):
            generatorSR.load_state_dict(torch.load(sr_ckpt, map_location=device))
            print(f"Loaded generatorSR from epoch {args.continueEpoch}")
        if os.path.exists(src_ckpt):
            generatorSRC.load_state_dict(torch.load(src_ckpt, map_location=device))
            print(f"Loaded generatorSRC from epoch {args.continueEpoch}")

    # optionally multi-gpu
    if args.distributed and torch.cuda.device_count() > 1:
        generatorSR = nn.DataParallel(generatorSR)
        generatorSRC = nn.DataParallel(generatorSRC)

    # simple LR schedule
    def adjust_lr(optimizer, epoch_idx):
        # replicate: lr = initial_lr * 0.5^(epoch / args.epoch_step)
        new_lr = args.lr * (0.5 ** (epoch_idx / args.epoch_step))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lr

    global_step = 0
    t0 = time.time()

    for epoch in range(args.continueEpoch, args.epoch):
        # adjust lr
        adjust_lr(optimizerSR, epoch)
        adjust_lr(optimizerSRC, epoch)

        epoch_loss1 = 0.
        epoch_loss2 = 0.
        step_count = 0

        for i,(lr_batch, hr_batch) in enumerate(train_loader):
            # shape: (1, batch_size,1,H,W)
            lr_batch = lr_batch.squeeze(0)
            hr_batch = hr_batch.squeeze(0)

            l1, l2 = train_one_step(generatorSR, generatorSRC,
                                    optimizerSR, optimizerSRC,
                                    lr_batch, hr_batch,
                                    device=device)
            epoch_loss1 += l1
            epoch_loss2 += l2
            step_count += 1
            global_step += 1

            if (i+1) % args.print_freq == 0:
                print(f"Epoch [{epoch+1}/{args.epoch}] Iter [{i+1}/{len(train_loader)}] "
                      f"Loss1: {l1:.4f}, Loss2: {l2:.4f}, Time: {time.time()-t0:.1f}s")

        epoch_loss1 /= step_count
        epoch_loss2 /= step_count
        print(f"=== End Epoch {epoch+1} | Loss1: {epoch_loss1:.4f}, Loss2: {epoch_loss2:.4f} ===")

        if (epoch+1) % args.save_freq == 0:
            os.makedirs(args.checkpoint_dir, exist_ok=True)
            sr_path = os.path.join(args.checkpoint_dir, f"generatorSR_{epoch+1}.pth")
            src_path = os.path.join(args.checkpoint_dir, f"generatorSRC_{epoch+1}.pth")
            torch.save(generatorSR.state_dict(), sr_path)
            torch.save(generatorSRC.state_dict(), src_path)
            print(f"Saved checkpoint epoch {epoch+1}")
