In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
!unzip /content/drive/MyDrive/data.zip

Archive:  /content/drive/MyDrive/data.zip
  inflating: HR_all.pkl              
  inflating: label_all.pkl           
  inflating: OX_all.pkl              
  inflating: SaO2_all.pkl            


In [5]:
import torch
import random
import pandas as pd
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset , random_split , DataLoader

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:

class EmbeddingEncoder:
    def __call__(
        self,
        data : torch.Tensor ,
        format : str = "sum",
        stride : int = 2 ,
        padding_format : str = "concat"
    ) -> torch.Tensor :
        # {
        #    data : [ batch , max_length ]
        # }

        encoder_all = torch.zeros(data.size() , dtype=torch.float)

        for i in range(data.size(1)):

            _ = [
                data[: , i : i + stride ]
            ]

            if _[0].size(1) < stride :
                _.append(data[: , 0 : stride - _[0].size(1)])

            if format == "tanh": # => stride => 10
                encoder_all[: , i] = torch.cat(_ , dim=-1).float().tanh().mean(dim=-1)

            elif format == "sinh": # => stride => 10
                encoder_all[: , i] = torch.cat(_ , dim=-1).float().sinh().mean(dim=-1)

            elif format == "softmax": # => stride => 2
                encoder_all[: , i] = torch.cat(_ , dim=-1).float().softmax(dim=-1).sum(dim=-1).softmax(dim=-1)

        return encoder_all

class Datasets(Dataset):
    def __init__(self ,
        datasets_HR : pd.DataFrame ,
        datasets_OX : pd.DataFrame ,
        datasets_SaO2 : pd.DataFrame ,
        datasets_labels : pd.DataFrame
    ) -> None:
        super().__init__()

        self._HR , self._OX , self._SaO2 , self._labels = datasets_HR , datasets_OX , datasets_SaO2 , datasets_labels

        self._ModelEncoder = EmbeddingEncoder()

    def __len__(self):
        return len(self._HR)

    def __getitem__(self, index : int ):

        return {
            "HR" : torch.from_numpy(self._HR.iloc[index].values).float().flatten() ,
            "OX" : torch.from_numpy(self._OX.iloc[index].values).float().flatten() ,
            "SaO2" : torch.from_numpy(self._SaO2.iloc[index].values).float().flatten() ,
            "labels" : torch.from_numpy(self._labels.iloc[index].values).float().flatten() ,
        }

data = Datasets(
    datasets_HR= pd.read_pickle('/content/HR_all.pkl')  ,
    datasets_OX= pd.read_pickle('/content/OX_all.pkl') ,
    datasets_SaO2= pd.read_pickle('/content/SaO2_all.pkl') ,
    datasets_labels= pd.read_pickle('/content/label_all.pkl')
)

train_dataloder =  DataLoader(data, batch_size= 50 )

In [10]:
class SpatialAttentionBlock(nn.Module):
    def __init__(self , in_channel : int = 1024):
        super().__init__()

        self._model = nn.Sequential(
            nn.Conv1d(in_channels=in_channel, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False),
            nn.Sigmoid()
        )

    def forward(self , input_ids : torch.Tensor ):
        return input_ids * self._model(input_ids)

class ChannelAttention(nn.Module):
    def __init__(self,  new_channels : int , last_channels : int):
        super().__init__()

        self._avgpool = nn.AdaptiveAvgPool1d(1)

        self._attn = nn.Sequential(
            nn.Linear(new_channels , last_channels , bias=False),
            nn.ReLU(),
            nn.Linear(last_channels , new_channels , bias=False),
            nn.Sigmoid()
        )

    def forward(self, input_ids : torch.Tensor ):

        _avg = self._avgpool(input_ids)

        return input_ids * self._attn(_avg.flatten(1)).view_as(_avg)

class DualAttention(nn.Module):
    def __init__(self, new_channels : int , last_channels : int ):
        super().__init__()

        self.channel_attn = ChannelAttention(new_channels=new_channels , last_channels=last_channels)
        self.spatial_attn = SpatialAttentionBlock(in_channel=new_channels)

    def forward(self, input_ids : torch.Tensor):
        return  self.spatial_attn(
            self.channel_attn(
                input_ids
            )
        )

class BlockEncoderCnn(nn.Module):
    def __init__(self , in_channels , out_channels , kernel_size : int = 3 , stride : int = 1 , padding : int = 1 , level : int = 0):
        super().__init__()

        if level == 0 :

            self._model = nn.Sequential(
                nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size , stride=stride , padding=padding , bias=False ),
                nn.BatchNorm1d( num_features=out_channels ),
                nn.LeakyReLU( inplace = True ),
                nn.Conv1d( out_channels, out_channels, kernel_size=1 ) ,
                DualAttention( new_channels= out_channels , last_channels= in_channels )
            )

            self._activation = nn.Sequential(
                nn.BatchNorm1d(num_features=out_channels) ,
                nn.ReLU() ,
                nn.MaxPool1d(kernel_size=2, stride=2)
            )

        elif level == 1 :

            self._model = nn.Sequential(
                nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size , stride=stride , padding=padding , bias=False ),
                nn.BatchNorm1d( num_features=out_channels ),
                nn.LeakyReLU( inplace = True )
            )

            self._activation = nn.MaxPool1d(kernel_size=2, stride=2)

        else :

            self._model = nn.Sequential(
                nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size , stride=stride , padding=padding , bias=False ),
                DualAttention( new_channels= out_channels , last_channels= in_channels )
            )

            self._activation = nn.Sequential(
                nn.BatchNorm1d(num_features=out_channels) ,
                nn.ReLU() ,
                nn.MaxPool1d(kernel_size=2, stride=2)
            )

    def forward(self , input_ids : torch.Tensor ):
        return self._activation(self._model(input_ids))

