## Transformers architecture를 받아와 구축해보자!

Transformers 강의 + 실습까지 수강하시느라 너무 수고 많으셨습니다! 저도 저번 기수 당시 transformers를 처음 접했는데, 모델의 구조가 난해하고 쓰인 개념들도 어려워서 이해하는데 많은 시간이 걸렸었네요.. 그래도 transformers 모델이 현재 contemporary AI technology에 쓰이지 않는 곳이 없다보니, 어려운 내용들이 다수 있지만 힘주어서 중요한 내용들로 채워넣으려고 노력했습니다 수고하셨습니다 XD

Transformers은 개념도 개념이지만, 매번 새롭게 attention 코드를 짜고, encoder와 decoder 구조를 구축하는 것도 막막하실 겁니다! 다행히도 transformers architecture는 워낙 유명해서 이제 코드 몇줄만 `딸깍`해도 최신 논문기반 transformer 구조를 `huggingface` 혹은 `github`에서 받아서 사용할 수 있습니다 :D 이번 실습에는 초심자가 사용하기는 어렵지만 `x-transformers`에서 저희가 구축한 transformers 구조를 받아와, 예시 문장을 출력하는 것까지 마무리할 예정입니다!

transformers architecture가 2017년 나온 이후로, 이 architecture를 기반으로 한 여러 모델들이 나왔고, 또한 base model에 대해서도 여러 개선점들이 추가되었습니다. `x-transformers`은 이 개선된 model들을 코드 몇줄을 추가해서 적용 가능하게 하는 라이브러리로, 최신 논문 동향을 파악하고 있어야 한다는 점에서 어렵지만 그만큼 성능이 뒷받침해주는 코드들을 모아둔 라이브러리입니다.

TMI가 생각보다 길어졌는데, 아무튼 이번 과제는 따로 작성해 넣어야할 부분은 없고, 제가 드리는 코드를 그대로 실행하기만 하면 되는 과정으로 추가했습니다. 이제 개강이 얼마 남지 않았는데, 다들 화이팅입니다 XD

# Colab GPU 환경에서 구동하세요!

In [1]:
!pip install x-transformers

Collecting x-transformers
  Downloading x_transformers-1.19.1-py3-none-any.whl (27 kB)
Collecting einops>=0.6.1 (from x-transformers)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, x-transformers
Successfully installed einops-0.6.1 x-transformers-1.19.1


In [2]:
import torch
from x_transformers import XTransformer

## Transformers Architectures 받아옴
model = XTransformer(
    ## 모델의 차원 (논문 = 512)
    dim = 16,
    ## encoder token의 개수 (논문 = 256)
    enc_num_tokens = 16,
    ## encoder 반복 횟수 (논문 = 6)
    enc_depth = 6,
    ## multihead attention n_heads 개수 (논문 = 8)
    enc_heads = 8,
    ## encoder token의 max sequence length (논문 = 1024)
    enc_max_seq_len = 32,
    ## (논문 = 256)
    dec_num_tokens = 16,
    dec_depth = 6,
    dec_heads = 8,
    ## (논문 = 1024)
    dec_max_seq_len = 32,
    tie_token_emb = True
).cuda()

In [3]:
## NUM_BATCHES = Epoches의 개수
## BATCH_SIZE = 하나의 batch에 들어갈 sample의 개수
## LEARNING_RATE = learning rate
## GENERATE_EVERY = 100번마다 한번씩 generate해서 accuracy확인
## NUM_TOKENS = 데이터 내 유니크한 토큰의 수
## ENC_SEQ_LEN = encoder sequence length

NUM_BATCHES = int(1e3)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 100
NUM_TOKENS = 4 + 2
ENC_SEQ_LEN = 8
DEC_SEQ_LEN = 16 + 1

In [4]:
## Transformer 모델에 넣을 임의의 src, tgt, src_mask 생성

def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
        yield (src, tgt, src_mask)

In [5]:
## Train Model

