## Transformers architecture를 받아와 구축해보자!

Transformers 강의 + 실습까지 수강하시느라 너무 수고 많으셨습니다! 저도 저번 기수 당시 transformers를 처음 접했는데, 모델의 구조가 난해하고 쓰인 개념들도 어려워서 이해하는데 많은 시간이 걸렸었네요.. 그래도 transformers 모델이 현재 contemporary AI technology에 쓰이지 않는 곳이 없다보니, 어려운 내용들이 다수 있지만 힘주어서 중요한 내용들로 채워넣으려고 노력했습니다 수고하셨습니다 XD

Transformers은 개념도 개념이지만, 매번 새롭게 attention 코드를 짜고, encoder와 decoder 구조를 구축하는 것도 막막하실 겁니다! 다행히도 transformers architecture는 워낙 유명해서 이제 코드 몇줄만 `딸깍`해도 최신 논문기반 transformer 구조를 `huggingface` 혹은 `github`에서 받아서 사용할 수 있습니다 :D 이번 실습에는 초심자가 사용하기는 어렵지만 `x-transformers`에서 저희가 구축한 transformers 구조를 받아와, 예시 문장을 출력하는 것까지 마무리할 예정입니다!

transformers architecture가 2017년 나온 이후로, 이 architecture를 기반으로 한 여러 모델들이 나왔고, 또한 base model에 대해서도 여러 개선점들이 추가되었습니다. `x-transformers`은 이 개선된 model들을 코드 몇줄을 추가해서 적용 가능하게 하는 라이브러리로, 최신 논문 동향을 파악하고 있어야 한다는 점에서 어렵지만 그만큼 성능이 뒷받침해주는 코드들을 모아둔 라이브러리입니다.

TMI가 생각보다 길어졌는데, 아무튼 이번 과제는 따로 작성해 넣어야할 부분은 없고, 제가 드리는 코드를 그대로 실행하기만 하면 되는 과정으로 추가했습니다. 이제 개강이 얼마 남지 않았는데, 다들 화이팅입니다 XD

# Colab GPU 환경에서 구동하세요!

In [None]:
!pip install x-transformers

Collecting x-transformers
  Downloading x_transformers-1.19.1-py3-none-any.whl (27 kB)
Collecting einops>=0.6.1 (from x-transformers)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, x-transformers
Successfully installed einops-0.6.1 x-transformers-1.19.1


In [None]:
import torch
from x_transformers import XTransformer

## Transformers Architectures 받아오
model = XTransformer(
    ## 모델의 차원 (논문 = 512)
    dim = 16,
    ## encoder token의 개수 (논문 = 256)
    enc_num_tokens = 16,
    ## encoder 반복 횟수 (논문 = 6)
    enc_depth = 6,
    ## multihead attention n_heads 개수 (논문 = 8)
    enc_heads = 8,
    ## encoder token의 max sequence length (논문 = 1024)
    enc_max_seq_len = 32,
    ## (논문 = 256)
    dec_num_tokens = 16,
    dec_depth = 6,
    dec_heads = 8,
    ## (논문 = 1024)
    dec_max_seq_len = 32,
    tie_token_emb = True
).cuda()

In [None]:
## NUM_BATCHES = Epoches의 개수
## BATCH_SIZE = 하나의 batch에 들어갈 sample의 개수
## LEARNING_RATE = learning rate
## GENERATE_EVERY = 100번마다 한번씩 generate해서 accuracy확인
## NUM_TOKENS = 데이터 내 유니크한 토큰의 수
## ENC_SEQ_LEN = encoder sequence length

NUM_BATCHES = int(1e3)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 100
NUM_TOKENS = 4 + 2
ENC_SEQ_LEN = 8
DEC_SEQ_LEN = 16 + 1

In [None]:
## Transformer 모델에 넣을 임의의 src, tgt, src_mask 생성

def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
        yield (src, tgt, src_mask)

In [None]:
## Train Model

import tqdm
import torch.optim as optim

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## loss update 해가면서 학습
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask = next(cycle())

    loss = model(src, tgt, mask=src_mask)
    loss.backward()
    print(f'{i}: {loss.item()}')

    optim.step()
    optim.zero_grad()

    ## 매 N(100)번마다 accuracy 측정
    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

