## Transformers architecture를 받아와 구축해보자!

Transformers 강의 + 실습까지 수강하시느라 너무 수고 많으셨습니다! 저도 저번 기수 당시 transformers를 처음 접했는데, 모델의 구조가 난해하고 쓰인 개념들도 어려워서 이해하는데 많은 시간이 걸렸었네요.. 그래도 transformers 모델이 현재 contemporary AI technology에 쓰이지 않는 곳이 없다보니, 어려운 내용들이 다수 있지만 힘주어서 중요한 내용들로 채워넣으려고 노력했습니다 수고하셨습니다 XD

Transformers은 개념도 개념이지만, 매번 새롭게 attention 코드를 짜고, encoder와 decoder 구조를 구축하는 것도 막막하실 겁니다! 다행히도 transformers architecture는 워낙 유명해서 이제 코드 몇줄만 `딸깍`해도 최신 논문기반 transformer 구조를 `huggingface` 혹은 `github`에서 받아서 사용할 수 있습니다 :D 이번 실습에는 초심자가 사용하기는 어렵지만 `x-transformers`에서 저희가 구축한 transformers 구조를 받아와, 예시 문장을 출력하는 것까지 마무리할 예정입니다!

transformers architecture가 2017년 나온 이후로, 이 architecture를 기반으로 한 여러 모델들이 나왔고, 또한 base model에 대해서도 여러 개선점들이 추가되었습니다. `x-transformers`은 이 개선된 model들을 코드 몇줄을 추가해서 적용 가능하게 하는 라이브러리로, 최신 논문 동향을 파악하고 있어야 한다는 점에서 어렵지만 그만큼 성능이 뒷받침해주는 코드들을 모아둔 라이브러리입니다.

TMI가 생각보다 길어졌는데, 아무튼 이번 과제는 따로 작성해 넣어야할 부분은 없고, 제가 드리는 코드를 그대로 실행하기만 하면 되는 과정으로 추가했습니다. 이제 개강이 얼마 남지 않았는데, 다들 화이팅입니다 XD

# Colab GPU 환경에서 구동하세요!

In [1]:
!pip install x-transformers

Collecting x-transformers
  Downloading x_transformers-1.19.1-py3-none-any.whl (27 kB)
Collecting einops>=0.6.1 (from x-transformers)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, x-transformers
Successfully installed einops-0.6.1 x-transformers-1.19.1


In [2]:
import torch
from x_transformers import XTransformer

## Transformers Architectures 받아오
model = XTransformer(
    ## 모델의 차원 (논문 = 512)
    dim = 16,
    ## encoder token의 개수 (논문 = 256)
    enc_num_tokens = 16,
    ## encoder 반복 횟수 (논문 = 6)
    enc_depth = 6,
    ## multihead attention n_heads 개수 (논문 = 8)
    enc_heads = 8,
    ## encoder token의 max sequence length (논문 = 1024)
    enc_max_seq_len = 32,
    ## (논문 = 256)
    dec_num_tokens = 16,
    dec_depth = 6,
    dec_heads = 8,
    ## (논문 = 1024)
    dec_max_seq_len = 32,
    tie_token_emb = True
).cuda()

In [3]:
## NUM_BATCHES = Epoches의 개수
## BATCH_SIZE = 하나의 batch에 들어갈 sample의 개수
## LEARNING_RATE = learning rate
## GENERATE_EVERY = 100번마다 한번씩 generate해서 accuracy확인
## NUM_TOKENS = 데이터 내 유니크한 토큰의 수
## ENC_SEQ_LEN = encoder sequence length

NUM_BATCHES = int(1e3)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 100
NUM_TOKENS = 4 + 2
ENC_SEQ_LEN = 8
DEC_SEQ_LEN = 16 + 1

In [4]:
## Transformer 모델에 넣을 임의의 src, tgt, src_mask 생성

def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
        yield (src, tgt, src_mask)

In [5]:
## Train Model

import tqdm
import torch.optim as optim

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## loss update 해가면서 학습
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask = next(cycle())

    loss = model(src, tgt, mask=src_mask)
    loss.backward()
    print(f'{i}: {loss.item()}')

    optim.step()
    optim.zero_grad()

    ## 매 N(100)번마다 accuracy 측정
    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

