In [1]:
%load_ext autoreload
%autoreload 2
import sys, os
sys.path.insert(0, os.path.join('.', 'code'))   
from utils import BatchIndex,get_mgrid,fast_random_choice,count_params,cleanup,seed_everything,dataset_selection,adjust_lr
from load_data_INR import LoadData
from model_INR import CoordNet,CoordNetBottleNeck,NGPNet,HashCoordNet

In [2]:
import torch.utils
import numpy as np
import torch
import os
import threading
import queue
import tqdm
from concurrent.futures import ThreadPoolExecutor

class ScalarDataSet(LoadData):
    def __init__(self,args, device='cuda:0'):
        self.dataset, self.batch_size = args.dataset, args.batch_size
        self.temporal, self.spatial = args.temporal, args.spatial
        self.device = device        
        self.ori_dim, self.total_samples, self.data_path, self.downsample_factor = dataset_selection(self.dataset,self.spatial, 
                                                                                                     self.temporal)
        self.dim = [0,0,0]    
        for i in range(len(self.ori_dim)):
            self.dim[i] = int(self.ori_dim[i] / self.spatial)

        self.num_workers = 16

        self.samples = [i for i in range(1,self.total_samples+1,self.temporal+1)]
        self.total_samples = self.samples[-1]
        self.num_samples_per_frame = (self.dim[0]*self.dim[1]*self.dim[2]//self.downsample_factor)//self.batch_size * self.batch_size

        self.queue_size = 2
        self.loader_queue = queue.Queue(maxsize=self.queue_size)  # 限制队列大小为2
        self.executor = ThreadPoolExecutor(max_workers=self.queue_size)

        if args.mode == 'train':
            self.data = self.preload_with_multi_threads(self.load_volume_data, num_workers=self.num_workers, data_str='Volume Data')
            self.data = torch.as_tensor(np.asarray(self.data), device=self.device)  # [t个时间步, z, y, x] 需要改成xyz的形式

            self.len = self.num_samples_per_frame * len(self.samples)
            self._get_data = self._get_training_data

        samples = self.ori_dim[2]*self.ori_dim[1]*self.ori_dim[0]
        self.coords = get_mgrid([self.ori_dim[0],self.ori_dim[1],self.ori_dim[2]],dim=3)
        self.time = np.zeros((samples,1))
        self.testing_data_inputs = torch.as_tensor(np.concatenate((self.time, self.coords),axis=1), dtype=torch.float, device='cuda:0')
        self.preload_data()
        
    @torch.no_grad()
    def _get_training_data(self):
        training_data_inputs = []
        training_data_outputs = []

        for i in range(0, len(self.samples)):
            x,y,z = fast_random_choice(self.dim, self.num_samples_per_frame)
            t = torch.ones_like(x) * (self.samples[i]-1)

            outputs = self.data[i, x, y, z]  # 第i个体数据中取xyz, 第i个体数据对应的时间步是t
            # 归一化

            x = x * self.spatial / (self.ori_dim[0] - 1)  #x / (self.dim[0] - 1)
            y = y * self.spatial / (self.ori_dim[1] - 1)  #y / (self.dim[1] - 1)
            z = z * self.spatial / (self.ori_dim[2] - 1)  #z / (self.dim[2] - 1)
            t = t / max((self.total_samples-1), 1)

            inputs = torch.stack([t, x, y, z], dim=-1)
            inputs = 2.0 * inputs - 1.0  # 缩放到[-1,1]
            training_data_inputs.append(inputs)
            training_data_outputs.append(outputs)

        training_data_inputs = torch.cat(training_data_inputs, dim=0).cuda()
        training_data_outputs = torch.cat(training_data_outputs, dim=0).cuda()
        idx = torch.randperm(training_data_inputs.shape[0], device='cpu')
        training_data_inputs = training_data_inputs[idx].contiguous()
        training_data_outputs = training_data_outputs[idx].contiguous()
        batchidxgenerator = BatchIndex(self.len, self.batch_size, shuffle=True)
        del idx
        cleanup()
        return training_data_inputs, training_data_outputs, batchidxgenerator

In [3]:
import torch
from torch import nn
import os
import numpy as np
import torch.optim as optim
import tqdm
from datetime import datetime
from shutil import copy, copytree
import json
import time
from torch.cuda.amp import autocast, GradScaler
from torch.profiler import profile, record_function, ProfilerActivity
import math

def trainNet(model,args,dataset):
    result_dir = os.path.join(args.result_dir, f'{args.dataset}', f'CoordNet')

    checkpoints_dir = os.path.join(result_dir, 'checkpoints')
    outputs_dir = os.path.join(result_dir, 'outputs')
    os.makedirs(checkpoints_dir, exist_ok=True)
    os.makedirs(outputs_dir, exist_ok=True)
    
    loss_log_file = result_dir+'/'+'loss.txt'
    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9,0.999), weight_decay=1e-6, fused=True)
    mse_loss = nn.MSELoss()
    scaler = GradScaler(enabled=args.fp16)
    
    start_time = time.time()
    for epoch in range(1,args.num_epochs+1):
        model.train()
        training_data_inputs, training_data_outputs, batchIndexGenerator = dataset.get_data()
        loss_mse = 0
        loss_grad = 0
        loop = tqdm.tqdm(batchIndexGenerator)

        for current_idx, next_idx in loop:
            coord = training_data_inputs[current_idx:next_idx].contiguous()
            v = training_data_outputs[current_idx:next_idx].contiguous()
            
            optimizer.zero_grad()
            with autocast(enabled=args.fp16):
                v_pred = model(coord)
                loss = mse_loss(v_pred.view(-1),v.view(-1))

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            loss_mse += loss.mean().item()

            loop.set_description(f'Epoch [{epoch}/{args.num_epochs}]')
            loop.set_postfix(loss=loss_mse)
        adjust_lr(args, optimizer, epoch)
        # scheduler.step()

        with open(loss_log_file,"a") as f:
            f.write(f"Epochs {epoch}: loss = {loss_mse}, lr = {optimizer.param_groups[0]['lr']}")
            f.write('\n')

        if epoch%args.checkpoint == 0 or epoch == 1:
            torch.save(model.state_dict(),checkpoints_dir+'/'+'-'+str(args.spatial)+'-'+str(args.temporal)+'-'+str(epoch)+'.pth')
    with open(loss_log_file,"a") as f:
        f.write(f"time:{time.time()-start_time}")
        f.write('\n')