class EncoderCnnExtra(nn.Module):
    def __init__(self, level : int = 0):
        super().__init__()

        """
        3   , 256 , 256
        32  , 128 , 128
        64  , 64  , 64
        128 , 32  , 32
        256 , 16  , 16
        512 , 8   , 8
        1024 , 4  , 4
        2048 , 2  , 2
        """

        self._model = nn.Sequential(
            BlockEncoderCnn( in_channels=1   , out_channels=32   , kernel_size=3  , stride=1 , padding=1 , level=level),
            BlockEncoderCnn( in_channels=32  , out_channels=64   , kernel_size=3  , stride=1 , padding=1 , level=level),
            BlockEncoderCnn( in_channels=64  , out_channels=128  , kernel_size=15 , stride=2 , padding=2 , level=level),
            BlockEncoderCnn( in_channels=128 , out_channels=256  , kernel_size=15 , stride=2 , padding=2 , level=level),
            BlockEncoderCnn( in_channels=256 , out_channels=512  , kernel_size=11 , stride=2 , padding=2 , level=level),
            BlockEncoderCnn( in_channels=512 , out_channels=1024 , kernel_size=3  , stride=1 , padding=1 , level=level),
            BlockEncoderCnn( in_channels=1024 ,out_channels=2048 , kernel_size=3  , stride=1 , padding=1 , level=level),
            nn.AdaptiveAvgPool1d(1),
        )

    def forward(self , input_ids : torch.Tensor ):
        return self._model(input_ids.unsqueeze(1)).permute(0 , 2 , 1)

In [11]:
class Model(nn.Module):
    def __init__(self , level : int = 0) -> None:
        super().__init__()

        self._cnn = EncoderCnnExtra(level)

        self._dropout = nn.Dropout(0.4)

        self._model = nn.Linear(2048 , 4)

    def forward(self , x : torch.Tensor ):

        return self._model(self._dropout(self._cnn(x))).squeeze(1)

In [12]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        """
        inputs: logits با shape [batch_size, num_classes]
        targets: اندیس کلاس صحیح با shape [batch_size]
        """
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [13]:
level = 0
model = Model(level).to(device)
optimizer = torch.optim.AdamW(model.parameters() , lr=0.00001 ) # lr/epoch
loss_func = FocalLoss().to(device)

In [14]:
from tqdm import tqdm

_lossed = []
_acc_train = []

best_acc = 0

for epoch in range(100):

    _list = []
    acc_train = []

    model.train()

    for data in tqdm(train_dataloder):

        # 0 , 1 , 2
        # data["HR"].to(device) , data["OX"].to(device) , data["SaO2"].to(device)

        predictions = model(data["HR"].to(device))

        loss = loss_func(predictions, data["labels"].squeeze(1).long().to(device))

        predict_classes_val = torch.argmax(predictions, dim=1)

        acc_train.append(torch.sum(predict_classes_val.cpu() == data["labels"].squeeze(1).long()) / predict_classes_val.size(0))

        _list.append(loss.item())

        optimizer.zero_grad()

        loss.backward()
        _ = nn.utils.clip_grad_norm_(model.parameters(), 50.0)

        optimizer.step()

    if sum(acc_train) / len(acc_train) > best_acc:

        torch.save({
            "model": model._cnn.state_dict()
        } , f"./data_best_{level}_HR.pt")

        best_acc = sum(acc_train) / len(acc_train)

        print("best acc test " , best_acc)

    print("epoch", epoch , "train_loss => " , sum(_list) / len(_list) , "acc Train => " , sum(acc_train) / len(acc_train) )

    _lossed.append(_list)
    _acc_train.append(acc_train)

    if sum(_list) / len(_list) < 0.15 :
        break


100%|██████████| 105/105 [00:36<00:00,  2.89it/s]


best acc test  tensor(0.3132)
epoch 0 train_loss =>  0.7821266100520179 acc Train =>  tensor(0.3132)


100%|██████████| 105/105 [00:33<00:00,  3.10it/s]


best acc test  tensor(0.3786)
epoch 1 train_loss =>  0.709903952053615 acc Train =>  tensor(0.3786)


100%|██████████| 105/105 [00:34<00:00,  3.01it/s]


best acc test  tensor(0.4221)
epoch 2 train_loss =>  0.660879447346642 acc Train =>  tensor(0.4221)


100%|██████████| 105/105 [00:35<00:00,  2.95it/s]


best acc test  tensor(0.4569)
epoch 3 train_loss =>  0.6243083661510831 acc Train =>  tensor(0.4569)


100%|██████████| 105/105 [00:35<00:00,  2.92it/s]


best acc test  tensor(0.5177)
epoch 4 train_loss =>  0.5766573494388944 acc Train =>  tensor(0.5177)


100%|██████████| 105/105 [00:36<00:00,  2.91it/s]


best acc test  tensor(0.5907)
epoch 5 train_loss =>  0.5105611060346876 acc Train =>  tensor(0.5907)


100%|██████████| 105/105 [00:36<00:00,  2.89it/s]


best acc test  tensor(0.6649)
epoch 6 train_loss =>  0.43942512784685406 acc Train =>  tensor(0.6649)


100%|██████████| 105/105 [00:36<00:00,  2.87it/s]


best acc test  tensor(0.7525)
epoch 7 train_loss =>  0.3449726973261152 acc Train =>  tensor(0.7525)


