In [1]:
import gc
import os
import random
import time
import torch
import datetime
import numpy as np
import pandas as pd

import torch.nn as nn
from torch.nn import AvgPool1d 
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import LSTM, Conv1d, GRU, TransformerEncoder, TransformerEncoderLayer, BatchNorm1d, LayerNorm
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from torchmetrics.regression import R2Score
import gc
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from tqdm import tqdm
from glob import glob
import hickle as hkl

import random
import math

from matplotlib.pyplot import plot

In [None]:
# Low-Res data and 1/15th of the high res data is preprocessed into .hkl files, each with 384*6 data points per files

DATA_PATH = 'autodl-tmp/'

file_list = glob(DATA_PATH+'CDFrag/'+'c_data_frag_x*hkl') #Path to low-res data
file_list_mix = glob(DATA_PATH+'CDFrag L/'+'c_data_frag_x_*_1e-15.hkl') #Path to high-res data

In [2]:
len(file_list)

34296

In [3]:
len(file_list_mix)

96672

In [10]:
seq_fea_list = ['state_t','state_q0001','state_q0002','state_q0003','state_u','state_v','pbuf_ozone','pbuf_CH4','pbuf_N2O']
num_fea_list = ['state_ps','pbuf_SOLIN','pbuf_LHFLX','pbuf_SHFLX','pbuf_TAUX','pbuf_TAUY','pbuf_COSZRS','cam_in_ALDIF','cam_in_ALDIR',
                'cam_in_ASDIF','cam_in_ASDIR','cam_in_LWUP','cam_in_ICEFRAC','cam_in_LANDFRAC','cam_in_OCNFRAC','cam_in_SNOWHLAND']

seq_y_list = ['ptend_t','ptend_q0001','ptend_q0002','ptend_q0003','ptend_u','ptend_v']
num_y_list = ['cam_out_NETSW','cam_out_FLWDS','cam_out_PRECSC','cam_out_PRECC','cam_out_SOLS','cam_out_SOLL','cam_out_SOLSD','cam_out_SOLLD']

seq_fea_expand_list = []
for i in seq_fea_list:
    for j in range(60):
        seq_fea_expand_list.append(i+'_'+str(j))

seq_y_expand_list = []
for i in seq_y_list:
    for j in range(60):
        seq_y_expand_list.append(i+'_'+str(j))
        
norm_dict = dict()
TARGET_COLS = seq_y_expand_list + num_y_list
FEAT_COLS = seq_fea_expand_list + num_fea_list

In [11]:
len(FEAT_COLS)

556

In [None]:
#Define a numpy dataset for loading the validation data
class NumpyDataset(Dataset):
    def __init__(self, x, y):
        """
        Initialize with NumPy arrays.
        """
        assert x.shape[0] == y.shape[0], "Features and labels must have the same number of samples"
        self.x = x
        self.y = y

    def __len__(self):
        """
        Total number of samples.
        """
        return self.x.shape[0]

    def __getitem__(self, index):
        """
        Generate one sample of data.
        """
        # Convert the data to tensors when requested
        return torch.from_numpy(self.x[index]).float().to(device), torch.from_numpy(self.y[index]).float().to(device)

x_test = hkl.load(DATA_PATH+'x_test_v1_1e-15.hkl')
x_valid = hkl.load(DATA_PATH+'c_data_x_8_1_v1_1e-15.hkl')
y_valid = hkl.load(DATA_PATH+'c_data_y_8_1_v1_1e-15.hkl')

val_dataset = NumpyDataset(x_valid, y_valid)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
#Define some helper functions
def format_time(elapsed):
    """Take a time in seconds and return a string hh:mm:ss."""
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))

# Definition of 3 Best Models

In [None]:
zeroout_index = list(range(145,147)) #Indices to just have a 0 output
zeroout_index = torch.tensor(list(set(zeroout_index)))

