In [22]:
# CNN을 이용해 점자 이미지들을 훈련시켜 문장을 점자로 번역하는 코드

In [40]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os

In [41]:
# CNN 모델 정의
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.fc = nn.Linear(128*7*7, 26)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [42]:
# 점자 이미지 데이터셋 클래스 정의
class BrailleDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.braille_chars = [chr(i) for i in range(97, 123)]  # 알파벳 a부터 z까지
        self.image_paths = [os.path.join(self.root_dir, char + '.jpg') for char in self.braille_chars]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path)
        label = idx
        if self.transform:
            image = self.transform(image)
        return image, label

In [43]:
# 학습 설정
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))])

In [44]:
# 데이터 로더 설정
image_directory = "./data/Braille_Dataset/"
braille_dataset = BrailleDataset(root_dir=image_directory, transform=transform)
data_loader = DataLoader(dataset=braille_dataset, batch_size=64, shuffle=True)

# 모델 인스턴스 생성
model = CNNModel()

# 손실 함수와 최적화기 설정
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [45]:
# 학습 설정
num_epochs = 10

# 학습과정 출력
for epoch in range(num_epochs):
    running_loss = 0.0
    correct = 0
    total = 0
    for i, (images, labels) in enumerate(data_loader):
        images = Variable(images)
        labels = Variable(labels)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        running_loss += loss.item()

    epoch_loss = running_loss / len(data_loader)
    epoch_accuracy = 100 * correct / total

    print('Epoch [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'.format(epoch+1, num_epochs, epoch_loss, epoch_accuracy))

FileNotFoundError: [Errno 2] No such file or directory: './data/Braille_Dataset/w.jpg'

In [46]:
# 모델 저장
torch.save(model.state_dict(), 'model.pth')

# 문장을 점자로 번역하는 함수
def translate_sentence_to_braille(sentence):
    braille_translation = ""
    predictions = []
    for char in sentence:
        if char.isalpha():
            braille_char = char.lower() + ".jpg"
            braille_translation += braille_char + " "
            predictions.append(char.lower())
        elif char == ' ':
            braille_translation += "space.jpg "
            predictions.append("space")
    return braille_translation.strip(), predictions

In [47]:
# 입력 문장
input_sentence = input("문장 입력: ")

# 번역된 점자 이미지파일 출력
translated_braille, predictions = translate_sentence_to_braille(input_sentence)
print("번역된 점자:", translated_braille)

번역된 점자: d.jpg


In [48]:
# 예측 결과 출력
print("예측값:", predictions)

예측값: ['d']


In [49]:
# 번역된 점자 이미지 출력
braille_image_names = translated_braille.split()
for braille_image_name in braille_image_names:
    image_path = os.path.join(image_directory, braille_image_name)
    image = Image.open(image_path)
    image.show()

FileNotFoundError: [Errno 2] No such file or directory: './data/Braille_Dataset/d.jpg'