In [9]:
from Code._2DGS import *
device = torch.device("cuda")

In [10]:
data_dir = ''
data_fname = '2DGS/2DGS_IC'

# define the initial conditions    
ICs = np.arange(1, 21)
data_loader = GSDataset(data_dir, data_fname, ICs)
n_datasets = data_loader.__len__()

# get mean and std
data = data_loader[0][1]
total_hres = torch.zeros(len(data_loader), data_loader[0][1].shape[0],  data_loader[0][1].shape[1],  data_loader[0][1].shape[2],  data_loader[0][1].shape[3])
total_lres = torch.zeros(len(data_loader), data_loader[0][0].shape[0],  data_loader[0][0].shape[1],  data_loader[0][0].shape[2],  data_loader[0][0].shape[3]) # [b,t,c,h,w]

for i in range(len(data_loader)):
    total_hres[i,...] = data_loader[i][1]
    total_lres[i,...] = data_loader[i][0]

mean_hres = torch.mean(total_hres, axis = (0,1,3,4)) 
std_hres = torch.std(total_hres, axis = (0,1,3,4))

# split data
split_ratio = [int(n_datasets*0.7), int(n_datasets*0.2), n_datasets - int(n_datasets*0.7) - int(n_datasets*0.2)]

train_data, val_data, test_data = torch.utils.data.random_split(data_loader, split_ratio)

# change to pytorch data
# data in train_loader is [b, t, c, h, w] -> [1, 151, 2, 32, 32]
train_loader = torch.utils.data.DataLoader(train_data, batch_size = 2, 
    shuffle=True, num_workers=0) 

val_loader = torch.utils.data.DataLoader(val_data, batch_size = 2, 
    shuffle=False, num_workers=0)    

test_loader = torch.utils.data.DataLoader(test_data, batch_size = 2, 
    shuffle=False, num_workers=0)

In [32]:
######################### build model #############################
# training parameters
n_iters = 2000 
learning_rate = 1e-3
print_every = 2   
dt = 1.0*5*2 
dx = 1.0  
steps = 100 
effective_step = list(range(0, steps))

beta = 0.025 # for physics loss        
print('Super-Resolution for 2D GS equation...')

model = PhySR(
    n_feats = 64,
    n_layers = [1, 2], # [n_convlstm, n_resblock]
    upscale_factor = [4, 8], # [t_up, s_up]
    shift_mean_paras = [mean_hres, std_hres],  
    step = steps,
    effective_step = effective_step).to(device)

init_state = get_init_state(
        batch_size = [2], 
        hidden_channels = [64], 
        output_size = [[16, 16]],
        mode = 'random')

model, _, _ = load_checkpoint(model, optimizer=None, scheduler=None, 
            save_dir='checkpoint_full.pt') 

Super-Resolution for 2D GS equation...
Pretrained model loaded!


In [83]:
for idx, (lres, hres) in enumerate(train_loader):
    lres = lres.transpose(0,1)
    break

In [84]:
lres.shape

torch.Size([25, 2, 2, 16, 16])

In [90]:
x = total_lres[0:1].transpose(0,1).to(device)
y = model(x, init_state)
print(x.shape, y.shape)

torch.Size([25, 1, 2, 16, 16]) torch.Size([100, 2, 2, 128, 128])


In [85]:
from matplotlib import pyplot as plt

In [92]:
fig, ax = plt.subplots(1,2, figsize=(10,5))
ax[0].imshow(y[0][0][0].cpu().detach().numpy())
ax[1].imshow(y[0][0][1].cpu().detach().numpy())
plt.show()

In [6]:
plt.imshow(total_hres[0][0][0])
plt.show()