In [2]:
!pip install -r requirements.txt
#pip install pillow numpy pandas seaborn matplotlib tqdm scikit-learn torch torchvision torchmetrics timm wandb albumentations

Collecting torchmetrics (from -r requirements.txt (line 10))
  Obtaining dependency information for torchmetrics from https://files.pythonhosted.org/packages/29/1b/b38033e61c28e52dde7bd459df6567c04c127ee153722c73b9acd0fe550b/torchmetrics-1.4.1-py3-none-any.whl.metadata
  Downloading torchmetrics-1.4.1-py3-none-any.whl.metadata (20 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics->-r requirements.txt (line 10))
  Obtaining dependency information for lightning-utilities>=0.8.0 from https://files.pythonhosted.org/packages/ea/d5/ed204bc738672c17455019b5e0c7c8d1effb0ea17707150ca50336298ca0/lightning_utilities-0.11.6-py3-none-any.whl.metadata
  Downloading lightning_utilities-0.11.6-py3-none-any.whl.metadata (5.2 kB)
Downloading torchmetrics-1.4.1-py3-none-any.whl (866 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m866.2/866.2 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading lightning_utilities-0.11.6-py3-none-any.whl (26 kB)
In

라이브러리 목록

In [148]:
import os
import PIL
import json
import math
import timm
import time
import wandb
import numpy as np
import pandas as pd
import seaborn as sns
import albumentations as album
import matplotlib.pyplot as plt

from tqdm import tqdm
from collections import Counter
from IPython.display import display
from sklearn.metrics import confusion_matrix
from albumentations.pytorch import ToTensorV2
from tqdm.notebook import tqdm as tqdm_notebook
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score, accuracy_score

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchvision import transforms
from torchmetrics import F1Score, classification
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader, random_split

# <함수 준비> 일괄 실행 가능!!!

쿠다 비워내는 함수

In [107]:
def print_gpu_memory():
    print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"Cached: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

커스텀 Config 구현

In [108]:
class Config:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)
    
    def __repr__(self):
        return f"Config({', '.join(f'{k}={v}' for k, v in self.__dict__.items())})"
    
    def __iter__(self):
        return iter(self.__dict__.items())

    def __getitem__(self, key):
        return getattr(self, key)

    def get(self, key, default=None):
        return getattr(self, key, default)
    
    def update(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

데이터셋

In [109]:
class ImageDataset(Dataset):
    def __init__(self, csv, path, transform=None):
        self.df = pd.read_csv(csv).values
        self.path = path
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        name, target = self.df[idx]
        img = np.array(PIL.Image.open(os.path.join(self.path, name)))
        if self.transform:
            img = self.transform(image=img)['image']
        return img, int(target)  # 타겟을 정수로 반환

def get_transform(img_size, mean, std):
    return album.Compose([
        album.LongestMaxSize(max_size=img_size),
        album.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=0, value=(0,0,0)),
        album.Normalize(mean=mean, std=std),
        ToTensorV2(),
    ])

def create_data_set(base_path, img_size= 448, mean = (0.485, 0.456, 0.406), std= (0.229, 0.224, 0.225)):
    # 경로 설정
    train_csv = os.path.join(base_path, "aug_train.csv")
    train_dir = os.path.join(base_path, "aug_train")
    #train_csv = os.path.join(base_path, "train.csv")
    #train_dir = os.path.join(base_path, "train")

    test_csv = os.path.join(base_path, "true_label.csv")
    test_dir = os.path.join(base_path, "test")

    # Dataset 정의
    trn_dataset = ImageDataset(
        train_csv, 
        train_dir, 
        transform=get_transform(img_size, mean, std)
    )
    tst_dataset = ImageDataset(
        test_csv, 
        test_dir, 
        transform=get_transform(img_size, mean, std)
    )

    print(f"Training dataset size: {len(trn_dataset)}")
    print(f"Test dataset size: {len(tst_dataset)}")
    return trn_dataset, tst_dataset

In [110]:
def to_device(x, device):
    return x.to(device) if isinstance(x, torch.Tensor) else x

class DeviceDataLoader:
    def __init__(self, dataloader, device):
        self.dataloader = dataloader
        self.device = device

    def __iter__(self):
        for batch in self.dataloader:
            yield tuple(to_device(x, self.device) for x in batch)

    def __len__(self):
        return len(self.dataloader)

메트릭

