In [1]:
import yfinance as yf
import pandas as pd
import numpy as np
import random
import torch
import datetime


import sys
sys.path.append('../src/utils') 
from common import get_accuracy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SAVE_PATH = "/home/s/python_progs/DL_homeworks/deep-metric-ts-clustering/notebooks/model"

In [2]:
# Select diverse stocks from different sectors
tickers = ['AAPL', 'MSFT', 'GOOGL', 'TSLA', 'AMZN', 'JPM', 'JNJ', 'XOM', 'NVDA', 'WMT']
ticker_to_class = {ticker: i for i, ticker in enumerate(tickers)}

data = yf.download(tickers, start="2020-01-01", end="2024-01-01")['Close']
data.dropna(inplace=True)
print(data)

YF.download() has changed argument auto_adjust default to True


[*********************100%***********************]  10 of 10 completed

Ticker            AAPL        AMZN       GOOGL         JNJ         JPM  \
Date                                                                     
2020-01-02   72.620827   94.900497   68.108376  126.055145  120.733566   
2020-01-03   71.914803   93.748497   67.752075  124.595749  119.140327   
2020-01-06   72.487846   95.143997   69.557945  124.440338  119.045586   
2020-01-07   72.146942   95.343002   69.423584  125.200233  117.021744   
2020-01-08   73.307503   94.598503   69.917732  125.182945  117.934608   
...                ...         ...         ...         ...         ...   
2023-12-22  192.192551  153.419998  140.816757  149.492691  161.660004   
2023-12-26  191.646561  153.410004  140.846634  150.146591  162.616074   
2023-12-27  191.745834  153.339996  139.702087  150.348541  163.591431   
2023-12-28  192.172714  153.380005  139.562744  150.569702  164.460587   
2023-12-29  191.130325  151.940002  139.025330  150.723572  164.267441   

Ticker            MSFT       NVDA    




In [3]:
def create_windows(series, window_size=30, step=1):
    windows = []
    for i in range(0, len(series) - window_size + 1, step):
        window = series[i:i + window_size].values
        windows.append(window)
    return windows


def create_dataset(windows):
    dataset = []
    for window in windows:

        dataset.append((window['window'], ticker_to_class[window['ticker']]))

    return dataset


In [4]:
window_size = 30
step = 5  # reduces overlap between samples
stock_windows = []  # Dictionary of ticker → list of windows

for ticker in tickers:
    for window in create_windows(data[ticker], window_size, step):
        stock_windows.append({
                'ticker': ticker,
                'window': window
            })
print(stock_windows[0])

{'ticker': 'AAPL', 'window': array([72.62082672, 71.91480255, 72.48784637, 72.14694214, 73.30750275,
       74.86462402, 75.03387451, 76.63691711, 75.60206604, 75.27807617,
       76.22105408, 77.06489563, 76.54263306, 76.8158493 , 77.18579102,
       76.96333313, 74.70021057, 76.81343842, 78.4213028 , 78.30767822,
       74.83560181, 74.63008881, 77.09391022, 77.72257233, 78.63166809,
       77.56285858, 77.9312439 , 77.4610672 , 79.30059814, 78.73588562])}


In [5]:
dataset = create_dataset(stock_windows)

In [6]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import random_split

class StockDataset(Dataset):
    def __init__(self, stock_windows):
        self.stock_windows = stock_windows
        self.labels = [label for _, label in stock_windows]  # Extract and store labels

    def __len__(self):
        return len(self.stock_windows)

    def __getitem__(self, idx):
        x, label = self.stock_windows[idx]
        x = torch.tensor(x, dtype=torch.float32).unsqueeze(0)  # shape [1, 30]
        label = torch.tensor(label, dtype=torch.float32)
        return x, label

dataset = StockDataset(dataset)


train_ratio = 0.8
train_size = int(train_ratio * len(dataset))
val_size = len(dataset) - train_size


train_dataset, val_dataset = random_split(dataset, [train_size, val_size])



