## Transformers architecture를 받아와 구축해보자!

Transformers 강의 + 실습까지 수강하시느라 너무 수고 많으셨습니다! 저도 저번 기수 당시 transformers를 처음 접했는데, 모델의 구조가 난해하고 쓰인 개념들도 어려워서 이해하는데 많은 시간이 걸렸었네요.. 그래도 transformers 모델이 현재 contemporary AI technology에 쓰이지 않는 곳이 없다보니, 어려운 내용들이 다수 있지만 힘주어서 중요한 내용들로 채워넣으려고 노력했습니다 수고하셨습니다 XD

Transformers은 개념도 개념이지만, 매번 새롭게 attention 코드를 짜고, encoder와 decoder 구조를 구축하는 것도 막막하실 겁니다! 다행히도 transformers architecture는 워낙 유명해서 이제 코드 몇줄만 `딸깍`해도 최신 논문기반 transformer 구조를 `huggingface` 혹은 `github`에서 받아서 사용할 수 있습니다 :D 이번 실습에는 초심자가 사용하기는 어렵지만 `x-transformers`에서 저희가 구축한 transformers 구조를 받아와, 예시 문장을 출력하는 것까지 마무리할 예정입니다!

transformers architecture가 2017년 나온 이후로, 이 architecture를 기반으로 한 여러 모델들이 나왔고, 또한 base model에 대해서도 여러 개선점들이 추가되었습니다. `x-transformers`은 이 개선된 model들을 코드 몇줄을 추가해서 적용 가능하게 하는 라이브러리로, 최신 논문 동향을 파악하고 있어야 한다는 점에서 어렵지만 그만큼 성능이 뒷받침해주는 코드들을 모아둔 라이브러리입니다.

TMI가 생각보다 길어졌는데, 아무튼 이번 과제는 따로 작성해 넣어야할 부분은 없고, 제가 드리는 코드를 그대로 실행하기만 하면 되는 과정으로 추가했습니다. 이제 개강이 얼마 남지 않았는데, 다들 화이팅입니다 XD

# Colab GPU 환경에서 구동하세요!

In [1]:
!pip install x-transformers

Collecting x-transformers
  Downloading x_transformers-1.19.1-py3-none-any.whl (27 kB)
Collecting einops>=0.6.1 (from x-transformers)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, x-transformers
Successfully installed einops-0.6.1 x-transformers-1.19.1


In [2]:
import torch
from x_transformers import XTransformer

## Transformers Architectures 받아오
model = XTransformer(
    ## 모델의 차원 (논문 = 512)
    dim = 16,
    ## encoder token의 개수 (논문 = 256)
    enc_num_tokens = 16,
    ## encoder 반복 횟수 (논문 = 6)
    enc_depth = 6,
    ## multihead attention n_heads 개수 (논문 = 8)
    enc_heads = 8,
    ## encoder token의 max sequence length (논문 = 1024)
    enc_max_seq_len = 32,
    ## (논문 = 256)
    dec_num_tokens = 16,
    dec_depth = 6,
    dec_heads = 8,
    ## (논문 = 1024)
    dec_max_seq_len = 32,
    tie_token_emb = True
).cuda()

In [3]:
## NUM_BATCHES = Epoches의 개수
## BATCH_SIZE = 하나의 batch에 들어갈 sample의 개수
## LEARNING_RATE = learning rate
## GENERATE_EVERY = 100번마다 한번씩 generate해서 accuracy확인
## NUM_TOKENS = 데이터 내 유니크한 토큰의 수
## ENC_SEQ_LEN = encoder sequence length

NUM_BATCHES = int(1e3)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 100
NUM_TOKENS = 4 + 2
ENC_SEQ_LEN = 8
DEC_SEQ_LEN = 16 + 1

In [4]:
## Transformer 모델에 넣을 임의의 src, tgt, src_mask 생성

def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
        yield (src, tgt, src_mask)

In [5]:
## Train Model

import tqdm
import torch.optim as optim

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## loss update 해가면서 학습
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask = next(cycle())

    loss = model(src, tgt, mask=src_mask)
    loss.backward()
    print(f'{i}: {loss.item()}')

    optim.step()
    optim.zero_grad()

    ## 매 N(100)번마다 accuracy 측정
    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

