## Transformers architecture를 받아와 구축해보자!

Transformers 강의 + 실습까지 수강하시느라 너무 수고 많으셨습니다! 저도 저번 기수 당시 transformers를 처음 접했는데, 모델의 구조가 난해하고 쓰인 개념들도 어려워서 이해하는데 많은 시간이 걸렸었네요.. 그래도 transformers 모델이 현재 contemporary AI technology에 쓰이지 않는 곳이 없다보니, 어려운 내용들이 다수 있지만 힘주어서 중요한 내용들로 채워넣으려고 노력했습니다 수고하셨습니다 XD

Transformers은 개념도 개념이지만, 매번 새롭게 attention 코드를 짜고, encoder와 decoder 구조를 구축하는 것도 막막하실 겁니다! 다행히도 transformers architecture는 워낙 유명해서 이제 코드 몇줄만 `딸깍`해도 최신 논문기반 transformer 구조를 `huggingface` 혹은 `github`에서 받아서 사용할 수 있습니다 :D 이번 실습에는 초심자가 사용하기는 어렵지만 `x-transformers`에서 저희가 구축한 transformers 구조를 받아와, 예시 문장을 출력하는 것까지 마무리할 예정입니다!

transformers architecture가 2017년 나온 이후로, 이 architecture를 기반으로 한 여러 모델들이 나왔고, 또한 base model에 대해서도 여러 개선점들이 추가되었습니다. `x-transformers`은 이 개선된 model들을 코드 몇줄을 추가해서 적용 가능하게 하는 라이브러리로, 최신 논문 동향을 파악하고 있어야 한다는 점에서 어렵지만 그만큼 성능이 뒷받침해주는 코드들을 모아둔 라이브러리입니다.

TMI가 생각보다 길어졌는데, 아무튼 이번 과제는 따로 작성해 넣어야할 부분은 없고, 제가 드리는 코드를 그대로 실행하기만 하면 되는 과정으로 추가했습니다. 이제 개강이 얼마 남지 않았는데, 다들 화이팅입니다 XD

# Colab GPU 환경에서 구동하세요!

In [1]:
!pip install x-transformers

Collecting x-transformers
  Downloading x_transformers-1.19.1-py3-none-any.whl (27 kB)
Collecting einops>=0.6.1 (from x-transformers)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, x-transformers
Successfully installed einops-0.6.1 x-transformers-1.19.1


In [2]:
import torch
from x_transformers import XTransformer

## Transformers Architectures 받아오기
model = XTransformer(
    ## 모델의 차원 (논문 = 512)
    dim = 16,
    ## encoder token의 개수 (논문 = 256)
    enc_num_tokens = 16,
    ## encoder 반복 횟수 (논문 = 6)
    enc_depth = 6,
    ## multihead attention n_heads 개수 (논문 = 8)
    enc_heads = 8,
    ## encoder token의 max sequence length (논문 = 1024)
    enc_max_seq_len = 32,
    ## (논문 = 256)
    dec_num_tokens = 16,
    dec_depth = 6,
    dec_heads = 8,
    ## (논문 = 1024)
    dec_max_seq_len = 32,
    tie_token_emb = True
).cuda()

In [3]:
## NUM_BATCHES = Epoches의 개수
## BATCH_SIZE = 하나의 batch에 들어갈 sample의 개수
## LEARNING_RATE = learning rate
## GENERATE_EVERY = 100번마다 한번씩 generate해서 accuracy확인
## NUM_TOKENS = 데이터 내 유니크한 토큰의 수
## ENC_SEQ_LEN = encoder sequence length

NUM_BATCHES = int(1e3)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 100
NUM_TOKENS = 4 + 2
ENC_SEQ_LEN = 8
DEC_SEQ_LEN = 16 + 1

In [4]:
## Transformer 모델에 넣을 임의의 src, tgt, src_mask 생성

def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
        yield (src, tgt, src_mask)

In [5]:
## Train Model

import tqdm
import torch.optim as optim

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## loss update 해가면서 학습
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask = next(cycle())

    loss = model(src, tgt, mask=src_mask)
    loss.backward()
    print(f'{i}: {loss.item()}')

    optim.step()
    optim.zero_grad()

    ## 매 N(100)번마다 accuracy 측정
    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

training:   0%|          | 0/1000 [00:00<?, ?it/s]

0: 2.7196145057678223
1: 2.284837007522583
2: 2.15211820602417
3: 2.147310495376587
4: 2.0967254638671875
5: 2.149860143661499
6: 2.0819973945617676
7: 2.1099071502685547
8: 2.0962393283843994
9: 2.098416328430176
10: 2.10199236869812
11: 2.070236921310425
12: 2.047194719314575
13: 2.0633225440979004
14: 2.047189235687256
15: 2.0333993434906006
16: 2.0334208011627197
17: 2.0230162143707275
18: 1.9932204484939575
19: 1.9907764196395874
20: 1.9742019176483154
21: 1.9829375743865967
22: 1.9801832437515259
23: 1.9742459058761597
24: 1.986448884010315
25: 1.9328001737594604
26: 1.9459468126296997
27: 1.916403889656067
28: 1.9224090576171875
29: 1.9260005950927734
30: 1.9006342887878418
31: 1.9036284685134888
32: 1.8902479410171509
33: 1.9224823713302612
34: 1.8789132833480835
35: 1.9032633304595947
36: 1.8926374912261963
37: 1.8730021715164185
38: 1.875799536705017
39: 1.8661211729049683
40: 1.840732216835022
41: 1.8586337566375732
42: 1.863901972770691
43: 1.8689433336257935
44: 1.83840513

training:  14%|█▍        | 143/1000 [00:10<01:00, 14.24it/s]

