In [39]:
import sys
import json
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Any, Generator
from tqdm import tqdm
from collections import Counter
import random
import torch.nn as nn
import torchmetrics as tm
main_path = Path('..').resolve()
sys.path.append(str(main_path))
from copy import deepcopy
from collections import OrderedDict
import torch
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from typing import Dict

## Data & Task generator

In [2]:
def flatten(li: List[Any]) -> Generator:
    """flatten nested list
    ```python
    x = [[[1], 2], [[[[3]], 4, 5], 6], 7, [[8]], [9], 10]
    print(type(flatten(x)))
    # <generator object flatten at 0x00000212BF603CC8>
    print(list(flatten(x)))
    # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    ```
    Args:
        li (List[Any]): any kinds of list
    Yields:
        Generator: flattened list generator
    """
    for ele in li:
        if isinstance(ele, list) or isinstance(ele, tuple):
            yield from flatten(ele)
        else:
            yield ele

In [3]:
class PanelDataDict(dict):
    def __init__(self, data, window_size):
        self.window_size = window_size
        self._set_state(f'numpy')
        for k, v in data.items():
            data[k] = np.array(v)
        
        self.n_stocks = len(v)
        super().__init__(data)
    
    def tensor_fn(self, value, key):
        return torch.FloatTensor(value)

    def _set_state(self, state: str):
        self.state = state

    def to(self, device: None | str=None):
        if device is None:
            device = torch.device('cpu')
        else:
            device = torch.device(device)
        self._set_state(f'tensor.{device}')
        for key in self.keys():
            value = self.__getitem__(key)
            tvalue = self.tensor_fn(value, key)
            self.__setitem__(key, tvalue.to(device)) 
        
    def numpy(self):
        self._set_state('numpy')
        for key in self.keys():
            tvalue = self.__getitem__(key)
            if not isinstance(tvalue, np.ndarray): 
                self.__setitem__(key, tvalue.detach().numpy())

    def __str__(self):
        s = f'PanelDataDict(T={self.window_size}, {self.state})\n'
        for i, key in enumerate(self.keys()):
            value = self.__getitem__(key)
            s += f'- {key}: {value.shape}'
            s += '' if i == len(self.keys())-1 else '\n'
        return s

class StockRegressionDataset(torch.utils.data.Dataset):
    def __init__(
            self, 
            meta_type: str ='train', 
            data_dir: Path | str ='', 
            dtype: str ='kdd17', 
            batch_size: int =64,
            n_support: int =5, 
            n_query: int = 3,
            window_sizes: List[int] =[5]
        ):    
        """dataset ref: https://arxiv.org/abs/1810.09936

        In this meta learning setting, we have 3 meta-test and 1 meta-train
        vertical = stocks, horizontal = time
                train      |    test
           A               |
           B   meta-train  |   meta-test
           C               |      (1)
           ----------------|-------------
           D   meta-test   |   meta-test
           E     (2)       |      (3)

        meta-test (1) same stock, different time
        meta-test (2) different stock, same time
        meta-test (3) different stock, different time
        use `valid_date` to split the train / test set

        the number of training stock was splitted with number of total stocks * 0.8.
        we have 5 stock universe

        Args:
            meta_type (str, optional): _description_. Defaults to 'train'.
            data_dir (Path | str, optional): _description_. Defaults to ''.
            dtype (str, optional): _description_. Defaults to 'kdd17'.
            stock_universe (int, optional): _description_. Defaults to 0.
            batch_size (int, optional): Batch size. Number of stock x Number of timestamp that is aviable for each window size. Defaults to 64.
            n_support (int, optional): Number of support. Defaults to 4.

        """
        super().__init__()

        # data config
        self.data_dir = Path(data_dir).resolve()
        ds_info = {
            # train: (Jan-01-2007 to Jan-01-2015)
            # val: (Jan-01-2015 to Jan-01-2016)
            # test: (Jan-01-2016 to Jan-01-2017)
            'kdd17': {
                'path': self.data_dir / 'kdd17/price_long_50',
                'date': self.data_dir / 'kdd17/trading_dates.csv',
                'universe': self.data_dir / 'kdd17/stock_universe.json', 
                'start_date': '2007-01-01',
                'train_date': '2015-01-01', 
                'valid_date': '2016-01-01', 
                'test_date': '2017-01-01',
            },
            # train: (Jan-01-2014 to Aug-01-2015)
            # val: (Aug-01-2015 to Oct-01-2015)
            # test: (Oct-01-2015 to Jan-01-2016)
            'acl18': {
                'path': self.data_dir / 'stocknet-dataset/price/raw',
                'date': self.data_dir / 'stocknet-dataset/price/trading_dates.csv',
                'universe': self.data_dir / 'stocknet-dataset/stock_universe.json',
                'start_date': '2014-01-01',
                'train_date': '2015-08-01', 
                'valid_date': '2015-10-01', 
                'test_date': '2016-01-01',
            }
        }
        ds_config = ds_info[dtype]
        
        self.meta_type = meta_type
        self.window_sizes = window_sizes
        self.batch_size = batch_size
        self.n_support = n_support
        self.n_query = n_query

        # get data
        self.data = {}
        self.all_tasks = {}
        ps = list((ds_config['path']).glob('*.csv'))
        with ds_config['universe'].open('r') as file:
            universe_dict = json.load(file)
        
        # meta_type: train / valid1: valid-time, valid2: valid-stock, valid3: valid-mix / test1, test2, test3
        if meta_type in ['train', 'valid-time', 'test-time']:
            universe = universe_dict['train']
        elif meta_type in ['valid-stock', 'valid-mix']:
            universe = universe_dict['valid']
        elif meta_type in ['test-stock', 'test-mix']:
            universe = universe_dict['test']
        else:
            raise KeyError('Error argument `meta_type`, should be in (train, valid-time, valid-stock, valid-mix, test-time, test-stock, test-mix)')

        if meta_type in ['train', 'valid-stock', 'test-stock']:
            date1 = ds_config['start_date']
            date2 = ds_config['train_date']
        elif meta_type in ['valid-time', 'valid-mix']:
            date1 = ds_config['train_date']
            date2 = ds_config['valid_date']
        elif meta_type in ['test-time', 'test-mix']:
            date1 = ds_config['valid_date']
            date2 = ds_config['test_date']
        else:
            raise KeyError('Error argument `meta_type`, should be in (train, valid-time, valid-stock, valid-mix, test-time, test-stock, test-mix)')

        iterator = [p for p in ps if p.name.strip('.csv') in universe]
        for p in tqdm(iterator, total=len(iterator), desc=f'Processing data for {self.meta_type}'):    
            stock_symbol = p.name.rstrip('.csv')
            df_single = self.load_single_stock(p)
            cond = df_single['date'].between(date1, date2)
            df_single = df_single.loc[cond].reset_index(drop=True)
            
            self.data[stock_symbol] = df_single


        self.n_stocks = len(universe)


    def load_single_stock(self, p: Path | str):
        def longterm_trend(x: pd.Series, k:int):
            return (x.rolling(k).sum().div(k*x) - 1) * 100

        df = pd.read_csv(p)
        df['Date'] = pd.to_datetime(df['Date'])
        df = df.sort_values('Date').reset_index(drop=True)
        if 'Unnamed' in df.columns:
            df.drop(columns=df.columns[7], inplace=True)
        if 'Original_Open' in df.columns:
            df.rename(columns={'Original_Open': 'Open', 'Open': 'Adj Open'}, inplace=True)

        # Open, High, Low
        z1 = (df.loc[:, ['Open', 'High', 'Low']].div(df['Close'], axis=0) - 1).rename(
            columns={'Open': 'open', 'High': 'high', 'Low': 'low'}) * 100
        # Close
        z2 = df[['Close']].pct_change().rename(columns={'Close': 'close'}) * 100
        # Adj Close
        z3 = df[['Adj Close']].pct_change().rename(columns={'Adj Close': 'adj_close'}) * 100

        z4 = []
        for k in [5, 10, 15, 20, 25, 30]:
            z4.append(df[['Adj Close']].apply(longterm_trend, k=k).rename(columns={'Adj Close': f'zd{k}'}))

        df_pct = pd.concat([df['Date'], z1, z2, z3] + z4, axis=1).rename(columns={'Date': 'date'})
        cols_max = df_pct.columns[df_pct.isnull().sum() == df_pct.isnull().sum().max()]
        df_pct = df_pct.loc[~df_pct[cols_max].isnull().values, :]

        return df_pct

    def sliding_window_idx(self, df_single, window_size):
    
        if len(df_single) >= window_size:
            x_spt_task = []
            y_spt_task = []
            x_qry_task = []
            y_qry_task = []

            for i in range(len(df_single)-window_size-self.n_support-self.n_query+1):
                x_spt = []
                y_spt = []
                x_qry = []
                y_qry = []

                for j in range(self.n_support+self.n_query):
                    if j < self.n_support:
                        spt_idx = [idx for idx in range(i+j, i+j+window_size)]
                        x_spt.append(spt_idx)
                        y_spt.append(i+j+window_size)

                    else:
                        qry_idx = [idx for idx in range(i+j, i+j+window_size)]
                        x_qry.append(qry_idx)
                        y_qry.append(i+j+window_size)

                x_spt_task.append(x_spt)
                y_spt_task.append(y_spt)
                x_qry_task.append(x_qry)
                y_qry_task.append(y_qry)
            return x_spt_task, y_spt_task, x_qry_task, y_qry_task
    
    def generate_data(self,df_single, x_spt_task, y_spt_task, x_qry_task, y_qry_task):
        num_task = len(x_spt_task)
        support_task = []
        support_labels = []
        query_task = []
        query_labels = []
        for i in range(num_task):
            support_inputs = []
            query_inputs = []
            for j in range(self.n_support):
                support_inputs.append(df_single.iloc[x_spt_task[i][j]].to_numpy()[:, 1:].astype(np.float64))

            support_labels.append(df_single['close'].iloc[y_spt_task[i]].to_numpy().astype(np.float64))
            support_task.append(np.array(support_inputs))
            for k in range(self.n_query):
                query_inputs.append(df_single.iloc[x_qry_task[i][k]].to_numpy()[:, 1:].astype(np.float64))
            query_labels.append(df_single['close'].iloc[y_qry_task[i]].to_numpy().astype(np.float64))
            query_task.append(np.array(query_inputs))   

        return support_task, support_labels, query_task, query_labels
    
    @property
    def symbols(self):
        return list(self.data.keys())
    
    def generate_all_task(self):
        all_tasks = dict()
        for window in self.window_sizes:
            all_tasks[window] = self.generate_all_task_per_window(window)
        self.all_tasks = all_tasks

    def generate_all_task_per_window(self,window_size):
        
        all_window_tasks = dict(
                query = [],
                query_labels = [],
                support = [],
                support_labels = [],
            )
        for symbol in self.symbols:
            df = self.data[symbol]
            x_spt_task, y_spt_task, x_qry_task, y_qry_task = self.sliding_window_idx(df, window_size)
            support_inputs, support_labels, query_inputs, query_labels = self.generate_data(df, x_spt_task, y_spt_task, x_qry_task, y_qry_task)
            all_window_tasks['query'].extend(query_inputs)
            all_window_tasks['query_labels'].extend(query_labels)
            all_window_tasks['support'].extend(support_inputs)
            all_window_tasks['support_labels'].extend(support_labels)
        
        all_window_tasks['query'] = np.array(all_window_tasks['query'])
        all_window_tasks['query_labels'] = np.array(all_window_tasks['query_labels'])
        all_window_tasks['support'] = np.array(all_window_tasks['support'])
        all_window_tasks['support_labels'] = np.array(all_window_tasks['support_labels'])
        return all_window_tasks
    
    def generate_batch_task(self, all_tasks):
        batch_tasks = dict(
                query = [],
                query_labels = [],
                support = [],
                support_labels = [],
            )

        
        if len(self.window_sizes) > 1:
            window_size = random.choice(self.window_sizes)
        else:
            window_size = self.window_sizes[0]
               
        num_task = len(all_tasks[window_size]['query'])
        batch_idx = random.sample(list(range(num_task)), self.batch_size)
        batch_tasks['query'] = all_tasks[window_size]['query'][batch_idx]
        batch_tasks['query_labels'] = all_tasks[window_size]['query_labels'][batch_idx]
        batch_tasks['support'] = all_tasks[window_size]['support'][batch_idx]
        batch_tasks['support_labels'] = all_tasks[window_size]['support_labels'][batch_idx]

        return batch_tasks, window_size

    def update_q_idx_dist(self, q_target):
        self.q_dist[q_target] += 1

    def reset_q_idx_dist(self):
        self.q_dist = Counter()

