# Data

In [1]:
import os
import warnings
from tqdm import tqdm

from PIL import Image
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import transforms

from sklearn.metrics import balanced_accuracy_score

class FERPlusDataset(data.Dataset):
    def __init__(self, data_csv, phase, transform=None):
        self.phase = phase
        self.transform = transform

        # Read the dataset CSV file
        self.data = pd.read_csv(data_csv)
        self.data.iloc[:, 2:12] = self.data.iloc[:, 2:12].replace(1, 0)
        # Get file paths and labels
        self.file_paths = self.data.iloc[:, 0].values
        self.counts = self.data.iloc[:, 2:12].values  # 감정 점수들

        # Apply constraints to filter valid samples
        self._apply_constraints()

        # Use argmax to determine the emotion class
        self.labels = np.argmax(self.counts, axis=1)

        # Debugging: Check label range
        print("Unique labels in dataset after filtering:", np.unique(self.labels))

    def _apply_constraints(self):
        # Constraint : 'unknown-face' 또는 'not-face' 레이블 제거
        max_counts = self.counts.max(axis=1)
        counts_eq_max = (self.counts == max_counts[:, None])
        constraint1_violation = counts_eq_max[:, [8, 9]].any(axis=1)

        # Constraint : 1인 라벨 0으로 만들기


        # Constraint : 최대 투표 수를 가진 레이블이 3개 초과 제거
        num_max_labels = counts_eq_max.sum(axis=1)
        constraint2_violation = num_max_labels > 3

        # Constraint : 최대 투표 수가 전체 투표 수의 절반 이하인 경우 제거
        total_votes = self.counts.sum(axis=1)
        constraint3_violation = max_counts <= (total_votes / 2)

        # Combine constraints
        valid_samples = ~(
            constraint1_violation | constraint2_violation | constraint3_violation
        )

        # Apply valid samples filter
        self.file_paths = self.file_paths[valid_samples]
        self.counts = self.counts[valid_samples]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        if self.phase == 'train':
            path = '/data/FER2013/FER2013Train/' + self.file_paths[idx]
        elif self.phase == 'val':
            path = '/data/FER2013/FER2013Valid/' + self.file_paths[idx]
        elif self.phase == 'test':
            path = '/data/FER2013/FER2013Test/' + self.file_paths[idx]
        image = Image.open(path).convert('RGB')
        label = self.labels[idx]

        if self.transform is not None:
            image = self.transform(image)

        return image, label



# Define variables
train_csv = '/data/FER2013/train_label.csv'
val_csv = '/data/FER2013/valid_label.csv'
test_csv = '/data/FER2013/test_label.csv'
batch_size = 128
lr = 0.1
workers = 4
epochs = 60

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = True

data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([
            transforms.RandomRotation(20),
            transforms.RandomCrop(224, padding=32)
        ], p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(scale=(0.02, 0.25)),
])

train_dataset = FERPlusDataset(train_csv, phase='train', transform=data_transforms)
print('Whole train set size:', train_dataset.__len__())

train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=batch_size,
                                            num_workers=workers,
                                            shuffle=True,
                                            pin_memory=True)

data_transforms_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
])

val_dataset = FERPlusDataset(val_csv, phase='val', transform=data_transforms_val)
print('Validation set size:', val_dataset.__len__())

val_loader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size=batch_size,
                                            num_workers=workers,
                                            shuffle=False,
                                            pin_memory=True)


Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Whole train set size: 25045
Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Validation set size: 3191


# Clip

In [1]:
import torch
import clip
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import os
import warnings
from tqdm import tqdm

from PIL import Image
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import transforms

from sklearn.metrics import balanced_accuracy_score

# CLIP 모델 로드
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
#model, preprocess = clip.load("ViT-L/14", device=device)

# 텍스트 레이블 정의
emotion_labels = ["neutral", "happiness", "surprise", "sadness", "anger", "disgust", "fear", "contempt"]
text_inputs = clip.tokenize([f"a photo of a person showing {label}" for label in emotion_labels]).to(device)


In [2]:
class FERPlusDataset(data.Dataset):
    def __init__(self, data_csv, phase, transform=None):
        self.phase = phase
        self.transform = transform

        # Read the dataset CSV file
        self.data = pd.read_csv(data_csv)
        self.data.iloc[:, 2:12] = self.data.iloc[:, 2:12].replace(1, 0)

        # Get file paths and labels
        self.file_paths = self.data.iloc[:, 0].values
        self.counts = self.data.iloc[:, 2:12].values

        # Apply constraints to filter valid samples
        self._apply_constraints()

        # Use argmax to determine the emotion class
        self.labels = np.argmax(self.counts, axis=1)

    def _apply_constraints(self):
        max_counts = self.counts.max(axis=1)
        counts_eq_max = (self.counts == max_counts[:, None])
        constraint1_violation = counts_eq_max[:, [8, 9]].any(axis=1)
        num_max_labels = counts_eq_max.sum(axis=1)
        constraint2_violation = num_max_labels > 3
        total_votes = self.counts.sum(axis=1)
        constraint3_violation = max_counts <= (total_votes / 2)
        valid_samples = ~(
            constraint1_violation | constraint2_violation | constraint3_violation
        )
        self.file_paths = self.file_paths[valid_samples]
        self.counts = self.counts[valid_samples]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        if self.phase == 'train':
            path = '/data/FER2013/FER2013Train/' + self.file_paths[idx]
        elif self.phase == 'valid':
            path = '/data/FER2013/FER2013Valid/' + self.file_paths[idx]
        elif self.phase == 'test':
            path = '/data/FER2013/FER2013Test/' + self.file_paths[idx]
        image = Image.open(path).convert('RGB')
        label = self.labels[idx]

        if self.transform is not None:
            image = self.transform(image)

        return image, label


In [3]:
# 데이터셋 준비
train_csv = '/data/FER2013/train_label.csv'
val_csv = '/data/FER2013/valid_label.csv'
test_csv = '/data/FER2013/test_label.csv'
batch_size = 128
lr = 0.01
workers = 4
epochs = 60

train_dataset = FERPlusDataset(train_csv, phase='train', transform=preprocess)
val_dataset = FERPlusDataset(val_csv, phase='valid', transform=preprocess)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)


In [4]:
import torch.nn as nn

# CLIP 모델에서 이미지 인코더만 가져오기
image_encoder = model.visual
for param in image_encoder.parameters():
    param.requires_grad = False  # 이미지 인코더 가중치 고정

# 새로운 분류기 정의
class FERClassifier(nn.Module):
    def __init__(self, num_classes):
        super(FERClassifier, self).__init__()
        self.image_encoder = image_encoder
        self.fc = nn.Linear(512, num_classes)  # CLIP의 출력 크기 512
        #self.fc = nn.Linear(768, num_classes)

    def forward(self, x):
        x = self.image_encoder(x)  # 이미지 인코딩
        x = x / x.norm(dim=-1, keepdim=True)  # 정규화
        x = self.fc(x.float())  # Linear 계층 입력을 float32로 변환
        return x


# 모델 초기화
num_classes = len(emotion_labels)  # FERPlus 데이터셋의 클래스 수
classifier = FERClassifier(num_classes).to(device)


In [5]:
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

# 손실 함수와 옵티마이저 정의
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

# 모델 전체를 float32로 변환
classifier = classifier.float()

# 학습 함수에서 데이터 타입 변환 제거
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs):
    for epoch in range(epochs):
        model.train()
        train_loss, correct, total = 0.0, 0, 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Metrics
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

        train_acc = correct / total * 100
        print(f"Epoch {epoch+1}, Loss: {train_loss/total:.4f}, Accuracy: {train_acc:.2f}%")

        # Validation
        val_acc = evaluate_model(model, val_loader)
        print(f"Validation Accuracy: {val_acc:.2f}%")

        # Scheduler step
        scheduler.step()

# 모델 평가 함수 정의
def evaluate_model(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
    return correct / total * 100

# 학습 시작
train_model(classifier, train_loader, val_loader, criterion, optimizer, scheduler, epochs=epochs)


Epoch 1/60:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 1/60: 100%|██████████| 196/196 [00:26<00:00,  7.45it/s]

Epoch 1, Loss: 0.0128, Accuracy: 41.29%





Validation Accuracy: 54.84%


Epoch 2/60: 100%|██████████| 196/196 [00:25<00:00,  7.59it/s]

Epoch 2, Loss: 0.0102, Accuracy: 61.90%





Validation Accuracy: 66.37%


Epoch 3/60: 100%|██████████| 196/196 [00:25<00:00,  7.66it/s]

Epoch 3, Loss: 0.0089, Accuracy: 69.74%





Validation Accuracy: 71.07%


Epoch 4/60: 100%|██████████| 196/196 [00:25<00:00,  7.59it/s]

Epoch 4, Loss: 0.0080, Accuracy: 74.00%





Validation Accuracy: 73.99%


Epoch 5/60: 100%|██████████| 196/196 [00:26<00:00,  7.54it/s]

Epoch 5, Loss: 0.0073, Accuracy: 76.13%





Validation Accuracy: 75.93%


Epoch 6/60: 100%|██████████| 196/196 [00:24<00:00,  7.92it/s]

Epoch 6, Loss: 0.0068, Accuracy: 77.47%





Validation Accuracy: 77.59%


Epoch 7/60: 100%|██████████| 196/196 [00:20<00:00,  9.39it/s]

Epoch 7, Loss: 0.0064, Accuracy: 78.49%





Validation Accuracy: 78.19%


Epoch 8/60: 100%|██████████| 196/196 [00:20<00:00,  9.38it/s]

Epoch 8, Loss: 0.0060, Accuracy: 79.43%





Validation Accuracy: 78.94%


Epoch 9/60: 100%|██████████| 196/196 [00:20<00:00,  9.37it/s]

Epoch 9, Loss: 0.0057, Accuracy: 80.00%





Validation Accuracy: 79.38%


Epoch 10/60: 100%|██████████| 196/196 [00:21<00:00,  8.96it/s]

Epoch 10, Loss: 0.0055, Accuracy: 80.57%





Validation Accuracy: 79.91%


Epoch 11/60: 100%|██████████| 196/196 [00:20<00:00,  9.41it/s]

Epoch 11, Loss: 0.0054, Accuracy: 80.85%





Validation Accuracy: 80.07%


Epoch 12/60: 100%|██████████| 196/196 [00:20<00:00,  9.34it/s]

Epoch 12, Loss: 0.0054, Accuracy: 80.88%





Validation Accuracy: 80.13%


Epoch 13/60: 100%|██████████| 196/196 [00:21<00:00,  9.23it/s]

Epoch 13, Loss: 0.0053, Accuracy: 80.97%





Validation Accuracy: 80.16%


Epoch 14/60: 100%|██████████| 196/196 [00:20<00:00,  9.38it/s]

Epoch 14, Loss: 0.0053, Accuracy: 81.03%





Validation Accuracy: 80.23%


Epoch 15/60: 100%|██████████| 196/196 [00:20<00:00,  9.40it/s]

Epoch 15, Loss: 0.0053, Accuracy: 81.08%





Validation Accuracy: 80.29%


Epoch 16/60: 100%|██████████| 196/196 [00:20<00:00,  9.42it/s]

Epoch 16, Loss: 0.0053, Accuracy: 81.09%





Validation Accuracy: 80.32%


Epoch 17/60: 100%|██████████| 196/196 [00:20<00:00,  9.42it/s]

Epoch 17, Loss: 0.0053, Accuracy: 81.13%





Validation Accuracy: 80.32%


Epoch 18/60: 100%|██████████| 196/196 [00:20<00:00,  9.37it/s]

Epoch 18, Loss: 0.0052, Accuracy: 81.16%





Validation Accuracy: 80.35%


Epoch 19/60: 100%|██████████| 196/196 [00:20<00:00,  9.40it/s]

Epoch 19, Loss: 0.0052, Accuracy: 81.22%





Validation Accuracy: 80.41%


Epoch 20/60: 100%|██████████| 196/196 [00:20<00:00,  9.39it/s]

Epoch 20, Loss: 0.0052, Accuracy: 81.21%





Validation Accuracy: 80.45%


Epoch 21/60: 100%|██████████| 196/196 [00:20<00:00,  9.42it/s]

Epoch 21, Loss: 0.0052, Accuracy: 81.25%





Validation Accuracy: 80.45%


Epoch 22/60: 100%|██████████| 196/196 [00:20<00:00,  9.37it/s]

Epoch 22, Loss: 0.0052, Accuracy: 81.25%





Validation Accuracy: 80.45%


Epoch 23/60: 100%|██████████| 196/196 [00:20<00:00,  9.38it/s]

Epoch 23, Loss: 0.0052, Accuracy: 81.26%





Validation Accuracy: 80.45%


Epoch 24/60: 100%|██████████| 196/196 [00:20<00:00,  9.40it/s]

Epoch 24, Loss: 0.0052, Accuracy: 81.26%





Validation Accuracy: 80.45%


Epoch 25/60: 100%|██████████| 196/196 [00:20<00:00,  9.41it/s]

Epoch 25, Loss: 0.0052, Accuracy: 81.26%





Validation Accuracy: 80.45%


Epoch 26/60: 100%|██████████| 196/196 [00:20<00:00,  9.35it/s]

Epoch 26, Loss: 0.0052, Accuracy: 81.27%





Validation Accuracy: 80.41%


Epoch 27/60: 100%|██████████| 196/196 [00:20<00:00,  9.38it/s]

Epoch 27, Loss: 0.0052, Accuracy: 81.27%





Validation Accuracy: 80.41%


Epoch 28/60: 100%|██████████| 196/196 [00:20<00:00,  9.41it/s]

Epoch 28, Loss: 0.0052, Accuracy: 81.27%





Validation Accuracy: 80.41%


Epoch 29/60: 100%|██████████| 196/196 [00:20<00:00,  9.39it/s]

Epoch 29, Loss: 0.0052, Accuracy: 81.28%





Validation Accuracy: 80.41%


Epoch 30/60: 100%|██████████| 196/196 [00:20<00:00,  9.41it/s]

Epoch 30, Loss: 0.0052, Accuracy: 81.28%





Validation Accuracy: 80.41%


Epoch 31/60: 100%|██████████| 196/196 [00:20<00:00,  9.38it/s]

Epoch 31, Loss: 0.0052, Accuracy: 81.29%





Validation Accuracy: 80.38%


Epoch 32/60: 100%|██████████| 196/196 [00:20<00:00,  9.38it/s]

Epoch 32, Loss: 0.0052, Accuracy: 81.29%





Validation Accuracy: 80.38%


Epoch 33/60: 100%|██████████| 196/196 [00:20<00:00,  9.41it/s]

Epoch 33, Loss: 0.0052, Accuracy: 81.29%





Validation Accuracy: 80.38%


Epoch 34/60: 100%|██████████| 196/196 [00:20<00:00,  9.40it/s]

Epoch 34, Loss: 0.0052, Accuracy: 81.29%





Validation Accuracy: 80.38%


Epoch 35/60: 100%|██████████| 196/196 [00:20<00:00,  9.37it/s]

Epoch 35, Loss: 0.0052, Accuracy: 81.29%





KeyboardInterrupt: 

# Clip - SelfAttention

In [1]:
import torch
import clip
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import os
import warnings
from tqdm import tqdm

from PIL import Image
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import transforms

from sklearn.metrics import balanced_accuracy_score

# CLIP 모델 로드
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
#model, preprocess = clip.load("ViT-L/14", device=device)

# 텍스트 레이블 정의
emotion_labels = ["neutral", "happiness", "surprise", "sadness", "anger", "disgust", "fear", "contempt"]
text_inputs = clip.tokenize([f"a photo of a person showing {label}" for label in emotion_labels]).to(device)


In [2]:
class FERPlusDataset(data.Dataset):
    def __init__(self, data_csv, phase, transform=None):
        self.phase = phase
        self.transform = transform

        # Read the dataset CSV file
        self.data = pd.read_csv(data_csv)
        self.data.iloc[:, 2:12] = self.data.iloc[:, 2:12].replace(1, 0)

        # Get file paths and labels
        self.file_paths = self.data.iloc[:, 0].values
        self.counts = self.data.iloc[:, 2:12].values

        # Apply constraints to filter valid samples
        self._apply_constraints()

        # Use argmax to determine the emotion class
        self.labels = np.argmax(self.counts, axis=1)

    def _apply_constraints(self):
        max_counts = self.counts.max(axis=1)
        counts_eq_max = (self.counts == max_counts[:, None])
        constraint1_violation = counts_eq_max[:, [8, 9]].any(axis=1)
        num_max_labels = counts_eq_max.sum(axis=1)
        constraint2_violation = num_max_labels > 3
        total_votes = self.counts.sum(axis=1)
        constraint3_violation = max_counts <= (total_votes / 2)
        valid_samples = ~(
            constraint1_violation | constraint2_violation | constraint3_violation
        )
        self.file_paths = self.file_paths[valid_samples]
        self.counts = self.counts[valid_samples]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        if self.phase == 'train':
            path = '/data/FER2013/FER2013Train/' + self.file_paths[idx]
        elif self.phase == 'valid':
            path = '/data/FER2013/FER2013Valid/' + self.file_paths[idx]
        elif self.phase == 'test':
            path = '/data/FER2013/FER2013Test/' + self.file_paths[idx]
        image = Image.open(path).convert('RGB')
        label = self.labels[idx]

        if self.transform is not None:
            image = self.transform(image)

        return image, label

# 데이터셋 준비
train_csv = '/data/FER2013/train_label.csv'
val_csv = '/data/FER2013/valid_label.csv'
test_csv = '/data/FER2013/test_label.csv'
batch_size = 128
lr = 0.01
workers = 4
epochs = 60

train_dataset = FERPlusDataset(train_csv, phase='train', transform=preprocess)
val_dataset = FERPlusDataset(val_csv, phase='valid', transform=preprocess)
test_dataset = FERPlusDataset(test_csv, phase='test', transform=preprocess)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DotProductAttention(nn.Module):
    def __init__(self, input_dim):
        super(DotProductAttention, self).__init__()
        self.input_dim = input_dim
        self.fc_q = nn.Linear(input_dim, input_dim)
        self.fc_k = nn.Linear(input_dim, input_dim)
        self.fc_v = nn.Linear(input_dim, input_dim)
    
    def forward(self, x):
        # Query, Key, Value 계산
        Q = self.fc_q(x)
        K = self.fc_k(x)
        V = self.fc_v(x)
        
        # Dot Product Attention
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.input_dim ** 0.5)  # Scaled Dot-Product Attention
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # Attention을 곱해 최종 출력 계산
        output = torch.matmul(attention_weights, V)
        return output

