# Import

In [None]:
import os
import random

import pandas as pd
import numpy as np

from PIL import Image
from tqdm import tqdm

from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch import nn, optim

from sklearn.metrics import log_loss


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
!pip install inplace-abn
!pip install imgaug
!pip install early_stopping_pytorch
!pip install wandb

In [None]:
import imgaug as ia
import imgaug.augmenters as iaa
from early_stopping_pytorch import EarlyStopping
import wandb

In [None]:
!git clone https://github.com/Alibaba-MIIL/TResNet
%cd TResNet

# Hyperparameter Setting

In [None]:
CFG = {
    'IMG_SIZE': 368,
    'BATCH_SIZE': 16,
    'EPOCHS': 20,
    'LEARNING_RATE': 1e-4,
    'SEED' : 42,
    'gradient_accumulation' : True,
    'accumulation_steps' : 4
}

In [None]:
!wandb login

In [None]:
# Initialize wandb
wandb.init(
    entity='Dacon_Car',
    project="car-classification",  # your project name
    name='TResNet_PMAL',
    config=CFG  # this will log your hyperparameters
)

# Fixed RandomSeed

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)    
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(CFG['SEED']) # Seed 고정

# CustomDataset

In [None]:
import os
from PIL import Image
import numpy as np # NumPy 임포트 추가
from torch.utils.data import Dataset


class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None, is_test=False):
        self.root_dir = root_dir
        self.transform = transform
        self.is_test = is_test
        self.samples = []

        if is_test:
            # 테스트셋: 라벨 없이 이미지 경로만 저장
            for fname in sorted(os.listdir(root_dir)):
                if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.gif')): # 이미지 확장자 추가
                    img_path = os.path.join(root_dir, fname)
                    self.samples.append((img_path,))
        else:
            # 학습셋: 클래스별 폴더 구조에서 라벨 추출
            self.classes = sorted(os.listdir(root_dir))
            self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}

            for cls_name in self.classes:
                cls_folder = os.path.join(root_dir, cls_name)
                # 폴더가 아닌 파일이 있을 수 있으므로 isdir 체크 추가
                if not os.path.isdir(cls_folder):
                    continue
                for fname in os.listdir(cls_folder):
                    if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.gif')): # 이미지 확장자 추가
                        img_path = os.path.join(cls_folder, fname)
                        label = self.class_to_idx[cls_name]
                        self.samples.append((img_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if self.is_test:
            img_path = self.samples[idx][0]
            image = Image.open(img_path).convert('RGB')
            # PIL 이미지를 NumPy 배열로 변환
            image = np.array(image)

            if self.transform:
                # Albumentations는 딕셔너리를 반환하며 'image' 키에 변환된 이미지가 있습니다.
                transformed_data = self.transform(image=image)
                image = transformed_data['image'] # PyTorch 텐서 (C, H, W)

            return image
        else:
            img_path, label = self.samples[idx]
            image = Image.open(img_path).convert('RGB')
            # PIL 이미지를 NumPy 배열로 변환
            image = np.array(image)

            if self.transform:
                # Albumentations는 딕셔너리를 반환하며 'image' 키에 변환된 이미지가 있습니다.
                transformed_data = self.transform(image=image)
                image = transformed_data['image'] # PyTorch 텐서 (C, H, W)

            return image, label

# Data Load

In [None]:
train_root = '/kaggle/input/car-classification/train'
test_root = '/kaggle/input/car-classification/test'

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2 # PyTorch 텐서로 변환하기 위함
import numpy as np # Albumentations는 NumPy 배열을 입력으로 받습니다.
from PIL import Image # 이미지 로딩을 위한 라이브러리

# Albumentations의 train_transform
train_transform = A.Compose([
    # ResizeIfPadNeeded는 가로세로 비율을 유지하면서 이미지의 긴 변 또는 짧은 변을 리사이즈한 다음,
    # 지정된 크기에 맞춰 패딩을 추가합니다.
    # pad_height, pad_width는 최종 출력 크기를 의미합니다.
    # 만약 원본 비율을 유지하면서 패딩으로 채우는 것이 목적이라면 아래와 같이 LongestMaxSize와 PadIfNeeded를 사용합니다.
    A.LongestMaxSize(max_size=CFG['IMG_SIZE'], interpolation=Image.BILINEAR),
    A.PadIfNeeded(min_height=CFG['IMG_SIZE'], min_width=CFG['IMG_SIZE'],
                border_mode=0, fill=(0,0,0)), # border_mode=0 (CONSTANT), value는 패딩 색상

    # 일반적으로 학습 시에는 Resize 후 Normalize를 많이 사용합니다.
    # torchvision의 Normalize와 동일한 mean/std 값을 사용합니다.
    A.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
                max_pixel_value=255.0), # 이미지 픽셀 값의 최댓값 (일반적으로 255)

    # Albumentations의 ToTensorV2는 이미지를 PyTorch 텐서로 변환하고 채널 순서를 (H, W, C) -> (C, H, W)로 변경합니다.
    # torchvision의 ToTensor()와 유사하게 동작합니다.
    ToTensorV2()
])  

# Albumentations의 val_transform (train_transform과 동일하게 구성)
val_transform = A.Compose([
    # 검증 시에도 동일하게 Resize 및 Normalize를 적용합니다.
    A.LongestMaxSize(max_size=CFG['IMG_SIZE'], interpolation=Image.BILINEAR),
    A.PadIfNeeded(min_height=CFG['IMG_SIZE'], min_width=CFG['IMG_SIZE'],
                border_mode=0, fill=(0,0,0)), # border_mode=0 (CONSTANT), value는 패딩 색상
    A.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
                max_pixel_value=255.0),
    ToTensorV2()    
])

In [None]:
# 전체 데이터셋 로드
full_dataset = CustomImageDataset(train_root, transform=None)
print(f"총 이미지 수: {len(full_dataset)}")

targets = [label for _, label in full_dataset.samples]
class_names = full_dataset.classes

# Stratified Split
train_idx, val_idx = train_test_split(
    range(len(targets)), test_size=0.2, stratify=targets, random_state=42
)

# Subset + transform 각각 적용  
train_dataset = Subset(CustomImageDataset(train_root, transform=train_transform), train_idx)
val_dataset = Subset(CustomImageDataset(train_root, transform=val_transform), val_idx)
print(f'train 이미지 수: {len(train_dataset)}, valid 이미지 수: {len(val_dataset)}')


