In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.io import read_image
import os
import h5py
import random

In [2]:
class FeatureExtractor(nn.Module):
    def __init__(self, in_channels=1, dim=64, img_size=100):
        super(FeatureExtractor, self).__init__()
        self.dim = dim
        self.img_size = img_size
        # Define the encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, dim, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(dim, dim, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc = nn.Linear(dim * (img_size // 4)**2, 1)

    def forward(self, x):
        # Encoder
        x = self.encoder(x)  # [batch, time, in_channels, size, size] -> [batch, time, dim, size/4, size/4]
        x = x.view(-1, self.dim * (self.img_size // 4)**2)  # Adjust the product based on your image size
        out = self.fc(x)
        return out.unsqueeze(-1)
        
        

class ImageWeights(nn.Module):
    def __init__(self, in_dim, dim, img_size, device):
        super(ImageWeights, self).__init__()
        # img_size is the multiplication of width and height
        # Define convolutional layers
        self.device = device
        self.img_size = img_size
        self.Fextractor = FeatureExtractor(in_channels=in_dim*2, dim=dim, img_size=img_size)

    def forward(self, l8co, s2_series, mask):
        # l8co = [batch_num, channel, height, width] (300*300)
        # s2_series = [batch_num, series_len, channel, height, width]
        # mask = [batch_num, series_len], type: numpy
        # assume the input images' height equal to widths
        scale_l8 = l8co.shape[-1] / self.img_size
        scale_s2 = s2_series.shape[-1] / self.img_size
        weight = torch.zeros((s2_series.shape[0], s2_series.shape[1], 1, 1)).to(self.device)

        # downsample l8 and s2 to smaller size
        if scale_l8 > 1:
            l8co_small = torch.nn.functional.interpolate(l8co, scale_factor=1.0/scale_l8, mode='bicubic')
        else:
            l8co_small = l8co
        if scale_s2 > 1:
            # change to 4D data for interpolation
            s2_series_shape = s2_series.shape
            s2_series_small = torch.nn.functional.interpolate(s2_series.reshape((s2_series_shape[0]*s2_series_shape[1], *s2_series_shape[2:])), scale_factor=1.0/scale_s2, mode='bicubic')
            s2_series_small = s2_series_small.reshape((*s2_series_shape[:3], int(s2_series_shape[-2]/scale_s2), int(s2_series_shape[-1]/scale_s2)))
        else:
            s2_series_small = s2_series


        # Apply UNet to s2_series
        for itime in range(s2_series.shape[1]):
            if np.all(mask[:, itime] == True): # all images of this time in this batch are masked
                weight[:, itime, ...].fill_(-torch.inf)
            else: # there are some images of this time in this batch not masked
                weight[:, itime, ...] = self.Fextractor(torch.cat((l8co_small, l8co_small - s2_series_small[:, itime, :, :, :]), dim=1))
                weight[:, itime, ...][mask[:, itime]] = -torch.inf
        weight_softmax = torch.softmax(weight, dim=1)
        weight = weight.unsqueeze(-1).expand(-1, -1, s2_series.shape[-3], s2_series.shape[-2], s2_series.shape[-1])
        weight_softmax = weight_softmax.unsqueeze(-1).expand(-1, -1, s2_series.shape[-3], s2_series.shape[-2], s2_series.shape[-1])
        return (s2_series*weight_softmax).sum(1), weight_softmax, weight

In [3]:
class SpatiotemporalFusionNetwork(nn.Module):
  def __init__(self, d0, d1, d2):
    super(SpatiotemporalFusionNetwork, self).__init__()
    self.shared_network = nn.Sequential(
        nn.Conv2d(1, d0, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(d0, d1, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(d1, d1, kernel_size=3, stride=2, padding=1, output_padding=1),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(d1, d2, kernel_size=3, stride=2, padding=1, output_padding=1),
        nn.ReLU(inplace=True),
    )

    # High-frequency extraction stage
    self.high_freq_extract = nn.Sequential(
        nn.Conv2d(1, d0, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(d0, d1, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(d1, d1, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(d1, d2, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
    )
    # Feature fusion stage
    self.fusion = nn.Sequential(
        nn.Conv2d(d2, d1, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(d1, d0, kernel_size=1, stride=1, padding=0),
        nn.ReLU(inplace=True),
        nn.Conv2d(d0, 1, kernel_size=1, stride=1, padding=0),
        nn.ReLU(inplace=True),
    )

  def forward(self, l8_0, l8_k, s2_k):
    # goal: predict s2_0

    # High-frequency extraction
    s2_k_h = self.high_freq_extract(s2_k)   # (batch, d2, 300, 300)
    # Feature expansion
    l8_0_h = self.shared_network(l8_0)      # (batch, d2, 300, 300)
    l8_k_h = self.shared_network(l8_k)      # (batch, d2, 300, 300)
    # Feature fusion
    s2_0_h = l8_0_h + s2_k_h - l8_k_h       # (batch, d2, 300, 300)
    s2_0 = self.fusion(s2_0_h)              # (batch, 1, 300, 300)
    return s2_0


In [5]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms

In [6]:
class cls_dataset(Dataset):
    def __init__(self, hdf5file_path_list, loc_l8, loc_s2, num1, num2) -> None:
        super().__init__()

        self.set = []

        SIZE = 100
        for iloc in range(num1, num2, 1):
          hdf5_file = h5py.File(hdf5file_path_list[iloc-num1],"r")
          samples_l8co = torch.Tensor(np.array(hdf5_file["l8co"]))
          samples_l8panco = torch.Tensor(np.array(hdf5_file["l8panco"]))
          samples_s2co = torch.Tensor(np.array(hdf5_file["s2co"]))
          samples_l8 = torch.Tensor(np.array(hdf5_file["l8"]))
          samples_l8pan = torch.Tensor(np.array(hdf5_file["l8pan"]))
          samples_s2 = torch.Tensor(np.array(hdf5_file["s2"]))
          loc_l8 = np.array(hdf5_file["l8_pos"])
          loc_l8 = np.array(hdf5_file["s2_pos"])
          patch_num = samples_l8co.shape[0]

          series_len_l8 = samples_l8.shape[1] + 1
          series_len_s2 = samples_s2.shape[1] + 1

          samples_l8co = samples_l8co.unsqueeze(1)
          l8 = torch.cat((samples_l8[:, :loc_l8[iloc-1]], samples_l8co, samples_l8[:, loc_l8[iloc-1]:]), dim=1)
          for iband in range(7):
            min_band = 0
            scale = 1.0
            l8[:, :, iband, :, :] = (l8[:, :, iband, :, :] - min_band) * scale
          if series_len_l8 < 6:
            zeros_array = torch.zeros((l8.shape[0], 6 - series_len_l8, l8.shape[2], l8.shape[3], l8.shape[4]))
            l8_series = torch.cat((l8, zeros_array), dim=1)
          else:
            l8_series = l8

          samples_l8panco = samples_l8panco.unsqueeze(1)
          l8pan = torch.cat((samples_l8pan[:, :loc_l8[iloc-1]], samples_l8panco, samples_l8pan[:, loc_l8[iloc-1]:]), dim=1)
          if series_len_l8 < 6:
            zeros_array = torch.zeros((l8pan.shape[0], 6 - series_len_l8, l8pan.shape[2], l8pan.shape[3], l8pan.shape[4]))
            l8pan_series = torch.cat((l8pan, zeros_array), dim=1)
          else:
            l8pan_series = l8pan

          samples_s2co = samples_s2co.unsqueeze(1)
          s2 = torch.cat((samples_s2[:, :loc_s2[iloc-1]], samples_s2co, samples_s2[:, loc_s2[iloc-1]:]), dim=1)
          for iband in range(12):
            min_band = 0
            scale = 1e-4
            s2[:, :, iband, :, :] = (s2[:, :, iband, :, :] - min_band) * scale
          if series_len_s2 < 12:
            zeros_array = torch.zeros((s2.shape[0], 12 - series_len_s2, s2.shape[2], s2.shape[3], s2.shape[4]))
            s2_series = torch.cat((s2, zeros_array), dim=1)
          else:
            s2_series = s2

          nan_mask = torch.isnan(l8_series)

          # Replace NaN values with 0
          l8_series[nan_mask] = 0

          for idata in range(patch_num):
            info_dict = {
                'l8': l8_series[idata],
                'l8pan': l8pan_series[idata],
                's2': s2_series[idata],
                'series_len_l8': series_len_l8,
                'series_len_s2': series_len_s2,
                'loc_l8': loc_l8[iloc-1],
                'loc_s2': loc_s2[iloc-1]
            }
            self.set.append(info_dict)




    def __getitem__(self, index):
        return self.set[index]

    def __len__(self):
        return len(self.set)

In [7]:
def load_datasets(num1, num2):

    root = 'D:/SR/data/segmented100_v2/'
    path_list = []
    for iloc in range(num1, num2):
      path_list.append(os.path.join(root, f"loc{iloc}.hdf5"))

    data_set = cls_dataset(hdf5file_path_list=path_list, num1=num1, num2=num2)

    print(len(data_set))

    return data_set

offset l8

In [None]:
# Hyperparameters
batch_size = 16
learning_rate = 0.0001
num_epochs = 50
weight_decay = 0.00001


data_loader = DataLoader(data_set, batch_size=batch_size, shuffle=True)

# Initialize the autoencoder
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('using device: ', device)
network = SpatiotemporalFusionNetwork(d0=32, d1=64, d2=128)
network.to(device)
weightnet = ImageWeights(in_dim=7, dim=32, img_size=100, device=device)
weightnet.load_state_dict(torch.load('D:/SR/baseline/DCSTFN/model/weightnet/indim7dim32_bs16lr2e-3wd3e-3_epoch100_v6tr1_falseimgs_zerorandommagnifywrongloc_falseimageloss_singleweight/Weightnet.pth'))
weightnet.to(device)
weightnet.eval()  # Set to evaluation mode to freeze weights

# Loss function and optimizer
criterion = nn.MSELoss()  # Mean squared error loss
optimizer = optim.Adam(network.parameters(), lr=learning_rate, weight_decay=weight_decay)

s2_seven_bands = [0, 1, 2, 3, 8, 10, 11]
target_band = [2]
target_band_s2 = [2]

# Training loop
for epoch in range(num_epochs):
    for data_dict in data_loader:
        l8, s2 = data_dict['l8'][:, :, target_band], data_dict['s2'][:, :, target_band_s2]
        l8_fullband, s2_sevenband = data_dict['l8'], data_dict['s2'][:, :, s2_seven_bands]

        l8, s2 = l8.to(device), s2.to(device)
        l8_fullband, s2_sevenband = l8_fullband.to(device), s2_sevenband.to(device)

        l8_time_num, s2_time_num = data_dict['series_len_l8'], data_dict['series_len_s2']
        l8_loc, s2_loc = data_dict['loc_l8'], data_dict['loc_s2']

        shape_l8, shape_s2 = l8.shape, s2.shape



        ################################## second part of gradient descent ##########################################
        # third loss: Time Series Inference Loss
        l8 = l8.reshape((shape_l8[0]*shape_l8[1], shape_l8[2], shape_l8[3], shape_l8[4]))
        s2 = s2.reshape((shape_s2[0]*shape_s2[1], shape_s2[2], shape_s2[3], shape_s2[4]))
        ind_l8 = []
        for i in range(batch_size):
            if l8_loc[i] < l8_time_num[i]-1:
                ind_l8.append(int(i*shape_l8[1] + l8_loc[i] + 1))
            else:
                ind_l8.append(int(i*shape_l8[1] + l8_loc[i] - 1))
        ind_s2 = [int(i*shape_s2[1] + s2_loc[i]) for i in range(batch_size)]
        ind_s2_neighbor = [int(i*shape_s2[1] + s2_loc[i] + 1) for i in range(batch_size)]
        l8_co = l8[ind_l8]
        s2_co = s2[ind_s2]
        
        l8_fullband_co = l8_fullband.reshape((shape_l8[0]*shape_l8[1], l8_fullband.shape[2], shape_l8[3], shape_l8[4]))[ind_l8]
        # create mask for weight
        s2_series_mask = np.empty((shape_s2[0], shape_s2[1]), dtype=bool)
        s2_series_mask[:] = False
        for ibatch in range(s2_series_mask.shape[0]):
            s2_series_mask[ibatch, s2_time_num[ibatch]:] = True     # mask empty images
            s2_series_mask[ibatch, s2_loc[ibatch]] = True   # mask gt s2 image
        s2_p, s2_weightsm, _ = weightnet(l8_fullband_co, s2_sevenband, s2_series_mask)
        s2_p = s2_p[:, target_band, ...]
        
        
        s2_neighbor = s2_p
        l8_neighbor = torch.nn.functional.interpolate(s2_neighbor, scale_factor=75.0/300.0, mode='bicubic')
        l8_co = torch.nn.functional.interpolate(l8_co, scale_factor=75.0/100.0, mode='bicubic')
        s2 = s2.reshape(shape_s2)
        l8 = l8.reshape(shape_l8)

        s2_p = network(l8_co, l8_neighbor, s2_neighbor)


        loss = criterion(s2_co, s2_p)



        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.7f}')
    # Save the trained model
folder_path = f'./model/fullres/timeseriess2/l8offset_d32d64d128_bs16lr1e-4wd1e-5_epoch50'
if not os.path.exists(folder_path):
    os.makedirs(folder_path)
weightnet_path = os.path.join(folder_path, 'band{}.pth'.format(target_band))
torch.save(network.state_dict(), weightnet_path)

real l8

In [None]:
# Hyperparameters
batch_size = 16
learning_rate = 0.0001
num_epochs = 50
weight_decay = 0.00001


data_loader = DataLoader(data_set, batch_size=batch_size, shuffle=True)

# Initialize the autoencoder
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('using device: ', device)
network = SpatiotemporalFusionNetwork(d0=32, d1=64, d2=128)
network.to(device)
weightnet = ImageWeights(in_dim=7, dim=32, img_size=100, device=device)
weightnet.load_state_dict(torch.load('D:/SR/baseline/DCSTFN/model/weightnet/indim7dim32_bs16lr2e-3wd3e-3_epoch100_v6tr1_falseimgs_zerorandommagnifywrongloc_falseimageloss_singleweight/Weightnet.pth'))
weightnet.to(device)
weightnet.eval()  # Set to evaluation mode to freeze weights

# Loss function and optimizer
criterion = nn.MSELoss()  # Mean squared error loss
optimizer = optim.Adam(network.parameters(), lr=learning_rate, weight_decay=weight_decay)

s2_seven_bands = [0, 1, 2, 3, 8, 10, 11]
target_band = [2]
target_band_s2 = [2]

# Training loop
for epoch in range(num_epochs):
    for data_dict in data_loader:
        l8, s2 = data_dict['l8'][:, :, target_band], data_dict['s2'][:, :, target_band_s2]
        l8_fullband, s2_sevenband = data_dict['l8'], data_dict['s2'][:, :, s2_seven_bands]

        l8, s2 = l8.to(device), s2.to(device)
        l8_fullband, s2_sevenband = l8_fullband.to(device), s2_sevenband.to(device)

        l8_time_num, s2_time_num = data_dict['series_len_l8'], data_dict['series_len_s2']
        l8_loc, s2_loc = data_dict['loc_l8'], data_dict['loc_s2']

        shape_l8, shape_s2 = l8.shape, s2.shape



        ################################## second part of gradient descent ##########################################
        # third loss: Time Series Inference Loss
        l8 = l8.reshape((shape_l8[0]*shape_l8[1], shape_l8[2], shape_l8[3], shape_l8[4]))
        s2 = s2.reshape((shape_s2[0]*shape_s2[1], shape_s2[2], shape_s2[3], shape_s2[4]))
        ind_l8 = [int(i*shape_l8[1] + l8_loc[i]) for i in range(batch_size)]
        ind_s2 = [int(i*shape_s2[1] + s2_loc[i]) for i in range(batch_size)]
        ind_s2_neighbor = [int(i*shape_s2[1] + s2_loc[i] + 1) for i in range(batch_size)]
        l8_co = l8[ind_l8]
        s2_co = s2[ind_s2]
        
        
        
        l8_fullband_co = l8_fullband.reshape((shape_l8[0]*shape_l8[1], l8_fullband.shape[2], shape_l8[3], shape_l8[4]))[ind_l8]
        # create mask for weight
        s2_series_mask = np.empty((shape_s2[0], shape_s2[1]), dtype=bool)
        s2_series_mask[:] = False
        for ibatch in range(s2_series_mask.shape[0]):
            s2_series_mask[ibatch, s2_time_num[ibatch]:] = True     # mask empty images
            s2_series_mask[ibatch, s2_loc[ibatch]] = True   # mask gt s2 image
        s2_p, s2_weightsm, _ = weightnet(l8_fullband_co, s2_sevenband, s2_series_mask)
        s2_p = s2_p[:, target_band, ...]
        
        
        s2_neighbor = s2_p
        l8_neighbor = torch.nn.functional.interpolate(s2_neighbor, scale_factor=75.0/300.0, mode='bicubic')
        l8_co = torch.nn.functional.interpolate(l8_co, scale_factor=75.0/100.0, mode='bicubic')
        s2 = s2.reshape(shape_s2)
        l8 = l8.reshape(shape_l8)
        

        s2_p = network(l8_co, l8_neighbor, s2_neighbor)


        loss = criterion(s2_co, s2_p)


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.7f}')
    # Save the trained model
folder_path = f'./model/fullres/timeseriess2/d32d64d128_bs16lr1e-4wd1e-5_epoch50'
if not os.path.exists(folder_path):
    os.makedirs(folder_path)
weightnet_path = os.path.join(folder_path, 'band{}.pth'.format(target_band))
torch.save(network.state_dict(), weightnet_path)

downsample

In [None]:
# Hyperparameters
batch_size = 16
learning_rate = 0.0001
num_epochs = 50
weight_decay = 0.00001


data_loader = DataLoader(data_set, batch_size=batch_size, shuffle=True)

# Initialize the autoencoder
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('using device: ', device)
network = SpatiotemporalFusionNetwork(d0=32, d1=64, d2=128)
network.to(device)
weightnet = ImageWeights(in_dim=7, dim=32, img_size=100, device=device)
weightnet.load_state_dict(torch.load('D:/SR/baseline/DCSTFN/model/weightnet/indim7dim32_bs16lr2e-3wd3e-3_epoch100_v6tr1_falseimgs_zerorandommagnifywrongloc_falseimageloss_singleweight/Weightnet.pth'))
weightnet.to(device)
weightnet.eval()  # Set to evaluation mode to freeze weights

# Loss function and optimizer
criterion = nn.MSELoss()  # Mean squared error loss
# criterion = nn.L1Loss()  # Mean absolute error loss
optimizer = optim.Adam(network.parameters(), lr=learning_rate, weight_decay=weight_decay)

s2_seven_bands = [0, 1, 2, 3, 8, 10, 11]
target_band = [2]
target_band_s2 = [2]

# Training loop
for epoch in range(num_epochs):
    for data_dict in data_loader:
        l8, s2 = data_dict['l8'][:, :, target_band], data_dict['s2'][:, :, target_band_s2]
        l8_fullband, s2_sevenband = data_dict['l8'], data_dict['s2'][:, :, s2_seven_bands]

        l8, s2 = l8.to(device), s2.to(device)
        l8_fullband, s2_sevenband = l8_fullband.to(device), s2_sevenband.to(device)

        l8_time_num, s2_time_num = data_dict['series_len_l8'], data_dict['series_len_s2']
        l8_loc, s2_loc = data_dict['loc_l8'], data_dict['loc_s2']

        shape_l8, shape_s2 = l8.shape, s2.shape



        ################################## second part of gradient descent ##########################################
        # third loss: Time Series Inference Loss
        s2 = s2.reshape((shape_s2[0]*shape_s2[1], shape_s2[2], shape_s2[3], shape_s2[4]))
        ind_s2 = [int(i*shape_s2[1] + s2_loc[i]) for i in range(batch_size)]
        ind_s2_fakegt = [int(i*shape_s2[1] + s2_loc[i] - 1) for i in range(batch_size)]
        ind_s2_neighbor = [int(i*shape_s2[1] + s2_loc[i] + 1) for i in range(batch_size)]
        s2_co = s2[ind_s2]
        
        l8_fullband_co = l8_fullband.reshape((shape_l8[0]*shape_l8[1], l8_fullband.shape[2], shape_l8[3], shape_l8[4]))[ind_l8]
        # create mask for weight
        s2_series_mask = np.empty((shape_s2[0], shape_s2[1]), dtype=bool)
        s2_series_mask[:] = False
        for ibatch in range(s2_series_mask.shape[0]):
            s2_series_mask[ibatch, s2_time_num[ibatch]:] = True     # mask empty images
            s2_series_mask[ibatch, s2_loc[ibatch]] = True   # mask gt s2 image
        s2_p, s2_weightsm, _ = weightnet(l8_fullband_co, s2_sevenband, s2_series_mask)
        s2_p = s2_p[:, target_band, ...]
        
        
        s2_fakegt = s2_p
        l8_co = torch.nn.functional.interpolate(s2_fakegt, scale_factor=75.0/300.0, mode='bicubic')
        s2_neighbor = s2[ind_s2_neighbor]
        l8_neighbor = torch.nn.functional.interpolate(s2_neighbor, scale_factor=75.0/300.0, mode='bicubic')
        s2 = s2.reshape(shape_s2)

        s2_p = network(l8_co, l8_neighbor, s2_neighbor)


        loss = criterion(s2_co, s2_p)



        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.7f}')
# Save the trained model
folder_path = f'./model/fullres/timeseriess2/l8downsamp_d32d64d128_bs16lr1e-4wd1e-5_epoch50'
if not os.path.exists(folder_path):
    os.makedirs(folder_path)
weightnet_path = os.path.join(folder_path, 'band{}.pth'.format(target_band))
torch.save(network.state_dict(), weightnet_path)

In [50]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

testmodel_list = []
for target_band in [[3],[2],[1]]:
# for target_band in [[6],[5],[4]]:
    # folder_path = f'./model/fullres/timeseriess2/d32d64d128_bs16lr1e-4wd1e-5_epoch50'
    # folder_path = f'./model/fullres/timeseriess2/l8offset_d32d64d128_bs16lr1e-4wd1e-5_epoch50'
    folder_path = f'./model/fullres/timeseriess2/l8downsamp_d32d64d128_bs16lr1e-4wd1e-5_epoch50'
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    network_path = os.path.join(folder_path, 'band{}.pth'.format(target_band))

    testmodel = SpatiotemporalFusionNetwork(d0=32, d1=64, d2=128)
    testmodel.load_state_dict(torch.load(network_path))
    testmodel.to(device)
    testmodel.eval()
    testmodel_list.append(testmodel)

In [44]:
import numpy as np
import h5py
import os
import matplotlib.pyplot as plt
import cv2

original

In [None]:
s2_seven_bands = [0, 1, 2, 3, 8, 10, 11]
RGB = [[3],[2],[1]]
RGB_s2 = [[3], [2], [1]]


flag = False
i = 0
with torch.no_grad():
    for iloc in range(70, 101):
        print(f'iloc{iloc}')
        test_dataset = load_datasets(iloc, iloc+1)
        test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) # donot shuffle
        for iimg, data_dict in enumerate(test_loader):
            img_l8_p_SR = []
            img_l8_co = []
            img_s2_co = []
            for imodel, target_band in enumerate(RGB):
                target_band_s2 = RGB_s2[imodel]
                l8, s2 = data_dict['l8'][:, :, target_band], data_dict['s2'][:, :, target_band_s2]

                l8, s2 = l8.to(device), s2.to(device)

                l8_time_num, s2_time_num = data_dict['series_len_l8'], data_dict['series_len_s2']
                l8_loc, s2_loc = data_dict['loc_l8'], data_dict['loc_s2']

                shape_l8, shape_s2 = l8.shape, s2.shape


                l8 = l8.reshape((shape_l8[0]*shape_l8[1], shape_l8[2], shape_l8[3], shape_l8[4]))
                s2 = s2.reshape((shape_s2[0]*shape_s2[1], shape_s2[2], shape_s2[3], shape_s2[4]))
                l8_co = l8[l8_loc]
                s2_co = s2[s2_loc]
                s2_neighbor = s2[s2_loc[0] + 1].unsqueeze(0)
                l8_neighbor = torch.nn.functional.interpolate(s2_neighbor, scale_factor=75.0/300.0, mode='bicubic')
                l8_co = torch.nn.functional.interpolate(l8_co, scale_factor=75.0/100.0, mode='bicubic')
                s2 = s2.reshape(shape_s2)
                l8 = l8.reshape(shape_l8)

                l8_p_SR = testmodel_list[imodel](l8_co, l8_neighbor, s2_neighbor)


                l8_p_SR = np.array(l8_p_SR.cpu()).squeeze(0)

                l8_p_SR = (l8_p_SR + 0) * 255
                l8_p_SR = l8_p_SR.clip(0, 255)
                


                img_l8_p_SR.append(l8_p_SR) # l8_p_SR: (1, 300, 300)

            img_l8_p_SR = np.concatenate(img_l8_p_SR, axis=0)
            img_l8_p_SR = img_l8_p_SR.transpose((1,2,0)).astype(np.uint8) # shape: (3, 300, 300)
            
            # output_folder_l8_p_SR = 'D:/SR/result/DCSTFN_fullres_timeseriess2/DCSTFN_l8s2_ifr'
            # output_folder_l8_p_SR = 'D:/SR/result/DCSTFN_fullres_timeseriess2/DCSTFN_nl8s2_ifr'
            # output_folder_l8_p_SR = 'D:/SR/result/DCSTFN_fullres_timeseriess2/DCSTFN_intrpl8s2_ifr'
            output_folder_l8_p_SR = 'D:/SR/result/DCSTFN_fullres_timeseriess2/DCSTFN_ds2s2_ifr'
            if not os.path.exists(output_folder_l8_p_SR):
                # Create the folder
                os.makedirs(output_folder_l8_p_SR)
                print(f"Folder '{output_folder_l8_p_SR}' created.")
            # else:
            #     print(f"Folder '{output_folder_l8_p_SR}' already exists.")
            output_path_l8_p_SR = os.path.join(output_folder_l8_p_SR, f'loc{iloc}_{iimg}.png')
            cv2.imwrite(output_path_l8_p_SR, img_l8_p_SR)