In [None]:
import os
import logging
import random 
import tqdm

import adios2 as ad2
import torch
import numpy as np

from sklearn.model_selection import train_test_split


In [None]:
conda install adios2

#Load Dataset

In [None]:
def read_f0(istep, expdir=None, iphi=None, inode=0, nnodes=None, average=False, 
            randomread=0.0, nchunk=16, fieldline=False):
    """
    Read XGC f0 data
    """
    def adios2_get_shape(f, varname):
        nstep = int(f.available_variables()[varname]['AvailableStepsCount'])
        shape = f.available_variables()[varname]['Shape']
        lshape = None
        if shape == '':
            ## Accessing Adios1 file
            ## Read data and figure out
            v = f.read(varname)
            lshape = v.shape
        else:
            lshape = tuple([ int(x.strip(',')) for x in shape.strip().split() ])
        return (nstep, lshape)

    fname = os.path.join(expdir, 'restart_dir/xgc.f0.%05d.bp'%istep)
    if randomread > 0.0:
        ## prefetch to get metadata
        with ad2.open(fname, 'r') as f:
            nstep, nsize = adios2_get_shape(f, 'i_f')
            ndim = len(nsize)
            nphi = nsize[0]
            _nnodes = nsize[2] if nnodes is None else nnodes
            nmu = nsize[1]
            nvp = nsize[3]
        assert _nnodes%nchunk == 0
        _lnodes = list(range(inode, inode+_nnodes, nchunk))
        lnodes = random.sample(_lnodes, k=int(len(_lnodes)*randomread))
        lnodes = np.sort(lnodes)

        lf = list()
        li = list()
        for i in tqdm(lnodes):
            li.append(np.array(range(i,i+nchunk), dtype=np.int32))
            with ad2.open(fname, 'r') as f:
                nphi = nsize[0] if iphi is None else 1
                iphi = 0 if iphi is None else iphi
                start = (iphi,0,i,0)
                count = (nphi,nmu,nchunk,nvp)
                _f = f.read('i_f', start=start, count=count).astype('float64')
                lf.append(_f)
        i_f = np.concatenate(lf, axis=2)
        lb = np.concatenate(li)
    elif fieldline is True:
        import networkx as nx

        fname2 = os.path.join(expdir, 'xgc.mesh.bp')
        with ad2.open(fname2, 'r') as f:
            _nnodes = int(f.read('n_n', ))
            nextnode = f.read('nextnode')
        
        G = nx.Graph()
        for i in range(_nnodes):
            G.add_node(i)
        for i in range(_nnodes):
            G.add_edge(i, nextnode[i])
            G.add_edge(nextnode[i], i)
        cc = [x for x in list(nx.connected_components(G)) if len(x) >= 16]

        li = list()
        for k, components in enumerate(cc):
            DG = nx.DiGraph()
            for i in components:
                DG.add_node(i)
            for i in components:
                DG.add_edge(i, nextnode[i])
            
            cycle = list(nx.find_cycle(DG))
            DG.remove_edge(*cycle[-1])
            
            path = nx.dag_longest_path(DG)
            #print (k, len(components), path[0])
            for i in path[:len(path)-len(path)%16]:
                li.append(i)

        with ad2.open(fname, 'r') as f:
            nstep, nsize = adios2_get_shape(f, 'i_f')
            ndim = len(nsize)
            nphi = nsize[0] if iphi is None else 1
            iphi = 0 if iphi is None else iphi
            _nnodes = nsize[2]
            nmu = nsize[1]
            nvp = nsize[3]
            start = (iphi,0,0,0)
            count = (nphi,nmu,_nnodes,nvp)
            logging.info (f"Reading: {start} {count}")
            i_f = f.read('i_f', start=start, count=count).astype('float64')
        
        _nnodes = len(li)-inode if nnodes is None else nnodes
        lb = np.array(li[inode:inode+_nnodes], dtype=np.int32)
        logging.info (f"Fieldline: {len(lb)}")
        logging.info (f"{lb}")
        i_f = i_f[:,:,lb,:]
    else:
        with ad2.open(fname, 'r') as f:
            nstep, nsize = adios2_get_shape(f, 'i_f')
            ndim = len(nsize)
            nphi = nsize[0] if iphi is None else 1
            iphi = 0 if iphi is None else iphi
            _nnodes = nsize[2]-inode if nnodes is None else nnodes
            nmu = nsize[1]
            nvp = nsize[3]
            start = (iphi,0,inode,0)
            count = (nphi,nmu,_nnodes,nvp)
            logging.info (f"Reading: {start} {count}")
            i_f = f.read('i_f', start=start, count=count).astype('float64')
            #e_f = f.read('e_f')
        li = list(range(inode, inode+_nnodes))
        lb = np.array(li, dtype=np.int32)

    # if i_f.shape[3] == 31:
    #     i_f = np.append(i_f, i_f[...,30:31], axis=3)
    #     # e_f = np.append(e_f, e_f[...,30:31], axis=3)
    if i_f.shape[3] == 39:
        i_f = np.append(i_f, i_f[...,38:39], axis=3)
        i_f = np.append(i_f, i_f[:,38:39,:,:], axis=1)

    Z0 = np.moveaxis(i_f, 1, 2)

    if average:
        Z0 = np.mean(Z0, axis=0)
        zlb = lb
    else:
        Z0 = Z0.reshape((-1,Z0.shape[2],Z0.shape[3]))
        _lb = list()
        for i in range(nphi):
            _lb.append( i*100_000_000 + lb)
        zlb = np.concatenate(_lb)
    
    #zlb = np.concatenate(li)
    zmu = np.mean(Z0, axis=(1,2))
    zsig = np.std(Z0, axis=(1,2))
    zmin = np.min(Z0, axis=(1,2))
    zmax = np.max(Z0, axis=(1,2))
    Zif = (Z0 - zmin[:,np.newaxis,np.newaxis])/(zmax-zmin)[:,np.newaxis,np.newaxis]

    return (Z0, Zif, zmu, zsig, zmin, zmax, zlb)

