In [1]:

import matplotlib.pyplot as plt
import numpy as np
import helper

import torch.nn as nn
import torchvision.models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
import torchvision.utils
import torch
import pandas as pd
from torchinfo import summary
from PIL import Image
from torchvision.transforms import ToTensor
from glob import glob
from torch.utils.data import Dataset, DataLoader, random_split
from copy import copy
from collections import defaultdict
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import time
from sklearn.metrics import classification_report
from tqdm import tqdm
import math
from torcheval.metrics import BinaryAccuracy
import os
import torchmetrics
import timm
import time
import datetime
import random
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size=4
image_count=50
img_size=256
tf = ToTensor()

In [2]:
class CustomDataset(Dataset):
    def __init__(self, id,image_list, label_list):
        self.img_path = image_list

        self.label = label_list
        self.id=id

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        id_tensor=self.id[idx]
        image_tensor = self.img_path[idx]
    
        label_tensor =  self.label[idx]
        return image_tensor, label_tensor


train_data=pd.read_csv('../../data/train.csv',encoding='cp949') 
file_path='../../data/frame/'
train_image_list=[]
for i in range(len(train_data)):
    file_name=train_data.loc[i]['FileName']
    id=file_name[:file_name.find('_')]
    train_image_list.append(file_path+id)
label_data=pd.read_csv('../../data/label_data.csv',encoding='cp949')  
train_label_list=[]
train_id_list=[]
train_image_tensor = torch.empty((len(train_image_list),image_count,3, img_size, img_size))
for i in tqdm(range(len(train_image_list))):
    folder_name=os.path.basename(train_image_list[i])
    dst_label=label_data.loc[label_data['일련번호']==int(folder_name[:-1])]
    dst_label=dst_label.loc[dst_label['구분값']==int(folder_name[-1])].reset_index()
    label=int(dst_label.loc[0]['OTE 원인'])
    train_id_list.append(folder_name)
    train_label_list.append(label-1) 
    image_file_list = glob(train_image_list[i]+'/*.jpg')
    if len(image_file_list)>image_count:
        image_index = torch.randint(low=0, high=len(
            image_file_list)-image_count, size=(1,))
        count = 0
        for index in range(image_count):
            image = 1-tf(Image.open(image_file_list[index]).resize((img_size,img_size)))
            train_image_tensor[i,count] = image
            count += 1
    else:
        count = 0
        for index in range(len(image_file_list)):
            image = 1-tf(Image.open(image_file_list[index]).resize((img_size,img_size)))
            train_image_tensor[i,count] = image
            count += 1
        for j in range(image_count-count):
            image = 1-tf(Image.open(image_file_list[j]).resize((img_size,img_size)))
            train_image_tensor[i,count] = image
            count += 1
            

train_dataset = CustomDataset(train_id_list,train_image_tensor, F.one_hot(torch.tensor(train_label_list).to(torch.int64)))
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)


100%|██████████| 2170/2170 [04:08<00:00,  8.73it/s]


In [3]:
class FeatureExtractor(nn.Module):
    """Feature extoractor block"""
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        cnn1= timm.create_model('efficientnet_b2', pretrained=True)
        self.feature_ex = nn.Sequential(*list(cnn1.children())[:-1])

    def forward(self, inputs):
        features = self.feature_ex(inputs)
        
        return features
    