In [10]:
data_dir = "C:/Users\/david/yjhwang/TEAP/data"
dtype = "kdd17"

In [11]:
data_dir = Path(data_dir).resolve()
ds_info = {
    # train: (Jan-01-2007 to Jan-01-2015)
    # val: (Jan-01-2015 to Jan-01-2016)
    # test: (Jan-01-2016 to Jan-01-2017)
    'kdd17': {
        'path': data_dir / 'kdd17/price_long_50',
        'date': data_dir / 'kdd17/trading_dates.csv',
        'universe': data_dir / 'kdd17/stock_universe.json', 
        'start_date': '2007-01-01',
        'train_date': '2015-01-01', 
        'valid_date': '2016-01-01', 
        'test_date': '2017-01-01',
    },
    # train: (Jan-01-2014 to Aug-01-2015)
    # val: (Aug-01-2015 to Oct-01-2015)
    # test: (Oct-01-2015 to Jan-01-2016)
    'acl18': {
        'path': data_dir / 'stocknet-dataset/price/raw',
        'date': data_dir / 'stocknet-dataset/price/trading_dates.csv',
        'universe': data_dir / 'stocknet-dataset/stock_universe.json',
        'start_date': '2014-01-01',
        'train_date': '2015-08-01', 
        'valid_date': '2015-10-01', 
        'test_date': '2016-01-01',
    }
}
ds_config = ds_info[dtype]

In [16]:
meta_train = StockRegressionDataset(data_dir = data_dir, n_query = 1)
meta_valid_time = StockRegressionDataset(meta_type='valid-time', data_dir = data_dir, n_query = 1)
meta_valid_entity = StockRegressionDataset(meta_type='valid-stock', data_dir = data_dir, n_query = 1)
meta_valid_mix = StockRegressionDataset(meta_type='valid-mix', data_dir = data_dir, n_query = 1)
meta_test_time = StockRegressionDataset(meta_type='test-time', data_dir = data_dir, n_query = 1)
meta_test_entity = StockRegressionDataset(meta_type='test-stock', data_dir = data_dir, n_query = 1)
meta_test_mix = StockRegressionDataset(meta_type='test-mix', data_dir = data_dir, n_query = 1)

