In [1]:
import os
from torch.utils.data import DataLoader,Dataset
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import pandas as pd
from torchvision import models
from pathlib import Path
# from fastai.vision import Path
import torch
import torch.nn as nn
from torch.autograd import Variable

In [2]:
NUMBER = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
ALPHABET = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
ALL_CHAR_SET = NUMBER + ALPHABET
ALL_CHAR_SET_LEN = len(ALL_CHAR_SET)
MAX_CAPTCHA = 5

In [3]:
def encode(a):
    onehot = [0]*ALL_CHAR_SET_LEN
    idx = ALL_CHAR_SET.index(a)
    onehot[idx] += 1
    return onehot

In [4]:
class Mydataset(Dataset):
    def __init__(self, path, is_train=True, transform=None):
        self.path = path
        if is_train: self.img = os.listdir(self.path)[:1000]
        else: self.img = os.listdir(self.path)[1001:]
        try: self.img.remove('3bnfnd.png')
        except: pass
        self.transform = transform
        
    def __getitem__(self, idx):
        img_path = self.img[idx]
        img = Image.open(self.path/img_path)
        img = img.convert('L')
        label = Path(self.path/img_path).name[:-4]
        label_oh = []
        for i in label:
            label_oh += encode(i)
        if self.transform is not None:
            img = self.transform(img)
        return img, np.array(label_oh), label
    
    def __len__(self):
        return len(self.img)

In [5]:
transform = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.ToTensor(),
])

In [6]:
train_ds = Mydataset(Path('captcha-version-2-images/samples/samples'), transform=transform)
test_ds = Mydataset(Path('captcha-version-2-images/samples/samples'), False, transform)
train_dl = DataLoader(train_ds, batch_size=64, num_workers=0)
test_dl = DataLoader(train_ds, batch_size=1, num_workers=0)

In [7]:
model = models.resnet18(pretrained=False)

In [8]:
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

In [9]:
model.fc = nn.Linear(in_features=512, out_features=ALL_CHAR_SET_LEN*MAX_CAPTCHA, bias=True)

In [10]:
model.cuda();

In [11]:
loss_func = nn.MultiLabelSoftMarginLoss()
optm = torch.optim.Adam(model.parameters(), lr=0.001)

In [12]:
for epoch in range(40):
    for step, i in enumerate(train_dl):
        img, label_oh, label = i
        img = Variable(img).cuda()
        label_oh = Variable(label_oh.float()).cuda()
        pred = model(img)
        loss = loss_func(pred, label_oh)
        optm.zero_grad()
        loss.backward()
        optm.step()
        print('eopch:', epoch+1, 'step:', step+1, 'loss:', loss.item())

eopch: 1 step: 1 loss: 0.7020683288574219
eopch: 1 step: 2 loss: 0.5063948035240173
eopch: 1 step: 3 loss: 0.33852648735046387
eopch: 1 step: 4 loss: 0.23360654711723328
eopch: 1 step: 5 loss: 0.1740381270647049
eopch: 1 step: 6 loss: 0.14506274461746216
eopch: 1 step: 7 loss: 0.13296550512313843
eopch: 1 step: 8 loss: 0.1332421898841858
eopch: 1 step: 9 loss: 0.1283293515443802
eopch: 1 step: 10 loss: 0.13049545884132385
eopch: 1 step: 11 loss: 0.1311972290277481
eopch: 1 step: 12 loss: 0.1337416172027588
eopch: 1 step: 13 loss: 0.13173668086528778
eopch: 1 step: 14 loss: 0.12729239463806152
eopch: 1 step: 15 loss: 0.12682193517684937
eopch: 1 step: 16 loss: 0.1265459656715393
eopch: 2 step: 1 loss: 0.12253037840127945
eopch: 2 step: 2 loss: 0.12151574343442917
eopch: 2 step: 3 loss: 0.12051652371883392
eopch: 2 step: 4 loss: 0.11589503288269043
eopch: 2 step: 5 loss: 0.11392486095428467
eopch: 2 step: 6 loss: 0.11172939836978912
eopch: 2 step: 7 loss: 0.11169478297233582
eopch: 2 ste

