## Transformers architecture를 받아와 구축해보자!

Transformers 강의 + 실습까지 수강하시느라 너무 수고 많으셨습니다! 저도 저번 기수 당시 transformers를 처음 접했는데, 모델의 구조가 난해하고 쓰인 개념들도 어려워서 이해하는데 많은 시간이 걸렸었네요.. 그래도 transformers 모델이 현재 contemporary AI technology에 쓰이지 않는 곳이 없다보니, 어려운 내용들이 다수 있지만 힘주어서 중요한 내용들로 채워넣으려고 노력했습니다 수고하셨습니다 XD

Transformers은 개념도 개념이지만, 매번 새롭게 attention 코드를 짜고, encoder와 decoder 구조를 구축하는 것도 막막하실 겁니다! 다행히도 transformers architecture는 워낙 유명해서 이제 코드 몇줄만 `딸깍`해도 최신 논문기반 transformer 구조를 `huggingface` 혹은 `github`에서 받아서 사용할 수 있습니다 :D 이번 실습에는 초심자가 사용하기는 어렵지만 `x-transformers`에서 저희가 구축한 transformers 구조를 받아와, 예시 문장을 출력하는 것까지 마무리할 예정입니다!

transformers architecture가 2017년 나온 이후로, 이 architecture를 기반으로 한 여러 모델들이 나왔고, 또한 base model에 대해서도 여러 개선점들이 추가되었습니다. `x-transformers`은 이 개선된 model들을 코드 몇줄을 추가해서 적용 가능하게 하는 라이브러리로, 최신 논문 동향을 파악하고 있어야 한다는 점에서 어렵지만 그만큼 성능이 뒷받침해주는 코드들을 모아둔 라이브러리입니다.

TMI가 생각보다 길어졌는데, 아무튼 이번 과제는 따로 작성해 넣어야할 부분은 없고, 제가 드리는 코드를 그대로 실행하기만 하면 되는 과정으로 추가했습니다. 이제 개강이 얼마 남지 않았는데, 다들 화이팅입니다 XD

# Colab GPU 환경에서 구동하세요!

In [1]:
!pip install x-transformers

Collecting x-transformers
  Downloading x_transformers-1.19.1-py3-none-any.whl (27 kB)
Collecting einops>=0.6.1 (from x-transformers)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, x-transformers
Successfully installed einops-0.6.1 x-transformers-1.19.1


In [2]:
import torch
from x_transformers import XTransformer

## Transformers Architectures 받아오기
model = XTransformer(
    ## 모델의 차원 (논문 = 512)
    dim = 16,
    ## encoder token의 개수 (논문 = 256)
    enc_num_tokens = 16,
    ## encoder 반복 횟수 (논문 = 6)
    enc_depth = 6,
    ## multihead attention n_heads 개수 (논문 = 8)
    enc_heads = 8,
    ## encoder token의 max sequence length (논문 = 1024)
    enc_max_seq_len = 32,
    ## (논문 = 256)
    dec_num_tokens = 16,
    dec_depth = 6,
    dec_heads = 8,
    ## (논문 = 1024)
    dec_max_seq_len = 32,
    tie_token_emb = True
).cuda()

In [3]:
## NUM_BATCHES = Epoches의 개수
## BATCH_SIZE = 하나의 batch에 들어갈 sample의 개수
## LEARNING_RATE = learning rate
## GENERATE_EVERY = 100번마다 한번씩 generate해서 accuracy확인
## NUM_TOKENS = 데이터 내 유니크한 토큰의 수
## ENC_SEQ_LEN = encoder sequence length

NUM_BATCHES = int(1e3)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 100
NUM_TOKENS = 4 + 2
ENC_SEQ_LEN = 8
DEC_SEQ_LEN = 16 + 1

In [4]:
## Transformer 모델에 넣을 임의의 src, tgt, src_mask 생성

def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
        yield (src, tgt, src_mask)

In [5]:
## Train Model

import tqdm
import torch.optim as optim

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## loss update 해가면서 학습
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask = next(cycle())

    loss = model(src, tgt, mask=src_mask)
    loss.backward()
    print(f'{i}: {loss.item()}')

    optim.step()
    optim.zero_grad()

    ## 매 N(100)번마다 accuracy 측정
    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

