## Transformers architecture를 받아와 구축해보자!

Transformers 강의 + 실습까지 수강하시느라 너무 수고 많으셨습니다! 저도 저번 기수 당시 transformers를 처음 접했는데, 모델의 구조가 난해하고 쓰인 개념들도 어려워서 이해하는데 많은 시간이 걸렸었네요.. 그래도 transformers 모델이 현재 contemporary AI technology에 쓰이지 않는 곳이 없다보니, 어려운 내용들이 다수 있지만 힘주어서 중요한 내용들로 채워넣으려고 노력했습니다 수고하셨습니다 XD

Transformers은 개념도 개념이지만, 매번 새롭게 attention 코드를 짜고, encoder와 decoder 구조를 구축하는 것도 막막하실 겁니다! 다행히도 transformers architecture는 워낙 유명해서 이제 코드 몇줄만 `딸깍`해도 최신 논문기반 transformer 구조를 `huggingface` 혹은 `github`에서 받아서 사용할 수 있습니다 :D 이번 실습에는 초심자가 사용하기는 어렵지만 `x-transformers`에서 저희가 구축한 transformers 구조를 받아와, 예시 문장을 출력하는 것까지 마무리할 예정입니다!

transformers architecture가 2017년 나온 이후로, 이 architecture를 기반으로 한 여러 모델들이 나왔고, 또한 base model에 대해서도 여러 개선점들이 추가되었습니다. `x-transformers`은 이 개선된 model들을 코드 몇줄을 추가해서 적용 가능하게 하는 라이브러리로, 최신 논문 동향을 파악하고 있어야 한다는 점에서 어렵지만 그만큼 성능이 뒷받침해주는 코드들을 모아둔 라이브러리입니다.

TMI가 생각보다 길어졌는데, 아무튼 이번 과제는 따로 작성해 넣어야할 부분은 없고, 제가 드리는 코드를 그대로 실행하기만 하면 되는 과정으로 추가했습니다. 이제 개강이 얼마 남지 않았는데, 다들 화이팅입니다 XD

# Colab GPU 환경에서 구동하세요!

In [1]:
!pip install x-transformers

Collecting x-transformers
  Downloading x_transformers-1.19.1-py3-none-any.whl (27 kB)
Collecting einops>=0.6.1 (from x-transformers)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, x-transformers
Successfully installed einops-0.6.1 x-transformers-1.19.1


In [2]:
import torch
from x_transformers import XTransformer

## Transformers Architectures 받아오
model = XTransformer(
    ## 모델의 차원 (논문 = 512)
    dim = 16,
    ## encoder token의 개수 (논문 = 256)
    enc_num_tokens = 16,
    ## encoder 반복 횟수 (논문 = 6)
    enc_depth = 6,
    ## multihead attention n_heads 개수 (논문 = 8)
    enc_heads = 8,
    ## encoder token의 max sequence length (논문 = 1024)
    enc_max_seq_len = 32,
    ## (논문 = 256)
    dec_num_tokens = 16,
    dec_depth = 6,
    dec_heads = 8,
    ## (논문 = 1024)
    dec_max_seq_len = 32,
    tie_token_emb = True
).cuda()

In [3]:
## NUM_BATCHES = Epoches의 개수
## BATCH_SIZE = 하나의 batch에 들어갈 sample의 개수
## LEARNING_RATE = learning rate
## GENERATE_EVERY = 100번마다 한번씩 generate해서 accuracy확인
## NUM_TOKENS = 데이터 내 유니크한 토큰의 수
## ENC_SEQ_LEN = encoder sequence length

NUM_BATCHES = int(1e3)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 100
NUM_TOKENS = 4 + 2
ENC_SEQ_LEN = 8
DEC_SEQ_LEN = 16 + 1

In [4]:
## Transformer 모델에 넣을 임의의 src, tgt, src_mask 생성

def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
        yield (src, tgt, src_mask)

In [5]:
## Train Model

import tqdm
import torch.optim as optim

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## loss update 해가면서 학습
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask = next(cycle())

    loss = model(src, tgt, mask=src_mask)
    loss.backward()
    print(f'{i}: {loss.item()}')

    optim.step()
    optim.zero_grad()

    ## 매 N(100)번마다 accuracy 측정
    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

training:   0%|          | 0/1000 [00:00<?, ?it/s]

0: 2.727139949798584
1: 2.256948709487915
2: 2.101267099380493
3: 2.083951234817505
4: 2.06235933303833
5: 2.0390946865081787
6: 2.0413966178894043
7: 1.996823787689209
8: 1.9823178052902222
9: 2.0146498680114746
10: 1.9566376209259033
11: 1.957607388496399
12: 1.984459638595581
13: 1.9396376609802246
14: 1.9336129426956177
15: 1.9199473857879639
16: 1.9103516340255737
17: 1.8849718570709229
18: 1.857194185256958
19: 1.9075984954833984
20: 1.931187391281128
21: 1.8467379808425903
22: 1.8843120336532593
23: 1.862120509147644
24: 1.8483104705810547
25: 1.8629496097564697
26: 1.8523776531219482
27: 1.8244199752807617
28: 1.8437343835830688
29: 1.8581794500350952
30: 1.8063095808029175
31: 1.816646695137024
32: 1.8436893224716187
33: 1.780761957168579
34: 1.8331599235534668
35: 1.7513445615768433
36: 1.7589415311813354
37: 1.776625394821167
38: 1.7697721719741821
39: 1.7656141519546509
40: 1.7993988990783691
41: 1.6931922435760498
42: 1.7161262035369873
43: 1.7516365051269531
44: 1.7560690