In [None]:
class FFNN_LSTM_888_AVG_ATT(nn.Module):
    def __init__(self, input_size, output_size):
        super(FFNN_LSTM_888_AVG_ATT, self).__init__()
        
        self.encode_dim = 256
        self.hidden_dim = 256
        self.iter_dim = 800

        self.LSTM_1 = LSTM(self.encode_dim,self.hidden_dim,6,batch_first=True,dropout=0.01,bidirectional=True)
        self.input_size = input_size
        self.Linear_1 = nn.Linear(len(seq_fea_list)+len(num_fea_list), self.encode_dim)
        self.Linear_2 = nn.Linear(6*self.hidden_dim+self.encode_dim, self.iter_dim)
        self.Linear_3 = nn.Linear(self.iter_dim, len(seq_y_list))
        self.Linear_3_0 = nn.Linear(self.iter_dim, 1)

        self.Linear_4_0 = nn.Linear(self.iter_dim, self.iter_dim*2)

        self.Linear_4 = nn.Linear(self.iter_dim*2, len(num_y_list))
        self.bias = nn.Linear(len(seq_y_list)*60+len(num_y_list),1)
        self.weight = nn.Linear(len(seq_y_list)*60+len(num_y_list),1)
        self.avg_pool_1 = AvgPool1d(kernel_size=3,stride=1,padding=1)
        
    def forward(self, x):
        x_seq = x[:,0:60*len(seq_fea_list)]
        x_seq = x_seq.reshape((-1,len(seq_fea_list),60))
        x_seq = torch.transpose(x_seq, 1, 2)
        
        x_num = x[:,60*len(seq_fea_list):x.shape[1]]
        x_num_repeat = x_num.reshape((-1,1,len(num_fea_list)))
        x_num_repeat = x_num_repeat.repeat((1,60,1))
        
        x_seq = F.elu(self.Linear_1(torch.concat((x_seq,x_num_repeat),dim=-1)/5))
        
        x_seq_1,_ = self.LSTM_1(x_seq/5)
        
        x_seq_1_mean = torch.mean(x_seq_1,dim=1,keepdim=True)
        x_seq_1_mean = x_seq_1_mean.repeat((1,60,1))

        x_seq_1_avg_pool = self.avg_pool_1(torch.transpose(x_seq_1, 1, 2))
        x_seq_1_avg_pool = torch.transpose(x_seq_1_avg_pool,1, 2)
        
        x_seq_1 = F.elu(self.Linear_2(torch.cat((x_seq_1,x_seq_1_mean,x_seq,x_seq_1_avg_pool),dim=-1)/5))
        
        att_weight = F.softmax(self.Linear_3_0(x_seq_1) - 10,dim=1)
        
        x_seq_out = self.Linear_3(x_seq_1)
        x_seq_out = torch.transpose(x_seq_out, 1, 2)
        x_seq_out = x_seq_out.reshape((-1,60*len(seq_y_list)))
        
        x_num_out = F.elu(self.Linear_4_0(torch.sum(att_weight*x_seq_1,dim=1)))
        x_num_out = self.Linear_4(x_num_out)
        
        return self.weight.weight*(torch.concat((x_seq_out,x_num_out),dim=-1))/3+self.bias.weight/3