class AttentionMILModel(nn.Module):
    def __init__(self, num_classes, image_feature_dim,feature_extractor_scale1: FeatureExtractor):
        super(AttentionMILModel, self).__init__()
        self.num_classes = num_classes
        self.image_feature_dim = image_feature_dim

        # Remove the classification head of the CNN model
        self.feature_extractor = feature_extractor_scale1
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(image_feature_dim, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )
        
        # Classification layer
        self.classification_layer = nn.Linear(image_feature_dim, num_classes)
        self.dropout=torch.nn.Dropout(0.2)
    def forward(self, inputs):
        batch_size, num_tiles, channels, height, width = inputs.size()
        
        # Flatten the inputs
        inputs = inputs.view(-1, channels, height, width)
        
        # Feature extraction using the pre-trained CNN
        features = self.feature_extractor(inputs)  # Shape: (batch_size * num_tiles, 2048, 1, 1)
        
        # Reshape features
        features = features.view(batch_size, num_tiles, -1)  # Shape: (batch_size, num_tiles, 2048)
        
        # Attention mechanism
        attention_weights = self.attention(features)  # Shape: (batch_size, num_tiles, 1)
        attention_weights = F.softmax(attention_weights, dim=1)  # Normalize attention weights
        
        # Apply attention weights to features
        attended_features = torch.sum(features * attention_weights, dim=1)  # Shape: (batch_size, 2048)
        attended_features=self.dropout(attended_features)
        attended_features=F.relu(attended_features)
        # Classification layer
        logits = self.classification_layer(attended_features)  # Shape: (batch_size, num_classes)
        
        return logits
Feature_Extractor=FeatureExtractor()
model = AttentionMILModel(3,1408,Feature_Extractor)
model = model.to(device)
accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=3).to(device)
optimizer = optim.Adam(model.parameters(), lr=2e-4)
summary(model,(batch_size,image_count,3,img_size,img_size))


Layer (type:depth-idx)                                  Output Shape              Param #
AttentionMILModel                                       [4, 3]                    --
├─FeatureExtractor: 1-1                                 [200, 1408]               --
│    └─Sequential: 2-1                                  [200, 1408]               --
│    │    └─Conv2d: 3-1                                 [200, 32, 128, 128]       864
│    │    └─BatchNormAct2d: 3-2                         [200, 32, 128, 128]       64
│    │    └─Sequential: 3-3                             [200, 352, 8, 8]          7,201,634
│    │    └─Conv2d: 3-4                                 [200, 1408, 8, 8]         495,616
│    │    └─BatchNormAct2d: 3-5                         [200, 1408, 8, 8]         2,816
│    │    └─SelectAdaptivePool2d: 3-6                   [200, 1408]               --
├─Sequential: 1-2                                       [4, 50, 1]                --
│    └─Linear: 2-2                          

In [4]:
start = time.time()
d = datetime.datetime.now()
now_time = f"{d.year}-{d.month}-{d.day} {d.hour}:{d.minute}:{d.second}"
print(f'[deeplearning Start]')
print(f'deeplearning Start Time : {now_time}')
MIN_loss=5000
train_loss_list=[]
val_loss_list=[]
train_acc_list=[]
sig=nn.Sigmoid()
val_acc_list=[]
MIN_acc=0

for epoch in range(50):
    train=tqdm(train_dataloader)
    count=0
    running_loss = 0.0
    acc_loss=0
    model.train()
    for x, y in train:
        
        y = y.to(device).float()
        count+=1
        x=x.to(device).float()
        optimizer.zero_grad()  # optimizer zero 로 초기화
        predict = model(x).to(device)
        cost = F.cross_entropy(predict.softmax(dim=1), y) # cost 구함
        acc=accuracy(predict.softmax(dim=1).argmax(dim=1),y.argmax(dim=1))
        cost.backward() # cost에 대한 backward 구함
        optimizer.step() 
        running_loss += cost.item()
        acc_loss+=acc
        train.set_description(f"epoch: {epoch+1}/{50} Step: {count+1} loss : {running_loss/count:.4f} accuracy: {acc_loss/count:.4f}")
    train_loss_list.append((running_loss/count))
    train_acc_list.append((acc_loss/count).cpu().detach().numpy())   
        
torch.save(model.state_dict(), '../../model/attention_eff50_MIL.pt')
end = time.time()
d = datetime.datetime.now()
now_time = f"{d.year}-{d.month}-{d.day} {d.hour}:{d.minute}:{d.second}"
print(f'deeplearning Time : {now_time}s Time taken : {start-end}')
print(f'[deeplearning End]')