training:  14%|█▍        | 144/1000 [00:10<00:59, 14.37it/s]

140: 1.43464994430542
141: 1.4628949165344238
142: 1.4320179224014282
143: 1.4127565622329712
144: 1.430187463760376
145: 1.4286540746688843
146: 1.4170945882797241
147: 1.39911687374115
148: 1.4257041215896606
149: 1.3905481100082397
150: 1.4044239521026611
151: 1.3863444328308105
152: 1.3849811553955078
153: 1.418110966682434
154: 1.4091557264328003
155: 1.441887378692627
156: 1.3906947374343872
157: 1.4362726211547852
158: 1.370867133140564
159: 1.3843990564346313
160: 1.361920952796936
161: 1.3906621932983398
162: 1.3923856019973755
163: 1.3891807794570923
164: 1.3539760112762451
165: 1.337084174156189
166: 1.3391385078430176
167: 1.3581598997116089
168: 1.3848265409469604
169: 1.3302966356277466
170: 1.3587265014648438
171: 1.3496084213256836
172: 1.3205713033676147
173: 1.2768058776855469
174: 1.3003612756729126
175: 1.3401587009429932
176: 1.3159152269363403
177: 1.3198906183242798
178: 1.3073817491531372
179: 1.3209714889526367
180: 1.3838685750961304
181: 1.4120445251464844
18

training:  37%|███▋      | 369/1000 [00:20<00:32, 19.15it/s]

365: 0.8893011212348938
366: 0.8483202457427979
367: 0.8959168195724487
368: 0.891314685344696
369: 0.8538379669189453
370: 0.9257286190986633
371: 0.812914252281189
372: 0.8709568977355957
373: 0.8290241956710815
374: 0.8217141628265381
375: 0.8195197582244873
376: 0.859210729598999
377: 0.8245171904563904
378: 0.8434295058250427
379: 1.017173171043396
380: 0.804629921913147
381: 0.8959485292434692
382: 0.8631798028945923
383: 0.8479448556900024
384: 0.9668976664543152
385: 0.8670132160186768
386: 0.8206992149353027
387: 0.8500629663467407
388: 0.8203577995300293
389: 0.8099605441093445
390: 0.903663158416748
391: 0.8553762435913086
392: 0.903647243976593
393: 0.8206267952919006
394: 0.8315045237541199
395: 0.8247908353805542
396: 0.8399472832679749
397: 0.8258156180381775
398: 0.7626250386238098
399: 0.7667422294616699
400: 0.7754154801368713
input:   tensor([[2, 5, 2, 5, 5, 4, 4, 5]], device='cuda:0')
predicted output:   tensor([[5, 5, 2, 4, 5, 4, 2, 5]], device='cuda:0')
incorrects

training:  59%|█████▉    | 594/1000 [00:30<00:19, 20.50it/s]

591: 0.5493804216384888
592: 0.517044186592102
593: 0.554677426815033
594: 0.5148018598556519
595: 0.529457688331604
596: 0.5528676509857178
597: 0.5247458219528198
598: 0.5132182240486145
599: 0.5382312536239624
600: 0.5319815278053284
input:   tensor([[5, 3, 5, 2, 3, 2, 4, 4]], device='cuda:0')
predicted output:   tensor([[3, 5, 2, 4, 3, 5, 2, 4]], device='cuda:0')
incorrects: 6
601: 0.5138307213783264
602: 0.6018282175064087
603: 0.5056847929954529
604: 0.5540526509284973
605: 0.5414960384368896
606: 0.5732665061950684
607: 0.5130674242973328
608: 0.6075820922851562
609: 0.5329014658927917
610: 0.6369239091873169
611: 0.5592753291130066
612: 0.5520699620246887
613: 0.6331421136856079
614: 0.5753402709960938
615: 0.5838153958320618
616: 0.6405737400054932
617: 0.5442283749580383
618: 0.5899227261543274
619: 0.5522559285163879
620: 0.7064933180809021
621: 0.5815100073814392
622: 0.617612361907959
623: 0.5768929123878479
624: 0.6110059022903442
625: 0.5863382816314697
626: 0.5145580768

training:  82%|████████▏ | 815/1000 [00:40<00:08, 20.60it/s]

810: 0.4580574929714203
811: 0.458659827709198
812: 0.4141996502876282
813: 0.43730872869491577
814: 0.4399748146533966
815: 0.44886574149131775
816: 0.43412986397743225
817: 0.42146769165992737
818: 0.42971786856651306
819: 0.4362533390522003
820: 0.4237402677536011
821: 0.45426926016807556
822: 0.4386739134788513
823: 0.4317070245742798
824: 0.42801493406295776
825: 0.42091861367225647
826: 0.44224226474761963
827: 0.42656320333480835
828: 0.44456785917282104
829: 0.4375540018081665
830: 0.43150416016578674
831: 0.4452114403247833
832: 0.43491289019584656
833: 0.4311659336090088
834: 0.4449954032897949
835: 0.41446709632873535
836: 0.4260168671607971
837: 0.4070156514644623
838: 0.42055824398994446
839: 0.44950681924819946
840: 0.4371493458747864
841: 0.4398782551288605
842: 0.45634356141090393
843: 0.41753992438316345
844: 0.43883994221687317
845: 0.4442380368709564
846: 0.4560803174972534
847: 0.4322143495082855
848: 0.4433416426181793
849: 0.4294085204601288
850: 0.412084758281707

training: 100%|██████████| 1000/1000 [00:49<00:00, 20.19it/s]

998: 0.4015270173549652
999: 0.44203439354896545





In [6]:
print(f'Accuracy : {(len(src) - incorrects)/len(src)*100}%')

Accuracy : 81.25%
