In [2]:
import os
import cv2
import csv
import math
import random
import numpy as np
import pandas as pd
import pickle
import sys

import torch
import torch.nn as nn
from torchvision import transforms
import torchvision.models as models
import torch.utils.data as data
import torch.nn.functional as F

# 클립 모델 로드
import clip
device = torch.device('cuda:0')
clip_model, preprocess = clip.load("ViT-B/32", device=device)

from torch.nn.modules.module import Module
from torch.nn.modules.utils import _pair
from torch.nn.parameter import Parameter

# 커스텀 맥스풀링
class my_MaxPool2d(Module):
    def __init__(self, kernel_size, stride=None, padding=0, dilation=1,
                 return_indices=False, ceil_mode=False):
        super(my_MaxPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride or kernel_size
        self.padding = padding
        self.dilation = dilation
        self.return_indices = return_indices
        self.ceil_mode = ceil_mode

    def forward(self, input):
        input = input.transpose(3, 1)
        input = F.max_pool2d(input, self.kernel_size, self.stride,
                             self.padding, self.dilation, self.ceil_mode,
                             self.return_indices)
        input = input.transpose(3, 1).contiguous()
        return input

    def __repr__(self):
        kh, kw = _pair(self.kernel_size)
        dh, dw = _pair(self.stride)
        padh, padw = _pair(self.padding)
        dilh, dilw = _pair(self.dilation)
        padding_str = ', padding=(' + str(padh) + ', ' + str(padw) + ')' \
            if padh != 0 or padw != 0 else ''
        dilation_str = (', dilation=(' + str(dilh) + ', ' + str(dilw) + ')'
                        if dilh != 0 and dilw != 0 else '')
        ceil_str = ', ceil_mode=' + str(self.ceil_mode)
        return self.__class__.__name__ + '(' \
            + 'kernel_size=(' + str(kh) + ', ' + str(kw) + ')' \
            + ', stride=(' + str(dh) + ', ' + str(dw) + ')' \
            + padding_str + dilation_str + ceil_str + ')'

# 커스텀 avg pooling
class my_AvgPool2d(Module):
    def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
                 count_include_pad=True):
        super(my_AvgPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride or kernel_size
        self.padding = padding
        self.ceil_mode = ceil_mode
        self.count_include_pad = count_include_pad

    def forward(self, input):
        input = input.transpose(3, 1)
        input = F.avg_pool2d(input, self.kernel_size, self.stride,
                             self.padding, self.ceil_mode, self.count_include_pad)
        input = input.transpose(3, 1).contiguous()
        return input

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'kernel_size=' + str(self.kernel_size) \
            + ', stride=' + str(self.stride) \
            + ', padding=' + str(self.padding) \
            + ', ceil_mode=' + str(self.ceil_mode) \
            + ', count_include_pad=' + str(self.count_include_pad) + ')'

# 데이터셋 로드 (실제로는 FERPLUS)
class RafDataset(data.Dataset):
    def __init__(self, root_dir, phase, transform=None, apply_constraints=False):
        self.root_dir = root_dir
        self.phase = phase  # 'Train', 'Valid', or 'Test'
        self.transform = transform

        label_path = os.path.join(root_dir, phase, 'label.csv')
        df = pd.read_csv(label_path, header=None)

        # 컬럼 이름 지정
        df.columns = ['filename', 'bbox'] + [f'c{i}' for i in range(10)]

        # 각 예제에 대한 레이블 카운트 가져오기
        counts = df[[f'c{i}' for i in range(10)]].astype(int)

        if apply_constraints:
            # (1) 각 감정에 대해 투표가 정확히 1표인 경우 0으로 만든다.
            counts = counts.applymap(lambda x: 0 if x == 1 else x)

            # (2) 처리 후 가장 많은 투표 수가 전체 투표의 절반을 넘지 못하면 제거
            max_counts = counts.max(axis=1)
            total_votes = counts.sum(axis=1)
            valid_mask = max_counts > (total_votes / 2)

            df = df[valid_mask].reset_index(drop=True)
            counts = counts[valid_mask].reset_index(drop=True)

        # 최종 레이블 결정
        df['label'] = counts.values.argmax(axis=1)

        # 레이블이 0부터 7 사이인지 확인 (8개 클래스로 변경)
        valid_labels_mask = (df['label'] >= 0) & (df['label'] <= 7)
        df = df[valid_labels_mask].reset_index(drop=True)
        self.file_paths = df['filename'].values
        self.labels = df['label'].values.astype(np.int64)

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]
        img_path = os.path.join(self.root_dir, self.phase, self.file_paths[idx])

        image = cv2.imread(img_path)
        image = image[:, :, ::-1]  # BGR에서 RGB로 변환
        if self.transform is not None:
            image = self.transform(image)
        image1 = transforms.RandomHorizontalFlip(p=1)(image)
        return image, label, idx, image1