# DataLoader 정의
train_loader = DataLoader(train_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=False)


# Model Define

In [None]:
from src.models.tresnet_v2.tresnet_v2 import TResnetL_V2 as TResnetL368


class TResNet(nn.Module):
    def __init__(self, num_classes):
        super(TResNet, self).__init__()
        model_params = {'num_classes' : 196}
        self.backbone = TResnetL368(model_params)
        
        weights_path = "/kaggle/input/tresnet-stanford-cars-pretrained/stanford_cars_tresnet-l-v2_96_27.pth"
        pretrained_weights = torch.load(weights_path)
        
        self.backbone.load_state_dict(pretrained_weights['model'])  # TResnetL368 모델 불러오기
        self.feature_dim = self.backbone.num_features
        self.backbone.head = nn.Identity()  # feature extractor로만 사용
        self.head = nn.Linear(self.feature_dim, num_classes)  # 분류기

    def forward(self, x):
        x = self.backbone(x)
        x = self.head(x)
        return x

# Utils

In [None]:
from torch.nn.modules.batchnorm import _BatchNorm

def cosine_anneal_schedule(t, nb_epoch, lr):
    cos_inner = np.pi * (t % (nb_epoch))
    cos_inner /= (nb_epoch)
    cos_out = np.cos(cos_inner) + 1

    return float(lr / 2 * cos_out)

class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
                              stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes, eps=1e-5,
                                 momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

class Features(nn.Module):
    def __init__(self, net_layers_FeatureHead):
        super(Features, self).__init__()
        self.net_layer_0 = nn.Sequential(net_layers_FeatureHead[0])
        self.net_layer_1 = nn.Sequential(*net_layers_FeatureHead[1])
        self.net_layer_2 = nn.Sequential(*net_layers_FeatureHead[2])
        self.net_layer_3 = nn.Sequential(*net_layers_FeatureHead[3])
        self.net_layer_4 = nn.Sequential(*net_layers_FeatureHead[4])
        self.net_layer_5 = nn.Sequential(*net_layers_FeatureHead[5])

    def forward(self, x):
        x = self.net_layer_0(x)
        x = self.net_layer_1(x)
        x = self.net_layer_2(x)
        x1 = self.net_layer_3(x)
        x2 = self.net_layer_4(x1)
        x3 = self.net_layer_5(x2)

        return x1, x2, x3


class Network_Wrapper(nn.Module):
    def __init__(self, net_layers, num_classes, classifier):
        super().__init__()
        self.Features = Features(net_layers)
        self.classifier_pool = nn.Sequential(classifier[0])
        
        # classifier_initial을 num_classes에 맞게 수정
        self.classifier_initial = nn.Linear(2048, num_classes)  # 기존 196을 num_classes로 변경
        
        self.sigmoid = nn.Sigmoid()
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)

        self.max_pool1 = nn.MaxPool2d(kernel_size=46, stride=1)
        self.max_pool2 = nn.MaxPool2d(kernel_size=23, stride=1)
        self.max_pool3 = nn.MaxPool2d(kernel_size=12, stride=1)

        self.conv_block1 = nn.Sequential(
            BasicConv(512, 512, kernel_size=1, stride=1, padding=0, relu=True),
            BasicConv(512, 1024, kernel_size=3, stride=1, padding=1, relu=True)
        )
        self.classifier1 = nn.Sequential(
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ELU(inplace=True),
            nn.Linear(512, num_classes)
        )

        self.conv_block2 = nn.Sequential(
            BasicConv(1024, 512, kernel_size=1, stride=1, padding=0, relu=True),
            BasicConv(512, 1024, kernel_size=3, stride=1, padding=1, relu=True)
        )
        self.classifier2 = nn.Sequential(
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ELU(inplace=True),
            nn.Linear(512, num_classes),
        )

        self.conv_block3 = nn.Sequential(
            BasicConv(2048, 512, kernel_size=1, stride=1, padding=0, relu=True),
            BasicConv(512, 1024, kernel_size=3, stride=1, padding=1, relu=True)
        )
        self.classifier3 = nn.Sequential(
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ELU(inplace=True),
            nn.Linear(512, num_classes),
        )

    def forward(self, x):
        x1, x2, x3 = self.Features(x)
        map1 = x1.clone()
        map2 = x2.clone()
        map3 = x3.clone()

        classifiers = self.classifier_pool(x3).view(x3.size(0), -1)
        classifiers = self.classifier_initial(classifiers)  # 이제 num_classes 출력

        x1_ = self.conv_block1(x1)
        x1_ = self.max_pool1(x1_)
        x1_f = x1_.view(x1_.size(0), -1)

        x1_c = self.classifier1(x1_f)

        x2_ = self.conv_block2(x2)
        x2_ = self.max_pool2(x2_)
        x2_f = x2_.view(x2_.size(0), -1)
        x2_c = self.classifier2(x2_f)

        x3_ = self.conv_block3(x3)
        x3_ = self.max_pool3(x3_)
        x3_f = x3_.view(x3_.size(0), -1)
        x3_c = self.classifier3(x3_f)

        return x1_c, x2_c, x3_c, classifiers, map1, map2, map3


class Anti_Noise_Decoder(nn.Module):
    def __init__(self, scale, in_channel):
        super(Anti_Noise_Decoder, self).__init__()
        self.Sigmoid = nn.Sigmoid()

        in_channel = in_channel // (scale * scale)

        self.skip = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1, bias=False),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            nn.Conv2d(64, 3, 3, 1, 1, bias=False),
            nn.LeakyReLU(negative_slope=0.1, inplace=True)

        )

        self.process = nn.Sequential(
            nn.PixelShuffle(scale),
            nn.Conv2d(in_channel, 256, 3, 1, 1, bias=False),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            nn.PixelShuffle(2),
            nn.Conv2d(64, 128, 3, 1, 1, bias=False),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            nn.PixelShuffle(2),
            nn.Conv2d(32, 64, 3, 1, 1, bias=False),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            nn.PixelShuffle(2),
            nn.Conv2d(16, 3, 3, 1, 1, bias=False),
            nn.LeakyReLU(negative_slope=0.1, inplace=True)
        )

    def forward(self, x, map):
        x_ = self.process(map)
        if not (x.size() == x_.size()):
            x_ = F.interpolate(x, (x.size(2),x.size(3)), mode='bilinear')
        return self.skip(x) + x_


