# import

In [1]:
import os
import random

import timm
import wandb
import torch
import albumentations as A
import pandas as pd
import numpy as np
import torch.nn.functional as F
from augraphy import *
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from pytorch_lightning.loggers import WandbLogger
from albumentations.core.transforms_interface import ImageOnlyTransform
from albumentations.pytorch import ToTensorV2
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.metrics import accuracy_score, f1_score

# Configs

In [2]:
# 시드를 고정합니다.
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True

# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# data config
data_path = '../data/'

# model config
model_list = {
    1 : 'vit_large_patch14_clip_224.openai_ft_in12k',
    2 : 'vit_base_patch16_clip_224.laion2b_ft_in1k',
    3 : 'vit_pe_core_base_patch16_224.fb', # 모델없음.
    4 : 'resnet152',
    6 : 'vit_small_patch16_224',
    5 : 'vit_base_patch16_224',
    7 : 'convnext_base.fb_in22k_ft_in1k',
    8 : 'vit_large_patch16_224',
    9 : 'convnextv2_huge.fcmae_ft_in1k'	, # Out of Memory
    10 : 'convnext_large.fb_in22k_ft_in1k', 
    11 : "convnextv2_base.fcmae_ft_in1k",
    12 : 'convnext_base.fb_in22k_ft_in1k_384',
    13 : 'vit_huge_patch14_224'	
}

model_family = {"resnet" : [model_list[1],
                            model_list[2],
                            model_list[3],
                            model_list[4],
                            model_list[7],
                            model_list[9],
                            model_list[10],
                            model_list[11],
                            model_list[12],],
                "vit" : [model_list[6],
                         model_list[5],
                         model_list[8],
                         model_list[13]]
                            }

num_classes=17

# training config
CFS={"MODEL" : model_list[2],
    "IMG_SIZE" : 224,
     "LR" : 1e-5,
    'EPOCHS' : 200,
    'BATCH_SIZE' : 32,
    "NUM_WORKERS" : 16,
    "ALPHA" : 0.2, #0.1 ~0.7
}

# wandb logging
wandb.finish()
wandb_logger = WandbLogger(
    project="contrastive-learning",
    name=f"{CFS['MODEL']},{CFS['BATCH_SIZE']},{CFS['EPOCHS']},{CFS['LR']}",
    config=CFS,
)

# Augraphy

In [3]:
ink_phase = [
    InkBleed(
        intensity_range=(0.5, 0.6),
        kernel_size=random.choice([(5, 5), (3, 3)]),
        severity=(0.2, 0.4),
        p=0.1,
    ),
    BleedThrough(
        intensity_range=(0.1, 0.3),
        color_range=(32, 224),
        ksize=(17, 17),
        sigmaX=1,
        alpha=random.uniform(0.1, 0.2),
        offsets=(10, 20),
        p=0.1,
    ),
],

paper_phase = [
    ColorPaper(
        hue_range=(0, 255),
        saturation_range=(10, 40),
        p=0.33,
    ),
    OneOf(
        [
        DelaunayTessellation(
            n_points_range=(500, 800),
            n_horizontal_points_range=(500, 800),
            n_vertical_points_range=(500, 800),
            noise_type="random",
            color_list="default",
            color_list_alternate="default",
            ),
        PatternGenerator(
            imgx=random.randint(256, 512),
            imgy=random.randint(256, 512),
            n_rotation_range=(10, 15),
            color="random",
            alpha_range=(0.25, 0.5),
            ),
        VoronoiTessellation(
            mult_range=(50, 80),
            seed=19829813472,
            num_cells_range=(500, 1000),
            noise_type="random",
            background_value=(200, 255),
            ),
        ],
        p=1.0,
    ),
    AugmentationSequence(
        [
            NoiseTexturize(
                sigma_range=(3, 10),
                turbulence_range=(2, 5),
            ),
            BrightnessTexturize(
                texturize_range=(0.9, 0.99),
                deviation=0.03,
            ),
        ],
    ),
]