def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):

    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample=False):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU(inplace=True)

        if downsample:
            conv = nn.Conv2d(in_channels, out_channels, kernel_size=1,
                             stride=stride, bias=False)
            bn = nn.BatchNorm2d(out_channels)
            downsample = nn.Sequential(conv, bn)
        else:
            downsample = None

        self.downsample = downsample

    def forward(self, x):

        i = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)

        if self.downsample is not None:
            i = self.downsample(i)

        x += i
        x = self.relu(x)

        return x

class ResNet(nn.Module):
    def __init__(self, block, n_blocks, channels, output_dim):
        super().__init__()

        self.in_channels = channels[0]

        assert len(n_blocks) == len(channels) == 4

        self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self.get_resnet_layer(block, n_blocks[0], channels[0])
        self.layer2 = self.get_resnet_layer(block, n_blocks[1], channels[1], stride=2)
        self.layer3 = self.get_resnet_layer(block, n_blocks[2], channels[2], stride=2)
        self.layer4 = self.get_resnet_layer(block, n_blocks[3], channels[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.in_channels, output_dim)

    def get_resnet_layer(self, block=BasicBlock, n_blocks=2, channels=64, stride=1):

        layers = []

        if self.in_channels != block.expansion * channels:
            downsample = True
        else:
            downsample = False

        layers.append(block(self.in_channels, channels, stride, downsample))

        for i in range(1, n_blocks):
            layers.append(block(block.expansion * channels, channels))

        self.in_channels = block.expansion * channels

        return nn.Sequential(*layers)

    def forward(self, x):

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        h = x.view(x.shape[0], -1)
        x = self.fc(h)

        return x, h

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

def Mask(nb_batch, num_classes=8):  # num_classes 변경
    total_channels = 512
    channels_per_class = total_channels // num_classes
    remainder = total_channels % num_classes

    class_channel_counts = [channels_per_class] * num_classes
    for i in range(remainder):
        class_channel_counts[i] += 1

    bar = []
    for _ in range(nb_batch):
        batch_bar = [0] * total_channels

        active_channels = []
        for count in class_channel_counts:
            indices = random.sample(range(total_channels), count)
            active_channels.extend(indices)

        for idx in active_channels:
            batch_bar[idx] = 1

        bar.append(batch_bar)

    bar = np.array(bar).astype("float32")
    bar = bar.reshape(nb_batch, total_channels, 1, 1)
    bar = torch.from_numpy(bar)
    bar = bar.to(device)
    return bar

def supervisor(x, targets, cnum):
    branch = x
    branch = branch.reshape(branch.size(0), branch.size(1), 1, 1)
    branch = my_MaxPool2d(kernel_size=(1, cnum), stride=(1, cnum))(branch)
    branch = branch.reshape(branch.size(0), branch.size(1), branch.size(2) * branch.size(3))
    loss_2 = 1.0 - 1.0 * torch.mean(torch.sum(branch, 2)) / cnum

    mask = Mask(x.size(0), num_classes=8)  # num_classes 변경
    branch_1 = x.reshape(x.size(0), x.size(1), 1, 1) * mask
    branch_1 = my_MaxPool2d(kernel_size=(1, cnum), stride=(1, cnum))(branch_1)
    branch_1 = branch_1.view(branch_1.size(0), -1)
    loss_1 = nn.CrossEntropyLoss()(branch_1, targets)
    return [loss_1, loss_2]

class Model(nn.Module):
    def __init__(self, pretrained=True, num_classes=8, drop_rate=0):  # num_classes 변경
        super(Model, self).__init__()

        res18 = ResNet(block=BasicBlock, n_blocks=[2, 2, 2, 2],
                       channels=[64, 128, 256, 512], output_dim=1000)

        try:
            msceleb_model = torch.load('../.././resnet18_msceleb.pth')
            state_dict = msceleb_model['state_dict']
            res18.load_state_dict(state_dict, strict=False)
            print("---RESNET MSCELEB LOADED--")
        except FileNotFoundError:
            print("Pre-trained model not found. Using torchvision's pre-trained model.")
            res18 = models.resnet18(pretrained=True)

        self.drop_rate = drop_rate
        self.features = nn.Sequential(*list(res18.children())[:-2])  # 마지막 두 레이어 제거
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # Global Average Pooling 추가

        fc_in_dim = 512
        self.fc = nn.Linear(fc_in_dim, num_classes)  # num_classes=8

        self.parm = {}
        for name, parameters in self.fc.named_parameters():
            print(name, ':', parameters.size())
            self.parm[name] = parameters

    def forward(self, x, clip_model, targets, phase='train'):
        with torch.no_grad():
            image_features = clip_model.encode_image(x)  # (batch, 512)

        x = self.features(x)  # (batch, 512, 7, 7)
        x = self.avgpool(x)   # (batch, 512, 1, 1)
        x = x.view(x.size(0), -1)  # (batch, 512)

        if x.size(1) != image_features.size(1):
            x = nn.Linear(x.size(1), image_features.size(1)).to(x.device)(x)

        if phase == 'train':
            MC_loss = supervisor(image_features * torch.sigmoid(x), targets, cnum=64)  # cnum=64

        x = image_features * torch.sigmoid(x)  # (batch, 512)
        out = self.fc(x)  # (batch, num_classes=8)

        if phase == 'train':
            return out, MC_loss
        else:
            return out, out

def add_g(image_array, mean=0.0, var=30):
    std = var ** 0.5
    image_add = image_array + np.random.normal(mean, std, image_array.shape)
    image_add = np.clip(image_add, 0, 255).astype(np.uint8)
    return image_add

def flip_image(image_array):
    return cv2.flip(image_array, 1)

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

class Args:
    def __init__(self):
        self.raf_path = '/home/work/dhkim/fer/FERPlus/data'  
        self.resnet50_path = '../../resnet50_ft_weight.pkl'
        self.label_path = 'list_patition_label.txt'
        self.workers = 4
        self.batch_size = 32
        self.w = 7
        self.h = 7
        self.gpu = 0
        self.lam = 5
        self.epochs = 20

args = Args()

def train(args, model, train_loader, optimizer, scheduler, device):
    running_loss = 0.0
    iter_cnt = 0
    correct_sum = 0

    model.to(device)
    model.train()

    for batch_i, (imgs1, labels, indexes, imgs2) in enumerate(train_loader):
        imgs1 = imgs1.to(device)
        labels = labels.to(device)

        output, MC_loss = model(imgs1, clip_model, labels, phase='train')

        loss1 = nn.CrossEntropyLoss()(output, labels)
        loss = loss1 + 5 * MC_loss[1] + 1.5 * MC_loss[0]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        iter_cnt += 1
        _, predicts = torch.max(output, 1)
        correct_num = torch.eq(predicts, labels).sum()
        correct_sum += correct_num
        running_loss += loss.item()

    scheduler.step()
    running_loss = running_loss / iter_cnt
    acc = correct_sum.float() / float(len(train_loader.dataset))
    return acc, running_loss

def test(model, test_loader, device):
    with torch.no_grad():
        model.eval()

        running_loss = 0.0
        iter_cnt = 0
        correct_sum = 0
        data_num = 0

        for batch_i, (imgs1, labels, indexes, imgs2) in enumerate(test_loader):
            imgs1 = imgs1.to(device)
            labels = labels.to(device)

            outputs, _ = model(imgs1, clip_model, labels, phase='Test')

            loss = nn.CrossEntropyLoss()(outputs, labels)

            iter_cnt += 1
            _, predicts = torch.max(outputs, 1)

            correct_num = torch.eq(predicts, labels).sum()
            correct_sum += correct_num

            running_loss += loss.item()
            data_num += outputs.size(0)

        running_loss = running_loss / iter_cnt
        test_acc = correct_sum.float() / float(data_num)

    return test_acc, running_loss

def main():
    setup_seed(3407)

    train_transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
        transforms.RandomHorizontalFlip(),
        transforms.RandomErasing(scale=(0.02, 0.25))
    ])

    eval_transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    train_dataset = RafDataset(root_dir=args.raf_path, phase='Train', transform=train_transforms, apply_constraints=True)
    test_dataset = RafDataset(args.raf_path, phase='Test', transform=eval_transforms, apply_constraints=True)
    val_dataset = RafDataset(args.raf_path, phase='Valid', transform=eval_transforms, apply_constraints=True)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=False)

    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=False)

    model = Model(num_classes=8)  # num_classes=8로 변경

    device = torch.device('cuda:{}'.format(args.gpu))
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.0002, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

    best_val_acc = 0.0  # 최고 Validation Accuracy 기록
    best_model_path = "best_model.pth"

    for i in range(1, args.epochs + 1):
        train_acc, train_loss = train(args, model, train_loader, optimizer, scheduler, device)
        val_acc, val_loss = test(model, val_loader, device)
        print(f"Epoch: {i}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Validation Loss: {val_loss:.4f}, Validation Acc: {val_acc:.4f}")

        # Validation Accuracy가 최고일 때 모델 저장
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({'model_state_dict': model.state_dict()}, best_model_path)
            print(f"[INFO] New Best Model Saved with Validation Accuracy: {best_val_acc:.4f}")

        with open('results.txt', 'a') as f:
            f.write(f"Epoch: {i}, Validation Accuracy: {val_acc:.4f}, Validation Loss: {val_loss:.4f}\n")

    print("\n[Final Evaluation on Test Set with Best Model]")
    # 최적 모델 로드
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])

    test_acc, test_loss = test(model, test_loader, device)
    print(f"Best Validation Accuracy: {best_val_acc:.4f}")
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

    with open('results.txt', 'a') as f:
        f.write(f"Best Validation Accuracy: {best_val_acc:.4f}\n")
        f.write(f"Final Test Accuracy: {test_acc:.4f}, Test Loss: {test_loss:.4f}\n")


