In [None]:
import os
import cv2
import torch
import pickle
import random
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.utils.data as data
from torchvision import transforms
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm import tqdm
from PIL import Image  # PIL 추가
import sys
sys.path.append('/kaggle/input/eac2/pytorch/default/1/Erasing-Attention-Consistency/src')  


# Utils
def add_g(image_array, mean=0.0, var=30):
    std = var ** 0.5
    image_add = image_array + np.random.normal(mean, std, image_array.shape)
    image_add = np.clip(image_add, 0, 255).astype(np.uint8)
    return image_add

def flip_image(image_array):
    return cv2.flip(image_array, 1)

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    
def generate_flip_grid(w, h, device):
    x_ = torch.arange(w).view(1, -1).expand(h, -1)
    y_ = torch.arange(h).view(-1, 1).expand(-1, w)
    grid = torch.stack([x_, y_], dim=0).float().to(device)
    grid = grid.unsqueeze(0).expand(1, -1, -1, -1)
    grid[:, 0, :, :] = 2 * grid[:, 0, :, :] / (w - 1) - 1
    grid[:, 1, :, :] = 2 * grid[:, 1, :, :] / (h - 1) - 1
    grid[:, 0, :, :] = -grid[:, 0, :, :]
    return grid

# Dataset
class FERPlusDataset(data.Dataset):
    def __init__(self, data_dir, label_file=None, phase="train", transform=None):
        self.data_dir = data_dir
        self.phase = phase
        self.transform = transform

        if phase == "test":
            # 테스트 데이터의 경우 라벨 파일 없이 이미지 파일만 읽음
            self.image_paths = sorted([
                os.path.join(data_dir, fname)
                for fname in os.listdir(data_dir)
                if fname.endswith('.png') and os.path.isfile(os.path.join(data_dir, fname))
            ])
        else:
            # 훈련 또는 검증 데이터의 경우 라벨 파일을 사용
            self.labels = pd.read_csv(label_file)
            self.image_paths = self.labels.iloc[:, 0].apply(lambda x: os.path.join(self.data_dir, x)).values
            self.labels = self.labels.iloc[:, 2:].values.argmax(axis=1)  # 라벨은 최대값 인덱스로 추출
    
    def _apply_constraints(self):
        # Constraint : 'unknown-face' 또는 'not-face' 레이블 제거
        max_counts = self.counts.max(axis=1)
        counts_eq_max = (self.counts == max_counts[:, None])
        constraint1_violation = counts_eq_max[:, [8, 9]].any(axis=1)

        # Constraint : 1인 라벨 0으로 만들기


        # Constraint : 최대 투표 수를 가진 레이블이 3개 초과 제거
        num_max_labels = counts_eq_max.sum(axis=1)
        constraint2_violation = num_max_labels > 3

        # Constraint : 최대 투표 수가 전체 투표 수의 절반 이하인 경우 제거
        total_votes = self.counts.sum(axis=1)
        constraint3_violation = max_counts <= (total_votes / 2)

        # Combine constraints
        valid_samples = ~(
            constraint1_violation | constraint2_violation | constraint3_violation
        )

        # Apply valid samples filter
        self.file_paths = self.file_paths[valid_samples]
        self.counts = self.counts[valid_samples]
    
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")

        if self.phase == "train":
            flipped_image = image.transpose(Image.FLIP_LEFT_RIGHT)  # 플립된 이미지 생성
            if self.transform:
                image = self.transform(image)
                flipped_image = self.transform(flipped_image)

            label = self.labels[idx]
            label = torch.tensor(label, dtype=torch.long)
            return image, label, flipped_image

        elif self.phase == "valid":
            if self.transform:
                image = self.transform(image)

            label = self.labels[idx]
            label = torch.tensor(label, dtype=torch.long)
            return image, label

        else:  # self.phase == "test"
            if self.transform:
                image = self.transform(image)

            return image, os.path.basename(image_path)
            
    @staticmethod
    def flip_image(image):
        """이미지 수평 반전"""
        return cv2.flip(image, 1)

    @staticmethod
    def add_noise(image, mean=0.0, var=30.0):
        """이미지에 가우시안 노이즈 추가"""
        std = var ** 0.5
        noisy_image = image + np.random.normal(mean, std, image.shape)
        noisy_image = np.clip(noisy_image, 0, 255).astype(np.uint8)
        return noisy_image

    def apply_advanced_aug(self, image):
        """Advanced augmentation 적용"""
        pil_image = transforms.ToPILImage()(image)  # OpenCV 이미지를 PIL 이미지로 변환
        augmented_image = self.advanced_aug(pil_image)  # Advanced augmentation 적용
        return np.array(augmented_image)  # 다시 numpy 배열로 변환
    def enable_erasing(self):
        """Dynamic Erasing을 활성화"""
        self.apply_erasing = True

    def disable_erasing(self):
        """Dynamic Erasing을 비활성화"""
        self.apply_erasing = False



