In [1]:
import random
import numpy as np
import torch
import gc
from torch import optim

class BatchIndex:
    def __init__(self, size, batch_size, shuffle=True):
        self.index_list = torch.as_tensor([(x, min(x + batch_size, size)) for x in range(0, size, batch_size)])
        
        if shuffle:
            self.index_list = self.index_list[torch.randperm(len(self.index_list))]
        
        self.pos = -1

    def __next__(self):
        self.pos += 1
        if self.pos >= len(self.index_list):
            raise StopIteration
        return self.index_list[self.pos]

    def __iter__(self):
        self.pos = -1
        return self

    def __len__(self):
        return len(self.index_list)
    
def get_mgrid(sidelen, dim=2, s=1,t=0):
    '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.'''
    if isinstance(sidelen, int):
        sidelen = dim * (sidelen,)

    if dim == 2:
        pixel_coords = np.stack(np.mgrid[:sidelen[0]:s, :sidelen[1]:s], axis=-1)[None, ...].astype(np.float32)
        pixel_coords[..., 0] = pixel_coords[..., 0] / (sidelen[0] - 1)
        pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1)
    elif dim == 3:
        ranges = [
            torch.arange(0, sidelen[0], s, device='cuda:0'),
            torch.arange(0, sidelen[1], s, device='cuda:0'),
            torch.arange(0, sidelen[2], s, device='cuda:0'),
        ]
        grid = torch.meshgrid(ranges, indexing='ij')
        pixel_coords = torch.stack(grid, dim=-1)[None, ...].float()
        pixel_coords[..., 0] = pixel_coords[..., 0] / (sidelen[0] - 1)
        pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1)
        pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1)
    elif dim == 4:
        pixel_coords = np.stack(np.mgrid[:sidelen[0]:(t+1), :sidelen[1]:s, :sidelen[2]:s, :sidelen[3]:s], axis=-1)[None, ...].astype(np.float32)
        pixel_coords[..., 0] = pixel_coords[..., 0] / max(sidelen[0] - 1, 1)
        pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1)
        pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1)
        pixel_coords[..., 3] = pixel_coords[..., 3] / (sidelen[3] - 1)
    else:
        raise NotImplementedError('Not implemented for dim=%d' % dim)
    pixel_coords = 2. * pixel_coords - 1.
    pixel_coords = pixel_coords.cpu().numpy().reshape(-1,3, order='F')
    return pixel_coords

def fast_random_choice(dim, num_samples_per_frame, unique=True, device='cuda:0'):
    if unique:
        num_samples = num_samples_per_frame * 2  # 防止去重后低于预定采样值
        x = torch.randint(
                0, dim[0], size=(num_samples,), device='cuda:0'
            )
        y = torch.randint(
                0, dim[1], size=(num_samples,), device='cuda:0'
            )
        z = torch.randint(
                0, dim[2], size=(num_samples,), device='cuda:0'
            )
        
        xyz = torch.stack([x, y, z], dim=-1)
        _, index = torch.unique(xyz, dim=0, sorted=False, return_inverse=True)
        xyz = xyz[index[:num_samples_per_frame, ...]]
        return xyz[...,0], xyz[...,1], xyz[...,2]
    else:
        x = torch.randint(
                0, dim[0], size=(num_samples_per_frame,), device='cuda:0'
            )
        y = torch.randint(
                0, dim[1], size=(num_samples_per_frame,), device='cuda:0'
            )
        z = torch.randint(
                0, dim[2], size=(num_samples_per_frame,), device='cuda:0'
            )
        xyz = torch.stack([x, y, z], dim=-1)
        if device == 'cpu':
            xyz = xyz.cpu()
        return xyz[...,0], xyz[...,1], xyz[...,2]
    
def count_params(model):  # 查看模型参数量
    param_num = sum(p.numel() for p in model.parameters())
    return param_num

def cleanup():
    gc.collect()
    torch.cuda.empty_cache()

def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.manual_seed(seed)

