## Transformers architecture를 받아와 구축해보자!

Transformers 강의 + 실습까지 수강하시느라 너무 수고 많으셨습니다! 저도 저번 기수 당시 transformers를 처음 접했는데, 모델의 구조가 난해하고 쓰인 개념들도 어려워서 이해하는데 많은 시간이 걸렸었네요.. 그래도 transformers 모델이 현재 contemporary AI technology에 쓰이지 않는 곳이 없다보니, 어려운 내용들이 다수 있지만 힘주어서 중요한 내용들로 채워넣으려고 노력했습니다 수고하셨습니다 XD

Transformers은 개념도 개념이지만, 매번 새롭게 attention 코드를 짜고, encoder와 decoder 구조를 구축하는 것도 막막하실 겁니다! 다행히도 transformers architecture는 워낙 유명해서 이제 코드 몇줄만 `딸깍`해도 최신 논문기반 transformer 구조를 `huggingface` 혹은 `github`에서 받아서 사용할 수 있습니다 :D 이번 실습에는 초심자가 사용하기는 어렵지만 `x-transformers`에서 저희가 구축한 transformers 구조를 받아와, 예시 문장을 출력하는 것까지 마무리할 예정입니다!

transformers architecture가 2017년 나온 이후로, 이 architecture를 기반으로 한 여러 모델들이 나왔고, 또한 base model에 대해서도 여러 개선점들이 추가되었습니다. `x-transformers`은 이 개선된 model들을 코드 몇줄을 추가해서 적용 가능하게 하는 라이브러리로, 최신 논문 동향을 파악하고 있어야 한다는 점에서 어렵지만 그만큼 성능이 뒷받침해주는 코드들을 모아둔 라이브러리입니다.

TMI가 생각보다 길어졌는데, 아무튼 이번 과제는 따로 작성해 넣어야할 부분은 없고, 제가 드리는 코드를 그대로 실행하기만 하면 되는 과정으로 추가했습니다. 이제 개강이 얼마 남지 않았는데, 다들 화이팅입니다 XD

# Colab GPU 환경에서 구동하세요!

In [1]:
!pip install x-transformers

Collecting x-transformers
  Downloading x_transformers-1.19.0-py3-none-any.whl (27 kB)
Collecting einops>=0.6.1 (from x-transformers)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, x-transformers
Successfully installed einops-0.6.1 x-transformers-1.19.0


In [2]:
import torch
from x_transformers import XTransformer

## Transformers Architectures 받아오
model = XTransformer(
    ## 모델의 차원 (논문 = 512)
    dim = 16,
    ## encoder token의 개수 (논문 = 256)
    enc_num_tokens = 16,
    ## encoder 반복 횟수 (논문 = 6)
    enc_depth = 6,
    ## multihead attention n_heads 개수 (논문 = 8)
    enc_heads = 8,
    ## encoder token의 max sequence length (논문 = 1024)
    enc_max_seq_len = 32,
    ## (논문 = 256)
    dec_num_tokens = 16,
    dec_depth = 6,
    dec_heads = 8,
    ## (논문 = 1024)
    dec_max_seq_len = 32,
    tie_token_emb = True
).cuda()

In [3]:
## NUM_BATCHES = Epoches의 개수
## BATCH_SIZE = 하나의 batch에 들어갈 sample의 개수
## LEARNING_RATE = learning rate
## GENERATE_EVERY = 100번마다 한번씩 generate해서 accuracy확인
## NUM_TOKENS = 데이터 내 유니크한 토큰의 수
## ENC_SEQ_LEN = encoder sequence length

NUM_BATCHES = int(1e3)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 100
NUM_TOKENS = 4 + 2
ENC_SEQ_LEN = 8
DEC_SEQ_LEN = 16 + 1

In [4]:
## Transformer 모델에 넣을 임의의 src, tgt, src_mask 생성

def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
        yield (src, tgt, src_mask)

In [5]:
## Train Model

import tqdm
import torch.optim as optim

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## loss update 해가면서 학습
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask = next(cycle())

    loss = model(src, tgt, mask=src_mask)
    loss.backward()
    print(f'{i}: {loss.item()}')

    optim.step()
    optim.zero_grad()

    ## 매 N(100)번마다 accuracy 측정
    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

training:   0%|          | 0/1000 [00:00<?, ?it/s]

0: 3.160845994949341
1: 2.4192914962768555
2: 2.2414767742156982
3: 2.113687753677368
4: 2.068185806274414
5: 2.005664825439453
6: 1.9955666065216064
7: 1.9809468984603882
8: 1.9598523378372192
9: 1.9655238389968872
10: 1.9615989923477173
11: 1.9545528888702393
12: 1.9578090906143188
13: 1.9011231660842896
14: 1.9127182960510254
15: 1.8671491146087646
16: 1.9160938262939453
17: 1.9076900482177734
18: 1.8913249969482422
19: 1.8514922857284546
20: 1.848868727684021
21: 1.837746262550354
22: 1.842522382736206
23: 1.8275797367095947
24: 1.8008253574371338
25: 1.8142328262329102
26: 1.8162447214126587
27: 1.7989778518676758
28: 1.8144147396087646
29: 1.7933297157287598
30: 1.783965826034546
31: 1.82062566280365
32: 1.7755224704742432
33: 1.793483853340149
34: 1.785622239112854
35: 1.768894076347351
36: 1.751534104347229
37: 1.7499592304229736
38: 1.7406036853790283
39: 1.7725391387939453
40: 1.7666406631469727
41: 1.741192102432251
42: 1.7354767322540283
43: 1.7642021179199219
44: 1.6992963

training:  14%|█▍        | 138/1000 [00:10<01:02, 13.79it/s]

