In [13]:
import os

directory_path = './data/samples'

def get_file_names(directory):
    file_names = []
    for filename in os.listdir(directory):
        file_names.append(filename)
    return file_names

# 파일명 리스트 얻기
file_names = get_file_names(directory_path)

characters = set()
captcha_length = []
dataset = []

for img_path in file_names:
    label = img_path.split(".")[0]
    captcha_length.append(len(label))
    dataset.append((str(img_path), label))
    for ch in label:
        characters.add(ch)

import pandas as pd
characters = sorted(characters)
dataset = pd.DataFrame(dataset, columns=["img_path", "label"], index=None)
dataset = dataset.sample(frac=1.).reset_index(drop=True)
dataset.head()

Unnamed: 0,img_path,label
0,mcg43.png,mcg43
1,c4bgd.png,c4bgd
2,cdfen.png,cdfen
3,ne325.png,ne325
4,bpwd7.png,bpwd7


In [14]:
characters

['2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'm',
 'n',
 'p',
 'w',
 'x',
 'y']

In [17]:
from sklearn.model_selection import train_test_split
seed = 123
training_data, validation_data = train_test_split(dataset, test_size=0.2, random_state=seed)

training_data = training_data.reset_index(drop=True)
validation_data = validation_data.reset_index(drop=True)

char_to_labels = {char:idx for idx, char in enumerate(characters)}
labels_to_char = {val:key for key, val in char_to_labels.items()}

In [21]:
import cv2
import numpy as np

def generate_arrays(df, resize=True, img_height=50, img_width=200):    
    num_items = len(df)
    images = np.zeros((num_items, img_height, img_width), dtype=np.float32)
    labels = [0]*num_items
    
    for i in range(num_items):
        img = cv2.imread(df["img_path"][i])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        
        if resize: 
            img = cv2.resize(img, (img_width, img_height))
        
        img = (img/255.).astype(np.float32)
        label = df["label"][i]
        
        images[i, :, :] = img
        labels[i] = label
    
    return images, np.array(labels)

In [22]:
training_data, training_labels = generate_arrays(df=training_data)
validation_data, validation_labels = generate_arrays(df=validation_data)

[ WARN:0@54.369] global loadsave.cpp:248 findDecoder imread_('nf8b8.png'): can't open/read file: check file path/integrity


error: OpenCV(4.9.0) /io/opencv/modules/imgproc/src/color.cpp:196: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'


In [2]:
from module.custom_dataset import CaptchaDataset
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTFeatureExtractor
import torch.nn as nn
import torch.optim as optim
from torch.nn import CTCLoss
from tqdm import tqdm

# 데이터셋 로드 및 DataLoader 설정
train_dataset = CaptchaDataset('./data/samples', transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

# 모델 설정
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k')
model.classifier = nn.Linear(model.classifier.in_features, len(train_dataset))  # 새 분류 레이어

# 훈련 루프
num_epochs = 10
checkpoint_interval = 2  # 체크포인트 저장 주기 (2 에포크마다)


# 손실 함수 및 옵티마이저 설정
criterion = CTCLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(num_epochs):
    model.train()

    # tqdm을 사용하여 진행 상황을 추적
    progress = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}')

    for inputs, labels in progress:
        inputs = torch.stack(inputs)
        outputs = model(inputs)

        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # tqdm 업데이트
        progress.set_postfix({'Loss': loss.item()})

    # 체크포인트 저장
    if (epoch + 1) % checkpoint_interval == 0:
        checkpoint_path = f"checkpoint_epoch_{epoch + 1}.pth"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