# ACLoss
def ACLoss(att_map1, att_map2, grid_l, output):
    flip_grid_large = grid_l.expand(output.size(0), -1, -1, -1)
    flip_grid_large = Variable(flip_grid_large, requires_grad=False).permute(0, 2, 3, 1)
    att_map2_flip = F.grid_sample(att_map2, flip_grid_large, mode='bilinear', padding_mode='border', align_corners=True)
    flip_loss_l = F.mse_loss(att_map1, att_map2_flip)
    return flip_loss_l

def RegularizationLoss(model, reg_type="L2", lambda_reg=1e-4):
    reg_loss = 0.0
    for param in model.parameters():
        if reg_type == "L2":
            reg_loss += torch.sum(param ** 2)
        elif reg_type == "L1":
            reg_loss += torch.sum(torch.abs(param))
    return lambda_reg * reg_loss

def SparsityLoss(feature_map, lambda_s=1e-3):
    sparsity_loss = torch.sum(torch.abs(feature_map))
    return lambda_s * sparsity_loss


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


import torch
import torch.nn as nn
import math

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=8631, include_top=True):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.apply_erasing = False  # Dynamic Erasing 활성화 여부 플래그

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = [block(self.inplanes, planes, stride, downsample)]
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        # Dynamic Erasing: 상위 10% 활성화 값을 지움
        if self.apply_erasing:
            x = self.dynamic_erasing_top_k(x, top_k=0.1)

        x = self.layer4(x)
        x = self.avgpool(x)

        if not self.include_top:
            return x

        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Model
class Model(nn.Module):
    def __init__(self, args, pretrained=True, num_classes=7):
        super(Model, self).__init__()
        # ResNet50 모델 초기화
        resnet50 = ResNet(Bottleneck, [3, 4, 6, 3])
        with open(args.resnet50_path, 'rb') as f:
            obj = f.read()
        weights = {key: torch.from_numpy(arr) for key, arr in pickle.loads(obj, encoding='latin1').items()}
        resnet50.load_state_dict(weights)

        # 필요한 레이어 구성
        self.features = nn.Sequential(*list(resnet50.children())[:-2])  
        self.features2 = nn.Sequential(*list(resnet50.children())[-2:-1])  
        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):        
        x = self.features(x)
        feature = self.features2(x)
        feature = feature.view(feature.size(0), -1)
        output = self.fc(feature)
        params = list(self.parameters())
        fc_weights = params[-2].data.view(1, 7, 2048, 1, 1)
        fc_weights = Variable(fc_weights, requires_grad=False)
        feat = x.unsqueeze(1)
        hm = feat * fc_weights
        hm = hm.sum(2)
        return output, hm
    def enable_erasing(self):
        """Dynamic Erasing을 활성화"""
        self.apply_erasing = True

    def disable_erasing(self):
        """Dynamic Erasing을 비활성화"""
        self.apply_erasing = False

# Args Class
class Args:
    def __init__(self):
        self.data_dir = "/kaggle/input/fer-competition"
        self.train_dir = os.path.join(self.data_dir, "FER2013Train/FER2013Train")
        self.valid_dir = os.path.join(self.data_dir, "FER2013Valid/FER2013Valid")
        self.test_dir = os.path.join(self.data_dir, "FER2013Test/FER2013Test")  
        self.train_labels = os.path.join(self.data_dir, "train_label.csv")
        self.valid_labels = os.path.join(self.data_dir, "valid_label.csv")
        self.resnet50_path = "/kaggle/input/resnet50/resnet50_ft_weight (2).pkl"
        self.batch_size = 32
        self.workers = 4
        self.epochs = 20
        self.gpu = 0
        self.w = 7
        self.h = 7
        self.lam = 5

