In [143]:
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms
from PIL import Image
import pandas as pd
from copy import deepcopy

device = "cuda" if torch.cuda.is_available() else "cpu"

In [144]:
class CNN(nn.Module):
    def __init__(self, num_feature, image_height, time_lag, device):
        super(CNN, self).__init__()
        
        self.num_feature = num_feature
        self.time_lag = time_lag
        self.image_height = image_height
        
        self.conv1 = nn.Conv2d(in_channels = self.num_feature, out_channels = 1,
                               kernel_size = (1, self.time_lag) , stride=(self.image_height, 1), device = device)
        
        #self.networks = nn.ModuleList([self.conv1 for _ in range(self.num_feature)])
        
        # self.fc1 = nn.Linear()
        
    def forward(self, x):
        pred = self.conv1(x)
        pred = pred.flatten() # size : (target_num)
        
            
        return pred 

In [145]:
class RBFCNN(nn.Module):
    def __init__(self, num_feature, image_height, time_lag, device):
        super(RBFCNN, self).__init__()
        
        self.num_feature = num_feature
        self.time_lag = time_lag
        self.image_height = image_height
        
        self.networks = nn.ModuleList([
            CNN(self.num_feature, self.image_height, time_lag, device) for _ in range(self.num_feature)])
        
        
    def forward(self, X):
        
        pred = [self.networks[i](X).flatten() for i in range(self.num_feature)]
        pred = torch.stack(pred)
        
        return pred
    
    def GC(self, threshold=True):

        GC = [torch.norm(net.conv1.weight, dim=-1)
              for net in self.networks]
        
        GC = torch.stack(GC)
        GC = GC.view((self.num_feature, self.num_feature))
        
        if threshold:
            return (GC > 0).int()
        else:
            return GC

In [146]:
def image_tensor(path, data_number):
    x_list = []
    for i in range(data_number):
        #img = np.asarray(Image.open(path.format(i)), dtype = float)
        img = Image.open(path.format(i))
        convert_tensor = transforms.ToTensor()
        tensor_image = convert_tensor(img)
        x_list.append(tensor_image[0])
    data = torch.stack(x_list)
    data = data.view(1, 10, tensor_image.size()[-2], tensor_image.size()[-1]).to(device)
    
    return data

path = './data/{}_value.jpg'
data = image_tensor(path, 10)
data = data.to(device= device)
print(data.dtype)

torch.float32


In [154]:
def regularize(network, lam): # 아직 적용하지 않음
    W = network.conv1.weight
    return lam * torch.sum(torch.norm(W, dim=0))

def restore_parameters(model, best_model):
    for params, best_params in zip(model.parameters(), best_model.parameters()):
        params.data = best_params
        
def prox_update(network, lam, lr):
    W = network.conv1.weight
    norm = torch.norm(W, dim=0, keepdim=True)
    W.data = ((W / torch.clamp(norm, min=(lam * lr)))
              * torch.clamp(norm - (lr * lam), min=0.0))

def ridge_regularize(network, lam):  # 적용 -> 이게 맞나?
    
    return lam * (torch.sum(network.conv1.weight ** 2, dtype = float))   
    '''
    return lam * (
        torch.sum(network.linear.weight ** 2) +
        torch.sum(network.lstm.weight_hh_l0 ** 2))
    '''