In [111]:
class CustomMetrics(torch.nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.f1_macro = F1Score(task="multiclass", num_classes=num_classes, average='macro')
        self.accuracy_per_class = classification.MulticlassAccuracy(num_classes=num_classes, average=None)

    def forward(self, preds, target):
        # preds: (batch_size, num_classes)
        # target: (batch_size,)
        if preds.dim() == 2:
            preds = preds.argmax(dim=1)        
        # F1 Macro 계산
        f1_macro = self.f1_macro(preds, target)

        # 각 클래스별 Accuracy 계산
        accuracies = self.accuracy_per_class(preds, target)

        # 가장 낮은 Accuracy 찾기
        lowest_accuracy = torch.min(accuracies)

        return {
            "f1_macro": f1_macro,
            "lowest_class_accuracy": lowest_accuracy,
            "accuracies_per_class": accuracies
        }

TTA 모듈 -> 차후 분석위한 분석툴

In [112]:
class TTA:
    def __init__(self, model, config, threshold = 0.95, temperature = 0.5,  tta_transforms = [
            transforms.Lambda(lambda x: x),  # 원본
            transforms.Lambda(lambda x: transforms.functional.rotate(x, 90)),  # 90도 회전
            transforms.Lambda(lambda x: transforms.functional.rotate(x, 180)),  # 180도 회전
            transforms.Lambda(lambda x: transforms.functional.rotate(x, 270)),  # 270도 회전
            transforms.Lambda(lambda x: transforms.functional.hflip(x)),  # 수평 뒤집기
            transforms.Lambda(lambda x: transforms.functional.vflip(x)),  # 수직 뒤집기
            transforms.Compose([  # 90도 회전 + 수평 뒤집기
                transforms.Lambda(lambda x: transforms.functional.rotate(x, 90)),
                transforms.Lambda(lambda x: transforms.functional.hflip(x))
            ]),
            transforms.Compose([  # 270도 회전 + 수평 뒤집기
                transforms.Lambda(lambda x: transforms.functional.rotate(x, 270)),
                transforms.Lambda(lambda x: transforms.functional.hflip(x))
            ])
        ]):
        self.model = model
        self.config = config
        self.threshold = threshold
        self.temperature = temperature
        self.tta_transforms = tta_transforms
        self.device = next(model.parameters()).device

    def tta_inference(self, image):
        self.model.eval()
        probs = []
        for transform in self.tta_transforms:
            augmented = transform(image)
            output = self.model(augmented.unsqueeze(0).to(self.device))
            prob = F.softmax(output, dim=1)
            probs.append(prob.squeeze().cpu())
        return torch.stack(probs)

    def aggregate_predictions(self, probs, method):
        if method == 'mean':
            result = probs.mean(dim=0)
        elif method == 'max':
            result = probs.max(dim=0)[0]
        elif method == 'temp_sharpen':
            result = self._temp_sharpen(probs, self.temperature)
        elif method == 'mode':
            result = self._mode(probs)
        elif method == 'modethreshold':
            result = self._mode_threshold(probs, self.threshold)
        elif method == 'no_tta':
            result = probs[0]
        else:
            raise ValueError(f"Unsupported TTA aggregation method: {method}")
    
        # 결과를 softmax 처리
        return F.softmax(result, dim=0)

    def _temp_sharpen(self, probs, temperature = 0.5):
        sharpened = probs ** (1 / temperature)
        return sharpened.mean(dim=0) / sharpened.mean(dim=0).sum()

    def _mode(self, probs):
        labels = probs.argmax(dim=1)
        mode = Counter(labels.tolist()).most_common(1)[0][0]
        result = torch.zeros(probs.shape[1])
        result[mode] = 1
        return result

    def _mode_threshold(self, probs, threshold = 0.95):
        high_conf = (probs > threshold).any(dim=1)
        if high_conf.any():
            high_conf_labels = probs[high_conf].argmax(dim=1)
            mode = Counter(high_conf_labels.tolist()).most_common(1)[0][0]
        else:
            mode = Counter(probs.argmax(dim=1).tolist()).most_common(1)[0][0]
        result = torch.zeros(probs.shape[1])
        result[mode] = 1
        return result
        
    def _ensemble_predictions(self, all_preds):
        df = pd.DataFrame(all_preds)
        mode_preds = df.mode(axis=1)
        
        ensemble_preds = []
        for i in range(len(df)):
            if len(mode_preds.iloc[i].dropna()) == 1:
                # 동점이 없는 경우
                ensemble_preds.append(mode_preds.iloc[i, 0])
            else:
                # 동점이 있는 경우
                mode_group = set(mode_preds.iloc[i].dropna())
                if df.loc[i, 'max'] in mode_group:
                    ensemble_preds.append(df.loc[i, 'max'])
                else:
                    # max가 모드 그룹에 없을 경우 우선순위에 따라 선택
                    for method in ['temp_sharpen', 'modethreshold', 'mode', 'mean', 'no_tta']:
                        if df.loc[i, method] in mode_group:
                            ensemble_preds.append(df.loc[i, method])
                            break
                    else:
                        # 모든 방법이 모드 그룹에 없는 경우 그냥 max(거의 일어나지 않을 것임)
                        ensemble_preds.append(mode_preds.iloc[i, 'max'])
        return ensemble_preds
    
    def _save_dataframe_as_image(self, df, filename='df_accuracy_table.png'):
        fig, ax = plt.subplots(figsize=(12, len(df) * 0.5 + 1))  # 행 수에 따라 세로 크기 조정
        ax.axis('tight')
        ax.axis('off')
        
        table = ax.table(cellText=df.values,
                         colLabels=df.columns,
                         rowLabels=df.index,
                         cellLoc='center',
                         loc='center')
        
        table.auto_set_font_size(False)
        table.set_fontsize(10)
        table.scale(1.2, 1.2)
        plt.savefig(filename, bbox_inches='tight', dpi=300)
        plt.close()
        print(f"Table saved as {filename}")   

    def tta_predict(self, loader):
        methods = ['no_tta', 'mean', 'max', 'temp_sharpen', 'mode', 'modethreshold']
        all_preds = {method: [] for method in methods}
        all_probs = {method: [] for method in methods}
        all_targets = []

        self.model.eval()
        with torch.no_grad():
            for images, targets in tqdm(loader, desc="Generating predictions"):
                images = images.to(self.device)
                all_targets.append(targets)
                for image in images:
                    probs = self.tta_inference(image)
                    for method in methods:
                        agg_prob = self.aggregate_predictions(probs, method)
                        all_preds[method].append(agg_prob.argmax().item())
                        all_probs[method].append(agg_prob.cpu().numpy())

        all_targets_tensor = torch.cat(all_targets)
        ensemble_preds = self._ensemble_predictions(all_preds)
        all_preds['ensemble'] = ensemble_preds
        # Note: We don't add 'ensemble' to all_probs

        return all_preds, all_probs, all_targets_tensor

    def tta_evaluate(self, loader):
        all_preds, all_probs, all_targets = self.tta_predict(loader)
        
        results = {}
        metrics = CustomMetrics(self.config.num_classes)
        
        for method in all_preds.keys():
            preds = torch.tensor(all_preds[method])
            targets = all_targets.clone().detach()
            metric_results = metrics(preds, targets)
            results[method] = {
                'preds': preds,
                'probs': all_probs.get(method) if method != 'ensemble' else None,
                'f1_macro': metric_results['f1_macro'].item(),
                'lowest_class_accuracy': metric_results['lowest_class_accuracy'].item(),
                'accuracies_per_class': metric_results['accuracies_per_class'].tolist()
            }
        
        return results, all_targets
    
    def save_predictions(self, predictions, probs, output_dir='submissions'):
        os.makedirs(output_dir, exist_ok=True)
        sample_submission_path = os.path.join(self.config.base_path, "sample_submission.csv")
        sample_submission = pd.read_csv(sample_submission_path)
        
        for method, preds in predictions.items():
            # 예측 결과 저장
            submission = sample_submission.copy()
            submission['target'] = preds
            try: 
                pred_filename = f'{self.config.group_name}_submission_tta_{method}.csv'
            except:
                pred_filename = f'data_orig_submission_tta_{method}.csv'
            pred_path =  f'{output_dir}/{pred_filename}'
            submission.to_csv(pred_path, index=False)
            print(f"Saved predictions for method: {method}")
            
            # 확률 저장
            if method != 'ensemble' and probs is not None:
                prob_df = pd.DataFrame(probs[method], columns=[f'prob_{i}' for i in range(len(probs[method][0]))])
                try:
                    prob_filename = f'{self.config.group_name}_labelprob_tta_{method}.csv'
                except:
                    prob_filename = f'data_orig_labelprob_tta_{method}.csv'
                prob_dir = os.path.join(output_dir, 'probs')
                os.makedirs(prob_dir, exist_ok=True)
                prob_path = f'{prob_dir}/{prob_filename}'
                prob_df.to_csv(prob_path, index=False)
                print(f"Saved label probabilities for method: {method}")

    def automatic_tta_submission(self, test_loader, output_dir='submissions'):
        predictions, probs, dummy_target = self.tta_predict(test_loader)
        self.save_predictions(predictions, probs)   

    def analyze_predictions(self, preds, targets, method ='no_tta', chart_dir= 'charts/data_orig'):
        os.makedirs(chart_dir, exist_ok=True)
        preds = preds.clone().detach()
        targets = targets.clone().detach()
        accuracy_per_class = []
        for class_idx in range(self.config.num_classes):
            class_mask = (targets == class_idx)
            if class_mask.sum() > 0:
                class_accuracy = (preds[class_mask] == targets[class_mask]).float().mean().item()
            else:
                class_accuracy = 0.0
            accuracy_per_class.append(class_accuracy)

        plt.figure(figsize=(12, 6))
        plt.bar(range(len(accuracy_per_class)), accuracy_per_class)
        plt.title(f'Accuracy per Class - {method}')
        plt.xlabel('Class')
        plt.ylabel('Accuracy')
        try: 
            plt.xticks(range(len(self.config.class_names)), self.config.class_names, rotation=90)
        except:
            plt.xticks(range(self.config.num_classes))
        plt.tight_layout()
        plt.savefig(os.path.join(chart_dir, f'accuracy_per_class_{method}.png'))
        plt.close()
        
        conf_matrix = confusion_matrix(targets, preds)
        plt.figure(figsize=(12, 10))
        sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
        plt.title(f'Confusion Matrix - {method}')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        try: 
            plt.xticks(range(len(self.config.class_names)), self.config.class_names, rotation=90)
            plt.yticks(range(len(self.config.class_names)), self.config.class_names, rotation=0)
        except:
            pass
        plt.tight_layout()
        plt.savefig(os.path.join(chart_dir, f'confusion_matrix_{method}.png'))
        plt.close()
        
        return accuracy_per_class

    def analyze_tta(self, val_loader):
        results, targets = self.tta_evaluate(val_loader)
        targets_tensor = torch.as_tensor(targets)

        accuracy_per_class_all = {}
        
        # 차트 저장 경로 설정
        try:
            disagreement_dir = f'disagreement_samples/{self.config.group_name}'
            chart_dir = f'charts/{self.config.group_name}'
        except AttributeError:
            chart_dir = os.path.join('charts', 'data_orig')
            disagreement_dir = os.path.join('disagreement_samples','data_orig')
        os.makedirs(disagreement_dir, exist_ok=True)
        os.makedirs(chart_dir, exist_ok=True)
        
        for method in results.keys():
            print(f'Summary Results for the aggregate method: {method}')
            accuracy_per_class = self.analyze_predictions(results[method]['preds'], targets, method, chart_dir)
            accuracy_per_class_all[method] = accuracy_per_class
        
        # Find images with disagreement
        disagreement_mask = torch.zeros(len(targets_tensor), dtype=torch.bool, device=targets_tensor.device)
        for method in results.keys():
            if method != 'ensemble':
                preds_tensor = torch.as_tensor(results[method]['preds'], device=targets_tensor.device)
                disagreement_mask |= (preds_tensor != targets_tensor)

        disagreement_indices = disagreement_mask.nonzero().squeeze()
        
        # Count disagreements per class
        disagreement_counts = torch.bincount(targets.clone().detach()[disagreement_indices], minlength=self.config.num_classes)
        
        plt.figure(figsize=(12, 6))
        plt.bar(range(self.config.num_classes), disagreement_counts)
        plt.title('Disagreements per Class')
        plt.xlabel('Class')
        plt.ylabel('Number of Disagreements')
        try: 
            plt.xticks(range(len(self.config.class_names)), self.config.class_names, rotation=90)
        except:
            plt.xticks(range(self.config.num_classes))
        plt.tight_layout()
        plt.savefig(os.path.join(chart_dir, 'disagreements_per_class.png'))
        plt.close()
        
        # Save disagreement images
        dataset = val_loader.dataset
        for idx in disagreement_indices[:min(len(disagreement_indices), 100)]:  # Save up to 100 samples
            try:
                img, _ = dataset[idx]
                img_np = img.permute(1, 2, 0).numpy()
                img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())  # Normalize to [0, 1]
                img_path = os.path.join(disagreement_dir, f'disagreement_sample_{idx}.png')
                plt.imsave(img_path, img_np)
            except Exception as e:
                print(f"Error saving disagreement sample {idx}: {str(e)}")
        print(f"Saved disagreement samples to {disagreement_dir}")
        
        # Create a new dataset with only disagreement samples
        disagreement_dataset = torch.utils.data.Subset(dataset, disagreement_indices)
        
        # Create DataFrame for accuracy per class
        try: 
            class_names = self.config.class_names
        except:
            class_names = [f"Class {i}" for i in range(self.config.num_classes)]
        df_accuracy = pd.DataFrame(accuracy_per_class_all).T
        df_accuracy.columns = class_names
        
        # Add average accuracy row
        df_accuracy.loc['Average'] = df_accuracy.mean()
        
        # Round all values to 4 decimal places
        df_accuracy = df_accuracy.round(4)
        
        print("\nAccuracy per class for all methods:")
        display(df_accuracy)
        # Save the dataframe as an image
        self._save_dataframe_as_image(df_accuracy, os.path.join(chart_dir, 'accuracy_per_class_all_methods.png'))

        return disagreement_dataset, df_accuracy

