## Transformers architecture를 받아와 구축해보자!

Transformers 강의 + 실습까지 수강하시느라 너무 수고 많으셨습니다! 저도 저번 기수 당시 transformers를 처음 접했는데, 모델의 구조가 난해하고 쓰인 개념들도 어려워서 이해하는데 많은 시간이 걸렸었네요.. 그래도 transformers 모델이 현재 contemporary AI technology에 쓰이지 않는 곳이 없다보니, 어려운 내용들이 다수 있지만 힘주어서 중요한 내용들로 채워넣으려고 노력했습니다 수고하셨습니다 XD

Transformers은 개념도 개념이지만, 매번 새롭게 attention 코드를 짜고, encoder와 decoder 구조를 구축하는 것도 막막하실 겁니다! 다행히도 transformers architecture는 워낙 유명해서 이제 코드 몇줄만 `딸깍`해도 최신 논문기반 transformer 구조를 `huggingface` 혹은 `github`에서 받아서 사용할 수 있습니다 :D 이번 실습에는 초심자가 사용하기는 어렵지만 `x-transformers`에서 저희가 구축한 transformers 구조를 받아와, 예시 문장을 출력하는 것까지 마무리할 예정입니다!

transformers architecture가 2017년 나온 이후로, 이 architecture를 기반으로 한 여러 모델들이 나왔고, 또한 base model에 대해서도 여러 개선점들이 추가되었습니다. `x-transformers`은 이 개선된 model들을 코드 몇줄을 추가해서 적용 가능하게 하는 라이브러리로, 최신 논문 동향을 파악하고 있어야 한다는 점에서 어렵지만 그만큼 성능이 뒷받침해주는 코드들을 모아둔 라이브러리입니다.

TMI가 생각보다 길어졌는데, 아무튼 이번 과제는 따로 작성해 넣어야할 부분은 없고, 제가 드리는 코드를 그대로 실행하기만 하면 되는 과정으로 추가했습니다. 이제 개강이 얼마 남지 않았는데, 다들 화이팅입니다 XD

# Colab GPU 환경에서 구동하세요!

In [1]:
!pip install x-transformers

Collecting x-transformers
  Downloading x_transformers-1.19.1-py3-none-any.whl (27 kB)
Collecting einops>=0.6.1 (from x-transformers)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, x-transformers
Successfully installed einops-0.6.1 x-transformers-1.19.1


In [2]:
import torch
from x_transformers import XTransformer

## Transformers Architectures 받아오
model = XTransformer(
    ## 모델의 차원 (논문 = 512)
    dim = 16,
    ## encoder token의 개수 (논문 = 256)
    enc_num_tokens = 16,
    ## encoder 반복 횟수 (논문 = 6)
    enc_depth = 6,
    ## multihead attention n_heads 개수 (논문 = 8)
    enc_heads = 8,
    ## encoder token의 max sequence length (논문 = 1024)
    enc_max_seq_len = 32,
    ## (논문 = 256)
    dec_num_tokens = 16,
    dec_depth = 6,
    dec_heads = 8,
    ## (논문 = 1024)
    dec_max_seq_len = 32,
    tie_token_emb = True
).cuda()

In [3]:
## NUM_BATCHES = Epoches의 개수
## BATCH_SIZE = 하나의 batch에 들어갈 sample의 개수
## LEARNING_RATE = learning rate
## GENERATE_EVERY = 100번마다 한번씩 generate해서 accuracy확인
## NUM_TOKENS = 데이터 내 유니크한 토큰의 수
## ENC_SEQ_LEN = encoder sequence length

NUM_BATCHES = int(1e3)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 100
NUM_TOKENS = 4 + 2
ENC_SEQ_LEN = 8
DEC_SEQ_LEN = 16 + 1

In [4]:
## Transformer 모델에 넣을 임의의 src, tgt, src_mask 생성

def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
        yield (src, tgt, src_mask)

In [5]:
## Train Model

import tqdm
import torch.optim as optim

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## loss update 해가면서 학습
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask = next(cycle())

    loss = model(src, tgt, mask=src_mask)
    loss.backward()
    print(f'{i}: {loss.item()}')

    optim.step()
    optim.zero_grad()

    ## 매 N(100)번마다 accuracy 측정
    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