def read_data(data_dir,num_channels=1):
    Z0, Zif, zmu, zsig, zmin, zmax, zlb = read_f0(420, expdir=data_dir, iphi=0)
    #print(Zif.shape, zlb.shape, zmu.shape, zsig.shape)
    lx = list()
    ly = list()
    for i in range(0,len(Zif)-num_channels,num_channels):
        X = Zif[i:i+num_channels,:,:]
        lx.append(X)
        ly.append(zlb[i:i+num_channels])
    
    X_train, X_test, y_train, y_test = train_test_split(lx, ly, test_size=0.10, random_state=42)
    
    return X_train, X_test, y_train, y_test


class XGCDataset:
    def __init__(self, base_X, base_Y, split, transform=None, patch_size=5):
        super().__init__()
        self.patch_size = patch_size
        self.split = split
        self.image_list = base_X
        self.label_list = base_Y
        self.transform = transform
        
        orig_sz = base_X[0].shape[1]
        self.image_size = int(orig_sz/patch_size)
        
        self.num_patches = int((orig_sz*orig_sz)/(patch_size*patch_size))
        self.ids =[]
        self.sub_ids=[]
        
        for i in range(0,len(self.image_list)):
            self.ids+=self.num_patches*[i]
            self.sub_ids+=range(0,self.num_patches)
        
        print('data init:',self.num_patches,len(self.image_list),len(self.ids),len(self.sub_ids))
    
    def __len__(self): 
        return len(self.ids)
    
    def __getitem__(self,i):
        orig_image = self.image_list[self.ids[i]]
        orig_image=orig_image[0,:,:]
        #print(orig_image.shape)
        
        sub_idx = self.sub_ids[i]
        
        ridx = int(sub_idx/self.image_size)
        cidx = int(sub_idx%self.image_size)
        
        rs= ridx*self.patch_size
        re = rs+self.patch_size
        cs = cidx*self.patch_size
        ce = cs+self.patch_size
        
        image = orig_image[rs:re,cs:ce]
        image = image[np.newaxis,:,:]
        if self.transform:
            image = self.transform(image)
        
        sample = {'image': torch.as_tensor(image.copy()).float(), 
                  'label': self.label_list[self.ids[i]],
                  'rsid': rs, 'csid':cs}
        
        return sample