Mixup & Cutmix

In [113]:
def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def mixup_data(x, y, alpha=1.0):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

In [114]:
def cutmix_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)  # np.int() 대신 int() 사용
    cut_h = int(H * cut_rat)  # np.int() 대신 int() 사용

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def cutmix_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    y_a, y_b = y, y[index]
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))

    return x, y_a, y_b, lam

커스텀 모델 -> 수정, 보완 가능

In [115]:
class CustomModel(nn.Module):
    def __init__(self, config):
        super(CustomModel, self).__init__()
        self.config = config
        self.backbone = timm.create_model(
            config.model_name,
            pretrained=config.pretrained,
            num_classes=0,
            in_chans=3
        )
        
        if hasattr(self.backbone, 'num_features'):
            num_features = self.backbone.num_features
        elif hasattr(self.backbone, 'fc'):
            num_features = self.backbone.fc.in_features
        else:
            raise ValueError("Unable to determine number of features in backbone")
        
        self.classifier = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, config.num_classes)
        )

        self.softmax = nn.Softmax(dim=1)        
        self.metrics = CustomMetrics(num_classes=config.num_classes).to(config.device)

    def forward(self, x):
        features = self.backbone(x)
        out = self.classifier(features)
        return out

    def training_step(self, batch, loss_func):
        x, y = batch
        rand_num = np.random.rand()
        if rand_num < self.config.mixup_prob:
            mixed_x, y_a, y_b, lam = mixup_data(x, y, self.config.mixup_alpha)
            logits = self(mixed_x)
            loss = mixup_criterion(loss_func, logits, y_a, y_b, lam)
        elif rand_num < self.config.mixup_prob + self.config.cutmix_prob:
            mixed_x, y_a, y_b, lam = cutmix_data(x, y, self.config.cutmix_alpha)
            logits = self(mixed_x)
            loss = cutmix_criterion(loss_func, logits, y_a, y_b, lam)
        else:
            logits = self(x)
            loss = loss_func(logits, y)
        preds = logits.argmax(dim=1)
        return loss, preds

    def validation_step(self, batch, loss_func):
        x, y = batch
        logits = self(x)
        preds = logits.argmax(dim=1)
        loss = loss_func(logits, y)
        return loss, preds

    def test_step(self, batch):
        x, labels  = batch
        logits = self(x)
        prob = self.softmax(logits)
        return prob, labels 

손실함수

In [116]:
class SmoothFocalLoss(nn.Module):
    def __init__(self, num_classes, alpha=None, gamma=2, smoothing=0.1):
        super().__init__()
        self.num_classes = num_classes
        self.gamma = gamma
        self.smoothing = smoothing
        
        if alpha is None:
            self.alpha = torch.ones(num_classes)
        else:
            self.alpha = torch.tensor(alpha)
        
        self.alpha = self.alpha / self.alpha.sum()
        
    def forward(self, inputs, targets):
        inputs = inputs.float()
        
        # 타겟이 클래스 인덱스인 경우 원-핫 인코딩으로 변환
        if targets.dim() == 1:
            targets = F.one_hot(targets, num_classes=self.num_classes).float()
        
        # 라벨 스무딩 적용
        targets_smooth = (1 - self.smoothing) * targets + self.smoothing / self.num_classes
        
        log_probs = F.log_softmax(inputs, dim=1)
        loss = -targets_smooth * log_probs
        
        pt = torch.exp(-loss)
        focal_loss = (1 - pt)**self.gamma * loss
        
        if self.alpha is not None:
            alpha = self.alpha.to(inputs.device)
            focal_loss = alpha.unsqueeze(0) * focal_loss
        
        return focal_loss.sum(dim=1).mean()

In [117]:
class FocalLoss(nn.Module):
    def __init__(self, num_classes, alpha=None, gamma=2):
        super().__init__()
        self.num_classes = num_classes
        self.gamma = gamma
        
        if alpha is None:
            self.alpha = torch.ones(num_classes)
        else:
            self.alpha = torch.tensor(alpha)
        
        self.alpha = self.alpha / self.alpha.sum()
        
    def __call__(self, inputs, targets):
        inputs = inputs.float()
        
        # 타겟이 클래스 인덱스인 경우 원-핫 인코딩으로 변환
        if targets.dim() == 1:
            targets = F.one_hot(targets, num_classes=self.num_classes).float()
        
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt)**self.gamma * ce_loss
        
        if self.alpha is not None:
            alpha = self.alpha.to(inputs.device)
            focal_loss = alpha.unsqueeze(0) * focal_loss
        
        return focal_loss.mean()

