In [1]:
import argparse
import numpy as np
from tqdm import tqdm
from earlystopping import EarlyStopping
import sys
import torch.optim as optim
from torch.optim import lr_scheduler
from torch import nn
import torch
from data.seqgen import ClimateData
from data.seqgen_multi_channel import ClimateData_MC
from net_params import convlstm_encoder_params, convlstm_decoder_params
from net_params_large import convlstm_encoder_params_large, convlstm_decoder_params_large
from model import ED
from decoder import Decoder
from encoder import Encoder
import os
import datetime
from torch.utils.tensorboard import SummaryWriter
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

sys.argv = ['main.py']
# TIMESTAMP = str(datetime.datetime.now()).replace(" ", "")
parser = argparse.ArgumentParser()
parser.add_argument('-clstm',
                    '--convlstm',
                    help='use convlstm as base cell',
                    action='store_true')
parser.add_argument('--batch_size',
                    default=1,
                    type=int,
                    help='mini-batch size')
parser.add_argument('-lr', default=1e-4, type=float, help='G learning rate')
parser.add_argument('-frames_input',
                    default=7,
                    type=int,
                    help='sum of input frames')
parser.add_argument('-frames_output',
                    default=1,
                    type=int,
                    help='sum of predict frames')
parser.add_argument('-epochs', default=1, type=int, help='sum of epochs')
args = parser.parse_args()

random_seed = 1996
np.random.seed(random_seed)
torch.manual_seed(random_seed)
if torch.cuda.device_count() > 1:
    torch.cuda.manual_seed_all(random_seed)
else:
    torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

day = 137
runtype = "mc-5-pred-" + str(day)
target = 0
small = True
data_loc = "./data/fin21data5.npz"


TIMESTAMP = runtype + str(target)

save_dir = './save_model/' + TIMESTAMP

if small:
    trainFolder = ClimateData(is_train=True,
                                root=data_loc,
                                n_frames_input=args.frames_input,
                                n_frames_output=args.frames_output, target=target)
    validFolder = ClimateData(is_train=False,
                                root=data_loc,
                                n_frames_input=args.frames_input,
                                n_frames_output=args.frames_output, target=target)
else:
    trainFolder = ClimateData_MC(is_train=True,
                                    root=data_loc,
                                    n_frames_input=args.frames_input,
                                    n_frames_output=args.frames_output, target=target)
    validFolder = ClimateData_MC(is_train=False,
                                    root=data_loc,
                                    n_frames_input=args.frames_input,
                                    n_frames_output=args.frames_output, target=target)
trainLoader = torch.utils.data.DataLoader(trainFolder,
                                            batch_size=args.batch_size,
                                            shuffle=False)
validLoader = torch.utils.data.DataLoader(validFolder,
                                            batch_size=args.batch_size,
                                            shuffle=False)

if small:
    encoder_params = convlstm_encoder_params
    decoder_params = convlstm_decoder_params
else:
    encoder_params = convlstm_encoder_params_large
    decoder_params = convlstm_decoder_params_large

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
encoder = Encoder(encoder_params[0], encoder_params[1]).cuda()
decoder = Decoder(decoder_params[0], decoder_params[1]).cuda()
model0 = ED(encoder, decoder).cuda()
model1 = ED(encoder, decoder).cuda()
model2 = ED(encoder, decoder).cuda()
model3 = ED(encoder, decoder).cuda()
model4 = ED(encoder, decoder).cuda()
model5 = ED(encoder, decoder).cuda()
model6 = ED(encoder, decoder).cuda()




In [2]:
# checkpoint0 = torch.load(
#     './save_model/mc_wowind_adamw0/checkpoint_2_0.002941.pth.tar')
# model0.load_state_dict(checkpoint0['state_dict'])
# ### now you can evaluate it
# model0.eval()


In [4]:
sum(p.numel() for p in model0.parameters())
sum(p.numel() for p in model1.parameters())


14736833

In [52]:
# checkpoint0 = torch.load(
#     './save_model/wind_adamw0/checkpoint_1_0.000017.pth.tar')
# model0.load_state_dict(checkpoint0['state_dict'])
# ### now you can evaluate it
# model0.eval()

# checkpoint1 = torch.load(
#     './save_model/wind_adamw1/checkpoint_1_0.000024.pth.tar')
# model1.load_state_dict(checkpoint1['state_dict'])
# ### now you can evaluate it
# model1.eval()

