## Transformers architecture를 받아와 구축해보자!

Transformers 강의 + 실습까지 수강하시느라 너무 수고 많으셨습니다! 저도 저번 기수 당시 transformers를 처음 접했는데, 모델의 구조가 난해하고 쓰인 개념들도 어려워서 이해하는데 많은 시간이 걸렸었네요.. 그래도 transformers 모델이 현재 contemporary AI technology에 쓰이지 않는 곳이 없다보니, 어려운 내용들이 다수 있지만 힘주어서 중요한 내용들로 채워넣으려고 노력했습니다 수고하셨습니다 XD

Transformers은 개념도 개념이지만, 매번 새롭게 attention 코드를 짜고, encoder와 decoder 구조를 구축하는 것도 막막하실 겁니다! 다행히도 transformers architecture는 워낙 유명해서 이제 코드 몇줄만 `딸깍`해도 최신 논문기반 transformer 구조를 `huggingface` 혹은 `github`에서 받아서 사용할 수 있습니다 :D 이번 실습에는 초심자가 사용하기는 어렵지만 `x-transformers`에서 저희가 구축한 transformers 구조를 받아와, 예시 문장을 출력하는 것까지 마무리할 예정입니다!

transformers architecture가 2017년 나온 이후로, 이 architecture를 기반으로 한 여러 모델들이 나왔고, 또한 base model에 대해서도 여러 개선점들이 추가되었습니다. `x-transformers`은 이 개선된 model들을 코드 몇줄을 추가해서 적용 가능하게 하는 라이브러리로, 최신 논문 동향을 파악하고 있어야 한다는 점에서 어렵지만 그만큼 성능이 뒷받침해주는 코드들을 모아둔 라이브러리입니다.

TMI가 생각보다 길어졌는데, 아무튼 이번 과제는 따로 작성해 넣어야할 부분은 없고, 제가 드리는 코드를 그대로 실행하기만 하면 되는 과정으로 추가했습니다. 이제 개강이 얼마 남지 않았는데, 다들 화이팅입니다 XD

# Colab GPU 환경에서 구동하세요!

In [1]:
!pip install x-transformers

Collecting x-transformers
  Downloading x_transformers-1.19.1-py3-none-any.whl (27 kB)
Collecting einops>=0.6.1 (from x-transformers)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, x-transformers
Successfully installed einops-0.6.1 x-transformers-1.19.1


In [2]:
import torch
from x_transformers import XTransformer

## Transformers Architectures 받아오
model = XTransformer(
    ## 모델의 차원 (논문 = 512)
    dim = 16,
    ## encoder token의 개수 (논문 = 256)
    enc_num_tokens = 16,
    ## encoder 반복 횟수 (논문 = 6)
    enc_depth = 6,
    ## multihead attention n_heads 개수 (논문 = 8)
    enc_heads = 8,
    ## encoder token의 max sequence length (논문 = 1024)
    enc_max_seq_len = 32,
    ## (논문 = 256)
    dec_num_tokens = 16,
    dec_depth = 6,
    dec_heads = 8,
    ## (논문 = 1024)
    dec_max_seq_len = 32,
    tie_token_emb = True
).cuda()

In [3]:
## NUM_BATCHES = Epoches의 개수
## BATCH_SIZE = 하나의 batch에 들어갈 sample의 개수
## LEARNING_RATE = learning rate
## GENERATE_EVERY = 100번마다 한번씩 generate해서 accuracy확인
## NUM_TOKENS = 데이터 내 유니크한 토큰의 수
## ENC_SEQ_LEN = encoder sequence length

NUM_BATCHES = int(1e3)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 100
NUM_TOKENS = 4 + 2
ENC_SEQ_LEN = 8
DEC_SEQ_LEN = 16 + 1

In [4]:
## Transformer 모델에 넣을 임의의 src, tgt, src_mask 생성

def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
        yield (src, tgt, src_mask)

In [5]:
## Train Model

import tqdm
import torch.optim as optim

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## loss update 해가면서 학습
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask = next(cycle())

    loss = model(src, tgt, mask=src_mask)
    loss.backward()
    print(f'{i}: {loss.item()}')

    optim.step()
    optim.zero_grad()

    ## 매 N(100)번마다 accuracy 측정
    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

training:   0%|          | 0/1000 [00:00<?, ?it/s]

0: 3.1176939010620117
1: 2.4334042072296143
2: 2.280622959136963
3: 2.2220051288604736
4: 2.1654574871063232
5: 2.148639440536499
6: 2.1320667266845703
7: 2.145817518234253
8: 2.107740879058838
9: 2.1001980304718018
10: 2.1098666191101074
11: 2.1100387573242188
12: 2.0546813011169434
13: 2.062079429626465
14: 2.048654556274414
15: 2.0194900035858154
16: 2.022160053253174
17: 2.018693447113037
18: 2.042724847793579
19: 1.9899821281433105
20: 1.9965746402740479
21: 1.991740345954895
22: 2.0088353157043457
23: 1.9996373653411865
24: 1.9766826629638672
25: 1.9871537685394287
26: 1.9633435010910034
27: 1.969401478767395
28: 1.953279733657837
29: 1.9461917877197266
30: 1.9303209781646729
31: 1.9100186824798584
32: 1.9359009265899658
33: 1.9420424699783325
34: 1.9267436265945435
35: 1.898440957069397
36: 1.8982547521591187
37: 1.904006004333496
38: 1.884267807006836
39: 1.902586817741394
40: 1.8877873420715332
41: 1.8649033308029175
42: 1.8482123613357544
43: 1.9159338474273682
44: 1.85727763

training:  16%|█▌        | 156/1000 [00:10<00:54, 15.57it/s]