In [118]:
class CELoss(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes
        
    def __call__(self, inputs, targets):
        inputs = inputs.float()
        
        # 타겟이 클래스 인덱스인 경우 원-핫 인코딩으로 변환
        if targets.dim() == 1:
            targets = F.one_hot(targets, num_classes=self.num_classes).float()
        
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        
        return ce_loss.mean()

In [119]:
class SmoothCELoss(nn.Module):
    def __init__(self, num_classes, smoothing=0.1):
        super().__init__()
        self.num_classes = num_classes
        self.smoothing = smoothing
        
    def __call__(self, inputs, targets):
        inputs = inputs.float()
        
        # 타겟이 클래스 인덱스인 경우 원-핫 인코딩으로 변환
        if targets.dim() == 1:
            targets = F.one_hot(targets, num_classes=self.num_classes).float()
        
        # 라벨 스무딩 적용
        targets_smooth = (1 - self.smoothing) * targets + self.smoothing / self.num_classes
        
        log_probs = F.log_softmax(inputs, dim=1)
        loss = -targets_smooth * log_probs
        
        return loss.sum(dim=1).mean()

In [120]:
def get_loss_function(config):
    if config.loss == 'ce':
        return CELoss(num_classes=config.num_classes)
    elif config.loss == 'focal':
        return FocalLoss(num_classes=config.num_classes, alpha=config.focal_alpha, gamma=config.focal_gamma)
    elif config.loss == 'smoothce':
        return SmoothCELoss(num_classes=config.num_classes, smoothing=config.smoothing)
    elif config.loss == 'smoothfocal':
        return SmoothFocalLoss(num_classes=config.num_classes, alpha=config.focal_alpha, gamma=config.focal_gamma, smoothing=config.smoothing)
    else:
        raise ValueError(f"Unsupported loss function: {config.loss}")

옵티마이저

In [121]:
class Ranger(optim.Optimizer): # from torch.optim.optimizer import Optimizer
    def __init__(self, params, lr=1e-3,                       # lr
                 alpha=0.5, k=6, N_sma_threshhold=5,           # Ranger options
                 betas=(.95, 0.999), eps=1e-5, weight_decay=0,  # Adam options
                 # Gradient centralization on or off, applied to conv layers only or conv + fc layers
                 use_gc=True, gc_conv_only=False
                 ):

        # parameter checks
        if not 0.0 <= alpha <= 1.0:
            raise ValueError(f'Invalid slow update rate: {alpha}')
        if not 1 <= k:
            raise ValueError(f'Invalid lookahead steps: {k}')
        if not lr > 0:
            raise ValueError(f'Invalid Learning Rate: {lr}')
        if not eps > 0:
            raise ValueError(f'Invalid eps: {eps}')


        # prep defaults and init torch.optim base
        defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas,
                        N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)

        # adjustable threshold
        self.N_sma_threshhold = N_sma_threshhold

        # look ahead params

        self.alpha = alpha
        self.k = k

        # radam buffer for state
        self.radam_buffer = [[None, None, None] for ind in range(10)]

        # gc on or off
        self.use_gc = use_gc

        # level of gradient centralization
        self.gc_gradient_threshold = 3 if gc_conv_only else 1

        print(
            f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}")
        if (self.use_gc and self.gc_gradient_threshold == 1):
            print(f"GC applied to both conv and fc layers")
        elif (self.use_gc and self.gc_gradient_threshold == 3):
            print(f"GC applied to conv layers only")

    def __setstate__(self, state):
        print("set state called")
        super(Ranger, self).__setstate__(state)

    def step(self, closure=None):
        loss = None

        # Evaluate averages and grad, update param tensors
        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()

                if grad.is_sparse:
                    raise RuntimeError(
                        'Ranger optimizer does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]  # get state dict for this param

                if len(state) == 0:  

                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)

                    # look ahead weight storage now in state dict
                    state['slow_buffer'] = torch.empty_like(p.data)
                    state['slow_buffer'].copy_(p.data)

                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(
                        p_data_fp32)

                # begin computations
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                # GC operation for Conv layers and FC layers
                if grad.dim() > self.gc_gradient_threshold:
                    grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))

                state['step'] += 1

                # compute variance mov avg
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                # compute mean moving avg
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                buffered = self.radam_buffer[int(state['step'] % 10)]

                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * \
                        state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma
                    if N_sma > self.N_sma_threshhold:
                        step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (
                            N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    else:
                        step_size = 1.0 / (1 - beta1 ** state['step'])
                    buffered[2] = step_size

                if group['weight_decay'] != 0:
                    p_data_fp32.add_(-group['weight_decay']
                                     * group['lr'], p_data_fp32)

                # apply lr
                if N_sma > self.N_sma_threshhold:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(-step_size *
                                         group['lr'], exp_avg, denom)
                else:
                    p_data_fp32.add_(-step_size * group['lr'], exp_avg)

                p.data.copy_(p_data_fp32)

                # integrated look ahead...
                # we do it at the param level instead of group level
                if state['step'] % group['k'] == 0:
                    # get access to slow param tensor
                    slow_p = state['slow_buffer']
                    # (fast weights - slow weights) * alpha
                    slow_p.add_(self.alpha, p.data - slow_p)
                    # copy interpolated weights to RAdam param tensor
                    p.data.copy_(slow_p)

        return loss

In [122]:
def get_optimizer(model_params, config):
    if config.optimizer == 'SGD':
        return optim.SGD(
            model_params,
            lr=config.learning_rate,
            momentum=config.momentum,
            weight_decay=config.weight_decay
        )
    elif config.optimizer == 'Adam':
        return optim.Adam(
            model_params,
            lr=config.learning_rate,
            betas=(config.adam_beta1, config.adam_beta2),
            eps=config.adam_epsilon,
            weight_decay=config.weight_decay
        )
    elif config.optimizer == 'AdamW':
        return optim.AdamW(
            model_params,
            lr=config.learning_rate,
            betas=(config.adam_beta1, config.adam_beta2),
            eps=config.adam_epsilon,
            weight_decay=config.weight_decay
        )
    elif config.optimizer == 'RMSprop':
        return optim.RMSprop(
            model_params,
            lr=config.learning_rate,
            alpha=config.rmsprop_alpha,
            eps=config.rmsprop_epsilon,
            weight_decay=config.weight_decay,
            momentum=config.momentum
        )
    elif config.optimizer == 'Adadelta':
        return optim.Adadelta(
            model_params,
            lr=config.learning_rate,
            rho=config.adadelta_rho,
            eps=config.adadelta_epsilon,
            weight_decay=config.weight_decay
        )
    elif config.optimizer == 'Ranger':
        return Ranger(
            model_params,
            lr=config.learning_rate,
            alpha=config.ranger_alpha,
            k=config.ranger_k,
            N_sma_threshhold=config.ranger_N_sma_threshhold,
            betas=(config.ranger_beta1, config.ranger_beta2),
            eps=config.ranger_eps,
            weight_decay=config.weight_decay,
            use_gc=config.ranger_use_gc,
            gc_conv_only=config.ranger_gc_conv_only
        )
    else:
        raise ValueError(f"Unsupported optimizer: {config.optimizer}")

스케쥴러

In [123]:
def get_scheduler(optimizer, config):
    if config.scheduler == 'LambdaLR':
        return optim.lr_scheduler.LambdaLR(
            optimizer=optimizer,
            lr_lambda=lambda epoch: config.lambda_factor ** epoch,
            last_epoch=config.last_epoch,
            verbose=config.verbose
        )
    elif config.scheduler == 'MultiplicativeLR':
        return optim.lr_scheduler.MultiplicativeLR(
            optimizer=optimizer,
            lr_lambda=lambda epoch: config.lambda_factor ** epoch
        )
    elif config.scheduler == 'StepLR':
        return optim.lr_scheduler.StepLR(
            optimizer, 
            step_size=config.step_size, 
            gamma=config.gamma
        )
    elif config.scheduler == 'CosineAnnealingLR':
        return optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 
            T_max=config.T_max, 
            eta_min=config.eta_min
        )
    elif config.scheduler == 'ReduceLROnPlateau':
        return optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode=config.mode
        )
    elif config.scheduler == 'CyclicLR':
        return optim.lr_scheduler.CyclicLR(
            optimizer, 
            base_lr=config.base_lr,
            max_lr=config.max_lr,
            step_size_up=config.step_size_up,
            mode=config.mode,
            gamma=config.gamma
        )
    elif config.scheduler == 'OneCycleLR':
        return optim.lr_scheduler.OneCycleLR(
            optimizer, 
            max_lr=config.max_lr,
            steps_per_epoch=config.steps_per_epoch, 
            epochs=config.epochs,
            anneal_strategy=config.anneal_strategy
        )
    elif config.scheduler == 'CosineAnnealingWarmRestarts':
        return optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, 
            T_0=config.T_0,
            T_mult=config.T_mult, 
            eta_min=config.eta_min
        )
    else:
        raise ValueError(f"Unsupported scheduler: {config.scheduler}")