import tqdm
import torch.optim as optim

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## loss update 해가면서 학습
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask = next(cycle())

    loss = model(src, tgt, mask=src_mask)
    loss.backward()
    print(f'{i}: {loss.item()}')

    optim.step()
    optim.zero_grad()

    ## 매 N(100)번마다 accuracy 측정
    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

training:   0%|          | 0/1000 [00:00<?, ?it/s]

0: 3.253167152404785
1: 2.5487208366394043
2: 2.29557728767395
3: 2.160595178604126
4: 2.075458288192749
5: 1.9932801723480225
6: 1.9861074686050415
7: 1.9743404388427734
8: 1.9212812185287476
9: 1.9268182516098022
10: 1.9502754211425781
11: 1.9572566747665405
12: 1.9201868772506714
13: 1.909145712852478
14: 1.911095380783081
15: 1.875662922859192
16: 1.9148366451263428
17: 1.8692140579223633
18: 1.8393852710723877
19: 1.8115262985229492
20: 1.8784409761428833
21: 1.8242179155349731
22: 1.8324335813522339
23: 1.8471227884292603
24: 1.846914529800415
25: 1.8429725170135498
26: 1.822601556777954
27: 1.8205138444900513
28: 1.7577190399169922
29: 1.8302661180496216
30: 1.803601861000061
31: 1.8290449380874634
32: 1.8284894227981567
33: 1.7958651781082153
34: 1.8106882572174072
35: 1.7815942764282227
36: 1.803736925125122
37: 1.785261631011963
38: 1.7701075077056885
39: 1.766919732093811
40: 1.7944310903549194
41: 1.777021884918213
42: 1.7558547258377075
43: 1.738787293434143
44: 1.75789725

training:  15%|█▌        | 152/1000 [00:10<00:55, 15.16it/s]

149: 1.416460633277893
150: 1.446626901626587
151: 1.3978991508483887
152: 1.378917932510376
153: 1.4098718166351318
154: 1.4104870557785034
155: 1.420050024986267
156: 1.400003433227539
157: 1.4049102067947388
158: 1.4712647199630737
159: 1.3981519937515259
160: 1.3935000896453857
161: 1.3829330205917358
162: 1.3950976133346558
163: 1.3924551010131836
164: 1.3937245607376099
165: 1.4112955331802368
166: 1.3661171197891235
167: 1.4071041345596313
168: 1.3482192754745483
169: 1.4027206897735596
170: 1.3667316436767578
171: 1.3535016775131226
172: 1.3774274587631226
173: 1.3740594387054443
174: 1.3677846193313599
175: 1.3308106660842896
176: 1.3386098146438599
177: 1.3710424900054932
178: 1.3648675680160522
179: 1.361747145652771
180: 1.3126083612442017
181: 1.3465183973312378
182: 1.3281645774841309
183: 1.3170881271362305
184: 1.2997926473617554
185: 1.2837104797363281
186: 1.3351032733917236
187: 1.2757453918457031
188: 1.2923691272735596
189: 1.309665560722351
190: 1.3323899507522583

training:  36%|███▌      | 358/1000 [00:20<00:35, 18.34it/s]

357: 0.8564289212226868
358: 0.8477700352668762
359: 0.8572279214859009
360: 0.8609045147895813
361: 0.8232436180114746
362: 0.7878204584121704
363: 0.8348687291145325
364: 0.810657262802124
365: 0.8213077187538147
366: 0.787672221660614
367: 0.8364291191101074
368: 0.8134245276451111
369: 0.795101523399353
370: 0.7716326713562012
371: 0.8513203263282776
372: 0.7986306548118591
373: 0.8411201238632202
374: 0.7761363983154297
375: 0.8101418614387512
376: 0.7951498627662659
377: 0.8332136273384094
378: 0.7501810789108276
379: 0.7848488688468933
380: 0.7751827239990234
381: 0.7756786942481995
382: 0.7347120046615601
383: 0.7768001556396484
384: 0.7984715104103088
385: 0.7401519417762756
386: 0.7365789413452148
387: 0.779926598072052
388: 0.7838245034217834
389: 0.7615861892700195
390: 0.734221339225769
391: 0.769199013710022
392: 0.7549229860305786
393: 0.8399533033370972
394: 0.7429800033569336
395: 0.7763772010803223
396: 0.7527692914009094
397: 0.7982178926467896
398: 0.753232419490814

