## Transformers architecture를 받아와 구축해보자!

Transformers 강의 + 실습까지 수강하시느라 너무 수고 많으셨습니다! 저도 저번 기수 당시 transformers를 처음 접했는데, 모델의 구조가 난해하고 쓰인 개념들도 어려워서 이해하는데 많은 시간이 걸렸었네요.. 그래도 transformers 모델이 현재 contemporary AI technology에 쓰이지 않는 곳이 없다보니, 어려운 내용들이 다수 있지만 힘주어서 중요한 내용들로 채워넣으려고 노력했습니다 수고하셨습니다 XD

Transformers은 개념도 개념이지만, 매번 새롭게 attention 코드를 짜고, encoder와 decoder 구조를 구축하는 것도 막막하실 겁니다! 다행히도 transformers architecture는 워낙 유명해서 이제 코드 몇줄만 `딸깍`해도 최신 논문기반 transformer 구조를 `huggingface` 혹은 `github`에서 받아서 사용할 수 있습니다 :D 이번 실습에는 초심자가 사용하기는 어렵지만 `x-transformers`에서 저희가 구축한 transformers 구조를 받아와, 예시 문장을 출력하는 것까지 마무리할 예정입니다!

transformers architecture가 2017년 나온 이후로, 이 architecture를 기반으로 한 여러 모델들이 나왔고, 또한 base model에 대해서도 여러 개선점들이 추가되었습니다. `x-transformers`은 이 개선된 model들을 코드 몇줄을 추가해서 적용 가능하게 하는 라이브러리로, 최신 논문 동향을 파악하고 있어야 한다는 점에서 어렵지만 그만큼 성능이 뒷받침해주는 코드들을 모아둔 라이브러리입니다.

TMI가 생각보다 길어졌는데, 아무튼 이번 과제는 따로 작성해 넣어야할 부분은 없고, 제가 드리는 코드를 그대로 실행하기만 하면 되는 과정으로 추가했습니다. 이제 개강이 얼마 남지 않았는데, 다들 화이팅입니다 XD

# Colab GPU 환경에서 구동하세요!

In [1]:
!pip install x-transformers

Collecting x-transformers
  Downloading x_transformers-1.19.1-py3-none-any.whl (27 kB)
Collecting einops>=0.6.1 (from x-transformers)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, x-transformers
Successfully installed einops-0.6.1 x-transformers-1.19.1


In [2]:
import torch
from x_transformers import XTransformer

## Transformers Architectures 받아오
model = XTransformer(
    ## 모델의 차원 (논문 = 512)
    dim = 16,
    ## encoder token의 개수 (논문 = 256)
    enc_num_tokens = 16,
    ## encoder 반복 횟수 (논문 = 6)
    enc_depth = 6,
    ## multihead attention n_heads 개수 (논문 = 8)
    enc_heads = 8,
    ## encoder token의 max sequence length (논문 = 1024)
    enc_max_seq_len = 32,
    ## (논문 = 256)
    dec_num_tokens = 16,
    dec_depth = 6,
    dec_heads = 8,
    ## (논문 = 1024)
    dec_max_seq_len = 32,
    tie_token_emb = True
).cuda()

In [3]:
## NUM_BATCHES = Epoches의 개수
## BATCH_SIZE = 하나의 batch에 들어갈 sample의 개수
## LEARNING_RATE = learning rate
## GENERATE_EVERY = 100번마다 한번씩 generate해서 accuracy확인
## NUM_TOKENS = 데이터 내 유니크한 토큰의 수
## ENC_SEQ_LEN = encoder sequence length

NUM_BATCHES = int(1e3)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 100
NUM_TOKENS = 4 + 2
ENC_SEQ_LEN = 8
DEC_SEQ_LEN = 16 + 1

In [4]:
## Transformer 모델에 넣을 임의의 src, tgt, src_mask 생성

def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
        yield (src, tgt, src_mask)

In [5]:
## Train Model

import tqdm
import torch.optim as optim

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## loss update 해가면서 학습
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask = next(cycle())

    loss = model(src, tgt, mask=src_mask)
    loss.backward()
    print(f'{i}: {loss.item()}')

    optim.step()
    optim.zero_grad()

    ## 매 N(100)번마다 accuracy 측정
    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

training:   0%|          | 0/1000 [00:00<?, ?it/s]

