# Fish experiment whole dataset


## Hyperparameters

In [1]:


#torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch import optim
#Libaries for calculation and processing
from einops import rearrange, repeat
import math
from math import sqrt
from math import ceil
import numpy as np
from sklearn.preprocessing import StandardScaler
#libaries for data importng, formating and handling 
import pandas as pd
#For analysis and plotting
import matplotlib.pyplot as plt
import seaborn as sns
#others
import os
import time
import json
import pickle
import warnings
warnings.filterwarnings('ignore')


from model.Crossformer import *
from data.Dataset import Dataset_MTS,Dataset_MTS_simplified
from exp.ExpFish import Expfish

In [3]:

#DATASET 
ROOT_PATH = "data/DatasetClusters/fishes/fish02/"
DATA_PATH = "fish_02_pairs.csv"

TRAIN_FLAG = "train"
VAL_FLAG = "val"
TEST_FLAG = "test"
SIZE=[100,25,25] #[seq_len, label_len, pred_len]
SCALE = True
SCALE_STATISTIC = True
DATA_SPLIT = [1, 0, 0.0]  # Train, Val, TEST
STRIDE = 1

#DATALOADER 
BATCH_SIZE = 30
SHUFFLE_FLAG = False
NUM_WORKSES = 0
DROP_LAST = False

#MODEL 
DATA_DIM = 768  # number of clusers
IN_LEN   = SIZE[0]
OUT_LEN  = SIZE[2]
SEG_LEN  = 20
WIN_SIZE = 1
FACTOR   = 2
D_MODEL  = 256 
D_FF     = 512
N_HEADS  = 1
E_LAYERS = 1
DROPOUT  = 0.2
BASELINE = False

#Device
DEVICE   = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#OPTIMIZER
LR = 0.001

#TRAINING 
NUM_EPOCHS = 1



## Setting up pairs


In [3]:
df_data = pd.read_csv(ROOT_PATH+DATA_PATH)
df_data

Unnamed: 0,cluster_1_red_s1_paired,cluster_1_green_s1_paired,cluster_1_red_s1_paired.1,cluster_2_green_s1_paired,cluster_1_red_s1_paired.2,cluster_3_green_s1_paired,cluster_1_red_s1_paired.3,cluster_4_green_s1_paired,cluster_1_red_s1_paired.4,cluster_5_green_s1_paired,...,cluster_8_red_s2_paired.19,cluster_8_green_s2_paired.15,cluster_8_red_s2_paired.20,cluster_9_green_s2_paired.15,cluster_8_red_s2_paired.21,cluster_10_green_s2_paired.15,cluster_8_red_s2_paired.22,cluster_11_green_s2_paired.15,cluster_8_red_s2_paired.23,cluster_12_green_s2_paired.15
0,-0.210637,-0.850498,-0.210637,0.731234,-0.210637,0.005784,-0.210637,-0.331898,-0.210637,-1.087744,...,-1.099984,-0.453258,-1.099984,-0.319978,-1.099984,0.195136,-1.099984,0.592922,-1.099984,-1.239751
1,0.005327,-1.278706,0.005327,0.912576,0.005327,0.600644,0.005327,0.984535,0.005327,1.181213,...,0.290193,0.291925,0.290193,-0.154953,0.290193,0.357503,0.290193,0.708883,0.290193,-0.056505
2,0.326035,-0.607865,0.326035,0.378180,0.326035,0.520806,0.326035,0.974911,0.326035,1.374006,...,-1.647060,0.261231,-1.647060,-1.528252,-1.647060,0.367067,-1.647060,0.448457,-1.647060,0.029960
3,0.875854,-1.416154,0.875854,1.774807,0.875854,0.552595,0.875854,0.685495,0.875854,0.385875,...,-0.860632,0.064001,-0.860632,1.328282,-0.860632,0.765776,-0.860632,0.540496,-0.860632,-1.256331
4,0.065255,-1.120760,0.065255,1.436819,0.065255,0.699797,0.065255,0.906697,0.065255,0.018051,...,2.617447,-0.563792,2.617447,0.245650,2.617447,0.213745,2.617447,0.009599,2.617447,0.023406
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1026,-0.851769,0.260886,-0.851769,0.868904,-0.851769,-0.902760,-0.851769,-1.429915,-0.851769,-1.734120,...,-0.174966,-0.410315,-0.174966,-1.014679,-0.174966,-0.227071,-0.174966,-0.175836,-0.174966,1.269689
1027,-0.778003,0.902261,-0.778003,0.420471,-0.778003,-0.769632,-0.778003,-1.169930,-0.778003,-1.116090,...,0.118153,-0.557613,0.118153,-1.405983,0.118153,-0.443731,0.118153,-0.067522,0.118153,0.707905
1028,-1.413795,-0.966663,-1.413795,0.644278,-1.413795,-1.421847,-1.413795,-1.840589,-1.413795,-1.790913,...,0.107467,-1.487121,0.107467,-1.324683,0.107467,-0.901849,0.107467,-0.726689,0.107467,0.036255
1029,-1.187163,0.408066,-1.187163,0.967696,-1.187163,-0.585294,-1.187163,-0.573442,-1.187163,-0.639640,...,-0.226328,-1.020237,-0.226328,-1.398377,-0.226328,-0.173399,-0.226328,-0.187637,-0.226328,0.551938


## DataLoader

In [4]:

#train_set = Dataset_MTS(root_path=ROOT_PATH,data_path=DATA_PATH,flag=TRAIN_FLAG,size=SIZE,scale=SCALE,scale_statistic=SCALE_STATISTIC,data_split=DATA_SPLIT,stride=STRIDE)

train_set = Dataset_MTS_simplified(df_data=df_data,size=SIZE,stride=STRIDE)

data_loader_train = DataLoader(
            train_set,
            batch_size=BATCH_SIZE,
            shuffle=SHUFFLE_FLAG,
            num_workers=NUM_WORKSES,
            drop_last=DROP_LAST
            )


## Model Init


In [5]:
model = Crossformer(
    data_dim=DATA_DIM,
    in_len=IN_LEN,
    out_len=OUT_LEN,
    seg_len=SEG_LEN,
    win_size=WIN_SIZE,
    factor=FACTOR,
    d_model=D_MODEL,
    d_ff=D_MODEL,
    n_heads=N_HEADS,
    e_layers=E_LAYERS,
    dropout=DROPOUT,
    baseline=False,
    device=DEVICE
    
).float()

## Loss-function and Optimizer

In [6]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

## Training

In [7]:
exp = Expfish(model=model,
               data_loader_train=data_loader_train,
               data_loader_test=None,
               data_loader_val=None,
               optimizer=optimizer,
               criterion=criterion,
               num_epochs=NUM_EPOCHS,
               device=DEVICE
               )

model = exp.train()


## Weight extraction

# Fish experimant validation test