training:   0%|          | 0/1000 [00:00<?, ?it/s]

0: 3.0212855339050293
1: 2.3642079830169678
2: 2.226148843765259
3: 2.1935482025146484
4: 2.171018123626709
5: 2.1506733894348145
6: 2.1315553188323975
7: 2.1253483295440674
8: 2.1182100772857666
9: 2.1025984287261963
10: 2.074812412261963
11: 2.0844385623931885
12: 2.0761375427246094
13: 2.0722482204437256
14: 2.0744636058807373
15: 2.056673526763916
16: 2.057690382003784
17: 2.0289928913116455
18: 2.020368814468384
19: 2.0188047885894775
20: 1.9886363744735718
21: 1.9948275089263916
22: 1.996999740600586
23: 1.9949848651885986
24: 1.9904048442840576
25: 1.9658087491989136
26: 1.9825024604797363
27: 1.9734437465667725
28: 1.9263211488723755
29: 1.9595569372177124
30: 1.934654951095581
31: 1.9867656230926514
32: 1.9428116083145142
33: 1.9454864263534546
34: 1.972327470779419
35: 1.929093599319458
36: 1.9355987310409546
37: 1.9030027389526367
38: 1.953557014465332
39: 1.9191949367523193
40: 1.9304624795913696
41: 1.9284189939498901
42: 1.864706039428711
43: 1.9268651008605957
44: 1.9384

training:   8%|▊         | 79/1000 [00:10<02:00,  7.66it/s]

78: 1.7495452165603638
79: 1.7518258094787598
80: 1.7339065074920654
81: 1.7065961360931396
82: 1.7074899673461914
83: 1.7526147365570068
84: 1.7137277126312256
85: 1.7124789953231812
86: 1.7167845964431763
87: 1.6901369094848633
88: 1.7295829057693481
89: 1.708381175994873
90: 1.6889915466308594
91: 1.6982662677764893
92: 1.717208981513977
93: 1.6952810287475586
94: 1.6834042072296143
95: 1.6823891401290894
96: 1.7010376453399658
97: 1.6360135078430176
98: 1.6470218896865845
99: 1.6673493385314941
100: 1.6687783002853394
input:   tensor([[3, 3, 4, 2, 2, 3, 3, 5]], device='cuda:0')
predicted output:   tensor([[2, 3, 2, 3, 2, 3, 2, 3]], device='cuda:0')
incorrects: 5
101: 1.6499775648117065
102: 1.6447473764419556
103: 1.6662211418151855
104: 1.6411222219467163
105: 1.640832543373108
106: 1.6290967464447021
107: 1.606196641921997
108: 1.6623011827468872
109: 1.6375300884246826
110: 1.6028977632522583
111: 1.6187576055526733
112: 1.617673635482788
113: 1.6427037715911865
114: 1.622441768

training:  26%|██▋       | 263/1000 [00:20<00:53, 13.87it/s]

258: 1.1341097354888916
259: 1.087375283241272
260: 1.1187375783920288
261: 1.1263916492462158
262: 1.0948294401168823
263: 1.123265027999878
264: 1.1567891836166382
265: 1.1104481220245361
266: 1.1292226314544678
267: 1.1714422702789307
268: 1.1225603818893433
269: 1.0781569480895996
270: 1.083160400390625
271: 1.0911201238632202
272: 1.0905139446258545
273: 1.0911593437194824
274: 1.1191331148147583
275: 1.1032209396362305
276: 1.1339524984359741
277: 1.1011319160461426
278: 1.0869920253753662
279: 1.0875697135925293
280: 1.1165754795074463
281: 1.09503173828125
282: 1.1488412618637085
283: 1.0706465244293213
284: 1.1319003105163574
285: 1.0680292844772339
286: 1.0835278034210205
287: 1.1535279750823975
288: 1.0656384229660034
289: 1.0554792881011963
290: 1.0790824890136719
291: 1.1120723485946655
292: 1.0887848138809204
293: 1.0022807121276855
294: 1.1188815832138062
295: 1.0493847131729126
296: 1.1038464307785034
297: 1.0533276796340942
298: 1.0766462087631226
299: 1.02693557739257

