## Transformers architecture를 받아와 구축해보자!

Transformers 강의 + 실습까지 수강하시느라 너무 수고 많으셨습니다! 저도 저번 기수 당시 transformers를 처음 접했는데, 모델의 구조가 난해하고 쓰인 개념들도 어려워서 이해하는데 많은 시간이 걸렸었네요.. 그래도 transformers 모델이 현재 contemporary AI technology에 쓰이지 않는 곳이 없다보니, 어려운 내용들이 다수 있지만 힘주어서 중요한 내용들로 채워넣으려고 노력했습니다 수고하셨습니다 XD

Transformers은 개념도 개념이지만, 매번 새롭게 attention 코드를 짜고, encoder와 decoder 구조를 구축하는 것도 막막하실 겁니다! 다행히도 transformers architecture는 워낙 유명해서 이제 코드 몇줄만 `딸깍`해도 최신 논문기반 transformer 구조를 `huggingface` 혹은 `github`에서 받아서 사용할 수 있습니다 :D 이번 실습에는 초심자가 사용하기는 어렵지만 `x-transformers`에서 저희가 구축한 transformers 구조를 받아와, 예시 문장을 출력하는 것까지 마무리할 예정입니다!

transformers architecture가 2017년 나온 이후로, 이 architecture를 기반으로 한 여러 모델들이 나왔고, 또한 base model에 대해서도 여러 개선점들이 추가되었습니다. `x-transformers`은 이 개선된 model들을 코드 몇줄을 추가해서 적용 가능하게 하는 라이브러리로, 최신 논문 동향을 파악하고 있어야 한다는 점에서 어렵지만 그만큼 성능이 뒷받침해주는 코드들을 모아둔 라이브러리입니다.

TMI가 생각보다 길어졌는데, 아무튼 이번 과제는 따로 작성해 넣어야할 부분은 없고, 제가 드리는 코드를 그대로 실행하기만 하면 되는 과정으로 추가했습니다. 이제 개강이 얼마 남지 않았는데, 다들 화이팅입니다 XD

# Colab GPU 환경에서 구동하세요!

In [1]:
!pip install x-transformers

Collecting x-transformers
  Downloading x_transformers-1.19.1-py3-none-any.whl (27 kB)
Collecting einops>=0.6.1 (from x-transformers)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, x-transformers
Successfully installed einops-0.6.1 x-transformers-1.19.1


In [2]:
import torch
from x_transformers import XTransformer

## Transformers Architectures 받아오
model = XTransformer(
    ## 모델의 차원 (논문 = 512)
    dim = 16,
    ## encoder token의 개수 (논문 = 256)
    enc_num_tokens = 16,
    ## encoder 반복 횟수 (논문 = 6)
    enc_depth = 6,
    ## multihead attention n_heads 개수 (논문 = 8)
    enc_heads = 8,
    ## encoder token의 max sequence length (논문 = 1024)
    enc_max_seq_len = 32,
    ## (논문 = 256)
    dec_num_tokens = 16,
    dec_depth = 6,
    dec_heads = 8,
    ## (논문 = 1024)
    dec_max_seq_len = 32,
    tie_token_emb = True
).cuda()

In [3]:
## NUM_BATCHES = Epoches의 개수
## BATCH_SIZE = 하나의 batch에 들어갈 sample의 개수
## LEARNING_RATE = learning rate
## GENERATE_EVERY = 100번마다 한번씩 generate해서 accuracy확인
## NUM_TOKENS = 데이터 내 유니크한 토큰의 수
## ENC_SEQ_LEN = encoder sequence length

NUM_BATCHES = int(1e3)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 100
NUM_TOKENS = 4 + 2
ENC_SEQ_LEN = 8
DEC_SEQ_LEN = 16 + 1

In [4]:
## Transformer 모델에 넣을 임의의 src, tgt, src_mask 생성

def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
        yield (src, tgt, src_mask)

In [5]:
## Train Model

import tqdm
import torch.optim as optim

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## loss update 해가면서 학습
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask = next(cycle())

    loss = model(src, tgt, mask=src_mask)
    loss.backward()
    print(f'{i}: {loss.item()}')

    optim.step()
    optim.zero_grad()

    ## 매 N(100)번마다 accuracy 측정
    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