Hyper Network Module

In [None]:
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F

class HyperNetwork(nn.Module):

    def __init__(self, f_size = 3, z_dim = 64, out_size=16, in_size=16):
        super(HyperNetwork, self).__init__()
        self.z_dim = z_dim
        self.f_size = f_size
        self.out_size = out_size
        self.in_size = in_size

        self.w1 = Parameter(torch.fmod(torch.randn((self.z_dim, self.out_size*self.f_size*self.f_size)).cuda(),2))
        self.b1 = Parameter(torch.fmod(torch.randn((self.out_size*self.f_size*self.f_size)).cuda(),2))

        self.w2 = Parameter(torch.fmod(torch.randn((self.z_dim, self.in_size*self.z_dim)).cuda(),2))
        self.b2 = Parameter(torch.fmod(torch.randn((self.in_size*self.z_dim)).cuda(),2))

    def forward(self, z):

        h_in = torch.matmul(z, self.w2) + self.b2
        h_in = h_in.view(self.in_size, self.z_dim)

        h_final = torch.matmul(h_in, self.w1) + self.b1
        kernel = h_final.view(self.out_size, self.in_size, self.f_size, self.f_size)

        return kernel







In [None]:
#ResNet Block

class IdentityLayer(nn.Module):

    def forward(self, x):
        return x


class ResNetBlock(nn.Module):

    def __init__(self, in_size=16, out_size=16, downsample = False):
        super(ResNetBlock,self).__init__()
        self.out_size = out_size
        self.in_size = in_size
        if downsample:
            self.stride1 = 2
            self.reslayer = nn.Conv2d(in_channels=self.in_size, out_channels=self.out_size, 
                                      stride=2, kernel_size=1)
        else:
            self.stride1 = 1
            self.reslayer = IdentityLayer()

        self.bn1 = nn.BatchNorm2d(out_size)
        self.bn2 = nn.BatchNorm2d(out_size)

    def forward(self, x, conv1_w, conv2_w):

        residual = self.reslayer(x)

        out = F.relu(self.bn1(F.conv2d(x, conv1_w, stride=self.stride1, padding=1)), inplace=True)
        out = self.bn2(F.conv2d(out, conv2_w, padding=1))

        out += residual

        out = F.relu(out)

        return out


In [None]:

class Embedding(nn.Module):

    def __init__(self, z_num, z_dim):
        super(Embedding, self).__init__()

        self.z_list = nn.ParameterList()
        self.z_num = z_num
        self.z_dim = z_dim

        h,k = self.z_num

        for i in range(h):
            for j in range(k):
                self.z_list.append(Parameter(torch.fmod(torch.randn(self.z_dim).cuda(), 2)))

    def forward(self, hyper_net):
        ww = []
        h, k = self.z_num
        for i in range(h):
            w = []
            for j in range(k):
                w.append(hyper_net(self.z_list[i*k + j]))
            ww.append(torch.cat(w, dim=1))
        return torch.cat(ww, dim=0)