# Training Function
def train(args, model, train_loader, optimizer, scheduler, device):
    model.train()
    running_loss = 0.0
    correct_sum = 0
    total = 0

    progress_bar = tqdm(train_loader, desc="Training", leave=False)
    for imgs1, labels, flipped_imgs in progress_bar:
        imgs1, flipped_imgs, labels = imgs1.to(device), flipped_imgs.to(device), labels.to(device)
        optimizer.zero_grad()

        output, hm1 = model(imgs1)
        output_flip, hm2 = model(flipped_imgs)
        grid_l = generate_flip_grid(args.w, args.h, device)

        loss1 = nn.CrossEntropyLoss()(output, labels)
        flip_loss_l = ACLoss(hm1, hm2, grid_l, output)

        loss = loss1 + args.lam * flip_loss_l
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs1.size(0)
        _, preds = torch.max(output, 1)
        correct_sum += (preds == labels).sum().item()
        total += imgs1.size(0)

        progress_bar.set_postfix(loss=loss.item())

    scheduler.step()
    epoch_loss = running_loss / total
    epoch_acc = correct_sum / total
    return epoch_loss, epoch_acc

def create_submission_file(args, model, test_loader, device, output_file="submission.csv"):
    """
    테스트 데이터를 기반으로 제출 파일 생성
    """
    model.eval()
    predictions = []

    with torch.no_grad():
        for idx, (images, _) in enumerate(tqdm(test_loader, desc="Generating Submission")):
            images = images.to(device)
            outputs, _ = model(images) 
            _, preds = torch.max(outputs, 1)

            # 제출용 데이터 생성
            for i, pred in enumerate(preds.cpu().numpy()):
                predictions.append({"ID": len(predictions), "Prediction": pred})

    # 제출 파일 저장
    submission_df = pd.DataFrame(predictions)
    submission_df.to_csv(output_file, index=False)
    print(f"Submission file saved to {output_file}")


    
# Testing Function
def test(args, model, test_loader, device):
    model.eval()
    running_loss = 0.0
    correct_sum = 0
    total = 0

    with torch.no_grad():
        progress_bar = tqdm(test_loader, desc="Testing", leave=False)
        for batch in progress_bar:
            if len(batch) == 3:  # train/valid 데이터셋 (image, label, flipped image)
                imgs1, labels, _ = batch
            elif len(batch) == 2:  # test 데이터셋 (image, filename 또는 label 없음)
                imgs1, labels = batch

            imgs1, labels = imgs1.to(device), labels.to(device)
            output, _ = model(imgs1)
            loss = nn.CrossEntropyLoss()(output, labels)

            running_loss += loss.item() * imgs1.size(0)
            _, preds = torch.max(output, 1)
            correct_sum += (preds == labels).sum().item()
            total += imgs1.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct_sum / total
    return epoch_loss, epoch_acc


# Main Function
def main():
    setup_seed(0)
    args = Args()

    train_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(scale=(0.02, 0.25))
    ])
    eval_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_dataset = FERPlusDataset(args.train_dir, args.train_labels, phase='train', transform=train_transforms)
    valid_dataset = FERPlusDataset(args.valid_dir, args.valid_labels, phase='valid', transform=eval_transforms)  
    test_dataset = FERPlusDataset(args.test_dir, phase='test', transform=eval_transforms)
    
    train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
    valid_loader = data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
    test_loader = data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

    model = Model(args)
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

    print("Training...")
    for epoch in range(1, args.epochs + 1):
        if epoch == 5: 
            model.enable_erasing()
            print("Dynamic Erasing Enabled")
        print(f"Epoch {epoch}/{args.epochs}")
        train_loss, train_acc = train(args, model, train_loader, optimizer, scheduler, device)
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    

    train_acc, train_loss = train(args, model, train_loader, optimizer, scheduler, device)
    val_loss, val_acc = test(args, model, valid_loader, device)
    print(f"Epoch {epoch}: Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, Val Acc: {val_acc:.4f}")


    print("Generating Submission...")
    create_submission_file(args, model, test_loader, device)

if __name__ == '__main__':
    main()