In [2]:
import torch.utils
import numpy as np
import torch
import os
import threading
import queue
import tqdm
from concurrent.futures import ThreadPoolExecutor

class ScalarDataSet():
    def __init__(self,args, device='cuda:0'):
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.interval = args.interval
        self.downsample_factor = args.downsample_factor
        self.device = device

        if self.dataset == 'h2':
            self.dim = [600, 248, 248]
            self.total_samples = 1
            self.data_path = './dataset/h2/'        
        elif self.dataset == 'fivejets':
            self.dim = [128, 128, 128]
            self.total_samples = 1
            self.data_path = './dataset/fivejets/' 
        elif self.dataset == 'combustion':
            self.dim = [480, 720, 120]
            self.total_samples = 1
            self.data_path = './dataset/combustion/'
        elif self.dataset == 'halfcy':
            self.dim = [640, 240, 80]
            self.total_samples = 1
            self.data_path = './dataset/halfcy/'
        elif self.dataset == 'tornado':
            self.dim = [128, 128, 128]
            self.total_samples = 1
            self.data_path = './dataset/tornado/'
        elif self.dataset == 'vortex':
            self.dim = [128, 128, 128]
            self.total_samples = 1
            self.data_path = './dataset/vortex/'

        self.num_workers = 16

        self.samples = [i for i in range(1,self.total_samples+1,self.interval+1)]
        self.total_samples = self.samples[-1]
        self.num_samples_per_frame = (self.dim[0]*self.dim[1]*self.dim[2]//self.downsample_factor)//self.batch_size * self.batch_size
        # self.num_samples_per_frame = 4 * self.batch_size
        # self.num_samples_per_frame = self.dim[0]*self.dim[1]*self.dim[2]

        self.queue_size = 2
        self.loader_queue = queue.Queue(maxsize=self.queue_size)  # 限制队列大小为2
        self.executor = ThreadPoolExecutor(max_workers=self.queue_size)

        if args.mode == 'train':
            self.data = self.preload_with_multi_threads(self.load_volume_data, num_workers=self.num_workers, data_str='Volume Data')
            self.data = torch.as_tensor(np.asarray(self.data), device=self.device)  # [t个时间步, z, y, x] 需要改成xyz的形式

            self.len = self.num_samples_per_frame * len(self.samples)
            self._get_data = self._get_training_data

            samples = self.dim[2]*self.dim[1]*self.dim[0]
            self.coords = get_mgrid([self.dim[0],self.dim[1],self.dim[2]],dim=3)
            self.time = np.zeros((samples,1))
            self.testing_data_inputs = torch.as_tensor(np.concatenate((self.time, self.coords),axis=1), dtype=torch.float, device='cuda:0')
            self.preload_data()

        elif args.mode == 'inf':
            samples = self.dim[2]*self.dim[1]*self.dim[0]
            self.coords = get_mgrid([self.dim[0],self.dim[1],self.dim[2]],dim=3)
            self.time = np.zeros((samples,1))
            self.testing_data_inputs = torch.as_tensor(np.concatenate((self.time, self.coords),axis=1), dtype=torch.float, device='cuda:0')
            # self._get_data = self._get_testing_data
            # self.preload_data()

    def preload_data(self):
        if self.loader_queue.full():
            return  # 如果队列已满，不进行加载
        self.loader_queue.put(self._get_data())

    def get_data(self):
        if self.loader_queue.empty():
            print("DataLoader is not ready yet! Waiting...")
        while self.loader_queue.empty():
            pass
        # 获取当前 DataLoader 并异步加载下一个
        current_data = self.loader_queue.get()
        self.executor.submit(self.preload_data)
        return current_data

    @torch.no_grad()
    def _get_testing_data(self, idx):
        t = idx - 1
        t = t / max((self.total_samples-1), 1)
        t = 2.0 * t - 1.0
        testing_data_inputs = self.testing_data_inputs.clone()
        testing_data_inputs[:,0] = t
        batchidxgenerator = BatchIndex(testing_data_inputs.shape[0], self.batch_size, False)
        return testing_data_inputs, batchidxgenerator

    @torch.no_grad()
    def _get_training_data(self):
        training_data_inputs = []
        training_data_outputs = []

        for i in range(0, len(self.samples)):
            x,y,z = fast_random_choice(self.dim, self.num_samples_per_frame)
            t = torch.ones_like(x) * (self.samples[i]-1)

            outputs = self.data[i, x, y, z]  # 第i个体数据中取xyz, 第i个体数据对应的时间步是t
            # 归一化
            x = x / (self.dim[0] - 1)
            y = y / (self.dim[1] - 1)
            z = z / (self.dim[2] - 1)
            t = t / max((self.total_samples-1), 1)

            inputs = torch.stack([t, x, y, z], dim=-1)
            inputs = 2.0 * inputs - 1.0  # 缩放到[-1,1]
            training_data_inputs.append(inputs)
            training_data_outputs.append(outputs)

        training_data_inputs = torch.cat(training_data_inputs, dim=0).cuda()
        training_data_outputs = torch.cat(training_data_outputs, dim=0).cuda()
        idx = torch.randperm(training_data_inputs.shape[0], device='cpu')
        training_data_inputs = training_data_inputs[idx].contiguous()
        training_data_outputs = training_data_outputs[idx].contiguous()
        batchidxgenerator = BatchIndex(self.len, self.batch_size, shuffle=True)
        del idx
        cleanup()
        return training_data_inputs, training_data_outputs, batchidxgenerator

    def load_volume_data(self, idx):
        d = np.fromfile(self.data_path+'{:04d}.raw'.format(self.samples[idx]), dtype='<f')
        d = 2. * (d - np.min(d)) / (np.max(d) - np.min(d)) - 1.  # FIXME: 每一帧范围都不一样，有助于时间超分？
        # d = 2. * (d - self.data_min) / (self.data_max - self.data_min) - 1.
        d = d.reshape(self.dim[2],self.dim[1],self.dim[0])  # 以x变化最大的形式存放的，读取时需要倒过来读
        d = d.transpose(2,1,0)  # 转化成xyz三维数组形式
        return d

    def _preload_worker(self, data_list, load_func, q, lock, idx_tqdm):
        # Keep preloading data in parallel.
        while True:
            idx = q.get()
            data_list[idx] = load_func(idx)
            with lock:
                idx_tqdm.update()
            q.task_done()

    def preload_with_multi_threads(self, load_func, num_workers, data_str='images'):
        data_list = [None] * len(self.samples)

        q = queue.Queue(maxsize=len(self.samples))
        idx_tqdm = tqdm.tqdm(range(len(self.samples)), desc=f"Loading {data_str}", leave=False)
        for i in range(len(self.samples)):
            q.put(i)
        lock = threading.Lock()
        for ti in range(num_workers):
            t = threading.Thread(target=self._preload_worker,
                                    args=(data_list, load_func, q, lock, idx_tqdm), daemon=True)
            t.start()
        q.join()
        idx_tqdm.close()
        assert all(map(lambda x: x is not None, data_list))

        return data_list  

In [3]:
import numpy as np
import torch
from torch import nn

class SineLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        
        self.init_weights()
    
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 
                                             1 / self.in_features)      
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                             np.sqrt(6 / self.in_features) / self.omega_0)
        
    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))

