**With VS Code Jupyter you don't need to add the kernel to Jupyter Notebook like before to have it as an environment**

In [1]:
import polars as pl
import json, gc
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open("./params.json", mode = "r", encoding = "utf-8") as f:
    data = json.load(f)
    model_path = data["model_path"]
    num_single_sample_timesteps = data["num_single_sample_timesteps"]
    window_stride = data["window_stride"]
    input_window_length = data["input_window_length"]
    label_window_length = data["label_window_length"]
    input_features = data["input_features"]
    label_features = data["label_features"]
    positional_encoding_max_len = data["positional_encoding_max_len"]
    embedding_dim = data["embedding_dim"]
    num_attention_head = data["num_attention_head"]
    num_encoder_layers = data["num_encoder_layers"]
    num_decoder_layers = data["num_decoder_layers"]
    position_wise_nn_dim = data["position_wise_nn_dim"]
    dropout = data["dropout"]
    batch_size = data["batch_size"]
    epochs = data["epochs"]
    learning_rate = data["learning_rate"]

In [3]:
df = pl.read_csv("./reversalData_minor.csv")
df

id,eps,n_0_squared,psi_e,b_e,psi_plus,b_plus,u_list,r_list,k_e_psi_e_list,k_e_b_e_list,k_e_psi_plus_list,k_e_b_plus_list,heat_flux_psi_e_b_e_list,heat_flux_psi_e_b_plus_list,b_e_psi_plus_list,b_e_b_plus_list,psi_plus_b_plus_list,eta_list
i64,f64,f64,str,str,str,str,str,str,str,str,str,str,str,str,str,str,str,str
1500,0.123943,318.864022,"""[-0.011744609035812982, -0.011…","""[-5.517162102724528, -3.192070…","""[-0.0016475367466768115, -0.00…","""[10.73076649947905, 10.7758112…","""[0.4615413218450946, 0.4686702…","""[0.6551578220909112, 0.8945764…","""[0.00013793584140409996, 0.000…","""[30.43907766773974, 10.1893158…","""[2.7143773316504123e-06, 5.549…","""[115.14934966634185, 116.11810…","""[0.06479691188370344, 0.035800…","""[-0.1260286571909809, -0.12085…","""[0.009089727301611367, 0.00751…","""[-59.203378264111755, -34.3971…","""[-0.01767933212790023, -0.0253…","""[[1.0369206541904592], [-1.681…"
1501,0.123943,318.864022,"""[-0.008147157474100065, -0.008…","""[-0.832904729882089, 0.1407146…","""[0.0006807256129005129, 0.0007…","""[-0.7190970906032216, -0.60951…","""[0.23703131704960073, 0.234794…","""[-0.18778048603775238, -0.2053…","""[6.637617490778455e-05, 6.8758…","""[0.6937302890599556, 0.0198006…","""[4.6338736005877887e-07, 5.349…","""[0.5171006257140178, 0.3715125…","""[0.006785805995272157, -0.0011…","""[0.005858597236311648, 0.00505…","""[-0.0005669795827367211, 0.000…","""[0.5989393680078723, -0.085768…","""[-0.0004895078077358536, -0.00…","""[[1.418213787060706], [0.88278…"
1502,0.123943,318.864022,"""[-0.003781434990007615, -0.001…","""[7.429632504232679, 7.61617440…","""[0.0017201097552821953, 0.0016…","""[0.26192796713213556, -0.01619…","""[0.3398928494975913, 0.3384786…","""[-0.22023434840890527, -0.0752…","""[1.4299250583653892e-05, 1.712…","""[55.199439147950756, 58.006112…","""[2.958777570216974e-06, 2.8808…","""[0.06860625996597308, 0.000262…","""[-0.028094672314403355, -0.009…","""[-0.000990463579775022, 2.1190…","""[0.012779783348692317, 0.01292…","""[1.9460285383725033, -0.123315…","""[0.00045054485144522063, -2.74…","""[[-0.6679438799496836], [-0.80…"
1503,0.123943,318.864022,"""[-0.008102021680746106, -0.007…","""[-13.405475487140006, -12.1429…","""[-0.007065355563569785, -0.007…","""[3.432058579711073, 2.92797364…","""[0.15021463519080008, 0.170829…","""[1.9382048681079675, 1.8501503…","""[6.564275531527996e-05, 5.7415…","""[179.70677303631157, 147.45183…","""[4.991924923966652e-05, 5.2004…","""[11.779026094568387, 8.5730296…","""[0.10861145303751879, 0.092010…","""[-0.0278066130224098, -0.02218…","""[0.09471445081536302, 0.087568…","""[-46.00837716074533, -35.55430…","""[-0.024248714180659044, -0.021…","""[[0.20668630436664126], [1.034…"
1504,0.123943,318.864022,"""[-0.0007052290130828054, -0.00…","""[7.188510348955567, 7.25531527…","""[0.01155266108186768, 0.011468…","""[-0.9665805895951766, 0.236189…","""[0.1388431307750241, 0.1327890…","""[-0.2758572865676108, -1.34969…","""[4.973479608937477e-07, 1.2082…","""[51.67468103704128, 52.6395997…","""[0.0001334639780725001, 0.0001…","""[0.9342780361821591, 0.0557856…","""[-0.0050695460589294675, -0.02…","""[0.0006816606752652025, -0.000…","""[0.08304642374498203, 0.083205…","""[-6.9482745714045, 1.713632080…","""[-0.011166577959904913, 0.0027…","""[[-0.2944195173353409], [0.846…"
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
2995,0.123943,318.864022,"""[0.0027304602552334, -0.002678…","""[-15.816264644897496, -15.5200…","""[0.0012142394002116878, 0.0009…","""[2.764721975700982, 4.05941343…","""[0.4027182522813478, 0.4025100…","""[0.11225674279144691, -0.08758…","""[7.455413205409244e-06, 7.1762…","""[250.1542273174345, 240.873016…","""[1.4743773210264393e-06, 9.324…","""[7.643687602923941, 16.4788374…","""[-0.04318568199914582, 0.04157…","""[0.007548963471421893, -0.0108…","""[-0.01920473169600966, -0.0149…","""[-43.727574437250595, -63.0024…","""[0.0033570343535272327, 0.0039…","""[[-1.3447105080916295], [-0.78…"
2996,0.123943,318.864022,"""[0.0038198007840869012, 0.0010…","""[-6.111776214911955, -6.408017…","""[0.002738456236028401, 0.00267…","""[0.1260447413606365, 0.8997889…","""[0.3815429695058497, 0.3839093…","""[0.35417571126001496, 0.096388…","""[1.4590878030110906e-05, 1.135…","""[37.3538085011635, 41.06268254…","""[7.4991425566428376e-06, 7.135…","""[0.01588727682466975, 0.809620…","""[-0.02334576757788436, -0.0068…","""[0.00048146580187939, 0.000958…","""[-0.0167368316889357, -0.01711…","""[-0.7703572522626673, -5.76586…","""[0.00034516800799762196, 0.002…","""[[-0.08993232103131506], [-1.1…"
2997,0.123943,318.864022,"""[-0.007867764494778443, -0.010…","""[4.615399916306391, 5.52828332…","""[-0.0027620958841561424, -0.00…","""[-3.0540710897004004, -3.64880…","""[0.32570331855224116, 0.334151…","""[0.7358043633897454, 0.8418703…","""[6.19017181452963e-05, 0.00010…","""[21.30191638744104, 30.5619165…","""[7.629173673272302e-06, 6.1743…","""[9.32735022094379, 13.31377581…","""[-0.03631287959071882, -0.0553…","""[0.02402871208407412, 0.036511…","""[-0.012748177112564487, -0.013…","""[-14.095759451796996, -20.1716…","""[0.00843563718678174, 0.009066…","""[[-1.748145068158444], [0.0966…"
2998,0.123943,318.864022,"""[0.004137980019488926, 0.00952…","""[20.850303575856653, 19.417738…","""[-0.006889452280651537, -0.006…","""[-2.251806224644361, -5.487241…","""[0.6490342140491646, 0.6324470…","""[-0.9652622949842546, -2.12360…","""[1.7122878641689577e-05, 9.077…","""[434.73515920538074, 377.04856…","""[4.746455272737466e-05, 4.3332…","""[5.07063127334709, 30.10982431…","""[0.08627813959717334, 0.185009…","""[-0.00931792916533916, -0.0522…","""[-0.1436471715229625, -0.12782…","""[-46.950843377838595, -106.549…","""[0.01551371152996142, 0.036121…","""[[-1.702330860987556], [-0.836…"


In [4]:
df.shape

(1500, 19)

In [5]:
df = df.drop(["id", "eps", "n_0_squared"])
df.head()

psi_e,b_e,psi_plus,b_plus,u_list,r_list,k_e_psi_e_list,k_e_b_e_list,k_e_psi_plus_list,k_e_b_plus_list,heat_flux_psi_e_b_e_list,heat_flux_psi_e_b_plus_list,b_e_psi_plus_list,b_e_b_plus_list,psi_plus_b_plus_list,eta_list
str,str,str,str,str,str,str,str,str,str,str,str,str,str,str,str
"""[-0.011744609035812982, -0.011…","""[-5.517162102724528, -3.192070…","""[-0.0016475367466768115, -0.00…","""[10.73076649947905, 10.7758112…","""[0.4615413218450946, 0.4686702…","""[0.6551578220909112, 0.8945764…","""[0.00013793584140409996, 0.000…","""[30.43907766773974, 10.1893158…","""[2.7143773316504123e-06, 5.549…","""[115.14934966634185, 116.11810…","""[0.06479691188370344, 0.035800…","""[-0.1260286571909809, -0.12085…","""[0.009089727301611367, 0.00751…","""[-59.203378264111755, -34.3971…","""[-0.01767933212790023, -0.0253…","""[[1.0369206541904592], [-1.681…"
"""[-0.008147157474100065, -0.008…","""[-0.832904729882089, 0.1407146…","""[0.0006807256129005129, 0.0007…","""[-0.7190970906032216, -0.60951…","""[0.23703131704960073, 0.234794…","""[-0.18778048603775238, -0.2053…","""[6.637617490778455e-05, 6.8758…","""[0.6937302890599556, 0.0198006…","""[4.6338736005877887e-07, 5.349…","""[0.5171006257140178, 0.3715125…","""[0.006785805995272157, -0.0011…","""[0.005858597236311648, 0.00505…","""[-0.0005669795827367211, 0.000…","""[0.5989393680078723, -0.085768…","""[-0.0004895078077358536, -0.00…","""[[1.418213787060706], [0.88278…"
"""[-0.003781434990007615, -0.001…","""[7.429632504232679, 7.61617440…","""[0.0017201097552821953, 0.0016…","""[0.26192796713213556, -0.01619…","""[0.3398928494975913, 0.3384786…","""[-0.22023434840890527, -0.0752…","""[1.4299250583653892e-05, 1.712…","""[55.199439147950756, 58.006112…","""[2.958777570216974e-06, 2.8808…","""[0.06860625996597308, 0.000262…","""[-0.028094672314403355, -0.009…","""[-0.000990463579775022, 2.1190…","""[0.012779783348692317, 0.01292…","""[1.9460285383725033, -0.123315…","""[0.00045054485144522063, -2.74…","""[[-0.6679438799496836], [-0.80…"
"""[-0.008102021680746106, -0.007…","""[-13.405475487140006, -12.1429…","""[-0.007065355563569785, -0.007…","""[3.432058579711073, 2.92797364…","""[0.15021463519080008, 0.170829…","""[1.9382048681079675, 1.8501503…","""[6.564275531527996e-05, 5.7415…","""[179.70677303631157, 147.45183…","""[4.991924923966652e-05, 5.2004…","""[11.779026094568387, 8.5730296…","""[0.10861145303751879, 0.092010…","""[-0.0278066130224098, -0.02218…","""[0.09471445081536302, 0.087568…","""[-46.00837716074533, -35.55430…","""[-0.024248714180659044, -0.021…","""[[0.20668630436664126], [1.034…"
"""[-0.0007052290130828054, -0.00…","""[7.188510348955567, 7.25531527…","""[0.01155266108186768, 0.011468…","""[-0.9665805895951766, 0.236189…","""[0.1388431307750241, 0.1327890…","""[-0.2758572865676108, -1.34969…","""[4.973479608937477e-07, 1.2082…","""[51.67468103704128, 52.6395997…","""[0.0001334639780725001, 0.0001…","""[0.9342780361821591, 0.0557856…","""[-0.0050695460589294675, -0.02…","""[0.0006816606752652025, -0.000…","""[0.08304642374498203, 0.083205…","""[-6.9482745714045, 1.713632080…","""[-0.011166577959904913, 0.0027…","""[[-0.2944195173353409], [0.846…"


In [6]:
df = df.select(
    pl.col("*").str.json_decode()
)
df.head()

psi_e,b_e,psi_plus,b_plus,u_list,r_list,k_e_psi_e_list,k_e_b_e_list,k_e_psi_plus_list,k_e_b_plus_list,heat_flux_psi_e_b_e_list,heat_flux_psi_e_b_plus_list,b_e_psi_plus_list,b_e_b_plus_list,psi_plus_b_plus_list,eta_list
list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[list[f64]]
"[-0.011745, -0.011215, … -0.011728]","[-5.517162, -3.192071, … 6.209561]","[-0.001648, -0.002356, … 0.000093]","[10.730766, 10.775811, … 2.362047]","[0.461541, 0.46867, … -0.278246]","[0.655158, 0.894576, … -0.036835]","[0.000138, 0.000126, … 0.000138]","[30.439078, 10.189316, … 38.558648]","[0.000003, 0.000006, … 8.6049e-9]","[115.14935, 116.118109, … 5.579266]","[0.064797, 0.0358, … -0.072824]","[-0.126029, -0.120854, … -0.027701]","[0.00909, 0.00752, … 0.000576]","[-59.203378, -34.397152, … 14.667275]","[-0.017679, -0.025385, … 0.000219]","[[1.036921], [-1.681305], … [-0.782573]]"
"[-0.008147, -0.008292, … -0.004894]","[-0.832905, 0.140715, … -8.502144]","[0.000681, 0.000731, … 0.006186]","[-0.719097, -0.609518, … -5.884718]","[0.237031, 0.234794, … -0.347979]","[-0.18778, -0.20534, … -1.025089]","[0.000066, 0.000069, … 0.000024]","[0.69373, 0.019801, … 72.28645]","[4.6339e-7, 5.3490e-7, … 0.000038]","[0.517101, 0.371513, … 34.629909]","[0.006786, -0.001167, … 0.041612]","[0.005859, 0.005054, … 0.028802]","[-0.000567, 0.000103, … -0.052593]","[0.598939, -0.085768, … 50.032721]","[-0.00049, -0.000446, … -0.036402]","[[1.418214], [0.882783], … [-0.412105]]"
"[-0.003781, -0.001309, … 0.004499]","[7.429633, 7.616174, … 0.77588]","[0.00172, 0.001697, … 0.001258]","[0.261928, -0.016191, … -5.434772]","[0.339893, 0.338479, … -0.232213]","[-0.220234, -0.075214, … 0.191679]","[0.000014, 0.000002, … 0.00002]","[55.199439, 58.006113, … 0.60199]","[0.000003, 0.000003, … 0.000002]","[0.068606, 0.000262, … 29.53675]","[-0.028095, -0.009968, … 0.003491]","[-0.00099, 0.000021, … -0.024451]","[0.01278, 0.012927, … 0.000976]","[1.946029, -0.123316, … -4.216733]","[0.000451, -0.000027, … -0.006839]","[[-0.667944], [-0.804122], … [0.820629]]"
"[-0.008102, -0.007577, … 0.016667]","[-13.405475, -12.142975, … 11.825626]","[-0.007065, -0.007211, … -0.00118]","[3.432059, 2.927974, … 3.577291]","[0.150215, 0.17083, … -0.121867]","[1.938205, 1.85015, … -0.665716]","[0.000066, 0.000057, … 0.000278]","[179.706773, 147.451835, … 139.845423]","[0.00005, 0.000052, … 0.000001]","[11.779026, 8.57303, … 12.797011]","[0.108611, 0.092011, … 0.1971]","[-0.027807, -0.022186, … 0.059623]","[0.094714, 0.087568, … -0.01395]","[-46.008377, -35.55431, … 42.303705]","[-0.024249, -0.021115, … -0.00422]","[[0.206686], [1.034311], … [-2.368951]]"
"[-0.000705, -0.003476, … 0.004359]","[7.18851, 7.255315, … -1.118738]","[0.011553, 0.011468, … 0.001921]","[-0.966581, 0.23619, … 0.077109]","[0.138843, 0.132789, … -0.237516]","[-0.275857, -1.349696, … 0.283456]","[4.9735e-7, 0.000012, … 0.000019]","[51.674681, 52.6396, … 1.251575]","[0.000133, 0.000132, … 0.000004]","[0.934278, 0.055786, … 0.005946]","[-0.00507, -0.025219, … -0.004877]","[0.000682, -0.000821, … 0.000336]","[0.083046, 0.083205, … -0.002149]","[-6.948275, 1.713632, … -0.086265]","[-0.011167, 0.002709, … 0.000148]","[[-0.29442], [0.846611], … [0.437807]]"


In [7]:
df = df.with_columns(
    eta_list = pl.col("eta_list").list.eval(pl.element().flatten(), parallel = True)
)
df.head()

psi_e,b_e,psi_plus,b_plus,u_list,r_list,k_e_psi_e_list,k_e_b_e_list,k_e_psi_plus_list,k_e_b_plus_list,heat_flux_psi_e_b_e_list,heat_flux_psi_e_b_plus_list,b_e_psi_plus_list,b_e_b_plus_list,psi_plus_b_plus_list,eta_list
list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64]
"[-0.011745, -0.011215, … -0.011728]","[-5.517162, -3.192071, … 6.209561]","[-0.001648, -0.002356, … 0.000093]","[10.730766, 10.775811, … 2.362047]","[0.461541, 0.46867, … -0.278246]","[0.655158, 0.894576, … -0.036835]","[0.000138, 0.000126, … 0.000138]","[30.439078, 10.189316, … 38.558648]","[0.000003, 0.000006, … 8.6049e-9]","[115.14935, 116.118109, … 5.579266]","[0.064797, 0.0358, … -0.072824]","[-0.126029, -0.120854, … -0.027701]","[0.00909, 0.00752, … 0.000576]","[-59.203378, -34.397152, … 14.667275]","[-0.017679, -0.025385, … 0.000219]","[1.036921, -1.681305, … -0.782573]"
"[-0.008147, -0.008292, … -0.004894]","[-0.832905, 0.140715, … -8.502144]","[0.000681, 0.000731, … 0.006186]","[-0.719097, -0.609518, … -5.884718]","[0.237031, 0.234794, … -0.347979]","[-0.18778, -0.20534, … -1.025089]","[0.000066, 0.000069, … 0.000024]","[0.69373, 0.019801, … 72.28645]","[4.6339e-7, 5.3490e-7, … 0.000038]","[0.517101, 0.371513, … 34.629909]","[0.006786, -0.001167, … 0.041612]","[0.005859, 0.005054, … 0.028802]","[-0.000567, 0.000103, … -0.052593]","[0.598939, -0.085768, … 50.032721]","[-0.00049, -0.000446, … -0.036402]","[1.418214, 0.882783, … -0.412105]"
"[-0.003781, -0.001309, … 0.004499]","[7.429633, 7.616174, … 0.77588]","[0.00172, 0.001697, … 0.001258]","[0.261928, -0.016191, … -5.434772]","[0.339893, 0.338479, … -0.232213]","[-0.220234, -0.075214, … 0.191679]","[0.000014, 0.000002, … 0.00002]","[55.199439, 58.006113, … 0.60199]","[0.000003, 0.000003, … 0.000002]","[0.068606, 0.000262, … 29.53675]","[-0.028095, -0.009968, … 0.003491]","[-0.00099, 0.000021, … -0.024451]","[0.01278, 0.012927, … 0.000976]","[1.946029, -0.123316, … -4.216733]","[0.000451, -0.000027, … -0.006839]","[-0.667944, -0.804122, … 0.820629]"
"[-0.008102, -0.007577, … 0.016667]","[-13.405475, -12.142975, … 11.825626]","[-0.007065, -0.007211, … -0.00118]","[3.432059, 2.927974, … 3.577291]","[0.150215, 0.17083, … -0.121867]","[1.938205, 1.85015, … -0.665716]","[0.000066, 0.000057, … 0.000278]","[179.706773, 147.451835, … 139.845423]","[0.00005, 0.000052, … 0.000001]","[11.779026, 8.57303, … 12.797011]","[0.108611, 0.092011, … 0.1971]","[-0.027807, -0.022186, … 0.059623]","[0.094714, 0.087568, … -0.01395]","[-46.008377, -35.55431, … 42.303705]","[-0.024249, -0.021115, … -0.00422]","[0.206686, 1.034311, … -2.368951]"
"[-0.000705, -0.003476, … 0.004359]","[7.18851, 7.255315, … -1.118738]","[0.011553, 0.011468, … 0.001921]","[-0.966581, 0.23619, … 0.077109]","[0.138843, 0.132789, … -0.237516]","[-0.275857, -1.349696, … 0.283456]","[4.9735e-7, 0.000012, … 0.000019]","[51.674681, 52.6396, … 1.251575]","[0.000133, 0.000132, … 0.000004]","[0.934278, 0.055786, … 0.005946]","[-0.00507, -0.025219, … -0.004877]","[0.000682, -0.000821, … 0.000336]","[0.083046, 0.083205, … -0.002149]","[-6.948275, 1.713632, … -0.086265]","[-0.011167, 0.002709, … 0.000148]","[-0.29442, 0.846611, … 0.437807]"


In [8]:
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torchmetrics.regression import R2Score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed_val = 0

torch.manual_seed(seed_val)
random.seed(seed_val)
np.random.seed(seed_val)

In [9]:
df = df.sample(n = len(df), with_replacement = False, shuffle = True, seed = seed_val)
df.head()

psi_e,b_e,psi_plus,b_plus,u_list,r_list,k_e_psi_e_list,k_e_b_e_list,k_e_psi_plus_list,k_e_b_plus_list,heat_flux_psi_e_b_e_list,heat_flux_psi_e_b_plus_list,b_e_psi_plus_list,b_e_b_plus_list,psi_plus_b_plus_list,eta_list
list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64]
"[-0.005793, -0.004183, … -0.002956]","[-5.107394, -4.225599, … 5.978745]","[-0.003674, -0.003845, … 0.003528]","[3.18727, 2.998706, … -2.897284]","[0.328583, 0.33484, … -0.445908]","[0.720642, 0.544571, … -0.353192]","[0.000034, 0.000017, … 0.000009]","[26.085471, 17.855691, … 35.745387]","[0.000013, 0.000015, … 0.000012]","[10.158692, 8.992235, … 8.394254]","[0.029587, 0.017676, … -0.017676]","[-0.018464, -0.012544, … 0.008566]","[0.018765, 0.016247, … 0.021095]","[-16.278645, -12.671329, … -17.32212]","[-0.01171, -0.01153, … -0.010223]","[0.511849, -0.243375, … 0.508325]"
"[0.021404, 0.021093, … 0.000991]","[-9.237605, -11.516815, … 5.765189]","[-0.000833, -0.001172, … 0.001094]","[4.428155, 4.733523, … -3.332383]","[0.245995, 0.238335, … -0.401837]","[-0.603414, -0.837105, … 0.036702]","[0.000458, 0.000445, … 9.8238e-7]","[85.333351, 132.637025, … 33.237406]","[6.9326e-7, 0.000001, … 0.000001]","[19.60856, 22.406236, … 11.104776]","[-0.197722, -0.24293, … 0.005714]","[0.094781, 0.099846, … -0.003303]","[0.007691, 0.013499, … 0.006305]","[-40.905552, -54.515103, … -19.211818]","[-0.003687, -0.005548, … -0.003644]","[0.063621, 0.334498, … 0.674818]"
"[-0.015647, -0.013076, … -0.00386]","[-2.369924, -0.49071, … 7.107985]","[-0.003052, -0.003181, … -0.003636]","[2.772283, 2.464133, … -6.293166]","[0.320847, 0.335211, … -0.332921]","[1.616712, 1.408141, … 0.475131]","[0.000245, 0.000171, … 0.000015]","[5.61654, 0.240796, … 50.523454]","[0.000009, 0.00001, … 0.000013]","[7.685551, 6.07195, … 39.603939]","[0.037081, 0.006417, … -0.027434]","[-0.043377, -0.032221, … 0.024289]","[0.007232, 0.001561, … -0.025843]","[-6.570099, -1.209175, … -44.731731]","[-0.00846, -0.007837, … 0.02288]","[-0.20652, 0.336207, … -1.636342]"
"[-0.011174, -0.009445, … 0.001876]","[13.150369, 14.2002, … 16.964786]","[-0.000778, -0.000786, … -0.003702]","[0.488479, 0.083618, … -1.323443]","[0.118517, 0.120958, … -0.26322]","[0.294501, 0.251263, … -0.235176]","[0.000125, 0.000089, … 0.000004]","[172.932208, 201.645674, … 287.803958]","[6.0592e-7, 6.1736e-7, … 0.000014]","[0.238612, 0.006992, … 1.751502]","[-0.146942, -0.134116, … 0.031832]","[-0.005458, -0.00079, … -0.002483]","[-0.010236, -0.011157, … -0.062799]","[6.423679, 1.187386, … -22.451931]","[-0.00038, -0.000066, … 0.004899]","[-1.079961, 0.697676, … 0.485542]"
"[-0.002717, -0.00317, … -0.019602]","[3.219618, 3.510212, … 11.002655]","[0.001369, 0.001422, … -0.002804]","[-0.827359, -0.89963, … 3.683177]","[0.396805, 0.394857, … -0.212469]","[-0.125937, -0.152674, … 1.861283]","[0.000007, 0.00001, … 0.000384]","[10.365939, 12.32159, … 121.058415]","[0.000002, 0.000002, … 0.000008]","[0.684523, 0.809334, … 13.565793]","[-0.008748, -0.011128, … -0.215678]","[0.002248, 0.002852, … -0.072199]","[0.004407, 0.004993, … -0.030855]","[-2.663781, -3.157892, … 40.524725]","[-0.001133, -0.00128, … -0.010329]","[-0.240056, 2.203021, … 0.820105]"


### Train-Val-Test Split (70-20-10)

1. Split based on 1500 time-series samples (Vertically)
2. Split each sample to train-val-test (Horizontally)

Vertically for now!

#### How about 5-Fold Cross-Validation? Stratified 5F?

Just a single split for now!

In [10]:
df_train = df[:int(len(df) * 0.7)]
df_val = df[int(len(df) * 0.7):int(len(df) * 0.9)]
df_test = df[int(len(df) * 0.9):]

In [11]:
print(f"df_train shape: {df_train.shape}")
print(f"df_val shape: {df_val.shape}")
print(f"df_test shape: {df_test.shape}")

df_train shape: (1050, 16)
df_val shape: (300, 16)
df_test shape: (150, 16)


## On-Demand Data Loading (No RAM Issue)

In [12]:
class WindowedDataset(Dataset):
    def __init__(self, input_df, label_df, num_single_sample_timesteps, stride, input_window_length, label_window_length):
        super().__init__()
        
        self.input_df = input_df  # Type: Numpy, Shape: Number of time-series, Number of time-steps, Number of input features
        self.label_df = label_df  # Type: Numpy, Shape: Number of time-series, Number of time-steps, Number of label features
        self.num_single_sample_timesteps = num_single_sample_timesteps
        self.stride = stride
        self.input_window_length = input_window_length
        self.label_window_length = label_window_length

        self.valid_length = self.input_window_length + self.label_window_length
        
        self.window_indices = []
        for time_series_idx in range(self.input_df.shape[0]):
            for input_window_start_idx in range(0, self.input_df.shape[1] - self.valid_length + 1, self.stride):
                self.window_indices.append((time_series_idx, input_window_start_idx))

    def __len__(self):
        return len(self.window_indices)
    
    def __getitem__(self, index):
        time_series_idx, input_window_start_idx = self.window_indices[index]

        label_window_start_idx = input_window_start_idx + self.input_window_length
        input_window = self.input_df[time_series_idx, input_window_start_idx: label_window_start_idx, :]
        label_window = self.label_df[time_series_idx, label_window_start_idx: label_window_start_idx + self.label_window_length, :]

        input_window_mean = input_window.mean(axis = 0)
        input_window_std = input_window.std(axis = 0)
        input_window_std[input_window_std == 0] = 10 ** -8
        input_window = (input_window - input_window_mean) / input_window_std

        return torch.tensor(input_window, dtype = torch.float), torch.tensor(label_window, dtype = torch.float)

In [13]:
# df_train.explode("*")

# import torch
# from torch.utils.data import Dataset
# import polars as pl
# import numpy as np


# class CausallyNormalizedTimeSeriesDataset(Dataset):
#     def __init__(self, csv_path, input_len=200, label_len=100, stride=1):
#         self.input_len = input_len
#         self.label_len = label_len
#         self.total_len = input_len + label_len
#         self.stride = stride

#         # Load dataset from CSV (assume shape: [num_samples, time_steps * features])
#         df = pl.read_csv(csv_path)
#         raw_np = df.to_numpy()

#         # Infer shape
#         num_samples, total_cols = raw_np.shape
#         if total_cols % 1000 != 0:
#             raise ValueError("Expected time series length of 1000 per sample")
#         self.sequence_len = 1000
#         self.num_features = total_cols // self.sequence_len

#         # Reshape to (N, T, F)
#         self.data = raw_np.reshape(num_samples, self.sequence_len, self.num_features)
#         self.num_samples = num_samples

#         assert self.sequence_len >= self.total_len, "Time series too short for given input + label length"

#         # Compute mean/std using only the first input_len steps of each event
#         first_input = self.data[:, :self.input_len, :]  # shape: (N, input_len, F)
#         self.row_means = first_input.mean(axis=1, keepdims=True)  # shape: (N, 1, F)
#         self.row_stds = first_input.std(axis=1, keepdims=True) + 1e-8

#         # Precompute valid window positions
#         self.window_indices = []
#         max_start = self.sequence_len - self.total_len
#         for sample_idx in range(self.num_samples):
#             for t_start in range(0, max_start + 1, self.stride):
#                 self.window_indices.append((sample_idx, t_start))

#     def __len__(self):
#         return len(self.window_indices)

#     def __getitem__(self, idx):
#         sample_idx, t_start = self.window_indices[idx]

#         full_sequence = self.data[sample_idx]  # (T, F)
#         mean = self.row_means[sample_idx]      # (1, F)
#         std = self.row_stds[sample_idx]        # (1, F)

#         norm_sequence = (full_sequence - mean) / std  # (T, F)

#         x = norm_sequence[t_start: t_start + self.input_len]  # (input_len, F)
#         y = norm_sequence[t_start + self.input_len: t_start + self.total_len]  # (label_len, F)

#         return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)


# dataset = CausallyNormalizedTimeSeriesDataset(
#     csv_path="your_multivariate_dataset.csv",
#     input_len=200,
#     label_len=100,
#     stride=1
# )

# loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

# for x, y in loader:
#     print("Input shape:", x.shape)   # (32, 200, F)
#     print("Label shape:", y.shape)   # (32, 100, F)
#     break


## All-in-one-go Data Windowing (Possible RAM overflow)

Must be super modular since we will run with diverse set of windows!

There is a helper function in Keras but I'm going full PyTorch! (`timeseries_dataset_from_array`)

### Normalization

Important! Moving average must be implemented!

During training you shouldn't have access to future tokens for normalization! --> Data Leakage

For now just z-score on the whole train set for each row separately

In [14]:
# # All-in-one-go Data Windowing
#
# df_train = df_train.select(
#     (pl.col("*") - pl.col("*").list.mean()) / pl.col("*").list.std()
# )

# df_val = df_val.select(
#     (pl.col("*") - pl.col("*").list.mean()) / pl.col("*").list.std()
# )

# df_test = df_test.select(
#     (pl.col("*") - pl.col("*").list.mean()) / pl.col("*").list.std()
# )

What I am doing above is global normalization where I calculate mean and std across all time-series. There are other options:

1. Calculate mean and std per series (num_single_sample_timesteps) and apply to windows --> Temporal leakage! --> Knowing about future distribution
2. Calculate mean and std per window (input_window_length) and apply to each window --> Locally aware might lose global context?
3. Calculate mean and std based on pervious steps of the current window, normalize the current window with that mean and std

I am implementing step 2 for now in the custom dataset!

Probably a violin plot like the one here (https://www.tensorflow.org/tutorials/structured_data/time_series#normalize_the_data) is a good idea!

#### Reading the dataset whole!

In [15]:
# All-in-one-go Data Windowing
#
# def createWindowedDataframe(dataframe, stride, input_features, input_window_length, label_features, label_window_length):
#     """
#         Creates a windowed time-series dataset suitable for neural network (transformer) input

#         dataframe: In time-series format where columns are features and rows are time-series steps (Polars exploded format!)
#         stride: How many timesteps will each window move by. Equal for both input and label windows
#         input_features: Names of the input features in a list
#         input_window_length: Input window size in timesteps
#         label_features: Names of the label features in a list
#         label_window_length: label window size in timesteps
#     """

#     input_df = dataframe.select(
#         pl.col(input_features)
#     )
#     input_df = input_df.select(
#         numbers = pl.concat_list("*")   # List of all features
#     )
#     input_df = input_df[:-label_window_length].with_row_index("index").with_columns(    # Exluding the ending label window
#         pl.col("index").cast(pl.Int64)
#     )
#     input_df = input_df.group_by_dynamic(   # index is considered for grouping
#         index_column = "index",
#         period = f"{input_window_length}i",
#         every = f"{stride}i",
#         closed = "left"
#     ).agg(
#         pl.col("numbers").alias("X"),
#         pl.len().alias("seq_len")
#     )
#     input_df = input_df.filter(
#         pl.col("seq_len") == input_window_length
#     )
#     input_df = input_df.select(
#         pl.exclude(["index", "seq_len"])
#     )


#     label_df = dataframe.select(
#         pl.col(label_features)
#     )
#     label_df = label_df.select(
#         numbers = pl.concat_list("*")
#     )
#     label_df = label_df[input_window_length:].with_row_index("index").with_columns(
#         pl.col("index").cast(pl.Int64)
#     )
#     label_df = label_df.group_by_dynamic(
#         index_column = "index",
#         period = f"{label_window_length}i",
#         every = f"{stride}i",
#         closed = "left"
#     ).agg(
#         pl.col("numbers").alias("Y"),
#         pl.len().alias("seq_len")
#     )
#     label_df = label_df.filter(
#         (pl.col("seq_len") == label_window_length)
#     )
#     label_df = label_df.select(
#         pl.exclude(["index", "seq_len"])
#     )

#     return pl.concat([input_df, label_df], how = "horizontal")    

In [16]:
# All-in-one-go Data Windowing
#
# df_train = df_train.explode("*")

# timeseries_df_train = pl.DataFrame()

# for i in range(0, len(df_train), num_single_sample_timesteps):
#     temp_df = createWindowedDataframe(
#         dataframe = df_train[i:i + num_single_sample_timesteps],
#         stride = window_stride,
#         input_features = input_features,
#         input_window_length = input_window_length,
#         label_features = label_features,
#         label_window_length = label_window_length
#     )

#     if(temp_df.is_empty()):
#         timeseries_df_train = temp_df
#     else:
#         timeseries_df_train = pl.concat([timeseries_df_train, temp_df], how = "vertical")

# timeseries_df_train

In [17]:
# All-in-one-go Data Windowing
#
# df_val = df_val.explode("*")

# timeseries_df_val = pl.DataFrame()

# for i in range(0, len(df_val), num_single_sample_timesteps):
#     temp_df = createWindowedDataframe(
#         dataframe = df_val[i:i + num_single_sample_timesteps],
#         stride = window_stride,
#         input_features = input_features,
#         input_window_length = input_window_length,
#         label_features = label_features,
#         label_window_length = label_window_length
#     )

#     if(temp_df.is_empty()):
#         timeseries_df_val = temp_df
#     else:
#         timeseries_df_val = pl.concat([timeseries_df_val, temp_df], how = "vertical")

# timeseries_df_val

In [18]:
# All-in-one-go Data Windowing
#
# df_test = df_test.explode("*")

# timeseries_df_test = pl.DataFrame()

# for i in range(0, len(df_test), num_single_sample_timesteps):
#     temp_df = createWindowedDataframe(
#         dataframe = df_test[i:i + num_single_sample_timesteps],
#         stride = window_stride,
#         input_features = input_features,
#         input_window_length = input_window_length,
#         label_features = label_features,
#         label_window_length = label_window_length
#     )

#     if(temp_df.is_empty()):
#         timeseries_df_test = temp_df
#     else:
#         timeseries_df_test = pl.concat([timeseries_df_test, temp_df], how = "vertical")

# timeseries_df_test

**Some visualization here is good to do a sanity-check on the dataset before sending it to the model!!!!!!!!!!**

In [19]:
class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, dropout, max_len):
        super().__init__()
        self.dropout = torch.nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype = torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

In [20]:
class TimeSeriesTransformer(torch.nn.Module):
    def __init__(self, input_dim, output_dim, d_model, num_head, num_encoder_layers, num_decoder_layers, positional_encoding_max_len, position_wise_ffn_dim, dropout):
        super().__init__()
        self.input_proj = torch.nn.Linear(input_dim, d_model)
        self.output_proj = torch.nn.Linear(output_dim, d_model)
        
        self.pos_encoder = PositionalEncoding(
            d_model = d_model,
            dropout = dropout,
            max_len = positional_encoding_max_len
        )

        self.transformer = torch.nn.Transformer(
            d_model = d_model,
            nhead = num_head,
            num_encoder_layers = num_encoder_layers,
            num_decoder_layers = num_decoder_layers,
            dropout = dropout,
            dim_feedforward = position_wise_ffn_dim,
            batch_first = True
        )

        self.final_proj = torch.nn.Linear(d_model, output_dim)

    def forward(self, src, tgt, tgt_mask = None):
        # encode and decode methods can be used here

        src = self.input_proj(src)
        src = self.pos_encoder(src)
        
        tgt = self.output_proj(tgt)
        tgt = self.pos_encoder(tgt)
        
        if(tgt_mask is None):
            tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)

        decoder_out = self.transformer(src, tgt, tgt_mask = tgt_mask)

        return self.final_proj(decoder_out)
    
    def encode(self, src):
        # memory is encoder output

        src = self.input_proj(src)
        src = self.pos_encoder(src)
        memory = self.transformer.encoder(src)
        return memory

    def decode(self, tgt, memory, tgt_mask = None):
        tgt = self.output_proj(tgt)
        tgt = self.pos_encoder(tgt)

        if(tgt_mask is None):
            tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)

        decoder_out = self.transformer.decoder(tgt, memory, tgt_mask = tgt_mask)

        return self.final_proj(decoder_out)

In [21]:
# All-in-one-go Data Windowing

# df_train = timeseries_df_train
# df_val = timeseries_df_val
# df_test = timeseries_df_test

# # %reset_selective -f "^timeseries_df_train$"
# # %reset_selective -f "^timeseries_df_val$"
# # %reset_selective -f "^timeseries_df_test$"
# # %reset_selective -f "^df$"
# del timeseries_df_train
# del timeseries_df_val
# del timeseries_df_test
# del df
# gc.collect()

In [22]:
# # All-in-one-go Data Windowing
# 
# df_train = TensorDataset(
#     torch.Tensor(df_train["X"]),
#     torch.Tensor(df_train["Y"])
# )
# data_loader_train = DataLoader(
#     df_train,
#     batch_size = batch_size,
#     shuffle = True,
#     num_workers = 10
# )

# df_val = TensorDataset(
#     torch.Tensor(df_val["X"]),
#     torch.Tensor(df_val["Y"])
# )
# data_loader_val = DataLoader(
#     df_val,
#     batch_size = batch_size,
#     shuffle = True,
#     num_workers = 10
# )

# df_test = TensorDataset(
#     torch.Tensor(df_test["X"]),
#     torch.Tensor(df_test["Y"])
# )
# data_loader_test = DataLoader(
#     df_test,
#     batch_size = batch_size,
#     shuffle = True,
#     num_workers = 10
# )

In [23]:
##### TRAIN #####

num_train_samples = df_train.shape[0]

input_df = df_train.select(
    pl.col(input_features)
)
label_df = df_train.select(
    pl.col(label_features)
)
input_df = input_df.explode("*").to_numpy()
input_df = input_df.reshape(num_train_samples, num_single_sample_timesteps, len(input_features))
label_df = label_df.explode("*").to_numpy()
label_df = label_df.reshape(num_train_samples, num_single_sample_timesteps, len(label_features))

df_train = WindowedDataset(
    input_df = input_df,
    label_df = label_df,
    num_single_sample_timesteps = num_single_sample_timesteps,
    stride = window_stride,
    input_window_length = input_window_length,
    label_window_length = label_window_length
)

data_loader_train = DataLoader(
    df_train,
    batch_size = batch_size,
    shuffle = True,
    num_workers = 10,
    prefetch_factor = 8,
    persistent_workers = True,
    pin_memory = True
)



##### VALIDATION #####
num_val_samples = df_val.shape[0]

input_df = df_val.select(
    pl.col(input_features)
)
label_df = df_val.select(
    pl.col(label_features)
)
input_df = input_df.explode("*").to_numpy()
input_df = input_df.reshape(num_val_samples, num_single_sample_timesteps, len(input_features))
label_df = label_df.explode("*").to_numpy()
label_df = label_df.reshape(num_val_samples, num_single_sample_timesteps, len(label_features))

df_val = WindowedDataset(
    input_df = input_df,
    label_df = label_df,
    num_single_sample_timesteps = num_single_sample_timesteps,
    stride = window_stride,
    input_window_length = input_window_length,
    label_window_length = label_window_length
)

data_loader_val = DataLoader(
    df_val,
    batch_size = batch_size,
    shuffle = True,
    num_workers = 10,
    prefetch_factor = 8,
    persistent_workers = True,
    pin_memory = True
)



##### TEST #####
num_test_samples = df_test.shape[0]

input_df = df_test.select(
    pl.col(input_features)
)
label_df = df_test.select(
    pl.col(label_features)
)
input_df = input_df.explode("*").to_numpy()
input_df = input_df.reshape(num_test_samples, num_single_sample_timesteps, len(input_features))
label_df = label_df.explode("*").to_numpy()
label_df = label_df.reshape(num_test_samples, num_single_sample_timesteps, len(label_features))

df_test = WindowedDataset(
    input_df = input_df,
    label_df = label_df,
    num_single_sample_timesteps = num_single_sample_timesteps,
    stride = window_stride,
    input_window_length = input_window_length,
    label_window_length = label_window_length
)

data_loader_test = DataLoader(
    df_test,
    batch_size = batch_size,
    shuffle = True,
    num_workers = 10,
    prefetch_factor = 8,
    persistent_workers = True,
    pin_memory = True
)


##### CLEANING UP! #####

del input_df
del label_df
del df
gc.collect()

0

In [None]:
model = TimeSeriesTransformer(
    input_dim = len(input_features),
    output_dim = len(label_features),
    d_model = embedding_dim,
    num_head = num_attention_head,
    num_encoder_layers = num_encoder_layers,
    num_decoder_layers = num_encoder_layers,
    positional_encoding_max_len = positional_encoding_max_len,
    position_wise_ffn_dim = position_wise_nn_dim,
    dropout = dropout
).to(device)

print(f"Number of trainable parameters in the model: {sum(p.numel() for p in model.parameters() if p.requires_grad)}\n")

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(
    model.parameters(),
    lr = learning_rate
)

train_r2 = R2Score(multioutput = "uniform_average").to(device)
val_r2 = R2Score(multioutput = "uniform_average").to(device)

overfit_count = 0

model.forward = torch.compile(model.forward)    # Faster for training

for epoch in range(epochs):
    ################################################## TRAINING ##################################################
    ##### No Scheduled Sampling #####
    # model.train()
    # epoch_train_loss = 0.0
    # train_progress_bar = tqdm(
    #     data_loader_train,
    #     desc = f"Epoch {epoch + 1}/{epochs}"
    # )
    
    # for batch_x, batch_y in train_progress_bar:
    #     batch_x = batch_x.to(device)
    #     batch_y = batch_y.to(device)

    #     decoder_input = torch.zeros_like(batch_y)
    #     decoder_input[:, 1:] = batch_y[:, :-1]
    #     decoder_input[:, 0] = 0    # Adding bos token

    #     optimizer.zero_grad()
    #     output = model(batch_x, decoder_input)
    #     loss = criterion(output, batch_y)
    #     loss.backward()
    #     optimizer.step()

    #     epoch_train_loss += loss.item()
    #     train_r2.update(
    #         output.view(output.shape[0], -1),    # Flatten (batch_size, timestep * num_feature)
    #         batch_y.view(batch_y.shape[0], -1)
    #     )
    #     train_progress_bar.set_postfix({
    #         "train_loss": f"{loss.item():.6f}"
    #     })

    # avg_train_loss = epoch_train_loss / len(data_loader_train)
    # epoch_train_r2 = train_r2.compute()
    # train_r2.reset()
    # print(f"Epoch [{epoch + 1}/{epochs}], Train Loss: {avg_train_loss:.6f}, Train R2: {epoch_train_r2:.6f}")
    #################################

    model.train()
    epoch_train_loss = 0.0
    train_progress_bar = tqdm(
        data_loader_train,
        desc = f"Epoch {epoch + 1}/{epochs}"
    )

    epsilon = max(0, 1.0 - 10 * epoch / epochs)  # Linear decay for choosing between label or pred as decoder input

    for batch_x, batch_y in train_progress_bar:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        num_label_batch_samples, num_label_timesteps, num_label_features = batch_y.shape

        # First pass: Just get preds --> No grads!
        decoder_input_first_pass = torch.zeros_like(batch_y)
        decoder_input_first_pass[:, 1:, :] = batch_y[:, :-1, :] # [:, 0, :] BOS = 0
        
        with torch.no_grad():
            output_first_pass = model(batch_x, decoder_input_first_pass)
            output_first_pass = output_first_pass.detach()   # detach: Not involved in model computation graph

        # Second pass: pred and label decision for decoder input and back prop
        decoder_input_second_pass = torch.zeros_like(batch_y)   # [;, 0, :] BOS = 0
        teacher_force_mask = (
            torch.rand(num_label_batch_samples, num_label_timesteps - 1, device = device) < epsilon    # Not including BOS
        ).unsqueeze(2)  # Boolean mat of size: (num_label_batch_samples, num_label_timesteps, 1) --> Bool for selecting either pred or label
        decoder_input_second_pass[:, 1:, :] = torch.where(
            teacher_force_mask,
            batch_y[:, :-1, :],
            output_first_pass[:, :-1, :]
        )
        optimizer.zero_grad()
        output_second_pass = model(batch_x, decoder_input_second_pass)
        loss = criterion(output_second_pass, batch_y)
        loss.backward()
        optimizer.step()

        epoch_train_loss += loss.item()
        train_r2.update(
            output_second_pass.view(output_second_pass.shape[0], -1),    # Flatten (batch_size, timestep * num_feature)
            batch_y.view(batch_y.shape[0], -1)
        )
        train_progress_bar.set_postfix({
            "train_loss": f"{loss.item():.6f}"
        })

    avg_train_loss = epoch_train_loss / len(data_loader_train)
    epoch_train_r2 = train_r2.compute()
    train_r2.reset()
    print(f"Epoch [{epoch + 1}/{epochs}], Train Loss: {avg_train_loss:.6f}, Train R2: {epoch_train_r2:.6f}")

    ################################################# VALIDATION #################################################
    model.eval()
    epoch_val_loss = 0.0
    val_progress_bar = tqdm(
        data_loader_val,
        desc = f"Epoch {epoch + 1}/{epochs}"
    )

    with torch.no_grad():
        for batch_x, batch_y in val_progress_bar:
            # No teacher forcing in inference: Output at each timestep is feeded back to the decoder as input in the next timestep
            # Attention mask grows as timesteps pass
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            num_label_batch_samples, num_label_timesteps, num_label_features = batch_y.shape    # num_label_features == len(label_features)

            outputs = torch.zeros(num_label_batch_samples, num_label_timesteps, num_label_features).to(device)
            decoder_input = torch.zeros(num_label_batch_samples, 1 + num_label_timesteps, num_label_features).to(device)    # +1 for BOS

            full_mask = model.transformer.generate_square_subsequent_mask(1 + num_label_timesteps).to(device)
            encoder_output = model.encode(batch_x)

            for t in range(num_label_timesteps):
                # tgt_mask = model.transformer.generate_square_subsequent_mask(decoder_input.shape[1]).to(device) --> Alternative to full mask! (Indexing is faster than mask generation!)
                tgt_mask = full_mask[:t+1, :t+1]
                out = model.decode(
                    tgt = decoder_input[:, :t+1, :],
                    memory = encoder_output,
                    tgt_mask = tgt_mask
                )
                next_step = out[:, -1, :]
                outputs[:, t, :] = next_step
                decoder_input[:, t+1, :] = next_step

            loss = criterion(outputs, batch_y)
            epoch_val_loss += loss.item()
            val_r2.update(
                outputs.view(outputs.shape[0], -1),    # Flatten (batch_size, timestep * num_feature)
                batch_y.view(batch_y.shape[0], -1)
            )
            val_progress_bar.set_postfix({
                "val_loss": f"{loss.item():.6f}"
            })

        avg_val_loss = epoch_val_loss / len(data_loader_val)
        epoch_val_r2 = val_r2.compute()
        val_r2.reset()
        print(f"Epoch [{epoch + 1}/{epochs}], Val Loss: {avg_val_loss:.6f}, Val R2: {epoch_val_r2:.6f}\n")
    

    if(epoch >= 10 and avg_val_loss - avg_train_loss > 0.01):
        overfit_count += 1
        print(f"Possible Overfitting!!! {overfit_count}/3\n")
        if(overfit_count == 3):
            print("Training Stopped!!!")
            break

Number of trainable parameters in the model: 135361



Epoch 1/100:   2%|▏         | 32/1880 [00:28<27:25,  1.12it/s, val_loss=0.284753] 


KeyboardInterrupt: 

In [None]:
################################################## TESTING ##################################################

model.eval()
test_loss = 0.0
test_progress_bar = tqdm(
    data_loader_test
)

test_r2 = R2Score(multioutput = "uniform_average")

with torch.no_grad():
    for batch_x, batch_y in test_progress_bar:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        num_label_batch_samples, num_label_timesteps, num_label_features = batch_y.shape    # num_label_features == len(label_features)

        outputs = torch.zeros(num_label_batch_samples, num_label_timesteps, num_label_features).to(device)
        decoder_input = torch.zeros(num_label_batch_samples, 1 + num_label_timesteps, num_label_features).to(device)    # +1 for BOS

        full_mask = model.transformer.generate_square_subsequent_mask(1 + num_label_timesteps).to(device)
        encoder_output = model.encode(batch_x)

        for t in range(num_label_timesteps):
            tgt_mask = full_mask[:t+1, :t+1]
            out = model.decode(
                tgt = decoder_input[:, :t+1, :],
                memory = encoder_output,
                tgt_mask = tgt_mask
            )
            next_step = out[:, -1, :]
            outputs[:, t, :] = next_step
            decoder_input[:, t+1, :] = next_step

        loss = criterion(outputs, batch_y)
        test_loss += loss.item()
        test_progress_bar.set_postfix({
            "batch_test_loss": f"{loss.item():.6f}"
        })
        test_r2.update(
            outputs.view(outputs.shape[0], -1),
            batch_y.view(batch_y.shape[0], -1)
        )

    final_test_loss = test_loss / len(data_loader_test)
    final_test_r2 = test_r2.compute()
    test_r2.reset()
    print(f"Final Test Loss: {final_test_loss:.6f}, Final Test R2: {final_test_r2:.6f}")

100%|██████████| 940/940 [12:44<00:00,  1.23it/s, batch_test_loss=0.042307]

Final Test Loss: 0.098021, Final Test R2: -2.319430





In [None]:
torch.save(model, model_path)