100%|██████████| 105/105 [00:36<00:00,  2.88it/s]


best acc test  tensor(0.8301)
epoch 8 train_loss =>  0.257892050913402 acc Train =>  tensor(0.8301)


100%|██████████| 105/105 [00:36<00:00,  2.88it/s]


best acc test  tensor(0.8983)
epoch 9 train_loss =>  0.17300128564238548 acc Train =>  tensor(0.8983)


100%|██████████| 105/105 [00:36<00:00,  2.88it/s]


best acc test  tensor(0.9371)
epoch 10 train_loss =>  0.11295241106833731 acc Train =>  tensor(0.9371)


In [15]:
from tqdm import tqdm

_lossed = []
_acc_train = []

best_acc = 0

for epoch in range(100):

    _list = []
    acc_train = []

    model.train()

    for data in tqdm(train_dataloder):

        # 0 , 1 , 2
        # data["HR"].to(device) , data["OX"].to(device) , data["SaO2"].to(device)

        predictions = model(data["OX"].to(device))

        loss = loss_func(predictions, data["labels"].squeeze(1).long().to(device))

        predict_classes_val = torch.argmax(predictions, dim=1)

        acc_train.append(torch.sum(predict_classes_val.cpu() == data["labels"].squeeze(1).long()) / predict_classes_val.size(0))

        _list.append(loss.item())

        optimizer.zero_grad()

        loss.backward()
        _ = nn.utils.clip_grad_norm_(model.parameters(), 50.0)

        optimizer.step()

    if sum(acc_train) / len(acc_train) > best_acc:

        torch.save({
            "model": model._cnn.state_dict()
        } , f"./data_best_{level}_OX.pt")

        best_acc = sum(acc_train) / len(acc_train)

        print("best acc test " , best_acc)

    print("epoch", epoch , "train_loss => " , sum(_list) / len(_list) , "acc Train => " , sum(acc_train) / len(acc_train) )

    _lossed.append(_list)
    _acc_train.append(acc_train)

    if sum(_list) / len(_list) < 0.15 :
        break


100%|██████████| 105/105 [00:34<00:00,  3.06it/s]


best acc test  tensor(0.3391)
epoch 0 train_loss =>  0.8362609420503889 acc Train =>  tensor(0.3391)


100%|██████████| 105/105 [00:35<00:00,  2.98it/s]


best acc test  tensor(0.5574)
epoch 1 train_loss =>  0.5252310514450074 acc Train =>  tensor(0.5574)


100%|██████████| 105/105 [00:35<00:00,  2.92it/s]


best acc test  tensor(0.7598)
epoch 2 train_loss =>  0.30923901817628313 acc Train =>  tensor(0.7598)


100%|██████████| 105/105 [00:35<00:00,  2.92it/s]


best acc test  tensor(0.9057)
epoch 3 train_loss =>  0.15087763862240883 acc Train =>  tensor(0.9057)


100%|██████████| 105/105 [00:36<00:00,  2.91it/s]


best acc test  tensor(0.9663)
epoch 4 train_loss =>  0.07000280977005051 acc Train =>  tensor(0.9663)


In [16]:
from tqdm import tqdm

_lossed = []
_acc_train = []

best_acc = 0

for epoch in range(100):

    _list = []
    acc_train = []

    model.train()

    for data in tqdm(train_dataloder):

        # 0 , 1 , 2
        # data["HR"].to(device) , data["OX"].to(device) , data["SaO2"].to(device)

        predictions = model(data["SaO2"].to(device))

        loss = loss_func(predictions, data["labels"].squeeze(1).long().to(device))

        predict_classes_val = torch.argmax(predictions, dim=1)

        acc_train.append(torch.sum(predict_classes_val.cpu() == data["labels"].squeeze(1).long()) / predict_classes_val.size(0))

        _list.append(loss.item())

        optimizer.zero_grad()

        loss.backward()
        _ = nn.utils.clip_grad_norm_(model.parameters(), 50.0)

        optimizer.step()

    if sum(acc_train) / len(acc_train) > best_acc:

        torch.save({
            "model": model._cnn.state_dict()
        } , f"./data_best_{level}_SaO2.pt")

        best_acc = sum(acc_train) / len(acc_train)

        print("best acc test " , best_acc)

    print("epoch", epoch , "train_loss => " , sum(_list) / len(_list) , "acc Train => " , sum(acc_train) / len(acc_train) )

    _lossed.append(_list)
    _acc_train.append(acc_train)

    if sum(_list) / len(_list) < 0.15 :
        break


100%|██████████| 105/105 [00:35<00:00,  2.94it/s]


best acc test  tensor(0.5955)
epoch 0 train_loss =>  0.454664951137134 acc Train =>  tensor(0.5955)


100%|██████████| 105/105 [00:36<00:00,  2.90it/s]


best acc test  tensor(0.8061)
epoch 1 train_loss =>  0.18890386553747313 acc Train =>  tensor(0.8061)


100%|██████████| 105/105 [00:36<00:00,  2.86it/s]


best acc test  tensor(0.8960)
epoch 2 train_loss =>  0.10457750072791464 acc Train =>  tensor(0.8960)


In [18]:
class IndexingData:
    def __call__(self , data_labels : pd.DataFrame ):
        index = {
            0 : [] ,
            1 : [] ,
            2 : [] ,
            3 : [] ,
        }

        for i in range(len(data_labels)):

            label = int(data_labels.iloc[i].label)

            index[label].append(i)

        minimum = min([len(index[0]) , len(index[1]) , len(index[2]) , len(index[3]) ])

        data = []

        for _class in index:

            labels = index[_class]

            random.shuffle(labels)

            data = data + labels[0:minimum]

        return data