Processing data for train: 100%|██████████████████████████████████████████████████████████| 35/35 [00:02<00:00, 16.45it/s]
Processing data for valid-time: 100%|█████████████████████████████████████████████████████| 35/35 [00:02<00:00, 15.84it/s]
Processing data for valid-stock: 100%|████████████████████████████████████████████████████| 10/10 [00:00<00:00, 29.70it/s]
Processing data for valid-mix: 100%|██████████████████████████████████████████████████████| 10/10 [00:00<00:00, 40.84it/s]
Processing data for test-time: 100%|██████████████████████████████████████████████████████| 35/35 [00:02<00:00, 15.71it/s]
Processing data for test-stock: 100%|███████████████████████████████████████████████████████| 5/5 [00:00<00:00, 11.80it/s]
Processing data for test-mix: 100%|█████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 13.14it/s]


In [43]:
meta_train.generate_all_task()
meta_valid_time.generate_all_task()
meta_valid_entity.generate_all_task()
meta_valid_mix.generate_all_task()
meta_test_time.generate_all_task()
meta_test_entity.generate_all_task()
meta_test_mix.generate_all_task()

## MetricRecorder

In [44]:
class RegressionMetricRecorder(nn.Module):
    def __init__(self):
        super().__init__()
        cs = tm.MetricCollection({
            'Support_MSE': tm.MeanMetric(), 
            'Support_MAE': tm.MeanMetric(),
            'Support_MAPE': tm.MeanMetric(),
            'Query_MSE': tm.MeanMetric(), 
            'Query_MAE': tm.MeanMetric(),
            'Query_MAPE': tm.MeanMetric(),
        })

        self.metrics = cs.clone()
    @property
    def keys(self):
        return list(self.metrics.keys())
    
    def update(self, key, scores):
        self.metrics[key].update(scores)
            
    def compute(self, prefix: str):
        results = {}
        for k in self.keys:
            m = self.metrics[k].compute()
            if isinstance(m, torch.Tensor):
                m = m.cpu().detach().numpy()
            results[f'{prefix}-{k}'] = m
        return results
    
    def reset(self):
        for k in self.keys:
            self.metrics[k].reset()
            
class RegressionMetricTaskRecorder(nn.Module):
    def __init__(self):
        super().__init__()
        cs = tm.MetricCollection({
            'Support_MSE': tm.MeanSquaredError(), 
            'Support_MAE': tm.MeanAbsoluteError(),
            'Support_MAPE': tm.MeanAbsolutePercentageError(),
            'Query_MSE': tm.MeanSquaredError(), 
            'Query_MAE': tm.MeanAbsoluteError(),
            'Query_MAPE': tm.MeanAbsolutePercentageError(),
        })
        
        self.metrics = cs.clone()
        
    @property
    def keys(self):
        return list(self.metrics.keys())

## Model

## Single Step ALSTM

In [45]:
class LSTM(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, num_layers: int):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=False)
        self.lnorm = nn.LayerNorm(hidden_size)
    
    def forward(self, x: torch.Tensor):
        # x: (B, T, I)
        o, (h, _) = self.lstm(x) # o: (B, T, H) / h: (1, B, H)
        normed_context = self.lnorm(h[-1, :, :])
        return normed_context

In [46]:
class LSTMAttention(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, num_layers: int):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=False)
        self.lnorm = nn.LayerNorm(hidden_size)
        
    def forward(self, x: torch.Tensor, rt_attn: bool=False):
        # x: (B, T, I)
        o, (h, _) = self.lstm(x) # o: (B, T, H) / h: (1, B, H)
        h = h[-1, :, :]  # (B, H)
        score = torch.bmm(o, h.unsqueeze(-1)) # (B, T, H) x (B, H, 1)
        attn = torch.softmax(score, 1).squeeze(-1)  # (B, T)
        context = torch.bmm(attn.unsqueeze(1), o).squeeze(1)  # (B, 1, T) x (B, T, H)
        normed_context = self.lnorm(context)  # (B, H)
        if rt_attn:
            return normed_context, attn
        else:
            return normed_context, None

In [47]:
class PanelRegressionModel(nn.Module):
    def __init__(
        self, 
        feature_size: int, 
        embed_size: int,
        output_size: int,
        num_layers: int, 
        drop_rate: float, 
        device: str
    ):
        super().__init__()
        self.embed_size = embed_size
        self.output_size = output_size

        # Network
        self.dropout = nn.Dropout(drop_rate)
        self.lstm_encoder = LSTMAttention(input_size=feature_size, hidden_size=embed_size, num_layers=num_layers)
        self.layer_norm = nn.LayerNorm(embed_size)

        self.encoder = nn.Sequential(
            nn.Linear(embed_size, output_size),
            nn.ReLU(inplace=False)
        )


        # Loss
        self.loss_fn = nn.MSELoss()

        # Meta Mode Support / Query
        self._mode_query(False)

    # Recoder
    # self.recorder = MetricRecorder().to(device)

    def meta_train(self):
        self.train()

    def meta_eval(self):
        self.manual_model_eval()

    def _mode_query(self, mode: bool=True):
        # check if is support of query
        self.is_query = mode

    def manual_model_eval(self, mode: bool=False):
        """
        [PyTorch Issue] RuntimeError: cudnn RNN backward can only be called in training mode
        cannot use `model.eval()`. 
        see https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch
        """
        for module in self.children():
            self.training = mode
            if isinstance(module, nn.Dropout) or isinstance(module, nn.LayerNorm):
                module.train(mode)

    def encode_lstm(self, inputs: torch.Tensor, rt_attn: bool=False):
        """forward data by each stock to avoid trained by other stocks
        - B: number of samples (n_support if meta-learning)
        - T: window size
        - I: input size
        - E: embedding size
        - M: M = N * K

        Args:
            inputs: (B, T, I).
            - support: (B, T, I) B: n_support
            - query: (B, T, I) B: n_query

        Returns:
            encoded: (B, E)
            attn: (B, T)
        """
        B, T, I = inputs.size() # B = n_support
        inputs = self.dropout(inputs)
        encoded, attn = self.lstm_encoder(inputs, rt_attn)  # encoded: (B, E), attn: (B, T)
        encoded = self.layer_norm(encoded)
        return encoded, attn

    def forward_encoder(self, inputs: torch.Tensor, rt_attn: bool=True):
        """Forward Encoder: from `inputs` to `z`
        - B: number of n_support
        - T: window size
        - E: embedding size
        - H: hidden size

        Returns:
            l: (B, O) # O: output feature dim
            attn: (B, T). attention weights for each inputs.

        """
        # support l: (B, T, E), attn: (B, T)
        # query l: (B, T, E), attn: (B, T)

        l, attn = self.encode_lstm(inputs, rt_attn=rt_attn)
        e = self.encoder(l)  # e: (B, N, K, 2)

        return e, attn

    def forward(
            self, data , # data = torch.tensor
            rt_attn: bool=False
        ):

        e , attn = self.forward_encoder(data) # e: (B, O)
        e = e.squeeze(dim=-1)
        return e

In [48]:
feature_size: 11 
embed_size: 32
output_size: 1
num_layers: 1 
drop_rate: 0.1
device: 'cuda'

In [49]:
model = PanelRegressionModel(feature_size=11, embed_size=32, output_size=1, num_layers=1, drop_rate=0, device='cuda')

## MAML Model