[deeplearning Start]
deeplearning Start Time : 2023-12-19 11:49:20


epoch: 1/50 Step: 70 loss : 1.0096 accuracy: 0.5362:  13%|█▎        | 69/542 [00:30<03:27,  2.28it/s]


KeyboardInterrupt: 

In [10]:
print('[deeplearning Start]\n \
deeplearning Start Time : 2023-12-19 11:49:20\n \
epoch: 1/50 Step: 541 loss : 1.0194 accuracy: 0.4986: 100%|██████████| 540/540 [03:54<00:00,  2.30it/s]\n \
Validation epoch: 1/50 Step: 68 loss : 1.0451  accuracy: 0.4888: 100%|██████████| 67/67 [00:09<00:00,  7.20it/s]\n \
epoch: 2/50 Step: 541 loss : 0.9950 accuracy: 0.5347: 100%|██████████| 540/540 [03:55<00:00,  2.29it/s]\n \
Validation epoch: 2/50 Step: 68 loss : 0.9870  accuracy: 0.5373: 100%|██████████| 67/67 [00:09<00:00,  7.12it/s]\n \
epoch: 3/50 Step: 541 loss : 0.9734 accuracy: 0.5519: 100%|██████████| 540/540 [03:55<00:00,  2.29it/s]\n \
Validation epoch: 3/50 Step: 68 loss : 1.0056  accuracy: 0.5112: 100%|██████████| 67/67 [00:09<00:00,  7.07it/s]\n \
epoch: 4/50 Step: 541 loss : 0.9608 accuracy: 0.5644: 100%|██████████| 540/540 [03:56<00:00,  2.29it/s]\n \
Validation epoch: 4/50 Step: 68 loss : 0.9889  accuracy: 0.5410: 100%|██████████| 67/67 [00:09<00:00,  7.09it/s]\n \
epoch: 5/50 Step: 541 loss : 0.9481 accuracy: 0.5870: 100%|██████████| 540/540 [03:56<00:00,  2.29it/s]\n \
Validation epoch: 5/50 Step: 68 loss : 1.0009  accuracy: 0.5224: 100%|██████████| 67/67 [00:09<00:00,  7.14it/s]\n \
epoch: 6/50 Step: 541 loss : 0.9228 accuracy: 0.6153: 100%|██████████| 540/540 [03:55<00:00,  2.29it/s]\n \
Validation epoch: 6/50 Step: 68 loss : 0.9926  accuracy: 0.5261: 100%|██████████| 67/67 [00:09<00:00,  7.04it/s]\n \
epoch: 7/50 Step: 541 loss : 0.9114 accuracy: 0.6250: 100%|██████████| 540/540 [03:55<00:00,  2.29it/s]\n \
Validation epoch: 7/50 Step: 68 loss : 0.9841  accuracy: 0.5597: 100%|██████████| 67/67 [00:09<00:00,  7.21it/s]\n \
epoch: 8/50 Step: 541 loss : 0.8868 accuracy: 0.6500: 100%|██████████| 540/540 [03:55<00:00,  2.29it/s]\n \
Validation epoch: 8/50 Step: 68 loss : 0.9845  accuracy: 0.5448: 100%|██████████| 67/67 [00:09<00:00,  7.30it/s]\n \
epoch: 9/50 Step: 541 loss : 0.8699 accuracy: 0.6750: 100%|██████████| 540/540 [03:57<00:00,  2.28it/s]\n \
Validation epoch: 9/50 Step: 68 loss : 0.9889  accuracy: 0.5485: 100%|██████████| 67/67 [00:09<00:00,  6.87it/s]\n \
epoch: 10/50 Step: 541 loss : 0.8592 accuracy: 0.6801: 100%|██████████| 540/540 [03:55<00:00,  2.29it/s]\n \
Validation epoch: 10/50 Step: 68 loss : 0.9609  accuracy: 0.5634: 100%|██████████| 67/67 [00:09<00:00,  6.91it/s]\n \
epoch: 11/50 Step: 541 loss : 0.8371 accuracy: 0.7088: 100%|██████████| 540/540 [03:55<00:00,  2.29it/s]\n \
Validation epoch: 11/50 Step: 68 loss : 0.9256  accuracy: 0.6157: 100%|██████████| 67/67 [00:09<00:00,  6.97it/s]\n \
epoch: 12/50 Step: 541 loss : 0.8125 accuracy: 0.7315: 100%|██████████| 540/540 [03:55<00:00,  2.29it/s]\n \
Validation epoch: 12/50 Step: 68 loss : 0.9574  accuracy: 0.6314: 100%|██████████| 67/67 [00:09<00:00,  7.20it/s]\n \
epoch: 13/50 Step: 541 loss : 0.7927 accuracy: 0.7532: 100%|██████████| 540/540 [03:54<00:00,  2.30it/s]\n \
Validation epoch: 13/50 Step: 68 loss : 0.9104  accuracy: 0.6811: 100%|██████████| 67/67 [00:09<00:00,  7.23it/s]\n \
epoch: 14/50 Step: 541 loss : 0.7842 accuracy: 0.7634: 100%|██████████| 540/540 [03:54<00:00,  2.30it/s]\n \
Validation epoch: 14/50 Step: 68 loss : 0.8651  accuracy: 0.7311: 100%|██████████| 67/67 [00:09<00:00,  7.14it/s]\n \
epoch: 15/50 Step: 541 loss : 0.7696 accuracy: 0.7759: 100%|██████████| 540/540 [03:55<00:00,  2.30it/s]\n \
Validation epoch: 15/50 Step: 68 loss : 0.8251  accuracy: 0.7631: 100%|██████████| 67/67 [00:09<00:00,  7.16it/s]\n \
epoch: 16/50 Step: 541 loss : 0.7455 accuracy: 0.8046: 100%|██████████| 540/540 [03:54<00:00,  2.30it/s]\n \
Validation epoch: 16/50 Step: 68 loss : 0.9970  accuracy: 0.5485: 100%|██████████| 67/67 [00:09<00:00,  7.24it/s]\n \
epoch: 17/50 Step: 541 loss : 0.7392 accuracy: 0.8088: 100%|██████████| 540/540 [03:54<00:00,  2.30it/s]\n \
Validation epoch: 17/50 Step: 68 loss : 0.9877  accuracy: 0.5522: 100%|██████████| 67/67 [00:09<00:00,  7.24it/s]\n \
epoch: 18/50 Step: 541 loss : 0.7401 accuracy: 0.8065: 100%|██████████| 540/540 [03:54<00:00,  2.30it/s]\n \
Validation epoch: 18/50 Step: 68 loss : 1.0150  accuracy: 0.5261: 100%|██████████| 67/67 [00:09<00:00,  7.26it/s]\n \
epoch: 19/50 Step: 541 loss : 0.7204 accuracy: 0.8245: 100%|██████████| 540/540 [03:53<00:00,  2.31it/s]\n \
Validation epoch: 19/50 Step: 68 loss : 0.9656  accuracy: 0.5709: 100%|██████████| 67/67 [00:09<00:00,  7.40it/s]\n \
epoch: 20/50 Step: 541 loss : 0.7171 accuracy: 0.8347: 100%|██████████| 540/540 [03:53<00:00,  2.31it/s]\n \
Validation epoch: 20/50 Step: 68 loss : 0.9911  accuracy: 0.5485: 100%|██████████| 67/67 [00:09<00:00,  6.82it/s]\n \
epoch: 21/50 Step: 541 loss : 0.7015 accuracy: 0.8444: 100%|██████████| 540/540 [03:54<00:00,  2.30it/s]\n \
Validation epoch: 21/50 Step: 68 loss : 0.9824  accuracy: 0.5522: 100%|██████████| 67/67 [00:09<00:00,  6.91it/s]\n \
epoch: 22/50 Step: 541 loss : 0.7131 accuracy: 0.8338: 100%|██████████| 540/540 [03:53<00:00,  2.32it/s]\n \
Validation epoch: 22/50 Step: 68 loss : 0.9565  accuracy: 0.5858: 100%|██████████| 67/67 [00:09<00:00,  7.03it/s]\n \
epoch: 23/50 Step: 541 loss : 0.6749 accuracy: 0.8750: 100%|██████████| 540/540 [03:55<00:00,  2.30it/s]\n \
Validation epoch: 23/50 Step: 68 loss : 0.9797  accuracy: 0.5560: 100%|██████████| 67/67 [00:08<00:00,  7.54it/s]\n \
epoch: 24/50 Step: 541 loss : 0.6771 accuracy: 0.8727: 100%|██████████| 540/540 [03:53<00:00,  2.32it/s]\n \
Validation epoch: 24/50 Step: 68 loss : 1.0148  accuracy: 0.5299: 100%|██████████| 67/67 [00:09<00:00,  6.96it/s]\n \
epoch: 25/50 Step: 541 loss : 0.6908 accuracy: 0.8556: 100%|██████████| 540/540 [03:56<00:00,  2.28it/s]\n \
Validation epoch: 25/50 Step: 68 loss : 0.9950  accuracy: 0.5410: 100%|██████████| 67/67 [00:09<00:00,  7.02it/s]\n \
epoch: 26/50 Step: 541 loss : 0.6729 accuracy: 0.8755: 100%|██████████| 540/540 [03:57<00:00,  2.27it/s]\n \
Validation epoch: 26/50 Step: 68 loss : 0.9782  accuracy: 0.5634: 100%|██████████| 67/67 [00:09<00:00,  6.93it/s]\n \
epoch: 27/50 Step: 541 loss : 0.6877 accuracy: 0.8620: 100%|██████████| 540/540 [03:55<00:00,  2.29it/s]\n \
Validation epoch: 27/50 Step: 68 loss : 1.0036  accuracy: 0.5410: 100%|██████████| 67/67 [00:09<00:00,  6.73it/s]\n \
epoch: 28/50 Step: 541 loss : 0.6672 accuracy: 0.8824: 100%|██████████| 540/540 [03:58<00:00,  2.26it/s]\n \
Validation epoch: 28/50 Step: 68 loss : 0.9632  accuracy: 0.5784: 100%|██████████| 67/67 [00:09<00:00,  6.88it/s]\n \
epoch: 29/50 Step: 541 loss : 0.6572 accuracy: 0.8944: 100%|██████████| 540/540 [03:53<00:00,  2.31it/s]\n \
Validation epoch: 29/50 Step: 68 loss : 0.9418  accuracy: 0.6045: 100%|██████████| 67/67 [00:09<00:00,  7.44it/s]\n \
epoch: 30/50 Step: 541 loss : 0.6500 accuracy: 0.9019: 100%|██████████| 540/540 [03:53<00:00,  2.31it/s]\n \
Validation epoch: 30/50 Step: 68 loss : 0.9713  accuracy: 0.5597: 100%|██████████| 67/67 [00:09<00:00,  7.38it/s]\n \
epoch: 31/50 Step: 541 loss : 0.6631 accuracy: 0.8884: 100%|██████████| 540/540 [03:53<00:00,  2.32it/s]\n \
Validation epoch: 31/50 Step: 68 loss : 0.9853  accuracy: 0.5560: 100%|██████████| 67/67 [00:09<00:00,  6.77it/s]\n \
epoch: 32/50 Step: 541 loss : 0.6592 accuracy: 0.8903: 100%|██████████| 540/540 [03:57<00:00,  2.27it/s]\n \
Validation epoch: 32/50 Step: 68 loss : 0.9734  accuracy: 0.5746: 100%|██████████| 67/67 [00:09<00:00,  7.23it/s]\n \
epoch: 33/50 Step: 541 loss : 0.6482 accuracy: 0.9019: 100%|██████████| 540/540 [03:56<00:00,  2.28it/s]\n \
Validation epoch: 33/50 Step: 68 loss : 0.9545  accuracy: 0.5858: 100%|██████████| 67/67 [00:09<00:00,  7.24it/s]\n \
epoch: 34/50 Step: 541 loss : 0.6462 accuracy: 0.9032: 100%|██████████| 540/540 [03:58<00:00,  2.26it/s]\n \
Validation epoch: 34/50 Step: 68 loss : 1.0065  accuracy: 0.5299: 100%|██████████| 67/67 [00:09<00:00,  7.02it/s]\n \
epoch: 35/50 Step: 541 loss : 0.6382 accuracy: 0.9102: 100%|██████████| 540/540 [03:58<00:00,  2.26it/s]\n \
Validation epoch: 35/50 Step: 68 loss : 0.9724  accuracy: 0.5672: 100%|██████████| 67/67 [00:09<00:00,  7.29it/s]\n \
epoch: 36/50 Step: 541 loss : 0.6667 accuracy: 0.8829: 100%|██████████| 540/540 [03:57<00:00,  2.28it/s]\n \
Validation epoch: 36/50 Step: 68 loss : 1.0007  accuracy: 0.5336: 100%|██████████| 67/67 [00:09<00:00,  6.89it/s]\n \
epoch: 37/50 Step: 541 loss : 0.6635 accuracy: 0.8884: 100%|██████████| 540/540 [03:56<00:00,  2.28it/s]\n \
Validation epoch: 37/50 Step: 68 loss : 1.0118  accuracy: 0.5299: 100%|██████████| 67/67 [00:10<00:00,  6.60it/s]\n \
epoch: 38/50 Step: 541 loss : 0.6532 accuracy: 0.8977: 100%|██████████| 540/540 [03:56<00:00,  2.29it/s]\n \
Validation epoch: 38/50 Step: 68 loss : 0.9783  accuracy: 0.5709: 100%|██████████| 67/67 [00:09<00:00,  6.79it/s]\n \
epoch: 39/50 Step: 541 loss : 0.6474 accuracy: 0.9019: 100%|██████████| 540/540 [03:56<00:00,  2.28it/s]\n \
Validation epoch: 39/50 Step: 68 loss : 1.0232  accuracy: 0.5187: 100%|██████████| 67/67 [00:09<00:00,  6.90it/s]\n \
epoch: 40/50 Step: 541 loss : 0.6399 accuracy: 0.9106: 100%|██████████| 540/540 [03:56<00:00,  2.28it/s]\n \
Validation epoch: 40/50 Step: 68 loss : 0.9822  accuracy: 0.5522: 100%|██████████| 67/67 [00:09<00:00,  6.80it/s]\n \
epoch: 41/50 Step: 541 loss : 0.6173 accuracy: 0.9329: 100%|██████████| 540/540 [03:56<00:00,  2.28it/s]\n \
Validation epoch: 41/50 Step: 68 loss : 0.9794  accuracy: 0.5672: 100%|██████████| 67/67 [00:09<00:00,  6.88it/s]\n \
epoch: 42/50 Step: 541 loss : 0.6143 accuracy: 0.9384: 100%|██████████| 540/540 [03:55<00:00,  2.29it/s]\n \
Validation epoch: 42/50 Step: 68 loss : 1.0212  accuracy: 0.5075: 100%|██████████| 67/67 [00:08<00:00,  7.49it/s]\n \
epoch: 43/50 Step: 541 loss : 0.6568 accuracy: 0.8921: 100%|██████████| 540/540 [03:53<00:00,  2.31it/s]\n \
Validation epoch: 43/50 Step: 68 loss : 0.9951  accuracy: 0.5448: 100%|██████████| 67/67 [00:09<00:00,  6.80it/s]\n \
epoch: 44/50 Step: 541 loss : 0.6470 accuracy: 0.9032: 100%|██████████| 540/540 [03:58<00:00,  2.27it/s]\n \
Validation epoch: 44/50 Step: 68 loss : 1.0129  accuracy: 0.5224: 100%|██████████| 67/67 [00:09<00:00,  6.91it/s]\n \
epoch: 45/50 Step: 541 loss : 0.6378 accuracy: 0.9111: 100%|██████████| 540/540 [03:57<00:00,  2.28it/s]\n \
Validation epoch: 45/50 Step: 68 loss : 1.0036  accuracy: 0.5448: 100%|██████████| 67/67 [00:09<00:00,  6.85it/s]\n \
epoch: 46/50 Step: 541 loss : 0.6272 accuracy: 0.9231: 100%|██████████| 540/540 [03:57<00:00,  2.27it/s]\n \
Validation epoch: 46/50 Step: 68 loss : 1.0143  accuracy: 0.5224: 100%|██████████| 67/67 [00:09<00:00,  6.86it/s]\n \
epoch: 47/50 Step: 541 loss : 0.6255 accuracy: 0.9250: 100%|██████████| 540/540 [03:56<00:00,  2.28it/s]\n \
Validation epoch: 47/50 Step: 68 loss : 0.9707  accuracy: 0.5746: 100%|██████████| 67/67 [00:09<00:00,  6.86it/s]\n \
epoch: 48/50 Step: 541 loss : 0.6298 accuracy: 0.9194: 100%|██████████| 540/540 [03:57<00:00,  2.28it/s]\n \
Validation epoch: 48/50 Step: 68 loss : 1.0366  accuracy: 0.5000: 100%|██████████| 67/67 [00:09<00:00,  6.78it/s]\n \
epoch: 49/50 Step: 541 loss : 0.6384 accuracy: 0.9106: 100%|██████████| 540/540 [03:57<00:00,  2.27it/s]\n \
Validation epoch: 49/50 Step: 68 loss : 1.0164  accuracy: 0.5336: 100%|██████████| 67/67 [00:09<00:00,  6.88it/s]\n \
epoch: 50/50 Step: 541 loss : 0.6367 accuracy: 0.9120: 100%|██████████| 540/540 [03:56<00:00,  2.29it/s]\n \
Validation epoch: 50/50 Step: 68 loss : 1.0553  accuracy: 0.4776: 100%|██████████| 67/67 [00:09<00:00,  6.78it/s]\n \
deeplearning Time : 2023-12-19 15:14:38 Time taken : 12276.459565162659\n \
[deeplearning End]')

