In [1]:
import os
import cv2
import csv
import math
import random
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
from torchvision import transforms
import torch.utils.data as data
import torch.nn.functional as F

import sys
sys.path.append('../')

import clip
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clip_model, preprocess = clip.load("ViT-B/32", device=device)

from torch.nn.modules.module import Module
from torch.nn.modules.utils import _pair
from torch.nn.parameter import Parameter

from depth_anything_v2.dpt import DepthAnythingV2
import warnings
warnings.filterwarnings(
    "ignore",
    category=UserWarning,
    message=".*antialias parameter.*"
)

import timm

CONVNEXT_NAME = "convnext_large"  # 백본
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Depth Model 로드 (경로는 사용자 환경에 맞게 수정)
depth_model = DepthAnythingV2(encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024])
depth_model.load_state_dict(torch.load('/home/work/dhkim/fer/Depth-Anything-V2/checkpoints/depth_anything_v2_vitl.pth', map_location=device))
depth_model = depth_model.to(device)
depth_model.eval()

class my_MaxPool2d(Module):
    def __init__(self, kernel_size, stride=None, padding=0, dilation=1,
                 return_indices=False, ceil_mode=False):
        super(my_MaxPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride or kernel_size
        self.padding = padding
        self.dilation = dilation
        self.return_indices = return_indices
        self.ceil_mode = ceil_mode
    def forward(self, input):
        input = input.transpose(3, 1)
        input = F.max_pool2d(input, self.kernel_size, self.stride,
                             self.padding, self.dilation, self.ceil_mode,
                             self.return_indices)
        input = input.transpose(3, 1).contiguous()
        return input

class my_AvgPool2d(Module):
    def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
                 count_include_pad=True):
        super(my_AvgPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride or kernel_size
        self.padding = padding
        self.ceil_mode = ceil_mode
        self.count_include_pad = count_include_pad
    def forward(self, input):
        input = input.transpose(3, 1)
        input = F.avg_pool2d(input, self.kernel_size, self.stride,
                             self.padding, self.ceil_mode, self.count_include_pad)
        input = input.transpose(3, 1).contiguous()
        return input

def generate_or_load_depth_map(image_path):
    depth_map_path = image_path.replace('.png', '_depth_map.png').replace('.jpg', '_depth_map.png')
    if os.path.exists(depth_map_path):
        depth_map = cv2.imread(depth_map_path, cv2.IMREAD_GRAYSCALE)
    else:
        raw_img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        if raw_img is None:
            raise FileNotFoundError(f"Input image not found at the specified path: {image_path}")
        
        raw_img_rgb = cv2.cvtColor(raw_img, cv2.COLOR_GRAY2RGB)
        with torch.no_grad():
            depth = depth_model.infer_image(raw_img_rgb)
        
        normalized_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
        depth_map = normalized_depth.astype(np.uint8)
        
        cv2.imwrite(depth_map_path, depth_map)
    return depth_map

class RafDataset(data.Dataset):
    def __init__(self, root_dir, phase, transform=None, apply_constraints=False):
        self.root_dir = root_dir
        self.phase = phase
        self.transform = transform

        label_path = os.path.join(root_dir, phase, 'label.csv')
        df = pd.read_csv(label_path, header=None)

        df.columns = ['filename', 'bbox'] + [f'c{i}' for i in range(10)]
        counts = df[[f'c{i}' for i in range(10)]].astype(int)

        if apply_constraints:
            counts = counts.applymap(lambda x: 0 if x == 1 else x)
            max_counts = counts.max(axis=1)
            total_votes = counts.sum(axis=1)
            valid_mask = max_counts > (total_votes / 2)
            df = df[valid_mask].reset_index(drop=True)
            counts = counts[valid_mask].reset_index(drop=True)

        df['label'] = counts.values.argmax(axis=1)
        valid_labels_mask = (df['label'] >= 0) & (df['label'] <= 7)
        df = df[valid_labels_mask].reset_index(drop=True)
        self.file_paths = df['filename'].values
        self.labels = df['label'].values.astype(np.int64)

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]
        img_filename = self.file_paths[idx]
        img_path = os.path.join(self.root_dir, self.phase, img_filename)

        image = cv2.imread(img_path)
        image = image[:, :, ::-1]

        depth_map = generate_or_load_depth_map(img_path)
        depth_map = np.expand_dims(depth_map, axis=2)
        image_with_depth = np.concatenate((image, depth_map), axis=2)  # (H,W,4)

        if self.transform is not None:
            image_with_depth = self.transform(image_with_depth)

        return image_with_depth, label, idx, image_with_depth