post_phase = [
    OneOf(
        [
            DirtyDrum(
                line_width_range=(1, 6),
                line_concentration=random.uniform(0.05, 0.15),
                direction=random.randint(0, 2),
                noise_intensity=random.uniform(0.6, 0.95),
                noise_value=(64, 224),
                ksize=random.choice([(3, 3), (5, 5), (7, 7)]),
                sigmaX=0,
                p=0.2,
            ),
            DirtyRollers(
                line_width_range=(2, 32),
                scanline_type=0,
            ),
        ],
        p=0.33,
    ),
    Folding(
        fold_count=10,
        fold_noise=0.0,
        fold_angle_range = (-360,360),
        gradient_width=(0.1, 0.2),
        gradient_height=(0.01, 0.1),
        backdrop_color = (0,0,0),
        p=0.33
    ),
    SubtleNoise(
        subtle_range=random.randint(5, 10),
        p=0.33,
    ),
    Jpeg(
        quality_range=(25, 95),
        p=0.33,
    ),
    Moire(
        moire_density = (15,20),
        moire_blend_method = "normal",
        moire_blend_alpha = 0.1,
        p=0.33
    ),
    ColorShift(
        color_shift_offset_x_range=(3, 5),
        color_shift_offset_y_range=(3, 5),
        color_shift_iterations=(2, 3),
        color_shift_brightness_range=(0.9, 1.1),
        color_shift_gaussian_kernel_range=(3, 3),
        p=0.33
    ),
    Scribbles(
        scribbles_type="random",
        scribbles_location="random",
        scribbles_size_range=(250, 600),
        scribbles_count_range=(1, 6),
        scribbles_thickness_range=(1, 3),
        scribbles_brightness_change=[32, 64, 128],
        scribbles_text="random",
        scribbles_text_font="random",
        scribbles_text_rotate_range=(0, 360),
        scribbles_lines_stroke_count_range=(1, 6),
        p=0.1,
    ),
    BadPhotoCopy(
        noise_type=-1,
        noise_side="random",
        noise_iteration=(1, 2),
        noise_size=(1, 3),
        noise_value=(128, 196),
        noise_sparsity=(0.3, 0.6),
        noise_concentration=(0.1, 0.6),
        blur_noise=random.choice([True, False]),
        blur_noise_kernel=random.choice([(3, 3), (5, 5), (7, 7)]),
        wave_pattern=random.choice([True, False]),
        edge_effect=random.choice([True, False]),
        p=0.33,
    ),
    Faxify(
        scale_range=(0.3, 0.6),
        monochrome=random.choice([0, 1]),
        monochrome_method="random",
        monochrome_arguments={},
        halftone=random.choice([0, 1]),
        invert=1,
        half_kernel_size=random.choice([(1, 1), (2, 2)]),
        angle=(0, 360),
        sigma=(1, 3),
        p=0.1,
    ),
    Geometric(
        scale=(0.5, 1.5),
        translation=(50, -50),
        fliplr=1,
        flipud=1,
        crop=(),
        rotate_range=(3, 5),
        p=0.33,
    ),

]

pipeline = AugraphyPipeline(ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase)

class AugraphyTransform(ImageOnlyTransform):
    def __init__(self, augraphy_pipeline, always_apply=False, p=0.5):
        super().__init__(always_apply, p)
        self.augraphy_pipeline = augraphy_pipeline

    def apply(self, img, **params):
        # NumPy → PIL 변환
        pil_img = Image.fromarray(img)
        # Augraphy 증강 적용
        aug_img = self.augraphy_pipeline(pil_img)
        # PIL → NumPy 변환
        return np.array(aug_img)

Augraphy = AugraphyTransform(augraphy_pipeline=pipeline, p=0.5)