141: 1.5288606882095337
142: 1.5209015607833862
143: 1.505415916442871
144: 1.4561697244644165
145: 1.4986673593521118
146: 1.45909583568573
147: 1.5112956762313843
148: 1.4633703231811523
149: 1.4910372495651245
150: 1.4954378604888916
151: 1.4372719526290894
152: 1.4554929733276367
153: 1.4624873399734497
154: 1.4488707780838013
155: 1.4531646966934204
156: 1.4654322862625122
157: 1.4674087762832642
158: 1.472261905670166
159: 1.4595696926116943
160: 1.4696125984191895
161: 1.44108247756958
162: 1.4641079902648926
163: 1.4450774192810059
164: 1.4595260620117188
165: 1.4700955152511597
166: 1.451647162437439
167: 1.4227265119552612
168: 1.4423176050186157
169: 1.4113271236419678
170: 1.4474937915802002
171: 1.4112871885299683
172: 1.4351924657821655
173: 1.3865132331848145
174: 1.405417799949646
175: 1.3534640073776245
176: 1.367796540260315
177: 1.4156488180160522
178: 1.3797411918640137
179: 1.3862048387527466
180: 1.3958663940429688
181: 1.4196858406066895
182: 1.3976317644119263
1

training:  35%|███▌      | 351/1000 [00:20<00:35, 18.09it/s]

346: 0.8528451323509216
347: 0.8183538913726807
348: 0.9028303623199463
349: 0.8341947197914124
350: 0.8934822082519531
351: 0.8699567317962646
352: 0.8651015758514404
353: 0.8736851215362549
354: 0.8287045359611511
355: 0.8762385249137878
356: 0.8497115969657898
357: 0.9022618532180786
358: 0.8974716663360596
359: 0.9242638349533081
360: 0.9049992561340332
361: 0.9600681066513062
362: 0.9470364451408386
363: 0.8991883993148804
364: 0.9256433844566345
365: 0.8903313875198364
366: 0.8511799573898315
367: 0.8590328693389893
368: 0.9029624462127686
369: 0.9416267275810242
370: 0.8302051424980164
371: 0.8672866821289062
372: 0.9034883379936218
373: 0.8598299026489258
374: 0.8510544896125793
375: 0.8718454837799072
376: 0.861213743686676
377: 0.7928309440612793
378: 0.8292728066444397
379: 0.841499924659729
380: 0.8327513337135315
381: 0.7830634713172913
382: 0.7733782529830933
383: 0.8576073050498962
384: 0.7993869185447693
385: 0.8125122785568237
386: 0.8507531881332397
387: 0.76541006565

training:  58%|█████▊    | 578/1000 [00:30<00:20, 20.17it/s]

576: 0.5562976002693176
577: 0.5637539625167847
578: 0.5308570861816406
579: 0.5877580642700195
580: 0.5619627833366394
581: 0.5200582146644592
582: 0.5524949431419373
583: 0.5551047325134277
584: 0.544207751750946
585: 0.5270990133285522
586: 0.5486636757850647
587: 0.5428414344787598
588: 0.5285195708274841
589: 0.527692973613739
590: 0.5373461246490479
591: 0.5385679602622986
592: 0.5309513807296753
593: 0.5382928252220154
594: 0.547231137752533
595: 0.5126953125
596: 0.5374917387962341
597: 0.5291644334793091
598: 0.5258398056030273
599: 0.5142088532447815
600: 0.5338625311851501
input:   tensor([[5, 4, 5, 5, 2, 4, 5, 3]], device='cuda:0')
predicted output:   tensor([[5, 5, 4, 3, 5, 4, 5, 2]], device='cuda:0')
incorrects: 5
601: 0.5370497107505798
602: 0.5244868993759155
603: 0.5186803340911865
604: 0.508577823638916
605: 0.5324913859367371
606: 0.4968973994255066
607: 0.5166393518447876
608: 0.5355000495910645
609: 0.5252708196640015
610: 0.5147669911384583
611: 0.524306058883667


training:  80%|████████  | 805/1000 [00:41<00:09, 20.28it/s]

input:   tensor([[4, 3, 2, 3, 3, 3, 3, 2]], device='cuda:0')
predicted output:   tensor([[3, 2, 3, 3, 4, 3, 2, 3]], device='cuda:0')
incorrects: 6
801: 0.3985729515552521
802: 0.44731712341308594
803: 0.45375990867614746
804: 0.4487120807170868
805: 0.4801049828529358
806: 0.4315362274646759
807: 0.43499043583869934
808: 0.4882591962814331
809: 0.4848472774028778
810: 0.5357328057289124
811: 0.48897626996040344
812: 0.5590566992759705
813: 0.47789520025253296
814: 0.5115637183189392
815: 0.5622873306274414
816: 0.5216833353042603
817: 0.47438275814056396
818: 0.49852705001831055
819: 0.5240309238433838
820: 0.5389979481697083
821: 0.4651407301425934
822: 0.4875665307044983
823: 0.5122313499450684
824: 0.48469099402427673
825: 0.4869903028011322
826: 0.5012809038162231
827: 0.4606001377105713
828: 0.49124082922935486
829: 0.4412807822227478
830: 0.48682868480682373
831: 0.4718887507915497
832: 0.45571276545524597
833: 0.4839143455028534
834: 0.4618336856365204
835: 0.4736320674419403
83

training: 100%|██████████| 1000/1000 [00:50<00:00, 19.77it/s]

997: 0.4017289876937866
998: 0.3946918547153473
999: 0.368343323469162





In [6]:
print(f'Accuracy : {(len(src) - incorrects)/len(src)*100}%')

Accuracy : 87.5%