훈련 - 트레이너

In [151]:
class Trainer:
    def __init__(self, model, config):
        self.model = model
        self.config = config
        self.group_name = self.config.group_name
        self.save_config()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.best_val_f1 = 0
        self.start_epoch = 0
        self.total_epochs = self.config.epochs
        self.loss_func = self.get_loss_function()
        self.optimizer = self.get_optimizer()
        self.scheduler = self.get_scheduler()
        self.scaler = GradScaler(enabled=config.use_mixed_precision)
        self.early_stopping_counter = 0
        self.best_val_loss = float('inf')
        self.metrics = CustomMetrics(num_classes=config.num_classes).to(self.device)
        
        self.use_mixed_precision = config.use_mixed_precision
        if self.use_mixed_precision:
            self.scaler = GradScaler()

    #################### 기본 세팅 ####################
    def initialize_model(self):
        return CustomModel(self.config).to(self.config.device)

    def get_loss_function(self):
        return get_loss_function(self.config)

    def get_optimizer(self):
        return get_optimizer(self.model.parameters(), self.config)

    def get_scheduler(self):
        return get_scheduler(self.optimizer, self.config)
    
    ##########################################################
    @staticmethod
    def get_short_model_name(model_name):
        return model_name.split('.')[0]

    def generate_new_group_name(self, old_group_name = False):
        if old_group_name:
            parts = old_group_name.rsplit('_', 1)
            new_suffix = wandb.util.generate_id()[:8]
            return f"{parts[0]}_{new_suffix}"
        else:
            model_short_name = self.get_short_model_name(self.config.model_name)
            return f"{model_short_name}_{self.config.optimizer}_{wandb.util.generate_id()[:8]}"
    
    ################세이브 로드 ####################
    def save_checkpoint(self, epoch, val_f1):
        checkpoint_path = f'models/{self.group_name}/checkpoint_epoch_{epoch+1}.pt'
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'config': dict(self.config),
            'val_f1_macro': val_f1,
        }, checkpoint_path)
        if val_f1 > self.best_val_f1:
            self.best_val_f1 = val_f1
            best_model_path = f'models/{self.group_name}/best_model.pt'
            torch.save({
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'scheduler_state_dict': self.scheduler.state_dict(),
                'config': dict(self.config),
                'val_f1_macro': val_f1,
            }, best_model_path)

    def load_model(self, path):
        if not os.path.exists(path):
            print(f"Model file not found: {path}")
            return False
        try:
            checkpoint = torch.load(path, map_location=self.device)
            self.config.__dict__.update(checkpoint['config'])
            self.model = CustomModel(self.config).to(self.device)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            self.start_epoch = checkpoint['epoch'] + 1
            self.best_val_f1 = checkpoint['val_f1_macro']
            self.group_name = self.config.group_name
            print(f"Model loaded from {path}")
            print(f"Continuing training from epoch {self.start_epoch}")
            print(f"Best validation F1 score: {self.best_val_f1:.4f}")
            self.model.train()
            return True
        except Exception as e:
            print(f"Error loading model from {path}: {str(e)}")
            return False
        
    def load_best_model(self, group_name=None):
        if group_name is None:
            group_name = self.group_name
        best_model_path = os.path.join('models', group_name, 'best_model.pt')
        if not os.path.exists(best_model_path):
            print(f"No best model found for group {group_name}")
            return False
        return self.load_model(best_model_path)
    
    ######################## Config 저장 #######################
    def save_config(self):
        config_path = f'models/{self.group_name}/config.json'
        os.makedirs(os.path.dirname(config_path), exist_ok=True)
        with open(config_path, 'w') as f:
            json.dump(dict(self.config), f)

    @staticmethod
    def load_config(group_name):
        config_path = f'models/{group_name}/config.json'
        with open(config_path, 'r') as f:
            config_dict = json.load(f)
        return Config(**config_dict)
    
    #################### 훈련 스크립트 ####################
    def train_epoch_mixed_precision(self, train_loader, epoch):
        self.model.train()
        total_loss = 0
        metrics = self.initialize_metrics()
        pbar = tqdm_notebook(train_loader, desc=f"Epoch {epoch+1}/{self.total_epochs}", leave=False)
        for batch in pbar:
            x, y = batch
            with autocast():
                loss, preds = self.model.training_step(batch, self.loss_func)
            self.scaler.scale(loss).backward()
            if self.config.gradient_clip_val > 0:
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip_val)
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.optimizer.zero_grad()
            total_loss += loss.item()
            self.update_metrics(metrics, {"preds": preds, "targets": y})
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        return total_loss / len(train_loader), self.calculate_metrics(metrics)

    def train_epoch_full_precision(self, train_loader, epoch):
        self.model.train()
        total_loss = 0
        metrics = self.initialize_metrics()
        pbar = tqdm_notebook(train_loader, desc=f"Epoch {epoch+1}/{self.total_epochs}", leave=False)
        for batch in pbar:
            x, y = batch
            loss, preds = self.model.training_step(batch, self.loss_func)
            loss.backward()
            if self.config.gradient_clip_val > 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip_val)
            self.optimizer.step()
            self.optimizer.zero_grad()
            total_loss += loss.item()
            self.update_metrics(metrics, {"preds": preds, "targets": y})
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        return total_loss / len(train_loader), self.calculate_metrics(metrics)
    
    def validate(self, val_loader):
        self.model.eval()
        total_loss = 0
        metrics = self.initialize_metrics()
        val_pbar = tqdm_notebook(val_loader, desc="Validation", leave=False)
        with torch.no_grad():
            for batch in val_pbar:
                x, y = batch
                if self.config.use_mixed_precision:
                    with autocast():
                        loss, preds = self.model.validation_step(batch, self.loss_func)
                else:
                    loss, preds = self.model.validation_step(batch, self.loss_func)
                total_loss += loss.item()
                self.update_metrics(metrics, {"preds": preds, "targets": y})
                val_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        return total_loss / len(val_loader), self.calculate_metrics(metrics)
        
    def train(self, train_loader, validation_loader, is_validate = False):
        wandb.init(project=self.config.project_name, config=vars(self.config), group=self.group_name, job_type= "new_run")
        
        start_time = time.time()
        epoch_pbar = tqdm_notebook(range(self.start_epoch, self.total_epochs), desc="Epochs", leave=True)
        for epoch in epoch_pbar:
            epoch_start_time = time.time()
            # Training
            if self.use_mixed_precision:
                train_loss, train_metrics = self.train_epoch_mixed_precision(train_loader, epoch)
            else:
                train_loss, train_metrics = self.train_epoch_full_precision(train_loader, epoch)
            # Validation
            if is_validate:
                valid_loss, valid_metrics = self.validate(validation_loader)
            else:
                valid_loss, valid_metrics = 0, {'f1_macro': 0, 'lowest_class_accuracy': 0}

            epoch_time = time.time() - epoch_start_time
            self.log_metrics(epoch, train_loss, train_metrics, valid_loss, valid_metrics, is_validate)
            self.update_scheduler(train_loss)
            self.save_checkpoint(epoch, train_metrics['f1_macro'])

            if is_validate:
                print(f"Epoch {epoch+1}/{self.total_epochs} - "
                    f"Loss: {train_loss:.4f}, F1 Macro: {train_metrics['f1_macro']:.4f}, "
                    f"Lowest_class_accuracy: {train_metrics['lowest_class_accuracy']:.4f}, "
                    f"Val Loss: {valid_loss:.4f}, Val F1 Macro: {valid_metrics['f1_macro']:.4f}, "
                    f"Val Lowest_class_accuracy: {valid_metrics['lowest_class_accuracy']:.4f}, Time: {epoch_time:.2f}s")  
            else:
                print(f"Epoch {epoch+1}/{self.total_epochs} - "
                    f"Loss: {train_loss:.4f}, F1 Macro: {train_metrics['f1_macro']:.4f}, "
                    f"Lowest_class_accuracy: {train_metrics['lowest_class_accuracy']:.4f}, Time: {epoch_time:.2f}s")  
                   
            # Update epoch progress bar
            if is_validate:
                epoch_pbar.set_postfix({
                    'Train Loss': f'{train_loss:.4f}',
                    'Val Loss': f'{valid_loss:.4f}',
                    'Train F1': f'{train_metrics["f1_macro"]:.4f}',
                    'Val F1': f'{valid_metrics["f1_macro"]:.4f}',
                    'Train Lowest_class_accuracy': f"{train_metrics['lowest_class_accuracy']:.4f}",
                    'Val Lowest_class_accuracy': f"{valid_metrics['lowest_class_accuracy']:.4f}",
                })
            else:
                epoch_pbar.set_postfix({
                    'Train Loss': f'{train_loss:.4f}',
                    'Train F1': f'{train_metrics["f1_macro"]:.4f}',
                    'Train Lowest_class_accuracy': f"{train_metrics['lowest_class_accuracy']:.4f}"
                })            
            if self.early_stopping(train_loss):
                print(f"Early stopping triggered after epoch {epoch+1}")
                break

        total_time = time.time() - start_time
        print(f"Training completed. Total time: {total_time:.2f}s")
        return train_loss, train_metrics, valid_loss, valid_metrics
    
    def setup_training(self, checkpoint_path=None, continue_training=False, additional_epochs=0):
        if checkpoint_path:
            self.load_model(checkpoint_path)
        else:
            self.load_model(checkpoint_path)
            self.start_epoch = 0
            self.best_val_f1 = 0
            self.group_name = self.generate_new_group_name()

        if continue_training:
            self.group_name = self.generate_new_group_name(self.group_name)
            print(f"Continuing training with new group_name: {self.group_name}")
            self.total_epochs = self.start_epoch + additional_epochs
        else:
            self.total_epochs = self.config.epochs

        self.config.group_name = self.group_name
        self.model.train()
        self.scaler = GradScaler(enabled=self.config.use_mixed_precision)
        
        if continue_training:
            self.optimizer = self.get_optimizer()
            self.scheduler = self.get_scheduler()

    #################### 기록관련 ####################
    def initialize_metrics(self):
        return {"preds": [], "targets": []}

    def update_metrics(self, metrics, batch_metrics):
        metrics["preds"].extend(batch_metrics["preds"].cpu().numpy())
        metrics["targets"].extend(batch_metrics["targets"].cpu().numpy())

    def calculate_metrics(self, metrics):
        preds = torch.from_numpy(np.array(metrics["preds"])).to(self.device)
        targets = torch.from_numpy(np.array(metrics["targets"])).to(self.device)
        return self.metrics(preds, targets)

    def log_metrics(self, epoch, train_loss, train_metrics, valid_loss, valid_metrics, is_validate = False):
        if is_validate:
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": train_loss,
                "train_f1_macro": train_metrics['f1_macro'],
                "train_lowest_class_accuracy": train_metrics['lowest_class_accuracy'],
                "valid_loss": valid_loss,
                "valid_f1_macro": valid_metrics['f1_macro'],
                "valid_lowest_class_accuracy": valid_metrics['lowest_class_accuracy'],
                "learning_rate": self.optimizer.param_groups[0]['lr']
            })
            for i, acc in enumerate(train_metrics['accuracies_per_class']):
                wandb.log({f"train_class_{i}_accuracy": acc.item()})
            for i, acc in enumerate(valid_metrics['accuracies_per_class']):
                wandb.log({f"valid_class_{i}_accuracy": acc.item()})
        else:
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": train_loss,
                "train_f1_macro": train_metrics['f1_macro'],
                "train_lowest_class_accuracy": train_metrics['lowest_class_accuracy'],
                "learning_rate": self.optimizer.param_groups[0]['lr']
            })
            for i, acc in enumerate(train_metrics['accuracies_per_class']):
                wandb.log({f"train_class_{i}_accuracy": acc.item()})       

    #################### 훈련 관련 ####################
    def update_scheduler(self, val_loss):
        if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            self.scheduler.step(val_loss)
        else:
            self.scheduler.step()
    
    def early_stopping(self, val_loss):
        if val_loss < self.best_val_loss - self.config.early_stopping_delta:
            self.best_val_loss = val_loss
            self.early_stopping_counter = 0
        else:
            self.early_stopping_counter += 1
            if self.early_stopping_counter >= self.config.early_stopping_patience:
                return True
        return False


  lambda data: self._console_raw_callback("stderr", data),