training:   0%|          | 0/1000 [00:00<?, ?it/s]

0: 3.401230573654175
1: 2.501105785369873
2: 2.297044277191162
3: 2.1846704483032227
4: 2.129467248916626
5: 2.1679232120513916
6: 2.087216854095459
7: 2.1017708778381348
8: 2.1031064987182617
9: 2.050971031188965
10: 2.0831520557403564
11: 2.0753841400146484
12: 2.022801399230957
13: 2.0084080696105957
14: 2.0599286556243896
15: 2.0332210063934326
16: 2.0613362789154053
17: 2.041513681411743
18: 2.0077953338623047
19: 2.021064519882202
20: 1.9977418184280396
21: 1.9609863758087158
22: 1.9783103466033936
23: 1.9388624429702759
24: 1.9403854608535767
25: 2.0045292377471924
26: 1.9563367366790771
27: 1.9620602130889893
28: 1.9475752115249634
29: 1.9410171508789062
30: 1.9085379838943481
31: 1.909257411956787
32: 1.8879634141921997
33: 1.8992094993591309
34: 1.8715054988861084
35: 1.9083049297332764
36: 1.882127046585083
37: 1.8411362171173096
38: 1.8801289796829224
39: 1.8911769390106201
40: 1.8719515800476074
41: 1.8532606363296509
42: 1.8339948654174805
43: 1.8774776458740234
44: 1.779

training:   7%|▋         | 70/1000 [00:10<02:13,  6.98it/s]

68: 1.773900032043457
69: 1.7292231321334839
70: 1.7478266954421997
71: 1.7454346418380737
72: 1.7269299030303955
73: 1.6844974756240845
74: 1.7509827613830566
75: 1.7369378805160522
76: 1.7060917615890503
77: 1.7143645286560059
78: 1.7413281202316284
79: 1.6951066255569458
80: 1.70162832736969
81: 1.7380344867706299
82: 1.7289066314697266
83: 1.687938928604126
84: 1.6519581079483032
85: 1.6598811149597168
86: 1.6467138528823853
87: 1.67922043800354
88: 1.6297913789749146
89: 1.6798005104064941
90: 1.6523131132125854
91: 1.6861127614974976
92: 1.6443918943405151
93: 1.6947553157806396
94: 1.6525758504867554
95: 1.6823093891143799
96: 1.6445993185043335
97: 1.6499935388565063
98: 1.6348520517349243
99: 1.6436207294464111
100: 1.6597812175750732
input:   tensor([[2, 5, 4, 2, 2, 3, 3, 3]], device='cuda:0')
predicted output:   tensor([[3, 5, 2, 2, 3, 2, 2, 2]], device='cuda:0')
incorrects: 6
101: 1.6430368423461914
102: 1.6093084812164307
103: 1.6468071937561035
104: 1.6263577938079834
105

training:  26%|██▌       | 256/1000 [00:20<00:54, 13.78it/s]

255: 1.073627233505249
256: 1.0953989028930664
257: 1.05478036403656
258: 1.0248057842254639
259: 1.0761349201202393
260: 1.105521559715271
261: 1.1185745000839233
262: 1.049312710762024
263: 1.1157252788543701
264: 1.0609345436096191
265: 1.0596975088119507
266: 1.057810664176941
267: 1.0844955444335938
268: 1.0535106658935547
269: 1.0609862804412842
270: 1.0604703426361084
271: 1.0180641412734985
272: 1.0786901712417603
273: 1.031104326248169
274: 1.0650688409805298
275: 1.0364155769348145
276: 1.003841757774353
277: 1.07004976272583
278: 1.0187362432479858
279: 1.0586304664611816
280: 1.0625107288360596
281: 1.0498508214950562
282: 1.006453275680542
283: 1.0594245195388794
284: 1.0526318550109863
285: 1.0124506950378418
286: 1.0291630029678345
287: 1.0335675477981567
288: 1.0132248401641846
289: 1.0444321632385254
290: 0.9814149141311646
291: 0.9597742557525635
292: 1.0303080081939697
293: 0.9731603860855103
294: 1.0139151811599731
295: 1.0369181632995605
296: 1.0042468309402466
297

training:  45%|████▌     | 451/1000 [00:30<00:33, 16.36it/s]