class FERClassifier(nn.Module):
    def __init__(self, num_classes):
        super(FERClassifier, self).__init__()
        self.image_encoder = model.visual
        for param in self.image_encoder.parameters():
            param.requires_grad = False  # 이미지 인코더 가중치 고정
        
        # FC 1 (차원 축소)
        self.fc1 = nn.Linear(512, 512)  # ViT-B/32의 출력 크기인 512을 512로 변환
        
        # Self-Attention
        self.attention = DotProductAttention(input_dim=512)  # Attention을 적용
        
        # FC 2 (최종 분류)
        self.fc2 = nn.Linear(512, num_classes)  # 클래스 예측
    
    def forward(self, x):
        # 이미지 인코딩
        x = self.image_encoder(x)
        
        # FC 1
        x = F.relu(self.fc1(x))  # FC1 이후 ReLU 적용
        
        # Self-Attention
        x = self.attention(x)  # Self-Attention을 적용
        
        # FC 2 (최종 분류)
        x = self.fc2(x)
        return x


# 모델 초기화
num_classes = len(emotion_labels)  # FERPlus 데이터셋의 클래스 수
classifier = FERClassifier(num_classes).to(device)


In [4]:
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

# 손실 함수와 옵티마이저 정의
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

# 모델 전체를 float32로 변환
classifier = classifier.float()

# 학습 함수에서 데이터 타입 변환 제거
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs):
    for epoch in range(epochs):
        model.train()
        train_loss, correct, total = 0.0, 0, 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Metrics
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

        train_acc = correct / total * 100
        print(f"Epoch {epoch+1}, Loss: {train_loss/total:.4f}, Accuracy: {train_acc:.2f}%")

        # Validation
        val_acc = evaluate_model(model, val_loader)
        print(f"Validation Accuracy: {val_acc:.2f}%")

        # Scheduler step
        scheduler.step()

# 모델 평가 함수 정의
def evaluate_model(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
    return correct / total * 100

# 학습 시작
train_model(classifier, train_loader, val_loader, criterion, optimizer, scheduler, epochs=epochs)


Epoch 1/60: 100%|██████████| 196/196 [00:25<00:00,  7.72it/s]

Epoch 1, Loss: 0.0089, Accuracy: 54.54%





Validation Accuracy: 73.02%


Epoch 2/60: 100%|██████████| 196/196 [00:24<00:00,  7.86it/s]

Epoch 2, Loss: 0.0052, Accuracy: 77.28%





Validation Accuracy: 80.51%


Epoch 3/60: 100%|██████████| 196/196 [00:24<00:00,  7.86it/s]

Epoch 3, Loss: 0.0042, Accuracy: 82.36%





Validation Accuracy: 82.29%


Epoch 4/60: 100%|██████████| 196/196 [00:24<00:00,  7.85it/s]

Epoch 4, Loss: 0.0039, Accuracy: 83.66%





Validation Accuracy: 82.04%


Epoch 5/60: 100%|██████████| 196/196 [00:24<00:00,  7.88it/s]

Epoch 5, Loss: 0.0036, Accuracy: 84.57%





Validation Accuracy: 83.30%


Epoch 6/60: 100%|██████████| 196/196 [00:24<00:00,  7.91it/s]

Epoch 6, Loss: 0.0035, Accuracy: 85.27%





Validation Accuracy: 84.33%


Epoch 7/60: 100%|██████████| 196/196 [00:24<00:00,  7.94it/s]

Epoch 7, Loss: 0.0033, Accuracy: 85.88%





Validation Accuracy: 84.52%


Epoch 8/60: 100%|██████████| 196/196 [00:24<00:00,  7.86it/s]

Epoch 8, Loss: 0.0032, Accuracy: 86.20%





Validation Accuracy: 84.30%


Epoch 9/60: 100%|██████████| 196/196 [00:24<00:00,  7.87it/s]

Epoch 9, Loss: 0.0030, Accuracy: 86.82%





Validation Accuracy: 83.86%


Epoch 10/60: 100%|██████████| 196/196 [00:25<00:00,  7.82it/s]

Epoch 10, Loss: 0.0029, Accuracy: 86.98%





Validation Accuracy: 84.68%


Epoch 11/60: 100%|██████████| 196/196 [00:24<00:00,  7.88it/s]

Epoch 11, Loss: 0.0025, Accuracy: 88.92%





Validation Accuracy: 85.49%


Epoch 12/60: 100%|██████████| 196/196 [00:24<00:00,  7.86it/s]

Epoch 12, Loss: 0.0024, Accuracy: 89.14%





Validation Accuracy: 85.30%


Epoch 13/60: 100%|██████████| 196/196 [00:25<00:00,  7.84it/s]

Epoch 13, Loss: 0.0024, Accuracy: 89.48%





Validation Accuracy: 84.86%


Epoch 14/60: 100%|██████████| 196/196 [00:24<00:00,  7.94it/s]

Epoch 14, Loss: 0.0023, Accuracy: 89.46%





Validation Accuracy: 85.30%


Epoch 15/60: 100%|██████████| 196/196 [00:24<00:00,  7.87it/s]

Epoch 15, Loss: 0.0023, Accuracy: 89.73%





Validation Accuracy: 84.42%


Epoch 16/60: 100%|██████████| 196/196 [00:24<00:00,  7.90it/s]

Epoch 16, Loss: 0.0023, Accuracy: 89.99%





Validation Accuracy: 84.93%


Epoch 17/60: 100%|██████████| 196/196 [00:24<00:00,  7.87it/s]

Epoch 17, Loss: 0.0022, Accuracy: 89.96%





Validation Accuracy: 84.90%


Epoch 18/60: 100%|██████████| 196/196 [00:25<00:00,  7.83it/s]

Epoch 18, Loss: 0.0022, Accuracy: 90.34%





Validation Accuracy: 85.24%


Epoch 19/60: 100%|██████████| 196/196 [00:25<00:00,  7.82it/s]

Epoch 19, Loss: 0.0022, Accuracy: 90.42%





Validation Accuracy: 85.37%


Epoch 20/60: 100%|██████████| 196/196 [00:24<00:00,  7.86it/s]

Epoch 20, Loss: 0.0021, Accuracy: 90.44%





Validation Accuracy: 84.80%


Epoch 21/60: 100%|██████████| 196/196 [00:24<00:00,  7.89it/s]

Epoch 21, Loss: 0.0020, Accuracy: 90.96%





Validation Accuracy: 85.49%


Epoch 22/60: 100%|██████████| 196/196 [00:24<00:00,  7.96it/s]

Epoch 22, Loss: 0.0020, Accuracy: 90.93%





Validation Accuracy: 85.43%


Epoch 23/60: 100%|██████████| 196/196 [00:24<00:00,  7.88it/s]

Epoch 23, Loss: 0.0020, Accuracy: 90.93%





Validation Accuracy: 85.33%


Epoch 24/60: 100%|██████████| 196/196 [00:24<00:00,  7.86it/s]

Epoch 24, Loss: 0.0020, Accuracy: 90.83%





Validation Accuracy: 85.37%


Epoch 25/60: 100%|██████████| 196/196 [00:24<00:00,  7.86it/s]

Epoch 25, Loss: 0.0020, Accuracy: 91.20%





Validation Accuracy: 85.27%


Epoch 26/60: 100%|██████████| 196/196 [00:24<00:00,  7.89it/s]

Epoch 26, Loss: 0.0020, Accuracy: 91.11%





Validation Accuracy: 85.37%


Epoch 27/60: 100%|██████████| 196/196 [00:24<00:00,  7.88it/s]

Epoch 27, Loss: 0.0020, Accuracy: 91.24%





Validation Accuracy: 85.18%


Epoch 28/60: 100%|██████████| 196/196 [00:24<00:00,  7.88it/s]

Epoch 28, Loss: 0.0020, Accuracy: 91.08%





Validation Accuracy: 85.43%


Epoch 29/60: 100%|██████████| 196/196 [00:24<00:00,  7.87it/s]

Epoch 29, Loss: 0.0020, Accuracy: 91.19%





Validation Accuracy: 85.08%


Epoch 30/60: 100%|██████████| 196/196 [00:24<00:00,  7.88it/s]

Epoch 30, Loss: 0.0020, Accuracy: 91.04%





Validation Accuracy: 85.24%


Epoch 31/60: 100%|██████████| 196/196 [00:24<00:00,  7.86it/s]

Epoch 31, Loss: 0.0020, Accuracy: 91.29%





Validation Accuracy: 85.24%


Epoch 32/60: 100%|██████████| 196/196 [00:24<00:00,  7.89it/s]

Epoch 32, Loss: 0.0020, Accuracy: 91.16%





Validation Accuracy: 85.21%


Epoch 33/60: 100%|██████████| 196/196 [00:24<00:00,  7.85it/s]

Epoch 33, Loss: 0.0020, Accuracy: 91.08%





Validation Accuracy: 85.21%


Epoch 34/60: 100%|██████████| 196/196 [00:24<00:00,  7.87it/s]

Epoch 34, Loss: 0.0020, Accuracy: 91.21%





Validation Accuracy: 85.21%


Epoch 35/60: 100%|██████████| 196/196 [00:24<00:00,  7.89it/s]

Epoch 35, Loss: 0.0020, Accuracy: 91.19%





Validation Accuracy: 85.18%


Epoch 36/60: 100%|██████████| 196/196 [00:24<00:00,  7.90it/s]

Epoch 36, Loss: 0.0020, Accuracy: 91.24%





Validation Accuracy: 85.18%


Epoch 37/60: 100%|██████████| 196/196 [00:24<00:00,  7.87it/s]

Epoch 37, Loss: 0.0020, Accuracy: 91.16%





Validation Accuracy: 85.18%


Epoch 38/60: 100%|██████████| 196/196 [00:24<00:00,  7.90it/s]

Epoch 38, Loss: 0.0020, Accuracy: 91.31%





Validation Accuracy: 85.18%


Epoch 39/60: 100%|██████████| 196/196 [00:25<00:00,  7.82it/s]

Epoch 39, Loss: 0.0020, Accuracy: 91.20%





Validation Accuracy: 85.21%


Epoch 40/60: 100%|██████████| 196/196 [00:24<00:00,  7.87it/s]

Epoch 40, Loss: 0.0020, Accuracy: 91.31%





Validation Accuracy: 85.15%


Epoch 41/60: 100%|██████████| 196/196 [00:25<00:00,  7.83it/s]

Epoch 41, Loss: 0.0020, Accuracy: 91.24%





Validation Accuracy: 85.15%


Epoch 42/60: 100%|██████████| 196/196 [00:24<00:00,  7.88it/s]

Epoch 42, Loss: 0.0020, Accuracy: 91.22%





Validation Accuracy: 85.15%


Epoch 43/60: 100%|██████████| 196/196 [00:24<00:00,  7.89it/s]

Epoch 43, Loss: 0.0020, Accuracy: 91.24%





Validation Accuracy: 85.15%


Epoch 44/60:  36%|███▌      | 70/196 [00:09<00:17,  7.32it/s]


KeyboardInterrupt: 

In [5]:
# 테스트 데이터 평가
def test_model(model, test_loader):
    model.eval()  # 평가 모드
    correct, total = 0, 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            
            # 총 정확도 계산
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

            # 자세한 결과 저장 (선택적)
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    # 정확도 계산
    accuracy = correct / total * 100
    print(f"Test Accuracy: {accuracy:.2f}%")

    return accuracy, all_labels, all_predictions


# 테스트 데이터 평가 시작
test_accuracy, test_labels, test_predictions = test_model(classifier, test_loader)

# 테스트 데이터의 구체적인 결과 출력 (선택적)
from sklearn.metrics import classification_report
print(classification_report(test_labels, test_predictions, target_names=emotion_labels))


Test Accuracy: 85.33%
              precision    recall  f1-score   support

     neutral       0.86      0.89      0.88      1083
   happiness       0.93      0.92      0.93       892
    surprise       0.82      0.86      0.84       394
     sadness       0.77      0.73      0.75       382
       anger       0.79      0.84      0.81       269
     disgust       0.00      0.00      0.00        16
        fear       0.60      0.50      0.54        86
    contempt       0.00      0.00      0.00        14

    accuracy                           0.85      3136
   macro avg       0.60      0.59      0.59      3136
weighted avg       0.84      0.85      0.85      3136



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


# Clip - SelfAttention Large

In [None]:
import torch
import clip
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import os
import warnings
from tqdm import tqdm

from PIL import Image
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import transforms

from sklearn.metrics import balanced_accuracy_score

# CLIP 모델 로드
device = "cuda" if torch.cuda.is_available() else "cpu"
#model, preprocess = clip.load("ViT-B/32", device=device)
model, preprocess = clip.load("ViT-L/14", device=device)

# 텍스트 레이블 정의
emotion_labels = ["neutral", "happiness", "surprise", "sadness", "anger", "disgust", "fear", "contempt"]
text_inputs = clip.tokenize([f"a photo of a person showing {label}" for label in emotion_labels]).to(device)


In [None]:
class FERPlusDataset(data.Dataset):
    def __init__(self, data_csv, phase, transform=None):
        self.phase = phase
        self.transform = transform

        # Read the dataset CSV file
        self.data = pd.read_csv(data_csv)
        self.data.iloc[:, 2:12] = self.data.iloc[:, 2:12].replace(1, 0)

        # Get file paths and labels
        self.file_paths = self.data.iloc[:, 0].values
        self.counts = self.data.iloc[:, 2:12].values

        # Apply constraints to filter valid samples
        self._apply_constraints()

        # Use argmax to determine the emotion class
        self.labels = np.argmax(self.counts, axis=1)

    def _apply_constraints(self):
        max_counts = self.counts.max(axis=1)
        counts_eq_max = (self.counts == max_counts[:, None])
        constraint1_violation = counts_eq_max[:, [8, 9]].any(axis=1)
        num_max_labels = counts_eq_max.sum(axis=1)
        constraint2_violation = num_max_labels > 3
        total_votes = self.counts.sum(axis=1)
        constraint3_violation = max_counts <= (total_votes / 2)
        valid_samples = ~(
            constraint1_violation | constraint2_violation | constraint3_violation
        )
        self.file_paths = self.file_paths[valid_samples]
        self.counts = self.counts[valid_samples]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        if self.phase == 'train':
            path = '/data/FER2013/FER2013Train/' + self.file_paths[idx]
        elif self.phase == 'valid':
            path = '/data/FER2013/FER2013Valid/' + self.file_paths[idx]
        elif self.phase == 'test':
            path = '/data/FER2013/FER2013Test/' + self.file_paths[idx]
        image = Image.open(path).convert('RGB')
        label = self.labels[idx]

        if self.transform is not None:
            image = self.transform(image)

        return image, label

# 데이터셋 준비
train_csv = '/data/FER2013/train_label.csv'
val_csv = '/data/FER2013/valid_label.csv'
test_csv = '/data/FER2013/test_label.csv'
batch_size = 128
lr = 0.01
workers = 4
epochs = 60

train_dataset = FERPlusDataset(train_csv, phase='train', transform=preprocess)
val_dataset = FERPlusDataset(val_csv, phase='valid', transform=preprocess)
test_dataset = FERPlusDataset(test_csv, phase='test', transform=preprocess)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DotProductAttention(nn.Module):
    def __init__(self, input_dim):
        super(DotProductAttention, self).__init__()
        self.input_dim = input_dim
        self.fc_q = nn.Linear(input_dim, input_dim)
        self.fc_k = nn.Linear(input_dim, input_dim)
        self.fc_v = nn.Linear(input_dim, input_dim)
    
    def forward(self, x):
        # Query, Key, Value 계산
        Q = self.fc_q(x)
        K = self.fc_k(x)
        V = self.fc_v(x)
        
        # Dot Product Attention
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.input_dim ** 0.5)  # Scaled Dot-Product Attention
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # Attention을 곱해 최종 출력 계산
        output = torch.matmul(attention_weights, V)
        return output