def img_add_noise(x, transformation_seq):
    x = x.permute(0, 2, 3, 1)
    x = x.cpu().numpy()
    x = transformation_seq(images=x)
    x = torch.from_numpy(x.astype(np.float32))
    x = x.permute(0, 3, 1, 2)
    return x

def smooth_crossentropy(pred, gold, smoothing=0.1):
    n_class = pred.size(1)

    one_hot = torch.full_like(pred, fill_value=smoothing / (n_class - 1))
    one_hot.scatter_(dim=1, index=gold.unsqueeze(1), value=1.0 - smoothing)
    log_prob = F.log_softmax(pred, dim=1)

    return F.kl_div(input=log_prob, target=one_hot, reduction='none').sum(-1)

def CELoss(x, y):
    return smooth_crossentropy(x, y, smoothing=0.1)

class CharbonnierLoss(nn.Module):
    """Charbonnier Loss (L1)"""

    def __init__(self, eps=1e-3):
        super(CharbonnierLoss, self).__init__()
        self.eps = eps

    def forward(self, x, y):
        diff = x - y
        # loss = torch.sum(torch.sqrt(diff * diff + self.eps))
        loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
        return loss
    



def disable_running_stats(model):
    def _disable(module):
        if isinstance(module, _BatchNorm):
            module.backup_momentum = module.momentum
            module.momentum = 0

    model.apply(_disable)

def enable_running_stats(model):
    def _enable(module):
        if isinstance(module, _BatchNorm) and hasattr(module, "backup_momentum"):
            module.momentum = module.backup_momentum

    model.apply(_enable)


class Student_Wrapper(nn.Module):
    def __init__(self, net_layers, classifier):
        super(Student_Wrapper, self).__init__()
        self.net_layer_0 = nn.Sequential(net_layers[0])
        self.net_layer_1 = nn.Sequential(*net_layers[1])
        self.net_layer_2 = nn.Sequential(*net_layers[2])
        self.net_layer_3 = nn.Sequential(*net_layers[3])
        self.net_layer_4 = nn.Sequential(*net_layers[4])
        self.net_layer_5 = nn.Sequential(*net_layers[5])

        self.classifier_pool = nn.Sequential(classifier[0])
        self.classifier_initial = nn.Sequential(classifier[1])

    def forward(self, x):
        x = self.net_layer_0(x)
        x = self.net_layer_1(x)
        x = self.net_layer_2(x)
        x1 = self.net_layer_3(x)
        x2 = self.net_layer_4(x1)
        x3 = self.net_layer_5(x2)


        classifiers = self.classifier_pool(x3).view(x3.size(0), -1)
        out = self.classifier_initial(classifiers)

        return out, x1, x2, x3

In [None]:
# SAM
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

# Freeze weights

In [None]:
best_logloss = float('inf')

# PMAL
model_params = {'num_classes' : 196}
model = TResnetL368(model_params)
weights_path = "/kaggle/input/tresnet-stanford-cars-pretrained/stanford_cars_tresnet-l-v2_96_27.pth"
pretrained_weights = torch.load(weights_path)
model.load_state_dict(pretrained_weights['model'])

net_layers = list(model.children())
classifier = net_layers[1:3]
net_layers = net_layers[0]
net_layers = list(net_layers.children())

# Network_Wrapper 생성
net = Network_Wrapper(net_layers, len(class_names), classifier)

# ====== Pretrained weights freeze ======
# Features (backbone) 부분 freeze
for param in net.Features.parameters():
    param.requires_grad = False

# # classifier_pool 부분도 freeze (기존 pretrained의 일부)
# for param in net.classifier_pool.parameters():
#     param.requires_grad = False

print("🔒 Frozen parameters:")
frozen_params = 0
for name, param in net.named_parameters():
    if not param.requires_grad:
        frozen_params += param.numel()
        print(f"  - {name}: {param.shape}")

print(f"\n📊 Parameter Summary:")
total_params = sum(p.numel() for p in net.parameters())
trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f"  Total parameters: {total_params:,}")
print(f"  Frozen parameters: {frozen_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Trainable ratio: {trainable_params/total_params:.2%}")

print(f"\n🎯 Trainable components:")
for name, param in net.named_parameters():
    if param.requires_grad:
        print(f"  - {name}: {param.shape}")

net.to(device)
decoder1 = Anti_Noise_Decoder(1, 512).to(device)
decoder2 = Anti_Noise_Decoder(2, 1024).to(device)
decoder3 = Anti_Noise_Decoder(4, 2048).to(device)

#loss
CB_loss = CharbonnierLoss()

#optimizer
base_optimizer = torch.optim.SGD

optimizer = SAM([
        {'params': net.classifier_initial.parameters(), 'lr': 0.002},
        {'params': net.conv_block1.parameters(), 'lr': 0.002},
        {'params': net.classifier1.parameters(), 'lr': 0.002},
        {'params': net.conv_block2.parameters(), 'lr': 0.002},
        {'params': net.classifier2.parameters(), 'lr': 0.002},
        {'params': net.conv_block3.parameters(), 'lr': 0.002},
        {'params': net.classifier3.parameters(), 'lr': 0.002},

        {'params': decoder1.skip.parameters(), 'lr': 0.002},
        {'params': decoder1.process.parameters(), 'lr': 0.002},
        {'params': decoder2.skip.parameters(), 'lr': 0.002},
        {'params': decoder2.process.parameters(), 'lr': 0.002},
        {'params': decoder3.skip.parameters(), 'lr': 0.002},
        {'params': decoder3.process.parameters(), 'lr': 0.002},

    ],
        base_optimizer, adaptive=False, momentum=0.9, weight_decay=5e-4)

max_val_acc = 0
lr = [0.002, 0.002, 0.002, 0.002, 0.002, 0.002, 0.002, 0.002, 0.002, 0.002, 0.002, 0.002, 0.002]