In [7]:
import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, input_dim, encoder_type, embedding_dim=128):
        super().__init__()

        self.encoder_type = encoder_type
        self.embedding_dim = embedding_dim


        if encoder_type=="CNN":
            self.conv_layers = nn.Sequential(
                nn.Conv1d(1, 64, kernel_size=3, padding=1),   # [30] → [30]
                nn.ReLU(),
                nn.Conv1d(64, 64, kernel_size=3, padding=1),  # [30] → [30]
                nn.ReLU(),
                nn.MaxPool1d(2),                              # [30] → [15]

                nn.Conv1d(64, 128, kernel_size=3, padding=1), # [15] → [15]
                nn.ReLU(),
                nn.Conv1d(128, 128, kernel_size=3, padding=1),# [15] → [15]
                nn.ReLU(),
                nn.MaxPool1d(2),                              # [15] → [7]

                nn.Conv1d(128, 256, kernel_size=3, padding=1),# [7] → [7]
                nn.ReLU(),
                nn.Conv1d(256, 256, kernel_size=3, padding=1),# [7] → [7]
                nn.ReLU(),
                nn.MaxPool1d(2),                              # [7] → [3]
            )

            self.fc = nn.Linear(256 * 3, embedding_dim)


        if encoder_type=="GRU":
            self.num_layers = 5
            self.hidden_size = 256
            self.gru = nn.GRU(1, self.hidden_size, self.num_layers, batch_first=True)
            self.fc = nn.Linear(self.hidden_size, self.embedding_dim)

        if encoder_type=="LSTM":    
            self.num_layers = 5
            self.hidden_size = 256
            self.lstm = nn.LSTM(1, self.hidden_size, self.num_layers, batch_first=True)
            self.fc = nn.Linear(self.hidden_size, self.embedding_dim)

    def forward(self, x):
        if self.encoder_type=="CNN":
            # x = x.unsqueeze(1)  # add channel dim: (batch_size, 1, seq_len)
            x = self.conv_layers(x)
            x = x.view(x.size(0), -1)  # flatten
            x = self.fc(x)
        
        if self.encoder_type=="GRU":
            x = x.reshape(-1, window_size, 1)

            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

            x, _ = self.gru(x, h0)
            x = x[:, -1, :]
            x = self.fc(x)
            x = nn.functional.normalize(x, p=2, dim=1)

        if self.encoder_type=="LSTM":
            x = x.reshape(-1, window_size, 1)

            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
            c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)


            x, _ = self.lstm(x, (h0,c0))
            x = x[:, -1, :]
            x = self.fc(x)
            x = nn.functional.normalize(x, p=2, dim=1)



        return x




In [None]:

from torch.utils.data import DataLoader
from pytorch_metric_learning import distances, losses, miners, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

distance = distances.CosineSimilarity()
reducer = reducers.ThresholdReducer(low=0)
mining_func = miners.TripletMarginMiner(margin=0.2, distance=distance, type_of_triplets="semihard")
loss_func = losses.TripletMarginLoss(margin=0.2, distance=distance, reducer=reducer)

# model = Encoder(window_size, 'CNN')
# model = Encoder(window_size, 'LSTM')
model = Encoder(window_size, 'GRU')
# model = Encoder(window_size, 'TRANSFORMER')

model.to(device)

Encoder(
  (lstm): LSTM(1, 256, num_layers=5, batch_first=True)
  (fc): Linear(in_features=256, out_features=128, bias=True)
)

In [None]:
import torch
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
import torch.optim as optim
from tqdm import tqdm
import json

batch_size = 512

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
history = {"train": [], "val": [], "best_accuracy": 0.0}

scaler = GradScaler()  # Helps manage fp16 gradients

optimizer = optim.Adam(model.parameters(), lr=1e-3)
epochs = 1000

for epoch in range(epochs):
    model.train()
    epoch_loss = 0

    for x, label in train_loader:
        x, label = x.to(device), label.to(device)

        optimizer.zero_grad()

        with autocast(dtype=torch.float16, device_type=device.type): 
            embeddings = model(x)
            indices_tuple = mining_func(embeddings, label)
            loss = loss_func(embeddings, label, indices_tuple)

        # Backpropagation with scaled gradients
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()




    if epoch %10 == 0:
        # model validation
        model.eval()
        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
        
        with torch.no_grad():
            # as all embeddings need to be stored in memory
            # you can set DEVICE = torch.device('cpu') in case gpu memory overflow occurs
            accuracy = get_accuracy(val_dataset, train_dataset, model, device.type)
            
            history["val"].append({"epoch": epoch, "accuracy": accuracy})
            print(f"Val accuracy: {accuracy}")
            
            # save model
            torch.save(model.state_dict(), f"{SAVE_PATH}/model_latest_{model.encoder_type}.pth")
            
            if accuracy >= history["best_accuracy"]:
                history["best_accuracy"] = accuracy
                torch.save(model.state_dict(), f"{SAVE_PATH}/model_best_{model.encoder_type}.pth")

