## Transformers architecture를 받아와 구축해보자!

Transformers 강의 + 실습까지 수강하시느라 너무 수고 많으셨습니다! 저도 저번 기수 당시 transformers를 처음 접했는데, 모델의 구조가 난해하고 쓰인 개념들도 어려워서 이해하는데 많은 시간이 걸렸었네요.. 그래도 transformers 모델이 현재 contemporary AI technology에 쓰이지 않는 곳이 없다보니, 어려운 내용들이 다수 있지만 힘주어서 중요한 내용들로 채워넣으려고 노력했습니다 수고하셨습니다 XD

Transformers은 개념도 개념이지만, 매번 새롭게 attention 코드를 짜고, encoder와 decoder 구조를 구축하는 것도 막막하실 겁니다! 다행히도 transformers architecture는 워낙 유명해서 이제 코드 몇줄만 `딸깍`해도 최신 논문기반 transformer 구조를 `huggingface` 혹은 `github`에서 받아서 사용할 수 있습니다 :D 이번 실습에는 초심자가 사용하기는 어렵지만 `x-transformers`에서 저희가 구축한 transformers 구조를 받아와, 예시 문장을 출력하는 것까지 마무리할 예정입니다!

transformers architecture가 2017년 나온 이후로, 이 architecture를 기반으로 한 여러 모델들이 나왔고, 또한 base model에 대해서도 여러 개선점들이 추가되었습니다. `x-transformers`은 이 개선된 model들을 코드 몇줄을 추가해서 적용 가능하게 하는 라이브러리로, 최신 논문 동향을 파악하고 있어야 한다는 점에서 어렵지만 그만큼 성능이 뒷받침해주는 코드들을 모아둔 라이브러리입니다.

TMI가 생각보다 길어졌는데, 아무튼 이번 과제는 따로 작성해 넣어야할 부분은 없고, 제가 드리는 코드를 그대로 실행하기만 하면 되는 과정으로 추가했습니다. 이제 개강이 얼마 남지 않았는데, 다들 화이팅입니다 XD

# Colab GPU 환경에서 구동하세요!

In [1]:
!pip install x-transformers

Collecting x-transformers
  Downloading x_transformers-1.19.1-py3-none-any.whl (27 kB)
Collecting einops>=0.6.1 (from x-transformers)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, x-transformers
Successfully installed einops-0.6.1 x-transformers-1.19.1


In [2]:
import torch
from x_transformers import XTransformer

## Transformers Architectures 받아오
model = XTransformer(
    ## 모델의 차원 (논문 = 512)
    dim = 16,
    ## encoder token의 개수 (논문 = 256)
    enc_num_tokens = 16,
    ## encoder 반복 횟수 (논문 = 6)
    enc_depth = 6,
    ## multihead attention n_heads 개수 (논문 = 8)
    enc_heads = 8,
    ## encoder token의 max sequence length (논문 = 1024)
    enc_max_seq_len = 32,
    ## (논문 = 256)
    dec_num_tokens = 16,
    dec_depth = 6,
    dec_heads = 8,
    ## (논문 = 1024)
    dec_max_seq_len = 32,
    tie_token_emb = True
).cuda()

In [3]:
## NUM_BATCHES = Epoches의 개수
## BATCH_SIZE = 하나의 batch에 들어갈 sample의 개수
## LEARNING_RATE = learning rate
## GENERATE_EVERY = 100번마다 한번씩 generate해서 accuracy확인
## NUM_TOKENS = 데이터 내 유니크한 토큰의 수
## ENC_SEQ_LEN = encoder sequence length

NUM_BATCHES = int(1e3)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 100
NUM_TOKENS = 4 + 2
ENC_SEQ_LEN = 8
DEC_SEQ_LEN = 16 + 1

In [4]:
## Transformer 모델에 넣을 임의의 src, tgt, src_mask 생성

def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
        yield (src, tgt, src_mask)

In [5]:
## Train Model

import tqdm
import torch.optim as optim

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## loss update 해가면서 학습
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask = next(cycle())

    loss = model(src, tgt, mask=src_mask)
    loss.backward()
    print(f'{i}: {loss.item()}')

    optim.step()
    optim.zero_grad()

    ## 매 N(100)번마다 accuracy 측정
    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

training:   0%|          | 0/1000 [00:00<?, ?it/s]

0: 2.7061374187469482
1: 2.214993953704834
2: 2.0629427433013916
3: 1.9888019561767578
4: 1.9318692684173584
5: 1.9197072982788086
6: 1.8874118328094482
7: 1.8992266654968262
8: 1.849900245666504
9: 1.859596610069275
10: 1.8585368394851685
11: 1.8539023399353027
12: 1.8120901584625244
13: 1.8008887767791748
14: 1.8125240802764893
15: 1.793718695640564
16: 1.8012608289718628
17: 1.781268835067749
18: 1.7972475290298462
19: 1.8015942573547363
20: 1.7938213348388672
21: 1.7752528190612793
22: 1.7734427452087402
23: 1.7380176782608032
24: 1.7525606155395508
25: 1.7334400415420532
26: 1.7335989475250244
27: 1.7306537628173828
28: 1.726776361465454
29: 1.7281047105789185
30: 1.7003061771392822
31: 1.7333178520202637
32: 1.692740559577942
33: 1.7154911756515503
34: 1.6934387683868408
35: 1.6694430112838745
36: 1.7040375471115112
37: 1.7256428003311157
38: 1.7054909467697144
39: 1.688014030456543
40: 1.6756374835968018
41: 1.6995227336883545
42: 1.649222731590271
43: 1.6569877862930298
44: 1.6

training:  13%|█▎        | 131/1000 [00:10<01:06, 13.07it/s]