In [121]:
class Maml_Regression_Trainer(nn.Module):
    def __init__(
        self, exp_name, log_dir, task_type,  model, batch_size,
        n_inner_step, total_steps, 
        n_valid_step, every_valid_step, print_step,
        inner_lr, outer_lr, device, clip_value, test_window_size):
        
        super(Maml_Regression_Trainer, self).__init__()
        self.exp_name = exp_name
        self.log_dir = Path(log_dir).resolve()
        self.device = device
        self.model = model.to(self.device)
        if task_type == "classification":
            self.loss_fn = nn.NLLLoss() 
        elif task_type == "regression":
            self.loss_fn = nn.MSELoss()
            
        self.n_inner_step = n_inner_step
        self.total_steps = total_steps
        self.n_valid_step = n_valid_step
        self.every_valid_step = every_valid_step
        self.print_step = print_step
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.batch_size = batch_size
        self.test_window_size = test_window_size
        self.keep_weights = self.clone_weight(self.model)
        self.meta_optim = torch.optim.Adam(
            self.model.parameters(), 
            lr=self.outer_lr
        )
        if self.device == 'cuda':
            self.cuda()
        
        # Recoder
        self.train_recorder = RegressionMetricRecorder().to(device)
        self.valid_recorder = RegressionMetricRecorder().to(device)
        self.test_recorder = RegressionMetricRecorder().to(device)
        
        self.task_recorder = RegressionMetricTaskRecorder().to(device)
 

        
    def init_experiments(self, exp_num=None, record_tensorboard: bool=True):
        # check if exp exists
        exp_dirs = sorted(list(self.log_dir.glob(f'{self.exp_name}_*')))
        if exp_num is None:
            exp_num = int(exp_dirs[-1].name[len(self.exp_name)+1:]) if exp_dirs else 0
            self.exp_num = exp_num + 1
        else:
            self.exp_num = exp_num
        self.exp_dir = self.log_dir / f'{self.exp_name}_{self.exp_num}'
        if record_tensorboard:
            self.writer = SummaryWriter(str(self.exp_dir))
        else:
            self.writer = None
        self.ckpt_path = self.exp_dir / 'checkpoints'
        self.ckpt_step_train_path =  self.ckpt_path / 'step' / 'train'
        self.ckpt_step_valid_path =  self.ckpt_path / 'step' / 'valid'
        for p in [self.ckpt_path, self.ckpt_step_train_path, self.ckpt_step_valid_path]:
            if not p.exists():
                p.mkdir(parents=True)    
    
    def get_acc(self,y_true, y_pred):
        correct = torch.eq(y_pred, y_true).sum().item()
        acc = correct/ len(y_true)
        return acc

    def clone_weight(self, model):
        return {k: v.clone() for k, v in model.state_dict().items()}

    def meta_update(self, dummy_loss, sum_grads):
        # Update theta_parameter by sum_gradients
        hooks = []
        for k,v in enumerate(self.model.parameters()):
            def closure():
                key = k
                return lambda grad: sum_grads[key]
            hooks.append(v.register_hook(closure()))

        self.meta_optim.zero_grad()
        dummy_loss.backward()
        self.meta_optim.step()

        for h in hooks:
            h.remove()
                
    # inner loop per 1 task
    def inner_loop(self, support_x, support_y, query_x, query_y, is_meta_train):
        updated_state_dict = {k: v.clone() for k, v in self.keep_weights.items()}
        for i in range(self.n_inner_step):
            if i > 0:
                self.model.load_state_dict(updated_state_dict)
            support_e = self.model(support_x)
            s_loss = self.loss_fn(support_e, support_y)
            grad = torch.autograd.grad(
                    s_loss, 
                    self.model.parameters(),
                    create_graph=True,
                )
            for i, (k,w) in enumerate(updated_state_dict.items()):
                updated_state_dict[k] = updated_state_dict[k] - self.inner_lr * grad[i].data
        
        s_mse = self.task_recorder.metrics['Support_MSE'](support_e, support_y)
        s_mae = self.task_recorder.metrics['Support_MAE'](support_e, support_y)
        s_mape = self.task_recorder.metrics['Support_MAPE'](support_e, support_y)
       
        self.model.load_state_dict(updated_state_dict)
        query_e = self.model(query_x)
        q_loss = self.loss_fn(query_e, query_y)
        
        q_mse = self.task_recorder.metrics['Query_MSE'](query_e, query_y)
        q_mae = self.task_recorder.metrics['Query_MAE'](query_e, query_y)
        q_mape = self.task_recorder.metrics['Query_MAPE'](query_e, query_y)
        
        
        if is_meta_train:
            q_grad = torch.autograd.grad(q_loss, self.model.parameters(), create_graph=True)
        else:
            q_grad = None
        
        
        return s_mse, s_mae, s_mape, q_mse, q_mae, q_mape, q_grad, query_e

    # outer loop per batch
    def outer_loop(self, meta_dataset):
        self.model.meta_train()
        batch_task, window_size = meta_dataset.generate_batch_task(all_tasks=meta_dataset.all_tasks) # PanelDataDict
        train_tasks  = PanelDataDict(batch_task,window_size = window_size)
        train_tasks.to(self.device)
        all_q_grads = []

        self.keep_weights = self.clone_weight(self.model)
        
        for i  in range(self.batch_size):
            x_spt = train_tasks['support'][i]
            y_spt = train_tasks['support_labels'][i]
            x_qry = train_tasks['query'][i]
            y_qry = train_tasks['query_labels'][i]
            s_mse, s_mae, s_mape, q_mse, q_mae, q_mape, q_grad, query_e = self.inner_loop(x_spt, y_spt, x_qry, y_qry, is_meta_train=True)
            self.train_recorder.update('Support_MSE', s_mse)
            self.train_recorder.update('Support_MAE', s_mae)
            self.train_recorder.update('Support_MAPE', s_mape)
            self.train_recorder.update('Query_MSE', q_mse)
            self.train_recorder.update('Query_MAE', q_mae)
            self.train_recorder.update('Query_MAPE', q_mape)
            
            
            all_q_grads.append(q_grad)
            self.model.load_state_dict(self.keep_weights)
            
        
        sum_q_grads = [torch.stack(grads).sum(dim=0) for grads in list(zip(*all_q_grads))]
        
        x_spt = train_tasks['support'][0]
        y_spt = train_tasks['support_labels'][0]
        
        dummy_e = self.model(x_spt)
        dummy_loss = self.loss_fn(dummy_e, y_spt)
        
        self.meta_update(dummy_loss, sum_q_grads)
        return 


    def meta_train(self, meta_trainset,
                meta_validset_time,
                meta_validset_entity,
                meta_validset_mix, 
                print_log: bool=True):
        
        best_eval_mse = 10000.0
        for step in range(self.total_steps):
            self.train_recorder.reset()
            # Meta-Train per epoch
            self.outer_loop(meta_trainset)
            if ( step % self.print_step == 0) or (step == self.total_steps-1):
                prefix = 'Train'
                train_logs = self.train_recorder.compute(prefix)
                cur_eval_mse = train_logs[f'{prefix}-Query_MSE']
                cur_eval_mae = train_logs[f'{prefix}-Query_MAE']
                cur_eval_mape = train_logs[f'{prefix}-Query_MAPE']
                
                self.log_results(train_logs, prefix, step=step, total_steps=self.total_steps, print_log=True)
                torch.save(self.model.state_dict(), str(self.ckpt_step_train_path / f'{step}-{cur_eval_mse:.4f}.ckpt'))

                
            # Meta-Valid
            if (self.every_valid_step != 0):
                if (step % self.every_valid_step == 0) or (step == self.total_steps-1):
                    ref_step = step
                    
                    prefix = 'Valid-time'
                    valid_time_logs, cur_eval_mse_time, cur_eval_mae_time, cur_eval_mape_time = self.meta_valid(self.model, meta_validset_time, prefix, ref_step, self.n_valid_step)
                    
                    prefix = 'Valid-entity'
                    valid_entity_logs, cur_eval_mse_entity, cur_eval_mae_entity, cur_eval_mape_entity = self.meta_valid(self.model, meta_validset_entity, prefix, ref_step, self.n_valid_step)
                    
                    prefix = 'Valid-mix'
                    valid_mix_logs, cur_eval_mse_mix, cur_eval_mae_mix, cur_eval_mape_mix = self.meta_valid(self.model, meta_validset_mix, prefix, ref_step, self.n_valid_step)
                    
                    prefix = 'Valid'
                    cur_eval_mse = (cur_eval_mse_time + cur_eval_mse_entity + cur_eval_mse_mix) / 3
                    cur_eval_mae = (cur_eval_mae_time + cur_eval_mae_entity + cur_eval_mae_mix) / 3
                    cur_eval_mape = (cur_eval_mape_time + cur_eval_mape_entity + cur_eval_mape_mix) / 3
                    valid_final_log = {f'{prefix}-AvgMSE': cur_eval_mse, f'{prefix}-AvgMAE': cur_eval_mae, f'{prefix}-AvgMAPE': cur_eval_mape}
                    self.log_results(valid_final_log, prefix, step=ref_step, total_steps=self.total_steps, print_log=print_log)
                    
                    # save best
                    if (cur_eval_mse < best_eval_mse):
                        best_eval_mse = cur_eval_mse 
                        torch.save(self.model.state_dict(), str(self.ckpt_step_valid_path / f'{ref_step:06d}-{cur_eval_mse:.4f}.ckpt'))
                    
    def meta_valid(self, model, meta_dataset, prefix, ref_step, n_valid, print_log=True):
        self.valid_recorder.reset()
        valid_logs = self.run_valid(model, meta_dataset, n_valid, prefix)
        self.log_results(valid_logs, prefix, step=ref_step, total_steps=self.total_steps, print_log=print_log)
        cur_eval_mse = valid_logs[f'{prefix}-Query_MSE']
        cur_eval_mae = valid_logs[f'{prefix}-Query_MAE']
        cur_eval_mape = valid_logs[f'{prefix}-Query_MAPE']
        return valid_logs, cur_eval_mse, cur_eval_mae, cur_eval_mape
        
    def meta_test(self, model, meta_dataset,  print_log: bool=True):
        self.test_recorder.reset()
        prefix = meta_dataset.meta_type.capitalize()
        test_logs = self.run_test(model, meta_dataset, prefix)
        self.log_results(test_logs, prefix, step=0, total_steps=0, print_log=print_log)
        eval_mse = test_logs[f'{prefix}-Query_MSE']
        eval_mae = test_logs[f'{prefix}-Query_MAE']
        eval_mape = test_logs[f'{prefix}-Query_MAPE']
        return prefix, eval_mse, eval_mae, eval_mape
    
    def run_valid(self, model, meta_dataset, n_valid, prefix):
        model = model.to(self.device)
        model.meta_eval()
        pregress = tqdm(range(n_valid), total= n_valid, desc=f'Running {prefix}')
     
        for val_idx in pregress:
            batch_task, window_size = meta_dataset.generate_batch_task(all_tasks=meta_dataset.all_tasks) # PanelDataDict
            valid_tasks  = PanelDataDict(batch_task,window_size = window_size)
            valid_tasks.to(self.device)
            for i  in range(self.batch_size):
                x_spt = valid_tasks['support'][i]
                y_spt = valid_tasks['support_labels'][i]
                x_qry = valid_tasks['query'][i]
                y_qry = valid_tasks['query_labels'][i]
                s_mse, s_mae, s_mape, q_mse, q_mae, q_mape, q_grad, query_e = self.inner_loop(x_spt, y_spt, x_qry, y_qry, is_meta_train=False)
                self.valid_recorder.update('Support_MSE', s_mse)
                self.valid_recorder.update('Support_MAE', s_mae)
                self.valid_recorder.update('Support_MAPE', s_mape)
                self.valid_recorder.update('Query_MSE', q_mse)
                self.valid_recorder.update('Query_MAE', q_mae)
                self.valid_recorder.update('Query_MAPE', q_mape)
        
        
        pregress.close()
        valid_logs = self.valid_recorder.compute(prefix)       
        return valid_logs
    
    def run_test(self, model, meta_dataset, prefix):
        model = model.to(self.device)
        model.meta_eval()
        test_all_tasks = meta_dataset.all_tasks[self.test_window_size]
        test_tasks = PanelDataDict(test_all_tasks, window_size = self.test_window_size)
        test_tasks.to(self.device)
        pregress = tqdm(range(len(test_tasks['query'])), total= len(test_tasks['query']), desc=f'Running {prefix}')
        for test_idx in pregress:
            x_spt = test_tasks['support'][test_idx]
            y_spt = test_tasks['support_labels'][test_idx]
            x_qry = test_tasks['query'][test_idx]
            y_qry = test_tasks['query_labels'][test_idx]
            s_mse, s_mae, s_mape, q_mse, q_mae, q_mape, q_grad, query_e = self.inner_loop(x_spt, y_spt, x_qry, y_qry, is_meta_train=False)
            self.test_recorder.update('Support_MSE', s_mse)
            self.test_recorder.update('Support_MAE', s_mae)
            self.test_recorder.update('Support_MAPE', s_mape)
            self.test_recorder.update('Query_MSE', q_mse)
            self.test_recorder.update('Query_MAE', q_mae)
            self.test_recorder.update('Query_MAPE', q_mape)
        
        pregress.close()
        test_logs = self.test_recorder.compute(prefix)       
        return test_logs
    
    def log_results(self, logs, prefix, step, total_steps, print_log=False):
        
        for log_string, value in logs.items():       
            if self.writer is not None:
                self.writer.add_scalar(log_string, value, step)
                
        def extract(prefix, key, logs):
            mean = logs[f'{prefix}-{key}']
            s = ''
            s += f'{mean:.4f}'
            return s

        if print_log:
            only_one_to_print = True if prefix in ['Valid', 'Test'] else False

            if only_one_to_print:
                avgmse = extract(prefix, 'AvgMSE', logs)
                avgmae = extract(prefix, 'AvgMAE', logs)
                avgmape = extract(prefix, 'AvgMAPE', logs)
         
                print(f'[Meta {prefix}] Result - AvgMSE: {avgmse}, AvgMAE: {avgmae}, AvgMAPE: {avgmape} ')
                print()

            else:
                s_mse = extract(prefix, 'Support_MSE', logs)
                s_mae = extract(prefix, 'Support_MAE', logs)
                s_mape = extract(prefix, 'Support_MAPE', logs)
                q_mse = extract(prefix, 'Query_MSE', logs)
                q_mae = extract(prefix, 'Query_MAE', logs)
                q_mape = extract(prefix, 'Query_MAPE', logs)

                print(f'[Meta {prefix}]({step+1}/{total_steps})')
                print(f'  - [Support] MSE: {s_mse}, MAE: {s_mae}, MAPE: {s_mape}')
                print(f'  - [Query] MSE: {q_mse}, MAE: {q_mae}, MAPE: {q_mape}')
                print()
                
    def get_best_results(self, exp_num, record_tensorboard: bool=True):
        self.init_experiments(exp_num=exp_num, record_tensorboard=record_tensorboard)
        best_ckpt = sorted(
            (self.ckpt_step_valid_path).glob('*.ckpt'),
            key=lambda x: x.name.split('-')[1], 
            reverse=True
        )[0]
        
        best_step, train_loss = best_ckpt.name.rstrip('.ckpt').split('-')
        state_dict = torch.load(best_ckpt)
        return int(best_step), float(train_loss), state_dict