448: 0.645434558391571
449: 0.6029922962188721
450: 0.6284847855567932
451: 0.6422280669212341
452: 0.7015345692634583
453: 0.5959850549697876
454: 0.6938579678535461
455: 0.629040002822876
456: 0.6348475813865662
457: 0.6208232045173645
458: 0.6821888089179993
459: 0.6241246461868286
460: 0.7374354600906372
461: 0.6307433247566223
462: 0.6618214249610901
463: 0.6515529751777649
464: 0.6544589996337891
465: 0.6120071411132812
466: 0.6452041864395142
467: 0.6638752222061157
468: 0.5989445447921753
469: 0.6230584383010864
470: 0.6899414658546448
471: 0.6993363499641418
472: 0.668201208114624
473: 0.6699269413948059
474: 0.6385456919670105
475: 0.6635546088218689
476: 0.6356886625289917
477: 0.6287006735801697
478: 0.6716015934944153
479: 0.6714935302734375
480: 0.6050931811332703
481: 0.6324937343597412
482: 0.6313937306404114
483: 0.6149828433990479
484: 0.614185094833374
485: 0.6166908740997314
486: 0.6269984245300293
487: 0.6151638031005859
488: 0.6150819659233093
489: 0.5626185536384

training:  66%|██████▌   | 655/1000 [00:40<00:19, 17.95it/s]

651: 0.467553973197937
652: 0.4648176431655884
653: 0.4911976456642151
654: 0.45389798283576965
655: 0.45835691690444946
656: 0.505477249622345
657: 0.48642146587371826
658: 0.45615077018737793
659: 0.5351996421813965
660: 0.4683299958705902
661: 0.43910929560661316
662: 0.49235665798187256
663: 0.46226003766059875
664: 0.4842638671398163
665: 0.4743165969848633
666: 0.4839310944080353
667: 0.5112587809562683
668: 0.4503178298473358
669: 0.47240570187568665
670: 0.48710986971855164
671: 0.48020631074905396
672: 0.456676721572876
673: 0.47096365690231323
674: 0.44144803285598755
675: 0.4869357645511627
676: 0.4372209310531616
677: 0.45522797107696533
678: 0.46531882882118225
679: 0.44899508357048035
680: 0.44952890276908875
681: 0.4392344057559967
682: 0.4539485573768616
683: 0.46850478649139404
684: 0.44622400403022766
685: 0.4736514687538147
686: 0.44495731592178345
687: 0.43023309111595154
688: 0.46190372109413147
689: 0.4273545742034912
690: 0.4472697079181671
691: 0.434479206800460

training:  86%|████████▌ | 859/1000 [00:50<00:07, 18.58it/s]

855: 0.494182825088501
856: 0.37621432542800903
857: 0.3832762539386749
858: 0.40720728039741516
859: 0.4293483793735504
860: 0.39243414998054504
861: 0.441006064414978
862: 0.4348162114620209
863: 0.4178259074687958
864: 0.3843914866447449
865: 0.42851853370666504
866: 0.408882737159729
867: 0.3982895016670227
868: 0.4129410684108734
869: 0.4607236087322235
870: 0.3751091957092285
871: 0.43847423791885376
872: 0.4013316035270691
873: 0.44240352511405945
874: 0.43345707654953003
875: 0.4477173388004303
876: 0.5483145117759705
877: 0.41919612884521484
878: 0.5067286491394043
879: 0.41725337505340576
880: 0.4101925492286682
881: 0.4702812433242798
882: 0.5094444155693054
883: 0.4366513192653656
884: 0.4128284156322479
885: 0.4488702118396759
886: 0.47367534041404724
887: 0.46816638112068176
888: 0.3922692835330963
889: 0.45860862731933594
890: 0.4955599009990692
891: 0.4354282319545746
892: 0.4954681396484375
893: 0.4183166027069092
894: 0.4706358313560486
895: 0.44970810413360596
896: 0

training: 100%|██████████| 1000/1000 [00:57<00:00, 17.28it/s]

996: 0.3308544158935547
997: 0.34540706872940063
998: 0.31389281153678894
999: 0.3339492082595825





In [6]:
print(f'Accuracy : {(len(src) - incorrects)/len(src)*100}%')

Accuracy : 81.25%