[deeplearning Start]
 deeplearning Start Time : 2023-12-19 11:49:20
 epoch: 1/50 Step: 541 loss : 1.0194 accuracy: 0.4986: 100%|██████████| 540/540 [03:54<00:00,  2.30it/s]
 Validation epoch: 1/50 Step: 68 loss : 1.0451  accuracy: 0.4888: 100%|██████████| 67/67 [00:09<00:00,  7.20it/s]
 epoch: 2/50 Step: 541 loss : 0.9950 accuracy: 0.5347: 100%|██████████| 540/540 [03:55<00:00,  2.29it/s]
 Validation epoch: 2/50 Step: 68 loss : 0.9870  accuracy: 0.5373: 100%|██████████| 67/67 [00:09<00:00,  7.12it/s]
 epoch: 3/50 Step: 541 loss : 0.9734 accuracy: 0.5519: 100%|██████████| 540/540 [03:55<00:00,  2.29it/s]
 Validation epoch: 3/50 Step: 68 loss : 1.0056  accuracy: 0.5112: 100%|██████████| 67/67 [00:09<00:00,  7.07it/s]
 epoch: 4/50 Step: 541 loss : 0.9608 accuracy: 0.5644: 100%|██████████| 540/540 [03:56<00:00,  2.29it/s]
 Validation epoch: 4/50 Step: 68 loss : 0.9889  accuracy: 0.5410: 100%|██████████| 67/67 [00:09<00:00,  7.09it/s]
 epoch: 5/50 Step: 541 loss : 0.9481 accuracy: 0.5870: 1