with open(f"history_{model.encoder_type}_{datetime.datetime.now()}.json", "w") as f:
    f.write(json.dumps(history))

Epoch 1/1000, Loss: 0.1852


100%|██████████| 13/13 [00:00<00:00, 123.26it/s]
100%|██████████| 49/49 [00:00<00:00, 282.21it/s]


Val accuracy: 0.4897959183673469
Epoch 11/1000, Loss: 0.1139


100%|██████████| 13/13 [00:00<00:00, 151.22it/s]
100%|██████████| 49/49 [00:00<00:00, 383.04it/s]


Val accuracy: 0.4642857142857143
Epoch 21/1000, Loss: 0.1113


100%|██████████| 13/13 [00:00<00:00, 188.13it/s]
100%|██████████| 49/49 [00:00<00:00, 369.37it/s]


Val accuracy: 0.461734693877551
Epoch 31/1000, Loss: 0.1082


100%|██████████| 13/13 [00:00<00:00, 116.22it/s]
100%|██████████| 49/49 [00:00<00:00, 262.21it/s]


Val accuracy: 0.46938775510204084
Epoch 41/1000, Loss: 0.1102


100%|██████████| 13/13 [00:00<00:00, 129.66it/s]
100%|██████████| 49/49 [00:00<00:00, 258.15it/s]


Val accuracy: 0.4923469387755102
Epoch 51/1000, Loss: 0.1080


100%|██████████| 13/13 [00:00<00:00, 126.14it/s]
100%|██████████| 49/49 [00:00<00:00, 323.59it/s]


Val accuracy: 0.4770408163265306
Epoch 61/1000, Loss: 0.1095


100%|██████████| 13/13 [00:00<00:00, 184.03it/s]
100%|██████████| 49/49 [00:00<00:00, 300.38it/s]


Val accuracy: 0.4642857142857143
Epoch 71/1000, Loss: 0.1059


100%|██████████| 13/13 [00:00<00:00, 115.13it/s]
100%|██████████| 49/49 [00:00<00:00, 329.62it/s]


Val accuracy: 0.4642857142857143
Epoch 81/1000, Loss: 0.1102


100%|██████████| 13/13 [00:00<00:00, 112.68it/s]
100%|██████████| 49/49 [00:00<00:00, 274.74it/s]


Val accuracy: 0.4872448979591837
Epoch 91/1000, Loss: 0.1079


100%|██████████| 13/13 [00:00<00:00, 163.49it/s]
100%|██████████| 49/49 [00:00<00:00, 431.51it/s]

Val accuracy: 0.46683673469387754





Epoch 101/1000, Loss: 0.1056


100%|██████████| 13/13 [00:00<00:00, 158.72it/s]
100%|██████████| 49/49 [00:00<00:00, 265.82it/s]


Val accuracy: 0.4642857142857143
Epoch 111/1000, Loss: 0.1051


100%|██████████| 13/13 [00:00<00:00, 128.92it/s]
100%|██████████| 49/49 [00:00<00:00, 328.52it/s]


Val accuracy: 0.4489795918367347
Epoch 121/1000, Loss: 0.1098


100%|██████████| 13/13 [00:00<00:00, 153.69it/s]
100%|██████████| 49/49 [00:00<00:00, 312.55it/s]


Val accuracy: 0.47959183673469385
Epoch 131/1000, Loss: 0.1095


100%|██████████| 13/13 [00:00<00:00, 142.24it/s]
100%|██████████| 49/49 [00:00<00:00, 339.68it/s]


Val accuracy: 0.4770408163265306
Epoch 141/1000, Loss: 0.1055


100%|██████████| 13/13 [00:00<00:00, 123.72it/s]
100%|██████████| 49/49 [00:00<00:00, 291.43it/s]


Val accuracy: 0.44642857142857145
Epoch 151/1000, Loss: 0.1062


100%|██████████| 13/13 [00:00<00:00, 179.02it/s]
100%|██████████| 49/49 [00:00<00:00, 293.70it/s]


Val accuracy: 0.45408163265306123
Epoch 161/1000, Loss: 0.1102


100%|██████████| 13/13 [00:00<00:00, 126.52it/s]
100%|██████████| 49/49 [00:00<00:00, 235.28it/s]


Val accuracy: 0.4719387755102041
Epoch 171/1000, Loss: 0.1058