0: 2.967717409133911
1: 2.310065984725952
2: 2.0980381965637207
3: 1.9792176485061646
4: 1.940903902053833
5: 2.0049703121185303
6: 1.9453977346420288
7: 1.9440948963165283
8: 1.9559404850006104
9: 1.9389418363571167
10: 1.9162236452102661
11: 1.8632274866104126
12: 1.853509545326233
13: 1.832919716835022
14: 1.8812745809555054
15: 1.846367597579956
16: 1.8738512992858887
17: 1.8542276620864868
18: 1.8362722396850586
19: 1.8372690677642822
20: 1.8493646383285522
21: 1.8070871829986572
22: 1.8197567462921143
23: 1.8205344676971436
24: 1.7927122116088867
25: 1.7314341068267822
26: 1.743080735206604
27: 1.777055263519287
28: 1.7751967906951904
29: 1.7418677806854248
30: 1.7336137294769287
31: 1.7423092126846313
32: 1.7371660470962524
33: 1.7616827487945557
34: 1.7542176246643066
35: 1.7242411375045776
36: 1.6979230642318726
37: 1.7378978729248047
38: 1.726595401763916
39: 1.7083709239959717
40: 1.669858694076538
41: 1.716749906539917
42: 1.708418369293213
43: 1.6986348628997803
44: 1.7184

training:   8%|▊         | 84/1000 [00:10<01:50,  8.27it/s]

83: 1.5655711889266968
84: 1.5707340240478516
85: 1.6011946201324463
86: 1.5734988451004028
87: 1.576660394668579
88: 1.5538846254348755
89: 1.6063388586044312
90: 1.5374525785446167
91: 1.5830274820327759
92: 1.5152256488800049
93: 1.5599007606506348
94: 1.5491526126861572
95: 1.556069254875183
96: 1.534737229347229
97: 1.583491563796997
98: 1.5431830883026123
99: 1.523798942565918
100: 1.5227258205413818
input:   tensor([[3, 3, 2, 2, 2, 5, 5, 3]], device='cuda:0')
predicted output:   tensor([[2, 5, 2, 3, 5, 5, 3, 5]], device='cuda:0')
incorrects: 6
101: 1.5056992769241333
102: 1.559043049812317
103: 1.5022737979888916
104: 1.5039485692977905
105: 1.5455212593078613
106: 1.5143779516220093
107: 1.506245732307434
108: 1.5240131616592407
109: 1.5346825122833252
110: 1.498958706855774
111: 1.474863886833191
112: 1.496057152748108
113: 1.51654851436615
114: 1.4753624200820923
115: 1.493377685546875
116: 1.487906575202942
117: 1.5021562576293945
118: 1.4970651865005493
119: 1.4577969312667

training:  28%|██▊       | 280/1000 [00:20<00:48, 14.87it/s]

278: 1.0121833086013794
279: 1.002677083015442
280: 0.9898850917816162
281: 1.0483893156051636
282: 1.01412832736969
283: 1.0267572402954102
284: 0.9851580262184143
285: 1.0062711238861084
286: 1.0008223056793213
287: 1.0161702632904053
288: 0.9937629699707031
289: 1.0222787857055664
290: 1.0033801794052124
291: 1.0144544839859009
292: 0.9550102949142456
293: 1.0206443071365356
294: 1.0068098306655884
295: 0.9938917756080627
296: 0.9944773316383362
297: 0.9881764650344849
298: 0.9665540456771851
299: 0.9619224071502686
300: 0.9384415149688721
input:   tensor([[3, 2, 3, 3, 5, 3, 3, 3]], device='cuda:0')
predicted output:   tensor([[3, 3, 5, 3, 2, 3, 4, 3]], device='cuda:0')
incorrects: 4
301: 0.9411118030548096
302: 0.947033703327179
303: 0.9527806043624878
304: 0.9698815941810608
305: 0.9925892353057861
306: 0.9133545756340027
307: 0.9936699271202087
308: 0.9340470433235168
309: 0.9412410259246826
310: 0.9369907379150391
311: 0.9035581946372986
312: 0.9032362103462219
313: 0.9422871470

training:  48%|████▊     | 476/1000 [00:30<00:30, 16.95it/s]