def train_model_ista(rbfcnn, X, Y, lr, max_iter, lam=0.0, lam_ridge=0.0,
                     lookback=5, check_every=50, verbose=1):
    '''Train model with Adam.'''
    
    model.to(device)
    loss_fn = nn.MSELoss(reduction='mean')
    train_loss_list = []
    feature_num = X.size()[-3]
    
    # For early stopping.
    best_it = None
    best_loss = np.inf
    best_model = None

    # Calculate smooth error.
    pred = [rbfcnn.networks[i](X) for i in range(feature_num)]
    loss = sum([loss_fn(pred[i], Y[i]) for i in range(feature_num)])
               
    ridge = sum([ridge_regularize(net, lam_ridge) for net in rbfcnn.networks])
    smooth = loss + ridge # loss : differentiable, ridge non-differentiable
    
    for it in range(max_iter):
        # Take gradient step.
        smooth.backward()
        for param in rbfcnn.parameters():
            param.data -= lr * param.grad

        # Take prox step.
        if lam > 0:
            for net in rbfcnn.networks:
                prox_update(net, lam, lr)

        rbfcnn.zero_grad()

        # Calculate loss for next iteration.
        pred = [rbfcnn.networks[i](X) for i in range(feature_num)]
        loss = sum([loss_fn(pred[i], Y[i]) for i in range(feature_num)])
        ridge = sum([ridge_regularize(net, lam_ridge) for net in rbfcnn.networks])
        smooth = loss + ridge

        # Check progress.
        if (it + 1) % check_every == 0:
            # Add nonsmooth penalty.
            nonsmooth = sum([regularize(net, lam) for net in rbfcnn.networks])
            mean_loss = (smooth + nonsmooth) / feature_num
            
            # mean_loss = loss / feature_num
            train_loss_list.append(mean_loss.detach())

            if verbose > 0:
                print(('-' * 10 + 'Iter = %d' + '-' * 10) % (it + 1))
                print('Loss = %f' % mean_loss)
                print('Variable usage = %.2f%%'
                      % (100 * torch.mean(rbfcnn.GC().float())))

            # Check for early stopping.
            if mean_loss < best_loss:
                best_loss = mean_loss
                best_it = it
                best_model = deepcopy(rbfcnn)
                
            elif (it - best_it) == lookback * check_every:
                if verbose:
                    print('Stopping early')
                break

    # Restore best model.
    restore_parameters(rbfcnn, best_model)

    return train_loss_list

In [162]:
def target_data(dataflame, time_lag, device):
    df = dataflame.iloc()[time_lag - 1:]
    target = df.values.T
    target = torch.tensor(target, device = device, dtype = torch.float32)
    
    return target
df = pd.read_csv('./data/simulate_lorenz_96_100.csv')
target = target_data(df, 10, device)

In [163]:
model = RBFCNN(10, 33, 10, device)

In [164]:
train_model_ista(model, data, target, lam=10.0, lam_ridge=1e-2, lr=1e-3, max_iter=20000,
    check_every=50)

----------Iter = 50----------
Loss = 22.831360
Variable usage = 0.00%
----------Iter = 100----------
Loss = 21.862729
Variable usage = 0.00%
----------Iter = 150----------
Loss = 21.069839
Variable usage = 0.00%
----------Iter = 200----------
Loss = 20.420808
Variable usage = 0.00%
----------Iter = 250----------
Loss = 19.889532
Variable usage = 0.00%
----------Iter = 300----------
Loss = 19.454646
Variable usage = 0.00%
----------Iter = 350----------
Loss = 19.098666
Variable usage = 0.00%
----------Iter = 400----------
Loss = 18.807269
Variable usage = 0.00%
----------Iter = 450----------
Loss = 18.568744
Variable usage = 0.00%
----------Iter = 500----------
Loss = 18.373492
Variable usage = 0.00%
----------Iter = 550----------
Loss = 18.213669
Variable usage = 0.00%
----------Iter = 600----------
Loss = 18.082840
Variable usage = 0.00%
----------Iter = 650----------
Loss = 17.975754
Variable usage = 0.00%
----------Iter = 700----------
Loss = 17.888089
Variable usage = 0.00%
-------

[tensor(22.8314, device='cuda:0', dtype=torch.float64),
 tensor(21.8627, device='cuda:0', dtype=torch.float64),
 tensor(21.0698, device='cuda:0', dtype=torch.float64),
 tensor(20.4208, device='cuda:0', dtype=torch.float64),
 tensor(19.8895, device='cuda:0', dtype=torch.float64),
 tensor(19.4546, device='cuda:0', dtype=torch.float64),
 tensor(19.0987, device='cuda:0', dtype=torch.float64),
 tensor(18.8073, device='cuda:0', dtype=torch.float64),
 tensor(18.5687, device='cuda:0', dtype=torch.float64),
 tensor(18.3735, device='cuda:0', dtype=torch.float64),
 tensor(18.2137, device='cuda:0', dtype=torch.float64),
 tensor(18.0828, device='cuda:0', dtype=torch.float64),
 tensor(17.9758, device='cuda:0', dtype=torch.float64),
 tensor(17.8881, device='cuda:0', dtype=torch.float64),
 tensor(17.8163, device='cuda:0', dtype=torch.float64),
 tensor(17.7576, device='cuda:0', dtype=torch.float64),
 tensor(17.7095, device='cuda:0', dtype=torch.float64),
 tensor(17.6702, device='cuda:0', dtype=torch.fl