100%|██████████| 13/13 [00:00<00:00, 148.00it/s]
100%|██████████| 49/49 [00:00<00:00, 306.64it/s]


Val accuracy: 0.461734693877551
Epoch 181/1000, Loss: 0.1084


100%|██████████| 13/13 [00:00<00:00, 140.92it/s]
100%|██████████| 49/49 [00:00<00:00, 304.50it/s]


Val accuracy: 0.4872448979591837
Epoch 191/1000, Loss: 0.1057


100%|██████████| 13/13 [00:00<00:00, 146.75it/s]
100%|██████████| 49/49 [00:00<00:00, 241.66it/s]


Val accuracy: 0.45918367346938777
Epoch 201/1000, Loss: 0.1060


100%|██████████| 13/13 [00:00<00:00, 149.15it/s]
100%|██████████| 49/49 [00:00<00:00, 433.60it/s]


Val accuracy: 0.4744897959183674
Epoch 211/1000, Loss: 0.1067


100%|██████████| 13/13 [00:00<00:00, 127.75it/s]
100%|██████████| 49/49 [00:00<00:00, 421.29it/s]


Val accuracy: 0.4897959183673469
Epoch 221/1000, Loss: 0.1112


100%|██████████| 13/13 [00:00<00:00, 152.30it/s]
100%|██████████| 49/49 [00:00<00:00, 325.30it/s]


Val accuracy: 0.45408163265306123
Epoch 231/1000, Loss: 0.1069


100%|██████████| 13/13 [00:00<00:00, 152.58it/s]
100%|██████████| 49/49 [00:00<00:00, 252.06it/s]


Val accuracy: 0.4770408163265306
Epoch 241/1000, Loss: 0.1091


100%|██████████| 13/13 [00:00<00:00, 158.63it/s]
100%|██████████| 49/49 [00:00<00:00, 405.85it/s]


Val accuracy: 0.4846938775510204
Epoch 251/1000, Loss: 0.1050


100%|██████████| 13/13 [00:00<00:00, 141.54it/s]
100%|██████████| 49/49 [00:00<00:00, 303.52it/s]


Val accuracy: 0.5076530612244898
Epoch 261/1000, Loss: 0.1094


100%|██████████| 13/13 [00:00<00:00, 149.27it/s]
100%|██████████| 49/49 [00:00<00:00, 240.49it/s]


Val accuracy: 0.5255102040816326
Epoch 271/1000, Loss: 0.1107


100%|██████████| 13/13 [00:00<00:00, 136.91it/s]
100%|██████████| 49/49 [00:00<00:00, 361.50it/s]


Val accuracy: 0.4897959183673469
Epoch 281/1000, Loss: 0.1007


100%|██████████| 13/13 [00:00<00:00, 198.23it/s]
100%|██████████| 49/49 [00:00<00:00, 250.48it/s]


Val accuracy: 0.5076530612244898
Epoch 291/1000, Loss: 0.1115


100%|██████████| 13/13 [00:00<00:00, 197.62it/s]
100%|██████████| 49/49 [00:00<00:00, 273.46it/s]


Val accuracy: 0.5306122448979592
Epoch 301/1000, Loss: 0.1048


100%|██████████| 13/13 [00:00<00:00, 122.32it/s]
100%|██████████| 49/49 [00:00<00:00, 303.34it/s]


Val accuracy: 0.5280612244897959
Epoch 311/1000, Loss: 0.1074


100%|██████████| 13/13 [00:00<00:00, 155.41it/s]
100%|██████████| 49/49 [00:00<00:00, 301.10it/s]


Val accuracy: 0.5076530612244898
Epoch 321/1000, Loss: 0.1127


100%|██████████| 13/13 [00:00<00:00, 200.13it/s]
100%|██████████| 49/49 [00:00<00:00, 420.03it/s]

Val accuracy: 0.5255102040816326





Epoch 331/1000, Loss: 0.1112


100%|██████████| 13/13 [00:00<00:00, 147.70it/s]
100%|██████████| 49/49 [00:00<00:00, 348.27it/s]


Val accuracy: 0.5459183673469388
Epoch 341/1000, Loss: 0.1074


100%|██████████| 13/13 [00:00<00:00, 157.96it/s]
100%|██████████| 49/49 [00:00<00:00, 315.19it/s]


Val accuracy: 0.5204081632653061
Epoch 351/1000, Loss: 0.1113


100%|██████████| 13/13 [00:00<00:00, 170.60it/s]
100%|██████████| 49/49 [00:00<00:00, 306.47it/s]


