## Transformers architecture를 받아와 구축해보자!

Transformers 강의 + 실습까지 수강하시느라 너무 수고 많으셨습니다! 저도 저번 기수 당시 transformers를 처음 접했는데, 모델의 구조가 난해하고 쓰인 개념들도 어려워서 이해하는데 많은 시간이 걸렸었네요.. 그래도 transformers 모델이 현재 contemporary AI technology에 쓰이지 않는 곳이 없다보니, 어려운 내용들이 다수 있지만 힘주어서 중요한 내용들로 채워넣으려고 노력했습니다 수고하셨습니다 XD

Transformers은 개념도 개념이지만, 매번 새롭게 attention 코드를 짜고, encoder와 decoder 구조를 구축하는 것도 막막하실 겁니다! 다행히도 transformers architecture는 워낙 유명해서 이제 코드 몇줄만 `딸깍`해도 최신 논문기반 transformer 구조를 `huggingface` 혹은 `github`에서 받아서 사용할 수 있습니다 :D 이번 실습에는 초심자가 사용하기는 어렵지만 `x-transformers`에서 저희가 구축한 transformers 구조를 받아와, 예시 문장을 출력하는 것까지 마무리할 예정입니다!

transformers architecture가 2017년 나온 이후로, 이 architecture를 기반으로 한 여러 모델들이 나왔고, 또한 base model에 대해서도 여러 개선점들이 추가되었습니다. `x-transformers`은 이 개선된 model들을 코드 몇줄을 추가해서 적용 가능하게 하는 라이브러리로, 최신 논문 동향을 파악하고 있어야 한다는 점에서 어렵지만 그만큼 성능이 뒷받침해주는 코드들을 모아둔 라이브러리입니다.

TMI가 생각보다 길어졌는데, 아무튼 이번 과제는 따로 작성해 넣어야할 부분은 없고, 제가 드리는 코드를 그대로 실행하기만 하면 되는 과정으로 추가했습니다. 이제 개강이 얼마 남지 않았는데, 다들 화이팅입니다 XD

# Colab GPU 환경에서 구동하세요!

In [1]:
!pip install x-transformers

Collecting x-transformers
  Downloading x_transformers-1.19.1-py3-none-any.whl (27 kB)
Collecting einops>=0.6.1 (from x-transformers)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, x-transformers
Successfully installed einops-0.6.1 x-transformers-1.19.1


In [2]:
import torch
from x_transformers import XTransformer

## Transformers Architectures 받아오
model = XTransformer(
    ## 모델의 차원 (논문 = 512)
    dim = 16,
    ## encoder token의 개수 (논문 = 256)
    enc_num_tokens = 16,
    ## encoder 반복 횟수 (논문 = 6)
    enc_depth = 6,
    ## multihead attention n_heads 개수 (논문 = 8)
    enc_heads = 8,
    ## encoder token의 max sequence length (논문 = 1024)
    enc_max_seq_len = 32,
    ## (논문 = 256)
    dec_num_tokens = 16,
    dec_depth = 6,
    dec_heads = 8,
    ## (논문 = 1024)
    dec_max_seq_len = 32,
    tie_token_emb = True
).cuda()

In [3]:
## NUM_BATCHES = Epoches의 개수
## BATCH_SIZE = 하나의 batch에 들어갈 sample의 개수
## LEARNING_RATE = learning rate
## GENERATE_EVERY = 100번마다 한번씩 generate해서 accuracy확인
## NUM_TOKENS = 데이터 내 유니크한 토큰의 수
## ENC_SEQ_LEN = encoder sequence length

NUM_BATCHES = int(1e3)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 100
NUM_TOKENS = 4 + 2
ENC_SEQ_LEN = 8
DEC_SEQ_LEN = 16 + 1

In [4]:
## Transformer 모델에 넣을 임의의 src, tgt, src_mask 생성

def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
        yield (src, tgt, src_mask)

In [5]:
## Train Model

import tqdm
import torch.optim as optim

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## loss update 해가면서 학습
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask = next(cycle())

    loss = model(src, tgt, mask=src_mask)
    loss.backward()
    print(f'{i}: {loss.item()}')

    optim.step()
    optim.zero_grad()

    ## 매 N(100)번마다 accuracy 측정
    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