class EmbeddingEncoder:
    def __call__(
        self,
        data : torch.Tensor ,
        format : str = "sum",
        stride : int = 2 ,
        padding_format : str = "concat"
    ) -> torch.Tensor :
        # {
        #    data : [ batch , max_length ]
        # }

        encoder_all = torch.zeros(data.size() , dtype=torch.float)

        for i in range(data.size(1)):

            _ = [
                data[: , i : i + stride ]
            ]

            if _[0].size(1) < stride :
                _.append(data[: , 0 : stride - _[0].size(1)])

            if format == "tanh": # => stride => 10
                encoder_all[: , i] = torch.cat(_ , dim=-1).float().tanh().mean(dim=-1)

            elif format == "sinh": # => stride => 10
                encoder_all[: , i] = torch.cat(_ , dim=-1).float().sinh().mean(dim=-1)

            elif format == "softmax": # => stride => 2
                encoder_all[: , i] = torch.cat(_ , dim=-1).float().softmax(dim=-1).sum(dim=-1).softmax(dim=-1)

        return encoder_all

class Datasets(Dataset):
    def __init__(self ,
        datasets_HR : pd.DataFrame ,
        datasets_OX : pd.DataFrame ,
        datasets_SaO2 : pd.DataFrame ,
        datasets_labels : pd.DataFrame ,
        indexing : list ,
    ) -> None:
        super().__init__()

        self._HR , self._OX , self._SaO2 , self._labels = datasets_HR , datasets_OX , datasets_SaO2 , datasets_labels

        self._indexing = indexing

        self._ModelEncoder = EmbeddingEncoder()

    def __len__(self):
        return len(self._indexing)

    def __getitem__(self, index : int ):

        index = self._indexing[index]

        return {
            "HR" : torch.from_numpy(self._HR.iloc[index].values).float().flatten() ,
            "OX" : torch.from_numpy(self._OX.iloc[index].values).float().flatten() ,
            "SaO2" : torch.from_numpy(self._SaO2.iloc[index].values).float().flatten() ,
            "labels" : torch.from_numpy(self._labels.iloc[index].values).float().flatten() ,
        }

datasets_labels = pd.read_pickle('/content/label_all.pkl')
indexing = IndexingData()

data = Datasets(
    datasets_HR= pd.read_pickle('/content/HR_all.pkl')  ,
    datasets_OX= pd.read_pickle('/content/OX_all.pkl') ,
    datasets_SaO2= pd.read_pickle('/content/SaO2_all.pkl') ,
    datasets_labels= datasets_labels ,
    indexing=indexing(datasets_labels.rename(columns={0 : 'label'}))
)

train_data, valid_data = random_split(data, [0.8 , 0.2] , generator=torch.Generator().manual_seed(42))

train_dataloder =  DataLoader(train_data, batch_size= 32 , shuffle= True)
valid_dataloder =  DataLoader(valid_data, batch_size= 32 , shuffle= True)

In [19]:
class SpatialAttentionBlock(nn.Module):
    def __init__(self , in_channel : int = 1024):
        super().__init__()

        self._model = nn.Sequential(
            nn.Conv1d(in_channels=in_channel, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False),
            nn.Sigmoid()
        )

    def forward(self , input_ids : torch.Tensor ):
        return input_ids * self._model(input_ids)

class ChannelAttention(nn.Module):
    def __init__(self,  new_channels : int , last_channels : int):
        super().__init__()

        self._avgpool = nn.AdaptiveAvgPool1d(1)

        self._attn = nn.Sequential(
            nn.Linear(new_channels , last_channels , bias=False),
            nn.ReLU(),
            nn.Linear(last_channels , new_channels , bias=False),
            nn.Sigmoid()
        )

    def forward(self, input_ids : torch.Tensor ):

        _avg = self._avgpool(input_ids)

        return input_ids * self._attn(_avg.flatten(1)).view_as(_avg)

class DualAttention(nn.Module):
    def __init__(self, new_channels : int , last_channels : int ):
        super().__init__()

        self.channel_attn = ChannelAttention(new_channels=new_channels , last_channels=last_channels)
        self.spatial_attn = SpatialAttentionBlock(in_channel=new_channels)

    def forward(self, input_ids : torch.Tensor):
        return  self.spatial_attn(
            self.channel_attn(
                input_ids
            )
        )

class BlockEncoderCnn(nn.Module):
    def __init__(self , in_channels , out_channels , kernel_size : int = 3 , stride : int = 1 , padding : int = 1 , level : int = 0):
        super().__init__()

        if level == 0 :
            self._model = nn.Sequential(
                nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size , stride=stride , padding=padding , bias=False ),
                nn.BatchNorm1d( num_features=out_channels ),
                nn.LeakyReLU( inplace = True ),
                nn.Conv1d( out_channels, out_channels, kernel_size=1 ) ,
                DualAttention( new_channels= out_channels , last_channels= in_channels )
            )

            self._activation = nn.Sequential(
                nn.BatchNorm1d(num_features=out_channels) ,
                nn.ReLU() ,
                nn.MaxPool1d(kernel_size=2, stride=2)
            )

        elif level == 1 :

            self._model = nn.Sequential(
                nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size , stride=stride , padding=padding , bias=False ),
                nn.BatchNorm1d( num_features=out_channels ),
                nn.LeakyReLU( inplace = True )
            )

            self._activation = nn.MaxPool1d(kernel_size=2, stride=2)

        else :

            self._model = nn.Sequential(
                nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size , stride=stride , padding=padding , bias=False ),
                DualAttention( new_channels= out_channels , last_channels= in_channels )
            )

            self._activation = nn.Sequential(
                nn.BatchNorm1d(num_features=out_channels) ,
                nn.ReLU() ,
                nn.MaxPool1d(kernel_size=2, stride=2)
            )

    def forward(self , input_ids : torch.Tensor ):
        return self._activation(self._model(input_ids))