eopch: 12 step: 14 loss: 0.07429489493370056
eopch: 12 step: 15 loss: 0.08354508876800537
eopch: 12 step: 16 loss: 0.07349015772342682
eopch: 13 step: 1 loss: 0.07628655433654785
eopch: 13 step: 2 loss: 0.07466072589159012
eopch: 13 step: 3 loss: 0.07571364939212799
eopch: 13 step: 4 loss: 0.0734923779964447
eopch: 13 step: 5 loss: 0.07602039724588394
eopch: 13 step: 6 loss: 0.07304145395755768
eopch: 13 step: 7 loss: 0.07421964406967163
eopch: 13 step: 8 loss: 0.0677301362156868
eopch: 13 step: 9 loss: 0.06725893914699554
eopch: 13 step: 10 loss: 0.07248710095882416
eopch: 13 step: 11 loss: 0.07179282605648041
eopch: 13 step: 12 loss: 0.08034037053585052
eopch: 13 step: 13 loss: 0.07081769406795502
eopch: 13 step: 14 loss: 0.06714864075183868
eopch: 13 step: 15 loss: 0.07696601748466492
eopch: 13 step: 16 loss: 0.06716179847717285
eopch: 14 step: 1 loss: 0.07099256664514542
eopch: 14 step: 2 loss: 0.0690281093120575
eopch: 14 step: 3 loss: 0.06930798292160034
eopch: 14 step: 4 loss: 0

eopch: 24 step: 6 loss: 0.027789119631052017
eopch: 24 step: 7 loss: 0.024732429534196854
eopch: 24 step: 8 loss: 0.020184403285384178
eopch: 24 step: 9 loss: 0.02039218321442604
eopch: 24 step: 10 loss: 0.021354157477617264
eopch: 24 step: 11 loss: 0.02564985491335392
eopch: 24 step: 12 loss: 0.03034072369337082
eopch: 24 step: 13 loss: 0.026384670287370682
eopch: 24 step: 14 loss: 0.019089745357632637
eopch: 24 step: 15 loss: 0.02271742932498455
eopch: 24 step: 16 loss: 0.024561937898397446
eopch: 25 step: 1 loss: 0.025388043373823166
eopch: 25 step: 2 loss: 0.022958669811487198
eopch: 25 step: 3 loss: 0.02152785286307335
eopch: 25 step: 4 loss: 0.02034199796617031
eopch: 25 step: 5 loss: 0.02177397347986698
eopch: 25 step: 6 loss: 0.025463800877332687
eopch: 25 step: 7 loss: 0.025527892634272575
eopch: 25 step: 8 loss: 0.019975867122411728
eopch: 25 step: 9 loss: 0.015876367688179016
eopch: 25 step: 10 loss: 0.01713196001946926
eopch: 25 step: 11 loss: 0.022745084017515182
eopch: 25

eopch: 35 step: 12 loss: 0.01439476665109396
eopch: 35 step: 13 loss: 0.013848884031176567
eopch: 35 step: 14 loss: 0.006877233274281025
eopch: 35 step: 15 loss: 0.008899419568479061
eopch: 35 step: 16 loss: 0.009859401732683182
eopch: 36 step: 1 loss: 0.007824943400919437
eopch: 36 step: 2 loss: 0.009188476018607616
eopch: 36 step: 3 loss: 0.007715878542512655
eopch: 36 step: 4 loss: 0.006972680799663067
eopch: 36 step: 5 loss: 0.006581156048923731
eopch: 36 step: 6 loss: 0.01042188424617052
eopch: 36 step: 7 loss: 0.009117270819842815
eopch: 36 step: 8 loss: 0.008100291714072227
eopch: 36 step: 9 loss: 0.007027916144579649
eopch: 36 step: 10 loss: 0.005878181662410498
eopch: 36 step: 11 loss: 0.007146788761019707
eopch: 36 step: 12 loss: 0.012770467437803745
eopch: 36 step: 13 loss: 0.012446282431483269
eopch: 36 step: 14 loss: 0.005436625797301531
eopch: 36 step: 15 loss: 0.005997912026941776
eopch: 36 step: 16 loss: 0.007806769106537104
eopch: 37 step: 1 loss: 0.006503511220216751


In [13]:
model.eval();

In [14]:
test_correct = 0
test_total = len(test_dl.dataset)
for step, (img, label_oh, label) in enumerate(test_dl):
    img = Variable(img).cuda()
    pred = model(img)

    c0 = ALL_CHAR_SET[np.argmax(pred.squeeze().cpu().tolist()[0:ALL_CHAR_SET_LEN])]
    c1 = ALL_CHAR_SET[np.argmax(pred.squeeze().cpu().tolist()[ALL_CHAR_SET_LEN:ALL_CHAR_SET_LEN*2])]
    c2 = ALL_CHAR_SET[np.argmax(pred.squeeze().cpu().tolist()[ALL_CHAR_SET_LEN*2:ALL_CHAR_SET_LEN*3])]
    c3 = ALL_CHAR_SET[np.argmax(pred.squeeze().cpu().tolist()[ALL_CHAR_SET_LEN*3:ALL_CHAR_SET_LEN*4])]
    c4 = ALL_CHAR_SET[np.argmax(pred.squeeze().cpu().tolist()[ALL_CHAR_SET_LEN*4:ALL_CHAR_SET_LEN*5])]
    c = '%s%s%s%s%s' % (c0, c1, c2, c3, c4)
    if c == label[0]:
        test_correct += 1


    print('label:', label[0], 'pred:', c, ' 맞춤여뷰 : ', label[0]==c)
    