128: 1.441887617111206
129: 1.4328140020370483
130: 1.380568504333496
131: 1.3836703300476074
132: 1.3815863132476807
133: 1.3847434520721436
134: 1.3997271060943604
135: 1.3595563173294067
136: 1.331903100013733
137: 1.3925719261169434
138: 1.3501434326171875
139: 1.389843463897705
140: 1.3630523681640625
141: 1.3858968019485474
142: 1.3612098693847656
143: 1.3585423231124878
144: 1.3599640130996704
145: 1.359816312789917
146: 1.3348052501678467
147: 1.3785815238952637
148: 1.3762836456298828
149: 1.363323450088501
150: 1.3309475183486938
151: 1.343226671218872
152: 1.3213199377059937
153: 1.3549320697784424
154: 1.3290680646896362
155: 1.3244702816009521
156: 1.3060557842254639
157: 1.3142459392547607
158: 1.3158352375030518
159: 1.3458019495010376
160: 1.353501319885254
161: 1.3393456935882568
162: 1.3187922239303589
163: 1.3096572160720825
164: 1.287101149559021
165: 1.2984886169433594
166: 1.2749079465866089
167: 1.302024483680725
168: 1.2754722833633423
169: 1.3016520738601685
17

training:  36%|███▌      | 356/1000 [00:20<00:34, 18.56it/s]

355: 0.8197599649429321
356: 0.8244092464447021
357: 0.8592352867126465
358: 0.9182614088058472
359: 0.8017748594284058
360: 0.8265559077262878
361: 0.879850447177887
362: 0.7480440139770508
363: 0.8126517534255981
364: 0.8501898050308228
365: 0.8198685646057129
366: 0.8081040382385254
367: 0.8133638501167297
368: 0.8320590257644653
369: 0.8176532983779907
370: 0.7895914316177368
371: 0.8193916082382202
372: 0.8333222270011902
373: 0.7844448685646057
374: 0.7959332466125488
375: 0.787899374961853
376: 0.7979117631912231
377: 0.7806097865104675
378: 0.8305640816688538
379: 0.7594384551048279
380: 0.7801885008811951
381: 0.7513172626495361
382: 0.7554836273193359
383: 0.7630159258842468
384: 0.7475595474243164
385: 0.7244447469711304
386: 0.7484669089317322
387: 0.7624181509017944
388: 0.7694929838180542
389: 0.7281314134597778
390: 0.7354514002799988
391: 0.6729552149772644
392: 0.7062530517578125
393: 0.7218243479728699
394: 0.6907866597175598
395: 0.7261609435081482
396: 0.71153271198

training:  58%|█████▊    | 580/1000 [00:30<00:21, 19.71it/s]

575: 0.4520793855190277
576: 0.4521121680736542
577: 0.4494975507259369
578: 0.46604660153388977
579: 0.46331608295440674
580: 0.4840630292892456
581: 0.4507341980934143
582: 0.517562747001648
583: 0.4806123375892639
584: 0.47490179538726807
585: 0.48375290632247925
586: 0.4914276599884033
587: 0.45105329155921936
588: 0.4803812801837921
589: 0.45914652943611145
590: 0.45880913734436035
591: 0.4749391973018646
592: 0.48893991112709045
593: 0.49779313802719116
594: 0.48866888880729675
595: 0.4558722972869873
596: 0.46407178044319153
597: 0.48824629187583923
598: 0.48453488945961
599: 0.4666805863380432
600: 0.5156347155570984
input:   tensor([[5, 4, 2, 5, 4, 2, 3, 4]], device='cuda:0')
predicted output:   tensor([[5, 2, 4, 5, 2, 4, 3, 4]], device='cuda:0')
incorrects: 4
601: 0.49776309728622437
602: 0.47314009070396423
603: 0.5344222187995911
604: 0.4958105981349945
605: 0.4494897723197937
606: 0.47218817472457886
607: 0.5140189528465271
608: 0.47683462500572205
609: 0.4394323229789734


training:  79%|███████▉  | 791/1000 [00:40<00:10, 20.05it/s]

787: 0.3470396101474762
788: 0.34318870306015015
789: 0.3221002519130707
790: 0.34250327944755554
791: 0.3125526010990143
792: 0.2823348939418793
793: 0.30364319682121277
794: 0.3176157772541046
795: 0.3122679591178894
796: 0.278723806142807
797: 0.3310447633266449
798: 0.3099045753479004
799: 0.34524789452552795
800: 0.3361593186855316
input:   tensor([[3, 2, 5, 3, 5, 2, 2, 3]], device='cuda:0')
predicted output:   tensor([[2, 3, 5, 2, 5, 3, 3, 2]], device='cuda:0')
incorrects: 6
801: 0.3557221293449402
802: 0.3973005712032318
803: 0.3071487247943878
804: 0.32520440220832825
805: 0.3147410750389099
806: 0.29161685705184937
807: 0.3187074363231659
808: 0.28604432940483093
809: 0.3079126477241516
810: 0.3066609501838684
811: 0.2918286621570587
812: 0.33603036403656006
813: 0.28784507513046265
814: 0.30336853861808777
815: 0.3009728491306305
816: 0.362269788980484
817: 0.2776124179363251
818: 0.2794112265110016
819: 0.3111684024333954
820: 0.2863611578941345
821: 0.2659822702407837
822: 

training: 100%|██████████| 1000/1000 [00:50<00:00, 19.70it/s]

997: 0.12625005841255188
998: 0.12399747967720032
999: 0.12504242360591888





In [6]:
print(f'Accuracy : {(len(src) - incorrects)/len(src)*100}%')

Accuracy : 87.5%