In [122]:
class RegressionMetricRecorder(nn.Module):
    def __init__(self):
        super().__init__()
        cs = tm.MetricCollection({
            'Support_MSE': tm.MeanMetric(), 
            'Support_MAE': tm.MeanMetric(),
            'Support_MAPE': tm.MeanMetric(),
            'Query_MSE': tm.MeanMetric(), 
            'Query_MAE': tm.MeanMetric(),
            'Query_MAPE': tm.MeanMetric(),
        })

        self.metrics = cs.clone()
    @property
    def keys(self):
        return list(self.metrics.keys())
    
    def update(self, key, scores):
        self.metrics[key].update(scores)
            
    def compute(self, prefix: str):
        results = {}
        for k in self.keys:
            m = self.metrics[k].compute()
            if isinstance(m, torch.Tensor):
                m = m.cpu().detach().numpy()
            results[f'{prefix}-{k}'] = m
        return results
    
    def reset(self):
        for k in self.keys:
            self.metrics[k].reset()

In [123]:
class RegressionMetricTaskRecorder(nn.Module):
    def __init__(self):
        super().__init__()
        cs = tm.MetricCollection({
            'Support_MSE': tm.MeanSquaredError(), 
            'Support_MAE': tm.MeanAbsoluteError(),
            'Support_MAPE': tm.MeanAbsolutePercentageError(),
            'Query_MSE': tm.MeanSquaredError(), 
            'Query_MAE': tm.MeanAbsoluteError(),
            'Query_MAPE': tm.MeanAbsolutePercentageError(),
        })
        
        self.metrics = cs.clone()
        
    @property
    def keys(self):
        return list(self.metrics.keys())