151: 1.4449341297149658
152: 1.4689693450927734
153: 1.4529794454574585
154: 1.4415712356567383
155: 1.4423469305038452
156: 1.4338527917861938
157: 1.4616342782974243
158: 1.4288718700408936
159: 1.4360359907150269
160: 1.4214022159576416
161: 1.4197133779525757
162: 1.3872158527374268
163: 1.400247573852539
164: 1.4215264320373535
165: 1.3795506954193115
166: 1.407945990562439
167: 1.3610520362854004
168: 1.3699232339859009
169: 1.3541046380996704
170: 1.3737233877182007
171: 1.4321997165679932
172: 1.3738764524459839
173: 1.333970308303833
174: 1.4193971157073975
175: 1.4283690452575684
176: 1.380033016204834
177: 1.3820531368255615
178: 1.3477119207382202
179: 1.3541285991668701
180: 1.375797152519226
181: 1.3430936336517334
182: 1.3454482555389404
183: 1.3438425064086914
184: 1.3175123929977417
185: 1.377790093421936
186: 1.2953691482543945
187: 1.3370099067687988
188: 1.251798391342163
189: 1.3478540182113647
190: 1.3235260248184204
191: 1.2642221450805664
192: 1.3068963289260864

training:  37%|███▋      | 373/1000 [00:20<00:32, 19.17it/s]

372: 0.7209097743034363
373: 0.7858580946922302
374: 0.7784229516983032
375: 0.7791374921798706
376: 0.7299410700798035
377: 0.8176723122596741
378: 0.7813855409622192
379: 0.7661973237991333
380: 0.7518871426582336
381: 0.7738476991653442
382: 0.8435282111167908
383: 0.787021815776825
384: 0.7448912858963013
385: 0.7270357608795166
386: 0.7456527352333069
387: 0.7399094104766846
388: 0.7429004311561584
389: 0.7698313593864441
390: 0.7554579377174377
391: 0.741539478302002
392: 0.7733910083770752
393: 0.7367152571678162
394: 0.8059340715408325
395: 0.7375515103340149
396: 0.7327673435211182
397: 0.726486086845398
398: 0.7043448090553284
399: 0.70052170753479
400: 0.7533358931541443
input:   tensor([[3, 3, 5, 2, 5, 4, 3, 3]], device='cuda:0')
predicted output:   tensor([[3, 5, 3, 2, 4, 3, 3, 5]], device='cuda:0')
incorrects: 5
401: 0.7263174057006836
402: 0.6913563013076782
403: 0.6887167096138
404: 0.6767765879631042
405: 0.7023733854293823
406: 0.6662542819976807
407: 0.71832031011581

training:  59%|█████▉    | 590/1000 [00:30<00:20, 19.69it/s]

588: 0.5381784439086914
589: 0.5187340378761292
590: 0.5202193260192871
591: 0.5495099425315857
592: 0.4943535029888153
593: 0.5062273740768433
594: 0.4820919334888458
595: 0.4829895794391632
596: 0.505492091178894
597: 0.5015062093734741
598: 0.460236132144928
599: 0.4799042344093323
600: 0.535198450088501
input:   tensor([[4, 3, 4, 5, 3, 5, 4, 4]], device='cuda:0')
predicted output:   tensor([[4, 4, 3, 5, 4, 5, 3, 4]], device='cuda:0')
incorrects: 4
601: 0.49098193645477295
602: 0.4929064214229584
603: 0.47974374890327454
604: 0.48951688408851624
605: 0.48303279280662537
606: 0.46862897276878357
607: 0.48988157510757446
608: 0.4703526496887207
609: 0.4718295633792877
610: 0.49217286705970764
611: 0.4854438006877899
612: 0.48611128330230713
613: 0.4727107584476471
614: 0.5149185657501221
615: 0.5016424655914307
616: 0.48974403738975525
617: 0.5036092400550842
618: 0.474759042263031
619: 0.47964709997177124
620: 0.520108163356781
621: 0.47165003418922424
622: 0.4669303297996521
623: 0.

training:  80%|███████▉  | 796/1000 [00:40<00:10, 20.03it/s]

793: 0.41379356384277344
794: 0.45035243034362793
795: 0.4128963053226471
796: 0.4848705530166626
797: 0.4254271388053894
798: 0.4633294343948364
799: 0.5140081644058228
800: 0.41193869709968567
input:   tensor([[3, 5, 3, 2, 5, 4, 4, 4]], device='cuda:0')
predicted output:   tensor([[3, 4, 5, 5, 3, 2, 4, 3]], device='cuda:0')
incorrects: 6
801: 0.44204840064048767
802: 0.43299371004104614
803: 0.42469269037246704
804: 0.42203593254089355
805: 0.41315382719039917
806: 0.41748276352882385
807: 0.4136725068092346
808: 0.4063330292701721
809: 0.43636077642440796
810: 0.3912659287452698
811: 0.3918258845806122
812: 0.40856438875198364
813: 0.3970547914505005
814: 0.40005868673324585
815: 0.4193633198738098
816: 0.39939355850219727
817: 0.3844432532787323
818: 0.42135173082351685
819: 0.3968737721443176
820: 0.4137842059135437
821: 0.41121238470077515
822: 0.3953211009502411
823: 0.40992602705955505
824: 0.38067492842674255
825: 0.3965337872505188
826: 0.44049733877182007
827: 0.415124028921

training: 100%|██████████| 1000/1000 [00:50<00:00, 19.89it/s]

998: 0.3704656660556793
999: 0.38309988379478455





In [6]:
print(f'Accuracy : {(len(src) - incorrects)/len(src)*100}%')

Accuracy : 81.25%
