In [1]:
import torch 
import torchmetrics
import torch.nn as nn
from pytorch_lightning import LightningModule,Trainer
from torch.utils.data import TensorDataset,DataLoader

In [None]:
from dataset_ import EEG_inception
from model import Conv1D_v2

### chrononet model

In [3]:
import torch
import torch.nn as nn

class Block(nn.Module):
    def __init__(self, inplace):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=inplace, out_channels=16, kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv1d(in_channels=inplace, out_channels=16, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv1d(in_channels=inplace, out_channels=16, kernel_size=8, stride=2, padding=3)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x1 = self.relu(self.conv1(x))
        x2 = self.relu(self.conv2(x))
        x3 = self.relu(self.conv3(x))
        x = torch.cat([x1, x2, x3], dim=1)
        return x

class ChronoNet(nn.Module):
    def __init__(self, channel):
        super().__init__()
        # Use fewer blocks and channels
        self.block1 = Block(channel)
        self.block2 = Block(48)  # Adjust input to match output channels from Block
        
        # Use a single GRU layer to simplify the model
        self.gru = nn.GRU(input_size=48, hidden_size=32, batch_first=True)
        
        # Linear layer for output prediction
        self.fc = nn.Sequential(
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 1)
        )
    
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = x.permute(0, 2, 1)  # Adjust for GRU input
        
        # GRU forward pass
        gru_out, _ = self.gru(x)
        # Use the last output from GRU
        x = gru_out[:, -1, :]  
        
        # Fully connected layer
        x = self.fc(x)
        return x


In [4]:
input=torch.randn(3,8,1000)
input.shape
model=ChronoNet(8)
out=model(input)
out.shape
print(torch.sigmoid(out))


tensor([[0.4306],
        [0.4347],
        [0.4354]], grad_fn=<SigmoidBackward0>)


### instantiating the lighting module for the model 

In [5]:
from pytorch_lightning import LightningModule
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torchmetrics