In [33]:
F.sigmoid(torch.tensor(1)).

tensor(0.7311)

In [None]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.metrics import f1_score
cm = confusion_matrix(total_y.cpu().argmax(axis=1),total_prob.cpu().argmax(axis=1))
classes = ['Oropharynx','Tonguebase','Epiglottis']

cm_display = ConfusionMatrixDisplay(cm,
                              display_labels=classes).plot()
f1 = f1_score(total_y.cpu().argmax(axis=1),total_prob.cpu().argmax(axis=1), average='macro')

print(f'total f1-score= {f1}') 

In [None]:
def getIntersection_Method2(a, b):
    indices = torch.zeros_like(a, dtype=torch.uint8)

    for elem in b:
        indices = indices | (a == elem).type(torch.uint8)

    intersection = a[indices.type(torch.bool)]
    return intersection

def dfaa(index,name_list,label_clip,path_list):
    for i in range(len(index)):
        id=path_list[index[i]][0]
        temp_label=label_clip.loc[label_clip['wake']==int(id[-1])]
        temp_label=temp_label.loc[temp_label['Serial Number']==int(id[:-1])]
        temp_label=temp_label.reset_index()
        file_name=temp_label.loc[0]['File Name']
        name_list.append(file_name)
    return name_list
    
y_label=total_y.cpu().argmax(axis=1)
prob_label=total_prob.cpu().argmax(axis=1)