class FERClassifier(nn.Module):
    def __init__(self, num_classes):
        super(FERClassifier, self).__init__()
        self.image_encoder = model.visual
        for param in self.image_encoder.parameters():
            param.requires_grad = False  # 이미지 인코더 가중치 고정
        
        # FC 1 (차원 축소)
        self.fc1 = nn.Linear(768, 512)  # ViT-B/32의 출력 크기인 512을 512로 변환
        
        # Self-Attention
        self.attention = DotProductAttention(input_dim=512)  # Attention을 적용
        
        # FC 2 (최종 분류)
        self.fc2 = nn.Linear(512, num_classes)  # 클래스 예측
    
    def forward(self, x):
        # 이미지 인코딩
        x = self.image_encoder(x)
        
        # FC 1
        x = F.relu(self.fc1(x))  # FC1 이후 ReLU 적용
        
        # Self-Attention
        x = self.attention(x)  # Self-Attention을 적용
        
        # FC 2 (최종 분류)
        x = self.fc2(x)
        return x


# 모델 초기화
num_classes = len(emotion_labels)  # FERPlus 데이터셋의 클래스 수
classifier = FERClassifier(num_classes).to(device)


In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

# 손실 함수와 옵티마이저 정의
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

# 모델 전체를 float32로 변환
classifier = classifier.float()

# 학습 함수에서 데이터 타입 변환 제거
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs):
    for epoch in range(epochs):
        model.train()
        train_loss, correct, total = 0.0, 0, 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Metrics
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

        train_acc = correct / total * 100
        print(f"Epoch {epoch+1}, Loss: {train_loss/total:.4f}, Accuracy: {train_acc:.2f}%")

        # Validation
        val_acc = evaluate_model(model, val_loader)
        print(f"Validation Accuracy: {val_acc:.2f}%")

        # Scheduler step
        scheduler.step()

