In [15]:
import xarray as xr
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from torch.nn import functional as F
import os
import json
import copy
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import climate_train as ct
import sys
import scipy
import time
import argparse

In [17]:
def cg_options(string_input=[]):
    parser=argparse.ArgumentParser()
    parser.add_argument("--iter",type=int,default=1000)
    parser.add_argument("--tol",type=float,default=1e-3)
    parser.add_argument("-r","--rootdir",type=str,default='/scratch/cg3306/climate/runs/')
    parser.add_argument("-o","--outdir",type=str,default="default")
    parser.add_argument("--testrun",type=int,default=0)
    parser.add_argument("--action",type=str,default="optimize")
    parser.add_argument("--model_id",type=int,default=0)
    parser.add_argument("--data_address",type=str,default=\
                        '/scratch/ag7531/mlruns/19/bae994ef5b694fc49981a0ade317bf07/artifacts/forcing/')
    parser.add_argument("--relog",type=int,default=0)
    parser.add_argument("--nworkers",type=int,default=3)
    parser.add_argument("--rerun",type=int,default=0)
    parser.add_argument("--timing",type=int,default=0)
    parser.add_argument("-b","--batch",type=int,default=1)
    #args = parser.parse_args()
    if len(string_input)==0:
        return parser.parse_args()
    return parser.parse_args(string_input)


def model_bank(mid):
    reg=0
    if mid==0:
        return {'m':3,'reg':reg}
    if mid==1:
        return {'m':9,'reg':reg}
    if mid==2:
        return {'m':21,'reg':reg}

def load_from_save(args):
    root = args.rootdir +args.outdir
    PATH0 = root + "/last-model-ID"+str(args.model_id)+".npy"
    PATH1 = root + "/best-model-ID"+str(args.model_id)+".npy"
    LOG = root + "/log-ID"+str(args.model_id)+".json"
    model=model_bank(args.model_id)
    
    isdir = os.path.isdir(args.rootdir) 
    if not isdir:
        os.mkdir(args.rootdir)
    isdir = os.path.isdir(root) 
    if not isdir:
        os.mkdir(root)
    if not args.rerun and os.path.isfile(PATH0):
        try:
            M=torch.tensor(np.load(PATH0),dtype=torch.float32)
            print("Loaded the existing model",flush=True)
        except IOError:
            M=initialize_model(model['m'])
            print("No existing model found - new beginnings",flush=True)
    else:
        M=initialize_model(model['m'])
        print("- new beginnings",flush=True)
    if not args.relog:
        try: 
            with open(LOG) as f:
                logs = json.load(f)
        except IOError:
            logs = {"epoch":[],"train-loss":[],"test-loss":[]}
    if args.relog:
        logs = {"epoch":[],"train-loss":[],"test-loss":[]}
    return model,M, logs,PATH0,PATH1,LOG,root

def idct22(S4M_,IDCTB):
    S4M=torch.zeros(2,IDCTB.shape[0],IDCTB.shape[0])
    for j in range(2):
        S4M[j]=torch.matmul(IDCTB,torch.matmul(S4M_[j],np.transpose(IDCTB)))
    return S4M

def compute_moments(M,datagen,m,DCT,S22flag=False,reg=1e-3,scale=1,time_stats=[],time_rec=False):
    m2=m**2
    tot=0
    S4M=torch.zeros(2,m2*2,m2*2,dtype=torch.float32)
    if S22flag:
        S22=torch.zeros(2,m2*2,m2*2,dtype=torch.float32)
    device=ct.get_device()
    spread=np.int64( (m-1)/2)
    Mu=S4M*2
    for uv,mask,Sxy in datagen:
        if time_rec:
            time_stats.append([])
        uv=uv/scale
        Sxy=Sxy/scale
        
        spatxdim=mask.shape[3]
        spatydim=mask.shape[2]
        spatdim=mask.shape[3]*mask.shape[2]
        if time_rec:
            tic = time.perf_counter()
        uv,mask=uv.to(device),mask.to(device)
        outputL=[[] for i in range(2)]
        with torch.set_grad_enabled(False):
            output=mask*DCT(uv)
        output=output[0]
        if time_rec:
            toc = time.perf_counter()
            time_stats[-1].append(toc-tic)
        if time_rec:
            tic = time.perf_counter()
        tot+=mask.sum()
        output=output.view([m2*2,-1]) 
        if S22flag:
            Sxy=Sxy.view(2,1,-1)
            
        #  (C,C) * (C,n) -> (C,n)
        uMu=[torch.matmul(M[j],output) for j in range(2)] 
        #  (C,n) o (C,n) -> (1,n)
        uMu=[torch.mul(uMu[j],output).sum(0).view(1,-1) for j in range(2)]
        #  (C,n) * (1,n) -> (C,n)
        uMu=[torch.mul(output,uMu[j]) for j in range(2)]
        #  (C,n) * (n,C)
        uMu=[torch.matmul(uMu[j],output.T) for j in range(2)]
        if S22flag:
            uuS=[output*Sxy[j] for j in range(2)]
            uuS=[torch.matmul(uuS[j],output.T) for j in range(2)]
        for j in range(2):
            S4M[j]=S4M[j]+uMu[j]
            if S22flag:
                S22[j]=S22[j]+uuS[j]
        if time_rec:
            toc = time.perf_counter()
            time_stats[-1].append(toc-tic)
    
    S4M=S4M/tot
    #print(torch.abs(S4M).mean().item(),torch.abs(M).mean().item())
    if S22flag:
        S22=S22/tot
    else:
        return S4M+reg*M,time_stats
    return S4M+reg*M,S22,time_stats