@torch.no_grad()
def inf(model,dataset,args, result_dir=None):
    ckpt = './result/'+args.dataset+args.ckpt+'-'+str(args.spatial)+'-'+str(args.temporal)+'-'+str(args.num_epochs)+'.pth'
    result_dir = os.path.dirname(os.path.dirname(ckpt)) if result_dir is None else result_dir
    outputs_dir = os.path.join(result_dir, 'outputs', str(args.spatial)+'-'+str(args.temporal))
    os.makedirs(outputs_dir, exist_ok=True)

    model.eval()
    samples = dataset.samples
    for i in range(len(samples)):  
        for j in range(0,dataset.temporal+1):
            frame_idx = samples[i] + j
            val_data_inputs, batchIndexGenerator =dataset._get_testing_data(frame_idx)
            d = []
            loop = tqdm.tqdm(batchIndexGenerator)
            for current_idx, next_idx in loop:
                coord = val_data_inputs[current_idx:next_idx]
                with torch.no_grad():
                    dat = model(coord).view(-1)
                    d.append(dat)
            d = torch.cat(d,dim=-1).float()
            d = d.detach().cpu().numpy()
            d = np.asarray(d,dtype='<f')
            out_path = f'{outputs_dir}/{frame_idx:04}-CoordNet.raw'
            d.tofile(out_path, format='<f')

In [None]:
import argparse

p = argparse.ArgumentParser()
p.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
p.add_argument('--gpu', type=str,default='0')
p.add_argument('--seed', type=int, default=42)
p.add_argument('--fp16', action="store_true")
# General training options
p.add_argument('--batch_size', type=int, default=8000)
p.add_argument('--lr', type=float, default=5e-5, help='learning rate. default=1e-4')
p.add_argument('--num_epochs', type=int, default=200,
               help='Number of epochs to train for.')
p.add_argument('--checkpoint', type=int, default=100,
               help='checkpoint is saved.')
p.add_argument('--ckpt', type=str,default="/CoordNet/checkpoints/",help='checkpoint path.')
p.add_argument('--result_dir', type=str, default='./result/', metavar='N',
                    help='the path where we stored the synthesized data')
p.add_argument('--temporal', type=int, default=0, metavar='N')
p.add_argument('--lr_s', type=str, default='cosine', help='learning rate scheduler')

p.add_argument('--dataset', type=str, default='vortex')
p.add_argument('--spatial', type=int, default=2, metavar='N')
p.add_argument('--mode', type=str, default='train', metavar='N')
# opt = p.parse_known_args()[0]
opt = p.parse_args(args=[])

import torch
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu
os.environ["KMP_DUPLICATE_LIB_OK"]  =  "TRUE"

opt.cuda = not opt.no_cuda and torch.cuda.is_available()
seed_everything(opt.seed)