training:  45%|████▌     | 453/1000 [00:30<00:33, 16.16it/s]

450: 0.7753773331642151
451: 0.7299436330795288
452: 0.7503994703292847
453: 0.7273203134536743
454: 0.7311460375785828
455: 0.7535455226898193
456: 0.695065438747406
457: 0.7438489198684692
458: 0.7301526069641113
459: 0.7418602705001831
460: 0.7682793736457825
461: 0.7321299314498901
462: 0.7227210402488708
463: 0.7009592056274414
464: 0.7521753311157227
465: 0.7574623823165894
466: 0.7247949242591858
467: 0.6967576146125793
468: 0.7271579504013062
469: 0.7085175514221191
470: 0.7209762930870056
471: 0.7194848656654358
472: 0.6749071478843689
473: 0.7155766487121582
474: 0.7003940939903259
475: 0.681584358215332
476: 0.688069760799408
477: 0.6895302534103394
478: 0.7105495929718018
479: 0.6926348805427551
480: 0.726351261138916
481: 0.684346616268158
482: 0.6735029220581055
483: 0.6605088114738464
484: 0.6704211831092834
485: 0.7158555388450623
486: 0.6592168211936951
487: 0.6471205353736877
488: 0.6793718934059143
489: 0.6960662007331848
490: 0.6914721727371216
491: 0.68412786722183

training:  64%|██████▍   | 643/1000 [00:40<00:20, 17.20it/s]

641: 0.5085007548332214
642: 0.5113397836685181
643: 0.5400828123092651
644: 0.5032861828804016
645: 0.5674381256103516
646: 0.5026510953903198
647: 0.5229494571685791
648: 0.540281355381012
649: 0.5416291952133179
650: 0.5478782653808594
651: 0.50193852186203
652: 0.5187769532203674
653: 0.5251586437225342
654: 0.5538427233695984
655: 0.500501811504364
656: 0.507365882396698
657: 0.5083209872245789
658: 0.4939287006855011
659: 0.5012441277503967
660: 0.5153346657752991
661: 0.5115206241607666
662: 0.5035040378570557
663: 0.4757844805717468
664: 0.5107600092887878
665: 0.49750816822052
666: 0.5474876761436462
667: 0.5158712267875671
668: 0.561199426651001
669: 0.46740126609802246
670: 0.4813959002494812
671: 0.5089186429977417
672: 0.5120235085487366
673: 0.5758919715881348
674: 0.47599637508392334
675: 0.5200101137161255
676: 0.5116533041000366
677: 0.5175460577011108
678: 0.5217621326446533
679: 0.5327814221382141
680: 0.5231935977935791
681: 0.5220983028411865
682: 0.504936277866363

training:  83%|████████▎ | 833/1000 [00:50<00:09, 17.82it/s]

828: 0.423071026802063
829: 0.4308679699897766
830: 0.4299837052822113
831: 0.4007757008075714
832: 0.4096335768699646
833: 0.4140881896018982
834: 0.4615592360496521
835: 0.38446998596191406
836: 0.4055885076522827
837: 0.3894278109073639
838: 0.4087826907634735
839: 0.39882782101631165
840: 0.40944361686706543
841: 0.39672911167144775
842: 0.37637221813201904
843: 0.40572071075439453
844: 0.3615206778049469
845: 0.43736541271209717
846: 0.40118512511253357
847: 0.411010205745697
848: 0.4046541750431061
849: 0.3974030911922455
850: 0.4172113239765167
851: 0.435287207365036
852: 0.39845699071884155
853: 0.44646766781806946
854: 0.45541077852249146
855: 0.43845006823539734
856: 0.49923089146614075
857: 0.39459678530693054
858: 0.5189737677574158
859: 0.47317248582839966
860: 0.4984641373157501
861: 0.5309951901435852
862: 0.4126749634742737
863: 0.5054384469985962
864: 0.44112467765808105
865: 0.5054055452346802
866: 0.4307250380516052
867: 0.5257443189620972
868: 0.4075700640678406
869

training: 100%|██████████| 1000/1000 [00:58<00:00, 17.03it/s]

999: 0.4109896719455719





In [None]:
print(f'Accuracy : {(len(src) - incorrects)/len(src)*100}%')

Accuracy : 93.75%