print(f'Test Accuracy: {(test_correct/test_total):.5f} ' +
      f'({test_correct}/{test_total})')


label: 226md pred: 226md  맞춤여뷰 :  True
label: 22d5n pred: 22d5n  맞춤여뷰 :  True
label: 2356g pred: 2356g  맞춤여뷰 :  True
label: 23mdg pred: 23mdg  맞춤여뷰 :  True
label: 23n88 pred: 23n88  맞춤여뷰 :  True
label: 243mm pred: 243mm  맞춤여뷰 :  True
label: 244e2 pred: 244e2  맞춤여뷰 :  True
label: 245y5 pred: 245y5  맞춤여뷰 :  True
label: 24f6w pred: 24f6w  맞춤여뷰 :  True
label: 24pew pred: 24pew  맞춤여뷰 :  True
label: 25257 pred: 25257  맞춤여뷰 :  True
label: 253dc pred: 253dc  맞춤여뷰 :  True
label: 25egp pred: 25egp  맞춤여뷰 :  True
label: 25m6p pred: 25m6p  맞춤여뷰 :  True
label: 25p2m pred: 25p2m  맞춤여뷰 :  True
label: 25w53 pred: 25w53  맞춤여뷰 :  True
label: 264m5 pred: 264m5  맞춤여뷰 :  True
label: 268g2 pred: 268g2  맞춤여뷰 :  True
label: 28348 pred: 28348  맞춤여뷰 :  True
label: 28x47 pred: 28x47  맞춤여뷰 :  True
label: 2b827 pred: 2b827  맞춤여뷰 :  True
label: 2bg48 pred: 2bg48  맞춤여뷰 :  True
label: 2cegf pred: 2cegf  맞춤여뷰 :  True
label: 2cg58 pred: 2cg58  맞춤여뷰 :  True
label: 2cgyx pred: 2cgyx  맞춤여뷰 :  True
label: 2en7g pred: 2en7g 

label: 64b3p pred: 64b3p  맞춤여뷰 :  True
label: 64m82 pred: 64m82  맞춤여뷰 :  True
label: 658xe pred: 658xe  맞춤여뷰 :  True
label: 65ebm pred: 65ebm  맞춤여뷰 :  True
label: 65m85 pred: 65m85  맞춤여뷰 :  True
label: 65nmw pred: 65nmw  맞춤여뷰 :  True
label: 662bw pred: 662bw  맞춤여뷰 :  True
label: 664dn pred: 664dn  맞춤여뷰 :  True
label: 664nf pred: 664nf  맞춤여뷰 :  True
label: 66wp5 pred: 66wp5  맞춤여뷰 :  True
label: 675p3 pred: 675p3  맞춤여뷰 :  True
label: 677g3 pred: 677g3  맞춤여뷰 :  True
label: 678w3 pred: 678w3  맞춤여뷰 :  True
label: 67dey pred: 67dey  맞춤여뷰 :  True
label: 6825y pred: 6825y  맞춤여뷰 :  True
label: 68wfd pred: 68wfd  맞춤여뷰 :  True
label: 68x48 pred: 68x48  맞춤여뷰 :  True
label: 6b46g pred: 6b46g  맞춤여뷰 :  True
label: 6b4w6 pred: 6b4w6  맞춤여뷰 :  True
label: 6bdn5 pred: 6bdn5  맞춤여뷰 :  True
label: 6bnnm pred: 6bnnm  맞춤여뷰 :  True
label: 6bxwg pred: 6bxwg  맞춤여뷰 :  True
label: 6c3n6 pred: 6c3n6  맞춤여뷰 :  True
label: 6c3p5 pred: 6c3p5  맞춤여뷰 :  True
label: 6cm6m pred: 6cm6m  맞춤여뷰 :  True
label: 6cwxe pred: 6cwxe 