class RAFDBTestDataset(data.Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (str): RAF-DB 테스트 데이터셋의 루트 디렉토리. 
                            예: '/path/to/RAF-DB/DATASET/test'
            transform (callable, optional): 데이터에 적용할 변환 함수.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.label_mapping = {
            '1': 2,  # surprise -> FERPlus label 2
            '2': 6,  # fear -> FERPlus label 6
            '3': 5,  # disgust -> FERPlus label 5
            '4': 1,  # happiness -> FERPlus label 1
            '5': 4,  # sadness -> FERPlus label 4
            '6': 3,  # anger -> FERPlus label 3
            '7': 0   # neutral -> FERPlus label 0
        }
        self.file_paths = []
        self.labels = []

        # 각 레이블 디렉토리를 순회하며 파일 경로와 레이블을 수집
        for label_str, mapped_label in self.label_mapping.items():
            label_dir = os.path.join(root_dir, label_str)
            if not os.path.isdir(label_dir):
                continue
            for fname in os.listdir(label_dir):
                # depth map 파일(_depth_map.png)은 스킵
                if '_depth_map' in fname:
                    continue

                # 이미지 확장자만 필터링
                if fname.lower().endswith(('.jpg', '.jpeg', '.png')):
                    self.file_paths.append(os.path.join(label_dir, fname))
                    self.labels.append(mapped_label)

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        """
        Returns:
            image_with_depth (Tensor): 변환된 이미지 텐서 (4채널: RGB + Depth)
            label (int): FERPlus 레이블 (0-6)
            idx (int): 샘플 인덱스
            image_with_depth (Tensor): 좌우 반전된 이미지 텐서 (여기서는 동일하게 반환)
        """
        img_path = self.file_paths[idx]
        label = self.labels[idx]

        image = cv2.imread(img_path)
        if image is None:
            raise FileNotFoundError(f"Image not found at path: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # BGR에서 RGB로 변환

        # Depth Map 로드 또는 생성
        depth_map = generate_or_load_depth_map(img_path)  # H x W
        depth_map = np.expand_dims(depth_map, axis=2)  # H x W x 1

        # 이미지와 Depth Map 결합
        image_with_depth = np.concatenate((image, depth_map), axis=2)  # H x W x 4

        if self.transform:
            image_with_depth = self.transform(image_with_depth)

        # 좌우 반전은 테스트 시 필요 없으므로 동일하게 반환
        return image_with_depth, label, idx, image_with_depth


def Mask(nb_batch, num_classes=8, total_channels=3072):
    channels_per_class = total_channels // num_classes
    remainder = total_channels % num_classes

    class_channel_counts = [channels_per_class] * num_classes
    for i in range(remainder):
        class_channel_counts[i] += 1

    bar = []
    for _ in range(nb_batch):
        batch_bar = [0] * total_channels
        active_channels = []
        for count in class_channel_counts:
            indices = random.sample(range(total_channels), count)
            active_channels.extend(indices)
        for idx in active_channels:
            batch_bar[idx] = 1
        bar.append(batch_bar)

    bar = np.array(bar).astype("float32")
    bar = bar.reshape(nb_batch, total_channels, 1, 1)
    bar = torch.from_numpy(bar).to(device)
    return bar

def supervisor(x, targets, cnum):
    branch = x.reshape(x.size(0), x.size(1), 1, 1)
    branch = my_MaxPool2d(kernel_size=(1, cnum), stride=(1, cnum))(branch)
    branch = branch.reshape(branch.size(0), branch.size(1), -1)
    loss_2 = 1.0 - torch.mean(torch.sum(branch, 2)) / cnum

    mask = Mask(x.size(0), num_classes=8, total_channels=x.size(1))
    branch_1 = x.reshape(x.size(0), x.size(1), 1, 1) * mask
    branch_1 = my_MaxPool2d(kernel_size=(1, cnum), stride=(1, cnum))(branch_1)
    branch_1 = branch_1.view(branch_1.size(0), -1)
    loss_1 = nn.CrossEntropyLoss()(branch_1, targets)
    return [loss_1, loss_2]

class Model(nn.Module):
    def __init__(self, pretrained=True, num_classes=8, drop_rate=0.0, convnext_name=CONVNEXT_NAME):
        super(Model, self).__init__()
        self.rgb_backbone = timm.create_model(convnext_name, pretrained=pretrained, num_classes=0)
        feature_dim = self.rgb_backbone.num_features  # convnext_large는 1536

        self.depth_backbone = timm.create_model(convnext_name, pretrained=pretrained, num_classes=0)
        depth_conv_weight = self.depth_backbone.stem[0].weight.data.mean(dim=1, keepdim=True)
        out_ch = depth_conv_weight.shape[0]
        self.depth_backbone.stem[0] = nn.Conv2d(1, out_ch, kernel_size=4, stride=4, padding=0, bias=False)
        self.depth_backbone.stem[0].weight.data = depth_conv_weight

        self.depth_attention = nn.Sequential(
            nn.Linear(512, feature_dim // 2),
            nn.ReLU(inplace=True),
            nn.Linear(feature_dim // 2, feature_dim),
            nn.Sigmoid()
        )

        self.image_proj = nn.Linear(512, feature_dim)
        self.dropout = nn.Dropout(p=0.2)
        self.fc = nn.Linear(feature_dim*2, num_classes)

    def forward(self, x, clip_model, targets, phase='train'):
        rgb_image = x[:, :3, :, :]
        depth_map = x[:, 3:, :, :]

        rgb_for_clip = F.interpolate(rgb_image, size=(224,224), mode='bicubic', align_corners=False)

        with torch.no_grad():
            image_features = clip_model.encode_image(rgb_for_clip)  # (B,512)

        image_features = image_features.float()
        rgb_feats = self.rgb_backbone(rgb_image)        # (B,1536)
        depth_feats = self.depth_backbone(depth_map)    # (B,1536)

        depth_att = self.depth_attention(image_features)  # (B,1536)
        depth_feats = depth_feats * depth_att

        image_features_1024 = self.image_proj(image_features)  # (B,1536)
        rgb_feats = rgb_feats * torch.sigmoid(image_features_1024)

        combined_features = torch.cat((rgb_feats, depth_feats), dim=1)  # (B, 3072)
        combined_features = self.dropout(combined_features)

        if phase == 'train':
            MC_loss = supervisor(combined_features, targets, cnum=384)  # 3072/8=384
        out = self.fc(combined_features)

        if phase == 'train':
            return out, MC_loss
        else:
            return out, out

def mixup_data(x, y, alpha=1.0):
    if alpha <= 0:
        return x, y, y, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

class SplitApplyTransform:
    def __init__(self, rgb_transform=None, depth_transform=None):
        self.rgb_transform = rgb_transform
        self.depth_transform = depth_transform
    
    def __call__(self, img):
        # img: (H,W,4) -> ToTensor() 후 (C,H,W), C=4
        if not isinstance(img, torch.Tensor):
            raise TypeError("SplitApplyTransform expects a torch.Tensor after ToTensor.")
        rgb = img[:3, :, :]
        depth = img[3:, :, :]

        if self.rgb_transform is not None:
            rgb = self.rgb_transform(rgb)

        if self.depth_transform is not None:
            depth = self.depth_transform(depth)

        return torch.cat((rgb, depth), dim=0)

class Args:
    def __init__(self):
        self.raf_path = '/home/work/dhkim/fer/FERPlus/data'
        self.rafdb_test_path = '/home/work/dhkim/fer/Depth-Anything-V2/raf-db/DATASET/test/'
        self.label_path = 'list_patition_label.txt'
        self.workers = 2
        self.batch_size = 64
        self.w = 7
        self.h = 7
        self.gpu = 0
        self.lam = 5
        self.epochs = 20
        self.mixup_alpha = 0.2

args = Args()

# PIL 기반 RandAugment를 RGB 채널에만 적용하기 위한 transform
rgb_transform = transforms.Compose([
    transforms.ToPILImage(),
    # 여기서 Resize를 할 필요가 없다면 생략 가능. 필요시 유지.
    transforms.Resize((256,256), antialias=True),
    transforms.RandAugment(num_ops=2, magnitude=10),
    transforms.ToTensor()
])

depth_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256,256), antialias=True),
    transforms.ToTensor()
])

