## Transformers architecture를 받아와 구축해보자!

Transformers 강의 + 실습까지 수강하시느라 너무 수고 많으셨습니다! 저도 저번 기수 당시 transformers를 처음 접했는데, 모델의 구조가 난해하고 쓰인 개념들도 어려워서 이해하는데 많은 시간이 걸렸었네요.. 그래도 transformers 모델이 현재 contemporary AI technology에 쓰이지 않는 곳이 없다보니, 어려운 내용들이 다수 있지만 힘주어서 중요한 내용들로 채워넣으려고 노력했습니다 수고하셨습니다 XD

Transformers은 개념도 개념이지만, 매번 새롭게 attention 코드를 짜고, encoder와 decoder 구조를 구축하는 것도 막막하실 겁니다! 다행히도 transformers architecture는 워낙 유명해서 이제 코드 몇줄만 `딸깍`해도 최신 논문기반 transformer 구조를 `huggingface` 혹은 `github`에서 받아서 사용할 수 있습니다 :D 이번 실습에는 초심자가 사용하기는 어렵지만 `x-transformers`에서 저희가 구축한 transformers 구조를 받아와, 예시 문장을 출력하는 것까지 마무리할 예정입니다!

transformers architecture가 2017년 나온 이후로, 이 architecture를 기반으로 한 여러 모델들이 나왔고, 또한 base model에 대해서도 여러 개선점들이 추가되었습니다. `x-transformers`은 이 개선된 model들을 코드 몇줄을 추가해서 적용 가능하게 하는 라이브러리로, 최신 논문 동향을 파악하고 있어야 한다는 점에서 어렵지만 그만큼 성능이 뒷받침해주는 코드들을 모아둔 라이브러리입니다.

TMI가 생각보다 길어졌는데, 아무튼 이번 과제는 따로 작성해 넣어야할 부분은 없고, 제가 드리는 코드를 그대로 실행하기만 하면 되는 과정으로 추가했습니다. 이제 개강이 얼마 남지 않았는데, 다들 화이팅입니다 XD

# Colab GPU 환경에서 구동하세요!

In [1]:
!pip install x-transformers

Collecting x-transformers
  Downloading x_transformers-1.19.1-py3-none-any.whl (27 kB)
Collecting einops>=0.6.1 (from x-transformers)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, x-transformers
Successfully installed einops-0.6.1 x-transformers-1.19.1


In [2]:
import torch
from x_transformers import XTransformer

## Transformers Architectures 받아오
model = XTransformer(
    ## 모델의 차원 (논문 = 512)
    dim = 16,
    ## encoder token의 개수 (논문 = 256)
    enc_num_tokens = 16,
    ## encoder 반복 횟수 (논문 = 6)
    enc_depth = 6,
    ## multihead attention n_heads 개수 (논문 = 8)
    enc_heads = 8,
    ## encoder token의 max sequence length (논문 = 1024)
    enc_max_seq_len = 32,
    ## (논문 = 256)
    dec_num_tokens = 16,
    dec_depth = 6,
    dec_heads = 8,
    ## (논문 = 1024)
    dec_max_seq_len = 32,
    tie_token_emb = True
).cuda()

In [3]:
## NUM_BATCHES = Epoches의 개수
## BATCH_SIZE = 하나의 batch에 들어갈 sample의 개수
## LEARNING_RATE = learning rate
## GENERATE_EVERY = 100번마다 한번씩 generate해서 accuracy확인
## NUM_TOKENS = 데이터 내 유니크한 토큰의 수
## ENC_SEQ_LEN = encoder sequence length

NUM_BATCHES = int(1e3)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 100
NUM_TOKENS = 4 + 2
ENC_SEQ_LEN = 8
DEC_SEQ_LEN = 16 + 1

In [4]:
## Transformer 모델에 넣을 임의의 src, tgt, src_mask 생성

def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
        yield (src, tgt, src_mask)

In [5]:
## Train Model

import tqdm
import torch.optim as optim

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## loss update 해가면서 학습
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask = next(cycle())

    loss = model(src, tgt, mask=src_mask)
    loss.backward()
    print(f'{i}: {loss.item()}')

    optim.step()
    optim.zero_grad()

    ## 매 N(100)번마다 accuracy 측정
    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

training:   0%|          | 0/1000 [00:00<?, ?it/s]

0: 3.5363845825195312
1: 2.5506021976470947
2: 2.264939785003662
3: 2.1915524005889893
4: 2.1163032054901123
5: 2.1038739681243896
6: 2.0782885551452637
7: 2.0580554008483887
8: 2.0479843616485596
9: 2.0629780292510986
10: 2.0210423469543457
11: 2.0202560424804688
12: 2.0394973754882812
13: 1.9996416568756104
14: 2.001347780227661
15: 1.9803589582443237
16: 2.037302017211914
17: 1.9408445358276367
18: 1.9772305488586426
19: 1.9449580907821655
20: 1.9746168851852417
21: 1.937604308128357
22: 1.9570003747940063
23: 1.9667229652404785
24: 1.9157633781433105
25: 1.9036816358566284
26: 1.8998273611068726
27: 1.865684151649475
28: 1.903259515762329
29: 1.8805909156799316
30: 1.8854260444641113
31: 1.88181734085083
32: 1.8762459754943848
33: 1.8737146854400635
34: 1.8870985507965088
35: 1.8390964269638062
36: 1.8458092212677002
37: 1.900241732597351
38: 1.7807025909423828
39: 1.8230148553848267
40: 1.838757038116455
41: 1.8340129852294922
42: 1.8241393566131592
43: 1.8297234773635864
44: 1.78

training:  16%|█▌        | 158/1000 [00:10<00:53, 15.74it/s]