Epochs:   0%|          | 1/300 [11:55<59:26:57, 715.78s/it, Train Loss=0.1032, Val Loss=0.0953, Train F1=0.0853, Val F1=0.1413, Train Lowest_class_accuracy=0.0000, Val Lowest_class_accuracy=0.0000]


# <세팅하기!> (이전까지 전부 실행 후 순차적 실행)

## 1. 기본세팅

1. 모델 고르시오 tiny, mobilenetv4 GMACs, GFLOPS따져서

In [None]:
#timm.list_models('cafo*',pretrained=True)
timm.list_models('tiny*',pretrained=True)

2. 기본설정 - Config

In [134]:
# 경로 및 기본설정
base_path = 'data_orig/'
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# FineTuning 대상 변수 미리 꺼내놓기
num_classes = 17
learning_rate = 1e-3
optimizer = 'Ranger' 
img_size = 1024
batch_size = 8
gradient_clip_val = 1 #0이면 클리핑 안함
use_mixed_precision = True

# 가중치 설정
focal_alpha = [1]*17
focal_alpha[3] = 3
focal_alpha[4] = 3
focal_alpha[7] = 4
focal_alpha[14] = 3

model_name = "tiny_vit_5m_224.dist_in22k_ft_in1k"
#model_name = "caformer_s36.sail_in22k_ft_in1k"
#model_name = "caformer_m36.sail_in22k_ft_in1k"#
#model_name = "tiny_vit_11m_224.dist_in22k_ft_in1k"

PROJECT_NAME = "k-fold-cross-validation-pytorch"
def get_short_model_name(model_name):
    return model_name.split('.')[0]
