## Transformers architecture를 받아와 구축해보자!

Transformers 강의 + 실습까지 수강하시느라 너무 수고 많으셨습니다! 저도 저번 기수 당시 transformers를 처음 접했는데, 모델의 구조가 난해하고 쓰인 개념들도 어려워서 이해하는데 많은 시간이 걸렸었네요.. 그래도 transformers 모델이 현재 contemporary AI technology에 쓰이지 않는 곳이 없다보니, 어려운 내용들이 다수 있지만 힘주어서 중요한 내용들로 채워넣으려고 노력했습니다 수고하셨습니다 XD

Transformers은 개념도 개념이지만, 매번 새롭게 attention 코드를 짜고, encoder와 decoder 구조를 구축하는 것도 막막하실 겁니다! 다행히도 transformers architecture는 워낙 유명해서 이제 코드 몇줄만 `딸깍`해도 최신 논문기반 transformer 구조를 `huggingface` 혹은 `github`에서 받아서 사용할 수 있습니다 :D 이번 실습에는 초심자가 사용하기는 어렵지만 `x-transformers`에서 저희가 구축한 transformers 구조를 받아와, 예시 문장을 출력하는 것까지 마무리할 예정입니다!

transformers architecture가 2017년 나온 이후로, 이 architecture를 기반으로 한 여러 모델들이 나왔고, 또한 base model에 대해서도 여러 개선점들이 추가되었습니다. `x-transformers`은 이 개선된 model들을 코드 몇줄을 추가해서 적용 가능하게 하는 라이브러리로, 최신 논문 동향을 파악하고 있어야 한다는 점에서 어렵지만 그만큼 성능이 뒷받침해주는 코드들을 모아둔 라이브러리입니다.

TMI가 생각보다 길어졌는데, 아무튼 이번 과제는 따로 작성해 넣어야할 부분은 없고, 제가 드리는 코드를 그대로 실행하기만 하면 되는 과정으로 추가했습니다. 이제 개강이 얼마 남지 않았는데, 다들 화이팅입니다 XD

# Colab GPU 환경에서 구동하세요!

In [1]:
!pip install x-transformers

Collecting x-transformers
  Downloading x_transformers-1.19.1-py3-none-any.whl (27 kB)
Collecting einops>=0.6.1 (from x-transformers)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, x-transformers
Successfully installed einops-0.6.1 x-transformers-1.19.1


In [2]:
import torch
from x_transformers import XTransformer

## Transformers Architectures 받아오
model = XTransformer(
    ## 모델의 차원 (논문 = 512)
    dim = 16,
    ## encoder token의 개수 (논문 = 256)
    enc_num_tokens = 16,
    ## encoder 반복 횟수 (논문 = 6)
    enc_depth = 6,
    ## multihead attention n_heads 개수 (논문 = 8)
    enc_heads = 8,
    ## encoder token의 max sequence length (논문 = 1024)
    enc_max_seq_len = 32,
    ## (논문 = 256)
    dec_num_tokens = 16,
    dec_depth = 6,
    dec_heads = 8,
    ## (논문 = 1024)
    dec_max_seq_len = 32,
    tie_token_emb = True
).cuda()

In [3]:
## NUM_BATCHES = Epoches의 개수
## BATCH_SIZE = 하나의 batch에 들어갈 sample의 개수
## LEARNING_RATE = learning rate
## GENERATE_EVERY = 100번마다 한번씩 generate해서 accuracy확인
## NUM_TOKENS = 데이터 내 유니크한 토큰의 수
## ENC_SEQ_LEN = encoder sequence length

NUM_BATCHES = int(1e3)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 100
NUM_TOKENS = 4 + 2
ENC_SEQ_LEN = 8
DEC_SEQ_LEN = 16 + 1

In [4]:
## Transformer 모델에 넣을 임의의 src, tgt, src_mask 생성

def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
        yield (src, tgt, src_mask)

In [5]:
## Train Model

import tqdm
import torch.optim as optim

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## loss update 해가면서 학습
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask = next(cycle())

    loss = model(src, tgt, mask=src_mask)
    loss.backward()
    print(f'{i}: {loss.item()}')

    optim.step()
    optim.zero_grad()

    ## 매 N(100)번마다 accuracy 측정
    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