class ChronoModel(LightningModule):
    def __init__(self, attribute):
        super(ChronoModel, self).__init__()
        self.attribute = attribute
        self.model = attribute["model"] # initialize the model
        self.lr = 1e-4
        self.bs = 64
        self.worker = 2
        self.acc = torchmetrics.Accuracy(task="binary")
     
        self.criterion = nn.BCEWithLogitsLoss()

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def train_dataloader(self):
        dataset = self.attribute["train_dataset"]
        return DataLoader(dataset, batch_size=self.bs, num_workers=self.worker, shuffle=True)

    def val_dataloader(self):
        dataset = self.attribute["val_dataset"]
        return DataLoader(dataset, batch_size=self.bs, num_workers=self.worker, shuffle=False)

    def training_step(self, batch, batch_idx):
        signal, label = batch
        out = self(signal.float())
        loss = self.criterion(out.flatten(), label.float().flatten())
        preds = (torch.sigmoid(out.flatten()) > 0.5).long()
        # print(preds.)
        acc = self.acc(preds, label.long().flatten())
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        signal, label = batch
        out = self(signal.float())
        loss = self.criterion(out.flatten(), label.float().flatten())
        preds = (torch.sigmoid(out.flatten()) > 0.5).long()
        acc = self.acc(preds, label.long().flatten())
        self.log('val_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_acc', acc, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def on_train_epoch_end(self):
        # Retrieve the average training loss
        train_loss = self.trainer.callback_metrics.get("train_loss_epoch", None)
        if train_loss is not None:
            self.print(f"Epoch {self.current_epoch} - Training Loss: {train_loss:.4f}")

    def on_validation_epoch_end(self):
        # Retrieve the average validation loss
        val_loss = self.trainer.callback_metrics.get("val_loss", None)
        if val_loss is not None:
            self.print(f"Epoch {self.current_epoch} - Validation Loss: {val_loss:.4f}")


In [6]:
attributes = {1 : {"model" : ChronoNet(channel=8), "train_dataset" : EEG_inception(kind="train", normalize= 1, balancing="equal_samples"), "val_dataset" : EEG_inception(kind = "val", normalize= 1)}}

model=ChronoModel(attribute = attributes[1])

should be here 


100%|██████████| 4988/4988 [00:32<00:00, 153.97it/s]


(4988, 8, 750) shap[e]


100%|██████████| 7650/7650 [00:11<00:00, 661.10it/s] 


(4988, 8, 750) in here dataset
train main_job done 10790 4988 4988
should be here 


100%|██████████| 856/856 [00:07<00:00, 107.94it/s]


(856, 8, 750) shap[e]


100%|██████████| 7650/7650 [00:04<00:00, 1850.53it/s]


(856, 8, 750) in here dataset
val main_job done 1350 856 856


In [7]:
trainer=Trainer(max_epochs=150)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [8]:
trainer.fit(model)

You are using a CUDA device ('NVIDIA GeForce RTX 4070 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type              | Params | Mode 
--------------------------------------------------------
0 | model     | ChronoNet         | 21.1 K | train
1 | acc       | BinaryAccuracy    | 0      | train
2 | criterion | BCEWithLogitsLoss | 0      | train
--------------------------------------------------------
21.1 K    Trainable params
0         Non-trainable params
21.1 K    Total params
0.084     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\zokov\.conda\envs\py3\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


Epoch 0 - Validation Loss: 0.6934


c:\Users\zokov\.conda\envs\py3\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0 - Validation Loss: 0.6932
Epoch 0 - Training Loss: 0.6925


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1 - Validation Loss: 0.6933
Epoch 1 - Training Loss: 0.6915


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2 - Validation Loss: 0.6935
Epoch 2 - Training Loss: 0.6906


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3 - Validation Loss: 0.6940
Epoch 3 - Training Loss: 0.6900


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4 - Validation Loss: 0.6941
Epoch 4 - Training Loss: 0.6892


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5 - Validation Loss: 0.6946
Epoch 5 - Training Loss: 0.6886


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6 - Validation Loss: 0.6951
Epoch 6 - Training Loss: 0.6881


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7 - Validation Loss: 0.6955
Epoch 7 - Training Loss: 0.6874


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8 - Validation Loss: 0.6958
Epoch 8 - Training Loss: 0.6868


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9 - Validation Loss: 0.6967
Epoch 9 - Training Loss: 0.6865


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10 - Validation Loss: 0.6972
Epoch 10 - Training Loss: 0.6857


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11 - Validation Loss: 0.6979
Epoch 11 - Training Loss: 0.6850


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12 - Validation Loss: 0.6988
Epoch 12 - Training Loss: 0.6844


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 13 - Validation Loss: 0.6996
Epoch 13 - Training Loss: 0.6838


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 14 - Validation Loss: 0.7010
Epoch 14 - Training Loss: 0.6832


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 15 - Validation Loss: 0.7015
Epoch 15 - Training Loss: 0.6824


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 16 - Validation Loss: 0.7024
Epoch 16 - Training Loss: 0.6820


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 17 - Validation Loss: 0.7037
Epoch 17 - Training Loss: 0.6815


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 18 - Validation Loss: 0.7043
Epoch 18 - Training Loss: 0.6806


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 19 - Validation Loss: 0.7076
Epoch 19 - Training Loss: 0.6802


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 20 - Validation Loss: 0.7082
Epoch 20 - Training Loss: 0.6793


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 21 - Validation Loss: 0.7075
Epoch 21 - Training Loss: 0.6788


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 22 - Validation Loss: 0.7071
Epoch 22 - Training Loss: 0.6793


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 23 - Validation Loss: 0.7093
Epoch 23 - Training Loss: 0.6775


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 24 - Validation Loss: 0.7109
Epoch 24 - Training Loss: 0.6768


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 25 - Validation Loss: 0.7115
Epoch 25 - Training Loss: 0.6758


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 26 - Validation Loss: 0.7139
Epoch 26 - Training Loss: 0.6753


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 27 - Validation Loss: 0.7155
Epoch 27 - Training Loss: 0.6738


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 28 - Validation Loss: 0.7165
Epoch 28 - Training Loss: 0.6731


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 29 - Validation Loss: 0.7179
Epoch 29 - Training Loss: 0.6727


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 30 - Validation Loss: 0.7166
Epoch 30 - Training Loss: 0.6714


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 31 - Validation Loss: 0.7188
Epoch 31 - Training Loss: 0.6714


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 32 - Validation Loss: 0.7199
Epoch 32 - Training Loss: 0.6708


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 33 - Validation Loss: 0.7198
Epoch 33 - Training Loss: 0.6695


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 34 - Validation Loss: 0.7227
Epoch 34 - Training Loss: 0.6693


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 35 - Validation Loss: 0.7267
Epoch 35 - Training Loss: 0.6679


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 36 - Validation Loss: 0.7246
Epoch 36 - Training Loss: 0.6679


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 37 - Validation Loss: 0.7240
Epoch 37 - Training Loss: 0.6661


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 38 - Validation Loss: 0.7244
Epoch 38 - Training Loss: 0.6657


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 39 - Validation Loss: 0.7327
Epoch 39 - Training Loss: 0.6643


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 40 - Validation Loss: 0.7282
Epoch 40 - Training Loss: 0.6645


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 41 - Validation Loss: 0.7307
Epoch 41 - Training Loss: 0.6629


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 42 - Validation Loss: 0.7317
Epoch 42 - Training Loss: 0.6619


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 43 - Validation Loss: 0.7332
Epoch 43 - Training Loss: 0.6621


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 44 - Validation Loss: 0.7311
Epoch 44 - Training Loss: 0.6624


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 45 - Validation Loss: 0.7327
Epoch 45 - Training Loss: 0.6611


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 46 - Validation Loss: 0.7341
Epoch 46 - Training Loss: 0.6584


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 47 - Validation Loss: 0.7364
Epoch 47 - Training Loss: 0.6584


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 48 - Validation Loss: 0.7344
Epoch 48 - Training Loss: 0.6583


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 49 - Validation Loss: 0.7396
Epoch 49 - Training Loss: 0.6565


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 50 - Validation Loss: 0.7392
Epoch 50 - Training Loss: 0.6565


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 51 - Validation Loss: 0.7410
Epoch 51 - Training Loss: 0.6561


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 52 - Validation Loss: 0.7399
Epoch 52 - Training Loss: 0.6546


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 53 - Validation Loss: 0.7418
Epoch 53 - Training Loss: 0.6540


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 54 - Validation Loss: 0.7433
Epoch 54 - Training Loss: 0.6554


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 55 - Validation Loss: 0.7418
Epoch 55 - Training Loss: 0.6532


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 56 - Validation Loss: 0.7473
Epoch 56 - Training Loss: 0.6512


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 57 - Validation Loss: 0.7444
Epoch 57 - Training Loss: 0.6522


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 58 - Validation Loss: 0.7452
Epoch 58 - Training Loss: 0.6496


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 59 - Validation Loss: 0.7466
Epoch 59 - Training Loss: 0.6504


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 60 - Validation Loss: 0.7540
Epoch 60 - Training Loss: 0.6494


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 61 - Validation Loss: 0.7499
Epoch 61 - Training Loss: 0.6483


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 62 - Validation Loss: 0.7490
Epoch 62 - Training Loss: 0.6487


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 63 - Validation Loss: 0.7531
Epoch 63 - Training Loss: 0.6492


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 64 - Validation Loss: 0.7516
Epoch 64 - Training Loss: 0.6458


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 65 - Validation Loss: 0.7538
Epoch 65 - Training Loss: 0.6455


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 66 - Validation Loss: 0.7534
Epoch 66 - Training Loss: 0.6437


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 67 - Validation Loss: 0.7583
Epoch 67 - Training Loss: 0.6445


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 68 - Validation Loss: 0.7621
Epoch 68 - Training Loss: 0.6440


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 69 - Validation Loss: 0.7547
Epoch 69 - Training Loss: 0.6442


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 70 - Validation Loss: 0.7604
Epoch 70 - Training Loss: 0.6435


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 71 - Validation Loss: 0.7552
Epoch 71 - Training Loss: 0.6417


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 72 - Validation Loss: 0.7573
Epoch 72 - Training Loss: 0.6406


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 73 - Validation Loss: 0.7630
Epoch 73 - Training Loss: 0.6404


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 74 - Validation Loss: 0.7589
Epoch 74 - Training Loss: 0.6391


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 75 - Validation Loss: 0.7602
Epoch 75 - Training Loss: 0.6390


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 76 - Validation Loss: 0.7685
Epoch 76 - Training Loss: 0.6378


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 77 - Validation Loss: 0.7668
Epoch 77 - Training Loss: 0.6378


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 78 - Validation Loss: 0.7624
Epoch 78 - Training Loss: 0.6365


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 79 - Validation Loss: 0.7620
Epoch 79 - Training Loss: 0.6379


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 80 - Validation Loss: 0.7652
Epoch 80 - Training Loss: 0.6373


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 81 - Validation Loss: 0.7688
Epoch 81 - Training Loss: 0.6341


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 82 - Validation Loss: 0.7626
Epoch 82 - Training Loss: 0.6335


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 83 - Validation Loss: 0.7712
Epoch 83 - Training Loss: 0.6326


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 84 - Validation Loss: 0.7658
Epoch 84 - Training Loss: 0.6322


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 85 - Validation Loss: 0.7674
Epoch 85 - Training Loss: 0.6320


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 86 - Validation Loss: 0.7745
Epoch 86 - Training Loss: 0.6313


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 87 - Validation Loss: 0.7730
Epoch 87 - Training Loss: 0.6299


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 88 - Validation Loss: 0.7789
Epoch 88 - Training Loss: 0.6297


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 89 - Validation Loss: 0.7760
Epoch 89 - Training Loss: 0.6293


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 90 - Validation Loss: 0.7720
Epoch 90 - Training Loss: 0.6263


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 91 - Validation Loss: 0.7730
Epoch 91 - Training Loss: 0.6288


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 92 - Validation Loss: 0.7763
Epoch 92 - Training Loss: 0.6265


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 93 - Validation Loss: 0.7741
Epoch 93 - Training Loss: 0.6242


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 94 - Validation Loss: 0.7866
Epoch 94 - Training Loss: 0.6259


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 95 - Validation Loss: 0.7790
Epoch 95 - Training Loss: 0.6237


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 96 - Validation Loss: 0.7957
Epoch 96 - Training Loss: 0.6227


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 97 - Validation Loss: 0.7698
Epoch 97 - Training Loss: 0.6234


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 98 - Validation Loss: 0.7876
Epoch 98 - Training Loss: 0.6207


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 99 - Validation Loss: 0.7818
Epoch 99 - Training Loss: 0.6199


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 100 - Validation Loss: 0.7970
Epoch 100 - Training Loss: 0.6204


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 101 - Validation Loss: 0.7873
Epoch 101 - Training Loss: 0.6172


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 102 - Validation Loss: 0.7931
Epoch 102 - Training Loss: 0.6196


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 103 - Validation Loss: 0.7961
Epoch 103 - Training Loss: 0.6175


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 104 - Validation Loss: 0.7949
Epoch 104 - Training Loss: 0.6157


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 105 - Validation Loss: 0.7871
Epoch 105 - Training Loss: 0.6171


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 106 - Validation Loss: 0.7920
Epoch 106 - Training Loss: 0.6137


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 107 - Validation Loss: 0.7897
Epoch 107 - Training Loss: 0.6138


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 108 - Validation Loss: 0.7990
Epoch 108 - Training Loss: 0.6139


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 109 - Validation Loss: 0.7964
Epoch 109 - Training Loss: 0.6133


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 110 - Validation Loss: 0.8055
Epoch 110 - Training Loss: 0.6113


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 111 - Validation Loss: 0.7951
Epoch 111 - Training Loss: 0.6113


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 112 - Validation Loss: 0.8044
Epoch 112 - Training Loss: 0.6105


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 113 - Validation Loss: 0.8033
Epoch 113 - Training Loss: 0.6082


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 114 - Validation Loss: 0.7962
Epoch 114 - Training Loss: 0.6107


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 115 - Validation Loss: 0.7994
Epoch 115 - Training Loss: 0.6089


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 116 - Validation Loss: 0.8120
Epoch 116 - Training Loss: 0.6066


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 117 - Validation Loss: 0.8011
Epoch 117 - Training Loss: 0.6058


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 118 - Validation Loss: 0.8108
Epoch 118 - Training Loss: 0.6037


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 119 - Validation Loss: 0.8088
Epoch 119 - Training Loss: 0.6037


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 120 - Validation Loss: 0.8123
Epoch 120 - Training Loss: 0.6036


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 121 - Validation Loss: 0.8156
Epoch 121 - Training Loss: 0.6018


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 122 - Validation Loss: 0.8108
Epoch 122 - Training Loss: 0.6035


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 123 - Validation Loss: 0.8296
Epoch 123 - Training Loss: 0.5984


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 124 - Validation Loss: 0.8220
Epoch 124 - Training Loss: 0.6011


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 125 - Validation Loss: 0.8344
Epoch 125 - Training Loss: 0.5993


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 126 - Validation Loss: 0.8106
Epoch 126 - Training Loss: 0.5992


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 127 - Validation Loss: 0.8164
Epoch 127 - Training Loss: 0.5974


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 128 - Validation Loss: 0.8289
Epoch 128 - Training Loss: 0.5946


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 129 - Validation Loss: 0.8388
Epoch 129 - Training Loss: 0.5941


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 130 - Validation Loss: 0.8262
Epoch 130 - Training Loss: 0.5923


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 131 - Validation Loss: 0.8184
Epoch 131 - Training Loss: 0.5968


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 132 - Validation Loss: 0.8377
Epoch 132 - Training Loss: 0.5940


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 133 - Validation Loss: 0.8579
Epoch 133 - Training Loss: 0.5923


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 134 - Validation Loss: 0.8204
Epoch 134 - Training Loss: 0.5951


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 135 - Validation Loss: 0.8348
Epoch 135 - Training Loss: 0.5902


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 136 - Validation Loss: 0.8457
Epoch 136 - Training Loss: 0.5892


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 137 - Validation Loss: 0.8615
Epoch 137 - Training Loss: 0.5923


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 138 - Validation Loss: 0.8346
Epoch 138 - Training Loss: 0.5872


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 139 - Validation Loss: 0.8487
Epoch 139 - Training Loss: 0.5890


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 140 - Validation Loss: 0.8363
Epoch 140 - Training Loss: 0.5894


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 141 - Validation Loss: 0.8382
Epoch 141 - Training Loss: 0.5878


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 142 - Validation Loss: 0.8432
Epoch 142 - Training Loss: 0.5829


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 143 - Validation Loss: 0.8772
Epoch 143 - Training Loss: 0.5863


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 144 - Validation Loss: 0.8465
Epoch 144 - Training Loss: 0.5843


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 145 - Validation Loss: 0.8592
Epoch 145 - Training Loss: 0.5804


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 146 - Validation Loss: 0.8539
Epoch 146 - Training Loss: 0.5800


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 147 - Validation Loss: 0.8688
Epoch 147 - Training Loss: 0.5813


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 148 - Validation Loss: 0.8694
Epoch 148 - Training Loss: 0.5805


Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=150` reached.


Epoch 149 - Validation Loss: 0.8499
Epoch 149 - Training Loss: 0.5803


In [None]:
print(trainer.callback_metrics)

{'train_loss': tensor(0.6919), 'train_loss_step': tensor(0.7212), 'train_acc': tensor(0.5230), 'train_acc_step': tensor(0.3750), 'val_loss': tensor(0.6909), 'val_acc': tensor(0.5275), 'train_loss_epoch': tensor(0.6919), 'train_acc_epoch': tensor(0.5230)}


In [None]:
train_data = EEG_inception(kind="train", normalize = False)
dl = DataLoader(train_data, batch_size=64, num_workers=2, shuffle=True)

(1000, 8, 1000) in here dataset
train kind 1000 1000 1000
main_job done


In [None]:
pred_label = []
true_label = []

device = torch.device("cpu")
for ii, (data, label) in enumerate(dl):
    input = data.float().to(device)
    label = label.to(device)

    pred = model(input).float()
    pred = torch.sigmoid(pred)
    pred = (pred >= 0.5).float().to(device).data

    pred = pred.view(-1)
    
    pred_label.append(pred)
    true_label.append(label)
    # print(pred, "val")
    # print(label)
pred_label = torch.cat(pred_label, 0)
true_label = torch.cat(true_label, 0)

val_accuracy = torch.sum(pred_label == true_label).type(torch.FloatTensor) / true_label.size(0)

print(val_accuracy)

tensor(0.5210)


In [None]:
num_ones = torch.sum(pred_label == 0).item()

In [None]:
num_ones

795

In [None]:
torch.sum(pred_label == 1).item()

2559

In [None]:
########################################################################

In [None]:
import torch
import torch.nn as nn

class Block(nn.Module):
    def __init__(self, inplace):
        super().__init__()
        # Add batch normalization and proper initialization
        self.conv1 = nn.Conv1d(in_channels=inplace, out_channels=32, kernel_size=2, stride=2, padding=0)
        self.bn1 = nn.BatchNorm1d(32)
        self.conv2 = nn.Conv1d(in_channels=inplace, out_channels=32, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm1d(32)
        self.conv3 = nn.Conv1d(in_channels=inplace, out_channels=32, kernel_size=8, stride=2, padding=3)
        self.bn3 = nn.BatchNorm1d(32)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        
        # Initialize weights properly
        nn.init.kaiming_normal_(self.conv1.weight)
        nn.init.kaiming_normal_(self.conv2.weight)
        nn.init.kaiming_normal_(self.conv3.weight)
    
    def forward(self, x):
        x1 = self.dropout(self.relu(self.bn1(self.conv1(x))))
        x2 = self.dropout(self.relu(self.bn2(self.conv2(x))))
        x3 = self.dropout(self.relu(self.bn3(self.conv3(x))))
        x = torch.cat([x1, x2, x3], dim=1)
        return x

class ChronoNet(nn.Module):
    def __init__(self, channel):
        super().__init__()
        self.block1 = Block(channel)
        self.block2 = Block(96)
        self.block3 = Block(96)
        
        # Add dropout and batch normalization to GRU layers
        self.gru1 = nn.GRU(input_size=96, hidden_size=32, batch_first=True, dropout=0.2)
        self.gru2 = nn.GRU(input_size=32, hidden_size=32, batch_first=True, dropout=0.2)
        self.gru3 = nn.GRU(input_size=64, hidden_size=32, batch_first=True, dropout=0.2)
        self.gru4 = nn.GRU(input_size=96, hidden_size=32, batch_first=True, dropout=0.2)
        
        self.gru_linear = nn.Linear(225, 1)
        self.bn_linear = nn.BatchNorm1d(32)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32, 1)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        
        # Initialize the linear layers
        nn.init.xavier_normal_(self.gru_linear.weight)
        nn.init.xavier_normal_(self.fc1.weight)
    
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = x.permute(0, 2, 1)
        
        gru_out1, _ = self.gru1(x)
        gru_out2, _ = self.gru2(gru_out1)
        gru_out = torch.cat([gru_out1, gru_out2], dim=2)
        
        gru_out3, _ = self.gru3(gru_out)
        gru_out = torch.cat([gru_out1, gru_out2, gru_out3], dim=2)
        
        linear_out = self.dropout(self.relu(self.gru_linear(gru_out.permute(0, 2, 1))))
        gru_out4, _ = self.gru4(linear_out.permute(0, 2, 1))
        
        x = self.flatten(gru_out4)
        x = self.bn_linear(x)
        x = self.dropout(x)
        x = self.fc1(x)
        return x

class ChronoModel(LightningModule):
    def __init__(self):
        super(ChronoModel, self).__init__()
        self.model = ChronoNet(8)
        self.lr = 1e-3  # Increased learning rate
        self.bs = 32    # Reduced batch size
        self.worker = 2
        self.acc = torchmetrics.Accuracy(task="binary")
        self.criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.0]))  # Balanced weight
        
        # Add early stopping metrics
        self.best_val_acc = 0.0
        self.patience_counter = 0
        
    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.1, patience=3, verbose=True
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_acc"
            }
        }

    def training_step(self, batch, batch_idx):
        signal, label = batch
        out = self(signal.float())
        loss = self.criterion(out.flatten(), label.float().flatten())
        
        # Add gradient clipping
        torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
        
        preds = torch.sigmoid(out.flatten())
        acc = self.acc(preds, label.long().flatten())
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        signal, label = batch
        out = self(signal.float())
        loss = self.criterion(out.flatten(), label.float().flatten())
        
        preds = torch.sigmoid(out.flatten())
        acc = self.acc(preds, label.long().flatten())
        
        self.log('val_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_acc', acc, on_epoch=True, prog_bar=True, logger=True)
        
        return loss

SyntaxError: invalid syntax (818234999.py, line 12)

In [None]:
# Initialize and train
model = ChronoModel()
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
trainer = Trainer(
    max_epochs=50,
    callbacks=[
        EarlyStopping(monitor='val_loss', patience=5),
        ModelCheckpoint(monitor='val_acc', mode='max')
    ]
)
trainer.fit(model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
C:\Users\zokov\.conda\envs\py3\Lib\site-packages\pytorch_lightning\trainer\configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type              | Params | Mode 
--------------------------------------------------------
0 | model     | ChronoNet         | 131 K  | train
1 | acc       | BinaryAccuracy    | 0      | train
2 | criterion | BCEWithLogitsLoss | 0      | train
--------------------------------------------------------
131 K     Trainable params
0         Non-trainable params
131 K     Total params
0.526     Total estimated model params size (MB)


MisconfigurationException: `train_dataloader` must be implemented to be used with the Lightning Trainer