wandb_logger.experiment.config["Augrapy"] = str(pipeline)


[34m[1mwandb[0m: Currently logged in as: [33mhoppure[0m ([33mhoppure-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


# Transform

In [14]:
if CFS['MODEL'] in model_family['resnet']:
    norm_mean = [0.485, 0.456, 0.406]
    norm_std = [0.229, 0.224, 0.225]
else:
    norm_mean = [0.5, 0.5, 0.5]
    norm_std = [0.5, 0.5, 0.5]
    
# augmentation을 위한 transform 코드
trn_transform = A.Compose([
    # 0. augraphy
    Augraphy,
    
    # 1. 기하학적 변환 (Geometric Transformations)
    A.OneOf([
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.15, rotate_limit=15, p=0.5),
        A.OpticalDistortion(distort_limit=0.2, p=0.5),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5)
    ], p=1.0),
    
    # 2. 공간적 변형 (Spatial Transformations)
    A.OneOf([
        A.RandomCrop(height=int(CFS["IMG_SIZE"]*0.9), width=int(CFS["IMG_SIZE"]*0.9), p=0.7),
        A.RandomResizedCrop(size=(CFS["IMG_SIZE"], CFS["IMG_SIZE"]), scale=(0.8, 1.0), p=0.3),
        A.Transpose(p=0.3), 
        A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.2),
    ], p=1.0),
    
    # 3. 색상 변환 (Color Transformations)
    A.OneOf([
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
        A.RandomGamma(gamma_limit=(80, 120), p=0.3),
        A.CLAHE(clip_limit=4.0, p=0.2),
    ], p=1.0),
    
    # 4. 노이즈 및 블러 (Noise & Blur)
    A.OneOf([
        A.GaussNoise(var_limit=(10.0, 50.0), mean=0.0, per_channel=True, p=0.4),
        A.GaussianBlur(blur_limit=(3, 7), p=0.3),
        A.MotionBlur(blur_limit=7, p=0.3),
    ], p=1.0),
    
    # 5. 고급 증강 기법 (Advanced Augmentations)
    A.OneOf([
        A.CoarseDropout(max_holes=8, max_height=16, max_width=16, fill_value=0, p=0.5), # cutout
        A.RandomSunFlare(src_radius=100, p=0.1),
        A.RandomShadow(num_shadows_lower=1, num_shadows_upper=3, p=0.2)
    ], p=1.0),
    
    # 6. 최종 전처리
    A.Resize(CFS["IMG_SIZE"], CFS['IMG_SIZE']),
    A.Normalize(mean=norm_mean, std=norm_std),
    ToTensorV2()
])

# test image 변환을 위한 transform 코드
tst_transform = A.Compose([
    A.Resize(CFS["IMG_SIZE"], CFS['IMG_SIZE']),
    A.Normalize(mean=norm_mean, std=norm_std),
    ToTensorV2(),
])

# WandB에 로깅
wandb_logger.experiment.config["train_transform"] = str(trn_transform)
wandb_logger.experiment.config["test_transform"] = str(tst_transform)


# Dataset

In [None]:
# 데이터셋 클래스를 정의합니다.
class ImageDataset(Dataset):
    def __init__(self, csv, path, transform=None, transform1=None, transform2=None):
        self.df = pd.read_csv(csv).values
        self.path = path
        self.transform = transform
        self.transform1 = transform1
        self.transform2 = transform2

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        name, target = self.df[idx]
        if target == 3 or 7: #
            pass #
        else : #
            target = 0 #
        img = np.array(Image.open(os.path.join(self.path, name)))
        if self.transform:
            img = self.transform(image=img)['image']
            return img, target
        elif self.transform1 and self.transform2:
            img1 = self.transform1(image=img)['image']
            img2 = self.transform2(image=img)['image']
            return img1, img2, target
        else:
            raise ValueError("No valid transform provided.")
        

# datamodule

In [16]:
class DataModule(LightningDataModule):
    def __init__(self, data_path, train_transform, test_transform, batch_size, num_workers):
        super().__init__()
        self.data_path = data_path
        self.train_transform = train_transform
        self.test_transform = test_transform
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_dataset = ImageDataset(
                csv=os.path.join(self.data_path, "train.csv"),
                path=os.path.join(self.data_path, "train"),
                transform1=self.train_transform,
                transform2=self.train_transform,
            )
            
        if stage == "test" or stage == "predict" or stage is None:
            self.test_dataset = ImageDataset(
                csv=os.path.join(self.data_path, "sample_submission.csv"),
                path=os.path.join(self.data_path, "test"),
                transform=self.test_transform
            )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=False
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=0,
            pin_memory=True
        )
    
    def predict_dataloader(self):
        return self.test_dataloader()
    
datamodule = DataModule(data_path='../data/',
    train_transform=trn_transform,
    test_transform=tst_transform,
    batch_size=CFS['BATCH_SIZE'],
    num_workers=CFS['NUM_WORKERS']
)

# Constrastive Loss

In [17]:
class SupConLoss(torch.nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, features, labels=None):
        # labels: (batch_size)
        device = features.device
        batch_size = features.shape[0] // 2
        features = features.view(batch_size, 2, -1)  # (B, n_views, feat_dim)
        features = F.normalize(features, dim=2)
        
        if labels is not None:
            labels = labels.contiguous().view(-1, 1)
            mask = torch.eq(labels, labels.T).float().to(device)  # (B, B)
        else:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)  # (B*n_views, feat_dim)
        anchor_feature = contrast_feature
        anchor_count = contrast_count

        # Compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature
        )

        # For numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)

        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask


        # Compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)

        # Compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-12)

        # Loss
        loss = -mean_log_prob_pos
        loss = loss.mean()
        return loss

        # ... (SupCon 공식 구현 참고)
        # https://github.com/HobbitLong/SupContrast/blob/master/losses.py
        # (여기서는 간략화, 실제 구현은 위 링크 참고)

def knn_accuracy(embeddings, labels, k=1):
    embeddings = F.normalize(embeddings, dim=1)
    sim_matrix = torch.matmul(embeddings, embeddings.T)
    # 자기 자신 제외
    sim_matrix.fill_diagonal_(-float('inf'))
    topk = sim_matrix.topk(k, dim=1).indices
    preds = labels[topk]
    # (N, k) 중 가장 많이 나온 라벨로 예측
    preds = torch.mode(preds, dim=1).values
    acc = (preds == labels).float().mean().item()
    return acc


# Model

In [18]:
class LightningModel(LightningModule):
    def __init__(self, model_name, num_classes, lr, feat_dim=128, alpha=0.1):
        super().__init__()
        self.model = timm.create_model(
            model_name=model_name, 
            pretrained=True, 
            num_classes=num_classes
        )
        in_features = self.model.get_classifier().in_features
        self.model.reset_classifier(0)  # 분류 헤드 제거 (backbone만 남음)
        self.proj_head = torch.nn.Sequential(
            torch.nn.Linear(in_features, in_features),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features, feat_dim)
        )
        self.classifier = torch.nn.Linear(in_features, num_classes)  # 분류 헤드 따로 추가
        self.lr = lr
        self.num_classes = num_classes
        self.contrastive_loss = SupConLoss()
        self.alpha = alpha
        self.train_embeddings = []
        self.train_targets = []
        self.train_losses = []

    def forward(self, x):
        feat = self.model.forward_features(x)  # ViT 등에서는 forward_features 사용
        z = self.proj_head(feat)
        z = F.normalize(z, dim=1)
        return z

    def training_step(self, batch, batch_idx):
        img1, img2, labels = batch  # DataLoader가 이렇게 반환해야 함
        imgs = torch.cat([img1, img2], dim=0)  # (2*B, C, H, W)
        features = self(imgs)
        
        con_loss = self.contrastive_loss(features, labels)

        # 분류용 (img1만 사용)
        feat1 = self.model.forward_features(img1)
        logits = self.classifier(feat1)
        logits = logits[:, 0, :]

        ce_loss = F.cross_entropy(logits, labels)
        loss = ce_loss + self.alpha * con_loss

        self.log('train_loss_step', loss, prog_bar=True)

        with torch.no_grad():
            emb = features[:labels.size(0), 0, :]  # img1의 CLS 토큰만 추출 (B, feat_dim)
            self.train_embeddings.append(emb.detach().cpu())
            self.train_targets.append(labels.detach().cpu())
            self.train_losses.append(loss.detach().cpu())
        return loss

    def on_train_epoch_end(self):
        if self.train_embeddings:
            all_embeddings = torch.cat(self.train_embeddings, dim=0)
            all_targets = torch.cat(self.train_targets, dim=0)
            acc = knn_accuracy(all_embeddings, all_targets, k=1)
            epoch_loss = torch.stack(self.train_losses).mean()
            
            self.log('train_loss', epoch_loss, prog_bar=True)
            self.log('train_knn_acc', acc, prog_bar=True)
            
            self.train_embeddings.clear()
            self.train_targets.clear()
            self.train_losses.clear()

    def predict_step(self, batch, batch_idx):
        img, _ = batch  # batch 구조가 (img1, img2, label)일 때만
        feature = self.model.forward_features(img)
        logits = self.classifier(feature)
        logits = logits[:, 0, :]
        return logits.argmax(dim=1)

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.lr)
        scheduler = StepLR(optimizer, step_size=45, gamma=0.5)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1
            }
        }