In [None]:
# ====== 선택적 unfreezing (optional) ======
def unfreeze_last_n_blocks(model, n_blocks=1):
    """마지막 n개 블록만 unfreeze (fine-tuning 시 사용)"""
    # TResNet의 경우 body의 마지막 몇 개 layer만 unfreeze
    if hasattr(model.Features, 'body'):
        body_layers = list(model.Features.body.children())
        for layer in body_layers[-n_blocks:]:
            for param in layer.parameters():
                param.requires_grad = True
    print(f"🔓 Unfroze last {n_blocks} blocks")

def unfreeze_classifier_pool():
    """classifier_pool도 학습하고 싶다면"""
    for param in net.classifier_pool.parameters():
        param.requires_grad = True
    print("🔓 Unfroze classifier_pool")

# 사용 예시 (필요시 주석 해제):
# unfreeze_last_n_blocks(net, n_blocks=1)  # 마지막 1개 블록 unfreeze
# unfreeze_classifier_pool()  # classifier_pool unfreeze

# ====== 학습률 스케줄링 (optional) ======
def get_parameter_groups_with_different_lr(model, backbone_lr=1e-5, new_layers_lr=1e-3):
    """backbone과 새로운 layer에 다른 학습률 적용"""
    backbone_params = []
    new_layer_params = []
    
    # Backbone (frozen이 아닌 경우)
    for name, param in model.named_parameters():
        if param.requires_grad:
            if 'Features' in name:
                backbone_params.append(param)
            else:
                new_layer_params.append(param)
    
    return [
        {'params': backbone_params, 'lr': backbone_lr},
        {'params': new_layer_params, 'lr': new_layers_lr}
    ]

# 다른 학습률 사용하고 싶다면:
# param_groups = get_parameter_groups_with_different_lr(net, backbone_lr=1e-5, new_layers_lr=1e-3)
# optimizer = SAM(param_groups, torch.optim.SGD, momentum=0.9)

In [None]:

# best_logloss = float('inf')

# # # 손실 함수
# # criterion = nn.CrossEntropyLoss()

# # # 옵티마이저
# # optimizer = optim.Adam(model.parameters(), lr=CFG['LEARNING_RATE'])


# # PMAL
# model_params = {'num_classes' : 196}
# model = TResnetL368(model_params)
# weights_path = "/kaggle/input/tresnet-stanford-cars-pretrained/stanford_cars_tresnet-l-v2_96_27.pth"
# pretrained_weights = torch.load(weights_path)
# model.load_state_dict(pretrained_weights['model'])

# net_layers = list(model.children())
# classifier = net_layers[1:3]
# net_layers = net_layers[0]
# net_layers = list(net_layers.children())

# net = Network_Wrapper(net_layers, len(class_names), classifier)
# # netp = torch.nn.DataParallel(net, device_ids=[0])

# net.to(device)
# decoder1 = Anti_Noise_Decoder(1, 512).to(device)
# decoder2 = Anti_Noise_Decoder(2, 1024).to(device)
# decoder3 = Anti_Noise_Decoder(4, 2048).to(device)

# #loss
# CB_loss = CharbonnierLoss()

# #optimizer
# base_optimizer = torch.optim.SGD

# optimizer = SAM([
#         {'params': net.classifier_initial.parameters(), 'lr': 0.002},
#         {'params': net.conv_block1.parameters(), 'lr': 0.002},
#         {'params': net.classifier1.parameters(), 'lr': 0.002},
#         {'params': net.conv_block2.parameters(), 'lr': 0.002},
#         {'params': net.classifier2.parameters(), 'lr': 0.002},
#         {'params': net.conv_block3.parameters(), 'lr': 0.002},
#         {'params': net.classifier3.parameters(), 'lr': 0.002},

#         {'params': decoder1.skip.parameters(), 'lr': 0.002},
#         {'params': decoder1.process.parameters(), 'lr': 0.002},
#         {'params': decoder2.skip.parameters(), 'lr': 0.002},
#         {'params': decoder2.process.parameters(), 'lr': 0.002},
#         {'params': decoder3.skip.parameters(), 'lr': 0.002},
#         {'params': decoder3.process.parameters(), 'lr': 0.002},

#         {'params': net.Features.parameters(), 'lr': 0.0002}

#     ],
#         base_optimizer, adaptive=False, momentum=0.9, weight_decay=5e-4)

# max_val_acc = 0
# lr = [0.002, 0.002, 0.002, 0.002, 0.002, 0.002, 0.002, 0.002, 0.002, 0.002, 0.002, 0.002, 0.002, 0.0002]

# Train/ Validation

In [None]:
gradient_accumulation = CFG['gradient_accumulation']