In [None]:
class FFNN_LSTM_6_AVG(nn.Module):
    def __init__(self, input_size, output_size):
        super(FFNN_LSTM_6_AVG, self).__init__()
        
        self.encode_dim = 300
        self.hidden_dim = 280
        self.iter_dim = 800

        self.LSTM_1 = LSTM(self.encode_dim,self.hidden_dim,6,batch_first=True,dropout=0.01,bidirectional=True)
        self.input_size = input_size
        self.Linear_1 = nn.Linear(len(seq_fea_list)+len(num_fea_list), self.encode_dim)
        self.Linear_2 = nn.Linear(6*self.hidden_dim+self.encode_dim, self.iter_dim)
        self.Linear_3 = nn.Linear(self.iter_dim, len(seq_y_list))
        self.Linear_4_0 = nn.Linear(self.iter_dim, self.iter_dim*2)

        self.Linear_4 = nn.Linear(self.iter_dim*2, len(num_y_list))
        self.bias = nn.Linear(len(seq_y_list)*60+len(num_y_list),1)
        self.weight = nn.Linear(len(seq_y_list)*60+len(num_y_list),1)
        self.avg_pool_1 = AvgPool1d(kernel_size=3,stride=1,padding=1)
        
    def forward(self, x):
        x_seq = x[:,0:60*len(seq_fea_list)]
        x_seq = x_seq.reshape((-1,len(seq_fea_list),60))
        x_seq = torch.transpose(x_seq, 1, 2)
        
        x_num = x[:,60*len(seq_fea_list):x.shape[1]]
        x_num_repeat = x_num.reshape((-1,1,len(num_fea_list)))
        x_num_repeat = x_num_repeat.repeat((1,60,1))
        
        x_seq = F.elu(self.Linear_1(torch.concat((x_seq,x_num_repeat),dim=-1)/5))
        
        x_seq_1,_ = self.LSTM_1(x_seq/5)
        
        x_seq_1_mean = torch.mean(x_seq_1,dim=1,keepdim=True)
        x_seq_1_mean = x_seq_1_mean.repeat((1,60,1))

        x_seq_1_avg_pool = self.avg_pool_1(torch.transpose(x_seq_1, 1, 2))
        x_seq_1_avg_pool = torch.transpose(x_seq_1_avg_pool,1, 2)
        
        x_seq_1 = F.elu(self.Linear_2(torch.cat((x_seq_1,x_seq_1_mean,x_seq,x_seq_1_avg_pool),dim=-1)/5))
        
        x_seq_out = self.Linear_3(x_seq_1)
        x_seq_out = torch.transpose(x_seq_out, 1, 2)
        x_seq_out = x_seq_out.reshape((-1,60*len(seq_y_list)))
        
        x_num_out = F.elu(self.Linear_4_0(torch.mean(x_seq_1,dim=1)))
        x_num_out = self.Linear_4(x_num_out)

        output = self.weight.weight*(torch.concat((x_seq_out,x_num_out),dim=-1))/3+self.bias.weight/3
        
        output[:,zeroout_index] =  output[:,zeroout_index]*0.0
        
        return output

In [None]:
class FFNN_LSTM_749_AVG_ATT(nn.Module):
    def __init__(self, input_size, output_size):
        super(FFNN_LSTM_749_AVG_ATT, self).__init__()
        
        self.encode_dim = 256
        self.hidden_dim = 320
        self.iter_dim = 1024

        self.LSTM_1 = LSTM(self.encode_dim,self.hidden_dim,6,batch_first=True,dropout=0.05,bidirectional=True)
        self.input_size = input_size
        self.Linear_1 = nn.Linear(len(seq_fea_list)+len(num_fea_list), self.encode_dim)
        self.Linear_2 = nn.Linear(6*self.hidden_dim+self.encode_dim, self.iter_dim)
        self.Linear_3 = nn.Linear(self.iter_dim, len(seq_y_list))
        self.Linear_3_0 = nn.Linear(self.iter_dim, 1)

        self.Linear_4_0 = nn.Linear(self.iter_dim, self.iter_dim*2)

        self.Linear_4 = nn.Linear(self.iter_dim*2, len(num_y_list))
        self.bias = nn.Linear(len(seq_y_list)*60+len(num_y_list),1)
        self.weight = nn.Linear(len(seq_y_list)*60+len(num_y_list),1)
        self.avg_pool_1 = AvgPool1d(kernel_size=3,stride=1,padding=1)
        
    def forward(self, x):
        
        x_seq = x[:,0:60*len(seq_fea_list)]
        x_seq = x_seq.reshape((-1,len(seq_fea_list),60))
        x_seq = torch.transpose(x_seq, 1, 2)
        
        x_num = x[:,60*len(seq_fea_list):x.shape[1]]
        x_num_repeat = x_num.reshape((-1,1,len(num_fea_list)))
        x_num_repeat = x_num_repeat.repeat((1,60,1))
        
        x_seq = F.elu(self.Linear_1(torch.concat((x_seq,x_num_repeat),dim=-1)/5))
        
        x_seq_1,_ = self.LSTM_1(x_seq/5)
        
        x_seq_1_mean = torch.mean(x_seq_1,dim=1,keepdim=True)
        x_seq_1_mean = x_seq_1_mean.repeat((1,60,1))

        x_seq_1_avg_pool = self.avg_pool_1(torch.transpose(x_seq_1, 1, 2))
        x_seq_1_avg_pool = torch.transpose(x_seq_1_avg_pool,1, 2)
        
        x_seq_1 = F.elu(self.Linear_2(torch.cat((x_seq_1,x_seq_1_mean,x_seq,x_seq_1_avg_pool),dim=-1)/5))
        
        att_weight = F.softmax(self.Linear_3_0(x_seq_1) - 10,dim=1)
        
        x_seq_out = self.Linear_3(x_seq_1)
        x_seq_out = torch.transpose(x_seq_out, 1, 2)
        x_seq_out = x_seq_out.reshape((-1,60*len(seq_y_list)))
        
        x_num_out = F.elu(self.Linear_4_0(torch.sum(att_weight*x_seq_1,dim=1)))
        x_num_out = self.Linear_4(x_num_out)
        
        return self.weight.weight*(torch.concat((x_seq_out,x_num_out),dim=-1))/3+self.bias.weight/3