torch.set_float32_matmul_precision('high')

def main():
    print('FP16 enbled: ', opt.fp16)
    Data = ScalarDataSet(opt)
    Model = CoordNet(4,1,init_features=64, num_res=5)

    if opt.mode in ['inf', 'ue']:
        ckpt = './result/'+opt.dataset+opt.ckpt+'-'+str(opt.spatial)+'-'+str(opt.temporal)+'-'+str(opt.num_epochs)+'.pth'
        Model.load_state_dict(torch.load(ckpt))
    Model.cuda()

    if opt.mode == 'train':
        print('Initalize Model Successfully using Sine Function!')
        trainNet(Model,opt,Data)
        inf(Model, Data,opt)
    elif opt.mode == 'inf':
        inf(Model, Data,opt)
    
if __name__== "__main__":
    main()


FP16 enbled:  False


                                                          

Initalize Model Successfully using Sine Function!


  scaler = GradScaler(enabled=args.fp16)
  with autocast(enabled=args.fp16):
Epoch [1/100]: 100%|██████████| 32/32 [00:00<00:00, 44.81it/s, loss=10.2]
Epoch [2/100]: 100%|██████████| 32/32 [00:00<00:00, 63.13it/s, loss=5.25]
Epoch [3/100]: 100%|██████████| 32/32 [00:00<00:00, 63.24it/s, loss=3.75]
Epoch [4/100]: 100%|██████████| 32/32 [00:00<00:00, 62.70it/s, loss=2.79]
Epoch [5/100]: 100%|██████████| 32/32 [00:00<00:00, 64.06it/s, loss=2.21] 
Epoch [6/100]: 100%|██████████| 32/32 [00:00<00:00, 62.58it/s, loss=1.94] 
Epoch [7/100]: 100%|██████████| 32/32 [00:00<00:00, 63.81it/s, loss=1.83] 
Epoch [8/100]: 100%|██████████| 32/32 [00:00<00:00, 63.68it/s, loss=1.76] 
Epoch [9/100]: 100%|██████████| 32/32 [00:00<00:00, 63.04it/s, loss=1.71] 
Epoch [10/100]: 100%|██████████| 32/32 [00:00<00:00, 63.40it/s, loss=1.56] 
Epoch [11/100]: 100%|██████████| 32/32 [00:00<00:00, 64.05it/s, loss=1.42] 
Epoch [12/100]: 100%|██████████| 32/32 [00:00<00:00, 63.66it/s, loss=1.34] 
Epoch [13/100]: 100%|███

In [6]:
import numpy as np
import torch
import matplotlib.pyplot as plt
data_name = opt.dataset
origin_dir = './dataset/' + data_name + '/'
recons_dir = './result/' + data_name + '/CoordNet/outputs/'+str(opt.spatial)+'-'+str(opt.temporal)+'/'
psnr,k = 0,0
line = []
psnr_fn_paper = lambda gt, pred, diff: 10. * torch.log10(diff**2 / torch.mean((gt-pred)**2))
for i in range(1,2):
    gt = np.fromfile(origin_dir + '{:04d}.raw'.format(i),dtype=np.float32)
    
    filename = f"{i:04d}-CoordNet.raw"
    file_path = os.path.join(recons_dir, filename)
    d = np.fromfile(file_path, dtype=np.float32)
    
    gt = 2*(gt-np.min(gt))/(np.max(gt)-np.min(gt))-1
    d = torch.from_numpy(d)
    gt = torch.from_numpy(gt)    
    diff = gt.max() - gt.min()
    
    psnr_volume = psnr_fn_paper(gt, d, diff)
    print(str(i)+":"+str(psnr_volume.item()))
    line.append(psnr_volume.item())
    psnr+=psnr_volume.item()
    k+=1
print(psnr/k)

1:41.776336669921875
41.776336669921875


In [25]:
data_name = opt.dataset
origin_dir = './dataset/' + data_name + '/'
recons_dir = './result/' + data_name + '/CoordNet/outputs/'+str(opt.spatial)+'-'+str(opt.temporal)+'/'

for i in range(1,2):
    filename = f"{i:04d}-CoordNet.raw"
    file_path = os.path.join(recons_dir, filename)
    d = np.fromfile(file_path, dtype=np.float32)
    d[0] = -1.0
    d[1] = 1.0
    savename = f"CoordNet.raw"
    d_path = os.path.join(recons_dir, savename)
    d.astype(np.float32).tofile(d_path)

In [7]:
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"cuDNN version: {torch.backends.cudnn.version()}")

CUDA available: True
CUDA version: 11.8
cuDNN version: 90100