In [124]:
exp_name = 'kdd17_0'
log_dir = './logging'
task_type = "regression"
model = PanelRegressionModel(feature_size=11, embed_size=32, output_size=1, num_layers=1, drop_rate=0, device='cuda')
batch_size = 64
n_inner_step = 2
total_steps = 2
n_valid_step = 2
every_valid_step = 2
print_step = 1
inner_lr = 0.01
outer_lr = 0.001
device = 'cuda'
clip_value = 0
test_window_size = 5

In [125]:
maml_train = Maml_Regression_Trainer(exp_name,log_dir, task_type, model, batch_size, n_inner_step, total_steps, n_valid_step, every_valid_step, print_step, inner_lr, outer_lr, device, clip_value, test_window_size )

In [126]:
maml_train.init_experiments()

In [127]:
maml_train.meta_train(meta_train, meta_valid_time, meta_valid_entity, meta_valid_mix)

[Meta Train](1/2)
  - [Support] MSE: 2.9967, MAE: 1.1789, MAPE: 3500.3667
  - [Query] MSE: 2.2061, MAE: 1.0476, MAPE: 1.0913



Running Valid-time: 100%|███████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.77it/s]


[Meta Valid-time](1/2)
  - [Support] MSE: 2.1760, MAE: 1.0574, MAPE: 1795.5468
  - [Query] MSE: 3.0426, MAE: 1.2335, MAPE: 1.2080



Running Valid-entity: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.88it/s]


[Meta Valid-entity](1/2)
  - [Support] MSE: 4.1395, MAE: 1.2132, MAPE: 1.0871
  - [Query] MSE: 3.8890, MAE: 1.2704, MAPE: 1.0731



Running Valid-mix: 100%|████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.82it/s]


[Meta Valid-mix](1/2)
  - [Support] MSE: 3.1395, MAE: 1.1972, MAPE: 1.1013
  - [Query] MSE: 3.6928, MAE: 1.2906, MAPE: 1.0104

[Meta Valid] Result - AvgMSE: 3.5415, AvgMAE: 1.2648, AvgMAPE: 1.0972 

[Meta Train](2/2)
  - [Support] MSE: 5.5980, MAE: 1.3855, MAPE: 1.0214
  - [Query] MSE: 5.4727, MAE: 1.3126, MAPE: 1.0275



Running Valid-time: 100%|███████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.84it/s]


[Meta Valid-time](2/2)
  - [Support] MSE: 2.1751, MAE: 1.0351, MAPE: 26.6098
  - [Query] MSE: 1.8135, MAE: 1.0152, MAPE: 1.4320



Running Valid-entity: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.98it/s]


[Meta Valid-entity](2/2)
  - [Support] MSE: 3.1640, MAE: 1.2556, MAPE: 1.2167
  - [Query] MSE: 5.8260, MAE: 1.4548, MAPE: 1.1026



Running Valid-mix: 100%|████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.71it/s]

[Meta Valid-mix](2/2)
  - [Support] MSE: 2.9875, MAE: 1.1575, MAPE: 415.5214
  - [Query] MSE: 2.2653, MAE: 1.1087, MAPE: 1.2722

[Meta Valid] Result - AvgMSE: 3.3016, AvgMAE: 1.1929, AvgMAPE: 1.2689 






In [128]:
maml_train.meta_test(maml_train.model, meta_test_time)

Running Test-time: 100%|█████████████████████████████████████████████████████████████| 8470/8470 [01:21<00:00, 103.61it/s]

[Meta Test-time](1/0)
  - [Support] MSE: 2.4698, MAE: 1.0361, MAPE: 920.6209
  - [Query] MSE: 2.3654, MAE: 1.0025, MAPE: 164.9027






('Test-time',
 array(2.3653708, dtype=float32),
 array(1.0024823, dtype=float32),
 array(164.90266, dtype=float32))

In [114]:
def load_single_stock(p: Path | str):
        def longterm_trend(x: pd.Series, k:int):
            return (x.rolling(k).sum().div(k*x) - 1) * 100

        df = pd.read_csv(p)
        df['Date'] = pd.to_datetime(df['Date'])
        df = df.sort_values('Date').reset_index(drop=True)
        if 'Unnamed' in df.columns:
            df.drop(columns=df.columns[7], inplace=True)
        if 'Original_Open' in df.columns:
            df.rename(columns={'Original_Open': 'Open', 'Open': 'Adj Open'}, inplace=True)

        # Open, High, Low
        z1 = (df.loc[:, ['Open', 'High', 'Low']].div(df['Close'], axis=0) - 1).rename(
            columns={'Open': 'open', 'High': 'high', 'Low': 'low'}) * 100
        # Close
        z2 = df[['Close']].pct_change().rename(columns={'Close': 'close'}) * 100
        # Adj Close
        z3 = df[['Adj Close']].pct_change().rename(columns={'Adj Close': 'adj_close'}) * 100

        z4 = []
        for k in [5, 10, 15, 20, 25, 30]:
            z4.append(df[['Adj Close']].apply(longterm_trend, k=k).rename(columns={'Adj Close': f'zd{k}'}))

        df_pct = pd.concat([df['Date'], z1, z2, z3] + z4, axis=1).rename(columns={'Date': 'date'})
        cols_max = df_pct.columns[df_pct.isnull().sum() == df_pct.isnull().sum().max()]
        df_pct = df_pct.loc[~df_pct[cols_max].isnull().values, :]

        return df_pct

In [17]:
ds_config['path']

WindowsPath('C:/Users/david/yjhwang/TEAP/data/kdd17/price_long_50')

In [18]:
ps = list((ds_config['path']).glob('*.csv'))

In [115]:
p = ps[0]

In [112]:
p_2 = ps[1]

In [20]:
stock_symbol = p.name.rstrip('.csv')

In [113]:
stock_symbol_2 = p_2.name.rstrip('.csv')