In [3]:
main()

Pre-trained model not found. Using torchvision's pre-trained model.
weight : torch.Size([8, 512])
bias : torch.Size([8])




Epoch: 1, Train Loss: 8.6520, Train Acc: 0.6887, Validation Loss: 0.6935, Validation Acc: 0.8132
[INFO] New Best Model Saved with Validation Accuracy: 0.8132
Epoch: 2, Train Loss: 8.1380, Train Acc: 0.8074, Validation Loss: 0.5382, Validation Acc: 0.8402
[INFO] New Best Model Saved with Validation Accuracy: 0.8402
Epoch: 3, Train Loss: 7.9757, Train Acc: 0.8341, Validation Loss: 0.4763, Validation Acc: 0.8499
[INFO] New Best Model Saved with Validation Accuracy: 0.8499
Epoch: 4, Train Loss: 7.8919, Train Acc: 0.8510, Validation Loss: 0.4353, Validation Acc: 0.8662
[INFO] New Best Model Saved with Validation Accuracy: 0.8662
Epoch: 5, Train Loss: 7.8323, Train Acc: 0.8652, Validation Loss: 0.4181, Validation Acc: 0.8709
[INFO] New Best Model Saved with Validation Accuracy: 0.8709
Epoch: 6, Train Loss: 7.7724, Train Acc: 0.8805, Validation Loss: 0.3916, Validation Acc: 0.8768
[INFO] New Best Model Saved with Validation Accuracy: 0.8768
Epoch: 7, Train Loss: 7.7253, Train Acc: 0.8900, Val