training:   0%|          | 0/1000 [00:00<?, ?it/s]

0: 3.13214111328125
1: 2.5392985343933105
2: 2.2689368724823
3: 2.2033846378326416
4: 2.1097118854522705
5: 2.1064391136169434
6: 2.069767713546753
7: 1.9942455291748047
8: 2.015533208847046
9: 1.9958992004394531
10: 2.0292623043060303
11: 1.9385567903518677
12: 1.961706280708313
13: 1.9460667371749878
14: 2.006744384765625
15: 1.9883928298950195
16: 1.9738253355026245
17: 1.9432681798934937
18: 1.8907550573349
19: 1.891868233680725
20: 1.8730283975601196
21: 1.9208471775054932
22: 1.8928077220916748
23: 1.910258173942566
24: 1.926522970199585
25: 1.8844573497772217
26: 1.9193223714828491
27: 1.9008256196975708
28: 1.858151912689209
29: 1.8675936460494995
30: 1.8568484783172607
31: 1.8436172008514404
32: 1.8609970808029175
33: 1.844849944114685
34: 1.838913083076477
35: 1.890913963317871
36: 1.8622297048568726
37: 1.8590399026870728
38: 1.8060780763626099
39: 1.8068137168884277
40: 1.8317228555679321
41: 1.8054025173187256
42: 1.8332784175872803
43: 1.8173495531082153
44: 1.84365868568

training:  11%|█▏        | 113/1000 [00:10<01:18, 11.28it/s]

111: 1.5871013402938843
112: 1.603407621383667
113: 1.5756127834320068
114: 1.5496370792388916
115: 1.5182769298553467
116: 1.5724856853485107
117: 1.5349947214126587
118: 1.504668116569519
119: 1.5322401523590088
120: 1.5274640321731567
121: 1.592337965965271
122: 1.5082037448883057
123: 1.545278549194336
124: 1.5091279745101929
125: 1.5344499349594116
126: 1.5151957273483276
127: 1.4604167938232422
128: 1.5231518745422363
129: 1.3781527280807495
130: 1.5181143283843994
131: 1.482142686843872
132: 1.4372655153274536
133: 1.4873334169387817
134: 1.4770585298538208
135: 1.4477342367172241
136: 1.4662364721298218
137: 1.4856631755828857
138: 1.5066314935684204
139: 1.4701955318450928
140: 1.449734091758728
141: 1.4205604791641235
142: 1.419723629951477
143: 1.4416569471359253
144: 1.435374140739441
145: 1.4056065082550049
146: 1.4275387525558472
147: 1.4529255628585815
148: 1.4518834352493286
149: 1.4395910501480103
150: 1.455998182296753
151: 1.379351019859314
152: 1.4300751686096191
15

training:  29%|██▉       | 291/1000 [00:20<00:46, 15.11it/s]

287: 0.9840041399002075
288: 1.0173405408859253
289: 1.0102125406265259
290: 1.0111443996429443
291: 0.9915384650230408
292: 1.0015592575073242
293: 0.9713619947433472
294: 1.0469005107879639
295: 0.9798405766487122
296: 0.9593299031257629
297: 0.9644400477409363
298: 1.0022497177124023
299: 0.9807640910148621
300: 0.9910978078842163
input:   tensor([[3, 4, 4, 3, 4, 4, 4, 4]], device='cuda:0')
predicted output:   tensor([[4, 4, 4, 4, 3, 4, 4, 3]], device='cuda:0')
incorrects: 4
301: 0.9746063947677612
302: 0.9629058241844177
303: 0.973691463470459
304: 0.9622275233268738
305: 0.931398332118988
306: 0.9643674492835999
307: 0.9681692719459534
308: 0.8980139493942261
309: 0.9426261186599731
310: 0.9264912605285645
311: 0.9073613286018372
312: 0.9621738791465759
313: 0.9023735523223877
314: 1.0028903484344482
315: 0.9001245498657227
316: 0.923603892326355
317: 0.9357532262802124
318: 0.8792539834976196
319: 0.8811565041542053
320: 0.866762101650238
321: 0.8959039449691772
322: 0.9051466584

training:  50%|█████     | 502/1000 [00:30<00:27, 17.83it/s]