class EncoderCnnExtra(nn.Module):
    def __init__(self, level : int = 0):
        super().__init__()
        self._model = nn.Sequential(
            BlockEncoderCnn( in_channels=1   , out_channels=32   , kernel_size=3  , stride=1 , padding=1 , level=level),
            BlockEncoderCnn( in_channels=32  , out_channels=64   , kernel_size=3  , stride=1 , padding=1 , level=level),
            BlockEncoderCnn( in_channels=64  , out_channels=128  , kernel_size=15 , stride=2 , padding=2 , level=level),
            BlockEncoderCnn( in_channels=128 , out_channels=256  , kernel_size=15 , stride=2 , padding=2 , level=level),
            BlockEncoderCnn( in_channels=256 , out_channels=512  , kernel_size=11 , stride=2 , padding=2 , level=level),
            BlockEncoderCnn( in_channels=512 , out_channels=1024 , kernel_size=3  , stride=1 , padding=1 , level=level),
            BlockEncoderCnn( in_channels=1024 ,out_channels=2048 , kernel_size=3  , stride=1 , padding=1 , level=level),
            nn.AdaptiveAvgPool1d(1),
        )

    def forward(self , input_ids : torch.Tensor ):
        return self._model(input_ids.unsqueeze(1)).permute(0 , 2 , 1)

In [20]:
class Model(nn.Module):
    def __init__( self , num_class : int = 4 ):
        super().__init__()

        self._cnn = nn.ModuleList([EncoderCnnExtra(level=level) for level in range(3)])

        self._embedding_position = nn.ModuleList([nn.LSTM(2048 , 256 , batch_first=True ) for _ in range(3)])

        self._attn = nn.MultiheadAttention(embed_dim=256 * 3 , num_heads=3 , dropout=0.2 , batch_first=True )

        self._model = nn.Sequential(
            nn.Linear(256 * 3 , num_class) ,
            nn.Dropout(0.1) ,
            nn.Softmax(dim=-1)
        )

    def forward(self ,
        HR : torch.Tensor ,
        OX : torch.Tensor ,
        saO2 : torch.Tensor
    ):
        hidden = []

        for i , data in enumerate([HR , OX , saO2]):

            hidden.append(
                self._embedding_position[i](
                    self._cnn[i](data)
                )[0]
            )

        hidden = torch.cat(hidden , dim=-1).squeeze(1)

        return self._model(
            self._attn(
                hidden , hidden , hidden
            )[0]
        )


In [21]:
model = Model(num_class=4).to(device)

In [22]:
(
    model._cnn[0].load_state_dict(torch.load("/content/data_best_0_HR.pt")["model"] , strict=False ) ,
    model._cnn[1].load_state_dict(torch.load("/content/data_best_0_OX.pt")["model"] , strict=False ) ,
    model._cnn[2].load_state_dict(torch.load("/content/data_best_0_SaO2.pt")["model"] , strict=False ) ,
)