training:   0%|          | 0/1000 [00:00<?, ?it/s]

0: 2.7159199714660645
1: 2.340632200241089
2: 2.211869716644287
3: 2.1506853103637695
4: 2.077253580093384
5: 2.1267380714416504
6: 2.054264783859253
7: 2.074209690093994
8: 2.081350564956665
9: 2.0057532787323
10: 2.018472909927368
11: 2.0012898445129395
12: 1.9789748191833496
13: 1.9547322988510132
14: 1.9762852191925049
15: 1.9849650859832764
16: 1.9714045524597168
17: 1.9267220497131348
18: 1.951064944267273
19: 1.982918381690979
20: 1.9797437191009521
21: 1.952080249786377
22: 1.9145829677581787
23: 1.9215447902679443
24: 1.8749977350234985
25: 1.8898621797561646
26: 1.9347507953643799
27: 1.9007478952407837
28: 1.8965784311294556
29: 1.8471498489379883
30: 1.8468708992004395
31: 1.9052762985229492
32: 1.816293478012085
33: 1.882497787475586
34: 1.8536301851272583
35: 1.8320268392562866
36: 1.8486570119857788
37: 1.8808456659317017
38: 1.8483726978302002
39: 1.857274055480957
40: 1.8012079000473022
41: 1.8206894397735596
42: 1.8089230060577393
43: 1.8185392618179321
44: 1.82067453

training:  14%|█▍        | 144/1000 [00:10<00:59, 14.37it/s]

141: 1.429375171661377
142: 1.4708038568496704
143: 1.4130728244781494
144: 1.426938772201538
145: 1.4158269166946411
146: 1.4316428899765015
147: 1.4025956392288208
148: 1.4313727617263794
149: 1.3923890590667725
150: 1.4011284112930298
151: 1.3822747468948364
152: 1.3944215774536133
153: 1.4168020486831665
154: 1.3475834131240845
155: 1.3935788869857788
156: 1.3425382375717163
157: 1.3844558000564575
158: 1.3378889560699463
159: 1.3482110500335693
160: 1.4005327224731445
161: 1.4245086908340454
162: 1.351359486579895
163: 1.485213279724121
164: 1.3365448713302612
165: 1.3408681154251099
166: 1.3633060455322266
167: 1.3562052249908447
168: 1.3528529405593872
169: 1.3829494714736938
170: 1.3404639959335327
171: 1.3405674695968628
172: 1.3186167478561401
173: 1.318481206893921
174: 1.295263648033142
175: 1.3060723543167114
176: 1.3011683225631714
177: 1.3565260171890259
178: 1.3387513160705566
179: 1.3186016082763672
180: 1.371983528137207
181: 1.2977324724197388
182: 1.2495943307876587

training:  34%|███▍      | 339/1000 [00:20<00:38, 17.38it/s]

335: 0.8408811688423157
336: 0.821189820766449
337: 0.816382110118866
338: 0.7702741026878357
339: 0.7913607358932495
340: 0.8461558222770691
341: 0.8474199771881104
342: 0.8759096264839172
343: 0.8180527091026306
344: 0.8636281490325928
345: 0.8126228451728821
346: 0.8495261073112488
347: 0.7622610330581665
348: 0.7712553143501282
349: 0.7534445524215698
350: 0.764530599117279
351: 0.7501214146614075
352: 0.7529242634773254
353: 0.7702922821044922
354: 0.7553625106811523
355: 0.7291116714477539
356: 0.730505108833313
357: 0.7647315263748169
358: 0.8019096255302429
359: 0.7678351402282715
360: 0.7439815402030945
361: 0.7344028353691101
362: 0.7114460468292236
363: 0.7380483746528625
364: 0.7473811507225037
365: 0.7175288200378418
366: 0.7285388112068176
367: 0.7153905630111694
368: 0.6770309209823608
369: 0.7094429731369019
370: 0.6856462955474854
371: 0.7331066131591797
372: 0.7316758632659912
373: 0.7104676961898804
374: 0.7346172332763672
375: 0.686581552028656
376: 0.75884073972702

training:  54%|█████▎    | 537/1000 [00:30<00:25, 18.47it/s]