# Train Models

In [None]:
BATCH_SIZE = 384*6*3
MIN_STD = 1e-10
SCHEDULER_PATIENCE = 6
SCHEDULER_FACTOR = 10**(-0.2)
EPOCHS = 70
PATIENCE = 6
PRINT_FREQ = 300
BIN_NUM = 10

In [14]:
ts = time.time()

gc.collect()

print("Time after processing data:", format_time(time.time()-ts), flush=True)    

input_size = x_valid.shape[1]
output_size = y_valid.shape[1]
hidden_size = input_size + output_size

model_single = FFNN_LSTM_6_AVG(input_size, output_size)
device_ids = list(range(torch.cuda.device_count()))
model = torch.nn.DataParallel(model_single)

zeroout_index.to(device)
model.to(device)

criterion = nn.MSELoss()
criterion_l1 = nn.L1Loss()
criterion_huber = nn.HuberLoss(delta=0.5)

from ema_pytorch import EMA
ema = EMA(
    model,
    beta = 0.99,              # exponential moving average factor
    update_after_step = 50,    # only after this number of .update() calls will it start updating
    update_every = 8,          # how often to actually update, to save on compute (updates every 10th .update() call)
)

optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.0002)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=SCHEDULER_FACTOR, patience=SCHEDULER_PATIENCE)

print("Time after all preparations:", format_time(time.time()-ts), flush=True)

Time after processing data: 0:00:00
Time after all preparations: 0:00:01


In [None]:
time_gap = 0.00000005

best_val_loss = float('inf')
best_model_state = None
patience_count = 0
r2score = R2Score(num_outputs=y_valid.shape[1]).to(device)