lightning_model = LightningModel(CFS['MODEL'], num_classes, CFS["LR"], alpha=CFS['ALPHA'])

# Training

In [19]:
# 트레이너 설정
trainer = Trainer(
    max_epochs=CFS["EPOCHS"],
    accelerator='cuda' if torch.cuda.is_available() else 'cpu',
    devices="auto",
    logger=wandb_logger,
)

# 학습 실행
trainer.fit(
    model=lightning_model,
    datamodule=datamodule
)

wandb.finish()


Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type              | Params | Mode 
---------------------------------------------------------------
0 | model            | VisionTransformer | 85.8 M | train
1 | proj_head        | Sequential        | 689 K  | train
2 | classifier       | Linear            | 13.1 K | train
3 | contrastive_loss | SupConLoss        | 0      | train
---------------------------------------------------------------
86.5 M    Trainable params
0         Non-trainable params
86.5 M    Total params
346.006   Total estimated model params size (MB)
270       Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.


0,1
epoch,▁▁▁▁▁▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇█████
train_knn_acc,▁▃▃▄▄▅▅▅▆▅▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇█▇▇▇██▇██████
train_loss,█▅▄▅▄▄▃▃▃▂▂▂▂▃▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss_step,▅▂▂▄▂▂▄▃▁▁▁▁▂▁▁▁▅▁▁▁▁▁▁▁▁▃▁▁▁▂▁▁▁█▁▁▁▁▁▃
trainer/global_step,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇████

0,1
epoch,199.0
train_knn_acc,0.92803
train_loss,0.87498
train_loss_step,0.17285
trainer/global_step,9999.0


# Prediction

In [20]:
predictions = trainer.predict(
    model=lightning_model, 
    datamodule=datamodule
)

# 4. 결과 처리
all_preds = torch.cat(predictions).cpu().numpy()  # [n_samples]
# 샘플 제출 파일 로드
submission = pd.read_csv(os.path.join(data_path, "sample_submission.csv"))
# 예측값으로 타겟 열 업데이트
submission["target"] = all_preds


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

In [11]:
# 결과 저장
submission.to_csv("submission.csv", index=False)