if not gradient_accumulation:
    
    # 학습 및 검증 루프
    for epoch in range(CFG['EPOCHS']):
        # Train
        net.train()
        train_loss = 0.0
        train_loss1 = 0
        train_loss2 = 0
        train_loss3 = 0
        train_loss4 = 0
        train_loss5 = 0
    
        for images, targets in tqdm(train_loader, desc=f"[Epoch {epoch+1}/{CFG['EPOCHS']}] Training"):
            images, targets = images.to(device), targets.to(device)
    
            for nlr in range(len(optimizer.param_groups)):
                optimizer.param_groups[nlr]['lr'] = cosine_anneal_schedule(epoch, CFG['EPOCHS'], lr[nlr])
            
            sometimes_1 = lambda aug: iaa.Sometimes(0.2, aug)
            sometimes_2 = lambda aug: iaa.Sometimes(0.5, aug)
    
            trans_seq_aug = iaa.Sequential(
                [
    
                    sometimes_1(iaa.Affine(
                        scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
                        translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
                        rotate=(-15, 15),
                        shear=(-15, 15),
                        order=[0, 1],
                        cval=(0, 1),
                        mode=ia.ALL
                    )),
                    sometimes_2(iaa.GaussianBlur((0, 3.0)))
                ],
                random_order=True
            )
    
            trans_seq = iaa.Sequential(
                [
    
                    iaa.AdditiveGaussianNoise(
                        loc=0, scale=(0.0, 0.05), per_channel=0.5
                    )
                ],
                random_order=True
            )
            # H1 first forward-backward step
            enable_running_stats(net)
            optimizer.zero_grad()
    
            inputs1_gt = img_add_noise(images, trans_seq_aug).to(device)
            inputs1 = img_add_noise(inputs1_gt, trans_seq).to(device)
            output_1, _, _, _, map1, _, _ = net(inputs1)
    
            loss1_c = CELoss(output_1, targets).mean() * 1
            inputs1_syn = decoder1(inputs1, map1)
            loss1_g = CB_loss(inputs1_syn, inputs1_gt) * 1
    
            output_1_syn, _, _, _, _, _, _ = net(inputs1_syn)
            loss1_c_syn = CELoss(output_1_syn, targets).mean() * 1
    
            loss1 = loss1_c + (loss1_g) + loss1_c_syn
            loss1.backward()
            optimizer.first_step(zero_grad=True)
    
            # H1 second forward-backward step
            disable_running_stats(net)
            optimizer.zero_grad()
    
            output_1, _, _, _, map1, _, _ = net(inputs1)
            loss1_c = CELoss(output_1, targets).mean() * 1
    
            inputs1_syn = decoder1(inputs1, map1)
            loss1_g = CB_loss(inputs1_syn, inputs1_gt) * 1
    
            output_1_syn, _, _, _, _, _, _ = net(inputs1_syn)
            loss1_c_syn = CELoss(output_1_syn, targets).mean() * 1
    
            loss1_ = loss1_c + loss1_g + loss1_c_syn
            loss1_.backward()
            optimizer.second_step(zero_grad=True)
    
            loss1 = loss1.cpu()
            loss1_g = loss1_g.cpu()
    
            del output_1
            del output_1_syn
            del loss1_
            del loss1_c
            del loss1_c_syn
            del inputs1
            del inputs1_gt
            del inputs1_syn
            torch.cuda.empty_cache()
        
            # H2
            # H2 first forward-backward step
            enable_running_stats(net)
            optimizer.zero_grad()
            inputs2_gt = img_add_noise(images, trans_seq_aug).to(device)
            inputs2 = img_add_noise(inputs2_gt, trans_seq).to(device)
            _, output_2, _, _, _, map2, _ = net(inputs2)
            loss2_c = CELoss(output_2, targets).mean() * 1
    
            inputs2_syn = decoder2(inputs2, map2)
            loss2_g = CB_loss(inputs2_syn, inputs2_gt) * 1
    
            _, output_2_syn, _, _, _, _, _ = net(inputs2_syn)
            loss2_c_syn = CELoss(output_2_syn, targets).mean() * 1
    
            loss2 = loss2_c + loss2_g + loss2_c_syn
            loss2.backward()
            optimizer.first_step(zero_grad=True)
    
            # H2 second forward-backward step
            disable_running_stats(net)
            optimizer.zero_grad()
            _, output_2, _, _, _, map2, _ = net(inputs2)
            loss2_c = CELoss(output_2, targets).mean() * 1
    
            inputs2_syn = decoder2(inputs2, map2)
            loss2_g = CB_loss(inputs2_syn, inputs2_gt) * 1
    
            _, output_2_syn, _, _, _, _, _ = net(inputs2_syn)
            loss2_c_syn = CELoss(output_2_syn, targets).mean() * 1
    
            loss2_ = loss2_c + (loss2_g) + loss2_c_syn
            loss2_.backward()
            optimizer.second_step(zero_grad=True)
    
            loss2 = loss2.cpu()
            loss2_g = loss2_g.cpu()
            del output_2
            del output_2_syn
            del loss2_
            del loss2_c
            del loss2_c_syn
            del inputs2
            del inputs2_gt
            del inputs2_syn
            torch.cuda.empty_cache()
    
            # H3
            # H3 first forward-backward step
            enable_running_stats(net)
            optimizer.zero_grad()
            inputs3_gt = img_add_noise(images, trans_seq_aug).to(device)
            inputs3 = img_add_noise(inputs3_gt, trans_seq).to(device)
            _, _, output_3, _, _, _, map3 = net(inputs3)
            loss3_c = CELoss(output_3, targets).mean() * 1
    
            inputs3_syn = decoder3(inputs3, map3)
            loss3_g = CB_loss(inputs3_syn, inputs3_gt) * 1
    
            _, _, output_3_syn, _, _, _, _ = net(inputs3_syn)
            loss3_c_syn = CELoss(output_3_syn, targets).mean() * 1
    
            loss3 = loss3_c + (loss3_g) + loss3_c_syn
            loss3.backward()
            optimizer.first_step(zero_grad=True)
    
            # H3 second forward-backward step
            disable_running_stats(net)
            optimizer.zero_grad()
            _, _, output_3, _, _, _, map3 = net(inputs3)
            loss3_c = CELoss(output_3, targets).mean() * 1
    
            inputs3_syn = decoder3(inputs3, map3)
            loss3_g = CB_loss(inputs3_syn, inputs3_gt) * 1
    
            _, _, output_3_syn, _, _, _, _ = net(inputs3_syn)
            loss3_c_syn = CELoss(output_3_syn, targets).mean() * 1
    
            loss3_ = loss3_c + (loss3_g) + loss3_c_syn
            loss3_.backward()
            optimizer.second_step(zero_grad=True)
    
            loss3 = loss3.cpu()
            loss3_g = loss3_g.cpu()
            del output_3
            del output_3_syn
            del loss3_
            del loss3_c
            del loss3_c_syn
            del inputs3
            del inputs3_gt
            del inputs3_syn
            torch.cuda.empty_cache()
    
            # H4
            # H4 first forward-backward step
            enable_running_stats(net)
            optimizer.zero_grad()
            output_1_final, output_2_final, output_3_final, output_ORI, _, _, _ = net(images)
            ORI_loss = CELoss(output_1_final, targets).mean() + \
                        CELoss(output_2_final, targets).mean() + \
                        CELoss(output_3_final, targets).mean() + \
                        CELoss(output_ORI, targets).mean() * 2
            # 손실 계산 전에 targets 검사
    
            ORI_loss.backward()
            optimizer.first_step(zero_grad=True)
    
            # H4 second forward-backward step
            disable_running_stats(net)
            optimizer.zero_grad()
            output_1_final, output_2_final, output_3_final, output_ORI, _, _, _ = net(images)
            ORI_loss_ = CELoss(output_1_final, targets).mean() + \
                        CELoss(output_2_final, targets).mean() + \
                        CELoss(output_3_final, targets).mean() + \
                        CELoss(output_ORI, targets).mean() * 2
            ORI_loss_.backward()
            optimizer.second_step(zero_grad=True)
    
            ORI_loss = ORI_loss.cpu()
            del output_1_final
            del output_2_final
            del output_3_final
            output_ORI = output_ORI.cpu()
            targets = targets.cpu()
            del images
            del ORI_loss_
            torch.cuda.empty_cache()
    
            train_loss += (loss1.item() + loss2.item() + loss3.item() + ORI_loss.item()) # 
            train_loss1 += loss1.item()
            train_loss2 += loss2.item()
            train_loss3 += loss3.item()
            train_loss4 += (loss1_g.item() + loss2_g.item() + loss3_g.item())
            train_loss5 += ORI_loss.item()
        total = len(train_loader)
        avg_train_loss = train_loss / total
        avg_train_loss1 = train_loss1 / total
        avg_train_loss2 = train_loss2 / total
        avg_train_loss3 = train_loss3 / total
        avg_train_loss4 = train_loss4 / total
        avg_train_loss5 = train_loss5 / total
    
        # Validation
        net.eval()
        val_loss = 0.0
        correct = 0
        correct_com = 0
        total = 0
        all_probs = []
        all_probs_com = []
        all_labels = []
    
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc=f"[Epoch {epoch+1}/{CFG['EPOCHS']}] Validation"):
                images, labels = images.to(device), labels.to(device)
    
                output_1, output_2, output_3, output_ORI, _, _, _ = net(images)
    
                outputs_com = output_1.cpu() + output_2.cpu() + output_3.cpu() + output_ORI.cpu()
                loss = CELoss(output_ORI, labels)
                val_loss += loss.item()
    
                # Accuracy
                _, preds = torch.max(output_ORI, 1)
                _, preds_com = torch.max(outputs_com, 1)
    
                correct += (preds == labels).sum().item()
                correct_com += (preds_com == labels).sum().item()
                total += labels.size(0)
    
                # LogLoss
                probs = F.softmax(output_ORI, dim=1)
                all_probs.extend(probs.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
    
                probs_com = F.softmax(outputs_com, dim=1)
                all_probs_com.extend(probs_com.cpu().numpy())
    
    
        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = 100 * correct / total
        val_com_accuracy = 100 * correct_com / total
        val_logloss = log_loss(all_labels, all_probs, labels=list(range(len(class_names))))
        val_com_logloss = log_loss(all_labels, all_probs_com, labels=list(range(len(class_names))))
        del images
        del loss
        del targets
        del output_1
        del output_2
        del output_3
        del output_ORI
        torch.cuda.empty_cache()
        
        # wandb 
        wandb.log({
            "train_loss": avg_train_loss,
            "train_loss1": avg_train_loss1,
            "train_loss2": avg_train_loss2,
            "train_loss3": avg_train_loss3,
            "train_loss4": avg_train_loss4,
            "train_loss5": avg_train_loss5,
            "val_loss": avg_val_loss,
            "val_accuracy": val_accuracy,
            "val_com_accuracy": val_com_accuracy,
            "val_logloss": val_logloss,
            "val_com_logloss": val_com_logloss,
        })
        
        # 결과 출력
        print(f"Train Loss : {avg_train_loss:.4f} || Valid Loss : {avg_val_loss:.4f} | Valid Accuracy : {val_accuracy:.4f}%")
    
        # Best model 저장
        if val_logloss < best_logloss:
            best_logloss = val_logloss
            torch.save(net.state_dict(), f'best_model.pth')
            print(f"📦 Best model saved at epoch {epoch+1} (logloss: {val_logloss:.4f})")
            torch.save(decoder1, f'decoder1.pth')
            torch.save(decoder1, f'decoder2.pth')
            torch.save(decoder1, f'decoder3.pth')
    
    
        early_stopping(val_logloss, model)
        if early_stopping.early_stop:
            print(f"🛑 Early stopping triggered at epoch {epoch+1}")
            break



In [None]:
# Gradient Accumulation 설정
if gradient_accumulation:

    accumulation_steps = CFG['accumulation_steps']  # 원하는 accumulation step 수로 조정
    
    # 학습 및 검증 루프
    for epoch in range(CFG['EPOCHS']):
        # Train
        net.train()
        train_loss = 0.0
        train_loss1 = 0
        train_loss2 = 0
        train_loss3 = 0
        train_loss4 = 0
        train_loss5 = 0
        
        # Gradient accumulation을 위한 변수들
        accumulated_loss1 = 0
        accumulated_loss2 = 0
        accumulated_loss3 = 0
        accumulated_loss4 = 0
    
        for batch_idx, (images, targets) in enumerate(tqdm(train_loader, desc=f"[Epoch {epoch+1}/{CFG['EPOCHS']}] Training")):
            images, targets = images.to(device), targets.to(device)
    
            for nlr in range(len(optimizer.param_groups)):
                optimizer.param_groups[nlr]['lr'] = cosine_anneal_schedule(epoch, CFG['EPOCHS'], lr[nlr])
            
            sometimes_1 = lambda aug: iaa.Sometimes(0.2, aug)
            sometimes_2 = lambda aug: iaa.Sometimes(0.5, aug)
    
            trans_seq_aug = iaa.Sequential(
                [
                    sometimes_1(iaa.Affine(
                        scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
                        translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
                        rotate=(-15, 15),
                        shear=(-15, 15),
                        order=[0, 1],
                        cval=(0, 1),
                        mode=ia.ALL
                    )),
                    sometimes_2(iaa.GaussianBlur((0, 3.0)))
                ],
                random_order=True
            )
    
            trans_seq = iaa.Sequential(
                [
                    iaa.AdditiveGaussianNoise(
                        loc=0, scale=(0.0, 0.05), per_channel=0.5
                    )
                ],
                random_order=True
            )
            
            # H1 first forward-backward step
            enable_running_stats(net)
            if batch_idx % accumulation_steps == 0:
                optimizer.zero_grad()
    
            inputs1_gt = img_add_noise(images, trans_seq_aug).to(device)
            inputs1 = img_add_noise(inputs1_gt, trans_seq).to(device)
            output_1, _, _, _, map1, _, _ = net(inputs1)
    
            loss1_c = CELoss(output_1, targets).mean() * 1
            inputs1_syn = decoder1(inputs1, map1)
            loss1_g = CB_loss(inputs1_syn, inputs1_gt) * 1
    
            output_1_syn, _, _, _, _, _, _ = net(inputs1_syn)
            loss1_c_syn = CELoss(output_1_syn, targets).mean() * 1
    
            loss1 = (loss1_c + loss1_g + loss1_c_syn) / accumulation_steps  # accumulation으로 나누기
            loss1.backward()
            
            accumulated_loss1 += loss1.item() * accumulation_steps  # 실제 loss 값 저장
            
            if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
                optimizer.first_step(zero_grad=True)
    
            # H1 second forward-backward step
            disable_running_stats(net)
            if batch_idx % accumulation_steps == 0:
                optimizer.zero_grad()
    
            output_1, _, _, _, map1, _, _ = net(inputs1)
            loss1_c = CELoss(output_1, targets).mean() * 1
    
            inputs1_syn = decoder1(inputs1, map1)
            loss1_g = CB_loss(inputs1_syn, inputs1_gt) * 1
    
            output_1_syn, _, _, _, _, _, _ = net(inputs1_syn)
            loss1_c_syn = CELoss(output_1_syn, targets).mean() * 1
    
            loss1_ = (loss1_c + loss1_g + loss1_c_syn) / accumulation_steps
            loss1_.backward()
            
            if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
                optimizer.second_step(zero_grad=True)
    
            loss1_g_cpu = loss1_g.cpu()
    
            del output_1, output_1_syn, loss1_, loss1_c, loss1_c_syn
            del inputs1, inputs1_gt, inputs1_syn
            torch.cuda.empty_cache()
        
            # H2 first forward-backward step
            enable_running_stats(net)
            if batch_idx % accumulation_steps == 0:
                optimizer.zero_grad()
                
            inputs2_gt = img_add_noise(images, trans_seq_aug).to(device)
            inputs2 = img_add_noise(inputs2_gt, trans_seq).to(device)
            _, output_2, _, _, _, map2, _ = net(inputs2)
            loss2_c = CELoss(output_2, targets).mean() * 1
    
            inputs2_syn = decoder2(inputs2, map2)
            loss2_g = CB_loss(inputs2_syn, inputs2_gt) * 1
    
            _, output_2_syn, _, _, _, _, _ = net(inputs2_syn)
            loss2_c_syn = CELoss(output_2_syn, targets).mean() * 1
    
            loss2 = (loss2_c + loss2_g + loss2_c_syn) / accumulation_steps
            loss2.backward()
            
            accumulated_loss2 += loss2.item() * accumulation_steps
            
            if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
                optimizer.first_step(zero_grad=True)
    
            # H2 second forward-backward step
            disable_running_stats(net)
            if batch_idx % accumulation_steps == 0:
                optimizer.zero_grad()
                
            _, output_2, _, _, _, map2, _ = net(inputs2)
            loss2_c = CELoss(output_2, targets).mean() * 1
    
            inputs2_syn = decoder2(inputs2, map2)
            loss2_g = CB_loss(inputs2_syn, inputs2_gt) * 1
    
            _, output_2_syn, _, _, _, _, _ = net(inputs2_syn)
            loss2_c_syn = CELoss(output_2_syn, targets).mean() * 1
    
            loss2_ = (loss2_c + loss2_g + loss2_c_syn) / accumulation_steps
            loss2_.backward()
            
            if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
                optimizer.second_step(zero_grad=True)
    
            loss2_g_cpu = loss2_g.cpu()
            
            del output_2, output_2_syn, loss2_, loss2_c, loss2_c_syn
            del inputs2, inputs2_gt, inputs2_syn
            torch.cuda.empty_cache()
    
            # H3 first forward-backward step
            enable_running_stats(net)
            if batch_idx % accumulation_steps == 0:
                optimizer.zero_grad()
                
            inputs3_gt = img_add_noise(images, trans_seq_aug).to(device)
            inputs3 = img_add_noise(inputs3_gt, trans_seq).to(device)
            _, _, output_3, _, _, _, map3 = net(inputs3)
            loss3_c = CELoss(output_3, targets).mean() * 1
    
            inputs3_syn = decoder3(inputs3, map3)
            loss3_g = CB_loss(inputs3_syn, inputs3_gt) * 1
    
            _, _, output_3_syn, _, _, _, _ = net(inputs3_syn)
            loss3_c_syn = CELoss(output_3_syn, targets).mean() * 1
    
            loss3 = (loss3_c + loss3_g + loss3_c_syn) / accumulation_steps
            loss3.backward()
            
            accumulated_loss3 += loss3.item() * accumulation_steps
            
            if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
                optimizer.first_step(zero_grad=True)
    
            # H3 second forward-backward step
            disable_running_stats(net)
            if batch_idx % accumulation_steps == 0:
                optimizer.zero_grad()
                
            _, _, output_3, _, _, _, map3 = net(inputs3)
            loss3_c = CELoss(output_3, targets).mean() * 1
    
            inputs3_syn = decoder3(inputs3, map3)
            loss3_g = CB_loss(inputs3_syn, inputs3_gt) * 1
    
            _, _, output_3_syn, _, _, _, _ = net(inputs3_syn)
            loss3_c_syn = CELoss(output_3_syn, targets).mean() * 1
    
            loss3_ = (loss3_c + loss3_g + loss3_c_syn) / accumulation_steps
            loss3_.backward()
            
            if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
                optimizer.second_step(zero_grad=True)
    
            loss3_g_cpu = loss3_g.cpu()
            
            del output_3, output_3_syn, loss3_, loss3_c, loss3_c_syn
            del inputs3, inputs3_gt, inputs3_syn
            torch.cuda.empty_cache()
    
            # H4 first forward-backward step
            enable_running_stats(net)
            if batch_idx % accumulation_steps == 0:
                optimizer.zero_grad()
                
            output_1_final, output_2_final, output_3_final, output_ORI, _, _, _ = net(images)
            ORI_loss = (CELoss(output_1_final, targets).mean() + \
                       CELoss(output_2_final, targets).mean() + \
                       CELoss(output_3_final, targets).mean() + \
                       CELoss(output_ORI, targets).mean() * 2) / accumulation_steps
    
            ORI_loss.backward()
            
            accumulated_loss4 += ORI_loss.item() * accumulation_steps
            
            if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
                optimizer.first_step(zero_grad=True)
    
            # H4 second forward-backward step
            disable_running_stats(net)
            if batch_idx % accumulation_steps == 0:
                optimizer.zero_grad()
                
            output_1_final, output_2_final, output_3_final, output_ORI, _, _, _ = net(images)
            ORI_loss_ = (CELoss(output_1_final, targets).mean() + \
                        CELoss(output_2_final, targets).mean() + \
                        CELoss(output_3_final, targets).mean() + \
                        CELoss(output_ORI, targets).mean() * 2) / accumulation_steps
            ORI_loss_.backward()
            
            if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
                optimizer.second_step(zero_grad=True)
    
            del output_1_final, output_2_final, output_3_final, output_ORI
            del images, targets, ORI_loss_
            torch.cuda.empty_cache()
    
            # accumulation step이 완료되었을 때만 loss 누적
            if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
                train_loss += (accumulated_loss1 + accumulated_loss2 + accumulated_loss3 + accumulated_loss4)
                train_loss1 += accumulated_loss1
                train_loss2 += accumulated_loss2
                train_loss3 += accumulated_loss3
                train_loss4 += (loss1_g_cpu.item() + loss2_g_cpu.item() + loss3_g_cpu.item())
                train_loss5 += accumulated_loss4
                
                # 누적 변수 초기화
                accumulated_loss1 = 0
                accumulated_loss2 = 0
                accumulated_loss3 = 0
                accumulated_loss4 = 0
    
        # 평균 계산 시 실제 step 수로 나누기
        actual_steps = (len(train_loader) + accumulation_steps - 1) // accumulation_steps
        avg_train_loss = train_loss / actual_steps
        avg_train_loss1 = train_loss1 / actual_steps
        avg_train_loss2 = train_loss2 / actual_steps
        avg_train_loss3 = train_loss3 / actual_steps
        avg_train_loss4 = train_loss4 / actual_steps
        avg_train_loss5 = train_loss5 / actual_steps
    
        # Validation (기존과 동일)
        net.eval()
        val_loss = 0.0
        correct = 0
        correct_com = 0
        total = 0
        all_probs = []
        all_probs_com = []
        all_labels = []
    
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc=f"[Epoch {epoch+1}/{CFG['EPOCHS']}] Validation"):
                images, labels = images.to(device), labels.to(device)
    
                output_1, output_2, output_3, output_ORI, _, _, _ = net(images)
    
                outputs_com = output_1.cpu() + output_2.cpu() + output_3.cpu() + output_ORI.cpu()
                loss = CELoss(output_ORI, labels)
                val_loss += loss.item()
    
                # Accuracy
                _, preds = torch.max(output_ORI, 1)
                _, preds_com = torch.max(outputs_com, 1)
    
                correct += (preds == labels).sum().item()
                correct_com += (preds_com == labels).sum().item()
                total += labels.size(0)
    
                # LogLoss
                probs = F.softmax(output_ORI, dim=1)
                all_probs.extend(probs.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
    
                probs_com = F.softmax(outputs_com, dim=1)
                all_probs_com.extend(probs_com.cpu().numpy())
    
        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = 100 * correct / total
        val_com_accuracy = 100 * correct_com / total
        val_logloss = log_loss(all_labels, all_probs, labels=list(range(len(class_names))))
        val_com_logloss = log_loss(all_labels, all_probs_com, labels=list(range(len(class_names))))
        
        del images, loss, output_1, output_2, output_3, output_ORI
        torch.cuda.empty_cache()
        
        # wandb 
        wandb.log({
            "train_loss": avg_train_loss,
            "train_loss1": avg_train_loss1,
            "train_loss2": avg_train_loss2,
            "train_loss3": avg_train_loss3,
            "train_loss4": avg_train_loss4,
            "train_loss5": avg_train_loss5,
            "val_loss": avg_val_loss,
            "val_accuracy": val_accuracy,
            "val_com_accuracy": val_com_accuracy,
            "val_logloss": val_logloss,
            "val_com_logloss": val_com_logloss,
        })
        
        # 결과 출력
        print(f"Train Loss : {avg_train_loss:.4f} || Valid Loss : {avg_val_loss:.4f} | Valid Accuracy : {val_accuracy:.4f}%")
    
        # Best model 저장
        if val_logloss < best_logloss:
            best_logloss = val_logloss
            torch.save(net.state_dict(), f'best_model.pth')
            print(f"📦 Best model saved at epoch {epoch+1} (logloss: {val_logloss:.4f})")
            torch.save(decoder1, f'decoder1.pth')
            torch.save(decoder2, f'decoder2.pth')  # decoder2로 수정
            torch.save(decoder3, f'decoder3.pth')  # decoder3으로 수정
    
        early_stopping(val_logloss, model)
        if early_stopping.early_stop:
            print(f"🛑 Early stopping triggered at epoch {epoch+1}")
            break

# Inference

In [None]:
test_dataset = CustomImageDataset(test_root, transform=val_transform, is_test=True)
test_loader = DataLoader(test_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=False)

In [None]:
# 저장된 모델 로드
model = TResnetL368(model_params)
model.load_state_dict(torch.load('best_model.pth', map_location=device))
net_layers = list(model.children())
classifier = net_layers[1:3]
net_layers = net_layers[0]
net_layers = list(net_layers.children())
net = Network_Wrapper(net_layers, len(class_names), classifier)
net.to(device)
# 추론
net.eval()
results = []

with torch.no_grad():
    for images in test_loader:
        images = images.to(device)
        outputs = net(images)
        probs = F.softmax(outputs, dim=1)

        # 각 배치의 확률을 리스트로 변환
        for prob in probs.cpu():  # prob: (num_classes,)
            result = {
                class_names[i]: prob[i].item()
                for i in range(len(class_names))
            }
            results.append(result)

pred = pd.DataFrame(results)

# Submission

In [None]:
submission = pd.read_csv('/kaggle/input/car-classification/sample_submission.csv', encoding='utf-8-sig')

# 'ID' 컬럼을 제외한 클래스 컬럼 정렬
class_columns = submission.columns[1:]
pred = pred[class_columns]

submission[class_columns] = pred.values
submission.to_csv('baseline_submission.csv', index=False, encoding='utf-8-sig')