training:   0%|          | 0/1000 [00:00<?, ?it/s]

0: 2.951876163482666
1: 2.5611953735351562
2: 2.4538168907165527
3: 2.3432841300964355
4: 2.3214359283447266
5: 2.2910056114196777
6: 2.3302743434906006
7: 2.293490409851074
8: 2.2404603958129883
9: 2.269561529159546
10: 2.233065366744995
11: 2.2342183589935303
12: 2.2177398204803467
13: 2.1737020015716553
14: 2.2359542846679688
15: 2.202705144882202
16: 2.211676597595215
17: 2.1926817893981934
18: 2.176954746246338
19: 2.137864828109741
20: 2.1121644973754883
21: 2.177294969558716
22: 2.1875507831573486
23: 2.1798977851867676
24: 2.126984119415283
25: 2.1395483016967773
26: 2.1068522930145264
27: 2.1098952293395996
28: 2.1195144653320312
29: 2.1122450828552246
30: 2.1435420513153076
31: 2.1014206409454346
32: 2.1291754245758057
33: 2.1162028312683105
34: 2.0926263332366943
35: 2.085679054260254
36: 2.045834541320801
37: 2.0492172241210938
38: 2.0326762199401855
39: 2.0520591735839844
40: 2.0158159732818604
41: 2.03735613822937
42: 1.9990808963775635
43: 1.990182638168335
44: 1.9791989

training:  12%|█▏        | 116/1000 [00:10<01:16, 11.58it/s]

115: 1.6885095834732056
116: 1.679988145828247
117: 1.6499072313308716
118: 1.657619833946228
119: 1.6803282499313354
120: 1.6261844635009766
121: 1.6004163026809692
122: 1.5736719369888306
123: 1.6277436017990112
124: 1.605025053024292
125: 1.6206315755844116
126: 1.5684080123901367
127: 1.5891735553741455
128: 1.5784502029418945
129: 1.6091735363006592
130: 1.598368525505066
131: 1.614851951599121
132: 1.6022125482559204
133: 1.5886198282241821
134: 1.5904381275177002
135: 1.5420140027999878
136: 1.5947315692901611
137: 1.5986969470977783
138: 1.51925790309906
139: 1.5499653816223145
140: 1.5645718574523926
141: 1.56022047996521
142: 1.5486174821853638
143: 1.5641531944274902
144: 1.5038851499557495
145: 1.601027250289917
146: 1.5412538051605225
147: 1.5832237005233765
148: 1.5157076120376587
149: 1.5260868072509766
150: 1.4878228902816772
151: 1.5346064567565918
152: 1.5222036838531494
153: 1.4993596076965332
154: 1.476027011871338
155: 1.4833518266677856
156: 1.505155086517334
157:

training:  31%|███       | 311/1000 [00:20<00:42, 16.21it/s]

307: 0.997701108455658
308: 0.9855968952178955
309: 1.076733112335205
310: 0.9928315281867981
311: 0.9801666736602783
312: 1.0061997175216675
313: 1.0270564556121826
314: 0.9828647375106812
315: 1.0064486265182495
316: 0.989850640296936
317: 0.961386501789093
318: 0.9955427646636963
319: 0.9323381781578064
320: 0.9593756794929504
321: 0.9481152296066284
322: 0.9732539653778076
323: 0.9410859942436218
324: 1.0011714696884155
325: 0.9713990092277527
326: 0.9068661332130432
327: 0.9408460259437561
328: 0.9689000248908997
329: 0.9317029714584351
330: 0.9273517727851868
331: 0.9149118661880493
332: 0.9212674498558044
333: 1.0011199712753296
334: 0.9993434548377991
335: 1.0593186616897583
336: 0.9774379730224609
337: 1.0540913343429565
338: 1.006453275680542
339: 0.9827117919921875
340: 0.9950416088104248
341: 0.9566176533699036
342: 0.9880194067955017
343: 0.9590165019035339
344: 0.9059821963310242
345: 0.9752295613288879
346: 0.9370886087417603
347: 1.01070237159729
348: 0.9683096408843994

training:  51%|█████     | 506/1000 [00:30<00:28, 17.10it/s]