# checkpoint2 = torch.load(
#     './save_model/wind_adamw2/checkpoint_1_0.000857.pth.tar')
# model2.load_state_dict(checkpoint2['state_dict'])
# ### now you can evaluate it
# model2.eval()

# checkpoint3 = torch.load(
#     './save_model/wind_adamw3/checkpoint_1_0.000029.pth.tar')
# model3.load_state_dict(checkpoint3['state_dict'])
# ### now you can evaluate it
# model3.eval()

# checkpoint4 = torch.load(
#     './save_model/wind_adamw4/checkpoint_1_0.001435.pth.tar')
# model4.load_state_dict(checkpoint4['state_dict'])
# ### now you can evaluate it
# model4.eval()

# checkpoint5 = torch.load(
#     './save_model/wind_adamw5/checkpoint_1_0.000018.pth.tar')
# model5.load_state_dict(checkpoint5['state_dict'])
# ### now you can evaluate it
# model5.eval()

# checkpoint6 = torch.load(
#     './save_model/wind_adamw6/checkpoint_1_0.000021.pth.tar')
# model6.load_state_dict(checkpoint6['state_dict'])
# ### now you can evaluate it
# model6.eval()





ED(
  (encoder): Encoder(
    (stage1): Sequential(
      (conv1_leaky_1): Conv2d(7, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (leaky_conv1_leaky_1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (rnn1): CLSTM_cell(
      (conv): Sequential(
        (0): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (1): GroupNorm(8, 256, eps=1e-05, affine=True)
      )
    )
    (stage2): Sequential(
      (conv2_leaky_1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (leaky_conv2_leaky_1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (rnn2): CLSTM_cell(
      (conv): Sequential(
        (0): Conv2d(192, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (1): GroupNorm(16, 512, eps=1e-05, affine=True)
      )
    )
    (stage3): Sequential(
      (conv3_leaky_1): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (leaky_conv3_leaky_1): LeakyReLU(negative_slope=0.2, inplace=Tru

In [3]:
checkpoint0 = torch.load(
    './save_model/wowind_adamw0/checkpoint_1_0.000017.pth.tar')
model0.load_state_dict(checkpoint0['state_dict'])
### now you can evaluate it
model0.eval()

checkpoint1 = torch.load(
    './save_model/wowind_adamw1/checkpoint_1_0.000024.pth.tar')
model1.load_state_dict(checkpoint1['state_dict'])
### now you can evaluate it
model1.eval()

checkpoint2 = torch.load(
    './save_model/wowind_adamw2/checkpoint_1_0.000835.pth.tar')
model2.load_state_dict(checkpoint2['state_dict'])
### now you can evaluate it
model2.eval()

checkpoint3 = torch.load(
    './save_model/wowind_adamw3/checkpoint_1_0.000023.pth.tar')
model3.load_state_dict(checkpoint3['state_dict'])
### now you can evaluate it
model3.eval()

checkpoint4 = torch.load(
    './save_model/wowind_adamw4/checkpoint_1_0.001358.pth.tar')
model4.load_state_dict(checkpoint4['state_dict'])
### now you can evaluate it
model4.eval()


ED(
  (encoder): Encoder(
    (stage1): Sequential(
      (conv1_leaky_1): Conv2d(5, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (leaky_conv1_leaky_1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (rnn1): CLSTM_cell(
      (conv): Sequential(
        (0): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (1): GroupNorm(8, 256, eps=1e-05, affine=True)
      )
    )
    (stage2): Sequential(
      (conv2_leaky_1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (leaky_conv2_leaky_1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (rnn2): CLSTM_cell(
      (conv): Sequential(
        (0): Conv2d(192, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (1): GroupNorm(16, 512, eps=1e-05, affine=True)
      )
    )
    (stage3): Sequential(
      (conv3_leaky_1): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (leaky_conv3_leaky_1): LeakyReLU(negative_slope=0.2, inplace=Tru

In [38]:
labels = []

t = tqdm(trainLoader, leave=False, total=len(trainLoader))
for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
    if i >= 0 + day:
        inputs = inputVar.to(device)  # B,S,C,H,W
        label = targetVar.to(device)  # B,S,C,H,W
        labels.append(np.array(label.cpu()))
    if i==13 + day : break
    if i==0 + day:
        inn = inputs




In [39]:
inputs = inn
inputs.shape

torch.Size([1, 7, 5, 73, 144])

In [40]:
# preds = []
# for i in range(14):
#     model0.load_state_dict(checkpoint0['state_dict'])
#     model0.eval()
#     pred0 = model0(inputs)
#     model1.load_state_dict(checkpoint1['state_dict'])
#     model1.eval()
#     pred1 = model1(inputs)
#     model2.load_state_dict(checkpoint2['state_dict'])
#     model2.eval()
#     pred2 = model2(inputs)
#     model3.load_state_dict(checkpoint3['state_dict'])
#     model3.eval()
#     pred3 = model3(inputs)
#     model4.load_state_dict(checkpoint4['state_dict'])
#     model4.eval()
#     pred4 = model4(inputs)
#     model5.load_state_dict(checkpoint5['state_dict'])
#     model5.eval()
#     pred5 = model5(inputs)
#     model6.load_state_dict(checkpoint6['state_dict'])
#     model6.eval()
#     pred6 = model6(inputs)

#     outputs = np.concatenate([pred0.detach().cpu(), pred1.detach().cpu(),
#                               pred2.detach().cpu(), pred3.detach().cpu(), pred4.detach().cpu(), pred5.detach().cpu(), pred6.detach().cpu()], axis=2)
#     preds.append(outputs)

#     inputs = inputs[:, 1:, ...]
#     inputs = torch.from_numpy(np.concatenate([inputs.detach().cpu(), outputs], axis=1))
#     inputs = inputs.cuda()


In [41]:
preds = []
for i in range(14):
    print(i)
    outputs = model0(inputs)
    preds.append(outputs.detach().cpu().numpy())
    inputs = inputs[:, 1:, ...]
    inputs = torch.from_numpy(np.concatenate(
        [inputs.detach().cpu(), outputs.detach().cpu()], axis=1))
    inputs = inputs.cuda()


0
1
2
3
4
5
6
7
8
9
10
11
12
13


In [42]:
preds

[array([[[[[ 1.27471343e-01,  1.75888449e-01,  1.90088987e-01, ...,
             1.94444716e-01,  1.75078332e-01,  1.35235980e-01],
           [ 1.59792379e-01,  2.29017392e-01,  2.45492429e-01, ...,
             2.45577320e-01,  2.32329786e-01,  1.77894652e-01],
           [ 1.82368934e-01,  2.52979785e-01,  2.77233183e-01, ...,
             2.64578223e-01,  2.46349439e-01,  1.97641119e-01],
           ...,
           [ 1.66147724e-01,  2.10224211e-01,  2.30907753e-01, ...,
             2.16914952e-01,  2.00512558e-01,  1.67314112e-01],
           [ 1.56144276e-01,  2.02388793e-01,  2.26537004e-01, ...,
             2.15300232e-01,  1.96582019e-01,  1.58885434e-01],
           [ 1.21437214e-01,  1.53791606e-01,  1.76470220e-01, ...,
             1.62543014e-01,  1.49437606e-01,  1.19032785e-01]],
 
          [[-7.87538651e-04,  1.04826381e-02,  7.04808440e-03, ...,
             2.69989856e-03,  3.93834431e-03,  9.01165232e-03],
           [-1.73772580e-03,  1.14180110e-02,  7.43648130

In [43]:
labels[0][:,:,0, ...]

array([[[[0.27165   , 0.27165   , 0.27165   , ..., 0.27165   ,
          0.27165   , 0.27165   ],
         [0.27195   , 0.27195   , 0.27198   , ..., 0.27191   ,
          0.27195   , 0.27196   ],
         [0.271     , 0.27087   , 0.27075   , ..., 0.27170002,
          0.27135   , 0.27118   ],
         ...,
         [0.2252    , 0.22473   , 0.22435   , ..., 0.22735001,
          0.22655   , 0.22575   ],
         [0.22473   , 0.22465   , 0.22452   , ..., 0.22496   ,
          0.22487   , 0.22477001],
         [0.22425   , 0.22425   , 0.22425   , ..., 0.22425   ,
          0.22425   , 0.22425   ]]]], dtype=float32)

In [44]:
os.mkdir(save_dir)
np.save(save_dir + "/inputs", inn.detach().cpu())
np.save(save_dir + "/labels", np.array(labels))
np.save(save_dir + "/preds", np.array(preds))