model_short_name = get_short_model_name(model_name)
GROUP_NAME = f"{model_short_name}_{optimizer}_{wandb.util.generate_id()[:8]}"
JOB_TYPE = 'train'

In [125]:
# 새버전 Config k-fold, manual warmup 삭제
config_manual = {
    # 기본 설정
    "seed": 42,
    "device": device,
    "num_workers": 0,
    "group_name": GROUP_NAME,
    "base_path": base_path,
    "project_name": PROJECT_NAME,

    # 데이터 관련 설정
    "batch_size": batch_size,
    "img_size": img_size,
    "num_classes": num_classes,

    # 모델 관련 설정
    "model_name": model_name,  # 예시
    "pretrained": True,

    # 훈련 관련 설정
    "epochs": 300,
    "learning_rate": learning_rate,
    "weight_decay": 1e-5,

    # 손실 함수 관련 설정
    "loss": "smoothfocal",  # 'ce', 'focal', 'smoothce', 'smoothfocal'
    "focal_gamma": 2,
    "smoothing": 0.1,
    "focal_alpha": focal_alpha,

    # 데이터 증강 관련 설정
    "mixup_alpha": 1.0,
    "cutmix_alpha": 1.0,
    "mixup_prob": 0.2,
    "cutmix_prob": 0,

    # 옵티마이저 설정
    "optimizer": optimizer,  # 'SGD', 'Adam', 'AdamW', 'RMSprop', 'Adadelta', 'Ranger'
    "momentum": 0.9,

    # Adam과 AdamW 특정 설정
    "adam_beta1": 0.9,
    "adam_beta2": 0.999,
    "adam_epsilon": 1e-8,

    # RMSprop 특정 설정
    "rmsprop_alpha": 0.99,
    "rmsprop_epsilon": 1e-8,

    # Adadelta 특정 설정
    "adadelta_rho": 0.9,
    "adadelta_epsilon": 1e-6,

    # Ranger 특정 설정
    "ranger_alpha": 0.5,
    "ranger_k": 6,
    "ranger_N_sma_threshhold": 5,
    "ranger_beta1": 0.95,
    "ranger_beta2": 0.999,
    "ranger_eps": 1e-5,
    "ranger_use_gc": True,
    "ranger_gc_conv_only": False,

    # 스케줄러 설정
    "scheduler": "CosineAnnealingWarmRestarts",  # 'LambdaLR', 'MultiplicativeLR', 'StepLR', 'CosineAnnealingLR', 
                                        #'ReduceLROnPlateau', 'CyclicLR', 'OneCycleLR', 'CosineAnnealingWarmRestarts'
    "verbose": False,

    # LambdaLR 설정
    "lambda_factor": 0.95,
    "last_epoch": -1,

    # StepLR 설정
    "step_size": 10,
    "gamma": 0.5,

    # ReduceLROnPlateau 설정
    "mode": "min",

    # CyclicLR 설정
    "base_lr": 0.00005,
    "max_lr": 0.0001,
    "step_size_up": 5,

    # OneCycleLR 설정
    "steps_per_epoch": 10,
    "anneal_strategy": "linear",

    # CosineAnnealingLR 설정
    "T_max": 50,
    "eta_min": 0,

    # CosineAnnealingWarmRestarts 설정
    "T_0": 10,
    "T_mult": 1,

    # 추가 훈련 기법 설정
    "early_stopping_patience": 100,
    "early_stopping_delta": 0.001,
    "use_mixed_precision": use_mixed_precision,
    "gradient_clip_val": gradient_clip_val, #0이면 안함
}
config = Config(**config_manual)

3. wandb 로긴

In [73]:
# Use wandb-core, temporary for wandb's new backend  
wandb.require("core")
wandb.login()



True

## 2. 데이터 셋 준비 

### 전체 데이터 훈련 및 외부 데이터 검증

In [81]:
# 데이터셋 준비
train_ds, test_dataset = create_data_set(base_path, config.img_size)

# DataLoader 정의
train_loader = DataLoader(
    train_ds,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=True,
    drop_last=False
)
test_loader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=True
)

# 데이터 로더를 GPU로 이동
train_loader = DeviceDataLoader(train_loader, device)
test_loader = DeviceDataLoader(test_loader, device)

Training dataset size: 100480
Test dataset size: 3140


### 트레이닝 데이터 일부로 검증

In [80]:
# 데이터셋 준비
train_ds, test_dataset = create_data_set(base_path, config.img_size)
train_size = int(0.9 * len(train_ds))
val_size = len(train_ds) - train_size
train_dataset, validation_dataset = random_split(train_ds, [train_size, val_size])
# DataLoader 정의
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=True,
    drop_last=False
)
validation_loader = DataLoader(
    validation_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=True
)
test_loader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=True
)

# 데이터 로더를 GPU로 이동
train_loader = DeviceDataLoader(train_loader, device)
validation_loader = DeviceDataLoader(validation_loader, device)
test_loader = DeviceDataLoader(test_loader, device)

Training dataset size: 100480
Test dataset size: 3140


### 디버깅 용도로 미니셋 만들기

In [129]:
print(int(0.01 * len(train_ds)))
train_ds, test_dataset = create_data_set(base_path, config.img_size)
train_size = int(0.01 * len(train_ds))
val_size = int(0.01 * len(train_ds))
rem_size = len(train_ds) - train_size - val_size
train_dataset, validation_dataset, rem_dataset = random_split(train_ds, [train_size, val_size, rem_size])
# DataLoader 정의
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=True,
    drop_last=False
)
val_loader = DataLoader(
    validation_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=True
)
test_loader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=True
)
# 데이터 로더를 GPU로 이동
train_loader = DeviceDataLoader(train_loader, device)
validation_loader = DeviceDataLoader(validation_loader, device)
test_loader = DeviceDataLoader(test_loader, device)

1004
Training dataset size: 100480
Test dataset size: 3140


## 3. 훈련 세팅

1. 손실함수, 옵티마이저, 스케줄러 확인 -> 실행 필요 없음

In [84]:
loss_func = get_loss_function(config)
optimizer_choice = get_optimizer(model.parameters(), config)
scheduler = get_scheduler(optimizer_choice, config)
loss_func, optimizer_choice, scheduler

Ranger optimizer loaded. 
Gradient Centralization usage = True
GC applied to both conv and fc layers


(SmoothFocalLoss(),
 Ranger (
 Parameter Group 0
     N_sma_threshhold: 5
     alpha: 0.5
     betas: (0.95, 0.999)
     eps: 1e-05
     initial_lr: 0.001
     k: 6
     lr: 0.001
     step_counter: 0
     weight_decay: 1e-05
 ),
 <torch.optim.lr_scheduler.CosineAnnealingWarmRestarts at 0x1a55683a470>)

2. 모델 준비