input:   tensor([[5, 4, 4, 2, 5, 3, 3, 2]], device='cuda:0')
predicted output:   tensor([[4, 5, 4, 2, 3, 4, 5, 2]], device='cuda:0')
incorrects: 5
501: 0.4982120990753174
502: 0.4584087133407593
503: 0.4651450514793396
504: 0.5008384585380554
505: 0.5090293884277344
506: 0.5088421106338501
507: 0.4681713581085205
508: 0.4830375909805298
509: 0.4696088433265686
510: 0.5095979571342468
511: 0.44496363401412964
512: 0.4491420090198517
513: 0.4786950349807739
514: 0.5023661255836487
515: 0.445194810628891
516: 0.4645935595035553
517: 0.45430463552474976
518: 0.4278334677219391
519: 0.4441244304180145
520: 0.4588082730770111
521: 0.43614310026168823
522: 0.4574339687824249
523: 0.4758722186088562
524: 0.4600765109062195
525: 0.4416640102863312
526: 0.48715728521347046
527: 0.45476147532463074
528: 0.4417630434036255
529: 0.48819002509117126
530: 0.43897223472595215
531: 0.47889310121536255
532: 0.42156222462654114
533: 0.4330015182495117
534: 0.43217208981513977
535: 0.4603905975818634
536:

training:  72%|███████▏  | 715/1000 [00:40<00:14, 19.17it/s]

711: 0.23071816563606262
712: 0.21359285712242126
713: 0.2161746472120285
714: 0.21157066524028778
715: 0.19302190840244293
716: 0.22821594774723053
717: 0.20553237199783325
718: 0.18504951894283295
719: 0.179910808801651
720: 0.19794563949108124
721: 0.20074890553951263
722: 0.1628144085407257
723: 0.15633822977542877
724: 0.2094571888446808
725: 0.21119621396064758
726: 0.1991499811410904
727: 0.18674711883068085
728: 0.16542227566242218
729: 0.23253986239433289
730: 0.20208314061164856
731: 0.1719496250152588
732: 0.16718333959579468
733: 0.22208784520626068
734: 0.18243804574012756
735: 0.24987971782684326
736: 0.19983625411987305
737: 0.18424201011657715
738: 0.21043069660663605
739: 0.15675824880599976
740: 0.1647699773311615
741: 0.15237559378147125
742: 0.20153814554214478
743: 0.1562492698431015
744: 0.16950814425945282
745: 0.14594019949436188
746: 0.16053007543087006
747: 0.14517414569854736
748: 0.18328861892223358
749: 0.1643412709236145
750: 0.16599471867084503
751: 0.175

training:  94%|█████████▎| 935/1000 [00:50<00:03, 20.16it/s]

934: 0.045613840222358704
935: 0.06780410557985306
936: 0.05420812591910362
937: 0.05668637901544571
938: 0.05373804271221161
939: 0.06281745433807373
940: 0.06393373012542725
941: 0.054537542164325714
942: 0.05605629086494446
943: 0.07050280272960663
944: 0.058252185583114624
945: 0.05017292499542236
946: 0.06158602610230446
947: 0.052920304238796234
948: 0.06392765790224075
949: 0.08313881605863571
950: 0.04989655688405037
951: 0.0528954342007637
952: 0.053321123123168945
953: 0.047366756945848465
954: 0.053531669080257416
955: 0.05088716372847557
956: 0.06316587328910828
957: 0.03820590302348137
958: 0.04875728860497475
959: 0.050563398748636246
960: 0.054713938385248184
961: 0.05959688872098923
962: 0.055000875145196915
963: 0.05419083684682846
964: 0.044831741601228714
965: 0.04116735979914665
966: 0.04326649755239487
967: 0.044504545629024506
968: 0.05944674089550972
969: 0.05091234669089317
970: 0.06227782741189003
971: 0.04383992403745651
972: 0.05166330188512802
973: 0.0490566

training: 100%|██████████| 1000/1000 [00:53<00:00, 18.86it/s]

997: 0.03809478506445885
998: 0.040214717388153076
999: 0.043390389531850815





In [6]:
print(f'Accuracy : {(len(src) - incorrects)/len(src)*100}%')

Accuracy : 100.0%