Val accuracy: 0.5918367346938775
Epoch 361/1000, Loss: 0.1119


100%|██████████| 13/13 [00:00<00:00, 143.77it/s]
100%|██████████| 49/49 [00:00<00:00, 278.46it/s]


Val accuracy: 0.5204081632653061
Epoch 371/1000, Loss: 0.1068


100%|██████████| 13/13 [00:00<00:00, 161.31it/s]
100%|██████████| 49/49 [00:00<00:00, 335.81it/s]


Val accuracy: 0.5
Epoch 381/1000, Loss: 0.1057


100%|██████████| 13/13 [00:00<00:00, 152.24it/s]
100%|██████████| 49/49 [00:00<00:00, 283.59it/s]


Val accuracy: 0.548469387755102
Epoch 391/1000, Loss: 0.1068


100%|██████████| 13/13 [00:00<00:00, 136.62it/s]
100%|██████████| 49/49 [00:00<00:00, 298.60it/s]


Val accuracy: 0.4897959183673469
Epoch 401/1000, Loss: 0.1035


100%|██████████| 13/13 [00:00<00:00, 140.55it/s]
100%|██████████| 49/49 [00:00<00:00, 338.46it/s]


Val accuracy: 0.5076530612244898
Epoch 411/1000, Loss: 0.1013


100%|██████████| 13/13 [00:00<00:00, 157.68it/s]
100%|██████████| 49/49 [00:00<00:00, 291.39it/s]


Val accuracy: 0.548469387755102
Epoch 421/1000, Loss: 0.1040


100%|██████████| 13/13 [00:00<00:00, 199.00it/s]
100%|██████████| 49/49 [00:00<00:00, 340.65it/s]


Val accuracy: 0.5178571428571429
Epoch 431/1000, Loss: 0.1015


100%|██████████| 13/13 [00:00<00:00, 145.88it/s]
100%|██████████| 49/49 [00:00<00:00, 303.74it/s]


Val accuracy: 0.5229591836734694
Epoch 441/1000, Loss: 0.1078


100%|██████████| 13/13 [00:00<00:00, 164.70it/s]
100%|██████████| 49/49 [00:00<00:00, 284.35it/s]


Val accuracy: 0.5
Epoch 451/1000, Loss: 0.1084


100%|██████████| 13/13 [00:00<00:00, 145.83it/s]
100%|██████████| 49/49 [00:00<00:00, 285.53it/s]


Val accuracy: 0.4770408163265306
Epoch 461/1000, Loss: 0.1052


100%|██████████| 13/13 [00:00<00:00, 129.14it/s]
100%|██████████| 49/49 [00:00<00:00, 359.28it/s]


Val accuracy: 0.5688775510204082
Epoch 471/1000, Loss: 0.1036


100%|██████████| 13/13 [00:00<00:00, 162.33it/s]
100%|██████████| 49/49 [00:00<00:00, 320.03it/s]


Val accuracy: 0.5357142857142857
Epoch 481/1000, Loss: 0.1053


100%|██████████| 13/13 [00:00<00:00, 122.40it/s]
100%|██████████| 49/49 [00:00<00:00, 338.28it/s]


Val accuracy: 0.5127551020408163
Epoch 491/1000, Loss: 0.1055


100%|██████████| 13/13 [00:00<00:00, 167.22it/s]
100%|██████████| 49/49 [00:00<00:00, 275.19it/s]


Val accuracy: 0.49744897959183676
Epoch 501/1000, Loss: 0.1074


100%|██████████| 13/13 [00:00<00:00, 150.74it/s]
100%|██████████| 49/49 [00:00<00:00, 306.11it/s]


Val accuracy: 0.576530612244898
Epoch 511/1000, Loss: 0.1026


100%|██████████| 13/13 [00:00<00:00, 127.69it/s]
100%|██████████| 49/49 [00:00<00:00, 317.33it/s]


Val accuracy: 0.5204081632653061
Epoch 521/1000, Loss: 0.1016


100%|██████████| 13/13 [00:00<00:00, 155.97it/s]
100%|██████████| 49/49 [00:00<00:00, 359.06it/s]


Val accuracy: 0.5433673469387755
Epoch 531/1000, Loss: 0.1038


100%|██████████| 13/13 [00:00<00:00, 199.31it/s]
100%|██████████| 49/49 [00:00<00:00, 366.23it/s]