(<All keys matched successfully>,
 _IncompatibleKeys(missing_keys=[], unexpected_keys=['_model.0._model.3.weight', '_model.0._model.3.bias', '_model.0._model.4.channel_attn._attn.0.weight', '_model.0._model.4.channel_attn._attn.2.weight', '_model.0._model.4.spatial_attn._model.0.weight', '_model.0._activation.0.weight', '_model.0._activation.0.bias', '_model.0._activation.0.running_mean', '_model.0._activation.0.running_var', '_model.0._activation.0.num_batches_tracked', '_model.1._model.3.weight', '_model.1._model.3.bias', '_model.1._model.4.channel_attn._attn.0.weight', '_model.1._model.4.channel_attn._attn.2.weight', '_model.1._model.4.spatial_attn._model.0.weight', '_model.1._activation.0.weight', '_model.1._activation.0.bias', '_model.1._activation.0.running_mean', '_model.1._activation.0.running_var', '_model.1._activation.0.num_batches_tracked', '_model.2._model.3.weight', '_model.2._model.3.bias', '_model.2._model.4.channel_attn._attn.0.weight', '_model.2._model.4.channel_attn.

In [23]:
model._cnn.requires_grad_(False)

ModuleList(
  (0): EncoderCnnExtra(
    (_model): Sequential(
      (0): BlockEncoderCnn(
        (_model): Sequential(
          (0): Conv1d(1, 32, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
          (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.01, inplace=True)
          (3): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
          (4): DualAttention(
            (channel_attn): ChannelAttention(
              (_avgpool): AdaptiveAvgPool1d(output_size=1)
              (_attn): Sequential(
                (0): Linear(in_features=32, out_features=1, bias=False)
                (1): ReLU()
                (2): Linear(in_features=1, out_features=32, bias=False)
                (3): Sigmoid()
              )
            )
            (spatial_attn): SpatialAttentionBlock(
              (_model): Sequential(
                (0): Conv1d(32, 1, kernel_size=(3,), stride=(1,), padding=(1,), bias=Fal

In [24]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        """
        inputs: logits با shape [batch_size, num_classes]
        targets: اندیس کلاس صحیح با shape [batch_size]
        """
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss


In [25]:
optimizer = torch.optim.AdamW(model.parameters() , lr=0.00001 ) # lr/epoch
loss_func = FocalLoss().to(device)

In [27]:
from tqdm import tqdm

_lossed = []
_acc_train = []
_acc_test = []

best_acc = 0

for epoch in range(30):

    _list = []
    acc = []
    acc_train = []

    model.train()

    for data in tqdm(train_dataloder):

        predictions = model(data["HR"].to(device) , data["OX"].to(device) , data["SaO2"].to(device))

        loss = loss_func(predictions, data["labels"].squeeze(1).long().to(device))

        predict_classes_val = torch.argmax(predictions, dim=1)

        acc_train.append(torch.sum(predict_classes_val.cpu() == data["labels"].squeeze(1).long()) / predict_classes_val.size(0))

        _list.append(loss.item())

        optimizer.zero_grad()

        loss.backward()
        _ = nn.utils.clip_grad_norm_(model.parameters(), 50.0)

        optimizer.step()

    model.eval()

    for data in tqdm(valid_dataloder):

        with torch.no_grad():

            predictions = model(data["HR"].to(device) , data["OX"].to(device) , data["SaO2"].to(device))

            predict_classes_val = torch.argmax(predictions, dim=1)

        acc.append(torch.sum(predict_classes_val.cpu() == data["labels"].squeeze(1).long()) / predict_classes_val.size(0))

    print("epoch", epoch , "train_loss => " , sum(_list) / len(_list) , "acc Train => " , sum(acc_train) / len(acc_train) , "acc Test => " , sum(acc) / len(acc))

    if sum(acc) / len(acc) > best_acc:

        torch.save({
            "model": model.state_dict()
        } , "./data_best.pt")

        best_acc = sum(acc) / len(acc)

        print("best acc test " , best_acc)

    _lossed.append(_list)
    _acc_train.append(acc_train)
    _acc_test.append(acc)


100%|██████████| 68/68 [00:13<00:00,  5.07it/s]
100%|██████████| 17/17 [00:03<00:00,  5.56it/s]


epoch 0 train_loss =>  0.7131936655325049 acc Train =>  tensor(0.3973) acc Test =>  tensor(0.4703)
best acc test  tensor(0.4703)


100%|██████████| 68/68 [00:13<00:00,  5.07it/s]
100%|██████████| 17/17 [00:05<00:00,  2.96it/s]


epoch 1 train_loss =>  0.644093403044869 acc Train =>  tensor(0.4831) acc Test =>  tensor(0.6466)
best acc test  tensor(0.6466)


100%|██████████| 68/68 [00:16<00:00,  4.22it/s]
100%|██████████| 17/17 [00:03<00:00,  5.51it/s]


epoch 2 train_loss =>  0.5746541487820008 acc Train =>  tensor(0.5998) acc Test =>  tensor(0.6931)
best acc test  tensor(0.6931)


100%|██████████| 68/68 [00:13<00:00,  5.06it/s]
100%|██████████| 17/17 [00:03<00:00,  5.23it/s]


epoch 3 train_loss =>  0.52372186850099 acc Train =>  tensor(0.6512) acc Test =>  tensor(0.7489)
best acc test  tensor(0.7489)


100%|██████████| 68/68 [00:13<00:00,  5.05it/s]
100%|██████████| 17/17 [00:03<00:00,  5.54it/s]


epoch 4 train_loss =>  0.47157972744282556 acc Train =>  tensor(0.7150) acc Test =>  tensor(0.7779)
best acc test  tensor(0.7779)


100%|██████████| 68/68 [00:13<00:00,  5.04it/s]
100%|██████████| 17/17 [00:03<00:00,  5.52it/s]


epoch 5 train_loss =>  0.45351137659128976 acc Train =>  tensor(0.7325) acc Test =>  tensor(0.8303)
best acc test  tensor(0.8303)


100%|██████████| 68/68 [00:14<00:00,  4.54it/s]
100%|██████████| 17/17 [00:03<00:00,  5.23it/s]


epoch 6 train_loss =>  0.43413127169889565 acc Train =>  tensor(0.7491) acc Test =>  tensor(0.8465)
best acc test  tensor(0.8465)


100%|██████████| 68/68 [00:13<00:00,  5.05it/s]
100%|██████████| 17/17 [00:03<00:00,  5.55it/s]


epoch 7 train_loss =>  0.42448671688051787 acc Train =>  tensor(0.7628) acc Test =>  tensor(0.8444)


100%|██████████| 68/68 [00:14<00:00,  4.85it/s]
100%|██████████| 17/17 [00:03<00:00,  5.08it/s]


epoch 8 train_loss =>  0.41095390609082055 acc Train =>  tensor(0.7752) acc Test =>  tensor(0.8698)
best acc test  tensor(0.8698)


100%|██████████| 68/68 [00:13<00:00,  5.04it/s]
100%|██████████| 17/17 [00:03<00:00,  5.24it/s]


epoch 9 train_loss =>  0.3910312135429943 acc Train =>  tensor(0.8016) acc Test =>  tensor(0.9097)
best acc test  tensor(0.9097)


100%|██████████| 68/68 [00:13<00:00,  5.05it/s]
100%|██████████| 17/17 [00:03<00:00,  5.52it/s]


epoch 10 train_loss =>  0.3813592782791923 acc Train =>  tensor(0.8080) acc Test =>  tensor(0.8978)


100%|██████████| 68/68 [00:13<00:00,  5.06it/s]
100%|██████████| 17/17 [00:03<00:00,  5.51it/s]


epoch 11 train_loss =>  0.3815799478222342 acc Train =>  tensor(0.8117) acc Test =>  tensor(0.9134)
best acc test  tensor(0.9134)


100%|██████████| 68/68 [00:13<00:00,  5.05it/s]
100%|██████████| 17/17 [00:03<00:00,  5.39it/s]


epoch 12 train_loss =>  0.3595693050061955 acc Train =>  tensor(0.8369) acc Test =>  tensor(0.9061)


100%|██████████| 68/68 [00:13<00:00,  5.03it/s]
100%|██████████| 17/17 [00:03<00:00,  5.36it/s]


epoch 13 train_loss =>  0.3644108382218024 acc Train =>  tensor(0.8271) acc Test =>  tensor(0.9281)
best acc test  tensor(0.9281)


100%|██████████| 68/68 [00:13<00:00,  5.04it/s]
100%|██████████| 17/17 [00:03<00:00,  5.51it/s]


epoch 14 train_loss =>  0.36348677777192173 acc Train =>  tensor(0.8325) acc Test =>  tensor(0.9059)


100%|██████████| 68/68 [00:13<00:00,  5.06it/s]
100%|██████████| 17/17 [00:03<00:00,  5.55it/s]


epoch 15 train_loss =>  0.3508837759933051 acc Train =>  tensor(0.8485) acc Test =>  tensor(0.9167)


100%|██████████| 68/68 [00:13<00:00,  5.07it/s]
100%|██████████| 17/17 [00:03<00:00,  5.22it/s]


epoch 16 train_loss =>  0.3560873164850123 acc Train =>  tensor(0.8369) acc Test =>  tensor(0.9257)


100%|██████████| 68/68 [00:13<00:00,  5.05it/s]
100%|██████████| 17/17 [00:03<00:00,  5.53it/s]


epoch 17 train_loss =>  0.35067971576662627 acc Train =>  tensor(0.8388) acc Test =>  tensor(0.8726)


100%|██████████| 68/68 [00:13<00:00,  5.07it/s]
100%|██████████| 17/17 [00:03<00:00,  5.58it/s]


epoch 18 train_loss =>  0.3557818987790276 acc Train =>  tensor(0.8339) acc Test =>  tensor(0.9204)


100%|██████████| 68/68 [00:13<00:00,  5.05it/s]
100%|██████████| 17/17 [00:03<00:00,  5.47it/s]


epoch 19 train_loss =>  0.3464997866574456 acc Train =>  tensor(0.8441) acc Test =>  tensor(0.9132)


100%|██████████| 68/68 [00:13<00:00,  5.05it/s]
100%|██████████| 17/17 [00:03<00:00,  5.30it/s]


epoch 20 train_loss =>  0.3451533435898669 acc Train =>  tensor(0.8458) acc Test =>  tensor(0.9210)


100%|██████████| 68/68 [00:13<00:00,  5.06it/s]
100%|██████████| 17/17 [00:03<00:00,  5.49it/s]


epoch 21 train_loss =>  0.34018261599190097 acc Train =>  tensor(0.8603) acc Test =>  tensor(0.9224)


100%|██████████| 68/68 [00:13<00:00,  5.05it/s]
100%|██████████| 17/17 [00:03<00:00,  5.51it/s]


epoch 22 train_loss =>  0.3473982959985733 acc Train =>  tensor(0.8476) acc Test =>  tensor(0.9130)


100%|██████████| 68/68 [00:13<00:00,  5.05it/s]
100%|██████████| 17/17 [00:03<00:00,  5.36it/s]


epoch 23 train_loss =>  0.3347255758502904 acc Train =>  tensor(0.8587) acc Test =>  tensor(0.9022)


100%|██████████| 68/68 [00:13<00:00,  5.01it/s]
100%|██████████| 17/17 [00:03<00:00,  5.44it/s]


epoch 24 train_loss =>  0.3394908694659962 acc Train =>  tensor(0.8542) acc Test =>  tensor(0.9165)


100%|██████████| 68/68 [00:13<00:00,  5.02it/s]
100%|██████████| 17/17 [00:03<00:00,  5.53it/s]


epoch 25 train_loss =>  0.3394211079267895 acc Train =>  tensor(0.8499) acc Test =>  tensor(0.9318)
best acc test  tensor(0.9318)


100%|██████████| 68/68 [00:13<00:00,  5.05it/s]
100%|██████████| 17/17 [00:03<00:00,  5.51it/s]


epoch 26 train_loss =>  0.33695056661963463 acc Train =>  tensor(0.8538) acc Test =>  tensor(0.8979)


100%|██████████| 68/68 [00:15<00:00,  4.50it/s]
100%|██████████| 17/17 [00:03<00:00,  5.35it/s]


epoch 27 train_loss =>  0.3433745665585293 acc Train =>  tensor(0.8470) acc Test =>  tensor(0.9204)


100%|██████████| 68/68 [00:13<00:00,  5.07it/s]
100%|██████████| 17/17 [00:03<00:00,  5.52it/s]


epoch 28 train_loss =>  0.34949028294752627 acc Train =>  tensor(0.8367) acc Test =>  tensor(0.9071)


100%|██████████| 68/68 [00:13<00:00,  5.04it/s]
100%|██████████| 17/17 [00:03<00:00,  5.51it/s]

epoch 29 train_loss =>  0.34590739844476476 acc Train =>  tensor(0.8450) acc Test =>  tensor(0.9241)





In [34]:
import torch
import numpy as np
from tqdm import tqdm
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    roc_auc_score,
    roc_curve,
    precision_recall_curve,
    auc,
    cohen_kappa_score,
    f1_score
)
import matplotlib.pyplot as plt
import seaborn as sns
import os

# ==== Configuration ====
save_dir = "results_plots"
os.makedirs(save_dir, exist_ok=True)

# --- Data collection ---
all_real = []
all_predict = []
all_probs = []

model.eval()
with torch.no_grad():
    for data in tqdm(valid_dataloder):
        inputs = [data["HR"].to(device), data["OX"].to(device), data["SaO2"].to(device)]
        labels = data["labels"].to(device)

        outputs = model(*inputs)  # logits
        probs = torch.softmax(outputs, dim=1)
        preds = torch.argmax(probs, dim=1)

        all_real.append(labels.cpu())
        all_predict.append(preds.cpu())
        all_probs.append(probs.cpu())

# --- Convert to numpy arrays ---
all_real = torch.cat(all_real).numpy()
all_predict = torch.cat(all_predict).numpy()
all_probs = torch.cat(all_probs).numpy()

num_classes = all_probs.shape[1]
class_names = [f"Class {i}" for i in range(num_classes)]  # update if you have real names

# --- 1️⃣ Classification report ---
print("\n--- Classification Report ---")
print(classification_report(all_real, all_predict, target_names=class_names, digits=4))

# --- 2️⃣ Confusion Matrix ---
cm = confusion_matrix(all_real, all_predict)
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "confusion_matrix.png"), dpi=300)
plt.close()

# --- 3️⃣ Extra metrics ---
kappa = cohen_kappa_score(all_real, all_predict)

# ✅ Specificity per class
specificity = []
for i in range(num_classes):
    tn = np.sum(np.delete(np.delete(cm, i, axis=0), i, axis=1))
    fp = np.sum(np.delete(cm, i, axis=0)[:, i])
    spec = tn / (tn + fp) if (tn + fp) > 0 else 0
    specificity.append(spec)

avg_spec = np.mean(specificity)  # Average specificity (macro)

# --- 4️⃣ F1 per class and MF1 ---
f1_per_class = f1_score(all_real, all_predict, average=None)
avg_mf1 = f1_score(all_real, all_predict, average='macro')

# --- 5️⃣ ROC & PR Curves ---
auc_per_class = []
auprc_per_class = []

plt.figure(figsize=(7, 6))
for i, cname in enumerate(class_names):
    y_true = (all_real == i).astype(int)
    y_score = all_probs[:, i]
    fpr, tpr, _ = roc_curve(y_true, y_score)
    roc_auc = roc_auc_score(y_true, y_score)
    auc_per_class.append(roc_auc)
    plt.plot(fpr, tpr, label=f"{cname} (AUC={roc_auc:.3f})")

plt.plot([0, 1], [0, 1], 'k--', lw=1)
plt.title("ROC Curves (All Classes)")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "roc_curves_all_classes.png"), dpi=300)
plt.close()

# --- 6️⃣ AUPRC Curves ---
plt.figure(figsize=(7, 6))
for i, cname in enumerate(class_names):
    y_true = (all_real == i).astype(int)
    y_score = all_probs[:, i]
    precision, recall, _ = precision_recall_curve(y_true, y_score)
    pr_auc = auc(recall, precision)
    auprc_per_class.append(pr_auc)
    plt.plot(recall, precision, label=f"{cname} (AUPRC={pr_auc:.3f})")

plt.title("Precision-Recall Curves (All Classes)")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "precision_recall_curves_all_classes.png"), dpi=300)
plt.close()

# --- 7️⃣ Print Summary ---
avg_auc = np.mean(auc_per_class)
avg_auprc = np.mean(auprc_per_class)

print(f"\nCohen's Kappa: {kappa:.4f}")
print(f"Average AUC: {avg_auc:.4f}")
print(f"Average AUPRC: {avg_auprc:.4f}")
print(f"Average Macro-F1 (MF1): {avg_mf1:.4f}")
print(f"Average Specificity (Spec): {avg_spec:.4f}\n")

print("Per-class metrics:")
for i, cname in enumerate(class_names):
    print(f"{cname}: F1={f1_per_class[i]:.3f}, Spec={specificity[i]:.3f}, AUC={auc_per_class[i]:.3f}, AUPRC={auprc_per_class[i]:.3f}")

print(f"\n✅ All figures saved to: {os.path.abspath(save_dir)}")


100%|██████████| 17/17 [00:03<00:00,  5.53it/s]



--- Classification Report ---
              precision    recall  f1-score   support

     Class 0     0.9489    0.9091    0.9286       143
     Class 1     0.9015    0.9225    0.9119       129
     Class 2     0.8831    0.9444    0.9128       144
     Class 3     0.9746    0.9200    0.9465       125

    accuracy                         0.9242       541
   macro avg     0.9270    0.9240    0.9249       541
weighted avg     0.9260    0.9242    0.9245       541


Cohen's Kappa: 0.8988
Average AUC: 0.9882
Average AUPRC: 0.9739
Average Macro-F1 (MF1): 0.9249
Average Specificity (Spec): 0.9746

Per-class metrics:
Class 0: F1=0.929, Spec=0.982, AUC=0.990, AUPRC=0.979
Class 1: F1=0.912, Spec=0.968, AUC=0.984, AUPRC=0.963
Class 2: F1=0.913, Spec=0.955, AUC=0.988, AUPRC=0.975
Class 3: F1=0.947, Spec=0.993, AUC=0.990, AUPRC=0.979

✅ All figures saved to: /content/results_plots