class PrimaryNetwork(nn.Module):

    def __init__(self, z_dim=64, patch_size=5):
        super(PrimaryNetwork, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.patch_size= patch_size
        self.z_dim = z_dim
        self.hope = HyperNetwork(z_dim=self.z_dim)

        self.zs_size = [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1],
                        [2, 1], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2],
                        [4, 2], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4]]

        self.filter_size = [[16,16], [16,16], [16,16], [16,16], [16,16], [16,16], [16,32], [32,32], [32,32], [32,32],
                            [32,32], [32,32], [32,64], [64,64], [64,64], [64,64], [64,64], [64,64]]

        self.res_net = nn.ModuleList()

        for i in range(18):
            down_sample = False
            if i > 5 and i % 6 == 0:
                down_sample = True
            self.res_net.append(ResNetBlock(self.filter_size[i][0], self.filter_size[i][1], downsample=down_sample))

        self.zs = nn.ModuleList()

        for i in range(36):
            self.zs.append(Embedding(self.zs_size[i], self.z_dim))

        self.global_avg = nn.AvgPool2d(8)
        self.final = nn.Linear(256,self.patch_size*self.patch_size)

    def forward(self, x):

        x = F.relu(self.bn1(self.conv1(x)))
        #print('conv1:',x.shape)
        for i in range(18):
            # if i != 15 and i != 17:
            w1 = self.zs[2*i](self.hope)
            w2 = self.zs[2*i+1](self.hope)
            x = self.res_net[i](x, w1, w2)
            #print('resnet:',i,x.shape)
        
        #print('final resnet:',x.shape)
        
        #x = self.global_avg(x)
        #print('avg pool:',x.shape)
        x = self.final(x.view(-1,256))
        x = x.view(-1,1,self.patch_size,self.patch_size)

        return x


Train Model

In [None]:
import os
from torch.autograd import Variable
import time

import torch.optim as optim


In [None]:
def set_data(patch_size, batch_size):
    
    X_train, X_test, y_train, y_test = read_data(dir_data)
    trainset = XGCDataset(X_train, y_train, split="train",
                               transform=None, patch_size=patch_size)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=False, num_workers=4)
    
    testset = XGCDataset(X_test, y_test, split="test",
                               transform=None, patch_size=patch_size)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=False, num_workers=4)
    
    return trainloader, testloader



In [None]:
def validate(device, net, testloader, criterion, batch_size):
    correct = 0.
    total = len(testloader)*batch_size
    
    for tdata in testloader:
        timages, tlabels = tdata['image'], tdata['label']
        toutputs = net(Variable(timages.cuda()))
        predicted = toutputs.cpu().data
        correct+= criterion(predicted,timages)
        
    return correct.item(), total


In [None]:
epochs = 100
batch_size = 256
save_freq = 20
patch_size = 8 
print_freq = 20
dir_out = 'checkpoint/'
dir_resume = 'checkpoint/hypernetworks_plasma.pth/'
dir_data =  'gpfs/alpine/world-shared/csc143/jyc/summit/d3d_coarse_small_v2' 
resume = False


if not os.path.exists(dir_out):
        os.makedirs(dir_out)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Calculate Model Parameters

In [None]:
net = PrimaryNetwork(patch_size=patch_size)
model_total_params = sum(p.numel() for p in net.parameters())
print(model_total_params)

In [None]:
conda install pytorch-model-summary


In [None]:
from pytorch_model_summary import summary
net = PrimaryNetwork(patch_size=8)
total=0
for name, param in net.named_parameters():
    if param.requires_grad:
        print(name, param.numel())
        if name.find('res_net'):
            #print('resnet')
            total+=param.numel()
        else:
            continue
        

print('total:',total)
#trainloader, testloader = set_data(8, 256)
#print(net)
#for i, data in enumerate(trainloader, 0):
 #   inputs, labels = data['image'], data['label']
    #inputs= Variable(inputs)
#print(summary(net, inputs, show_input=True, print_summary=True))
    

# Training Model