training:   0%|          | 0/1000 [00:00<?, ?it/s]

0: 2.6478800773620605
1: 2.261573076248169
2: 2.1637144088745117
3: 2.078657388687134
4: 2.1100258827209473
5: 2.0440163612365723
6: 2.0724966526031494
7: 2.0347578525543213
8: 2.023211717605591
9: 2.078023910522461
10: 2.0280685424804688
11: 2.0008797645568848
12: 2.001140594482422
13: 2.0116567611694336
14: 1.976168155670166
15: 2.0053961277008057
16: 1.9580340385437012
17: 1.9591012001037598
18: 2.01456618309021
19: 1.9999701976776123
20: 1.9744423627853394
21: 1.9344919919967651
22: 1.9283185005187988
23: 1.9319243431091309
24: 1.9216663837432861
25: 1.8966960906982422
26: 1.9359952211380005
27: 1.9134674072265625
28: 1.9112390279769897
29: 1.9334901571273804
30: 1.894859790802002
31: 1.8683433532714844
32: 1.9014170169830322
33: 1.9246236085891724
34: 1.8898277282714844
35: 1.843149185180664
36: 1.8760325908660889
37: 1.8695727586746216
38: 1.847262978553772
39: 1.8578498363494873
40: 1.8311244249343872
41: 1.850803256034851
42: 1.8708817958831787
43: 1.8318973779678345
44: 1.8311

training:  15%|█▌        | 150/1000 [00:10<00:56, 14.99it/s]

145: 1.4938673973083496
146: 1.530427098274231
147: 1.4917058944702148
148: 1.5221178531646729
149: 1.502716064453125
150: 1.5152720212936401
151: 1.4755371809005737
152: 1.5126391649246216
153: 1.480089783668518
154: 1.4707601070404053
155: 1.4855468273162842
156: 1.507737636566162
157: 1.4798026084899902
158: 1.4954264163970947
159: 1.4775959253311157
160: 1.4686166048049927
161: 1.4626795053482056
162: 1.479661226272583
163: 1.4902706146240234
164: 1.4675730466842651
165: 1.4786278009414673
166: 1.463558554649353
167: 1.493986964225769
168: 1.4733859300613403
169: 1.4490575790405273
170: 1.4406300783157349
171: 1.4411287307739258
172: 1.4547057151794434
173: 1.4391162395477295
174: 1.442339539527893
175: 1.411732792854309
176: 1.4331369400024414
177: 1.4591479301452637
178: 1.4239457845687866
179: 1.3900911808013916
180: 1.418952226638794
181: 1.403038501739502
182: 1.4286181926727295
183: 1.4201360940933228
184: 1.4426168203353882
185: 1.4099531173706055
186: 1.4195793867111206
187

training:  36%|███▌      | 357/1000 [00:20<00:35, 18.31it/s]

355: 0.8760620355606079
356: 0.8460785746574402
357: 0.8773635625839233
358: 0.8416330814361572
359: 0.8751066327095032
360: 0.8505054712295532
361: 0.8456946015357971
362: 0.8265751600265503
363: 0.871931254863739
364: 0.8430922627449036
365: 0.8478541970252991
366: 0.8143545389175415
367: 0.8482965230941772
368: 0.959280788898468
369: 0.8265675902366638
370: 0.9132434725761414
371: 0.8731804490089417
372: 0.8837803602218628
373: 0.883604109287262
374: 0.8043373227119446
375: 0.8742016553878784
376: 0.7838917970657349
377: 0.8192270398139954
378: 0.8654640913009644
379: 0.8133478164672852
380: 0.7926818132400513
381: 0.7963818311691284
382: 0.8415619730949402
383: 0.8234283328056335
384: 0.851693868637085
385: 0.7873582243919373
386: 0.7926135659217834
387: 0.7920675277709961
388: 0.8025996088981628
389: 0.8135731816291809
390: 0.8042054772377014
391: 0.7876914143562317
392: 0.7976572513580322
393: 0.7718051075935364
394: 0.778685986995697
395: 0.7994734644889832
396: 0.77481096982955