Val accuracy: 0.5255102040816326
Epoch 541/1000, Loss: 0.1010


100%|██████████| 13/13 [00:00<00:00, 147.55it/s]
100%|██████████| 49/49 [00:00<00:00, 280.82it/s]


Val accuracy: 0.548469387755102
Epoch 551/1000, Loss: 0.1054


100%|██████████| 13/13 [00:00<00:00, 98.60it/s] 
100%|██████████| 49/49 [00:00<00:00, 275.78it/s]


Val accuracy: 0.5510204081632653
Epoch 561/1000, Loss: 0.0989


100%|██████████| 13/13 [00:00<00:00, 143.95it/s]
100%|██████████| 49/49 [00:00<00:00, 245.89it/s]


Val accuracy: 0.576530612244898
Epoch 571/1000, Loss: 0.1026


100%|██████████| 13/13 [00:00<00:00, 175.18it/s]
100%|██████████| 49/49 [00:00<00:00, 382.45it/s]


Val accuracy: 0.5841836734693877
Epoch 581/1000, Loss: 0.1032


100%|██████████| 13/13 [00:00<00:00, 173.82it/s]
100%|██████████| 49/49 [00:00<00:00, 282.56it/s]


Val accuracy: 0.5969387755102041
Epoch 591/1000, Loss: 0.1005


100%|██████████| 13/13 [00:00<00:00, 123.87it/s]
100%|██████████| 49/49 [00:00<00:00, 229.06it/s]


Val accuracy: 0.5816326530612245
Epoch 601/1000, Loss: 0.1016


100%|██████████| 13/13 [00:00<00:00, 154.80it/s]
100%|██████████| 49/49 [00:00<00:00, 284.98it/s]


Val accuracy: 0.6505102040816326
Epoch 611/1000, Loss: 0.1051


100%|██████████| 13/13 [00:00<00:00, 143.68it/s]
100%|██████████| 49/49 [00:00<00:00, 288.51it/s]


Val accuracy: 0.5790816326530612
Epoch 621/1000, Loss: 0.1030


100%|██████████| 13/13 [00:00<00:00, 177.65it/s]
100%|██████████| 49/49 [00:00<00:00, 282.98it/s]


Val accuracy: 0.5994897959183674
Epoch 631/1000, Loss: 0.0975


100%|██████████| 13/13 [00:00<00:00, 139.52it/s]
100%|██████████| 49/49 [00:00<00:00, 303.99it/s]


Val accuracy: 0.5969387755102041
Epoch 641/1000, Loss: 0.1026


100%|██████████| 13/13 [00:00<00:00, 150.35it/s]
100%|██████████| 49/49 [00:00<00:00, 224.57it/s]


Val accuracy: 0.6352040816326531
Epoch 651/1000, Loss: 0.1040


100%|██████████| 13/13 [00:00<00:00, 168.00it/s]
100%|██████████| 49/49 [00:00<00:00, 282.03it/s]


Val accuracy: 0.5943877551020408
Epoch 661/1000, Loss: 0.1013


100%|██████████| 13/13 [00:00<00:00, 151.84it/s]
100%|██████████| 49/49 [00:00<00:00, 258.61it/s]


Val accuracy: 0.6301020408163265
Epoch 671/1000, Loss: 0.1028


100%|██████████| 13/13 [00:00<00:00, 188.03it/s]
100%|██████████| 49/49 [00:00<00:00, 266.88it/s]


Val accuracy: 0.6301020408163265
Epoch 681/1000, Loss: 0.1039


100%|██████████| 13/13 [00:00<00:00, 184.79it/s]
100%|██████████| 49/49 [00:00<00:00, 240.34it/s]


Val accuracy: 0.6173469387755102
Epoch 691/1000, Loss: 0.1030


100%|██████████| 13/13 [00:00<00:00, 140.28it/s]
100%|██████████| 49/49 [00:00<00:00, 287.65it/s]


Val accuracy: 0.6301020408163265
Epoch 701/1000, Loss: 0.0993


100%|██████████| 13/13 [00:00<00:00, 146.69it/s]
100%|██████████| 49/49 [00:00<00:00, 295.34it/s]


Val accuracy: 0.6556122448979592
Epoch 711/1000, Loss: 0.1011


100%|██████████| 13/13 [00:00<00:00, 167.06it/s]
100%|██████████| 49/49 [00:00<00:00, 326.30it/s]


Val accuracy: 0.6760204081632653
Epoch 721/1000, Loss: 0.0997


