In [1]:
import yfinance as yf
import pandas as pd
import numpy as np
import random
import torch

import sys
sys.path.append('../src/utils') 
from common import get_accuracy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SAVE_PATH = "/home/s/python_progs/DL_homeworks/deep-metric-ts-clustering/notebooks/model"

In [2]:
# Select diverse stocks from different sectors
tickers = ['AAPL', 'MSFT', 'GOOGL', 'TSLA', 'AMZN', 'JPM', 'JNJ', 'XOM', 'NVDA', 'WMT']
ticker_to_class = {ticker: i for i, ticker in enumerate(tickers)}

data = yf.download(tickers, start="2020-01-01", end="2024-01-01")['Close']
data.dropna(inplace=True)
print(data)

YF.download() has changed argument auto_adjust default to True


[*********************100%***********************]  10 of 10 completed

Ticker            AAPL        AMZN       GOOGL         JNJ         JPM  \
Date                                                                     
2020-01-02   72.620834   94.900497   68.108376  126.055138  120.733566   
2020-01-03   71.914810   93.748497   67.752075  124.595726  119.140312   
2020-01-06   72.487854   95.143997   69.557945  124.440300  119.045601   
2020-01-07   72.146927   95.343002   69.423592  125.200226  117.021736   
2020-01-08   73.307503   94.598503   69.917725  125.182991  117.934631   
...                ...         ...         ...         ...         ...   
2023-12-22  192.192551  153.419998  140.816757  149.492691  161.660004   
2023-12-26  191.646561  153.410004  140.846634  150.146591  162.616074   
2023-12-27  191.745819  153.339996  139.702087  150.348541  163.591431   
2023-12-28  192.172699  153.380005  139.562759  150.569702  164.460587   
2023-12-29  191.130310  151.940002  139.025330  150.723557  164.267441   

Ticker            MSFT       NVDA    




In [3]:
def create_windows(series, window_size=30, step=1):
    windows = []
    for i in range(0, len(series) - window_size + 1, step):
        window = series[i:i + window_size].values
        windows.append(window)
    return windows


def create_dataset(windows):
    dataset = []
    for window in windows:

        dataset.append((window['window'], ticker_to_class[window['ticker']]))

    return dataset


In [4]:
window_size = 30
step = 5  # reduces overlap between samples
stock_windows = []  # Dictionary of ticker → list of windows

for ticker in tickers:
    for window in create_windows(data[ticker], window_size, step):
        stock_windows.append({
                'ticker': ticker,
                'window': window
            })
print(stock_windows[0])

{'ticker': 'AAPL', 'window': array([72.62083435, 71.91481018, 72.487854  , 72.14692688, 73.30750275,
       74.86463928, 75.03386688, 76.63692474, 75.6020813 , 75.27806091,
       76.22105408, 77.06490326, 76.5426178 , 76.81585693, 77.18579865,
       76.96334839, 74.70020294, 76.81342316, 78.42131042, 78.30766296,
       74.83559418, 74.63008881, 77.09391022, 77.72254181, 78.63166046,
       77.56287384, 77.93125153, 77.46105194, 79.30059052, 78.73588562])}


In [5]:
dataset = create_dataset(stock_windows)

In [6]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import random_split

class StockDataset(Dataset):
    def __init__(self, stock_windows):
        self.stock_windows = stock_windows
        self.labels = [label for _, label in stock_windows]  # Extract and store labels

    def __len__(self):
        return len(self.stock_windows)

    def __getitem__(self, idx):
        x, label = self.stock_windows[idx]
        x = torch.tensor(x, dtype=torch.float32).unsqueeze(0)  # shape [1, 30]
        label = torch.tensor(label, dtype=torch.float32)
        return x, label

dataset = StockDataset(dataset)


train_ratio = 0.8
train_size = int(train_ratio * len(dataset))
val_size = len(dataset) - train_size


train_dataset, val_dataset = random_split(dataset, [train_size, val_size])