label: c4mcm pred: c4mcm  맞춤여뷰 :  True
label: c55c6 pred: c55c6  맞춤여뷰 :  True
label: c5xne pred: c5xne  맞춤여뷰 :  True
label: c6745 pred: c6745  맞춤여뷰 :  True
label: c6f8g pred: c6f8g  맞춤여뷰 :  True
label: c6we6 pred: c6we6  맞춤여뷰 :  True
label: c753e pred: c753e  맞춤여뷰 :  True
label: c7gb3 pred: c7gb3  맞춤여뷰 :  True
label: c7nn8 pred: c7nn8  맞춤여뷰 :  True
label: c86md pred: c86md  맞춤여뷰 :  True
label: c8fxy pred: c8fxy  맞춤여뷰 :  True
label: c8n8c pred: c8n8c  맞춤여뷰 :  True
label: cb8cf pred: cb8cf  맞춤여뷰 :  True
label: cc845 pred: cc845  맞춤여뷰 :  True
label: ccf2w pred: ccf2w  맞춤여뷰 :  True
label: ccn2x pred: ccn2x  맞춤여뷰 :  True
label: cd4eg pred: cd4eg  맞춤여뷰 :  True
label: cd6p4 pred: cd6p4  맞춤여뷰 :  True
label: cdcb3 pred: cdcb3  맞춤여뷰 :  True
label: cdf77 pred: cdf77  맞춤여뷰 :  True
label: cdfen pred: cdfen  맞춤여뷰 :  True
label: cdmn8 pred: cdmn8  맞춤여뷰 :  True
label: cen55 pred: cen55  맞춤여뷰 :  True
label: cewnm pred: cewnm  맞춤여뷰 :  True
label: cfc2y pred: cfc2y  맞춤여뷰 :  True
label: cfc56 pred: cfc56 

label: gcx6f pred: gcx6f  맞춤여뷰 :  True
label: gd4mf pred: gd4mf  맞춤여뷰 :  True
label: gd8fb pred: gd8fb  맞춤여뷰 :  True
label: gdng3 pred: gdpg3  맞춤여뷰 :  False
label: gecmf pred: gecmf  맞춤여뷰 :  True
label: gegw4 pred: gegw4  맞춤여뷰 :  True
label: gewfy pred: gewfy  맞춤여뷰 :  True
label: geyn5 pred: geyn5  맞춤여뷰 :  True
label: gf2g4 pred: gf2g4  맞춤여뷰 :  True
label: gfbx6 pred: gfbx6  맞춤여뷰 :  True
label: gfp54 pred: gfp54  맞춤여뷰 :  True
label: gfxcc pred: gfxcc  맞춤여뷰 :  True
label: ggd7m pred: ggd7m  맞춤여뷰 :  True
label: gm2c2 pred: gm2c2  맞춤여뷰 :  True
label: gm6nn pred: gm6nn  맞춤여뷰 :  True
label: gm7n8 pred: gm7n8  맞춤여뷰 :  True
label: gmmne pred: gmmne  맞춤여뷰 :  True
label: gn2d3 pred: gn2d3  맞춤여뷰 :  True
label: gn2xy pred: gn2xy  맞춤여뷰 :  True
label: gnbde pred: gnbde  맞춤여뷰 :  True
label: gnbn4 pred: gnbn4  맞춤여뷰 :  True
label: gnc3n pred: gnc3n  맞춤여뷰 :  True
label: gnf85 pred: gnf85  맞춤여뷰 :  True
label: gng6e pred: gng6e  맞춤여뷰 :  True
label: gny6b pred: gny6b  맞춤여뷰 :  True
label: gp22x pred: gp22x

label: pg2pm pred: pg2pm  맞춤여뷰 :  True
label: pg2yx pred: pg2yx  맞춤여뷰 :  True
label: pg4bf pred: pg4bf  맞춤여뷰 :  True
label: pgg3n pred: pgg3n  맞춤여뷰 :  True
label: pgm2e pred: pgm2e  맞춤여뷰 :  True
label: pgmn2 pred: pgmn2  맞춤여뷰 :  True
label: pgwnp pred: pgwnp  맞춤여뷰 :  True
label: pm363 pred: pm363  맞춤여뷰 :  True
label: pm47f pred: pm47f  맞춤여뷰 :  True
label: pmd3w pred: pmd3w  맞춤여뷰 :  True
label: pme86 pred: pme86  맞춤여뷰 :  True
label: pmf5w pred: pmf5w  맞춤여뷰 :  True
label: pmg55 pred: pmg55  맞춤여뷰 :  True
label: pn7pn pred: pn7pn  맞춤여뷰 :  True
label: pnmxf pred: pnmxf  맞춤여뷰 :  True
label: pnnwy pred: pnnwy  맞춤여뷰 :  True
label: pp546 pred: pp546  맞춤여뷰 :  True
label: pp87n pred: pp87n  맞춤여뷰 :  True
label: ppwyd pred: ppwyd  맞춤여뷰 :  True
label: ppx77 pred: ppx77  맞춤여뷰 :  True
label: pw5nc pred: pw5nc  맞춤여뷰 :  True
label: pwebm pred: pwebm  맞춤여뷰 :  True
label: pwmbn pred: pwmbn  맞춤여뷰 :  True
label: pwn5e pred: pwn5e  맞춤여뷰 :  True
label: px2xp pred: px2xp  맞춤여뷰 :  True
label: px8n8 pred: px8n8 

### 모델 예측값 시각화하기
일부 이미지에 대한 예측값을 보여주는 일반화된 함수입니다.

In [None]:
def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)