533: 0.5293940305709839
534: 0.5078270435333252
535: 0.5428709387779236
536: 0.5293728709220886
537: 0.5161980390548706
538: 0.5341401696205139
539: 0.4827260971069336
540: 0.548072338104248
541: 0.5018462538719177
542: 0.5319571495056152
543: 0.5855957865715027
544: 0.5012254118919373
545: 0.6155388355255127
546: 0.5164189338684082
547: 0.6397287845611572
548: 0.5647966265678406
549: 0.5122579336166382
550: 0.5855109691619873
551: 0.5715779066085815
552: 0.5547770857810974
553: 0.5396165251731873
554: 0.5784672498703003
555: 0.5002285838127136
556: 0.5432012677192688
557: 0.5184900164604187
558: 0.5061147809028625
559: 0.5234165787696838
560: 0.5382082462310791
561: 0.5063276290893555
562: 0.511388897895813
563: 0.5108919143676758
564: 0.5215871334075928
565: 0.5247786641120911
566: 0.4781842529773712
567: 0.508637011051178
568: 0.5266063213348389
569: 0.5086038112640381
570: 0.49855995178222656
571: 0.5161296129226685
572: 0.5332735776901245
573: 0.4915812313556671
574: 0.49148863554

training:  74%|███████▎  | 735/1000 [00:40<00:13, 18.93it/s]

732: 0.4416087865829468
733: 0.4338935613632202
734: 0.4253292381763458
735: 0.4640275835990906
736: 0.4437926113605499
737: 0.41335397958755493
738: 0.3995330333709717
739: 0.40192991495132446
740: 0.41009485721588135
741: 0.4052377939224243
742: 0.44076988101005554
743: 0.3896525800228119
744: 0.42218634486198425
745: 0.4161558151245117
746: 0.41818809509277344
747: 0.42023590207099915
748: 0.44871002435684204
749: 0.40684977173805237
750: 0.4349977374076843
751: 0.40076032280921936
752: 0.4215726852416992
753: 0.44108110666275024
754: 0.38700243830680847
755: 0.43387800455093384
756: 0.4028025269508362
757: 0.38800883293151855
758: 0.4180164635181427
759: 0.44364455342292786
760: 0.4305459260940552
761: 0.4052961766719818
762: 0.45078518986701965
763: 0.391659140586853
764: 0.38873982429504395
765: 0.4360658526420593
766: 0.415536493062973
767: 0.3929416239261627
768: 0.3980906307697296
769: 0.39193195104599
770: 0.3869918882846832
771: 0.41097426414489746
772: 0.3830980062484741
77

training:  94%|█████████▍| 942/1000 [00:50<00:02, 19.57it/s]

940: 0.2780251204967499
941: 0.3489471673965454
942: 0.3219373822212219
943: 0.32006511092185974
944: 0.287994384765625
945: 0.3195747435092926
946: 0.24891731142997742
947: 0.2845100164413452
948: 0.30493804812431335
949: 0.34677281975746155
950: 0.3081819415092468
951: 0.26820099353790283
952: 0.25609081983566284
953: 0.26454564929008484
954: 0.24327512085437775
955: 0.24684742093086243
956: 0.29283076524734497
957: 0.26320382952690125
958: 0.2556506097316742
959: 0.24689587950706482
960: 0.3048698902130127
961: 0.22283144295215607
962: 0.2604905366897583
963: 0.2393389344215393
964: 0.24247199296951294
965: 0.22451430559158325
966: 0.23703710734844208
967: 0.22739265859127045
968: 0.22064292430877686
969: 0.21850770711898804
970: 0.20160435140132904
971: 0.21377970278263092
972: 0.22758297622203827
973: 0.16810797154903412
974: 0.2182527333498001
975: 0.2112853229045868
976: 0.21269434690475464
977: 0.23143872618675232
978: 0.21455417573451996
979: 0.24093268811702728
980: 0.2032862

training: 100%|██████████| 1000/1000 [00:53<00:00, 18.72it/s]

995: 0.1511784940958023
996: 0.1793159395456314
997: 0.21370358765125275
998: 0.18582852184772491
999: 0.1681508868932724





In [6]:
print(f'Accuracy : {(len(src) - incorrects)/len(src)*100}%')

Accuracy : 81.25%