In [7]:
import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, input_dim, encoder_type, embedding_dim=128):
        super().__init__()

        self.encoder_type = encoder_type
        self.embedding_dim = embedding_dim


        if encoder_type=="CNN":
            self.conv_layers = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(16, 32, 5, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(2),
            )
            # Calculate final output size after conv and pooling layers based on input length
            self.fc = nn.Linear(32 * (input_dim // 4), embedding_dim)


        if encoder_type=="GRU":
            self.num_layers = 5
            self.hidden_size = 256
            self.gru = nn.GRU(1, self.hidden_size, self.num_layers, batch_first=True)
            self.fc = nn.Linear(self.hidden_size, self.embedding_dim)

        if encoder_type=="LSTM":    
            self.num_layers = 2
            self.hidden_size = 256
            self.lstm = nn.LSTM(1, self.hidden_size, self.num_layers, batch_first=True)
            self.fc = nn.Linear(self.hidden_size, self.embedding_dim)

    def forward(self, x):
        if self.encoder_type=="CNN":
            # x = x.unsqueeze(1)  # add channel dim: (batch_size, 1, seq_len)
            x = self.conv_layers(x)
            x = x.view(x.size(0), -1)  # flatten
            x = self.fc(x)
        
        if self.encoder_type=="GRU":
            x = x.reshape(-1, window_size, 1)

            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

            x, _ = self.gru(x, h0)
            x = x[:, -1, :]
            x = self.fc(x)

        if self.encoder_type=="LSTM":
            x = x.reshape(-1, window_size, 1)

            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
            c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)


            x, _ = self.lstm(x, (h0,c0))
            x = x[:, -1, :]
            x = self.fc(x)


        return x




In [8]:

from torch.utils.data import DataLoader
from pytorch_metric_learning import distances, losses, miners, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

distance = distances.CosineSimilarity()
reducer = reducers.ThresholdReducer(low=0)
mining_func = miners.TripletMarginMiner(margin=0.2, distance=distance, type_of_triplets="semihard")
loss_func = losses.TripletMarginLoss(margin=0.2, distance=distance, reducer=reducer)

# model = Encoder(window_size, 'CNN')
# model = Encoder(window_size, 'LSTM')
model = Encoder(window_size, 'GRU')
# model = Encoder(window_size, 'TRANSFORMER')

model.to(device)

Encoder(
  (gru): GRU(1, 256, num_layers=5, batch_first=True)
  (fc): Linear(in_features=256, out_features=128, bias=True)
)

In [10]:
import torch
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
import torch.optim as optim
from tqdm import tqdm
import json

batch_size = 512

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
history = {"train": [], "val": [], "best_accuracy": 0.0}

scaler = GradScaler()  # Helps manage fp16 gradients

optimizer = optim.Adam(model.parameters(), lr=1e-3)
epochs = 1000

for epoch in range(epochs):
    model.train()
    epoch_loss = 0

    for x, label in train_loader:
        x, label = x.to(device), label.to(device)

        optimizer.zero_grad()

        with autocast(dtype=torch.float16, device_type=device.type): 
            embeddings = model(x)
            indices_tuple = mining_func(embeddings, label)
            loss = loss_func(embeddings, label, indices_tuple)

        # Backpropagation with scaled gradients
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()




    if epoch %10 == 0:
        # model validation
        model.eval()
        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
        
        with torch.no_grad():
            # as all embeddings need to be stored in memory
            # you can set DEVICE = torch.device('cpu') in case gpu memory overflow occurs
            accuracy = get_accuracy(val_dataset, train_dataset, model, device.type)
            
            history["val"].append({"epoch": epoch, "accuracy": accuracy})
            print(f"Val accuracy: {accuracy}")
            
            # save model
            torch.save(model.state_dict(), f"{SAVE_PATH}/model_latest.pth")
            
            if accuracy >= history["best_accuracy"]:
                history["best_accuracy"] = accuracy
                torch.save(model.state_dict(), f"{SAVE_PATH}/model_best.pth")

            with open("history.json", "w") as f:
                f.write(json.dumps(history))

Epoch 1/1000, Loss: 0.0956


100%|██████████| 13/13 [00:00<00:00, 120.48it/s]
100%|██████████| 49/49 [00:00<00:00, 290.14it/s]


Val accuracy: 0.6632653061224489
Epoch 11/1000, Loss: 0.0844


100%|██████████| 13/13 [00:00<00:00, 145.20it/s]
100%|██████████| 49/49 [00:00<00:00, 235.16it/s]


Val accuracy: 0.8341836734693877
Epoch 21/1000, Loss: 0.0945


100%|██████████| 13/13 [00:00<00:00, 136.08it/s]
100%|██████████| 49/49 [00:00<00:00, 319.65it/s]


Val accuracy: 0.826530612244898
Epoch 31/1000, Loss: 0.0772


100%|██████████| 13/13 [00:00<00:00, 144.11it/s]
100%|██████████| 49/49 [00:00<00:00, 354.84it/s]


Val accuracy: 0.8494897959183674
Epoch 41/1000, Loss: 0.0789


100%|██████████| 13/13 [00:00<00:00, 175.10it/s]
100%|██████████| 49/49 [00:00<00:00, 280.49it/s]


Val accuracy: 0.8698979591836735
Epoch 51/1000, Loss: 0.0664


100%|██████████| 13/13 [00:00<00:00, 113.35it/s]
100%|██████████| 49/49 [00:00<00:00, 307.10it/s]


Val accuracy: 0.8443877551020408
Epoch 61/1000, Loss: 0.0555


100%|██████████| 13/13 [00:00<00:00, 120.93it/s]
100%|██████████| 49/49 [00:00<00:00, 282.21it/s]


Val accuracy: 0.875
Epoch 71/1000, Loss: 0.0812


100%|██████████| 13/13 [00:00<00:00, 125.97it/s]
100%|██████████| 49/49 [00:00<00:00, 352.53it/s]


Val accuracy: 0.8392857142857143
Epoch 81/1000, Loss: 0.0716


100%|██████████| 13/13 [00:00<00:00, 127.84it/s]
100%|██████████| 49/49 [00:00<00:00, 294.96it/s]


Val accuracy: 0.8571428571428571
Epoch 91/1000, Loss: 0.0572


100%|██████████| 13/13 [00:00<00:00, 164.29it/s]
100%|██████████| 49/49 [00:00<00:00, 342.90it/s]


Val accuracy: 0.8698979591836735
Epoch 101/1000, Loss: 0.0533


100%|██████████| 13/13 [00:00<00:00, 177.19it/s]
100%|██████████| 49/49 [00:00<00:00, 243.43it/s]


Val accuracy: 0.8494897959183674
Epoch 111/1000, Loss: 0.0523


100%|██████████| 13/13 [00:00<00:00, 121.70it/s]
100%|██████████| 49/49 [00:00<00:00, 380.15it/s]


Val accuracy: 0.8596938775510204
Epoch 121/1000, Loss: 0.0734


100%|██████████| 13/13 [00:00<00:00, 121.90it/s]
100%|██████████| 49/49 [00:00<00:00, 297.59it/s]


Val accuracy: 0.8469387755102041
Epoch 131/1000, Loss: 0.0874


100%|██████████| 13/13 [00:00<00:00, 128.47it/s]
100%|██████████| 49/49 [00:00<00:00, 277.04it/s]


Val accuracy: 0.8545918367346939
Epoch 141/1000, Loss: 0.0872


100%|██████████| 13/13 [00:00<00:00, 165.02it/s]
100%|██████████| 49/49 [00:00<00:00, 286.03it/s]


Val accuracy: 0.875
Epoch 151/1000, Loss: 0.0780


100%|██████████| 13/13 [00:00<00:00, 136.77it/s]
100%|██████████| 49/49 [00:00<00:00, 351.13it/s]


Val accuracy: 0.8698979591836735
Epoch 161/1000, Loss: 0.0504


100%|██████████| 13/13 [00:00<00:00, 134.48it/s]
100%|██████████| 49/49 [00:00<00:00, 311.41it/s]


Val accuracy: 0.8596938775510204
Epoch 171/1000, Loss: 0.0589


100%|██████████| 13/13 [00:00<00:00, 127.72it/s]
100%|██████████| 49/49 [00:00<00:00, 300.05it/s]


Val accuracy: 0.8698979591836735
Epoch 181/1000, Loss: 0.0496


100%|██████████| 13/13 [00:00<00:00, 117.89it/s]
100%|██████████| 49/49 [00:00<00:00, 294.34it/s]


Val accuracy: 0.8494897959183674
Epoch 191/1000, Loss: 0.0364


100%|██████████| 13/13 [00:00<00:00, 130.39it/s]
100%|██████████| 49/49 [00:00<00:00, 295.75it/s]


Val accuracy: 0.8367346938775511
Epoch 201/1000, Loss: 0.0491


100%|██████████| 13/13 [00:00<00:00, 169.79it/s]
100%|██████████| 49/49 [00:00<00:00, 239.94it/s]


Val accuracy: 0.8673469387755102
Epoch 211/1000, Loss: 0.0817


100%|██████████| 13/13 [00:00<00:00, 127.92it/s]
100%|██████████| 49/49 [00:00<00:00, 372.05it/s]


Val accuracy: 0.798469387755102
Epoch 221/1000, Loss: 0.0569


100%|██████████| 13/13 [00:00<00:00, 161.12it/s]
100%|██████████| 49/49 [00:00<00:00, 265.84it/s]


Val accuracy: 0.8571428571428571
Epoch 231/1000, Loss: 0.0619


100%|██████████| 13/13 [00:00<00:00, 170.26it/s]
100%|██████████| 49/49 [00:00<00:00, 422.83it/s]

Val accuracy: 0.8877551020408163





Epoch 241/1000, Loss: 0.0729


100%|██████████| 13/13 [00:00<00:00, 132.55it/s]
100%|██████████| 49/49 [00:00<00:00, 397.46it/s]


Val accuracy: 0.826530612244898
Epoch 251/1000, Loss: 0.0925


100%|██████████| 13/13 [00:00<00:00, 136.23it/s]
100%|██████████| 49/49 [00:00<00:00, 325.65it/s]


Val accuracy: 0.8316326530612245
Epoch 261/1000, Loss: 0.0648


100%|██████████| 13/13 [00:00<00:00, 128.54it/s]
100%|██████████| 49/49 [00:00<00:00, 226.18it/s]


Val accuracy: 0.8647959183673469
Epoch 271/1000, Loss: 0.0677


100%|██████████| 13/13 [00:00<00:00, 117.34it/s]
100%|██████████| 49/49 [00:00<00:00, 286.50it/s]


Val accuracy: 0.8392857142857143
Epoch 281/1000, Loss: 0.0546


100%|██████████| 13/13 [00:00<00:00, 146.70it/s]
100%|██████████| 49/49 [00:00<00:00, 286.58it/s]


Val accuracy: 0.8494897959183674
Epoch 291/1000, Loss: 0.0973


100%|██████████| 13/13 [00:00<00:00, 136.59it/s]
100%|██████████| 49/49 [00:00<00:00, 272.12it/s]


Val accuracy: 0.8163265306122449
Epoch 301/1000, Loss: 0.0594


100%|██████████| 13/13 [00:00<00:00, 131.28it/s]
100%|██████████| 49/49 [00:00<00:00, 291.54it/s]


Val accuracy: 0.8852040816326531
Epoch 311/1000, Loss: 0.0578


100%|██████████| 13/13 [00:00<00:00, 148.54it/s]
100%|██████████| 49/49 [00:00<00:00, 314.77it/s]


Val accuracy: 0.8545918367346939
Epoch 321/1000, Loss: 0.0604


100%|██████████| 13/13 [00:00<00:00, 138.99it/s]
100%|██████████| 49/49 [00:00<00:00, 310.71it/s]


Val accuracy: 0.8520408163265306
Epoch 331/1000, Loss: 0.0697


100%|██████████| 13/13 [00:00<00:00, 145.72it/s]
100%|██████████| 49/49 [00:00<00:00, 395.42it/s]


Val accuracy: 0.8622448979591837
Epoch 341/1000, Loss: 0.0369


100%|██████████| 13/13 [00:00<00:00, 161.15it/s]
100%|██████████| 49/49 [00:00<00:00, 292.24it/s]


Val accuracy: 0.875
Epoch 351/1000, Loss: 0.0737


100%|██████████| 13/13 [00:00<00:00, 125.88it/s]
100%|██████████| 49/49 [00:00<00:00, 340.22it/s]


Val accuracy: 0.8775510204081632
Epoch 361/1000, Loss: 0.0398


100%|██████████| 13/13 [00:00<00:00, 162.61it/s]
100%|██████████| 49/49 [00:00<00:00, 316.59it/s]


Val accuracy: 0.8826530612244898
Epoch 371/1000, Loss: 0.0498


100%|██████████| 13/13 [00:00<00:00, 104.11it/s]
100%|██████████| 49/49 [00:00<00:00, 286.69it/s]


Val accuracy: 0.8571428571428571
Epoch 381/1000, Loss: 0.0468


100%|██████████| 13/13 [00:00<00:00, 138.67it/s]
100%|██████████| 49/49 [00:00<00:00, 349.78it/s]


Val accuracy: 0.8877551020408163
Epoch 391/1000, Loss: 0.0558


100%|██████████| 13/13 [00:00<00:00, 136.11it/s]
100%|██████████| 49/49 [00:00<00:00, 424.69it/s]


Val accuracy: 0.826530612244898
Epoch 401/1000, Loss: 0.0443


100%|██████████| 13/13 [00:00<00:00, 96.35it/s] 
100%|██████████| 49/49 [00:00<00:00, 302.61it/s]


Val accuracy: 0.8826530612244898
Epoch 411/1000, Loss: 0.0300


100%|██████████| 13/13 [00:00<00:00, 164.08it/s]
100%|██████████| 49/49 [00:00<00:00, 272.49it/s]


Val accuracy: 0.8826530612244898
Epoch 421/1000, Loss: 0.0487


100%|██████████| 13/13 [00:00<00:00, 135.51it/s]
100%|██████████| 49/49 [00:00<00:00, 386.31it/s]


Val accuracy: 0.9005102040816326
Epoch 431/1000, Loss: 0.0557


100%|██████████| 13/13 [00:00<00:00, 142.02it/s]
100%|██████████| 49/49 [00:00<00:00, 317.76it/s]


Val accuracy: 0.8622448979591837
Epoch 441/1000, Loss: 0.0564


100%|██████████| 13/13 [00:00<00:00, 127.92it/s]
100%|██████████| 49/49 [00:00<00:00, 396.87it/s]


Val accuracy: 0.8545918367346939
Epoch 451/1000, Loss: 0.0373


100%|██████████| 13/13 [00:00<00:00, 106.80it/s]
100%|██████████| 49/49 [00:00<00:00, 270.87it/s]


Val accuracy: 0.8596938775510204
Epoch 461/1000, Loss: 0.0575


100%|██████████| 13/13 [00:00<00:00, 158.06it/s]
100%|██████████| 49/49 [00:00<00:00, 331.71it/s]


Val accuracy: 0.8877551020408163
Epoch 471/1000, Loss: 0.0591


100%|██████████| 13/13 [00:00<00:00, 143.68it/s]
100%|██████████| 49/49 [00:00<00:00, 305.26it/s]


Val accuracy: 0.8775510204081632
Epoch 481/1000, Loss: 0.0582


100%|██████████| 13/13 [00:00<00:00, 121.93it/s]
100%|██████████| 49/49 [00:00<00:00, 261.00it/s]


Val accuracy: 0.8647959183673469
Epoch 491/1000, Loss: 0.0541


100%|██████████| 13/13 [00:00<00:00, 116.40it/s]
100%|██████████| 49/49 [00:00<00:00, 321.03it/s]


Val accuracy: 0.8928571428571429
Epoch 501/1000, Loss: 0.0637


100%|██████████| 13/13 [00:00<00:00, 147.40it/s]
100%|██████████| 49/49 [00:00<00:00, 302.03it/s]


Val accuracy: 0.8801020408163265
Epoch 511/1000, Loss: 0.0672


100%|██████████| 13/13 [00:00<00:00, 149.61it/s]
100%|██████████| 49/49 [00:00<00:00, 318.63it/s]


Val accuracy: 0.8469387755102041
Epoch 521/1000, Loss: 0.0623


100%|██████████| 13/13 [00:00<00:00, 130.81it/s]
100%|██████████| 49/49 [00:00<00:00, 277.50it/s]


Val accuracy: 0.8979591836734694
Epoch 531/1000, Loss: 0.0477


100%|██████████| 13/13 [00:00<00:00, 127.29it/s]
100%|██████████| 49/49 [00:00<00:00, 273.98it/s]


Val accuracy: 0.8698979591836735
Epoch 541/1000, Loss: 0.0046


100%|██████████| 13/13 [00:00<00:00, 144.04it/s]
100%|██████████| 49/49 [00:00<00:00, 424.47it/s]


Val accuracy: 0.9056122448979592
Epoch 551/1000, Loss: 0.0594


100%|██████████| 13/13 [00:00<00:00, 137.07it/s]
100%|██████████| 49/49 [00:00<00:00, 283.34it/s]


Val accuracy: 0.8877551020408163
Epoch 561/1000, Loss: 0.0505


100%|██████████| 13/13 [00:00<00:00, 126.18it/s]
100%|██████████| 49/49 [00:00<00:00, 386.03it/s]


Val accuracy: 0.8979591836734694
Epoch 571/1000, Loss: 0.0504


100%|██████████| 13/13 [00:00<00:00, 155.17it/s]
100%|██████████| 49/49 [00:00<00:00, 250.24it/s]


Val accuracy: 0.8903061224489796
Epoch 581/1000, Loss: 0.0427


100%|██████████| 13/13 [00:00<00:00, 136.83it/s]
100%|██████████| 49/49 [00:00<00:00, 299.03it/s]


Val accuracy: 0.8877551020408163
Epoch 591/1000, Loss: 0.0542


100%|██████████| 13/13 [00:00<00:00, 151.39it/s]
100%|██████████| 49/49 [00:00<00:00, 319.59it/s]


Val accuracy: 0.8979591836734694
Epoch 601/1000, Loss: 0.0445


100%|██████████| 13/13 [00:00<00:00, 137.38it/s]
100%|██████████| 49/49 [00:00<00:00, 254.49it/s]


Val accuracy: 0.9132653061224489
Epoch 611/1000, Loss: 0.0157


100%|██████████| 13/13 [00:00<00:00, 140.77it/s]
100%|██████████| 49/49 [00:00<00:00, 298.23it/s]


Val accuracy: 0.8954081632653061
Epoch 621/1000, Loss: 0.0648


100%|██████████| 13/13 [00:00<00:00, 131.39it/s]
100%|██████████| 49/49 [00:00<00:00, 289.21it/s]


Val accuracy: 0.9132653061224489
Epoch 631/1000, Loss: 0.0362


100%|██████████| 13/13 [00:00<00:00, 117.15it/s]
100%|██████████| 49/49 [00:00<00:00, 265.21it/s]


Val accuracy: 0.8801020408163265
Epoch 641/1000, Loss: 0.1028


100%|██████████| 13/13 [00:00<00:00, 133.99it/s]
100%|██████████| 49/49 [00:00<00:00, 302.63it/s]


Val accuracy: 0.8724489795918368
Epoch 651/1000, Loss: 0.0513


100%|██████████| 13/13 [00:00<00:00, 144.53it/s]
100%|██████████| 49/49 [00:00<00:00, 287.01it/s]


Val accuracy: 0.8673469387755102
Epoch 661/1000, Loss: 0.0447


100%|██████████| 13/13 [00:00<00:00, 121.99it/s]
100%|██████████| 49/49 [00:00<00:00, 409.49it/s]


Val accuracy: 0.8826530612244898
Epoch 671/1000, Loss: 0.0206


100%|██████████| 13/13 [00:00<00:00, 131.75it/s]
100%|██████████| 49/49 [00:00<00:00, 255.20it/s]


Val accuracy: 0.8801020408163265
Epoch 681/1000, Loss: 0.0363


100%|██████████| 13/13 [00:00<00:00, 172.25it/s]
100%|██████████| 49/49 [00:00<00:00, 335.36it/s]


Val accuracy: 0.9056122448979592
Epoch 691/1000, Loss: 0.0456


100%|██████████| 13/13 [00:00<00:00, 131.22it/s]
100%|██████████| 49/49 [00:00<00:00, 303.18it/s]


Val accuracy: 0.8852040816326531
Epoch 701/1000, Loss: 0.0291


100%|██████████| 13/13 [00:00<00:00, 128.07it/s]
100%|██████████| 49/49 [00:00<00:00, 250.14it/s]


Val accuracy: 0.9183673469387755
Epoch 711/1000, Loss: 0.0653


100%|██████████| 13/13 [00:00<00:00, 141.87it/s]
100%|██████████| 49/49 [00:00<00:00, 287.17it/s]


Val accuracy: 0.8928571428571429
Epoch 721/1000, Loss: 0.0826


100%|██████████| 13/13 [00:00<00:00, 129.04it/s]
100%|██████████| 49/49 [00:00<00:00, 258.04it/s]


Val accuracy: 0.8214285714285714
Epoch 731/1000, Loss: 0.0512


100%|██████████| 13/13 [00:00<00:00, 160.05it/s]
100%|██████████| 49/49 [00:00<00:00, 285.24it/s]


Val accuracy: 0.8877551020408163
Epoch 741/1000, Loss: 0.0571


100%|██████████| 13/13 [00:00<00:00, 118.77it/s]
100%|██████████| 49/49 [00:00<00:00, 288.08it/s]


Val accuracy: 0.9158163265306123
Epoch 751/1000, Loss: 0.0443


100%|██████████| 13/13 [00:00<00:00, 126.77it/s]
100%|██████████| 49/49 [00:00<00:00, 309.91it/s]


Val accuracy: 0.8852040816326531
Epoch 761/1000, Loss: 0.0419


100%|██████████| 13/13 [00:00<00:00, 143.94it/s]
100%|██████████| 49/49 [00:00<00:00, 293.99it/s]


Val accuracy: 0.8622448979591837
Epoch 771/1000, Loss: 0.0368


100%|██████████| 13/13 [00:00<00:00, 124.38it/s]
100%|██████████| 49/49 [00:00<00:00, 332.84it/s]


Val accuracy: 0.8903061224489796
Epoch 781/1000, Loss: 0.0707


100%|██████████| 13/13 [00:00<00:00, 125.15it/s]
100%|██████████| 49/49 [00:00<00:00, 254.47it/s]


Val accuracy: 0.875
Epoch 791/1000, Loss: 0.0369


100%|██████████| 13/13 [00:00<00:00, 169.70it/s]
100%|██████████| 49/49 [00:00<00:00, 302.13it/s]


Val accuracy: 0.9209183673469388
Epoch 801/1000, Loss: 0.0458


100%|██████████| 13/13 [00:00<00:00, 149.43it/s]
100%|██████████| 49/49 [00:00<00:00, 380.76it/s]


Val accuracy: 0.9056122448979592
Epoch 811/1000, Loss: 0.0467


100%|██████████| 13/13 [00:00<00:00, 124.50it/s]
100%|██████████| 49/49 [00:00<00:00, 301.12it/s]


Val accuracy: 0.8852040816326531
Epoch 821/1000, Loss: 0.0267


100%|██████████| 13/13 [00:00<00:00, 134.43it/s]
100%|██████████| 49/49 [00:00<00:00, 247.51it/s]


Val accuracy: 0.9081632653061225
Epoch 831/1000, Loss: 0.0474


100%|██████████| 13/13 [00:00<00:00, 130.90it/s]
100%|██████████| 49/49 [00:00<00:00, 295.63it/s]


Val accuracy: 0.8877551020408163
Epoch 841/1000, Loss: 0.0148


100%|██████████| 13/13 [00:00<00:00, 165.05it/s]
100%|██████████| 49/49 [00:00<00:00, 332.90it/s]


Val accuracy: 0.8954081632653061
Epoch 851/1000, Loss: 0.0648


100%|██████████| 13/13 [00:00<00:00, 149.01it/s]
100%|██████████| 49/49 [00:00<00:00, 314.83it/s]


Val accuracy: 0.8571428571428571
Epoch 861/1000, Loss: 0.0774


100%|██████████| 13/13 [00:00<00:00, 125.28it/s]
100%|██████████| 49/49 [00:00<00:00, 277.02it/s]


Val accuracy: 0.8341836734693877
Epoch 871/1000, Loss: 0.0594


100%|██████████| 13/13 [00:00<00:00, 127.57it/s]
100%|██████████| 49/49 [00:00<00:00, 273.55it/s]


Val accuracy: 0.8673469387755102
Epoch 881/1000, Loss: 0.0408


100%|██████████| 13/13 [00:00<00:00, 170.63it/s]
100%|██████████| 49/49 [00:00<00:00, 303.32it/s]


Val accuracy: 0.9056122448979592
Epoch 891/1000, Loss: 0.0951


100%|██████████| 13/13 [00:00<00:00, 129.08it/s]
100%|██████████| 49/49 [00:00<00:00, 368.74it/s]


Val accuracy: 0.8979591836734694
Epoch 901/1000, Loss: 0.1096


100%|██████████| 13/13 [00:00<00:00, 135.32it/s]
100%|██████████| 49/49 [00:00<00:00, 288.81it/s]


Val accuracy: 0.5382653061224489
Epoch 911/1000, Loss: 0.1056


100%|██████████| 13/13 [00:00<00:00, 142.48it/s]
100%|██████████| 49/49 [00:00<00:00, 351.99it/s]


Val accuracy: 0.6275510204081632
Epoch 921/1000, Loss: 0.0971


100%|██████████| 13/13 [00:00<00:00, 166.01it/s]
100%|██████████| 49/49 [00:00<00:00, 274.19it/s]


Val accuracy: 0.7525510204081632
Epoch 931/1000, Loss: 0.0644


100%|██████████| 13/13 [00:00<00:00, 130.16it/s]
100%|██████████| 49/49 [00:00<00:00, 333.51it/s]


Val accuracy: 0.8826530612244898
Epoch 941/1000, Loss: 0.0488


100%|██████████| 13/13 [00:00<00:00, 132.05it/s]
100%|██████████| 49/49 [00:00<00:00, 260.83it/s]


Val accuracy: 0.8903061224489796
Epoch 951/1000, Loss: 0.0376


100%|██████████| 13/13 [00:00<00:00, 139.55it/s]
100%|██████████| 49/49 [00:00<00:00, 383.43it/s]


Val accuracy: 0.9081632653061225
Epoch 961/1000, Loss: 0.0563


100%|██████████| 13/13 [00:00<00:00, 156.33it/s]
100%|██████████| 49/49 [00:00<00:00, 292.35it/s]


Val accuracy: 0.9030612244897959
Epoch 971/1000, Loss: 0.0415


100%|██████████| 13/13 [00:00<00:00, 144.15it/s]
100%|██████████| 49/49 [00:00<00:00, 319.71it/s]


Val accuracy: 0.8979591836734694
Epoch 981/1000, Loss: 0.0353


100%|██████████| 13/13 [00:00<00:00, 156.73it/s]
100%|██████████| 49/49 [00:00<00:00, 385.67it/s]


Val accuracy: 0.8698979591836735
Epoch 991/1000, Loss: 0.0488


100%|██████████| 13/13 [00:00<00:00, 104.44it/s]
100%|██████████| 49/49 [00:00<00:00, 311.30it/s]


Val accuracy: 0.875


In [9]:

model.load_state_dict(torch.load("/home/s/python_progs/DL_homeworks/deep-metric-ts-clustering/notebooks/model/model_GRU_absolute_best.pth", map_location=torch.device("cuda")))
model.to("cuda")

model.eval()

with torch.no_grad():
    # as all embeddings need to be stored in memory
    # you can set DEVICE = torch.device('cpu') in case gpu memory overflow occurs
    accuracy = get_accuracy(val_dataset, train_dataset, model, device.type)
    
    print(f"Val accuracy: {accuracy}")

100%|██████████| 13/13 [00:00<00:00, 114.85it/s]
100%|██████████| 49/49 [00:00<00:00, 380.05it/s]

Val accuracy: 0.9821428571428571