In [141]:
# 모델 준비 -> 필수!!!!
model = CustomModel(config).to(config.device)
# 백본 확인
print(config.model_name)

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/tiny_vit_5m_224.dist_in22k_ft_in1k)
INFO:timm.models._hub:[timm/tiny_vit_5m_224.dist_in22k_ft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


tiny_vit_5m_224.dist_in22k_ft_in1k


3. 훈련 준비 - 메모리 비워내기

In [142]:
torch.cuda.empty_cache()
print_gpu_memory()

Allocated: 6.21 GB
Cached: 6.45 GB


# <훈련>

In [152]:
# 일반 훈련
trainer = Trainer(model, config)
trainer.train(train_loader, test_loader, is_validate = True) #validate 안하려면 False로

Ranger optimizer loaded. 
Gradient Centralization usage = True
GC applied to both conv and fc layers


VBox(children=(Label(value='0.013 MB of 0.013 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁█
learning_rate,█▁
train_class_0_accuracy,▁█
train_class_10_accuracy,▁█
train_class_11_accuracy,█▁
train_class_12_accuracy,▁█
train_class_13_accuracy,▁▁
train_class_14_accuracy,▁▁
train_class_15_accuracy,▁█
train_class_16_accuracy,█▁

0,1
epoch,2.0
learning_rate,0.00098
train_class_0_accuracy,0.43939
train_class_10_accuracy,0.06849
train_class_11_accuracy,0.0
train_class_12_accuracy,0.01562
train_class_13_accuracy,0.0
train_class_14_accuracy,0.0
train_class_15_accuracy,0.09722
train_class_16_accuracy,0.65574


Epochs:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 1/300:   0%|          | 0/126 [00:00<?, ?it/s]

Validation:   0%|          | 0/393 [00:00<?, ?it/s]

Epoch 1/300 - Loss: 0.0577, F1 Macro: 0.3706, Lowest_class_accuracy: 0.0000, Val Loss: 0.0451, Val F1 Macro: 0.4652, Val Lowest_class_accuracy: 0.0000, Time: 145.79s


Epoch 2/300:   0%|          | 0/126 [00:00<?, ?it/s]

Validation:   0%|          | 0/393 [00:00<?, ?it/s]

Epoch 2/300 - Loss: 0.0508, F1 Macro: 0.4324, Lowest_class_accuracy: 0.0000, Val Loss: 0.0448, Val F1 Macro: 0.5269, Val Lowest_class_accuracy: 0.0000, Time: 153.79s


Epoch 3/300:   0%|          | 0/126 [00:00<?, ?it/s]

Validation:   0%|          | 0/393 [00:00<?, ?it/s]

Epoch 3/300 - Loss: 0.0461, F1 Macro: 0.4451, Lowest_class_accuracy: 0.0000, Val Loss: 0.0435, Val F1 Macro: 0.5308, Val Lowest_class_accuracy: 0.0000, Time: 143.28s


Epoch 4/300:   0%|          | 0/126 [00:00<?, ?it/s]

Validation:   0%|          | 0/393 [00:00<?, ?it/s]

Epoch 4/300 - Loss: 0.0386, F1 Macro: 0.5384, Lowest_class_accuracy: 0.0000, Val Loss: 0.0372, Val F1 Macro: 0.5843, Val Lowest_class_accuracy: 0.0000, Time: 155.84s


Epoch 5/300:   0%|          | 0/126 [00:00<?, ?it/s]

Validation:   0%|          | 0/393 [00:00<?, ?it/s]

Epoch 5/300 - Loss: 0.0343, F1 Macro: 0.5867, Lowest_class_accuracy: 0.0000, Val Loss: 0.0330, Val F1 Macro: 0.6242, Val Lowest_class_accuracy: 0.0000, Time: 152.16s


Epoch 6/300:   0%|          | 0/126 [00:00<?, ?it/s]

Validation:   0%|          | 0/393 [00:00<?, ?it/s]

Epoch 6/300 - Loss: 0.0326, F1 Macro: 0.6261, Lowest_class_accuracy: 0.0000, Val Loss: 0.0301, Val F1 Macro: 0.6579, Val Lowest_class_accuracy: 0.0000, Time: 132.10s


Epoch 7/300:   0%|          | 0/126 [00:00<?, ?it/s]

KeyboardInterrupt: 

# 체크포인트, 연속 학습

체크 포인트 경로 (상대경로)

In [None]:
checkpoint_path = 'models/tiny_vit_11m_224_Ranger_hzeqx347/checkpoint_epoch_36.pt'

처음부터 트레이닝

In [None]:
trainer = Trainer(model, config) #아무거나 불러와도 됨
trainer.setup_training(checkpoint_path= checkpoint_path, continue_training=False)
trainer.train(train_loader, validation_loader)

이어서 트레이닝 (옵티마이저, 스케줄러 불러옴)

In [None]:
# 체크포인트에서 훈련 계속 (10 에폭 추가)
trainer = Trainer(model, config)
trainer.setup_training(checkpoint_path= checkpoint_path, continue_training=True, additional_epochs=10)
trainer.train(train_loader, validation_loader)

# 모델 분석 및 결과제출

모델 분석 및 결과제출 - 특정 에포크 모델 불러오기
- 저장 위치로 불러온다

In [None]:
# 모델 저장 위치
model_path = 'models/caformer_m36_Ranger_1zo3gt64/checkpoint_epoch_10.pt'

In [80]:
# 분석 예시
trainer = Trainer(model, config) #아무거나 사용해도 됨 , 사용한 모델과 컨피그는 아무거나
trainer.load_model(model_path)
tta = TTA(trainer.model, trainer.config)

#test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
#val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# 테스트 데이터에 대한 예측 및 제출 파일 생성
tta.automatic_tta_submission(test_loader, config)

Ranger optimizer loaded. 
Gradient Centralization usage = True
GC applied to both conv and fc layers


Unexpected keys (head.fc.fc1.bias, head.fc.fc1.weight, head.fc.norm.bias, head.fc.norm.weight) found while loading pretrained weights. This may be expected if model is being adapted.


set state called
Model loaded from models/caformer_m36_Ranger_1zo3gt64/checkpoint_epoch_10.pt
Continuing training from epoch 10
Best validation F1 score: 0.9992


Generating predictions: 100%|██████████| 785/785 [05:36<00:00,  2.33it/s]


Saved predictions for method: no_tta
Saved label probabilities for method: no_tta
Saved predictions for method: mean
Saved label probabilities for method: mean
Saved predictions for method: max
Saved label probabilities for method: max
Saved predictions for method: temp_sharpen
Saved label probabilities for method: temp_sharpen
Saved predictions for method: mode
Saved label probabilities for method: mode
Saved predictions for method: modethreshold
Saved label probabilities for method: modethreshold
Saved predictions for method: ensemble


모델 분석 및 결과제출 - 베스트 모델 불러오기
- 그룹 이름으로 불러온다

In [None]:
# 모델 이름 (폴더명임)
model_name_anal = 'tiny_vit_11m_224_Ranger_4j1cx18a'

In [None]:
trainer = Trainer(model, config) #아무거나 사용해도 됨 , 사용한 모델과 컨피그는 아무거나
trainer.load_best_model(group_name=model_name_anal)
tta = TTA(trainer.model, trainer.config)

#test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
#val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# 테스트 데이터에 대한 예측 및 제출 파일 생성
tta.automatic_tta_submission(test_loader, config)

In [82]:
# 검증 데이터에 대한 TTA 분석
disagreement_dataset = tta.analyze_tta(train_loader)

Generating predictions:  16%|█▋        | 4093/25120 [29:23<2:27:46,  2.37it/s]

Focal Alpha 업데이트

In [None]:
# Focal Weight 업데이트
error_rates = [0.115, 0.115, 0.106, 0.021, 0.018, 0.014, 0.011, 0.011, 0.004, 0.004, 0.004, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
max_error = max(error_rates)
alphas = [min(1.0, (rate / max_error) * 2) for rate in error_rates]
alphas = [max(0.1, alpha) for alpha in alphas]

# 추가 구축 필요

3,4,7,14
오분류한 데이터들 모으기 -> 증강 -> 학습
(기존모델 fine tune)

SIMCLR -> 자주 틀리는 라벨에 한해서 구현하면 될듯 하다. 그리고 그 라벨로 예측한 것에 2차분류기를 simclr로 설정

Knowledge Distillation 크고 좋은 모델을 날잡고 훈련해서 knowledge distillation을 시도해봐도 좋을듯 함

### 2차분류기

Progressive Resizing

Adversarial Training (성능 약화 예상)

Convolutional Block Attention Module (CBAM) -> 해볼 가치 충분. CNN 기반에 잘 돌아감

Squeeze-and-Excitation (SE) 블록