def cosine_transform(m,nfreq=-1):
    device=ct.get_device()
    freq=np.array(np.unravel_index(np.arange(m**2),[m,m]))
    freqm=np.argsort(freq[0]+freq[1]*(1+1e-3))
    freq=freq[:,freqm]
    if nfreq>0:
        freq=freq[:,:nfreq]
    else:
        nfreq=m**2
    DCT=nn.Conv2d(2, 2*nfreq, m,bias=False).to(device)
    DCT.weight.data=DCT.weight.data*0
    
    
    
    IDCT=torch.zeros(m**2,nfreq)
    for i in range(nfreq):
        i0,i1=freq[:,i]
        W0=np.cos(np.pi/m * (np.arange(0,m)+1/2) * i0)
        W1=np.cos(np.pi/m * (np.arange(0,m)+1/2) * i1)
        W=np.matmul(np.reshape(W0,[-1,1]),np.reshape(W1,[1,-1]))
        W=W/np.sqrt(np.sum(W**2))
        W=torch.tensor(W,dtype=torch.float32)
        DCT.weight.data[i,0,:,:]=W
        DCT.weight.data[nfreq+i,1,:,:]=W
        IDCT[:,i]=W.view(-1)
    IDCTB=torch.tensor(scipy.linalg.block_diag(IDCT,IDCT))
    return DCT,IDCTB

def shift_transform(m):
    device=ct.get_device()
    m2=m**2
    SHT=nn.Conv2d(2, 2*(m2), m,bias=False).to(device)
    SHT.weight.data=SHT.weight.data*0
    for i in range(m2):
        i0,i1=np.unravel_index(i,[m,m])
        SHT.weight.data[i,0,i0,i1]=1
        SHT.weight.data[m2+i,1,i0,i1]=1
    return SHT

def initialize_model(m):
    return torch.zeros(2,2*m**2,2*m**2,dtype=torch.float32)

def conjugate_gradient_algorithm(args):
    (_,datagen),_,_,_=load_data(args)
    print("loaded data",flush=True)
    model,M, logs,PATH0,PATH1,LOG,root=load_from_save(args)
    print("loaded model",flush=True)
    m=model['m']
    reg=model['reg']
    DCT=shift_transform(m)
    tol=args.tol
    max_iter=args.iter
    time_stats=[]
    if len(M)==0:
        M=initialize_model(m)
    S4p,S22,time_stats=compute_moments(M,datagen,m,DCT,S22flag=True,time_stats=time_stats,reg=reg,time_rec=args.timing)

    r1=-S4p+S22
    p1=-S4p+S22
    reset=0
    alpha=torch.zeros(2,dtype=torch.float32)
    beta=torch.zeros(2,dtype=torch.float32)
    err=torch.zeros(2,dtype=torch.float32)
    for i in range(max_iter):
        p0=p1*1
        r0=r1*1
        
        S4p0,time_stats=compute_moments(p0,datagen,m,DCT,S22flag=False,time_stats=time_stats,reg=reg,time_rec=args.timing)
        
        for j in range(2):
            alpha[j]=torch.norm(r0[j].view(-1,1))**2/(torch.matmul(p0[j].view(1,-1),S4p0[j].view(-1,1)))
            r1[j]=r0[j]-alpha[j]*S4p0[j]
            beta[j]=torch.norm(r1[j].view(-1,1))**2/(torch.norm(r0[j].view(-1,1))**2)
            p1[j]=r1[j]+beta[j]*p0[j]
            M[j]=M[j]+alpha[j]*p0[j]
            err[j]=torch.norm(r1[j].view(-1,1))
        serr=torch.norm(err).item()
        print('#'+str(i)+' err: '+str(err),flush=True)
        logs['train-loss'].append(serr)
        logs['test-loss'].append(serr)
        
        with open(PATH0,'wb') as outfile:
            np.save(outfile,M.numpy())
        if logs['test-loss'][-1]==np.min(logs['test-loss']):
            with open(PATH1,'wb') as outfile:
                np.save( outfile,M.numpy())
        with open(LOG, 'w') as outfile:
            json.dump(logs, outfile)
        reset+=1
        if err[0].item()<tol:
            return M,err,time_stats 
        else:
            if reset>20:
                if np.mean(logs['train-loss'][-4:-1])/logs['train-loss'][-1] < 1.01 : 
                    S4p,time_stats=compute_moments(M,datagen,m,DCT,S22flag=False,time_stats=time_stats,reg=reg,time_rec=args.timing)
                    r1=-S4p+S22
                    p1=-S4p+S22
                    reset=0
    return M,err,time_stats #idct22(M,IDCTB)
   