475: 0.6174309253692627
476: 0.6124481558799744
477: 0.6207864880561829
478: 0.6509138345718384
479: 0.6009296774864197
480: 0.6260566711425781
481: 0.6124959588050842
482: 0.6270242929458618
483: 0.6401156783103943
484: 0.6353281736373901
485: 0.6037055850028992
486: 0.5908169746398926
487: 0.5927097797393799
488: 0.5717766880989075
489: 0.5854120254516602
490: 0.6209537982940674
491: 0.5786031484603882
492: 0.6161050200462341
493: 0.6176379323005676
494: 0.5941520929336548
495: 0.5872821807861328
496: 0.582821249961853
497: 0.5620584487915039
498: 0.579217255115509
499: 0.6172512769699097
500: 0.5518900156021118
input:   tensor([[4, 5, 3, 4, 2, 4, 4, 5]], device='cuda:0')
predicted output:   tensor([[4, 5, 2, 4, 4, 3, 4, 5]], device='cuda:0')
incorrects: 3
501: 0.6268106698989868
502: 0.5572391152381897
503: 0.5996753573417664
504: 0.6128614544868469
505: 0.5622513294219971
506: 0.5725207328796387
507: 0.5853000283241272
508: 0.5379571318626404
509: 0.6000652313232422
510: 0.55038326

training:  68%|██████▊   | 679/1000 [00:40<00:17, 18.24it/s]

675: 0.3848828375339508
676: 0.4063850939273834
677: 0.38584113121032715
678: 0.35551202297210693
679: 0.37169691920280457
680: 0.39570367336273193
681: 0.3960769474506378
682: 0.3870064914226532
683: 0.3749505281448364
684: 0.37962380051612854
685: 0.4340254068374634
686: 0.39383479952812195
687: 0.3979518413543701
688: 0.41835543513298035
689: 0.36988234519958496
690: 0.4206797182559967
691: 0.38051050901412964
692: 0.39775651693344116
693: 0.40363818407058716
694: 0.3910568356513977
695: 0.40735194087028503
696: 0.3845386505126953
697: 0.43114596605300903
698: 0.37162941694259644
699: 0.38245558738708496
700: 0.45845016837120056
input:   tensor([[2, 4, 2, 5, 3, 5, 3, 4]], device='cuda:0')
predicted output:   tensor([[2, 2, 4, 3, 5, 3, 5, 4]], device='cuda:0')
incorrects: 6
701: 0.35935667157173157
702: 0.4930676519870758
703: 0.40680354833602905
704: 0.3764704167842865
705: 0.49221429228782654
706: 0.39733970165252686
707: 0.4009998142719269
708: 0.41168642044067383
709: 0.392928183

training:  88%|████████▊ | 884/1000 [00:50<00:06, 19.04it/s]

883: 0.12561944127082825
884: 0.09075899422168732
885: 0.07698749750852585
886: 0.0827365294098854
887: 0.09870563447475433
888: 0.07895971834659576
889: 0.09098190069198608
890: 0.09073847532272339
891: 0.07083693891763687
892: 0.08723028004169464
893: 0.0846836268901825
894: 0.08023647218942642
895: 0.08282266557216644
896: 0.07397863268852234
897: 0.07361968606710434
898: 0.08432047814130783
899: 0.0780077576637268
900: 0.10908015817403793
input:   tensor([[3, 3, 2, 3, 4, 3, 3, 3]], device='cuda:0')
predicted output:   tensor([[3, 3, 2, 3, 4, 3, 3, 3]], device='cuda:0')
incorrects: 0
901: 0.08158114552497864
902: 0.0776631236076355
903: 0.07418902218341827
904: 0.09654264897108078
905: 0.07542961835861206
906: 0.09981941431760788
907: 0.08068098872900009
908: 0.06192578002810478
909: 0.07627157866954803
910: 0.09054531157016754
911: 0.06423592567443848
912: 0.0836353749036789
913: 0.09476631879806519
914: 0.0867752805352211
915: 0.08575737476348877
916: 0.0790274515748024
917: 0.070

training: 100%|██████████| 1000/1000 [00:55<00:00, 17.91it/s]

995: 0.07519640028476715
996: 0.05161571130156517
997: 0.04285477101802826
998: 0.0902014896273613
999: 0.06234462186694145





In [6]:
print(f'Accuracy : {(len(src) - incorrects)/len(src)*100}%')

Accuracy : 100.0%