for epoch in range(1,400):
    gc.collect()
    torch.cuda.empty_cache()
    
    random.shuffle(file_list)
    random.shuffle(file_list_mix)
    
    if epoch > 9:
        for g in optimizer.param_groups:
            #g['lr'] = max(g['lr']*0.8,1e-4)
            g['lr']  = 1e-4


    print("")
    model.train()
    ema.train()
    total_loss = 0
    steps = 0
    batch_idx = -1
    
    x_train = []
    y_train = []
    batch_num = 105

    x_train = np.zeros((batch_num*384*6,input_size),dtype=np.float32)
    y_train = np.zeros((batch_num*384*6,output_size),dtype=np.float32)

    x_train_mix = np.zeros((batch_num*384*6,input_size),dtype=np.float32)
    y_train_mix = np.zeros((batch_num*384*6,output_size),dtype=np.float32)

    for file_i in range(len(file_list)): 
        
        y_tmp = hkl.load(file_list[file_i][0:30]+'y'+file_list[file_i][31:len(file_list[file_i])])
        x_tmp = hkl.load(file_list[file_i])

        y_tmp_mix = hkl.load(file_list_mix[file_i][0:32]+'y'+file_list_mix[file_i][33:len(file_list_mix[file_i])])
        x_tmp_mix = hkl.load(file_list_mix[file_i])
        
        x_train[file_i%batch_num*384*6:file_i%batch_num*384*6+384*6,:] = x_tmp
        y_train[file_i%batch_num*384*6:file_i%batch_num*384*6+384*6,:] = y_tmp

        x_train_mix[file_i%batch_num*384*6:file_i%batch_num*384*6+384*6,:] = x_tmp_mix
        y_train_mix[file_i%batch_num*384*6:file_i%batch_num*384*6+384*6,:] = y_tmp_mix

        if (file_i+1)%batch_num == 0:
            gc.collect()

            y_train[:,zeroout_index] = y_train[:,zeroout_index]*0.0
            y_train_mix[:,zeroout_index] = y_train_mix[:,zeroout_index]*0.0
          
            random_index = np.random.permutation(x_train.shape[0])
            i1 = 0
            for i in range(x_train.shape[0]//BATCH_SIZE+1):
                time.sleep(time_gap)
                i2 = np.minimum(i1 + BATCH_SIZE, x_train.shape[0])
                if i1 == i2:  # Break the loop if range does not change
                    break
        
                # Convert the current slice of xt to a PyTorch tensor
                inputs = torch.from_numpy(x_train[random_index[i1:i2], :]).to(device)
                batch_idx = batch_idx + 1

                outputs = model(inputs)
                outputs_y = torch.from_numpy(y_train[random_index[i1:i2], :]).to(device)

                inputs_mix = torch.from_numpy(x_train_mix[random_index[i1:i2], :]).to(device)

                outputs_mix = model(inputs_mix)
                outputs_y_mix = torch.from_numpy(y_train_mix[random_index[i1:i2], :]).to(device)

                
                loss = 0.08*criterion(outputs,outputs_y)+criterion_l1(outputs,outputs_y)
                loss += 0.3*criterion_l1(outputs_mix,outputs_y_mix)
                loss += 0.0005*criterion(outputs_mix,outputs_y_mix)

                loss.backward()

                optimizer.step()
                optimizer.zero_grad()
                
                ema.update()
        
                total_loss += loss.item()
                steps += 1
        
                if (batch_idx + 1) % PRINT_FREQ == 0:
                    current_lr = optimizer.param_groups[0]["lr"]
                    elapsed_time = format_time(time.time() - ts)
                    print(f'  Epoch: {epoch+1}',\
                          # f'  Batch: {batch_idx + 1}',\
                          f'  Train Loss: {total_loss / steps:.4f}',\
                          f'  LR: {current_lr:.1e}',\
                          f'  Time: {elapsed_time}', flush=True)
                    with open('log0.txt', 'a') as file:
                        file.write(f'  Epoch: {epoch+1}  Train Loss: {total_loss / steps:.4f}  LR: {current_lr:.1e} Time: {elapsed_time}' + '\n')
                    total_loss = 0
                    steps = 0
        
                # No need to track gradients for inference
                i1 = i2  # Update i1 to the end of the current batch
                if i2 >= x_train.shape[0]:
                    break

    model.eval()
    val_loss = 0
    y_true = torch.tensor([], device=device)
    all_outputs = torch.tensor([], device=device)
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader):
            time.sleep(time_gap)

            outputs = model(inputs)
            val_loss += criterion(outputs, labels).item()
            y_true = torch.cat((y_true, labels), 0)
            all_outputs = torch.cat((all_outputs, outputs), 0)
    
    r2=0
    r2_broken = []
    r2_broken_names = []
    for i in range(368):
        r2_i = r2score(all_outputs[:, i], y_true[:, i])
        if r2_i > 1e-6:
            r2 += r2_i
        else:
            r2_broken.append(i)
            r2_broken_names.append(FEAT_COLS[i])
    r2 /= 368

    avg_val_loss = val_loss / len(val_loader)
    print(f'\nEpoch: {epoch+1}  Val Loss: {avg_val_loss:.4f}  R2 score: {r2:.4f}')
    print(f'{len(r2_broken)} targets were excluded during evaluation of R2 score.')
    with open('log0.txt', 'a') as file:
            file.write(f'\nEpoch: {epoch+1}  Val Loss: {avg_val_loss:.4f}  R2 score: {r2:.4f}' + '\n')

    scheduler.step(round(avg_val_loss*2,4))
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model_state = model.state_dict()
        patience_count = 0
        print("Validation loss decreased, saving new best model and resetting patience counter.")
    else:
        patience_count += 1
        print(f"No improvement in validation loss for {patience_count} epochs.")
        
    if patience_count >= PATIENCE:
        print("Stopping early due to no improvement in validation loss.")
        break
    
    ema.eval()
    val_loss = 0
    y_true = torch.tensor([], device=device)
    all_outputs = torch.tensor([], device=device)
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader):
            time.sleep(time_gap)

            outputs = ema(inputs)
            val_loss += criterion(outputs, labels).item()
            y_true = torch.cat((y_true, labels), 0)
            all_outputs = torch.cat((all_outputs, outputs), 0)
    valid_results = all_outputs.to('cpu').numpy()

    r2=0
    r2_broken = []
    r2_broken_names = []
    for i in range(368):
        r2_i = r2score(all_outputs[:, i], y_true[:, i])
        if r2_i > 1e-6:
            r2 += r2_i
        else:
            r2_broken.append(i)
            r2_broken_names.append(FEAT_COLS[i])
    r2 /= 368
    
    avg_val_loss = val_loss / len(val_loader)
    print(f'\nEpoch: {epoch+1}  Val Loss: {avg_val_loss:.4f}  R2 score: {r2:.4f}')
    with open('log0.txt', 'a') as file:
        file.write(f'\nEpoch: {epoch+1}  Val Loss: {avg_val_loss:.4f}  R2 score: {r2:.4f}' + '\n')
    print(f'{len(r2_broken)} targets were excluded during evaluation of R2 score.')

    predt = np.zeros([x_test.shape[0], output_size], dtype=np.float32)

    if epoch%1==0 and epoch>0:
        i1 = 0
        for i in tqdm(range(10000)):
            time.sleep(time_gap)

            i2 = np.minimum(i1 + BATCH_SIZE, x_test.shape[0])
            if i1 == i2:  # Break the loop if range does not change
                break
    
            # Convert the current slice of xt to a PyTorch tensor
            inputs = torch.from_numpy(x_test[i1:i2, :]).float().to(device)
    
            # No need to track gradients for inference
            with torch.no_grad():
                outputs = ema(inputs)  # Get model predictions
                predt[i1:i2, :] = outputs.cpu().numpy()  # Store predictions in predt
    
            i1 = i2  # Update i1 to the end of the current batch
    
            if i2 >= x_test.shape[0]:
                break
                
        hkl.dump(predt, DATA_PATH+'FFNN_LSTM_999_AVG_ATT_'+str(epoch)+'_mix_1e-15.hkl', compression='gzip')
        torch.save(model.state_dict(), DATA_PATH+'FFNN_LSTM_999_AVG_ATT_1e-15_'+str(epoch)+'.pt')
        torch.save(ema.state_dict(), DATA_PATH+'FFNN_LSTM_999_AVG_ATT_1e-15_ema_'+str(epoch)+'.pt')


  Epoch: 2   Train Loss: 0.5134   LR: 1.0e-03   Time: 0:02:45
  Epoch: 2   Train Loss: 0.5003   LR: 1.0e-03   Time: 0:05:16
  Epoch: 2   Train Loss: 0.4746   LR: 1.0e-03   Time: 0:07:44
  Epoch: 2   Train Loss: 0.4365   LR: 1.0e-03   Time: 0:10:15
  Epoch: 2   Train Loss: 0.4026   LR: 1.0e-03   Time: 0:12:42
  Epoch: 2   Train Loss: 0.3955   LR: 1.0e-03   Time: 0:15:12
  Epoch: 2   Train Loss: 0.3608   LR: 1.0e-03   Time: 0:17:40
  Epoch: 2   Train Loss: 0.3438   LR: 1.0e-03   Time: 0:20:11
  Epoch: 2   Train Loss: 0.3138   LR: 1.0e-03   Time: 0:22:58
  Epoch: 2   Train Loss: 0.3123   LR: 1.0e-03   Time: 0:25:42
  Epoch: 2   Train Loss: 0.3019   LR: 1.0e-03   Time: 0:28:28
  Epoch: 2   Train Loss: 0.3076   LR: 1.0e-03   Time: 0:31:16


In [None]:
torch.save(ema.state_dict(), DATA_PATH+'FFNN_LSTM_6_AVG_1e-15_ema_'+str(epoch)+'.pt')
torch.save(model.state_dict(), DATA_PATH+'FFNN_LSTM_6_AVG_1e-15_'+str(epoch)+'.pt')