In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import mne
import pywt
from torch.utils.data import DataLoader, Dataset
from torch import nn
import torch
import pickle
import os
from datetime import datetime
import sqlite3

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Training on device: {device}')

Training on device: cuda


In [4]:
def read_cwt_data(db_path):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute("SELECT cwt_data FROM wavelet_transforms")
    data = cursor.fetchall()
    conn.close()

    # Deserializacja danych
    cwt_arrays = [pickle.loads(d[0]) for d in data]
    return cwt_arrays


# Odczytanie danych z bazy danych
cwt_data_list = read_cwt_data("cwt_data.db")


# Przykładowy odczyt jednego tensora CWT
print(cwt_data_list[0].shape)

(64, 10)


In [9]:

class CWTDataset(Dataset):
    def __init__(self, db_path, sequence_length=4000):
        self.db_path = db_path
        self.sequence_length = sequence_length
        self.conn = sqlite3.connect(db_path)
        self.cursor = self.conn.cursor()
        self.cursor.execute("SELECT COUNT(*) FROM wavelet_transforms")
        self.total_samples = self.cursor.fetchone()[0]

    def __len__(self):
        # Aby umożliwić nachodzenie, liczba możliwych sekwencji będzie równa liczbie próbek minus długość sekwencji + 1
        return self.total_samples - self.sequence_length + 1

    def __getitem__(self, idx):
        # Zwraca sekwencję próbek i target z ostatniej próbki
        query = "SELECT cwt_data, target FROM wavelet_transforms WHERE id BETWEEN ? AND ?"
        self.cursor.execute(query, (idx + 1, idx + self.sequence_length))  # SQLite indeksuje od 1
        rows = self.cursor.fetchall()

        cwt_sequence = [pickle.loads(row[0]) for row in rows]
        # Target ostatniej próbki w sekwencji
        target = rows[-1][1]

        cwt_tensor = torch.tensor(cwt_sequence, dtype=torch.float32)
        target_tensor = torch.tensor(target, dtype=torch.int64)
        return cwt_tensor, target_tensor

    def __del__(self):
        self.conn.close()


In [10]:
from torch.utils.data import DataLoader

# Tworzenie instancji datasetu
dataset = CWTDataset('cwt_data.db', 4000)

batch_size = 10  # Liczba sekwencji w jednym batchu
train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)


In [11]:
class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super(LSTMModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.layer_dim = layer_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        # Inicjalizacja stanów ukrytych
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim, device=x.device)
        c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim, device=x.device)
        
        # Forward pass through LSTM layer
        out, (hn, cn) = self.lstm(x, (h0, c0))
        
        # Forward pass through linear layer
        out = self.fc(out[:, -1, :])
        return out

In [12]:
# Parametry
input_dim = 640
hidden_dim = 100
layer_dim = 1
output_dim = 1
num_epochs = 20
learning_rate = 0.01

model = LSTMModel(input_dim, hidden_dim, layer_dim, output_dim).to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Pętla treningowa
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_idx, (data, targets) in enumerate(train_loader):
        # Przeniesienie danych na GPU
        data = data.view(-1, 4000, 640).to(device)
        targets = targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(data)
        
        loss = criterion(outputs, targets.float().unsqueeze(1))
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

    print(f'Epoch: {epoch+1}, Loss: {total_loss / len(train_loader)}')