training:   0%|          | 0/1000 [00:00<?, ?it/s]

0: 2.797635078430176
1: 2.4878015518188477
2: 2.3340139389038086
3: 2.2608416080474854
4: 2.2459466457366943
5: 2.199850559234619
6: 2.2026562690734863
7: 2.163006067276001
8: 2.1598961353302
9: 2.1393463611602783
10: 2.1239266395568848
11: 2.127950668334961
12: 2.0992133617401123
13: 2.135681629180908
14: 2.0985476970672607
15: 2.0787017345428467
16: 2.058150291442871
17: 2.069566249847412
18: 2.0514683723449707
19: 2.017598867416382
20: 2.025083065032959
21: 2.048281669616699
22: 1.9812371730804443
23: 2.0286498069763184
24: 1.9859546422958374
25: 2.005308151245117
26: 2.0049960613250732
27: 1.9866104125976562
28: 2.0047099590301514
29: 2.005924940109253
30: 1.9846440553665161
31: 1.960372805595398
32: 1.972169280052185
33: 1.952781319618225
34: 1.9394750595092773
35: 1.980185866355896
36: 1.9445708990097046
37: 1.9491454362869263
38: 1.9554587602615356
39: 1.9372355937957764
40: 1.9379241466522217
41: 1.8985792398452759
42: 1.9272955656051636
43: 1.9330424070358276
44: 1.91910815238

training:  15%|█▍        | 148/1000 [00:10<00:57, 14.72it/s]

145: 1.4933606386184692
146: 1.5532925128936768
147: 1.5040174722671509
148: 1.4859027862548828
149: 1.5282752513885498
150: 1.4666990041732788
151: 1.4721544981002808
152: 1.4638686180114746
153: 1.4911365509033203
154: 1.4735946655273438
155: 1.4414175748825073
156: 1.491284728050232
157: 1.4260482788085938
158: 1.4637525081634521
159: 1.430141568183899
160: 1.4627032279968262
161: 1.4170596599578857
162: 1.4647945165634155
163: 1.4270256757736206
164: 1.4089199304580688
165: 1.4185672998428345
166: 1.4489885568618774
167: 1.3971130847930908
168: 1.4531923532485962
169: 1.443192481994629
170: 1.4223392009735107
171: 1.3937302827835083
172: 1.4409198760986328
173: 1.402637004852295
174: 1.4010883569717407
175: 1.4363696575164795
176: 1.4160774946212769
177: 1.404794692993164
178: 1.3781437873840332
179: 1.3960669040679932
180: 1.3937376737594604
181: 1.3869215250015259
182: 1.398311972618103
183: 1.358536720275879
184: 1.3720898628234863
185: 1.3311406373977661
186: 1.3692959547042847

training:  31%|███       | 310/1000 [00:20<00:44, 15.54it/s]

308: 1.0146417617797852
309: 1.0331388711929321
310: 0.9841303825378418
311: 1.039339303970337
312: 0.9902533292770386
313: 1.0149030685424805
314: 0.964514970779419
315: 1.055003046989441
316: 0.9950140118598938
317: 1.0058685541152954
318: 0.9912046194076538
319: 1.0110431909561157
320: 0.9901937246322632
321: 1.0050981044769287
322: 0.9980780482292175
323: 0.9876976013183594
324: 0.9584865570068359
325: 0.9742130637168884
326: 1.0943330526351929
327: 0.9888275861740112
328: 1.0039036273956299
329: 0.9838696718215942
330: 0.9816551208496094
331: 0.957250714302063
332: 0.9352869391441345
333: 0.9796928763389587
334: 0.9751938581466675
335: 0.9247007966041565
336: 0.9066839814186096
337: 0.9321062564849854
338: 0.9432708621025085
339: 0.9290735721588135
340: 0.9565730094909668
341: 0.9168662428855896
342: 0.9496229290962219
343: 0.9128372669219971
344: 0.9782382249832153
345: 0.908180296421051
346: 0.8959537744522095
347: 0.9523228406906128
348: 0.8837871551513672
349: 0.96727871894836

training:  52%|█████▏    | 516/1000 [00:30<00:27, 17.83it/s]