In [None]:
def train(device):
    trainloader, testloader = set_data(patch_size, batch_size)
    
    net = PrimaryNetwork(patch_size=patch_size)
    best_accuracy = 10000.

    if resume:
        ckpt = torch.load(args.dir_resume)
        net.load_state_dict(ckpt['net'])
        best_accuracy = ckpt['acc']

    net.cuda()

    learning_rate = 0.002
    weight_decay = 0.0005
    milestones = [168000, 336000, 400000, 450000, 550000, 600000]
    max_iter = epochs

    optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=0.5)
    criterion = nn.MSELoss()

    total_iter = 0
    epoch = 0
    #print_freq = args.print_freq
    num_batches=len(trainloader)
    loss_file=open(dir_out+'loss.txt','w')
    loss_file.write('epoch,loss,val_loss\n')
    start = time.time()
    
    print('data loader:',len(trainloader),len(testloader))
    while total_iter < max_iter:

        running_loss = 0.0
        epoch_loss=0
        epoch=0
        for i, data in enumerate(trainloader, 0):

            inputs, labels = data['image'], data['label']
            rid, cid = data['rsid'], data['csid']

            inputs= Variable(inputs.cuda())

            optimizer.zero_grad()
            
            
            outputs = net(inputs)
            #print('outputs:',inputs.shape,outputs.shape)
            
            
            loss = criterion(outputs, inputs)
            loss.backward()

            optimizer.step()
            lr_scheduler.step()

            running_loss += loss.item()
            epoch_loss+=loss.item()
            if i % print_freq == (print_freq-1):
                print("[Epoch %d, Total Iterations %4d] Loss: %.4f" % (epoch + 1, 
                                                                       total_iter + 1, 
                                                                       running_loss/print_freq))
                running_loss = 0.0
            
            epoch += 1
            
        
        total_iter += 1

        correct, total= validate(device, net, testloader, criterion, batch_size)
        epoch_loss/=num_batches
        val_loss = correct/total
        string1=str(total_iter)+','+str(epoch_loss)+','+str(val_loss)+'\n'
        loss_file.write(string1)
        accuracy = (100. * correct) / total
        print('After epoch %d, accuracy: %.4f %%' % (total_iter, accuracy))

        if accuracy < best_accuracy:
            print('Saving model...')
            state = {
                'net': net.state_dict(),
                'acc': accuracy
            }
            torch.save(state, dir_out+'hypernetworks_plasma_'+str(total_iter)+'.pth')
            best_accuracy = accuracy
        
    print('Finished Training')
    state = {
                'net': net.state_dict(),
                'acc': accuracy
            }
    torch.save(state, dir_out+'last.pth')
    loss_file.close()
    


In [None]:
train(device)

print('Finished Training!!')   

Test Data

In [None]:
def load_data(batch_size, dir_data, patch_size):
    X_train, X_test, y_train, y_test = read_data(dir_data)
    testset = XGCDataset(X_test, y_test, split="test",
                               transform=None, patch_size=patch_size)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=False, num_workers=4)
    
    return testloader, X_test, y_test


In [None]:
def load_model(device, dir_model):
    net = PrimaryNetwork(patch_size=patch_size)
    ckpt = torch.load(dir_model, map_location=device)
    state_dict = ckpt['net']
    
    new_state_dict = {}
    
    for k, v in state_dict.items():
        k = k.replace("module.", "")
        new_state_dict[k] = v
        state_dict = new_state_dict
    
    #net.to(device=device)
    net.load_state_dict(state_dict)
    
    return net


In [None]:
def aggregate(map_pred_img,predictions, rid, cid, labels, patch_size):
    
    for l in range(0,len(labels)):
        map_pred_img[labels[l][0]][rid[l]:rid[l]+patch_size,cid[l]:cid[l]+patch_size] = predictions[l,:,:]
        
    return map_pred_img


In [None]:
def rmse(predictions, targets):
    return np.sqrt(((predictions - targets) ** 2).mean())

In [None]:
from pytorch_model_summary import summary
dir_data = 'gpfs/alpine/world-shared/csc143/jyc/summit/d3d_coarse_small_v2'
model = load_model(device, 'checkpoint/hypernetworks_plasma_51.pth')
#testloader, Xtest, ytest = load_data(256, dir_data, 8)
help(summary)
#for i, data in enumerate(testloader,0):
 #       timages, tlabels = data['image'], data['label']
print(summary(model, torch.zeros((1, 1, 8, 8)), show_input=False,print_summary=True))