In [163]:
symbol = [p,p_2]

In [116]:
df_1 = load_single_stock(p)

In [117]:
df_2 = load_single_stock(p_2)

In [118]:
df_1.shape

(2489, 12)

In [119]:
df_2.shape

(2489, 12)

In [29]:
def sliding_window_idx(df_single, window_size, n_support, n_query):
    if len(df_single) >= window_size:
        x_spt_task = []
        y_spt_task = []
        x_qry_task = []
        y_qry_task = []

        for i in range(len_df-window_size-n_support-n_query+1):
            x_spt = []
            y_spt = []
            x_qry = []
            y_qry = []

            for j in range(n_support+n_query):
                if j < n_support:
                    spt_idx = [idx for idx in range(i+j, i+j+window_size)]
                    x_spt.append(spt_idx)
                    y_spt.append(i+j+window_size)

                else:
                    qry_idx = [idx for idx in range(i+j, i+j+window_size)]
                    x_qry.append(qry_idx)
                    y_qry.append(i+j+window_size)

            x_spt_task.append(x_spt)
            y_spt_task.append(y_spt)
            x_qry_task.append(x_qry)
            y_qry_task.append(y_qry)
        return x_spt_task, y_spt_task, x_qry_task, y_qry_task

In [153]:
def generate_data(df_single, x_spt_task, y_spt_task, x_qry_task, y_qry_task, n_support, n_query):
    num_task = len(x_spt_task)
    support_task = []
    support_labels = []
    query_task = []
    query_labels = []
    for i in range(len(x_spt_task)):
        support_inputs = []
        query_inputs = []
        for j in range(n_support):
            support_inputs.append(df_single.iloc[x_spt_task[i][j]].to_numpy()[:, 1:].astype(np.float64))
        
        support_labels.append(df_single['close'].iloc[y_spt_task[i]].to_numpy().astype(np.float64))
        support_task.append(np.array(support_inputs))
        for k in range(n_query):
            query_inputs.append(df_single.iloc[x_qry_task[i][k]].to_numpy()[:, 1:].astype(np.float64))
        query_labels.append(df_single['close'].iloc[y_qry_task[i]].to_numpy().astype(np.float64))
        query_task.append(np.array(query_inputs))   
        
    return support_task, support_labels, query_task, query_labels

In [154]:
window_size, n_support, n_query = 3, 5, 1

In [167]:
x_spt_task, y_spt_task, x_qry_task, y_qry_task = sliding_window_idx(df_2, window_size, n_support, n_query)

In [168]:
support_inputs, support_labels, query_inputs, query_labels = generate_data(df_2, x_spt_task, y_spt_task, x_qry_task, y_qry_task, n_support, n_query)

In [169]:
print(f'support input: {len(support_inputs)}') # (n_task, n_support, widow_size, feature_dim)
print(f'support label: {len(support_labels)}') # (n_task, n_support)
print(f'query input: {len(query_inputs)}')
print(f'query label: {len(query_labels)}')

support input: 12
support label: 12
query input: 12
query label: 12


In [158]:
support_inputs[0].shape

(5, 3, 11)

In [159]:
support_labels[0].shape

(5,)

In [160]:
support_labels[0]

array([ 1.26134388,  3.84167883,  0.34753364, -0.49156072, -0.62871896])

In [128]:
all_tasks = dict(
            query = [],
            query_labels = [],
            support = [],
            support_labels = [],
        )

In [193]:
def generate_all_task(symbol):
    all_tasks = dict(
            query = [],
            query_labels = [],
            support = [],
            support_labels = [],
        )
    for s in symbol:
        df = load_single_stock(s)
        x_spt_task, y_spt_task, x_qry_task, y_qry_task = sliding_window_idx(df, window_size, n_support, n_query)
        support_inputs, support_labels, query_inputs, query_labels = generate_data(df, x_spt_task, y_spt_task, x_qry_task, y_qry_task, n_support, n_query)
        all_tasks['query'].extend(query_inputs)
        all_tasks['query_labels'].extend(query_labels)
        all_tasks['support'].extend(support_inputs)
        all_tasks['support_labels'].extend(support_labels)
        num_task = len(all_tasks['query'])
    return num_task, all_tasks

In [194]:
def generate_batch_task(batch_size, all_tasks, num_task):
    batch_tasks = dict(
            query = [],
            query_labels = [],
            support = [],
            support_labels = [],
        )
    for k, v in all_tasks.items():
            all_tasks[k] = np.array(v)
            
    batch_idx = random.sample(list(range(num_task)), batch_size)
    batch_tasks['query'] = all_tasks['query'][batch_idx]
    batch_tasks['query_labels'] = all_tasks['query_labels'][batch_idx]
    batch_tasks['support'] = all_tasks['support'][batch_idx]
    batch_tasks['support_labels'] = all_tasks['support_labels'][batch_idx]
    
    return batch_tasks

In [196]:
num_task, all_tasks = generate_all_task(symbol)

In [206]:
batch_tasks = generate_batch_task(batch_size=2, all_tasks=all_tasks, num_task=num_task)

In [213]:
class StockRegressionDataDict(dict):
    def __init__(self, data, window_size):
        self.window_size = window_size
        self._set_state(f'numpy')
        for k, v in data.items():
            data[k] = np.array(v)
        
        self.n_stocks = len(v)
        super().__init__(data)
    
    def tensor_fn(self, value, key):
        return torch.FloatTensor(value)

    def _set_state(self, state: str):
        self.state = state

    def to(self, device: None | str=None):
        if device is None:
            device = torch.device('cpu')
        else:
            device = torch.device(device)
        self._set_state(f'tensor.{device}')
        for key in self.keys():
            value = self.__getitem__(key)
            tvalue = self.tensor_fn(value, key)
            self.__setitem__(key, tvalue.to(device)) 
        
    def numpy(self):
        self._set_state('numpy')
        for key in self.keys():
            tvalue = self.__getitem__(key)
            if not isinstance(tvalue, np.ndarray): 
                self.__setitem__(key, tvalue.detach().numpy())

    def __str__(self):
        s = f'StockDataDict(T={self.window_size}, {self.state})\n'
        for i, key in enumerate(self.keys()):
            value = self.__getitem__(key)
            s += f'- {key}: {value.shape}'
            s += '' if i == len(self.keys())-1 else '\n'
        return s

In [217]:
data = StockRegressionDataDict(batch_tasks, window_size = 3)

In [218]:
data.to('cuda')

In [219]:
data