def load_data(args):
    model=model_bank(args.model_id)
    m=model['m']
    spread= np.int64((m-1)/2)
    ds_zarr=xr.open_zarr(args.data_address)
    tot_time=ds_zarr.time.shape[0]
    params={'batch_size':args.batch,'shuffle':False, 'num_workers':args.nworkers}
    partition=ct.physical_domains(1,validation=False)
    if args.testrun>0:
        sub_tot_time=args.testrun
        rng = np.random.default_rng(0)
        indices=np.sort(rng.choice(tot_time,size=sub_tot_time,replace=False))
        ds_zarr=ds_zarr.sel(time=ds_zarr.time[indices].data)
    training_set = ct.Dataset(ds_zarr,partition['train'],spread=spread)
    training_generator = torch.utils.data.DataLoader(training_set, **params,collate_fn=ct.default_collate)
    
    test_set = ct.Dataset(ds_zarr,partition['test'],spread=spread)
    test_generator = torch.utils.data.DataLoader(test_set, **params,collate_fn=ct.default_collate)
    
    earth_set = ct.Dataset(ds_zarr,partition['earth'],spread=spread)
    earth_generator = torch.utils.data.DataLoader(earth_set, **params,collate_fn=ct.default_collate)
    return (training_set,training_generator),\
                (test_set,test_generator),\
                    (earth_set,earth_generator),\
                        ds_zarr

In [18]:
def analysis(args):
    _,_,(dataset,_),_=load_data(args)
    print("loaded data",flush=True)
    model,M, logs,PATH0,PATH1,LOG,root=load_from_save(args)
    device=ct.get_device()
    MSELOC=root+'/MSE-ID'+str(args.model_id)+'.npy'
    SC2LOC=root+'/SC2-ID'+str(args.model_id)+'.npy'
    MSE=torch.zeros(2,dataset.dimens[0]-dataset.spread*2, dataset.dimens[1]-dataset.spread*2)
    SC2=torch.zeros(2,dataset.dimens[0]-dataset.spread*2, dataset.dimens[1]-dataset.spread*2)
    print(MSELOC)
    m=model['m']
    DCT=shift_transform(m)
    arr=np.arange(len(dataset))
    np.random.shuffle(arr)
    for i in range(len(dataset)):
        uv,mask,Sxy=dataset[arr[i]]
        uv=torch.stack([uv]).to(device)
        spatxdim=mask.shape[2]
        spatydim=mask.shape[1]
        with torch.set_grad_enabled(False):
            output=DCT.forward(uv)
        output=output[0].view(2*m**2,-1)        
        
        #  (C,C) * (C,n) -> (C,n)
        uMu=[torch.matmul(M[j],output) for j in range(2)] 
        #  (C,n) o (C,n) -> (1,n)
        Sxy_=[torch.mul(uMu[j],output).sum(0).view(spatydim,spatxdim) for j in range(2)]
        ERR=Sxy*1
        for j in range(2):
            ERR[j]-=Sxy_[j]
        SC2=SC2 + Sxy**2
        MSE=MSE + ERR**2
        if i%10==0:
            MSE_=MSE.numpy()
            SC2_=SC2.numpy()
            with open(MSELOC, 'wb') as f:
                np.save(f, MSE_/(i+1))
            with open(SC2LOC, 'wb') as f:
                np.save(f, SC2_/(i+1))
            print('\t #'+str(i),flush=True)
    MSE_=MSE.numpy()
    SC2_=SC2.numpy()
    with open(MSELOC, 'wb') as f:
        np.save(f, MSE_/(i+1))
    with open(SC2LOC, 'wb') as f:
        np.save(f, SC2_/(i+1))
    print('analysis is done',flush=True)

In [4]:
'''args=cg_options(string_input=['--outdir','2021-06-29-CG','--tol','1e-3',\
                              '--rerun','0','--relog','0','--testrun','1','--model_id','2','--timing','0'])'''

"args=cg_options(string_input=['--outdir','2021-06-29-CG','--tol','1e-3',                              '--rerun','0','--relog','0','--testrun','1','--model_id','2','--timing','0'])"

