In [None]:
import librosa
import librosa.display
import numpy as np
import pandas as pd
import os
from typing import List, Tuple, Dict
import matplotlib.pyplot as plt
import random
import cv2
import pickle
import torch.nn as nn
import torch
import datetime
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms, utils, models
from torch.utils.data import Dataset, DataLoader
from facenet_pytorch import MTCNN
import logging
import torch
import torch.nn as nn
from einops import rearrange, reduce
import tqdm

In [None]:
from src.models import ASTModel
import os
import torch

# Create a new class that inherits the original ASTModel class
class ASTModelVis(ASTModel):
    def get_att_map(self, block, x):
        qkv = block.attn.qkv
        num_heads = block.attn.num_heads
        scale = block.attn.scale
        B, N, C = x.shape
        qkv = qkv(x).reshape(B, N, 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
        attn = (q @ k.transpose(-2, -1)) * scale
        attn = attn.softmax(dim=-1)
        return attn

    def forward_visualization(self, x):
        # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
        x = x.unsqueeze(1)
        x = x.transpose(2, 3)

        B = x.shape[0]
        x = self.v.patch_embed(x)
        cls_tokens = self.v.cls_token.expand(B, -1, -1)
        dist_token = self.v.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)
        x = x + self.v.pos_embed
        x = self.v.pos_drop(x)
        # save the attention map of each of 12 Transformer layer
        att_list = []
        for blk in self.v.blocks:
            cur_att = self.get_att_map(blk, x)
            att_list.append(cur_att)
            x = blk(x)
        return att_list

def make_features(wav_name, mel_bins, target_length=1024):
    waveform, sr = torchaudio.load(wav_name)
    assert sr == 16000, 'input audio sampling rate must be 16kHz'

    fbank = torchaudio.compliance.kaldi.fbank(
        waveform, htk_compat=True, sample_frequency=sr, use_energy=False,
        window_type='hanning', num_mel_bins=mel_bins, dither=0.0, frame_shift=10)

    n_frames = fbank.shape[0]

    p = target_length - n_frames
    if p > 0:
        m = torch.nn.ZeroPad2d((0, 0, 0, p))
        fbank = m(fbank)
    elif p < 0:
        fbank = fbank[0:target_length, :]

    fbank = (fbank - (-4.2677393)) / (4.5689974 * 2)
    return fbank


def load_label(label_csv):
    with open(label_csv, 'r') as f:
        reader = csv.reader(f, delimiter=',')
        lines = list(reader)
    labels = []
    ids = []  # Each label has a unique id such as "/m/068hy"
    for i1 in range(1, len(lines)):
        id = lines[i1][1]
        label = lines[i1][2]
        ids.append(id)
        labels.append(label)
    return labels

# Assume each input spectrogram has 1024 time frames
input_tdim = 1054
# now load the visualization model
audio_model = ASTModelVis(label_dim=5, input_tdim=input_tdim, imagenet_pretrain=False, audioset_pretrain=False)

audio_model = audio_model.to(torch.device("cuda:1"))

In [None]:
with open('/home/ssrlab/kw/개성형성/audio_spectrogram_transformer/training_set_audio_spectrogram', "rb") as training_file:
    train_set_data = pickle.load(training_file)
with open('/home/ssrlab/kw/개성형성/audio_spectrogram_transformer/validation_set_audio_spectrogram', "rb") as training_file:
    valid_set_data = pickle.load(training_file)

In [None]:
def reshape_to_expected_input(dataset: List[Tuple[np.ndarray,np.ndarray]]) -> Tuple[np.ndarray,np.ndarray]:
    
    x0_list = []
    x1_list = []
    y_list = []
    for i in range(0,len(dataset)):
        x0_list.append(dataset[i][0])
        x1_list.append(dataset[i][1])
    return (np.stack(x0_list),np.stack(x1_list))

In [None]:
train_input = reshape_to_expected_input(dataset= train_set_data)
del train_set_data
valid_input = reshape_to_expected_input(dataset= valid_set_data)
del valid_set_data

In [None]:
USE_CUDA = torch.cuda.is_available()
print(USE_CUDA)


device = torch.device('cuda:1' if USE_CUDA else 'cpu')
print('학습을 진행하는 기기:', device)

In [None]:
batchsz = 8
num_workerssz = 5
lr = 1e-4
epochs = 120

class ChalearnDataset(Dataset):
    def __init__(self,imagedata,tagdata,transform=None):
        self.imagedata=imagedata
        self.tagdata=tagdata
        self.transform = transform
        
    def __len__(self):
        return len(self.imagedata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image_data=self.imagedata[idx]
        image_data=torch.FloatTensor(image_data)
        big_five_sorces=self.tagdata[idx]
        big_five_sorces = torch.FloatTensor(big_five_sorces)
        return image_data,big_five_sorces

In [None]:
train_dataset = ChalearnDataset(imagedata=train_input[0],tagdata=train_input[1])
valid_dataset = ChalearnDataset(imagedata=valid_input[0],tagdata=valid_input[1])
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batchsz, shuffle=True, num_workers=num_workerssz)
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=batchsz, shuffle=True, num_workers=num_workerssz)

In [None]:
criterion = torch.nn.L1Loss().to(device)  # 손실함수 MAE
optimizer = torch.optim.AdamW(audio_model.parameters(), lr=lr)

In [None]:
with torch.cuda.device(1):
    
    trainingEpoch_loss = []
    validationEpoch_loss = []
    
    for i in range(epochs):
        train_avg_loss = 0
        val_avg_loss = 0
        
        audio_model.train()
        for image_data, big_five_data in train_dataloader:
            
            image_data = image_data.to(device)
            big_five_data = reduce(big_five_data,'b c d -> b c', 'max')
            big_five_data = big_five_data.to(device)
            
            optimizer.zero_grad()  # 기울기가 0이 됩니다.
            
            hypothesis = audio_model(image_data)  # 모델의 예측 결과를 저장합니다.
            
            loss = criterion(hypothesis, big_five_data)  # 예측된 결과와 실제 태그 사이의 손실 값을 저장합니다.
            
            loss.backward()  # 역방향 전파입니다.
            optimizer.step()  # 매개 변수를 업데이트합니다
            
            train_avg_loss += loss  # 훈련 손실의 평균치입니다
        train_avg_loss=train_avg_loss/len(train_dataloader)
        trainingEpoch_loss.append(train_avg_loss)
        print('Epoch = {}, loss = {}'.format(i+1,train_avg_loss))
        
        with torch.no_grad():#validate
            audio_model.eval()
            for image_data, big_five_data in valid_dataloader:
                
                image_data = image_data.to(device)
                
                big_five_data = reduce(big_five_data,'b c d -> b c', 'max')
                big_five_data = big_five_data.to(device)
                
                hypothesis = audio_model(image_data)
                
                val_loss = criterion(hypothesis, big_five_data)
                val_avg_loss += val_loss
                
            val_avg_loss=val_avg_loss/len(valid_dataloader)
            validationEpoch_loss.append(val_avg_loss)
            print('Epoch = {}, val_loss = {}, 1 - MAE = {}'.format(i+1,val_avg_loss, 1 - val_avg_loss))
            print('\n')