{'query': tensor([[[[-0.5876,  1.1242, -2.7082,  0.7983,  0.7983,  2.5192,  3.2013,
             2.0082,  0.6745, -0.4854, -1.2936],
           [ 1.2098,  1.2098, -2.0592, -0.7409, -0.7409,  2.1776,  3.6396,
             2.7473,  1.5779,  0.4211, -0.4736],
           [ 1.6715,  3.1308,  0.0000, -2.9858, -2.9858,  3.6827,  6.2006,
             5.6602,  4.5702,  3.5765,  2.6523]]],
 
 
         [[[-3.4804,  1.4358, -3.8020,  2.8956,  2.8956, -0.4870, -0.3182,
            -1.0522, -1.4536, -1.4592, -1.1601],
           [ 1.5923,  2.4939, -0.2342, -1.8952, -1.8953,  0.5784,  1.6310,
             0.7993,  0.4894,  0.4051,  0.6065],
           [-0.4981,  2.6993, -0.6487,  1.0655,  1.0655, -0.9893,  0.7322,
            -0.0278, -0.4790, -0.6098, -0.5383]]]], device='cuda:0'),
 'query_labels': tensor([[-1.6981],
         [ 2.1664]], device='cuda:0'),
 'support': tensor([[[[-0.1697,  0.1454, -0.8240, -0.6023, -0.6023, -1.4542, -3.4755,
            -5.0816, -6.3112, -7.0887, -7.2895],
          

In [220]:
print(data)

StockDataDict(T=3, tensor.cuda)
- query: torch.Size([2, 1, 3, 11])
- query_labels: torch.Size([2, 1])
- support: torch.Size([2, 5, 3, 11])
- support_labels: torch.Size([2, 5])


In [207]:
batch_tasks['query'].shape

(2, 1, 3, 11)

In [209]:
batch_tasks['query_labels'].shape

(2, 1)

In [211]:
type(batch_tasks['query'])

numpy.ndarray

In [212]:
type(batch_tasks['query_labels'])

numpy.ndarray

In [208]:
len(all_tasks['query'])

24

In [187]:
all_tasks['support'][[1,2,3]].shape

(3, 5, 3, 11)

In [180]:
print(type(all_tasks['support']))

<class 'numpy.ndarray'>


In [181]:
def tensor_fn(value, key):
    return torch.FloatTensor(value)
    

In [184]:
torch.FloatTensor(all_tasks['support']).shape

torch.Size([24, 5, 3, 11])

In [203]:
device = torch.device('cuda')
for key,value in batch_tasks.items():
            tvalue = tensor_fn(value, key)
            batch_tasks[key] = tvalue.to(device)

In [205]:
batch_tasks

{'query': tensor([[[[ 1.2098,  1.2098, -2.0592, -0.7409, -0.7409,  2.1776,  3.6396,
             2.7473,  1.5779,  0.4211, -0.4736],
           [ 1.6715,  3.1308,  0.0000, -2.9858, -2.9858,  3.6827,  6.2006,
             5.6602,  4.5702,  3.5765,  2.6523],
           [ 0.2699,  3.4278, -0.0270, -1.6981, -1.6981,  3.4062,  7.1498,
             7.1849,  6.3306,  5.3873,  4.4283]]],
 
 
         [[[ 0.9756,  2.4390, -0.2683, -0.6301, -0.6301, -0.4098, -2.3707,
            -3.9382, -5.2610, -6.1951, -6.4398],
           [ 0.5395,  1.0299, -0.0981, -0.5366, -0.5366,  0.4806, -1.4321,
            -3.0799, -4.2962, -5.3163, -5.6384],
           [-0.0489,  0.7828, -1.1742,  0.2452,  0.2452,  0.5039, -1.1448,
            -2.7479, -4.0374, -5.1703, -5.5855]]]], device='cuda:0'),
 'query_labels': tensor([[ 4.1296],
         [-5.0147]], device='cuda:0'),
 'support': tensor([[[[ 0.9756,  2.4390, -0.2683, -0.6301, -0.6301, -0.4098, -2.3707,
            -3.9382, -5.2610, -6.1951, -6.4398],
          

In [None]:
def to(device: None | str=None):
        if device is None:
            device = torch.device('cpu')
        else:
            device = torch.device(device)

        for key in all_tasks.keys():
            tvalue = tensor_fn(value, key)
            self.__setitem__(key, tvalue.to(device)) 

In [172]:
all_tasks['support'][0].shape

(5, 3, 11)

In [170]:
all_tasks['query'][0].shape

(1, 3, 11)

In [36]:
print(f'x_spt_task:{x_spt_task[:2]}')
print(f'y_spt_task: {y_spt_task[:2]}')

x_spt_task:[[[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 6]], [[1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 6], [5, 6, 7]]]
y_spt_task: [[3, 4, 5, 6, 7], [4, 5, 6, 7, 8]]


In [41]:
x_spt_task[0][0]

[0, 1, 2]

In [48]:
support_inputs = []

In [49]:
support_inputs.append(df_single.iloc[x_spt_task[0][0]].to_numpy()[:, 1:])

In [52]:
np.array(support_inputs).shape

(1, 3, 11)

In [67]:
support_labels = []

In [68]:
support_labels.append(df_single['close'].iloc[y_spt_task[0]].to_numpy())

In [71]:
np.array(support_labels)

array([[ 1.26134388,  3.84167883,  0.34753364, -0.49156072, -0.62871896]])

In [57]:
np.array(support_labels).shape

(1,)

In [37]:
print(f'x_qry_task:{x_qry_task[:2]}')
print(f'y_qry_task: {y_qry_task[:2]}')

x_qry_task:[[[5, 6, 7]], [[6, 7, 8]]]
y_qry_task: [[8], [9]]


In [33]:
i,window_size = 0,5

In [36]:
spt_idx = [idx for idx in range(i, i+window_size)]
spt_label_idx =  [spt_idx[-1] + 1]

In [37]:
spt_idx, spt_label_idx

([0, 1, 2, 3, 4], [5])

In [2]:
len_df = 20
window_size = 2
n_support = 3
n_query = 2

In [7]:
x_spt_task = []
y_spt_task = []
x_qry_task = []
y_qry_task = []

for i in range(len_df-window_size-n_support-n_query+1):
    x_spt = []
    y_spt = []
    x_qry = []
    y_qry = []
    
    for j in range(n_support+n_query):
        if j < n_support:
            spt_idx = [idx for idx in range(i+j, i+j+window_size)]
            x_spt.append(spt_idx)
            y_spt.append(i+j+window_size)
           
        else:
            qry_idx = [idx for idx in range(i+j, i+j+window_size)]
            x_qry.append(qry_idx)
            y_qry.append(i+j+window_size)

    x_spt_task.append(x_spt)
    y_spt_task.append(y_spt)
    x_qry_task.append(x_qry)
    y_qry_task.append(y_qry)

In [8]:
x_spt_task

[[[0, 1], [1, 2], [2, 3]],
 [[1, 2], [2, 3], [3, 4]],
 [[2, 3], [3, 4], [4, 5]],
 [[3, 4], [4, 5], [5, 6]],
 [[4, 5], [5, 6], [6, 7]],
 [[5, 6], [6, 7], [7, 8]],
 [[6, 7], [7, 8], [8, 9]],
 [[7, 8], [8, 9], [9, 10]],
 [[8, 9], [9, 10], [10, 11]],
 [[9, 10], [10, 11], [11, 12]],
 [[10, 11], [11, 12], [12, 13]],
 [[11, 12], [12, 13], [13, 14]],
 [[12, 13], [13, 14], [14, 15]],
 [[13, 14], [14, 15], [15, 16]]]

In [9]:
y_spt_task

[[2, 3, 4],
 [3, 4, 5],
 [4, 5, 6],
 [5, 6, 7],
 [6, 7, 8],
 [7, 8, 9],
 [8, 9, 10],
 [9, 10, 11],
 [10, 11, 12],
 [11, 12, 13],
 [12, 13, 14],
 [13, 14, 15],
 [14, 15, 16],
 [15, 16, 17]]

In [10]:
x_qry_task

[[[3, 4], [4, 5]],
 [[4, 5], [5, 6]],
 [[5, 6], [6, 7]],
 [[6, 7], [7, 8]],
 [[7, 8], [8, 9]],
 [[8, 9], [9, 10]],
 [[9, 10], [10, 11]],
 [[10, 11], [11, 12]],
 [[11, 12], [12, 13]],
 [[12, 13], [13, 14]],
 [[13, 14], [14, 15]],
 [[14, 15], [15, 16]],
 [[15, 16], [16, 17]],
 [[16, 17], [17, 18]]]