label_clip=pd.read_csv('../../data/label.csv',encoding='cp949') 
file_name_list=[]
index_1_1=getIntersection_Method2(torch.where(y_label==0)[0],torch.where(prob_label==0)[0])[:76]
index_1_2=getIntersection_Method2(torch.where(y_label==0)[0],torch.where(prob_label==1)[0])[:18]
index_1_3=getIntersection_Method2(torch.where(y_label==0)[0],torch.where(prob_label==2)[0])[:3]
index_2_1=getIntersection_Method2(torch.where(y_label==1)[0],torch.where(prob_label==0)[0])[:24]
index_2_2=getIntersection_Method2(torch.where(y_label==1)[0],torch.where(prob_label==1)[0])[:96]
index_2_3=getIntersection_Method2(torch.where(y_label==1)[0],torch.where(prob_label==2)[0])[:6]
index_3_1=getIntersection_Method2(torch.where(y_label==2)[0],torch.where(prob_label==0)[0])[:7]
index_3_2=getIntersection_Method2(torch.where(y_label==2)[0],torch.where(prob_label==1)[0])[:10]
index_3_3=getIntersection_Method2(torch.where(y_label==2)[0],torch.where(prob_label==2)[0])[:31]

file_name_list=dfaa(index_1_1,file_name_list,label_clip,path_list)
file_name_list=dfaa(index_1_2,file_name_list,label_clip,path_list)
file_name_list=dfaa(index_1_3,file_name_list,label_clip,path_list)
file_name_list=dfaa(index_2_1,file_name_list,label_clip,path_list)
file_name_list=dfaa(index_2_2,file_name_list,label_clip,path_list)
file_name_list=dfaa(index_2_3,file_name_list,label_clip,path_list)
file_name_list=dfaa(index_3_1,file_name_list,label_clip,path_list)
file_name_list=dfaa(index_3_2,file_name_list,label_clip,path_list)
file_name_list=dfaa(index_3_3,file_name_list,label_clip,path_list)


In [None]:
pd.DataFrame(file_name_list).to_csv('../../data/test.csv')

In [None]:
pd.read_csv('../../data/test.csv',encoding='cp949') 



In [None]:
list