train_transforms = transforms.Compose([
    transforms.ToTensor(),  
    SplitApplyTransform(rgb_transform=rgb_transform, depth_transform=depth_transform),
    transforms.Normalize(mean=[0.485, 0.456, 0.406, 0.5],
                         std=[0.229, 0.224, 0.225, 0.5]),
    transforms.RandomHorizontalFlip(),
    transforms.RandomErasing(scale=(0.02, 0.25))
])

eval_transforms = transforms.Compose([
    transforms.ToTensor(),
    SplitApplyTransform(
        rgb_transform=transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((256,256), antialias=True),
            transforms.ToTensor()
        ]),
        depth_transform=transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((256,256), antialias=True),
            transforms.ToTensor()
        ])
    ),
    transforms.Normalize(mean=[0.485, 0.456, 0.406, 0.5],
                         std=[0.229, 0.224, 0.225, 0.5])
])

def train(args, model, train_loader, optimizer, scheduler, device, scaler, criterion):
    running_loss = 0.0
    iter_cnt = 0
    correct_sum = 0

    model.to(device)
    model.train()

    for batch_i, (imgs1, labels, indexes, imgs2) in enumerate(train_loader):
        imgs1 = imgs1.to(device)
        labels = labels.to(device)

        imgs1, y_a, y_b, lam = mixup_data(imgs1, labels, alpha=args.mixup_alpha)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            output, MC_loss = model(imgs1, clip_model, y_a, phase='train')
            loss1 = mixup_criterion(criterion, output, y_a, y_b, lam)
            loss = loss1 + 5 * MC_loss[1] + 1.5 * MC_loss[0]

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        iter_cnt += 1
        _, predicts = torch.max(output, 1)
        correct_num = torch.eq(predicts, y_a).sum()
        correct_sum += correct_num
        running_loss += loss.item()

    scheduler.step()
    running_loss = running_loss / iter_cnt
    acc = correct_sum.float() / float(len(train_loader.dataset))
    return acc, running_loss