135: 1.4337424039840698
136: 1.4419280290603638
137: 1.4004782438278198
138: 1.4394340515136719
139: 1.4099866151809692
140: 1.404908299446106
141: 1.4650686979293823
142: 1.407716989517212
143: 1.3630460500717163
144: 1.382617712020874
145: 1.4042248725891113
146: 1.36228346824646
147: 1.3804348707199097
148: 1.3502607345581055
149: 1.3759255409240723
150: 1.3725814819335938
151: 1.3534984588623047
152: 1.432508111000061
153: 1.369185447692871
154: 1.3389712572097778
155: 1.3553946018218994
156: 1.351906418800354
157: 1.3452777862548828
158: 1.3671962022781372
159: 1.3309834003448486
160: 1.2991892099380493
161: 1.3284310102462769
162: 1.3175803422927856
163: 1.3451809883117676
164: 1.3125485181808472
165: 1.3147995471954346
166: 1.3433610200881958
167: 1.3353776931762695
168: 1.3393924236297607
169: 1.337436318397522
170: 1.3308534622192383
171: 1.3011424541473389
172: 1.3099358081817627
173: 1.2802441120147705
174: 1.286696434020996
175: 1.2943155765533447
176: 1.262995958328247
177

training:  35%|███▍      | 346/1000 [00:20<00:36, 17.87it/s]

344: 0.8044654726982117
345: 0.785243570804596
346: 0.7937471270561218
347: 0.8048834204673767
348: 0.7878421545028687
349: 0.7648131251335144
350: 0.8120756149291992
351: 0.7975457906723022
352: 0.8046236634254456
353: 0.8122341632843018
354: 0.773522675037384
355: 0.8094795942306519
356: 0.7926761507987976
357: 0.8050554990768433
358: 0.8180520534515381
359: 0.7363572716712952
360: 0.8061025142669678
361: 0.7940486669540405
362: 0.8222563862800598
363: 0.7827319502830505
364: 0.7575640678405762
365: 0.8201051950454712
366: 0.7232848405838013
367: 0.7485322952270508
368: 0.7589207291603088
369: 0.750253438949585
370: 0.7431774139404297
371: 0.7166653275489807
372: 0.7551544308662415
373: 0.7184282541275024
374: 0.7771000862121582
375: 0.7119134664535522
376: 0.7366393804550171
377: 0.7279590964317322
378: 0.7410123944282532
379: 0.689592719078064
380: 0.7337483167648315
381: 0.6960592865943909
382: 0.6989627480506897
383: 0.7170958518981934
384: 0.7042816281318665
385: 0.6815991997718

training:  55%|█████▌    | 554/1000 [00:30<00:23, 19.19it/s]

551: 0.5049794316291809
552: 0.5008053183555603
553: 0.5424531698226929
554: 0.5050491690635681
555: 0.5526353716850281
556: 0.5562880039215088
557: 0.5159060955047607
558: 0.5777913331985474
559: 0.5054633617401123
560: 0.5262734889984131
561: 0.5274472832679749
562: 0.5015805959701538
563: 0.5297299027442932
564: 0.5247697234153748
565: 0.5255246758460999
566: 0.5086236000061035
567: 0.5307679176330566
568: 0.5059980154037476
569: 0.45651066303253174
570: 0.5020957589149475
571: 0.5087870359420776
572: 0.52702796459198
573: 0.4981159269809723
574: 0.4804989993572235
575: 0.4981729984283447
576: 0.5051885843276978
577: 0.4806261658668518
578: 0.4811885952949524
579: 0.47619158029556274
580: 0.5124094486236572
581: 0.5056158900260925
582: 0.5288753509521484
583: 0.5059769749641418
584: 0.48474884033203125
585: 0.4867873191833496
586: 0.48798778653144836
587: 0.49485984444618225
588: 0.5080727338790894
589: 0.49075520038604736
590: 0.5125023126602173
591: 0.4817081689834595
592: 0.49767

training:  77%|███████▋  | 774/1000 [00:40<00:11, 20.27it/s]

772: 0.5107887983322144
773: 0.4627531170845032
774: 0.4894343912601471
775: 0.5156344175338745
776: 0.48579341173171997
777: 0.45985597372055054
778: 0.46159350872039795
779: 0.498787522315979
780: 0.4538487493991852
781: 0.4524104595184326
782: 0.4359729290008545
783: 0.46030330657958984
784: 0.43080416321754456
785: 0.4268718659877777
786: 0.45497995615005493
787: 0.4550473988056183
788: 0.4509519040584564
789: 0.4632589817047119
790: 0.4316036105155945
791: 0.43303778767585754
792: 0.4705539643764496
793: 0.4222944676876068
794: 0.4363822937011719
795: 0.44211044907569885
796: 0.4396730065345764
797: 0.434769868850708
798: 0.43555739521980286
799: 0.42199400067329407
800: 0.4522143602371216
input:   tensor([[5, 2, 3, 4, 3, 2, 3, 5]], device='cuda:0')
predicted output:   tensor([[2, 3, 4, 5, 2, 3, 3, 5]], device='cuda:0')
incorrects: 6
801: 0.43598005175590515
802: 0.46356406807899475
803: 0.4386112093925476
804: 0.4317091703414917
805: 0.43921640515327454
806: 0.4195844233036041
80

training:  99%|█████████▉| 994/1000 [00:50<00:00, 20.81it/s]

993: 0.398166686296463
994: 0.4304058253765106
995: 0.4032456874847412
996: 0.4214848577976227
997: 0.42439448833465576


training: 100%|██████████| 1000/1000 [00:50<00:00, 19.82it/s]

998: 0.40893203020095825
999: 0.42155563831329346





In [6]:
print(f'Accuracy : {(len(src) - incorrects)/len(src)*100}%')

Accuracy : 84.375%