512: 0.6090440154075623
513: 0.6446576714515686
514: 0.6527178883552551
515: 0.6378310322761536
516: 0.6468561291694641
517: 0.6464178562164307
518: 0.6493579149246216
519: 0.6193405985832214
520: 0.6219031810760498
521: 0.61358642578125
522: 0.6259344816207886
523: 0.6075366139411926
524: 0.621893048286438
525: 0.6185920238494873
526: 0.6025910973548889
527: 0.58881676197052
528: 0.5832171440124512
529: 0.622134804725647
530: 0.5740756988525391
531: 0.5804957747459412
532: 0.5853502750396729
533: 0.5903634428977966
534: 0.5802825689315796
535: 0.5740465521812439
536: 0.5740303993225098
537: 0.6013234853744507
538: 0.5929142832756042
539: 0.5830521583557129
540: 0.5711731314659119
541: 0.5780900716781616
542: 0.6113883852958679
543: 0.5938930511474609
544: 0.5975487232208252
545: 0.5905042886734009
546: 0.611792802810669
547: 0.6165891289710999
548: 0.6411629319190979
549: 0.5778566598892212
550: 0.6222193241119385
551: 0.6198763847351074
552: 0.5874467492103577
553: 0.6095393300056458

training:  73%|███████▎  | 726/1000 [00:40<00:14, 19.06it/s]

722: 0.4757784605026245
723: 0.4587622582912445
724: 0.4903657138347626
725: 0.46297216415405273
726: 0.4727136194705963
727: 0.46477824449539185
728: 0.46662795543670654
729: 0.461911678314209
730: 0.4800814688205719
731: 0.5317946672439575
732: 0.4379477798938751
733: 0.5531448125839233
734: 0.45533549785614014
735: 0.4615228772163391
736: 0.49226969480514526
737: 0.4844951927661896
738: 0.4831395149230957
739: 0.47786587476730347
740: 0.46215614676475525
741: 0.4732614755630493
742: 0.4989212453365326
743: 0.4634529948234558
744: 0.4943608343601227
745: 0.4753184914588928
746: 0.4943804144859314
747: 0.4674665927886963
748: 0.4204547703266144
749: 0.45043930411338806
750: 0.4941175580024719
751: 0.46208351850509644
752: 0.5044943690299988
753: 0.4826430082321167
754: 0.5478001236915588
755: 0.47259631752967834
756: 0.4651457369327545
757: 0.5146099328994751
758: 0.4639907777309418
759: 0.5642129778862
760: 0.46897342801094055
761: 0.47310397028923035
762: 0.4854135811328888
763: 0.4

training:  94%|█████████▎| 936/1000 [00:50<00:03, 19.72it/s]

933: 0.38110414147377014
934: 0.3869953453540802
935: 0.36288371682167053
936: 0.3652299642562866
937: 0.3955049514770508
938: 0.4198540151119232
939: 0.416943222284317
940: 0.4016740918159485
941: 0.3932243287563324
942: 0.36744314432144165
943: 0.42825859785079956
944: 0.378195583820343
945: 0.3975553810596466
946: 0.39388197660446167
947: 0.3940126895904541
948: 0.4037431478500366
949: 0.389197438955307
950: 0.38490474224090576
951: 0.3487967252731323
952: 0.36680731177330017
953: 0.38446107506752014
954: 0.3785320222377777
955: 0.3767951726913452
956: 0.3721362352371216
957: 0.36923813819885254
958: 0.38836953043937683
959: 0.373357892036438
960: 0.39260411262512207
961: 0.3516077399253845
962: 0.37776148319244385
963: 0.37694621086120605
964: 0.36616280674934387
965: 0.3600388467311859
966: 0.3671846389770508
967: 0.40030115842819214
968: 0.39550328254699707
969: 0.41785430908203125
970: 0.37128469347953796
971: 0.44521546363830566
972: 0.36867645382881165
973: 0.4339306354522705


training: 100%|██████████| 1000/1000 [00:53<00:00, 18.87it/s]

998: 0.41865459084510803
999: 0.382826030254364





In [6]:
print(f'Accuracy : {(len(src) - incorrects)/len(src)*100}%')

Accuracy : 84.375%