def test(model, test_loader, device, criterion):
    with torch.no_grad():
        model.eval()

        running_loss = 0.0
        iter_cnt = 0
        correct_sum = 0
        data_num = 0

        for batch_i, (imgs1, labels, indexes, imgs2) in enumerate(test_loader):
            imgs1 = imgs1.to(device)
            labels = labels.to(device)

            with torch.cuda.amp.autocast():
                outputs, _ = model(imgs1, clip_model, labels, phase='Test')
                loss = criterion(outputs, labels)

            iter_cnt += 1
            _, predicts = torch.max(outputs, 1)
            correct_num = torch.eq(predicts, labels).sum()
            correct_sum += correct_num

            running_loss += loss.item()
            data_num += outputs.size(0)

        running_loss = running_loss / iter_cnt
        test_acc = correct_sum.float() / float(data_num)
    return test_acc, running_loss

def main():
    setup_seed(3407)

    train_dataset = RafDataset(root_dir=args.raf_path, phase='Train', transform=train_transforms, apply_constraints=True)
    test_dataset = RafDataset(args.raf_path, phase='Test', transform=eval_transforms, apply_constraints=True)
    val_dataset = RafDataset(args.raf_path, phase='Valid', transform=eval_transforms, apply_constraints=True)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=False)

    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=False)

    model = Model(num_classes=8, convnext_name=CONVNEXT_NAME)
    device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu')
    model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    scaler = torch.cuda.amp.GradScaler()

    best_acc = 0
    for i in range(1, args.epochs + 1):
        train_acc, train_loss = train(args, model, train_loader, optimizer, scheduler, device, scaler, criterion)
        val_acc, val_loss = test(model, val_loader, device, criterion)
        test_acc, test_loss = test(model, test_loader, device, criterion)

        print(f"Epoch: {i}")
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save({'model_state_dict': model.state_dict()}, "ours_best.pth")
        torch.save({'model_state_dict': model.state_dict()}, "ours_final.pth")

    print("\n[Final Evaluation on FERPlus Test Set]")
    test_acc, test_loss = test(model, test_loader, device, criterion)
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")
    