100%|██████████| 13/13 [00:00<00:00, 151.87it/s]
100%|██████████| 49/49 [00:00<00:00, 262.14it/s]


Val accuracy: 0.6301020408163265
Epoch 731/1000, Loss: 0.1091


100%|██████████| 13/13 [00:00<00:00, 168.50it/s]
100%|██████████| 49/49 [00:00<00:00, 299.20it/s]


Val accuracy: 0.6479591836734694
Epoch 741/1000, Loss: 0.1138


100%|██████████| 13/13 [00:00<00:00, 119.70it/s]
100%|██████████| 49/49 [00:00<00:00, 238.11it/s]


Val accuracy: 0.6454081632653061
Epoch 751/1000, Loss: 0.0997


100%|██████████| 13/13 [00:00<00:00, 142.89it/s]
100%|██████████| 49/49 [00:00<00:00, 337.28it/s]


Val accuracy: 0.6556122448979592
Epoch 761/1000, Loss: 0.0996


100%|██████████| 13/13 [00:00<00:00, 128.95it/s]
100%|██████████| 49/49 [00:00<00:00, 280.76it/s]


Val accuracy: 0.6658163265306123
Epoch 771/1000, Loss: 0.1041


100%|██████████| 13/13 [00:00<00:00, 169.84it/s]
100%|██████████| 49/49 [00:00<00:00, 278.06it/s]


Val accuracy: 0.7040816326530612
Epoch 781/1000, Loss: 0.1011


100%|██████████| 13/13 [00:00<00:00, 157.45it/s]
100%|██████████| 49/49 [00:00<00:00, 252.12it/s]


Val accuracy: 0.5255102040816326
Epoch 791/1000, Loss: 0.1047


100%|██████████| 13/13 [00:00<00:00, 184.09it/s]
100%|██████████| 49/49 [00:00<00:00, 285.76it/s]


Val accuracy: 0.5612244897959183
Epoch 801/1000, Loss: 0.1006


100%|██████████| 13/13 [00:00<00:00, 146.91it/s]
100%|██████████| 49/49 [00:00<00:00, 398.97it/s]


Val accuracy: 0.5739795918367347
Epoch 811/1000, Loss: 0.0984


100%|██████████| 13/13 [00:00<00:00, 114.54it/s]
100%|██████████| 49/49 [00:00<00:00, 281.52it/s]


Val accuracy: 0.5892857142857143
Epoch 821/1000, Loss: 0.1014


100%|██████████| 13/13 [00:00<00:00, 158.79it/s]
100%|██████████| 49/49 [00:00<00:00, 267.89it/s]


Val accuracy: 0.6147959183673469
Epoch 831/1000, Loss: 0.1018


100%|██████████| 13/13 [00:00<00:00, 193.40it/s]
100%|██████████| 49/49 [00:00<00:00, 375.60it/s]


Val accuracy: 0.6862244897959183
Epoch 841/1000, Loss: 0.0999


100%|██████████| 13/13 [00:00<00:00, 142.70it/s]
100%|██████████| 49/49 [00:00<00:00, 293.19it/s]


Val accuracy: 0.6454081632653061
Epoch 851/1000, Loss: 0.1044


100%|██████████| 13/13 [00:00<00:00, 142.85it/s]
100%|██████████| 49/49 [00:00<00:00, 276.40it/s]


Val accuracy: 0.6581632653061225
Epoch 861/1000, Loss: 0.1054


100%|██████████| 13/13 [00:00<00:00, 143.64it/s]
100%|██████████| 49/49 [00:00<00:00, 328.31it/s]


Val accuracy: 0.6301020408163265
Epoch 871/1000, Loss: 0.1044


100%|██████████| 13/13 [00:00<00:00, 139.31it/s]
100%|██████████| 49/49 [00:00<00:00, 265.96it/s]


Val accuracy: 0.6377551020408163
Epoch 881/1000, Loss: 0.0907


100%|██████████| 13/13 [00:00<00:00, 139.82it/s]
100%|██████████| 49/49 [00:00<00:00, 299.10it/s]


Val accuracy: 0.6709183673469388
Epoch 891/1000, Loss: 0.0948


100%|██████████| 13/13 [00:00<00:00, 188.11it/s]
100%|██████████| 49/49 [00:00<00:00, 334.73it/s]


Val accuracy: 0.6964285714285714
Epoch 901/1000, Loss: 0.0944


