In [3]:
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from Custom_dataset import CustomDataset


In [4]:
adult_train = pd.read_csv('./dataset/ECG_adult_age_train.csv')
child_train = pd.read_csv('./dataset/ECG_child_age_train.csv')

In [5]:
adult_file_path = './dataset/ECG_adult_numpy_train/'
child_file_path = './dataset/ECG_child_numpy_train/'

In [6]:
adult_train_data, adult_valid_data = train_test_split(adult_train,shuffle=True, random_state=42, test_size=0.1)
adult_train_data.reset_index(inplace=True,drop=True)
adult_valid_data.reset_index(inplace=True,drop=True)

In [7]:
adult_train_dataset = CustomDataset(adult_file_path,adult_train_data)
adult_valid_dataset = CustomDataset(adult_file_path,adult_valid_data)
child_dataset = CustomDataset(child_file_path,child_train)

In [8]:
train_adult_loader = torch.utils.data.DataLoader(adult_train_dataset, batch_size = 16)
valid_adult_loader = torch.utils.data.DataLoader(adult_valid_dataset, batch_size = 16)

In [104]:
import torch
import torch.nn as nn

class Temporalblock(torch.nn.Module):
    def __init__(self, N, K, MP_factor):
        super().__init__()

        self.conv_layers = torch.nn.LazyConv1d(N, K, stride=1, padding = 'same')
        self.batchnorm_layers = torch.nn.BatchNorm1d(N)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool1d(kernel_size=MP_factor, stride=MP_factor)

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.batchnorm_layers(x)
        x = self.relu(x)
        x = self.maxpool(x)
        return x

class FCblock(torch.nn.Module):
    def __init__(self, N, dropout):
        super().__init__()
        self.linear = torch.nn.LazyLinear(N)
        self.batchnorm_layers = torch.nn.BatchNorm1d(N)
        self.relu = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.linear(x)
        x = self.batchnorm_layers(x)
        x = self.relu(x)
        x = self.dropout(x)
        return x

class Custom_model(torch.nn.Module):
    def __init__(self,N, K, MP_factor):
        super().__init__()
        self.first_block = Temporalblock(16, 7, 2)
        self.temporal_block = self._make_layer(N, K, MP_factor)
        self.spatial_block = Temporalblock(128, 1, 2)
        self.flatten = torch.nn.Flatten()
        self.fc1 = FCblock(128, 0.2)
        self.fc2 = FCblock(64, 0.2)

        self.linear_out = torch.nn.Linear(64,1)
    def _make_layer(self, N, K, MP_factor):
        layer = []
        for n, k, mp in zip(N, K, MP_factor):
            layer.append(Temporalblock(n, k, mp))
        return torch.nn.Sequential(*layer)
    
    def forward(self,x):
        x = self.first_block(x)
        x = self.temporal_block(x)
        x = self.spatial_block(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.linear_out(x)

        return x

N = (16,32,32,64,64,64,64)
K = (5,5,5,5,3,3,3)
MP_factor = (4,2,4,2,2,2,2)
    
model = Custom_model(N, K, MP_factor)



In [105]:
batch = next(iter(train_adult_loader))[0]
batch = torch.tensor(batch.reshape((16,12,5000)),dtype=torch.float32)
batch = torch.nn.functional.pad(batch,(0,120,0,0),'constant',0)
batch.size()

  batch = torch.tensor(batch.reshape((16,12,5000)),dtype=torch.float32)


In [107]:
model(batch).size()

torch.Size([16, 1])

In [None]:
print('train_run')
for epochs in range(5):
    for batch in tqdm(train_adult_loader):
        inputs = batch[1]['I'].to('cuda')
        label = torch.tensor(batch[2],dtype=torch.float32).to('cuda')
        output = model(torch.tensor(inputs, dtype = torch.float32))
        loss = criterion(output, label)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    valid_loss = 0
    print('valid_run')
    for batch in tqdm(valid_adult_loader):
        inputs = batch[1]['I'].to('cuda')
        label = torch.tensor(batch[2],dtype=torch.float32).to('cuda')
        with torch.no_grad():
            output = model(torch.tensor(inputs, dtype = torch.float32))
        loss = criterion(output, label)
        valid_loss += loss
    print(valid_loss/len(valid_adult_loader))