157: 1.4008640050888062
158: 1.3414876461029053
159: 1.3916746377944946
160: 1.3706940412521362
161: 1.3766337633132935
162: 1.3539798259735107
163: 1.3607406616210938
164: 1.43787682056427
165: 1.4227697849273682
166: 1.3931187391281128
167: 1.3944611549377441
168: 1.4087345600128174
169: 1.359459638595581
170: 1.3567333221435547
171: 1.427065134048462
172: 1.3371367454528809
173: 1.3463525772094727
174: 1.383100986480713
175: 1.3941693305969238
176: 1.362324833869934
177: 1.350888729095459
178: 1.3419662714004517
179: 1.3503172397613525
180: 1.3416064977645874
181: 1.3637659549713135
182: 1.3054044246673584
183: 1.3577377796173096
184: 1.3595693111419678
185: 1.369030475616455
186: 1.3408730030059814
187: 1.3396656513214111
188: 1.3627790212631226
189: 1.3135662078857422
190: 1.3322205543518066
191: 1.2711786031723022
192: 1.351226806640625
193: 1.324966549873352
194: 1.272853136062622
195: 1.2978448867797852
196: 1.3033660650253296
197: 1.2445307970046997
198: 1.3250333070755005
199

training:  38%|███▊      | 376/1000 [00:20<00:32, 19.27it/s]

371: 0.8616665601730347
372: 0.8568291664123535
373: 0.9188927412033081
374: 0.8223289251327515
375: 0.8863011598587036
376: 0.8385462164878845
377: 0.8805311322212219
378: 0.8517838716506958
379: 0.8067492246627808
380: 0.8327056169509888
381: 0.8638747334480286
382: 0.8098494410514832
383: 0.8511287569999695
384: 0.8763384819030762
385: 0.7942199110984802
386: 0.8361011147499084
387: 0.7781546115875244
388: 0.7972903847694397
389: 0.809349775314331
390: 0.794104278087616
391: 0.7783336043357849
392: 0.7779484391212463
393: 0.7814876437187195
394: 0.7589930295944214
395: 0.7487598657608032
396: 0.7551804780960083
397: 0.7727528810501099
398: 0.7732232213020325
399: 0.7404147386550903
400: 0.782452404499054
input:   tensor([[5, 2, 2, 2, 3, 5, 3, 4]], device='cuda:0')
predicted output:   tensor([[2, 3, 5, 2, 4, 2, 3, 5]], device='cuda:0')
incorrects: 6
401: 0.7654131054878235
402: 0.7920652627944946
403: 0.7527590394020081
404: 0.733409583568573
405: 0.7691285610198975
406: 0.7404714226

training:  59%|█████▉    | 594/1000 [00:30<00:20, 20.03it/s]

593: 0.6201993227005005
594: 0.5860369801521301
595: 0.5901767015457153
596: 0.5635480880737305
597: 0.541772186756134
598: 0.6223844885826111
599: 0.5482400059700012
600: 0.5323712825775146
input:   tensor([[2, 5, 5, 4, 5, 5, 5, 4]], device='cuda:0')
predicted output:   tensor([[5, 5, 5, 4, 2, 5, 5, 2]], device='cuda:0')
incorrects: 3
601: 0.5699364542961121
602: 0.5544357895851135
603: 0.5613698363304138
604: 0.591215193271637
605: 0.537032425403595
606: 0.5391689538955688
607: 0.6059051156044006
608: 0.5837838649749756
609: 0.5927106142044067
610: 0.6013795137405396
611: 0.560204267501831
612: 0.5672552585601807
613: 0.5468578338623047
614: 0.5476542711257935
615: 0.544449508190155
616: 0.594987154006958
617: 0.6483073830604553
618: 0.5531762838363647
619: 0.6261215806007385
620: 0.554713249206543
621: 0.567674458026886
622: 0.6371515989303589
623: 0.5740925073623657
624: 0.5938554406166077
625: 0.6064944863319397
626: 0.5413174033164978
627: 0.5707780122756958
628: 0.57498657703399

training:  80%|████████  | 803/1000 [00:40<00:09, 20.25it/s]

802: 0.4406333565711975
803: 0.4151228070259094
804: 0.4597623944282532
805: 0.4236814081668854
806: 0.4766072928905487
807: 0.4499953091144562
808: 0.4274060130119324
809: 0.4987371861934662
810: 0.4443860650062561
811: 0.4724119305610657
812: 0.4471132159233093
813: 0.4212198257446289
814: 0.42613548040390015
815: 0.44065725803375244
816: 0.46517860889434814
817: 0.43049538135528564
818: 0.4191420078277588
819: 0.45147621631622314
820: 0.4133729934692383
821: 0.42965587973594666
822: 0.4595978856086731
823: 0.4216533303260803
824: 0.46896892786026
825: 0.4155671298503876
826: 0.4050505459308624
827: 0.43500420451164246
828: 0.4535307288169861
829: 0.4442615211009979
830: 0.43516892194747925
831: 0.43969064950942993
832: 0.4111848771572113
833: 0.43038082122802734
834: 0.4109736680984497
835: 0.39895129203796387
836: 0.4246698021888733
837: 0.4091208279132843
838: 0.3573041558265686
839: 0.3861424922943115
840: 0.3991720378398895
841: 0.40633293986320496
842: 0.44375309348106384
843: 

training: 100%|██████████| 1000/1000 [00:50<00:00, 19.94it/s]

997: 0.22514668107032776
998: 0.20265790820121765
999: 0.22559469938278198





In [6]:
print(f'Accuracy : {(len(src) - incorrects)/len(src)*100}%')

Accuracy : 93.75%