training:   0%|          | 0/1000 [00:00<?, ?it/s]

0: 2.914153575897217
1: 2.4718339443206787
2: 2.2922163009643555
3: 2.200052261352539
4: 2.1611337661743164
5: 2.1652488708496094
6: 2.1505610942840576
7: 2.084855794906616
8: 2.0654168128967285
9: 2.095949649810791
10: 2.0796775817871094
11: 2.0561516284942627
12: 2.069959878921509
13: 2.0482749938964844
14: 2.0365278720855713
15: 2.049649953842163
16: 2.022548198699951
17: 2.0343170166015625
18: 1.9980446100234985
19: 1.9868320226669312
20: 1.9652010202407837
21: 1.9928710460662842
22: 2.010746479034424
23: 2.010732412338257
24: 1.977786898612976
25: 1.9896687269210815
26: 1.9711124897003174
27: 1.9527109861373901
28: 1.9643702507019043
29: 1.9604979753494263
30: 1.9713276624679565
31: 1.9144268035888672
32: 1.9333975315093994
33: 1.9230254888534546
34: 1.9366639852523804
35: 1.90908682346344
36: 1.9223239421844482
37: 1.8756539821624756
38: 1.8812384605407715
39: 1.8786792755126953
40: 1.8550961017608643
41: 1.859017252922058
42: 1.853161096572876
43: 1.8383246660232544
44: 1.868273

training:  11%|█         | 111/1000 [00:10<01:20, 11.07it/s]

109: 1.5784934759140015
110: 1.5637534856796265
111: 1.5875625610351562
112: 1.5881927013397217
113: 1.5980137586593628
114: 1.5324068069458008
115: 1.5766412019729614
116: 1.57908034324646
117: 1.5725839138031006
118: 1.5588524341583252
119: 1.5807102918624878
120: 1.5730291604995728
121: 1.5532383918762207
122: 1.5316075086593628
123: 1.5376712083816528
124: 1.5413099527359009
125: 1.5674426555633545
126: 1.487176537513733
127: 1.574353814125061
128: 1.5548062324523926
129: 1.5459672212600708
130: 1.5185749530792236
131: 1.5202702283859253
132: 1.5188095569610596
133: 1.5318721532821655
134: 1.512141466140747
135: 1.4981012344360352
136: 1.5171382427215576
137: 1.4827970266342163
138: 1.5056418180465698
139: 1.4631519317626953
140: 1.4900286197662354
141: 1.4632591009140015
142: 1.4538007974624634
143: 1.4889214038848877
144: 1.4751932621002197
145: 1.4701156616210938
146: 1.4691507816314697
147: 1.450945258140564
148: 1.4578852653503418
149: 1.4380671977996826
150: 1.449689149856567

training:  29%|██▊       | 286/1000 [00:20<00:48, 14.81it/s]

282: 1.0231668949127197
283: 1.0776859521865845
284: 1.0353426933288574
285: 1.0438802242279053
286: 1.0394113063812256
287: 1.0637224912643433
288: 1.0874154567718506
289: 1.0333714485168457
290: 1.0643212795257568
291: 1.0029126405715942
292: 1.0510104894638062
293: 1.0462837219238281
294: 1.0121333599090576
295: 1.0907479524612427
296: 1.0113110542297363
297: 1.0505857467651367
298: 1.0335153341293335
299: 1.0873744487762451
300: 1.021870732307434
input:   tensor([[4, 5, 3, 4, 5, 3, 5, 2]], device='cuda:0')
predicted output:   tensor([[3, 4, 5, 5, 2, 4, 3, 4]], device='cuda:0')
incorrects: 8
301: 1.0305638313293457
302: 0.9863582849502563
303: 0.978124737739563
304: 0.9416863918304443
305: 0.9774136543273926
306: 0.9242491722106934
307: 0.9411768913269043
308: 0.9421495199203491
309: 0.9801124334335327
310: 0.9446565508842468
311: 0.9554623961448669
312: 0.9569493532180786
313: 0.90095055103302
314: 0.9319756627082825
315: 0.9214234948158264
316: 0.953609824180603
317: 0.91202658414