# 모델 평가 함수 정의
def evaluate_model(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
    return correct / total * 100

# 학습 시작
train_model(classifier, train_loader, val_loader, criterion, optimizer, scheduler, epochs=epochs)


Epoch 1/60:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 1/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 1, Loss: 0.0080, Accuracy: 59.07%





Validation Accuracy: 81.95%


Epoch 2/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 2, Loss: 0.0039, Accuracy: 83.67%





Validation Accuracy: 85.71%


Epoch 3/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 3, Loss: 0.0033, Accuracy: 86.60%





Validation Accuracy: 86.02%


Epoch 4/60: 100%|██████████| 196/196 [05:20<00:00,  1.64s/it]

Epoch 4, Loss: 0.0030, Accuracy: 87.75%





Validation Accuracy: 87.87%


Epoch 5/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 5, Loss: 0.0027, Accuracy: 88.69%





Validation Accuracy: 88.28%


Epoch 6/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 6, Loss: 0.0025, Accuracy: 89.51%





Validation Accuracy: 87.84%


Epoch 7/60: 100%|██████████| 196/196 [05:20<00:00,  1.64s/it]

Epoch 7, Loss: 0.0024, Accuracy: 89.69%





Validation Accuracy: 88.40%


Epoch 8/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 8, Loss: 0.0022, Accuracy: 90.37%





Validation Accuracy: 88.15%


Epoch 9/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 9, Loss: 0.0021, Accuracy: 90.90%





Validation Accuracy: 87.09%


Epoch 10/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 10, Loss: 0.0020, Accuracy: 91.24%





Validation Accuracy: 88.53%


Epoch 11/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 11, Loss: 0.0016, Accuracy: 92.78%





Validation Accuracy: 88.72%


Epoch 12/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 12, Loss: 0.0015, Accuracy: 93.34%





Validation Accuracy: 88.34%


Epoch 13/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 13, Loss: 0.0014, Accuracy: 93.60%





Validation Accuracy: 88.59%


Epoch 14/60: 100%|██████████| 196/196 [05:20<00:00,  1.64s/it]

Epoch 14, Loss: 0.0014, Accuracy: 93.68%





Validation Accuracy: 89.19%


Epoch 15/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 15, Loss: 0.0013, Accuracy: 93.84%





Validation Accuracy: 88.78%


Epoch 16/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 16, Loss: 0.0013, Accuracy: 93.89%





Validation Accuracy: 88.84%


Epoch 17/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 17, Loss: 0.0013, Accuracy: 94.06%





Validation Accuracy: 89.16%


Epoch 18/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 18, Loss: 0.0013, Accuracy: 94.21%





Validation Accuracy: 88.25%


Epoch 19/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 19, Loss: 0.0012, Accuracy: 94.42%





Validation Accuracy: 89.13%


Epoch 20/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 20, Loss: 0.0012, Accuracy: 94.56%





Validation Accuracy: 88.75%


Epoch 21/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 21, Loss: 0.0011, Accuracy: 94.81%





Validation Accuracy: 88.91%


Epoch 22/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 22, Loss: 0.0011, Accuracy: 95.05%





Validation Accuracy: 88.94%


Epoch 23/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 23, Loss: 0.0011, Accuracy: 95.09%





Validation Accuracy: 88.91%


Epoch 24/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 24, Loss: 0.0011, Accuracy: 95.08%





Validation Accuracy: 88.84%


Epoch 25/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 25, Loss: 0.0011, Accuracy: 95.02%





Validation Accuracy: 88.84%


Epoch 26/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 26, Loss: 0.0011, Accuracy: 95.00%





Validation Accuracy: 89.00%


Epoch 27/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 27, Loss: 0.0011, Accuracy: 95.12%





Validation Accuracy: 88.87%


Epoch 28/60: 100%|██████████| 196/196 [05:18<00:00,  1.63s/it]

Epoch 28, Loss: 0.0011, Accuracy: 95.15%





Validation Accuracy: 88.81%


Epoch 29/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 29, Loss: 0.0011, Accuracy: 95.06%





Validation Accuracy: 88.94%


Epoch 30/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 30, Loss: 0.0011, Accuracy: 95.24%





Validation Accuracy: 88.87%


Epoch 31/60: 100%|██████████| 196/196 [05:20<00:00,  1.64s/it]

Epoch 31, Loss: 0.0011, Accuracy: 95.15%





Validation Accuracy: 88.94%


Epoch 32/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 32, Loss: 0.0011, Accuracy: 95.18%





Validation Accuracy: 88.87%


Epoch 33/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 33, Loss: 0.0011, Accuracy: 95.25%





Validation Accuracy: 88.87%


Epoch 34/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 34, Loss: 0.0011, Accuracy: 95.27%





Validation Accuracy: 88.78%


Epoch 35/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 35, Loss: 0.0011, Accuracy: 95.24%





Validation Accuracy: 88.75%


Epoch 36/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 36, Loss: 0.0011, Accuracy: 95.20%





Validation Accuracy: 88.75%


Epoch 37/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 37, Loss: 0.0011, Accuracy: 95.18%





Validation Accuracy: 88.78%


Epoch 38/60: 100%|██████████| 196/196 [05:20<00:00,  1.64s/it]

Epoch 38, Loss: 0.0011, Accuracy: 95.18%





Validation Accuracy: 88.72%


Epoch 39/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 39, Loss: 0.0011, Accuracy: 95.29%





Validation Accuracy: 88.69%


Epoch 40/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 40, Loss: 0.0011, Accuracy: 95.20%





Validation Accuracy: 88.69%


Epoch 41/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 41, Loss: 0.0011, Accuracy: 95.16%





Validation Accuracy: 88.69%


Epoch 42/60: 100%|██████████| 196/196 [05:20<00:00,  1.64s/it]

Epoch 42, Loss: 0.0011, Accuracy: 95.06%





Validation Accuracy: 88.69%


Epoch 43/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 43, Loss: 0.0011, Accuracy: 95.16%





Validation Accuracy: 88.72%


Epoch 44/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 44, Loss: 0.0011, Accuracy: 95.23%





Validation Accuracy: 88.72%


Epoch 45/60: 100%|██████████| 196/196 [06:40<00:00,  2.04s/it]

Epoch 45, Loss: 0.0011, Accuracy: 95.14%





Validation Accuracy: 88.72%


Epoch 46/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 46, Loss: 0.0011, Accuracy: 95.19%





Validation Accuracy: 88.72%


Epoch 47/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 47, Loss: 0.0011, Accuracy: 95.16%





Validation Accuracy: 88.72%


Epoch 48/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 48, Loss: 0.0011, Accuracy: 95.29%





Validation Accuracy: 88.72%


Epoch 49/60: 100%|██████████| 196/196 [06:28<00:00,  1.98s/it]

Epoch 49, Loss: 0.0011, Accuracy: 95.16%





Validation Accuracy: 88.69%


Epoch 50/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 50, Loss: 0.0011, Accuracy: 95.12%





Validation Accuracy: 88.69%


Epoch 51/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 51, Loss: 0.0011, Accuracy: 95.22%





Validation Accuracy: 88.69%


Epoch 52/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 52, Loss: 0.0011, Accuracy: 95.25%





Validation Accuracy: 88.69%


Epoch 53/60: 100%|██████████| 196/196 [06:20<00:00,  1.94s/it]

Epoch 53, Loss: 0.0011, Accuracy: 95.24%





Validation Accuracy: 88.69%


Epoch 54/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 54, Loss: 0.0011, Accuracy: 95.07%





Validation Accuracy: 88.69%


Epoch 55/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 55, Loss: 0.0011, Accuracy: 95.12%





Validation Accuracy: 88.69%


Epoch 56/60: 100%|██████████| 196/196 [06:28<00:00,  1.98s/it]

Epoch 56, Loss: 0.0011, Accuracy: 95.23%





Validation Accuracy: 88.69%


Epoch 57/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 57, Loss: 0.0011, Accuracy: 95.16%





Validation Accuracy: 88.69%


Epoch 58/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 58, Loss: 0.0011, Accuracy: 95.34%





Validation Accuracy: 88.69%


Epoch 59/60: 100%|██████████| 196/196 [05:21<00:00,  1.64s/it]

Epoch 59, Loss: 0.0011, Accuracy: 95.20%





Validation Accuracy: 88.69%


Epoch 60/60: 100%|██████████| 196/196 [06:29<00:00,  1.99s/it]

Epoch 60, Loss: 0.0011, Accuracy: 95.24%





Validation Accuracy: 88.69%


In [None]:
# 테스트 데이터 평가
def test_model(model, test_loader):
    model.eval()  # 평가 모드
    correct, total = 0, 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            
            # 총 정확도 계산
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

            # 자세한 결과 저장 (선택적)
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    # 정확도 계산
    accuracy = correct / total * 100
    print(f"Test Accuracy: {accuracy:.2f}%")

    return accuracy, all_labels, all_predictions


# 테스트 데이터 평가 시작
test_accuracy, test_labels, test_predictions = test_model(classifier, test_loader)

# 테스트 데이터의 구체적인 결과 출력 (선택적)
from sklearn.metrics import classification_report
print(classification_report(test_labels, test_predictions, target_names=emotion_labels))


Test Accuracy: 88.33%
              precision    recall  f1-score   support

     neutral       0.88      0.90      0.89      1083
   happiness       0.95      0.95      0.95       892
    surprise       0.87      0.91      0.89       394
     sadness       0.78      0.78      0.78       382
       anger       0.88      0.91      0.89       269
     disgust       0.17      0.06      0.09        16
        fear       0.78      0.52      0.62        86
    contempt       0.33      0.07      0.12        14

    accuracy                           0.88      3136
   macro avg       0.70      0.64      0.65      3136
weighted avg       0.88      0.88      0.88      3136



# Clip - SelfAttention DO,BN

In [1]:
import torch
import clip
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import os
import warnings
from tqdm import tqdm

from PIL import Image
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import transforms

from sklearn.metrics import balanced_accuracy_score

# CLIP 모델 로드
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
#model, preprocess = clip.load("ViT-L/14", device=device)

# 텍스트 레이블 정의
emotion_labels = ["neutral", "happiness", "surprise", "sadness", "anger", "disgust", "fear", "contempt"]
text_inputs = clip.tokenize([f"{label} expression" for label in emotion_labels]).to(device)


In [2]:
class FERPlusDataset(data.Dataset):
    def __init__(self, data_csv, phase, transform=None):
        self.phase = phase
        self.transform = transform

        # Read the dataset CSV file
        self.data = pd.read_csv(data_csv)
        self.data.iloc[:, 2:12] = self.data.iloc[:, 2:12].replace(1, 0)

        # Get file paths and labels
        self.file_paths = self.data.iloc[:, 0].values
        self.counts = self.data.iloc[:, 2:12].values

        # Apply constraints to filter valid samples
        self._apply_constraints()

        # Use argmax to determine the emotion class
        self.labels = np.argmax(self.counts, axis=1)

    def _apply_constraints(self):
        max_counts = self.counts.max(axis=1)
        counts_eq_max = (self.counts == max_counts[:, None])
        constraint1_violation = counts_eq_max[:, [8, 9]].any(axis=1)
        num_max_labels = counts_eq_max.sum(axis=1)
        constraint2_violation = num_max_labels > 3
        total_votes = self.counts.sum(axis=1)
        constraint3_violation = max_counts <= (total_votes / 2)
        valid_samples = ~(
            constraint1_violation | constraint2_violation | constraint3_violation
        )
        self.file_paths = self.file_paths[valid_samples]
        self.counts = self.counts[valid_samples]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        if self.phase == 'train':
            path = '/data/FER2013/FER2013Train/' + self.file_paths[idx]
        elif self.phase == 'valid':
            path = '/data/FER2013/FER2013Valid/' + self.file_paths[idx]
        elif self.phase == 'test':
            path = '/data/FER2013/FER2013Test/' + self.file_paths[idx]
        image = Image.open(path).convert('RGB')
        label = self.labels[idx]

        if self.transform is not None:
            image = self.transform(image)

        return image, label

# 데이터셋 준비
train_csv = '/data/FER2013/train_label.csv'
val_csv = '/data/FER2013/valid_label.csv'
test_csv = '/data/FER2013/test_label.csv'
batch_size = 128
lr = 0.01
workers = 4
epochs = 60

train_dataset = FERPlusDataset(train_csv, phase='train', transform=preprocess)
val_dataset = FERPlusDataset(val_csv, phase='valid', transform=preprocess)
test_dataset = FERPlusDataset(test_csv, phase='test', transform=preprocess)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DotProductAttention(nn.Module):
    def __init__(self, input_dim):
        super(DotProductAttention, self).__init__()
        self.input_dim = input_dim
        self.fc_q = nn.Linear(input_dim, input_dim)
        self.fc_k = nn.Linear(input_dim, input_dim)
        self.fc_v = nn.Linear(input_dim, input_dim)
    
    def forward(self, x):
        # Query, Key, Value 계산
        Q = self.fc_q(x)
        K = self.fc_k(x)
        V = self.fc_v(x)
        
        # Dot Product Attention
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.input_dim ** 0.5)  # Scaled Dot-Product Attention
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # Attention을 곱해 최종 출력 계산
        output = torch.matmul(attention_weights, V)
        return output

class FERClassifier(nn.Module):
    def __init__(self, num_classes):
        super(FERClassifier, self).__init__()
        self.image_encoder = model.visual
        for param in self.image_encoder.parameters():
            param.requires_grad = False  # 이미지 인코더 가중치 고정
        
        # FC 1 (차원 축소)
        self.fc1 = nn.Linear(512, 512)
        self.bn1 = nn.BatchNorm1d(512)  # 배치 정규화 추가
        self.dropout1 = nn.Dropout(0.3)  # 드롭아웃 추가 (확률 30%)

        # Self-Attention
        self.attention = DotProductAttention(input_dim=512)  # Attention 적용

        # FC 2 (최종 분류)
        self.fc2 = nn.Linear(512, num_classes)
        self.bn2 = nn.BatchNorm1d(num_classes)  # 최종 출력에 배치 정규화 추가
        self.dropout2 = nn.Dropout(0.3)  # 드롭아웃 추가

    def forward(self, x):
        # 이미지 인코딩
        x = self.image_encoder(x)
        
        # FC 1
        x = self.fc1(x)
        x = self.bn1(x)  # 배치 정규화 적용
        x = F.relu(x)
        x = self.dropout1(x)  # 드롭아웃 적용
        
        # Self-Attention
        x = self.attention(x)
        
        # FC 2 (최종 분류)
        x = self.fc2(x)
        x = self.bn2(x)  # 배치 정규화 적용
        x = self.dropout2(x)  # 드롭아웃 적용
        return x

# 모델 초기화
num_classes = len(emotion_labels)  # FERPlus 데이터셋 클래스 수
classifier = FERClassifier(num_classes).to(device)


In [4]:
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

# 손실 함수와 옵티마이저 정의
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

# 모델 전체를 float32로 변환
classifier = classifier.float()

# 학습 함수에서 데이터 타입 변환 제거
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs):
    for epoch in range(epochs):
        model.train()
        train_loss, correct, total = 0.0, 0, 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Metrics
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

        train_acc = correct / total * 100
        print(f"Epoch {epoch+1}, Loss: {train_loss/total:.4f}, Accuracy: {train_acc:.2f}%")

        # Validation
        val_acc = evaluate_model(model, val_loader)
        print(f"Validation Accuracy: {val_acc:.2f}%")

        # Scheduler step
        scheduler.step()