if __name__ == '__main__':
    main()


RuntimeError: module compiled against API version 0x10 but this version of numpy is 0xf . Check the section C-API incompatibility at the Troubleshooting ImportError section at https://numpy.org/devdocs/user/troubleshooting-importerror.html#c-api-incompatibility for indications on how to solve this problem .

xFormers not available
xFormers not available


Epoch: 1
Train Loss: 8.5534, Train Acc: 0.4846
Val Loss: 0.8900, Val Acc: 0.8242
Epoch: 2
Train Loss: 8.1503, Train Acc: 0.5237
Val Loss: 0.8119, Val Acc: 0.8533
Epoch: 3
Train Loss: 7.8923, Train Acc: 0.5829
Val Loss: 0.7879, Val Acc: 0.8571
Epoch: 4
Train Loss: 7.9463, Train Acc: 0.5460
Val Loss: 0.7403, Val Acc: 0.8793
Epoch: 5
Train Loss: 7.8081, Train Acc: 0.5852
Val Loss: 0.7429, Val Acc: 0.8740
Epoch: 6
Train Loss: 7.7623, Train Acc: 0.5918
Val Loss: 0.7249, Val Acc: 0.8847
Epoch: 7
Train Loss: 7.7938, Train Acc: 0.5689
Val Loss: 0.7095, Val Acc: 0.8925
Epoch: 8
Train Loss: 7.7998, Train Acc: 0.5567
Val Loss: 0.7132, Val Acc: 0.8878
Epoch: 9
Train Loss: 7.6463, Train Acc: 0.6106
Val Loss: 0.7100, Val Acc: 0.8894
Epoch: 10
Train Loss: 7.6779, Train Acc: 0.5915
Val Loss: 0.7096, Val Acc: 0.8894
Epoch: 11
Train Loss: 7.8156, Train Acc: 0.5658
Val Loss: 0.7259, Val Acc: 0.8822
Epoch: 12
Train Loss: 7.6907, Train Acc: 0.6058
Val Loss: 0.7236, Val Acc: 0.8828
Epoch: 13
Train Loss: 7.6

In [4]:
test_dataset = RafDataset(args.raf_path, phase='Test', transform=eval_transforms, apply_constraints=True)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=False)



rafdb_test_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224), antialias=True),
        transforms.Normalize(mean=[0.485, 0.456, 0.406, 0.5],
                             std=[0.229, 0.224, 0.225, 0.5])
    ])
rafdb_test_dataset = RAFDBTestDataset(root_dir=args.rafdb_test_path, transform=rafdb_test_transforms)
rafdb_test_loader = torch.utils.data.DataLoader(rafdb_test_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=0,  
                                                   pin_memory=False)

model = Model(num_classes=8, convnext_name=CONVNEXT_NAME)
device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load("ours_best.pth", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

model.eval()
with torch.no_grad():
    test_acc, test_loss = test(model, test_loader, device, criterion)
    print(f"[Re-Evaluation on FERPlus Test Set with Best Model]")
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")


[Re-Evaluation on FERPlus Test Set with Best Model]
Test Loss: 0.7732, Test Accuracy: 0.8900