training:  46%|████▌     | 461/1000 [00:30<00:33, 15.92it/s]

456: 0.6918255090713501
457: 0.7300965189933777
458: 0.8033282160758972
459: 0.7502056956291199
460: 0.7979726791381836
461: 0.7417446970939636
462: 0.7219099998474121
463: 0.7513388395309448
464: 0.7831972241401672
465: 0.7061207890510559
466: 0.7028279304504395
467: 0.7908085584640503
468: 0.7545245289802551
469: 0.6836647987365723
470: 0.6987789869308472
471: 0.7206925749778748
472: 0.7284412384033203
473: 0.7124083042144775
474: 0.6793212890625
475: 0.6630160212516785
476: 0.7085117697715759
477: 0.6958303451538086
478: 0.6586872339248657
479: 0.6416089534759521
480: 0.6254047751426697
481: 0.655293345451355
482: 0.6356406211853027
483: 0.6073547005653381
484: 0.6376940608024597
485: 0.6447975039482117
486: 0.6386815905570984
487: 0.620576024055481
488: 0.6069929003715515
489: 0.6118273138999939
490: 0.5976830124855042
491: 0.6380203366279602
492: 0.587378203868866
493: 0.6306474208831787
494: 0.5997589826583862
495: 0.626042366027832
496: 0.6271562576293945
497: 0.6267951726913452

training:  65%|██████▍   | 649/1000 [00:40<00:20, 17.00it/s]

647: 0.48837044835090637
648: 0.4861643314361572
649: 0.511516273021698
650: 0.48152586817741394
651: 0.5000406503677368
652: 0.5344781875610352
653: 0.49410420656204224
654: 0.5108428001403809
655: 0.4990445077419281
656: 0.4874820113182068
657: 0.5157042741775513
658: 0.5004673600196838
659: 0.5169881582260132
660: 0.5124659538269043
661: 0.5436401963233948
662: 0.5195754766464233
663: 0.508572518825531
664: 0.4964755177497864
665: 0.49466657638549805
666: 0.5190247893333435
667: 0.48971375823020935
668: 0.5100383758544922
669: 0.47692883014678955
670: 0.47503572702407837
671: 0.48776373267173767
672: 0.4842880368232727
673: 0.48919859528541565
674: 0.46884891390800476
675: 0.5153008103370667
676: 0.4887297749519348
677: 0.5260093212127686
678: 0.5142576098442078
679: 0.49579939246177673
680: 0.5715795159339905
681: 0.47985225915908813
682: 0.5093812942504883
683: 0.541231095790863
684: 0.5422060489654541
685: 0.4854280352592468
686: 0.5348236560821533
687: 0.4889082610607147
688: 0.

training:  84%|████████▎ | 836/1000 [00:50<00:09, 17.34it/s]

834: 0.40536949038505554
835: 0.4485745429992676
836: 0.4004315435886383
837: 0.36676377058029175
838: 0.42619141936302185
839: 0.4441970884799957
840: 0.4244338572025299
841: 0.42615893483161926
842: 0.43636995553970337
843: 0.4191148281097412
844: 0.419814795255661
845: 0.429036408662796
846: 0.3998190462589264
847: 0.4364280104637146
848: 0.4401352107524872
849: 0.4091636538505554
850: 0.42458367347717285
851: 0.4570711553096771
852: 0.4300329089164734
853: 0.45889899134635925
854: 0.40051546692848206
855: 0.42815566062927246
856: 0.436421662569046
857: 0.4418829083442688
858: 0.4273316264152527
859: 0.42708447575569153
860: 0.4519297480583191
861: 0.42249515652656555
862: 0.46277254819869995
863: 0.4264020323753357
864: 0.4190294146537781
865: 0.43044745922088623
866: 0.44540345668792725
867: 0.40881839394569397
868: 0.40483587980270386
869: 0.42364078760147095
870: 0.39735841751098633
871: 0.3899841010570526
872: 0.42498189210891724
873: 0.4046390652656555
874: 0.4038293957710266


training: 100%|██████████| 1000/1000 [01:00<00:00, 16.62it/s]

999: 0.3711548149585724





In [6]:
print(f'Accuracy : {(len(src) - incorrects)/len(src)*100}%')

Accuracy : 81.25%