# 모델 평가 함수 정의
def evaluate_model(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
    return correct / total * 100

# 학습 시작
train_model(classifier, train_loader, val_loader, criterion, optimizer, scheduler, epochs=epochs)


Epoch 1/60:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 1/60: 100%|██████████| 196/196 [00:21<00:00,  9.10it/s]

Epoch 1, Loss: 0.0096, Accuracy: 68.58%





Validation Accuracy: 81.26%


Epoch 2/60: 100%|██████████| 196/196 [00:21<00:00,  9.07it/s]

Epoch 2, Loss: 0.0081, Accuracy: 72.43%





Validation Accuracy: 81.95%


Epoch 3/60: 100%|██████████| 196/196 [00:21<00:00,  9.28it/s]

Epoch 3, Loss: 0.0076, Accuracy: 72.78%





Validation Accuracy: 83.08%


Epoch 4/60: 100%|██████████| 196/196 [00:21<00:00,  9.32it/s]

Epoch 4, Loss: 0.0072, Accuracy: 73.89%





Validation Accuracy: 83.80%


Epoch 5/60: 100%|██████████| 196/196 [00:21<00:00,  9.30it/s]

Epoch 5, Loss: 0.0069, Accuracy: 74.81%





Validation Accuracy: 84.11%


Epoch 6/60: 100%|██████████| 196/196 [00:21<00:00,  9.30it/s]

Epoch 6, Loss: 0.0067, Accuracy: 75.68%





Validation Accuracy: 83.70%


Epoch 7/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 7, Loss: 0.0066, Accuracy: 76.03%





Validation Accuracy: 84.96%


Epoch 8/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 8, Loss: 0.0065, Accuracy: 76.28%





Validation Accuracy: 83.42%


Epoch 9/60: 100%|██████████| 196/196 [00:21<00:00,  9.33it/s]

Epoch 9, Loss: 0.0065, Accuracy: 76.30%





Validation Accuracy: 84.71%


Epoch 10/60: 100%|██████████| 196/196 [00:21<00:00,  9.32it/s]

Epoch 10, Loss: 0.0062, Accuracy: 77.74%





Validation Accuracy: 84.71%


Epoch 11/60: 100%|██████████| 196/196 [00:21<00:00,  9.29it/s]

Epoch 11, Loss: 0.0059, Accuracy: 78.34%





Validation Accuracy: 85.68%


Epoch 12/60: 100%|██████████| 196/196 [00:21<00:00,  9.28it/s]

Epoch 12, Loss: 0.0057, Accuracy: 79.49%





Validation Accuracy: 86.09%


Epoch 13/60: 100%|██████████| 196/196 [00:21<00:00,  9.30it/s]

Epoch 13, Loss: 0.0057, Accuracy: 79.57%





Validation Accuracy: 85.99%


Epoch 14/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 14, Loss: 0.0056, Accuracy: 79.74%





Validation Accuracy: 86.05%


Epoch 15/60: 100%|██████████| 196/196 [00:21<00:00,  9.32it/s]

Epoch 15, Loss: 0.0056, Accuracy: 79.99%





Validation Accuracy: 86.43%


Epoch 16/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 16, Loss: 0.0056, Accuracy: 80.03%





Validation Accuracy: 86.12%


Epoch 17/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 17, Loss: 0.0056, Accuracy: 80.22%





Validation Accuracy: 86.37%


Epoch 18/60: 100%|██████████| 196/196 [00:20<00:00,  9.35it/s]

Epoch 18, Loss: 0.0055, Accuracy: 80.00%





Validation Accuracy: 86.34%


Epoch 19/60: 100%|██████████| 196/196 [00:21<00:00,  9.32it/s]

Epoch 19, Loss: 0.0054, Accuracy: 80.54%





Validation Accuracy: 86.09%


Epoch 20/60: 100%|██████████| 196/196 [00:21<00:00,  9.30it/s]

Epoch 20, Loss: 0.0054, Accuracy: 80.80%





Validation Accuracy: 86.24%


Epoch 21/60: 100%|██████████| 196/196 [00:20<00:00,  9.34it/s]

Epoch 21, Loss: 0.0054, Accuracy: 80.99%





Validation Accuracy: 85.99%


Epoch 22/60: 100%|██████████| 196/196 [00:21<00:00,  9.32it/s]

Epoch 22, Loss: 0.0054, Accuracy: 80.81%





Validation Accuracy: 85.99%


Epoch 23/60: 100%|██████████| 196/196 [00:21<00:00,  9.29it/s]

Epoch 23, Loss: 0.0053, Accuracy: 81.00%





Validation Accuracy: 86.02%


Epoch 24/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 24, Loss: 0.0054, Accuracy: 81.09%





Validation Accuracy: 85.93%


Epoch 25/60: 100%|██████████| 196/196 [00:21<00:00,  9.33it/s]

Epoch 25, Loss: 0.0053, Accuracy: 81.17%





Validation Accuracy: 85.96%


Epoch 26/60: 100%|██████████| 196/196 [00:21<00:00,  9.29it/s]

Epoch 26, Loss: 0.0054, Accuracy: 80.91%





Validation Accuracy: 86.27%


Epoch 27/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 27, Loss: 0.0053, Accuracy: 81.17%





Validation Accuracy: 85.96%


Epoch 28/60: 100%|██████████| 196/196 [00:21<00:00,  9.29it/s]

Epoch 28, Loss: 0.0053, Accuracy: 81.38%





Validation Accuracy: 86.27%


Epoch 29/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 29, Loss: 0.0053, Accuracy: 81.06%





Validation Accuracy: 86.05%


Epoch 30/60: 100%|██████████| 196/196 [00:21<00:00,  9.30it/s]

Epoch 30, Loss: 0.0052, Accuracy: 81.72%





Validation Accuracy: 86.18%


Epoch 31/60: 100%|██████████| 196/196 [00:21<00:00,  9.29it/s]

Epoch 31, Loss: 0.0053, Accuracy: 81.59%





Validation Accuracy: 85.90%


Epoch 32/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 32, Loss: 0.0053, Accuracy: 80.99%





Validation Accuracy: 86.02%


Epoch 33/60: 100%|██████████| 196/196 [00:21<00:00,  9.28it/s]

Epoch 33, Loss: 0.0053, Accuracy: 81.36%





Validation Accuracy: 85.99%


Epoch 34/60: 100%|██████████| 196/196 [00:21<00:00,  9.32it/s]

Epoch 34, Loss: 0.0053, Accuracy: 81.47%





Validation Accuracy: 85.96%


Epoch 35/60: 100%|██████████| 196/196 [00:21<00:00,  9.32it/s]

Epoch 35, Loss: 0.0053, Accuracy: 81.18%





Validation Accuracy: 85.99%


Epoch 36/60: 100%|██████████| 196/196 [00:21<00:00,  9.30it/s]

Epoch 36, Loss: 0.0053, Accuracy: 81.00%





Validation Accuracy: 86.09%


Epoch 37/60: 100%|██████████| 196/196 [00:21<00:00,  9.32it/s]

Epoch 37, Loss: 0.0053, Accuracy: 81.17%





Validation Accuracy: 85.80%


Epoch 38/60: 100%|██████████| 196/196 [00:20<00:00,  9.33it/s]

Epoch 38, Loss: 0.0053, Accuracy: 81.49%





Validation Accuracy: 86.37%


Epoch 39/60: 100%|██████████| 196/196 [00:20<00:00,  9.35it/s]

Epoch 39, Loss: 0.0053, Accuracy: 81.12%





Validation Accuracy: 86.31%


Epoch 40/60: 100%|██████████| 196/196 [00:20<00:00,  9.34it/s]

Epoch 40, Loss: 0.0053, Accuracy: 81.59%





Validation Accuracy: 86.12%


Epoch 41/60: 100%|██████████| 196/196 [00:20<00:00,  9.33it/s]

Epoch 41, Loss: 0.0053, Accuracy: 81.51%





Validation Accuracy: 86.05%


Epoch 42/60: 100%|██████████| 196/196 [00:21<00:00,  9.27it/s]

Epoch 42, Loss: 0.0053, Accuracy: 81.22%





Validation Accuracy: 85.99%


Epoch 43/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 43, Loss: 0.0053, Accuracy: 80.95%





Validation Accuracy: 86.02%


Epoch 44/60: 100%|██████████| 196/196 [00:21<00:00,  9.30it/s]

Epoch 44, Loss: 0.0053, Accuracy: 81.55%





Validation Accuracy: 86.09%


Epoch 45/60: 100%|██████████| 196/196 [00:21<00:00,  9.29it/s]

Epoch 45, Loss: 0.0052, Accuracy: 81.80%





Validation Accuracy: 85.93%


Epoch 46/60: 100%|██████████| 196/196 [00:21<00:00,  9.30it/s]

Epoch 46, Loss: 0.0053, Accuracy: 81.88%





Validation Accuracy: 85.96%


Epoch 47/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 47, Loss: 0.0053, Accuracy: 81.35%





Validation Accuracy: 85.99%


Epoch 48/60: 100%|██████████| 196/196 [00:21<00:00,  9.30it/s]

Epoch 48, Loss: 0.0053, Accuracy: 81.46%





Validation Accuracy: 85.99%


Epoch 49/60: 100%|██████████| 196/196 [00:21<00:00,  9.29it/s]

Epoch 49, Loss: 0.0053, Accuracy: 81.46%





Validation Accuracy: 86.34%


Epoch 50/60: 100%|██████████| 196/196 [00:21<00:00,  9.32it/s]

Epoch 50, Loss: 0.0053, Accuracy: 81.46%





Validation Accuracy: 86.12%


Epoch 51/60: 100%|██████████| 196/196 [00:21<00:00,  9.33it/s]

Epoch 51, Loss: 0.0052, Accuracy: 81.53%





Validation Accuracy: 86.05%


Epoch 52/60: 100%|██████████| 196/196 [00:21<00:00,  9.32it/s]

Epoch 52, Loss: 0.0052, Accuracy: 81.48%





Validation Accuracy: 86.09%


Epoch 53/60: 100%|██████████| 196/196 [00:21<00:00,  9.33it/s]

Epoch 53, Loss: 0.0054, Accuracy: 80.67%





Validation Accuracy: 86.18%


Epoch 54/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 54, Loss: 0.0053, Accuracy: 81.06%





Validation Accuracy: 86.15%


Epoch 55/60: 100%|██████████| 196/196 [00:21<00:00,  8.99it/s]

Epoch 55, Loss: 0.0053, Accuracy: 81.10%





Validation Accuracy: 86.02%


Epoch 56/60: 100%|██████████| 196/196 [00:21<00:00,  9.14it/s]

Epoch 56, Loss: 0.0053, Accuracy: 81.29%





Validation Accuracy: 85.99%


Epoch 57/60: 100%|██████████| 196/196 [00:22<00:00,  8.65it/s]

Epoch 57, Loss: 0.0053, Accuracy: 81.63%





Validation Accuracy: 85.96%


Epoch 58/60: 100%|██████████| 196/196 [00:22<00:00,  8.72it/s]

Epoch 58, Loss: 0.0053, Accuracy: 81.07%





Validation Accuracy: 86.15%


Epoch 59/60: 100%|██████████| 196/196 [00:22<00:00,  8.60it/s]

Epoch 59, Loss: 0.0053, Accuracy: 81.46%





Validation Accuracy: 85.96%


Epoch 60/60: 100%|██████████| 196/196 [00:22<00:00,  8.60it/s]

Epoch 60, Loss: 0.0052, Accuracy: 81.53%





Validation Accuracy: 86.21%


In [5]:
# 테스트 데이터 평가
def test_model(model, test_loader):
    model.eval()  # 평가 모드
    correct, total = 0, 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            
            # 총 정확도 계산
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

            # 자세한 결과 저장 (선택적)
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    # 정확도 계산
    accuracy = correct / total * 100
    print(f"Test Accuracy: {accuracy:.2f}%")

    return accuracy, all_labels, all_predictions


# 테스트 데이터 평가 시작
test_accuracy, test_labels, test_predictions = test_model(classifier, test_loader)

# 테스트 데이터의 구체적인 결과 출력 (선택적)
from sklearn.metrics import classification_report
print(classification_report(test_labels, test_predictions, target_names=emotion_labels))


Test Accuracy: 85.94%
              precision    recall  f1-score   support

     neutral       0.86      0.91      0.89      1083
   happiness       0.93      0.93      0.93       892
    surprise       0.84      0.86      0.85       394
     sadness       0.76      0.74      0.75       382
       anger       0.82      0.83      0.82       269
     disgust       0.00      0.00      0.00        16
        fear       0.60      0.40      0.48        86
    contempt       0.00      0.00      0.00        14

    accuracy                           0.86      3136
   macro avg       0.60      0.58      0.59      3136
weighted avg       0.85      0.86      0.85      3136



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


# Clip - SelfAttention Large DO,BN

In [1]:
import torch
import clip
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import os
import warnings
from tqdm import tqdm

from PIL import Image
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import transforms

from sklearn.metrics import balanced_accuracy_score

# CLIP 모델 로드
device = "cuda" if torch.cuda.is_available() else "cpu"
#model, preprocess = clip.load("ViT-B/32", device=device)
model, preprocess = clip.load("ViT-L/14", device=device)

# 텍스트 레이블 정의
emotion_labels = ["neutral", "happiness", "surprise", "sadness", "anger", "disgust", "fear", "contempt"]
text_inputs = clip.tokenize([f"{label} expression" for label in emotion_labels]).to(device)


In [2]:
class FERPlusDataset(data.Dataset):
    def __init__(self, data_csv, phase, transform=None):
        self.phase = phase
        self.transform = transform

        # Read the dataset CSV file
        self.data = pd.read_csv(data_csv)
        self.data.iloc[:, 2:12] = self.data.iloc[:, 2:12].replace(1, 0)

        # Get file paths and labels
        self.file_paths = self.data.iloc[:, 0].values
        self.counts = self.data.iloc[:, 2:12].values

        # Apply constraints to filter valid samples
        self._apply_constraints()

        # Use argmax to determine the emotion class
        self.labels = np.argmax(self.counts, axis=1)

    def _apply_constraints(self):
        max_counts = self.counts.max(axis=1)
        counts_eq_max = (self.counts == max_counts[:, None])
        constraint1_violation = counts_eq_max[:, [8, 9]].any(axis=1)
        num_max_labels = counts_eq_max.sum(axis=1)
        constraint2_violation = num_max_labels > 3
        total_votes = self.counts.sum(axis=1)
        constraint3_violation = max_counts <= (total_votes / 2)
        valid_samples = ~(
            constraint1_violation | constraint2_violation | constraint3_violation
        )
        self.file_paths = self.file_paths[valid_samples]
        self.counts = self.counts[valid_samples]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        if self.phase == 'train':
            path = '/data/FER2013/FER2013Train/' + self.file_paths[idx]
        elif self.phase == 'valid':
            path = '/data/FER2013/FER2013Valid/' + self.file_paths[idx]
        elif self.phase == 'test':
            path = '/data/FER2013/FER2013Test/' + self.file_paths[idx]
        image = Image.open(path).convert('RGB')
        label = self.labels[idx]

        if self.transform is not None:
            image = self.transform(image)

        return image, label

# 데이터셋 준비
train_csv = '/data/FER2013/train_label.csv'
val_csv = '/data/FER2013/valid_label.csv'
test_csv = '/data/FER2013/test_label.csv'
batch_size = 128
lr = 0.01
workers = 4
epochs = 60

train_dataset = FERPlusDataset(train_csv, phase='train', transform=preprocess)
val_dataset = FERPlusDataset(val_csv, phase='valid', transform=preprocess)
test_dataset = FERPlusDataset(test_csv, phase='test', transform=preprocess)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DotProductAttention(nn.Module):
    def __init__(self, input_dim):
        super(DotProductAttention, self).__init__()
        self.input_dim = input_dim
        self.fc_q = nn.Linear(input_dim, input_dim)
        self.fc_k = nn.Linear(input_dim, input_dim)
        self.fc_v = nn.Linear(input_dim, input_dim)
    
    def forward(self, x):
        # Query, Key, Value 계산
        Q = self.fc_q(x)
        K = self.fc_k(x)
        V = self.fc_v(x)
        
        # Dot Product Attention
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.input_dim ** 0.5)  # Scaled Dot-Product Attention
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # Attention을 곱해 최종 출력 계산
        output = torch.matmul(attention_weights, V)
        return output