503: 0.6462603211402893
504: 0.6475587487220764
505: 0.6100461483001709
506: 0.6226413249969482
507: 0.662249743938446
508: 0.599905788898468
509: 0.6263145208358765
510: 0.6340590715408325
511: 0.607163667678833
512: 0.6432814002037048
513: 0.6196220517158508
514: 0.6128268837928772
515: 0.5916788578033447
516: 0.6427377462387085
517: 0.5766206979751587
518: 0.6236575245857239
519: 0.5924926400184631
520: 0.6148115396499634
521: 0.5904048681259155
522: 0.6270086765289307
523: 0.5706331729888916
524: 0.6197294592857361
525: 0.5926479697227478
526: 0.6140221953392029
527: 0.559475839138031
528: 0.6001051664352417
529: 0.5714327096939087
530: 0.5619996190071106
531: 0.5951911807060242
532: 0.5721026659011841
533: 0.5800743103027344
534: 0.5693296194076538
535: 0.5641229152679443
536: 0.5723752379417419
537: 0.5796456933021545
538: 0.595663845539093
539: 0.5752490758895874
540: 0.5827935338020325
541: 0.5935394763946533
542: 0.5585675239562988
543: 0.6176038980484009
544: 0.60120475292205

training:  69%|██████▉   | 688/1000 [00:40<00:17, 17.51it/s]

687: 0.47012585401535034
688: 0.48177027702331543
689: 0.4769915044307709
690: 0.4517093896865845
691: 0.46625450253486633
692: 0.4606900215148926
693: 0.4622890055179596
694: 0.4645303785800934
695: 0.47784650325775146
696: 0.4767778515815735
697: 0.47405120730400085
698: 0.49485012888908386
699: 0.466756135225296
700: 0.46710070967674255
input:   tensor([[2, 5, 2, 2, 3, 5, 4, 2]], device='cuda:0')
predicted output:   tensor([[5, 2, 2, 2, 5, 3, 4, 2]], device='cuda:0')
incorrects: 4
701: 0.5028604865074158
702: 0.4692763388156891
703: 0.4895990490913391
704: 0.44853419065475464
705: 0.43978798389434814
706: 0.4659585654735565
707: 0.4594346880912781
708: 0.4438977539539337
709: 0.4584377110004425
710: 0.47122499346733093
711: 0.46800845861434937
712: 0.4743494689464569
713: 0.507631242275238
714: 0.4816054701805115
715: 0.5559099316596985
716: 0.5088658928871155
717: 0.4967734217643738
718: 0.5486406087875366
719: 0.5655853152275085
720: 0.49516069889068604
721: 0.5382137894630432
722

training:  87%|████████▋ | 870/1000 [00:51<00:07, 17.60it/s]

869: 0.40213194489479065
870: 0.42401042580604553
871: 0.4056536853313446
872: 0.41406872868537903
873: 0.4031669497489929
874: 0.41860735416412354
875: 0.411870539188385
876: 0.4283507168292999
877: 0.3909497857093811
878: 0.3861047029495239
879: 0.3902338147163391
880: 0.40511608123779297
881: 0.38508832454681396
882: 0.40912431478500366
883: 0.38295644521713257
884: 0.41225123405456543
885: 0.42231374979019165
886: 0.40940478444099426
887: 0.3877836763858795
888: 0.40154793858528137
889: 0.3943355083465576
890: 0.43038997054100037
891: 0.3853076696395874
892: 0.40516984462738037
893: 0.42102235555648804
894: 0.3807799816131592
895: 0.3856078088283539
896: 0.40726616978645325
897: 0.4154421091079712
898: 0.4326346218585968
899: 0.3913170397281647
900: 0.40585437417030334
input:   tensor([[3, 5, 3, 5, 2, 3, 3, 2]], device='cuda:0')
predicted output:   tensor([[5, 3, 3, 3, 2, 5, 3, 5]], device='cuda:0')
incorrects: 5
901: 0.39358165860176086
902: 0.38211172819137573
903: 0.404450595378

training: 100%|██████████| 1000/1000 [00:58<00:00, 17.19it/s]

998: 0.4173107445240021
999: 0.3606526255607605





In [6]:
print(f'Accuracy : {(len(src) - incorrects)/len(src)*100}%')

Accuracy : 84.375%