training:  56%|█████▋    | 564/1000 [00:30<00:23, 18.69it/s]

562: 0.5100585222244263
563: 0.5296949744224548
564: 0.527786135673523
565: 0.5155544877052307
566: 0.5147513151168823
567: 0.5394890904426575
568: 0.5070684552192688
569: 0.5022724866867065
570: 0.5018717646598816
571: 0.5003150105476379
572: 0.49890580773353577
573: 0.500983715057373
574: 0.506031334400177
575: 0.511115312576294
576: 0.5199735164642334
577: 0.5328829288482666
578: 0.5217544436454773
579: 0.5157819986343384
580: 0.5104552507400513
581: 0.5219036340713501
582: 0.4935610592365265
583: 0.5327804088592529
584: 0.5100175142288208
585: 0.5278816223144531
586: 0.5382776260375977
587: 0.5130273699760437
588: 0.5117141008377075
589: 0.49948763847351074
590: 0.5432973504066467
591: 0.5019978880882263
592: 0.5066117644309998
593: 0.4980362057685852
594: 0.48347562551498413
595: 0.5179393887519836
596: 0.5271766781806946
597: 0.5089397430419922
598: 0.48639288544654846
599: 0.5048534274101257
600: 0.4931618869304657
input:   tensor([[2, 4, 4, 4, 5, 4, 5, 5]], device='cuda:0')
pre

training:  77%|███████▋  | 768/1000 [00:40<00:11, 19.35it/s]

763: 0.39701053500175476
764: 0.45263317227363586
765: 0.40680208802223206
766: 0.4062192440032959
767: 0.4452989101409912
768: 0.3758666515350342
769: 0.4218957722187042
770: 0.430584192276001
771: 0.40974724292755127
772: 0.4308444559574127
773: 0.3982343077659607
774: 0.4079146683216095
775: 0.43094488978385925
776: 0.41978001594543457
777: 0.4295048117637634
778: 0.41245102882385254
779: 0.4192917048931122
780: 0.41186457872390747
781: 0.4279228746891022
782: 0.43894362449645996
783: 0.47836336493492126
784: 0.4326198399066925
785: 0.4272969365119934
786: 0.4350811541080475
787: 0.47654762864112854
788: 0.4258374869823456
789: 0.41832268238067627
790: 0.4152073562145233
791: 0.4277893006801605
792: 0.4343804717063904
793: 0.4020806849002838
794: 0.3799927234649658
795: 0.44473159313201904
796: 0.41363298892974854
797: 0.405514657497406
798: 0.41905730962753296
799: 0.45883163809776306
800: 0.41775035858154297
input:   tensor([[3, 2, 2, 2, 2, 4, 5, 2]], device='cuda:0')
predicted ou

training:  98%|█████████▊| 984/1000 [00:50<00:00, 20.11it/s]

984: 0.39709171652793884
985: 0.3719639480113983
986: 0.35971587896347046
987: 0.3736967444419861
988: 0.37682223320007324
989: 0.3521883189678192
990: 0.3647136390209198
991: 0.3605862855911255
992: 0.39089444279670715
993: 0.35468411445617676
994: 0.4209952652454376
995: 0.37098199129104614
996: 0.36697980761528015
997: 0.4003404974937439


training: 100%|██████████| 1000/1000 [00:52<00:00, 19.21it/s]

998: 0.4151996076107025
999: 0.4248560070991516





In [6]:
print(f'Accuracy : {(len(src) - incorrects)/len(src)*100}%')

Accuracy : 87.5%