class FERClassifier(nn.Module):
    def __init__(self, num_classes):
        super(FERClassifier, self).__init__()
        self.image_encoder = model.visual
        for param in self.image_encoder.parameters():
            param.requires_grad = False  # 이미지 인코더 가중치 고정
        
        # FC 1 (차원 축소)
        self.fc1 = nn.Linear(768, 512)
        self.bn1 = nn.BatchNorm1d(512)  # 배치 정규화 추가
        self.dropout1 = nn.Dropout(0.3)  # 드롭아웃 추가 (확률 30%)

        # Self-Attention
        self.attention = DotProductAttention(input_dim=512)  # Attention 적용

        # FC 2 (최종 분류)
        self.fc2 = nn.Linear(512, num_classes)
        self.bn2 = nn.BatchNorm1d(num_classes)  # 최종 출력에 배치 정규화 추가
        self.dropout2 = nn.Dropout(0.3)  # 드롭아웃 추가

    def forward(self, x):
        # 이미지 인코딩
        x = self.image_encoder(x)
        
        # FC 1
        x = self.fc1(x)
        x = self.bn1(x)  # 배치 정규화 적용
        x = F.relu(x)
        x = self.dropout1(x)  # 드롭아웃 적용
        
        # Self-Attention
        x = self.attention(x)
        
        # FC 2 (최종 분류)
        x = self.fc2(x)
        x = self.bn2(x)  # 배치 정규화 적용
        x = self.dropout2(x)  # 드롭아웃 적용
        return x

# 모델 초기화
num_classes = len(emotion_labels)  # FERPlus 데이터셋의 클래스 수
classifier = FERClassifier(num_classes).to(device)

In [4]:
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

# 손실 함수와 옵티마이저 정의
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

# 모델 전체를 float32로 변환
classifier = classifier.float()

# 학습 함수에서 데이터 타입 변환 제거
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs):
    for epoch in range(epochs):
        model.train()
        train_loss, correct, total = 0.0, 0, 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Metrics
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

        train_acc = correct / total * 100
        print(f"Epoch {epoch+1}, Loss: {train_loss/total:.4f}, Accuracy: {train_acc:.2f}%")

        # Validation
        val_acc = evaluate_model(model, val_loader)
        print(f"Validation Accuracy: {val_acc:.2f}%")

        # Scheduler step
        scheduler.step()

# 모델 평가 함수 정의
def evaluate_model(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
    return correct / total * 100

# 학습 시작
train_model(classifier, train_loader, val_loader, criterion, optimizer, scheduler, epochs=epochs)


Epoch 1/60:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 1/60: 100%|██████████| 196/196 [04:25<00:00,  1.36s/it]

Epoch 1, Loss: 0.0090, Accuracy: 70.50%





Validation Accuracy: 87.31%


Epoch 2/60: 100%|██████████| 196/196 [04:24<00:00,  1.35s/it]

Epoch 2, Loss: 0.0075, Accuracy: 75.76%





Validation Accuracy: 86.31%


Epoch 3/60: 100%|██████████| 196/196 [04:24<00:00,  1.35s/it]

Epoch 3, Loss: 0.0070, Accuracy: 76.08%





Validation Accuracy: 87.84%


Epoch 4/60: 100%|██████████| 196/196 [04:24<00:00,  1.35s/it]

Epoch 4, Loss: 0.0066, Accuracy: 77.24%





Validation Accuracy: 88.53%


Epoch 5/60: 100%|██████████| 196/196 [04:29<00:00,  1.38s/it]

Epoch 5, Loss: 0.0063, Accuracy: 78.31%





Validation Accuracy: 87.68%


Epoch 6/60: 100%|██████████| 196/196 [04:41<00:00,  1.44s/it]

Epoch 6, Loss: 0.0061, Accuracy: 79.10%





Validation Accuracy: 88.15%


Epoch 7/60: 100%|██████████| 196/196 [04:40<00:00,  1.43s/it]

Epoch 7, Loss: 0.0059, Accuracy: 79.59%





Validation Accuracy: 88.28%


Epoch 8/60: 100%|██████████| 196/196 [04:41<00:00,  1.44s/it]

Epoch 8, Loss: 0.0058, Accuracy: 79.68%





Validation Accuracy: 88.25%


Epoch 9/60: 100%|██████████| 196/196 [04:41<00:00,  1.43s/it]

Epoch 9, Loss: 0.0057, Accuracy: 80.44%





Validation Accuracy: 88.44%


Epoch 10/60: 100%|██████████| 196/196 [04:41<00:00,  1.43s/it]

Epoch 10, Loss: 0.0056, Accuracy: 80.92%





Validation Accuracy: 87.84%


Epoch 11/60: 100%|██████████| 196/196 [04:35<00:00,  1.41s/it]

Epoch 11, Loss: 0.0052, Accuracy: 82.11%





Validation Accuracy: 89.06%


Epoch 12/60: 100%|██████████| 196/196 [04:36<00:00,  1.41s/it]

Epoch 12, Loss: 0.0052, Accuracy: 82.40%





Validation Accuracy: 89.22%


Epoch 13/60: 100%|██████████| 196/196 [04:32<00:00,  1.39s/it]

Epoch 13, Loss: 0.0050, Accuracy: 83.25%





Validation Accuracy: 89.44%


Epoch 14/60: 100%|██████████| 196/196 [04:24<00:00,  1.35s/it]

Epoch 14, Loss: 0.0049, Accuracy: 83.42%





Validation Accuracy: 89.06%


Epoch 15/60: 100%|██████████| 196/196 [04:41<00:00,  1.44s/it]

Epoch 15, Loss: 0.0050, Accuracy: 82.99%





Validation Accuracy: 89.28%


Epoch 16/60: 100%|██████████| 196/196 [04:32<00:00,  1.39s/it]

Epoch 16, Loss: 0.0049, Accuracy: 83.57%





Validation Accuracy: 89.22%


Epoch 17/60: 100%|██████████| 196/196 [04:24<00:00,  1.35s/it]

Epoch 17, Loss: 0.0049, Accuracy: 83.69%





Validation Accuracy: 89.31%


Epoch 18/60: 100%|██████████| 196/196 [04:24<00:00,  1.35s/it]

Epoch 18, Loss: 0.0048, Accuracy: 83.82%





Validation Accuracy: 89.47%


Epoch 19/60: 100%|██████████| 196/196 [04:24<00:00,  1.35s/it]

Epoch 19, Loss: 0.0048, Accuracy: 83.98%





Validation Accuracy: 89.75%


Epoch 20/60: 100%|██████████| 196/196 [04:24<00:00,  1.35s/it]

Epoch 20, Loss: 0.0048, Accuracy: 83.80%





Validation Accuracy: 89.38%


Epoch 21/60: 100%|██████████| 196/196 [04:24<00:00,  1.35s/it]

Epoch 21, Loss: 0.0048, Accuracy: 83.63%





Validation Accuracy: 89.44%


Epoch 22/60: 100%|██████████| 196/196 [04:24<00:00,  1.35s/it]

Epoch 22, Loss: 0.0047, Accuracy: 84.52%





Validation Accuracy: 89.60%


Epoch 23/60: 100%|██████████| 196/196 [04:24<00:00,  1.35s/it]

Epoch 23, Loss: 0.0047, Accuracy: 84.31%





Validation Accuracy: 89.35%


Epoch 24/60: 100%|██████████| 196/196 [04:24<00:00,  1.35s/it]

Epoch 24, Loss: 0.0047, Accuracy: 84.26%





Validation Accuracy: 89.53%


Epoch 25/60: 100%|██████████| 196/196 [04:24<00:00,  1.35s/it]

Epoch 25, Loss: 0.0048, Accuracy: 83.98%





Validation Accuracy: 89.19%


Epoch 26/60: 100%|██████████| 196/196 [04:52<00:00,  1.49s/it]

Epoch 26, Loss: 0.0047, Accuracy: 84.23%





Validation Accuracy: 89.69%


Epoch 27/60: 100%|██████████| 196/196 [05:18<00:00,  1.62s/it]

Epoch 27, Loss: 0.0047, Accuracy: 84.21%





Validation Accuracy: 89.44%


Epoch 28/60: 100%|██████████| 196/196 [05:18<00:00,  1.62s/it]

Epoch 28, Loss: 0.0047, Accuracy: 83.98%





Validation Accuracy: 89.50%


Epoch 29/60: 100%|██████████| 196/196 [05:18<00:00,  1.62s/it]

Epoch 29, Loss: 0.0047, Accuracy: 84.24%





Validation Accuracy: 89.63%


Epoch 30/60: 100%|██████████| 196/196 [05:18<00:00,  1.62s/it]

Epoch 30, Loss: 0.0047, Accuracy: 83.98%





Validation Accuracy: 89.69%


Epoch 31/60: 100%|██████████| 196/196 [05:18<00:00,  1.62s/it]

Epoch 31, Loss: 0.0047, Accuracy: 84.34%





Validation Accuracy: 89.53%


Epoch 32/60: 100%|██████████| 196/196 [04:53<00:00,  1.50s/it]

Epoch 32, Loss: 0.0047, Accuracy: 84.30%





Validation Accuracy: 89.75%


Epoch 33/60: 100%|██████████| 196/196 [04:24<00:00,  1.35s/it]

Epoch 33, Loss: 0.0046, Accuracy: 84.26%





Validation Accuracy: 89.47%


Epoch 34/60: 100%|██████████| 196/196 [04:24<00:00,  1.35s/it]

Epoch 34, Loss: 0.0046, Accuracy: 84.51%





Validation Accuracy: 89.66%


Epoch 35/60: 100%|██████████| 196/196 [04:25<00:00,  1.35s/it]

Epoch 35, Loss: 0.0047, Accuracy: 84.42%





Validation Accuracy: 89.60%


Epoch 36/60: 100%|██████████| 196/196 [04:25<00:00,  1.35s/it]

Epoch 36, Loss: 0.0046, Accuracy: 84.18%





Validation Accuracy: 89.60%


Epoch 37/60: 100%|██████████| 196/196 [04:25<00:00,  1.35s/it]

Epoch 37, Loss: 0.0047, Accuracy: 84.29%





Validation Accuracy: 89.53%


Epoch 38/60: 100%|██████████| 196/196 [04:25<00:00,  1.35s/it]

Epoch 38, Loss: 0.0047, Accuracy: 84.32%





Validation Accuracy: 89.53%


Epoch 39/60: 100%|██████████| 196/196 [04:25<00:00,  1.35s/it]

Epoch 39, Loss: 0.0046, Accuracy: 84.40%





Validation Accuracy: 89.82%


Epoch 40/60: 100%|██████████| 196/196 [04:25<00:00,  1.35s/it]

Epoch 40, Loss: 0.0046, Accuracy: 84.66%





Validation Accuracy: 89.69%


Epoch 41/60: 100%|██████████| 196/196 [04:25<00:00,  1.35s/it]

Epoch 41, Loss: 0.0047, Accuracy: 84.27%





Validation Accuracy: 89.75%


Epoch 42/60: 100%|██████████| 196/196 [04:25<00:00,  1.35s/it]

Epoch 42, Loss: 0.0047, Accuracy: 84.36%





Validation Accuracy: 89.50%


Epoch 43/60: 100%|██████████| 196/196 [04:25<00:00,  1.35s/it]

Epoch 43, Loss: 0.0047, Accuracy: 84.21%





Validation Accuracy: 89.60%


Epoch 44/60: 100%|██████████| 196/196 [04:25<00:00,  1.35s/it]

Epoch 44, Loss: 0.0047, Accuracy: 84.15%





Validation Accuracy: 89.35%


Epoch 45/60: 100%|██████████| 196/196 [04:25<00:00,  1.35s/it]

Epoch 45, Loss: 0.0047, Accuracy: 84.47%





Validation Accuracy: 89.53%


Epoch 46/60: 100%|██████████| 196/196 [04:25<00:00,  1.35s/it]

Epoch 46, Loss: 0.0047, Accuracy: 83.99%





Validation Accuracy: 89.72%


Epoch 47/60: 100%|██████████| 196/196 [04:25<00:00,  1.35s/it]

Epoch 47, Loss: 0.0046, Accuracy: 84.62%





Validation Accuracy: 89.56%


Epoch 48/60: 100%|██████████| 196/196 [04:55<00:00,  1.51s/it]

Epoch 48, Loss: 0.0046, Accuracy: 84.48%





Validation Accuracy: 89.50%


Epoch 49/60: 100%|██████████| 196/196 [04:25<00:00,  1.35s/it]

Epoch 49, Loss: 0.0047, Accuracy: 84.28%





Validation Accuracy: 89.72%


Epoch 50/60:  18%|█▊        | 36/196 [00:50<03:44,  1.41s/it]


KeyboardInterrupt: 

In [5]:
# 테스트 데이터 평가
def test_model(model, test_loader):
    model.eval()  # 평가 모드
    correct, total = 0, 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            
            # 총 정확도 계산
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

            # 자세한 결과 저장 (선택적)
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    # 정확도 계산
    accuracy = correct / total * 100
    print(f"Test Accuracy: {accuracy:.2f}%")

    return accuracy, all_labels, all_predictions


# 테스트 데이터 평가 시작
test_accuracy, test_labels, test_predictions = test_model(classifier, test_loader)

# 테스트 데이터의 구체적인 결과 출력 (선택적)
from sklearn.metrics import classification_report
print(classification_report(test_labels, test_predictions, target_names=emotion_labels))


Test Accuracy: 89.06%
              precision    recall  f1-score   support

     neutral       0.88      0.91      0.90      1083
   happiness       0.95      0.96      0.95       892
    surprise       0.89      0.89      0.89       394
     sadness       0.79      0.81      0.80       382
       anger       0.89      0.91      0.90       269
     disgust       0.50      0.12      0.20        16
        fear       0.83      0.51      0.63        86
    contempt       0.00      0.00      0.00        14

    accuracy                           0.89      3136
   macro avg       0.72      0.64      0.66      3136
weighted avg       0.88      0.89      0.89      3136



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


# CLIP Patt

In [1]:
import torch
import clip
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import os
import warnings
from tqdm import tqdm

from PIL import Image
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import transforms

from sklearn.metrics import balanced_accuracy_score

# CLIP 모델 로드
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
#model, preprocess = clip.load("ViT-L/14", device=device)

# 텍스트 레이블 정의
emotion_labels = ["neutral", "happiness", "surprise", "sadness", "anger", "disgust", "fear", "contempt"]
text_inputs = clip.tokenize([f"a photo of a person showing {label}" for label in emotion_labels]).to(device)


In [2]:
class FERPlusDataset(data.Dataset):
    def __init__(self, data_csv, phase, transform=None):
        self.phase = phase
        self.transform = transform

        # Read the dataset CSV file
        self.data = pd.read_csv(data_csv)
        self.data.iloc[:, 2:12] = self.data.iloc[:, 2:12].replace(1, 0)

        # Get file paths and labels
        self.file_paths = self.data.iloc[:, 0].values
        self.counts = self.data.iloc[:, 2:12].values

        # Apply constraints to filter valid samples
        self._apply_constraints()

        # Use argmax to determine the emotion class
        self.labels = np.argmax(self.counts, axis=1)

    def _apply_constraints(self):
        max_counts = self.counts.max(axis=1)
        counts_eq_max = (self.counts == max_counts[:, None])
        constraint1_violation = counts_eq_max[:, [8, 9]].any(axis=1)
        num_max_labels = counts_eq_max.sum(axis=1)
        constraint2_violation = num_max_labels > 3
        total_votes = self.counts.sum(axis=1)
        constraint3_violation = max_counts <= (total_votes / 2)
        valid_samples = ~(
            constraint1_violation | constraint2_violation | constraint3_violation
        )
        self.file_paths = self.file_paths[valid_samples]
        self.counts = self.counts[valid_samples]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        if self.phase == 'train':
            path = '/data/FER2013/FER2013Train/' + self.file_paths[idx]
        elif self.phase == 'valid':
            path = '/data/FER2013/FER2013Valid/' + self.file_paths[idx]
        elif self.phase == 'test':
            path = '/data/FER2013/FER2013Test/' + self.file_paths[idx]
        image = Image.open(path).convert('RGB')
        label = self.labels[idx]

        if self.transform is not None:
            image = self.transform(image)

        return image, label

# 데이터셋 준비
train_csv = '/data/FER2013/train_label.csv'
val_csv = '/data/FER2013/valid_label.csv'
test_csv = '/data/FER2013/test_label.csv'
batch_size = 128
lr = 0.01
workers = 4
epochs = 60

train_dataset = FERPlusDataset(train_csv, phase='train', transform=preprocess)
val_dataset = FERPlusDataset(val_csv, phase='valid', transform=preprocess)
test_dataset = FERPlusDataset(test_csv, phase='test', transform=preprocess)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Dot-Product Self-Attention 모듈
class DotProductAttention(nn.Module):
    def __init__(self, input_dim):
        super(DotProductAttention, self).__init__()
        self.fc_q = nn.Linear(input_dim, input_dim)
        self.fc_k = nn.Linear(input_dim, input_dim)
        self.fc_v = nn.Linear(input_dim, input_dim)

    def forward(self, x):
        Q = self.fc_q(x)
        K = self.fc_k(x)
        V = self.fc_v(x)

        # Scaled Dot-Product Attention
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (Q.size(-1) ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        return output

# FERClassifier 모델
class FERClassifier(nn.Module):
    def __init__(self, model, num_classes):
        super(FERClassifier, self).__init__()
        self.image_encoder = model.visual  # CLIP 이미지 인코더
        for param in self.image_encoder.parameters():
            param.requires_grad = False  # 가중치 고정

        # Patch Extraction
        self.separable_conv1 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=1, stride=1, groups=1),  # groups=1로 수정
            nn.ReLU()
        )
        self.separable_conv2 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=1, stride=1, groups=1),  # groups=1로 수정
            nn.ReLU()
        )
        self.pointwise_conv = nn.Conv2d(256, 256, kernel_size=1)  # Pointwise Conv

        # GAP (Global Average Pooling)
        self.gap = nn.AdaptiveAvgPool2d(1)

        # Attention Classifier
        self.fc1 = nn.Linear(256, 256)
        self.attention = DotProductAttention(input_dim=256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        # 이미지 인코딩
        x = self.image_encoder(x)  # (B, 512)
        x = x.unsqueeze(-1).unsqueeze(-1)  # (B, 512, 1, 1)

        # Patch Extraction
        x = self.separable_conv1(x)  # (B, 256, 1, 1)
        x = self.separable_conv2(x)  # (B, 256, 1, 1)
        x = self.pointwise_conv(x)  # (B, 256, 1, 1)

        # GAP
        x = self.gap(x)  # (B, 256, 1, 1)
        x = x.view(x.size(0), -1)  # (B, 256)

        # Attention Classifier
        x = F.relu(self.fc1(x))
        x = self.attention(x)
        x = self.fc2(x)
        return x

# 모델 초기화
num_classes = len(emotion_labels)  # FERPlus 데이터셋 클래스 수
classifier = FERClassifier(model, num_classes).to(device)


In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

# 손실 함수와 옵티마이저 정의
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

# 모델 전체를 float32로 변환
classifier = classifier.float()

# 학습 함수에서 데이터 타입 변환 제거
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs):
    for epoch in range(epochs):
        model.train()
        train_loss, correct, total = 0.0, 0, 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Metrics
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

        train_acc = correct / total * 100
        print(f"Epoch {epoch+1}, Loss: {train_loss/total:.4f}, Accuracy: {train_acc:.2f}%")

        # Validation
        val_acc = evaluate_model(model, val_loader)
        print(f"Validation Accuracy: {val_acc:.2f}%")

        # Scheduler step
        scheduler.step()