training:  56%|█████▋    | 564/1000 [00:30<00:22, 19.32it/s]

560: 0.5591724514961243
561: 0.5486852526664734
562: 0.5333073735237122
563: 0.5455255508422852
564: 0.5316713452339172
565: 0.5340614914894104
566: 0.5446965098381042
567: 0.524417519569397
568: 0.5157907605171204
569: 0.5179059505462646
570: 0.525750458240509
571: 0.5506320595741272
572: 0.533226728439331
573: 0.5275943875312805
574: 0.5127018690109253
575: 0.5188522338867188
576: 0.520380973815918
577: 0.5384547710418701
578: 0.512622594833374
579: 0.5103803277015686
580: 0.5109537243843079
581: 0.525209903717041
582: 0.5033068060874939
583: 0.5127798318862915
584: 0.5127544403076172
585: 0.5080753564834595
586: 0.5330702662467957
587: 0.5470303297042847
588: 0.507145881652832
589: 0.5157497525215149
590: 0.5161808729171753
591: 0.5240605473518372
592: 0.4948578476905823
593: 0.5134532451629639
594: 0.4944641590118408
595: 0.5068218111991882
596: 0.4992515444755554
597: 0.5103199481964111
598: 0.5335096716880798
599: 0.5140480399131775
600: 0.5323043465614319
input:   tensor([[4, 5,

training:  77%|███████▋  | 770/1000 [00:40<00:11, 19.64it/s]

768: 0.45069655776023865
769: 0.4618922173976898
770: 0.4825136661529541
771: 0.43602022528648376
772: 0.44258224964141846
773: 0.4835675358772278
774: 0.47648483514785767
775: 0.4456413686275482
776: 0.44949886202812195
777: 0.4766618609428406
778: 0.4849805533885956
779: 0.44416874647140503
780: 0.40616363286972046
781: 0.4546428918838501
782: 0.4537827968597412
783: 0.42771488428115845
784: 0.5031782388687134
785: 0.4508569836616516
786: 0.5049734711647034
787: 0.4329441487789154
788: 0.46513742208480835
789: 0.4851142168045044
790: 0.44816073775291443
791: 0.4526878893375397
792: 0.48474615812301636
793: 0.4534595310688019
794: 0.4333065450191498
795: 0.4414491057395935
796: 0.4259790778160095
797: 0.48112979531288147
798: 0.418515682220459
799: 0.4337979257106781
800: 0.42473462224006653
input:   tensor([[5, 4, 3, 4, 2, 5, 5, 3]], device='cuda:0')
predicted output:   tensor([[5, 4, 5, 3, 4, 5, 2, 3]], device='cuda:0')
incorrects: 4
801: 0.45467907190322876
802: 0.46483156085014343

training:  97%|█████████▋| 974/1000 [00:50<00:01, 19.91it/s]

969: 0.4036896228790283
970: 0.4480675458908081
971: 0.40851205587387085
972: 0.40263041853904724
973: 0.4005323648452759
974: 0.4566102921962738
975: 0.4224098324775696
976: 0.41898298263549805
977: 0.4295375645160675
978: 0.4116900563240051
979: 0.4347352683544159
980: 0.40734803676605225
981: 0.4118662476539612
982: 0.39046916365623474
983: 0.40149152278900146
984: 0.42339903116226196
985: 0.3906756043434143
986: 0.4090900421142578
987: 0.42831340432167053
988: 0.3980204164981842
989: 0.4470379054546356
990: 0.4019286334514618
991: 0.3952958285808563
992: 0.4030468165874481
993: 0.399570494890213
994: 0.39047136902809143
995: 0.39065268635749817
996: 0.4112502932548523
997: 0.38613393902778625
998: 0.406905859708786


training: 100%|██████████| 1000/1000 [00:51<00:00, 19.41it/s]

999: 0.38784945011138916





In [6]:
print(f'Accuracy : {(len(src) - incorrects)/len(src)*100}%')

Accuracy : 87.5%