class ResBlock(nn.Module):
    def __init__(self,in_features,out_features,nonlinearity='relu'):
        super(ResBlock,self).__init__()

        self.net = []

        self.net.append(SineLayer(in_features,out_features))

        self.net.append(SineLayer(out_features,out_features))

        self.flag = (in_features!=out_features)

        if self.flag:
            self.transform = SineLayer(in_features,out_features)

        self.net = nn.Sequential(*self.net)
    
    def forward(self,features):
        outputs = self.net(features)
        if self.flag:
            features = self.transform(features)
        return 0.5*(outputs+features)

class CoordNet(nn.Module):
    def __init__(self, in_features, out_features, init_features=64,num_res = 10):
        super(CoordNet,self).__init__()

        self.num_res = num_res

        self.net = []

        self.net.append(ResBlock(in_features,init_features))
        self.net.append(ResBlock(init_features,2*init_features))
        self.net.append(ResBlock(2*init_features,4*init_features))

        for i in range(self.num_res):
            self.net.append(ResBlock(4*init_features,4*init_features))
        self.net = nn.Sequential(*self.net)
        
        self.fc1 = ResBlock(4*init_features, out_features)
        self.fc2 = ResBlock(4*init_features, out_features)
        self.fc3 = ResBlock(4*init_features, out_features)

        self.n_output_dims = out_features

    def forward(self, coords):
        output = self.net(coords)
        out1 = self.fc1(output)
        out2 = self.fc2(output)
        out3 = self.fc3(output)
        out = (out1+out2+out3)/3
        data = torch.stack([out1,out2,out3], axis=-1)
        var = torch.var(data, axis=-1)
        return out, var, out1, out2, out3
    
class LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        
        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features) 
            #self.linear.weight.normal_(0,0.05) 
        
    def forward(self, input):
        return self.linear(input)

class BottleNeckBlock(nn.Module):
    def __init__(self,in_features):
        super(BottleNeckBlock,self).__init__()
        self.net = []

        self.net.append(SineLayer(in_features, in_features//4))

        self.net.append(SineLayer(in_features//4, in_features//4))

        self.net.append(SineLayer(in_features//4, in_features))

        self.net = nn.Sequential(*self.net)
    
    def forward(self, features):
        outputs = self.net(features)
        return outputs+features
    
class CoordNetBottleNeck(nn.Module):
    def __init__(self, in_features, out_features, init_features=64,num_res = 10):
        super(CoordNetBottleNeck,self).__init__()

        self.num_res = num_res

        self.net = []
        self.net.append(SineLayer(in_features,init_features))
        self.net.append(SineLayer(init_features,2*init_features))
        self.net.append(SineLayer(2*init_features,4*init_features))

        for i in range(self.num_res):
            self.net.append(BottleNeckBlock(4*init_features))
        self.net = nn.Sequential(*self.net)
        
        self.fc1 = ResBlock(4*init_features, out_features)
        self.fc2 = ResBlock(4*init_features, out_features)
        self.fc3 = ResBlock(4*init_features, out_features)

    def forward(self, coords):
        output = self.net(coords)
        out1 = self.fc1(output)
        out2 = self.fc2(output)
        out3 = self.fc3(output)
        out = (out1+out2+out3)/3
        data = torch.stack([out1,out2,out3], axis=-1)
        var = torch.var(data, axis=-1)
        return out, var, out1, out2, out3
    
    
import torch.nn.functional as F
class FreqEmbedder:
    def __init__(self, multi_freq, include_input=True, input_dims=3, log_sampling=True):
        self.multi_freq = multi_freq
        self.input_dims = input_dims
        self.include_input = include_input
        self.log_sampling = log_sampling
        self.periodic_fns = [torch.sin, torch.cos]

        self.embed_fns = None
        self.out_dim = None
        self.create_embedding_fn()

    def create_embedding_fn(self):
        embed_fns = []
        d = self.input_dims
        out_dim = 0
        if self.include_input:
            embed_fns.append(lambda x: x)
            out_dim += d

        max_freq = self.multi_freq - 1
        N_freqs = self.multi_freq

        if self.log_sampling:
            freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs)
        else:
            freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs)

        for freq in freq_bands:
            for p_fn in self.periodic_fns:
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
                out_dim += d

        self.embed_fns = embed_fns
        self.out_dim = out_dim
    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)

class FCLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        self.relu = nn.ReLU()
        
    def forward(self, input):
        return self.relu(self.linear(input))
    
class Nerf(nn.Module):
    def __init__(self, input_dim=4, mi_dim=256, output_dim=1, fourier_dim=84):
        super(Nerf, self).__init__()
        # 傅里叶编码层
        self.fourier_encoding = FreqEmbedder(multi_freq=10, include_input=True, input_dims=3, log_sampling=True)
        
        # 全连接层
        self.net = []
        self.net.append(FCLayer(84,mi_dim))
        for i in range(8):
            self.net.append(FCLayer(mi_dim,mi_dim)) 
        self.net = nn.Sequential(*self.net)
        
        self.out1 = nn.Linear(256, output_dim)
        self.out2 = nn.Linear(256, output_dim)
        self.out3 = nn.Linear(256, output_dim)

    def forward(self, x):
        # 傅里叶编码
        x = self.fourier_encoding.embed(x)
        x = self.net(x)

        out1 = self.out1(x)
        out2 = self.out2(x)
        out3 = self.out3(x)
        out = (out1+out2+out3)/3
        data = torch.stack([out1,out2,out3], axis=-1)
        var = torch.var(data, axis=-1)
        return out, var, out1, out2, out3


In [4]:
import torch
from torch import nn
import os
import numpy as np
import torch.optim as optim
import tqdm
from datetime import datetime
from shutil import copy, copytree
import json
import time
from torch.cuda.amp import autocast, GradScaler
from torch.profiler import profile, record_function, ProfilerActivity
import math

def kl_d(var, mse):
    mse = mse/mse.sum().detach()
    var = var/var.sum()
    kl_loss = torch.nn.functional.kl_div(
        torch.log(var+1.e-16),
        torch.log(mse+1.e-16),
        reduction='none',
        log_target=True,
    ).mean()
    return kl_loss

def trainNet(model,args,dataset):
    result_dir = os.path.join(args.result_dir, f'{args.dataset}', f'MDSRN')

    logs_dir = os.path.join(result_dir, 'logs')
    checkpoints_dir = os.path.join(result_dir, 'checkpoints')
    outputs_dir = os.path.join(result_dir, 'outputs')
    os.makedirs(logs_dir, exist_ok=True)
    os.makedirs(checkpoints_dir, exist_ok=True)
    os.makedirs(outputs_dir, exist_ok=True)
    
    loss_log_file = result_dir+'/'+'loss-'+'-'+str(args.interval)+'-'+str(args.init)+'-'+str(args.active)+'.txt'
    # todo: FusedAdam训练很不稳定
    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9,0.999), weight_decay=1e-6, fused=True)
    # optimizer = FusedAdam(model.parameters(), lr=args.lr, betas=(0.9,0.999), weight_decay=1e-2)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.1)
    mse_loss = nn.MSELoss()
    # 初始化梯度缩放器
    scaler = GradScaler(enabled=args.fp16)
    
    t = 0
    start_time = time.time()
    with open(loss_log_file,"a") as f:
        f.write(f"time:{time.time()}")
        f.write('\n')
    for epoch in range(1,args.num_epochs+1):
        model.train()
        training_data_inputs, training_data_outputs, batchIndexGenerator = dataset.get_data()
        loss_mse = 0
        loss_kl = 0
        loss_grad = 0
        loop = tqdm.tqdm(batchIndexGenerator)
        l = 30*(500**((epoch-1) / (args.num_epochs-1))-1)/499 

        for current_idx, next_idx in loop:
            coord = training_data_inputs[current_idx:next_idx].contiguous()
            v = training_data_outputs[current_idx:next_idx].contiguous()
            
            optimizer.zero_grad()
            with autocast(enabled=args.fp16):
                mean, var, out1, out2, out3 = model(coord)
                #mse = mse_loss(mean.view(-1),v.view(-1))
                mse = (mse_loss(out1.view(-1),v.view(-1)) + mse_loss(out2.view(-1),v.view(-1)) + mse_loss(out3.view(-1),v.view(-1))) / 3
                kl = kl_d(var.view(-1), ((mean.view(-1)-v.view(-1))**2))
                loss = mse + l*kl

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            loss_mse += mse.mean().item()
            loss_kl += kl.mean().item()

            loop.set_description(f'Epoch [{epoch}/{args.num_epochs}]')
            loop.set_postfix(mse=loss_mse, kl=loss_kl)
        adjust_lr(args, optimizer, epoch)
        # scheduler.step()

        with open(loss_log_file,"a") as f:
            f.write(f"Epochs {epoch}: loss = {loss_mse}, lr = {optimizer.param_groups[0]['lr']}")
            f.write('\n')

        if epoch%args.checkpoint == 0 or epoch==1:
            torch.save(model.state_dict(),checkpoints_dir+'/'+'-'+str(args.interval)+'-'+str(args.init)+'-'+str(epoch)+'.pth')
            
    with open(loss_log_file,"a") as f:
        f.write(f"time:{time.time()}")
        f.write(f"time:{time.time()-start_time}")
        f.write('\n')

@torch.no_grad()
def inf(model,dataset,args, result_dir=None):
    ckpt = './'+args.dataset+args.ckpt
    result_dir = os.path.dirname(os.path.dirname(ckpt)) if result_dir is None else result_dir
    outputs_dir = os.path.join(result_dir, 'outputs', 'inference')
    var_dir = os.path.join(result_dir, 'outputs', 'var')
    os.makedirs(outputs_dir, exist_ok=True)
    os.makedirs(var_dir, exist_ok=True)

    model.eval()
    samples = dataset.samples
    for i in range(len(samples)):  
        for j in range(0,dataset.interval+1):
            frame_idx = samples[i] + j
            val_data_inputs, batchIndexGenerator =dataset._get_testing_data(frame_idx)
            v = []
            d = []
            loop = tqdm.tqdm(batchIndexGenerator)
            for current_idx, next_idx in loop:
                coord = val_data_inputs[current_idx:next_idx]
                with torch.no_grad():
                    dat, var, out1, out2, out3 = model(coord)
                    dat = dat.view(-1)
                    var = var.view(-1)
                    d.append(dat)
                    v.append(var)
            d = torch.cat(d,dim=-1).float()
            d = d.detach().cpu().numpy()
            d = np.asarray(d,dtype='<f')
            out_path = f'{outputs_dir}/{frame_idx:04}.dat'
            d.tofile(out_path, format='<f')
            v = torch.cat(v,dim=-1).float()
            v = v.detach().cpu().numpy()
            v = np.asarray(v,dtype='<f')
            var_path = f'{var_dir}/{frame_idx:04}.dat'
            v.tofile(var_path, format='<f')
    
def adjust_lr(args, optimizer, epoch):
    if args.lr_s=='exp':
        lr = args.lr * math.exp(-0.02 * epoch)
    elif args.lr_s=='step':
        lr = args.lr * (0.5 ** (epoch // 50))
    elif args.lr_s == 'cosine':
        T_max = args.num_epochs
        eta_min = 0
        lr = eta_min + (args.lr - eta_min) * (1 + math.cos(math.pi * epoch / T_max)) / 2
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
import argparse

p = argparse.ArgumentParser()
p.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
p.add_argument('--gpu', type=str,default='0')
p.add_argument('--seed', type=int, default=42)
p.add_argument('--fp16', action="store_true")
p.add_argument('--compile', action="store_true")
# General training options
p.add_argument('--downsample_factor', type=int, default=4, metavar='N',
                    help='downsample factor')
p.add_argument('--batch_size', type=int, default=8000)
p.add_argument('--lr', type=float, default=1e-5, help='learning rate. default=1e-4')
p.add_argument('--num_epochs', type=int, default=200,
               help='Number of epochs to train for.')
p.add_argument('--checkpoint', type=int, default=50,
               help='checkpoint is saved.')
p.add_argument('--ckpt', type=str,default="/MDSRN/checkpoints/-0-64-200.pth",
               help='checkpoint path.')
p.add_argument('--dataset', type=str, default='fivejets',
               help='Scalar dataset; one of (Vortex, combustion)')
p.add_argument('--result_dir', type=str, default='./', metavar='N',
                    help='the path where we stored the synthesized data')
p.add_argument('--interval', type=int, default=0, metavar='N',
                    help='temporal upscaling factor')
p.add_argument('--active', type=str, default='sine', metavar='N',
                    help='active function')
p.add_argument('--init', type=int, default=64, metavar='N',
                    help='init features')
p.add_argument('--num_res', type=int, default=10, metavar='N',
                    help='number of residual blocks')
p.add_argument('--lr_s', type=str, default='cosine', help='step or exp')
p.add_argument('--mode', type=str, default='inf', metavar='N',
                    help='the path where we stored the synthesized data')
opt = p.parse_known_args()[0]

import torch
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu
os.environ["KMP_DUPLICATE_LIB_OK"]  =  "TRUE"

opt.cuda = not opt.no_cuda and torch.cuda.is_available()
seed_everything(opt.seed)

torch.set_float32_matmul_precision('high')

def main():
    print('FP16 enbled: ', opt.fp16)
    print('Compile enbled: ', opt.compile)
    Data = ScalarDataSet(opt)
    Model = CoordNet(4,1, init_features=64, num_res=7)
#    Model = CoordNetBottleNeck(4,1, init_features=288, num_res=1)
#     Model = Nerf()
    if opt.mode in ['inf', 'ue']:
        ckpt = './'+opt.dataset+opt.ckpt
        Model.load_state_dict(torch.load(ckpt))
    if opt.compile:
        Model.compile()
    Model.cuda()

    if opt.mode == 'train':
        print('Initalize Model Successfully using Sine Function!')
        trainNet(Model,opt,Data)
    elif opt.mode == 'inf':
        inf(Model, Data,opt)
    
if __name__== "__main__":
    main()

In [7]:
#得到误差
import os
import numpy as np
data_name = "combustion"#opt.dataset#'h2'
origin_dir = './dataset/' + data_name + '/'
recons_dir = './' + data_name + '/MDSRN/outputs/inference/'
error_dir = './' + data_name + '/MDSRN/outputs/error/'
var_dir = './' + data_name + '/MDSRN/outputs/var/'
os.makedirs(error_dir, exist_ok=True)

for i in range(1,2):
    d = np.fromfile(recons_dir + '{:04d}.dat'.format(i), dtype='<f')
    real = np.fromfile(origin_dir + '{:04d}.raw'.format(i), dtype='<f')
    real = 2*(real-np.min(real))/(np.max(real)-np.min(real))-1
    error = (real - d) ** 2    
    error.tofile(error_dir+'{:04d}.dat'.format(i), format='<f')

In [None]:
#计算PSNR,norm
import numpy as np
import torch
psnr = 0
k=0
psnr_fn_paper = lambda gt, pred, diff: 10. * torch.log10(diff**2 / torch.mean((gt-pred)**2))
for i in range(1,2):
    gt = np.fromfile(origin_dir + '{:04d}.raw'.format(i),dtype=np.float32)
    d = np.fromfile(recons_dir + "{:04d}.dat".format(i),dtype=np.float32)
    gt = 2*(gt-np.min(gt))/(np.max(gt)-np.min(gt))-1
    d = torch.from_numpy(d)
    gt = torch.from_numpy(gt)    
    diff = gt.max() - gt.min()
    
    psnr_volume = psnr_fn_paper(gt, d, diff)
    print(str(i)+":"+str(psnr_volume.item()))
    psnr+=psnr_volume.item()
    k+=1
# print(psnr/k)
def compute_nll(variance, error, epsilon=1e-8):
    safe_variance = np.maximum(variance,  epsilon)
    nll_terms = 0.5 * ( (error**2) / safe_variance + np.log(safe_variance) )
    return np.mean(nll_terms)

#计算corr,norm 
k = 0 
t_corr = 0 
for i in range(1,2): 
    k += 1 
    v = np.fromfile(var_dir + "{:04d}.dat".format(i), dtype='<f') 
    e = np.fromfile(error_dir + "{:04d}.dat".format(i), dtype='<f') 
    corr = np.corrcoef(v, e)  
    print(str(i)+":"+str(corr[0,1])) 
    t_corr+=corr[0,1] 
    nll = compute_nll(v, e)
    print(str(i)+":"+str(nll))