# 모델 평가 함수 정의
def evaluate_model(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
    return correct / total * 100

# 학습 시작
train_model(classifier, train_loader, val_loader, criterion, optimizer, scheduler, epochs=epochs)


Epoch 1/60:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 1/60: 100%|██████████| 196/196 [00:21<00:00,  9.05it/s]

Epoch 1, Loss: 0.0120, Accuracy: 35.64%





Validation Accuracy: 57.47%


Epoch 2/60: 100%|██████████| 196/196 [00:21<00:00,  9.13it/s]

Epoch 2, Loss: 0.0075, Accuracy: 66.53%





Validation Accuracy: 71.67%


Epoch 3/60: 100%|██████████| 196/196 [00:21<00:00,  9.30it/s]

Epoch 3, Loss: 0.0076, Accuracy: 66.82%





Validation Accuracy: 73.43%


Epoch 4/60: 100%|██████████| 196/196 [00:20<00:00,  9.35it/s]

Epoch 4, Loss: 0.0054, Accuracy: 76.61%





Validation Accuracy: 78.19%


Epoch 5/60: 100%|██████████| 196/196 [00:20<00:00,  9.34it/s]

Epoch 5, Loss: 0.0047, Accuracy: 80.04%





Validation Accuracy: 81.70%


Epoch 6/60: 100%|██████████| 196/196 [00:20<00:00,  9.36it/s]

Epoch 6, Loss: 0.0042, Accuracy: 82.91%





Validation Accuracy: 82.11%


Epoch 7/60: 100%|██████████| 196/196 [00:21<00:00,  9.30it/s]

Epoch 7, Loss: 0.0040, Accuracy: 83.77%





Validation Accuracy: 83.48%


Epoch 8/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 8, Loss: 0.0038, Accuracy: 84.77%





Validation Accuracy: 83.48%


Epoch 9/60: 100%|██████████| 196/196 [00:21<00:00,  9.33it/s]

Epoch 9, Loss: 0.0037, Accuracy: 85.07%





Validation Accuracy: 83.95%


Epoch 10/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 10, Loss: 0.0035, Accuracy: 85.43%





Validation Accuracy: 83.20%


Epoch 11/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 11, Loss: 0.0031, Accuracy: 86.68%





Validation Accuracy: 84.52%


Epoch 12/60: 100%|██████████| 196/196 [00:20<00:00,  9.33it/s]

Epoch 12, Loss: 0.0030, Accuracy: 86.98%





Validation Accuracy: 84.08%


Epoch 13/60: 100%|██████████| 196/196 [00:20<00:00,  9.33it/s]

Epoch 13, Loss: 0.0030, Accuracy: 87.23%





Validation Accuracy: 84.46%


Epoch 14/60: 100%|██████████| 196/196 [00:21<00:00,  9.32it/s]

Epoch 14, Loss: 0.0030, Accuracy: 87.45%





Validation Accuracy: 84.02%


Epoch 15/60: 100%|██████████| 196/196 [00:21<00:00,  9.33it/s]

Epoch 15, Loss: 0.0029, Accuracy: 87.54%





Validation Accuracy: 84.21%


Epoch 16/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 16, Loss: 0.0029, Accuracy: 87.75%





Validation Accuracy: 84.39%


Epoch 17/60: 100%|██████████| 196/196 [00:21<00:00,  9.32it/s]

Epoch 17, Loss: 0.0029, Accuracy: 87.67%





Validation Accuracy: 84.61%


Epoch 18/60: 100%|██████████| 196/196 [00:21<00:00,  9.30it/s]

Epoch 18, Loss: 0.0028, Accuracy: 87.97%





Validation Accuracy: 84.27%


Epoch 19/60: 100%|██████████| 196/196 [00:21<00:00,  9.33it/s]

Epoch 19, Loss: 0.0028, Accuracy: 87.89%





Validation Accuracy: 84.46%


Epoch 20/60: 100%|██████████| 196/196 [00:21<00:00,  9.27it/s]

Epoch 20, Loss: 0.0028, Accuracy: 88.16%





Validation Accuracy: 84.39%


Epoch 21/60: 100%|██████████| 196/196 [00:21<00:00,  9.32it/s]

Epoch 21, Loss: 0.0027, Accuracy: 88.39%





Validation Accuracy: 84.46%


Epoch 22/60: 100%|██████████| 196/196 [00:20<00:00,  9.35it/s]

Epoch 22, Loss: 0.0027, Accuracy: 88.31%





Validation Accuracy: 84.49%


Epoch 23/60: 100%|██████████| 196/196 [00:21<00:00,  9.32it/s]

Epoch 23, Loss: 0.0027, Accuracy: 88.61%





Validation Accuracy: 84.49%


Epoch 24/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 24, Loss: 0.0027, Accuracy: 88.46%





Validation Accuracy: 84.46%


Epoch 25/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 25, Loss: 0.0027, Accuracy: 88.64%





Validation Accuracy: 84.52%


Epoch 26/60: 100%|██████████| 196/196 [00:21<00:00,  9.30it/s]

Epoch 26, Loss: 0.0027, Accuracy: 88.53%





Validation Accuracy: 84.46%


Epoch 27/60: 100%|██████████| 196/196 [00:21<00:00,  9.32it/s]

Epoch 27, Loss: 0.0027, Accuracy: 88.54%





Validation Accuracy: 84.55%


Epoch 28/60: 100%|██████████| 196/196 [00:21<00:00,  9.30it/s]

Epoch 28, Loss: 0.0027, Accuracy: 88.56%





Validation Accuracy: 84.58%


Epoch 29/60: 100%|██████████| 196/196 [00:20<00:00,  9.34it/s]

Epoch 29, Loss: 0.0027, Accuracy: 88.55%





Validation Accuracy: 84.42%


Epoch 30/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 30, Loss: 0.0027, Accuracy: 88.50%





Validation Accuracy: 84.49%


Epoch 31/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 31, Loss: 0.0026, Accuracy: 88.68%





Validation Accuracy: 84.46%


Epoch 32/60: 100%|██████████| 196/196 [00:20<00:00,  9.35it/s]

Epoch 32, Loss: 0.0027, Accuracy: 88.49%





Validation Accuracy: 84.46%


Epoch 33/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 33, Loss: 0.0027, Accuracy: 88.56%





Validation Accuracy: 84.42%


Epoch 34/60: 100%|██████████| 196/196 [00:21<00:00,  9.31it/s]

Epoch 34, Loss: 0.0027, Accuracy: 88.55%





Validation Accuracy: 84.42%


Epoch 35/60: 100%|██████████| 196/196 [00:21<00:00,  9.32it/s]

Epoch 35, Loss: 0.0027, Accuracy: 88.58%





Validation Accuracy: 84.42%


Epoch 36/60: 100%|██████████| 196/196 [00:20<00:00,  9.38it/s]

Epoch 36, Loss: 0.0027, Accuracy: 88.65%





Validation Accuracy: 84.46%


Epoch 37/60:  18%|█▊        | 36/196 [00:04<00:16,  9.50it/s]

In [None]:
# 테스트 데이터 평가
def test_model(model, test_loader):
    model.eval()  # 평가 모드
    correct, total = 0, 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            
            # 총 정확도 계산
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

            # 자세한 결과 저장 (선택적)
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    # 정확도 계산
    accuracy = correct / total * 100
    print(f"Test Accuracy: {accuracy:.2f}%")

    return accuracy, all_labels, all_predictions


# 테스트 데이터 평가 시작
test_accuracy, test_labels, test_predictions = test_model(classifier, test_loader)

# 테스트 데이터의 구체적인 결과 출력 (선택적)
from sklearn.metrics import classification_report
print(classification_report(test_labels, test_predictions, target_names=emotion_labels))


Test Accuracy: 85.52%
              precision    recall  f1-score   support

     neutral       0.88      0.89      0.88      1083
   happiness       0.92      0.93      0.93       892
    surprise       0.84      0.86      0.85       394
     sadness       0.76      0.75      0.76       382
       anger       0.79      0.82      0.81       269
     disgust       0.00      0.00      0.00        16
        fear       0.55      0.45      0.50        86
    contempt       0.00      0.00      0.00        14

    accuracy                           0.86      3136
   macro avg       0.59      0.59      0.59      3136
weighted avg       0.85      0.86      0.85      3136



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


# Latent-CLIP

In [1]:
from torch import nn
from torch.nn import functional as F
import torch
import torch.nn.init as init
from torchvision import models
from clip import load  # CLIP 모델 불러오기
from torchvision import transforms

class LatentOFER(nn.Module):
    def __init__(self, clip_model, num_class=7, num_head=4, pretrained=True):
        super(LatentOFER, self).__init__()
        self.image_encoder = clip_model.visual  # CLIP 이미지 인코더 사용
        self.num_head = num_head
        
        for i in range(num_head):
            setattr(self, "cat_head%d" % i, CrossAttentionHead())

        self.sig = nn.Sigmoid()
        self.hh_layer = nn.Linear(39936, 10000)
        self.hh_layer2 = nn.Linear(10000, 5000)
        self.hh_batch1 = nn.BatchNorm1d(5000)
        self.hh_layer3 = nn.Linear(5000, 1024)
        self.hh_layer4 = nn.Linear(1024, 256)
        self.batch_norm1 = nn.BatchNorm1d(256)

        self.fc = nn.Linear(512 + 256, 256)  # CLIP 출력 차원 512 사용
        self.fc2 = nn.Linear(256, 128)
        self.batch_norm = nn.BatchNorm1d(128)
        self.fc4 = nn.Linear(128, num_class)
        self.bn = nn.BatchNorm1d(num_class)

    def forward(self, x, latent):
        x = self.image_encoder(x)  # CLIP 이미지 인코딩 (B, 512)
        x = x.unsqueeze(-1).unsqueeze(-1)  # (B, 512, 1, 1): 차원 확장

        heads = []
        for i in range(self.num_head):
            heads.append(getattr(self, "cat_head%d" % i)(x))
        
        heads = torch.stack(heads).permute([1, 0, 2])
        if heads.size(1) > 1:
            heads = F.log_softmax(heads, dim=1)
            heads = heads.sum(dim=1)

        # Latent 인코딩
        latent = self.hh_layer(latent)
        latent = self.hh_layer2(latent)
        latent = self.hh_batch1(latent)
        latent = self.hh_layer3(latent)
        latent = self.hh_layer4(latent)
        latent = self.batch_norm1(latent)

        # Concatenation 후 Fully Connected 적용
        out = torch.cat([heads, latent], dim=1)
        out = self.fc(out)
        out = self.fc2(out)
        out = self.batch_norm(out)
        out = self.fc4(out)
        out = self.bn(out)

        return out, x, heads



class CrossAttentionHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.sa = SpatialAttention()
        self.ca = ChannelAttention()
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        ca = self.ca(x)
        sa = self.sa(ca)
        return sa


class SpatialAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=1),
            nn.BatchNorm2d(256),
        )
        self.conv_3x3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
        )
        self.conv_1x3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=(1, 3), padding=(0, 1)),
            nn.BatchNorm2d(512),
        )
        self.conv_3x1 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=(3, 1), padding=(1, 0)),
            nn.BatchNorm2d(512),
        )
        self.relu = nn.ReLU()
        self.gap = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        y = self.conv1x1(x)
        a = self.conv_3x3(y)
        b = self.conv_1x3(y)
        c = self.conv_3x1(y)

        y = self.relu(a + b + c)
        y = y.sum(dim=1, keepdim=True)

        out = x * y
        out = self.gap(out)
        out = out.view(out.size(0), -1)
        return out


class ChannelAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.attention = nn.Sequential(
            nn.Linear(512, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 512),
            nn.Sigmoid()
        )

    def forward(self, sa):
        sa2 = self.gap(sa)
        sa2 = sa2.view(sa2.size(0), -1)
        y = self.attention(sa2)
        y = y.unsqueeze(dim=-1).unsqueeze(dim=-1)

        out = sa * y
        return out


# CLIP 모델 로드 및 LatentOFER 모델 초기화
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_model, _ = load("ViT-B/32", device=device)  # CLIP 모델 로드
num_classes = 8 

model = LatentOFER(clip_model=clip_model, num_class=num_classes).to(device).float()  


In [2]:
import dlib
import pandas as pd
import numpy as np
from PIL import Image
import cv2
from torch.utils.data import Dataset
from tqdm import tqdm

from sklearn.metrics import balanced_accuracy_score

class FERPlusDataset(Dataset):
    def __init__(self, data_csv, phase, transform=None):
        self.phase = phase
        self.transform = transform

        # Read the dataset CSV file
        self.data = pd.read_csv(data_csv)
        self.data.iloc[:, 2:12] = self.data.iloc[:, 2:12].replace(1, 0)

        # Get file paths and labels
        self.file_paths = self.data.iloc[:, 0].values
        self.counts = self.data.iloc[:, 2:12].values  # 감정 점수들

        # Apply constraints to filter valid samples
        self._apply_constraints()

        # Use argmax to determine the emotion class
        self.labels = np.argmax(self.counts, axis=1)

        # Debugging: Check label range
        print("Unique labels in dataset after filtering:", np.unique(self.labels))

        # Dlib face detector and predictor initialization
        self.detector = dlib.get_frontal_face_detector()
        self.predictor = dlib.shape_predictor('/root/FER2013/shape_predictor_68_face_landmarks.dat')  # Dlib 모델 파일 경로 필요

    def _apply_constraints(self):
        # Constraint : 'unknown-face' 또는 'not-face' 레이블 제거
        max_counts = self.counts.max(axis=1)
        counts_eq_max = (self.counts == max_counts[:, None])
        constraint1_violation = counts_eq_max[:, [8, 9]].any(axis=1)

        # Constraint : 최대 투표 수를 가진 레이블이 3개 초과 제거
        num_max_labels = counts_eq_max.sum(axis=1)
        constraint2_violation = num_max_labels > 3

        # Constraint : 최대 투표 수가 전체 투표 수의 절반 이하인 경우 제거
        total_votes = self.counts.sum(axis=1)
        constraint3_violation = max_counts <= (total_votes / 2)

        # Combine constraints
        valid_samples = ~(
            constraint1_violation | constraint2_violation | constraint3_violation
        )

        # Apply valid samples filter
        self.file_paths = self.file_paths[valid_samples]
        self.counts = self.counts[valid_samples]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        if self.phase == 'train':
            path = '/data/FER2013/FER2013Train/' + self.file_paths[idx]
        elif self.phase == 'val':
            path = '/data/FER2013/FER2013Valid/' + self.file_paths[idx]
        elif self.phase == 'test':
            path = '/data/FER2013/FER2013Test/' + self.file_paths[idx]
        
        # Open image
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Align face using Dlib
        image = self._align_face(image)

        # Convert to PIL image for further processing
        image = Image.fromarray(image)

        label = self.labels[idx]

        if self.transform is not None:
            image = self.transform(image)

        return image, label

    def _align_face(self, image):
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        faces = self.detector(gray)

        if len(faces) == 0:
            return image  # 얼굴이 감지되지 않으면 원본 이미지를 반환

        for face in faces:
            landmarks = self.predictor(gray, face)

            # 좌우 눈의 중심 좌표 추출
            left_eye = (landmarks.part(36).x, landmarks.part(36).y)
            right_eye = (landmarks.part(45).x, landmarks.part(45).y)

            # 두 눈의 중심 계산
            eye_center = ((left_eye[0] + right_eye[0]) // 2, (left_eye[1] + right_eye[1]) // 2)

            # 눈 사이의 기울기 계산
            delta_x = right_eye[0] - left_eye[0]
            delta_y = right_eye[1] - left_eye[1]
            angle = np.degrees(np.arctan2(delta_y, delta_x))

            # 회전 행렬 계산
            rot_matrix = cv2.getRotationMatrix2D(eye_center, angle, 1.0)

            # 이미지 회전 및 정렬
            aligned_face = cv2.warpAffine(image, rot_matrix, (image.shape[1], image.shape[0]))

            return aligned_face  # 첫 번째 얼굴만 처리

        return image


In [3]:
# Define variables
train_csv = '/data/FER2013/train_label.csv'
val_csv = '/data/FER2013/valid_label.csv'
test_csv = '/data/FER2013/test_label.csv'
batch_size = 16
lr = 0.001
workers = 4
epochs = 60

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = True

data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([
            transforms.RandomRotation(20),
            transforms.RandomCrop(224, padding=32)
        ], p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(scale=(0.02, 0.25)),
])

train_dataset = FERPlusDataset(train_csv, phase='train', transform=data_transforms)
print('Whole train set size:', train_dataset.__len__())

train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=batch_size,
                                            num_workers=workers,
                                            shuffle=True,
                                            pin_memory=True)

data_transforms_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
])

val_dataset = FERPlusDataset(val_csv, phase='val', transform=data_transforms_val)
print('Validation set size:', val_dataset.__len__())

val_loader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size=batch_size,
                                            num_workers=workers,
                                            shuffle=False,
                                            pin_memory=True)

criterion_cls = torch.nn.CrossEntropyLoss()

params = list(model.parameters())
optimizer = torch.optim.SGD(params, lr=lr, weight_decay=1e-4, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

best_acc = 0
for epoch in tqdm(range(1, epochs + 1)):
    running_loss = 0.0
    correct_sum = 0
    iter_cnt = 0
    model.train()

    for imgs, targets in train_loader:
        iter_cnt += 1
        optimizer.zero_grad()

        imgs = imgs.to(device)
        targets = targets.to(device)
        out, _, _ = model(imgs, torch.zeros(imgs.size(0), 39936).to(device)) # Pass None for latent

        loss = criterion_cls(out, targets)

        loss.backward()
        optimizer.step()

        running_loss += loss
        _, predicts = torch.max(out, 1)
        correct_num = torch.eq(predicts, targets).sum()
        correct_sum += correct_num

    acc = correct_sum.float() / float(train_dataset.__len__())
    running_loss = running_loss / iter_cnt
    tqdm.write('[Epoch %d] Training accuracy: %.4f. Loss: %.3f. LR %.6f' % (epoch, acc, running_loss, optimizer.param_groups[0]['lr']))

    with torch.no_grad():
        running_loss = 0.0
        iter_cnt = 0
        bingo_cnt = 0
        sample_cnt = 0
        baccs = []

        y_true = []
        y_pred = []

        model.eval()
        for imgs, targets in val_loader:
            imgs = imgs.to(device)
            targets = targets.to(device)

            out, _, _ = model(imgs, torch.zeros(imgs.size(0), 39936).to(device))   # Pass None for latent
            loss = criterion_cls(out, targets)
            running_loss += loss
            iter_cnt += 1
            _, predicts = torch.max(out, 1)
            correct_num = torch.eq(predicts, targets)
            bingo_cnt += correct_num.sum().cpu()
            sample_cnt += out.size(0)
            y_true.append(targets.cpu().numpy())
            y_pred.append(predicts.cpu().numpy())

            baccs.append(balanced_accuracy_score(targets.cpu().numpy(), predicts.cpu().numpy()))
        running_loss = running_loss / iter_cnt
        scheduler.step()

        acc = bingo_cnt.float() / float(sample_cnt)
        acc = np.around(acc.numpy(), 4)
        best_acc = max(acc, best_acc)

        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)

        bacc = np.around(np.mean(baccs), 4)
        tqdm.write("[Epoch %d] Validation accuracy:%.4f. bacc:%.4f. Loss:%.3f" % (epoch, acc, bacc, running_loss))
        tqdm.write("best_acc:" + str(best_acc))

Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Whole train set size: 25045
Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Validation set size: 3191


  0%|          | 0/60 [01:55<?, ?it/s]

[Epoch 1] Training accuracy: 0.3440. Loss: 1.729. LR 0.001000


  2%|▏         | 1/60 [02:00<1:58:54, 120.92s/it]

[Epoch 1] Validation accuracy:0.3541. bacc:0.2297. Loss:1.609
best_acc:0.3541


  2%|▏         | 1/60 [03:56<1:58:54, 120.92s/it]

[Epoch 2] Training accuracy: 0.3543. Loss: 1.596. LR 0.001000


  3%|▎         | 2/60 [04:01<1:56:52, 120.90s/it]

[Epoch 2] Validation accuracy:0.3557. bacc:0.2227. Loss:1.583
best_acc:0.3557


  3%|▎         | 2/60 [05:56<1:56:52, 120.90s/it]

[Epoch 3] Training accuracy: 0.3588. Loss: 1.573. LR 0.001000


  5%|▌         | 3/60 [06:01<1:54:20, 120.35s/it]

[Epoch 3] Validation accuracy:0.3676. bacc:0.2319. Loss:1.564
best_acc:0.3676


  5%|▌         | 3/60 [06:50<2:10:08, 136.99s/it]


KeyboardInterrupt: 

In [None]:
def evaluate_model(model, test_loader, device):
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0
    y_true = []
    y_pred = []

    print("Starting evaluation...")
    with torch.no_grad():  # Disable gradient computation
        for i, (imgs, targets) in enumerate(test_loader):

            imgs = imgs.to(device)
            targets = targets.to(device)

            # Forward pass
            outputs, _, _ = model(imgs, torch.zeros(imgs.size(0), 39936).to(device))  # Provide dummy latent input
            _, predictions = torch.max(outputs, 1)  # Get predicted class

            # Collect results
            correct += (predictions == targets).sum().item()
            total += targets.size(0)

            y_true.extend(targets.cpu().numpy())
            y_pred.extend(predictions.cpu().numpy())

    # Calculate accuracy
    accuracy = correct / total if total > 0 else 0  # Prevent division by zero
    print(f"Test Accuracy: {accuracy}")  
    print(f"Test Accuracy: {accuracy}")
    print(f"Test Accuracy: {accuracy}")

    return accuracy
data_transforms_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
])

test_dataset = FERPlusDataset(test_csv, phase='test', transform=data_transforms_test)
print('Test set size:', test_dataset.__len__())

test_loader = torch.utils.data.DataLoader(test_dataset,
                                            batch_size=batch_size,
                                            num_workers=workers,
                                            shuffle=False,
                                            pin_memory=True)

# Evaluate model
acc = evaluate_model(model, test_loader, device)
print(f"Final Test Accuracy: {acc:.4f}")
