In [None]:
import sklearn

from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import log_loss
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split

from torchvision.transforms.functional import to_pil_image
from torchvision import models

from PIL import Image
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary

import torchaudio
from tqdm.notebook import trange,tqdm

import numpy as np
import pandas as pd
import os

import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline
import soundfile as sf

import librosa
import librosa.display

import wave

from torchvision.io import read_image

import warnings
warnings.filterwarnings('ignore')


In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
index_df = pd.read_csv("index_dataset.csv", index_col=0)
middle_df = pd.read_csv("middle_dataset.csv", index_col=0)
ring_df = pd.read_csv("ring_dataset.csv", index_col=0)
pinkie_df = pd.read_csv("pinkie_dataset.csv", index_col=0)

In [None]:
index_nclass = index_df["target"].nunique()
middle_nclass = index_df["target"].nunique()
ring_nclass = index_df["target"].nunique()
pinkie_ncalss = index_df["target"].nunique()

In [None]:
index_nclass

In [None]:
index_df["target"].value_counts()

In [None]:
index_df.info()

In [None]:
#0~2까지 image, 3은 투명도 관련
read_image(index_df.iloc[0]["path"])[0:3].size()

In [None]:
train_index, test_index, _, _ = train_test_split(index_df, index_df['target'], test_size=0.3)
train_middle, test_middle, _, _ = train_test_split(middle_df, middle_df['target'], test_size=0.3)
train_ring, test_ring, _, _ = train_test_split(ring_df, ring_df['target'], test_size=0.3)
train_pinkie, test_pinkie, _, _ = train_test_split(pinkie_df, pinkie_df['target'], test_size=0.3)

In [None]:
class Guitar_Dataset(Dataset):
    def __init__(self, df, transform=None, target_transform=None):
        self.df = df
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        
        img_path = self.df.iloc[idx]["path"]
        name = self.df.iloc[idx]["file_name"].split("_")[0]
        
        audio = read_image(img_path)[0:3]
        audio = to_pil_image(audio)
        
        image = read_image("./2canny_crop_per_tablature_frames/%s/%s"%(name, self.df.iloc[idx]["file_name"]))
        image = to_pil_image(image)


        label = self.df.iloc[idx]["target"]
        
        if self.transform:
            image = self.transform(image)
            audio = self.transform(audio)
            
        
        if self.target_transform:
            label = self.target_transform(label)

        
        data = torch.cat((image, audio))
        
        return data, label

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(size=(224,224)),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [None]:
_,_,_ = next(iter(Guitar_Dataset(train_index)))

In [None]:
index_train_dataloader = DataLoader(Guitar_Dataset(train_index, transform), batch_size=4, shuffle=True)
index_test_dataloader = DataLoader(Guitar_Dataset(test_index, transform), batch_size=4, shuffle=False)

middle_train_dataloader = DataLoader(Guitar_Dataset(train_middle, transform), batch_size=4, shuffle=True)
middle_test_dataloader = DataLoader(Guitar_Dataset(test_middle, transform), batch_size=4, shuffle=False)

ring_train_dataloader = DataLoader(Guitar_Dataset(train_ring, transform), batch_size=4, shuffle=True)
ring_test_dataloader = DataLoader(Guitar_Dataset(test_ring, transform), batch_size=4, shuffle=False)

pinkie_train_dataloader = DataLoader(Guitar_Dataset(train_pinkie, transform), batch_size=4, shuffle=True)
pinkie_test_dataloader = DataLoader(Guitar_Dataset(test_pinkie, transform), batch_size=4, shuffle=False)

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    correct_train = 0
    for batch, (X, y) in enumerate(dataloader):
        # 예측(prediction)과 손실(loss) 계산
        X = X.to("cuda")
        y = y.to("cuda")

        pred = model(X)
        loss = loss_fn(pred, y)
        
        correct_train += (pred.argmax(1) == y).type(torch.float).sum().item()

        # 역전파
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        

        if batch % 3 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    
    correct_train /= size 
    print(f"Train Error: \n Accuracy: {(100*correct_train):>0.1f}%\n")
    return correct_train 


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss = 0
    correct_test = 0

    with torch.no_grad():
        for X, y in dataloader:
            X = X.to("cuda")
            y = y.to("cuda")
            pred = model(X)
            #print(pred)
            test_loss += loss_fn(pred, y).item()
            correct_test += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct_test /= size
    print(f"Test Error: \n Accuracy: {(100*correct_test):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    
    return correct_test

In [None]:
class ResNet(nn.Module):
    def __init__(self, nclass):
        super(ResNet, self).__init__()
    
        self.resnet_image = models.resnet18(pretrained=True)
        self.resnet_audio = models.resnet18(pretrained=True)
        
        self.resnet_image.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        
        num_ftrs = self.resnet_image.fc.in_features
        
        self.resnet_image = nn.Sequential(nn.Linear(num_ftrs, 256))
        self.resnet_audio = nn.Sequential(nn.Linear(num_ftrs, 256))
        self.FClayer = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Dropout(p=0.3),
            nn.Linear(256, 128), nn.ReLU(), nn.Dropout(p=0.3),
            nn.Linear(128, nclass))

        
    def forward(self, x):
        print(x.shape)
        img = self.resnet_image(x[0])
        audio = self.resnet_audio(x[1:4])
        x = torch.cat((img, audio))
        x = self.FClayer(x)
        return x

In [None]:
resnet_index = ResNet(ring_nclass).to("cuda")

epochs = 100
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet_index.parameters(), lr=1e-4)
acc_train_list = []
acc_test_list = []
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    acc_trian = train_loop(ring_train_dataloader, resnet_index, loss_fn, optimizer)
    acc_test = test_loop(ring_test_dataloader, resnet_index, loss_fn)
    acc_train_list.append(acc_trian)
    acc_test_list.append(acc_test)

print("Done!")


In [None]:
class Guitar_Dataset(Dataset):
    def __init__(self, df, transform=None, target_transform=None):
        self.df = df
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        
        img_path = self.df.iloc[idx]["path"]
        name = self.df.iloc[idx]["file_name"].split("_")[0]
        
        audio = read_image(img_path)[0:3]
        audio = to_pil_image(audio)
        
        image = read_image("./2canny_crop_per_tablature_frames/%s/%s"%(name, self.df.iloc[idx]["file_name"]))
        image = to_pil_image(image)


        label = self.df.iloc[idx]["target"]
        
        if self.transform:
            image = self.transform(image)
            audio = self.transform(audio)
            
        
        if self.target_transform:
            label = self.target_transform(label)

        
        data = torch.cat((image, audio))
        
        return data, label