100%|██████████| 13/13 [00:00<00:00, 165.67it/s]
100%|██████████| 49/49 [00:00<00:00, 285.96it/s]


Val accuracy: 0.6683673469387755
Epoch 911/1000, Loss: 0.0898


100%|██████████| 13/13 [00:00<00:00, 144.18it/s]
100%|██████████| 49/49 [00:00<00:00, 296.01it/s]


Val accuracy: 0.7831632653061225
Epoch 921/1000, Loss: 0.0907


100%|██████████| 13/13 [00:00<00:00, 192.54it/s]
100%|██████████| 49/49 [00:00<00:00, 388.09it/s]

Val accuracy: 0.7678571428571429





Epoch 931/1000, Loss: 0.0887


100%|██████████| 13/13 [00:00<00:00, 142.53it/s]
100%|██████████| 49/49 [00:00<00:00, 279.56it/s]


Val accuracy: 0.7627551020408163
Epoch 941/1000, Loss: 0.0878


100%|██████████| 13/13 [00:00<00:00, 181.15it/s]
100%|██████████| 49/49 [00:00<00:00, 307.41it/s]


Val accuracy: 0.8137755102040817
Epoch 951/1000, Loss: 0.0882


100%|██████████| 13/13 [00:00<00:00, 179.87it/s]
100%|██████████| 49/49 [00:00<00:00, 377.54it/s]


Val accuracy: 0.7755102040816326
Epoch 961/1000, Loss: 0.0801


100%|██████████| 13/13 [00:00<00:00, 139.55it/s]
100%|██████████| 49/49 [00:00<00:00, 244.10it/s]


Val accuracy: 0.8112244897959183
Epoch 971/1000, Loss: 0.0926


100%|██████████| 13/13 [00:00<00:00, 157.85it/s]
100%|██████████| 49/49 [00:00<00:00, 263.97it/s]


Val accuracy: 0.7295918367346939
Epoch 981/1000, Loss: 0.0805


100%|██████████| 13/13 [00:00<00:00, 151.08it/s]
100%|██████████| 49/49 [00:00<00:00, 268.21it/s]


Val accuracy: 0.8112244897959183
Epoch 991/1000, Loss: 0.0899


100%|██████████| 13/13 [00:00<00:00, 154.83it/s]
100%|██████████| 49/49 [00:00<00:00, 259.30it/s]


Val accuracy: 0.8163265306122449


In [None]:
model.load_state_dict(torch.load("/home/s/python_progs/DL_homeworks/deep-metric-ts-clustering/notebooks/model/model_GRU_absolute_best.pth", map_location=torch.device("cuda")))
model.to("cuda")

model.eval()

with torch.no_grad():
    # as all embeddings need to be stored in memory
    # you can set DEVICE = torch.device('cpu') in case gpu memory overflow occurs
    accuracy = get_accuracy(val_dataset, train_dataset, model, device.type)
    
    print(f"Val accuracy: {accuracy}")

RuntimeError: Error(s) in loading state_dict for Encoder:
	Missing key(s) in state_dict: "lstm.weight_ih_l0", "lstm.weight_hh_l0", "lstm.bias_ih_l0", "lstm.bias_hh_l0", "lstm.weight_ih_l1", "lstm.weight_hh_l1", "lstm.bias_ih_l1", "lstm.bias_hh_l1", "lstm.weight_ih_l2", "lstm.weight_hh_l2", "lstm.bias_ih_l2", "lstm.bias_hh_l2", "lstm.weight_ih_l3", "lstm.weight_hh_l3", "lstm.bias_ih_l3", "lstm.bias_hh_l3", "lstm.weight_ih_l4", "lstm.weight_hh_l4", "lstm.bias_ih_l4", "lstm.bias_hh_l4". 
	Unexpected key(s) in state_dict: "gru.weight_ih_l0", "gru.weight_hh_l0", "gru.bias_ih_l0", "gru.bias_hh_l0", "gru.weight_ih_l1", "gru.weight_hh_l1", "gru.bias_ih_l1", "gru.bias_hh_l1", "gru.weight_ih_l2", "gru.weight_hh_l2", "gru.bias_ih_l2", "gru.bias_hh_l2", "gru.weight_ih_l3", "gru.weight_hh_l3", "gru.bias_ih_l3", "gru.bias_hh_l3", "gru.weight_ih_l4", "gru.weight_hh_l4", "gru.bias_ih_l4", "gru.bias_hh_l4". 

2025-05-27 22:01:46.628948