In [None]:
def test(device, batch_size, patch_size, dir_out, dir_model, dir_data):
    model = load_model(device, dir_model)
    testloader, Xtest, ytest = load_data(batch_size, dir_data, patch_size)
    model.eval()
    map_pred_img_ensemble={}
    
    for tid in ytest:
        #print('test labels:',tid[0])
        map_pred_img_ensemble[tid[0]]= np.zeros((40,40))
    
    print('testloader:',len(testloader))
    
    for i, data in enumerate(testloader,0):
        timages, tlabels = data['image'], data['label']
        rid, cid = data['rsid'], data['csid']
        #timages = timages.to(device=device, dtype=torch.float32)
        with torch.no_grad():
            toutputs = model(Variable(timages.cuda()))
            predicted = toutputs.cpu().data
        
        predicted = predicted.squeeze()
        predicted = predicted.numpy()
        tlabels = list(tlabels.numpy())
        rid = list(rid.numpy())
        cid = list(cid.numpy())
        orig_image = timages.numpy()
        #print(i,predicted.shape,orig_image.shape,len(rid),len(cid),len(tlabels))
        
        map_pred_img_ensemble = aggregate(map_pred_img_ensemble,predicted, rid, cid, 
                                          tlabels, patch_size)
    
    fname = 'mse_xgc_testrmse_'+str(batch_size)+'_'+str(len(testloader))+'.txt'
    error_file=open(dir_out+fname,'w')
    total_err =0
    X_pred =[]
    for l in range(0,len(ytest)):
        tid = ytest[l][0]
        targets = map_pred_img_ensemble[tid]
        #loss = np.mean((Xtest[l] - targets)**2)
        loss = rmse(targets, Xtest[l][0])
        total_err+=loss
        #print(tid,loss,Xtest[l][0].shape, targets.shape)
        string=str(tid)+','+str(loss.item())+'\n'
        error_file.write(string)
        X_pred.append(targets)
        
    print('total:',total_err/(len(testloader)*batch_size))
    string='total,'+str(total_err)+'\n'
    error_file.write(string)
        
    error_file.close()
    
    return Xtest, ytest, X_pred


In [None]:
batch_size = 256
patch_size = 8
dir_out = 'results/'
dir_model = 'checkpoint/hypernetworks_plasma_51.pth'
dir_data = 'gpfs/alpine/world-shared/csc143/jyc/summit/d3d_coarse_small_v2'
if not os.path.exists(dir_out):
        os.makedirs(dir_out)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
xtest, ytest, xpred = test(device,batch_size, patch_size, dir_out, dir_model, dir_data)

In [None]:
for l in range(0,len(ytest)):
    np.save(dir_out+'f_pred_'+str(ytest[l][0]),xpred)


In [None]:
import matplotlib.pyplot as plt

_, ny, nx = xtest[0].shape
ix = np.linspace(0, nx-1, nx)
iy = np.linspace(0, ny-1, ny)
Mx, My = np.meshgrid(ix, iy)

for l in range(0,len(ytest)):
    plt.figure(figsize=(4, 9))
    print('idx:',l)
    f = plt.figure(figsize=(10, 6))
    f.add_subplot(1,2, 1)
    plt.imshow(xtest[l][0], origin='lower')
    plt.colorbar()
    plt.contour(Mx, My, xtest[l][0], levels=5, origin='image', colors='white', alpha=0.5)
    plt.axis('scaled')
    plt.axis('off')
    plt.title('original %d'%(ytest[l][0]))
    plt.tight_layout()
    
    f.add_subplot(1,2, 2)
    plt.imshow(xpred[l], origin='lower')
    plt.colorbar()
    plt.contour(Mx, My, xpred[l], levels=5, origin='image', colors='white', alpha=0.5)
    plt.axis('scaled')
    plt.axis('off')
    plt.title('predicted %d'%(ytest[l][0]))
    plt.tight_layout()
    plt.show()
    