In [14]:
'''# --outdir 2021-06-28-CG --iter 1000 --testrun 1000 --nworkers 5 --model_id $SLURM_ARRAY_TASK_ID --action analysis
args=cg_options(string_input=['--outdir','2021-06-29-CG','--tol','1e-3',\
                              '--rerun','0','--relog','0','--testrun','1','--model_id','2','--timing','0'])'''

In [15]:
'''analysis(args)'''

loaded data
Loaded the existing model
/scratch/cg3306/climate/runs/2021-06-29-CG/MSE-ID2.npy
	 #0
analysis is done


In [20]:
'''_,_,(dataset,_),_=load_data(args)
print("loaded data",flush=True)
model,M, logs,PATH0,PATH1,LOG,root=load_from_save(args)
device=ct.get_device()
MSELOC=root+'/MSE-ID'+str(args.model_id)+'.npy'
SC2LOC=root+'/SC2-ID'+str(args.model_id)+'.npy'
MSE=torch.zeros(2,dataset.dimens[0]-dataset.spread*2, dataset.dimens[1]-dataset.spread*2)
SC2=torch.zeros(2,dataset.dimens[0]-dataset.spread*2, dataset.dimens[1]-dataset.spread*2)
print(MSELOC)
m=model['m']
DCT=shift_transform(m)
arr=np.arange(len(dataset))
np.random.shuffle(arr)
for i in range(len(dataset)):
    uv,mask,Sxy=dataset[arr[i]]
    uv=torch.stack([uv]).to(device)
    with torch.set_grad_enabled(False):
        output=DCT.forward(uv)
    output=output[0].to(torch.device("cpu"))
    Sxy_=Sxy*1
    for k in range(2):
        output_=output*1
        for j in range(M.shape[1]):
            output_[j]=(M[k,j,:].view(-1,1,1)*output).sum(0)
        Sxy_[k]=(output_*output).sum(0)
    SC2=SC2 + Sxy**2
    MSE=MSE + (Sxy-Sxy_)**2
    break'''

'_,_,(dataset,_),_=load_data(args)\nprint("loaded data",flush=True)\nmodel,M, logs,PATH0,PATH1,LOG,root=load_from_save(args)\ndevice=ct.get_device()\nMSELOC=root+\'/MSE-ID\'+str(args.model_id)+\'.npy\'\nSC2LOC=root+\'/SC2-ID\'+str(args.model_id)+\'.npy\'\nMSE=torch.zeros(2,dataset.dimens[0]-dataset.spread*2, dataset.dimens[1]-dataset.spread*2)\nSC2=torch.zeros(2,dataset.dimens[0]-dataset.spread*2, dataset.dimens[1]-dataset.spread*2)\nprint(MSELOC)\nm=model[\'m\']\nDCT=shift_transform(m)\narr=np.arange(len(dataset))\nnp.random.shuffle(arr)\nfor i in range(len(dataset)):\n    uv,mask,Sxy=dataset[arr[i]]\n    uv=torch.stack([uv]).to(device)\n    with torch.set_grad_enabled(False):\n        output=DCT.forward(uv)\n    output=output[0].to(torch.device("cpu"))\n    Sxy_=Sxy*1\n    for k in range(2):\n        output_=output*1\n        for j in range(M.shape[1]):\n            output_[j]=(M[k,j,:].view(-1,1,1)*output).sum(0)\n        Sxy_[k]=(output_*output).sum(0)\n    SC2=SC2 + Sxy**2\n    MS

In [21]:
'''args=cg_options(string_input=['--iter','5','--tol','1e-3',\
                              '--rerun','1','--relog','1','--testrun','1','--model_id','3','--timing','1'])'''

In [22]:
'''M,err,time_stats=conjugate_gradient_algorithm(args)'''

loaded data
- new beginnings
loaded model
#0 err: tensor([373.4859, 321.0286])
#1 err: tensor([185.8229, 161.9241])
#2 err: tensor([117.6426, 104.7633])
#3 err: tensor([81.4258, 70.7428])
#4 err: tensor([58.8485, 45.0566])


In [3]:
'''args=cg_options(string_input=['--iter','30','--tol','1e-3',\
                              '--rerun','1','--relog','1','--testrun','1','--model_id','2','--timing','1'])'''


'''M,err,time_stats=conjugate_gradient_algorithm(args)'''

'M,err,time_stats=conjugate_gradient_algorithm(args)'

In [None]:
'''carpet_quadratic_operator(M[0],n=9)'''

In [101]:
'''M,err,time_stats=conjugate_gradient_algorithm(args)'''

'M,err,time_stats=conjugate_gradient_algorithm(args)'

In [None]:
def main():
    args=cg_options()
    if args.action=="optimize":
        M,err,time_stats=conjugate_gradient_algorithm(args)
    elif args.action=="analysis":
        analysis(args)
if __name__=='__main__':
    main()