In [None]:
from customdataset import *
from SSD import *
from train_step import *
from eval_step import test_step
import os
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import time
import warnings 
warnings.filterwarnings('ignore')

In [None]:
SAMPLE_RATE = 8000
HOP_LENGTH = int(np.ceil(SAMPLE_RATE*0.01))#시간
WIN_LENGTH = int(np.ceil(SAMPLE_RATE / 15))#주파수
_length = WIN_LENGTH - 1

n=0
while _length > 1:
    _length = _length // 2
    n+=1
N_FFT = 2 ** (n+1)
N_MELS = 128 #300 if N_FFT // 2 + 1 > 300 else N_FFT // 2 + 1

print(SAMPLE_RATE, HOP_LENGTH, WIN_LENGTH, N_FFT, N_MELS)

filters = [
    (Biquad.LOWPASS, 400, SAMPLE_RATE, 1.0),
    (Biquad.HIGHPASS, 25, SAMPLE_RATE, 1.0),
    (Biquad.BANDPASS, 125, SAMPLE_RATE, 1.0),
    (Biquad.PEAK, 125, SAMPLE_RATE, 1.0),
    (Biquad.NOTCH, 125, SAMPLE_RATE, 1.0, 1.0),
    (Biquad.LOWSHELF, 200, SAMPLE_RATE, 1.0, 1.0),
    (Biquad.HIGHSHELF, 125, SAMPLE_RATE, 1.0, 1.0)
]

In [None]:
args={"SR" : SAMPLE_RATE,
      "HL" : HOP_LENGTH,
      "WL" : WIN_LENGTH,
      "n_FFT" : N_FFT,
      "n_MELS" : N_MELS,
      
      "augmentation" : False,
      "filter_params" : [filters[1], filters[0]],
      "padding_type" : 0,
      "freq_mask" : False,
      "time_mask" : False,
      "multi_channels" : False,
      "clipping" : True,
      "target_size" : (300, 300),
      "th" : 5,
      "cutting": True,
      
      "MODEL_NAME" : "MnetSSD",
      "is_freeze" : "False",
      "epoch_num" : 100,
      "batch_size" : 8,
      "min_lr" : 1e-4,
      "max_lr" : 2e-3,
      "optim_type" : "Adam",
      
      "conf_thresh" : 0.6,
      "nms_thresh" : 0.5, 
      "iou_thresh" : 0.7
     }



In [None]:
PATH=os.getenv("HOME")+"/aiffel/ECG_data/physionet.org/files/circor-heart-sound/1.0.3/training_data"
file_list = os.listdir(PATH)
txt_list = [os.path.join(PATH, file) for file in file_list if file.endswith(".txt")]

# 환자 아이디를 훈련, 검증, 테스트 데이터셋으로 나눔
train_patient_txt, extra_patient_txt = train_test_split(txt_list, test_size=0.4, random_state=42)
valid_patient_txt, test_patient_txt = train_test_split(extra_patient_txt, test_size=0.5, random_state=42)

print(len(txt_list) ,len(train_patient_txt),
      len(valid_patient_txt) ,len(test_patient_txt)
     )

In [None]:
#print(test_patient_txt[0]) 
#test_patient_txt=[test_patient_txt[0]]

In [None]:
def my_collate_fn(batch):
    targets = []
    imgs = []
    for sample in batch:
        imgs.append(sample[0])  # sample[0]은 화상 gt
        targets.append(torch.FloatTensor(sample[1]))  # sample[1]은 어노테이션 gt

    imgs = torch.stack(imgs, dim=0)
    return imgs, targets
BATCHSIZE = 8

In [None]:

s_t=time.time()

dataset = CustomDataset(PATH, train_patient_txt,
                        sample_rate = args['SR'],
                        hop_length = args['HL'],
                        n_mels = args['n_MELS'],
                        n_fft = args['n_FFT'],
                        win_length = args['WL'],
                        filter_params = args["filter_params"], 
                        padding_type = args["padding_type"], clipping = args["clipping"], 
                        target_size = args["target_size"], th = args["th"])
train_dataloader = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, collate_fn=my_collate_fn)
e_t=time.time()

print(e_t-s_t)

In [None]:
s_t=time.time()
dataset = CustomDataset(PATH, valid_patient_txt,
                        sample_rate = args['SR'],
                        hop_length = args['HL'],
                        n_mels = args['n_MELS'],
                        n_fft = args['n_FFT'],
                        win_length = args['WL'],
                        filter_params = args["filter_params"], 
                        padding_type = args["padding_type"], clipping = args["clipping"], 
                        target_size = args["target_size"], th = args["th"])
valid_dataloader = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, collate_fn=my_collate_fn)
e_t=time.time()

print(e_t-s_t)

In [None]:
s_t=time.time()
dataset = CustomDataset(PATH, test_patient_txt,
                        sample_rate = args['SR'],
                        hop_length = args['HL'],
                        n_mels = args['n_MELS'],
                        n_fft = args['n_FFT'],
                        win_length = args['WL'],
                        filter_params = args["filter_params"], 
                        padding_type = args["padding_type"], clipping = args["clipping"], 
                        target_size = args["target_size"], th = args["th"])
test_dataloader = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=False, collate_fn=my_collate_fn)
e_t=time.time()

print(e_t-s_t)

# train_step

In [None]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = build_model_SSD("Train", input_channels=1, is_freeze=False)
test_model =  build_model_SSD("Test", input_channels=1)
#model_weight_path='./objectdetection_model/ssd300_weight_100.pth'
#weight = torch.load(model_weight_path)
#model.load_state_dict(weight)

In [None]:
train_step(model, test_model, train_dataloader, valid_dataloader, args, is_wandb=True, device=DEVICE)

# test_step

In [None]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = build_model_SSD("Test", input_channels=1)
model_weight_path='./objectdetection_model/MnetSSD_weight_101_8_Adam_False.pth'
weight = torch.load(model_weight_path)

In [None]:
model.load_state_dict(weight)
model.eval()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
torch.autograd.set_detect_anomaly(True)
s = time.time()
result = test_step(model, test_dataloader, image_size=(300,300), device = DEVICE)
e = time.time()
print(e-s)


In [None]:
print(result) #(total_Recall, S1_Recall, S2_Recall,total_Precison, S2_Recall, S2_Precison, mAP)

In [None]:
visualization_step(model, test_dataloader, device = DEVICE)

In [None]:
!pip freeze | grep -E "torch" >> requirements.txt
!pip freeze | grep -E "skimage" >> requirements.txt
!pip freeze | grep -E "numpy" >> requirements.txt
!pip freeze | grep -E "librosa" >> requirements.txt
!pip freeze | grep -E "wandb" >> requirements.txt
!pip freeze | grep -E "scipy" >> requirements.txt
!pip freeze | grep -E "time" >> requirements.txt
!pip freeze | grep -E "pandas" >> requirements.txt
!pip freeze | grep -E "matplotlib" >> requirements.txt
!pip freeze | grep -E "cmapy" >> requirements.txt
!pip freeze | grep -E "nlpaug" >> requirements.txt
!pip freeze | grep -E "wandb" >> requirements.txt