## Transformers architecture를 받아와 구축해보자!

Transformers 강의 + 실습까지 수강하시느라 너무 수고 많으셨습니다! 저도 저번 기수 당시 transformers를 처음 접했는데, 모델의 구조가 난해하고 쓰인 개념들도 어려워서 이해하는데 많은 시간이 걸렸었네요.. 그래도 transformers 모델이 현재 contemporary AI technology에 쓰이지 않는 곳이 없다보니, 어려운 내용들이 다수 있지만 힘주어서 중요한 내용들로 채워넣으려고 노력했습니다 수고하셨습니다 XD

Transformers은 개념도 개념이지만, 매번 새롭게 attention 코드를 짜고, encoder와 decoder 구조를 구축하는 것도 막막하실 겁니다! 다행히도 transformers architecture는 워낙 유명해서 이제 코드 몇줄만 `딸깍`해도 최신 논문기반 transformer 구조를 `huggingface` 혹은 `github`에서 받아서 사용할 수 있습니다 :D 이번 실습에는 초심자가 사용하기는 어렵지만 `x-transformers`에서 저희가 구축한 transformers 구조를 받아와, 예시 문장을 출력하는 것까지 마무리할 예정입니다!

transformers architecture가 2017년 나온 이후로, 이 architecture를 기반으로 한 여러 모델들이 나왔고, 또한 base model에 대해서도 여러 개선점들이 추가되었습니다. `x-transformers`은 이 개선된 model들을 코드 몇줄을 추가해서 적용 가능하게 하는 라이브러리로, 최신 논문 동향을 파악하고 있어야 한다는 점에서 어렵지만 그만큼 성능이 뒷받침해주는 코드들을 모아둔 라이브러리입니다.

TMI가 생각보다 길어졌는데, 아무튼 이번 과제는 따로 작성해 넣어야할 부분은 없고, 제가 드리는 코드를 그대로 실행하기만 하면 되는 과정으로 추가했습니다. 이제 개강이 얼마 남지 않았는데, 다들 화이팅입니다 XD

# Colab GPU 환경에서 구동하세요!

In [1]:
!pip install x-transformers

Collecting x-transformers
  Downloading x_transformers-1.19.1-py3-none-any.whl (27 kB)
Collecting einops>=0.6.1 (from x-transformers)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, x-transformers
Successfully installed einops-0.6.1 x-transformers-1.19.1


In [2]:
import torch
from x_transformers import XTransformer

## Transformers Architectures 받아오
model = XTransformer(
    ## 모델의 차원 (논문 = 512)
    dim = 16,
    ## encoder token의 개수 (논문 = 256)
    enc_num_tokens = 16,
    ## encoder 반복 횟수 (논문 = 6)
    enc_depth = 6,
    ## multihead attention n_heads 개수 (논문 = 8)
    enc_heads = 8,
    ## encoder token의 max sequence length (논문 = 1024)
    enc_max_seq_len = 32,
    ## (논문 = 256)
    dec_num_tokens = 16,
    dec_depth = 6,
    dec_heads = 8,
    ## (논문 = 1024)
    dec_max_seq_len = 32,
    tie_token_emb = True
).cuda()

In [3]:
## NUM_BATCHES = Epoches의 개수
## BATCH_SIZE = 하나의 batch에 들어갈 sample의 개수
## LEARNING_RATE = learning rate
## GENERATE_EVERY = 100번마다 한번씩 generate해서 accuracy확인
## NUM_TOKENS = 데이터 내 유니크한 토큰의 수
## ENC_SEQ_LEN = encoder sequence length

NUM_BATCHES = int(1e3)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY  = 100
NUM_TOKENS = 4 + 2
ENC_SEQ_LEN = 8
DEC_SEQ_LEN = 16 + 1

In [4]:
## Transformer 모델에 넣을 임의의 src, tgt, src_mask 생성

def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
        yield (src, tgt, src_mask)

In [5]:
## Train Model

import tqdm
import torch.optim as optim

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## loss update 해가면서 학습
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask = next(cycle())

    loss = model(src, tgt, mask=src_mask)
    loss.backward()
    print(f'{i}: {loss.item()}')

    optim.step()
    optim.zero_grad()

    ## 매 N(100)번마다 accuracy 측정
    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

training:   0%|          | 0/1000 [00:00<?, ?it/s]

0: 2.700308084487915
1: 2.240046262741089
2: 2.082610845565796
3: 2.0302839279174805
4: 1.950796365737915
5: 1.9366321563720703
6: 1.9304484128952026
7: 1.8915820121765137
8: 1.9082778692245483
9: 1.8950530290603638
10: 1.903541922569275
11: 1.8849380016326904
12: 1.8615903854370117
13: 1.8616070747375488
14: 1.8663768768310547
15: 1.8151081800460815
16: 1.8269948959350586
17: 1.7989776134490967
18: 1.7762219905853271
19: 1.8355995416641235
20: 1.7632486820220947
21: 1.7997957468032837
22: 1.7665623426437378
23: 1.796034574508667
24: 1.7555702924728394
25: 1.7421404123306274
26: 1.7792954444885254
27: 1.7626760005950928
28: 1.7506550550460815
29: 1.7146005630493164
30: 1.757443904876709
31: 1.7261422872543335
32: 1.7469799518585205
33: 1.6877354383468628
34: 1.664712905883789
35: 1.68332839012146
36: 1.7092509269714355
37: 1.7096048593521118
38: 1.7060552835464478
39: 1.7173749208450317
40: 1.6967332363128662
41: 1.6413441896438599
42: 1.6868735551834106
43: 1.6824721097946167
44: 1.62

training:   8%|▊         | 80/1000 [00:10<01:55,  7.99it/s]

76: 1.5607818365097046
77: 1.5770814418792725
78: 1.5605520009994507
79: 1.5653212070465088
80: 1.5562183856964111
81: 1.5688064098358154
82: 1.5373679399490356
83: 1.563483715057373
84: 1.5488319396972656
85: 1.5493639707565308
86: 1.5115169286727905
87: 1.5328844785690308
88: 1.507446527481079
89: 1.5081279277801514
90: 1.5060762166976929
91: 1.5118932723999023
92: 1.484879732131958
93: 1.505183458328247
94: 1.4855293035507202
95: 1.5027811527252197
96: 1.484879970550537
97: 1.5067646503448486
98: 1.5097023248672485
99: 1.471545934677124
100: 1.47261381149292
input:   tensor([[3, 2, 3, 5, 4, 4, 5, 5]], device='cuda:0')
predicted output:   tensor([[5, 5, 3, 5, 4, 5, 5, 3]], device='cuda:0')
incorrects: 4
101: 1.4969087839126587
102: 1.4651380777359009
103: 1.470493197441101
104: 1.4818735122680664
105: 1.488682746887207
106: 1.4736229181289673
107: 1.4737321138381958
108: 1.47532320022583
109: 1.4914863109588623
110: 1.4362200498580933
111: 1.457634687423706
112: 1.4292559623718262
11

training:  24%|██▍       | 240/1000 [00:20<00:59, 12.69it/s]

236: 1.0902880430221558
237: 1.081052303314209
238: 1.0549606084823608
239: 1.0760844945907593
240: 1.0840264558792114
241: 1.0698133707046509
242: 1.0095059871673584
243: 1.0663758516311646
244: 1.0559353828430176
245: 1.0896846055984497
246: 1.0810338258743286
247: 1.1003810167312622
248: 1.049586296081543
249: 1.053931474685669
250: 1.0233408212661743
251: 1.0351519584655762
252: 1.0026068687438965
253: 1.0492510795593262
254: 1.0312682390213013
255: 1.0866893529891968
256: 1.0536963939666748
257: 1.165036678314209
258: 1.0329264402389526
259: 1.0781999826431274
260: 1.1418991088867188
261: 1.0525782108306885
262: 1.1282700300216675
263: 1.079772710800171
264: 1.0890165567398071
265: 1.0571612119674683
266: 1.0806219577789307
267: 1.0026655197143555
268: 0.9941757917404175
269: 1.052908182144165
270: 0.9783361554145813
271: 0.9724865555763245
272: 1.0585675239562988
273: 0.9573220014572144
274: 0.9888461232185364
275: 0.9448914527893066
276: 0.960413932800293
277: 0.955420970916748


training:  44%|████▍     | 442/1000 [00:30<00:34, 16.10it/s]

441: 0.601811408996582
442: 0.610418975353241
443: 0.6077300906181335
444: 0.6065456867218018
445: 0.598670482635498
446: 0.5911808013916016
447: 0.6210311651229858
448: 0.5992686748504639
449: 0.5920094847679138
450: 0.5882865786552429
451: 0.6107645630836487
452: 0.5904985070228577
453: 0.6043109893798828
454: 0.5935861468315125
455: 0.61053466796875
456: 0.5911848545074463
457: 0.568984866142273
458: 0.5940187573432922
459: 0.5673584342002869
460: 0.5650113821029663
461: 0.5935786366462708
462: 0.6045711636543274
463: 0.592810869216919
464: 0.5792158246040344
465: 0.5718005895614624
466: 0.5556021332740784
467: 0.5702114701271057
468: 0.5479696393013
469: 0.6133038997650146
470: 0.5605663061141968
471: 0.607807993888855
472: 0.6504755020141602
473: 0.5908305048942566
474: 0.5824753046035767
475: 0.6090877652168274
476: 0.5661080479621887
477: 0.648739218711853
478: 0.5733109712600708
479: 0.5824947953224182
480: 0.5828505158424377
481: 0.5962070822715759
482: 0.5691430568695068
483:

training:  65%|██████▍   | 649/1000 [00:40<00:19, 17.90it/s]

648: 0.47215867042541504
649: 0.47126367688179016
650: 0.483359694480896
651: 0.478702574968338
652: 0.4698694050312042
653: 0.46474602818489075
654: 0.4606288969516754
655: 0.4636602997779846
656: 0.4768472909927368
657: 0.4347838759422302
658: 0.4796478748321533
659: 0.46682673692703247
660: 0.47811317443847656
661: 0.4527319073677063
662: 0.43469133973121643
663: 0.45078423619270325
664: 0.4661328196525574
665: 0.46269136667251587
666: 0.447790265083313
667: 0.45798540115356445
668: 0.4305976331233978
669: 0.4625517427921295
670: 0.4697646498680115
671: 0.47915899753570557
672: 0.45719626545906067
673: 0.4540376365184784
674: 0.4600057899951935
675: 0.4484625458717346
676: 0.45437952876091003
677: 0.4759160280227661
678: 0.457138329744339
679: 0.45998746156692505
680: 0.4599747657775879
681: 0.4547473192214966
682: 0.44557660818099976
683: 0.4645410478115082
684: 0.47604161500930786
685: 0.44341298937797546
686: 0.42317017912864685
687: 0.44265058636665344
688: 0.4468829929828644
68

training:  86%|████████▌ | 860/1000 [00:50<00:07, 19.04it/s]

855: 0.43740659952163696
856: 0.40555137395858765
857: 0.42771580815315247
858: 0.42791157960891724
859: 0.41551193594932556
860: 0.4333088994026184
861: 0.40359586477279663
862: 0.3822217285633087
863: 0.4295499324798584
864: 0.4313490390777588
865: 0.40638571977615356
866: 0.4157223701477051
867: 0.41544803977012634
868: 0.3959784507751465
869: 0.4214736223220825
870: 0.42876091599464417
871: 0.4261586666107178
872: 0.401374489068985
873: 0.4101544916629791
874: 0.44430652260780334
875: 0.3925076723098755
876: 0.4062388241291046
877: 0.43048059940338135
878: 0.4050147831439972
879: 0.39779597520828247
880: 0.3920194208621979
881: 0.4086243212223053
882: 0.39729341864585876
883: 0.4435228407382965
884: 0.3947218656539917
885: 0.41359686851501465
886: 0.3983170688152313
887: 0.40293169021606445
888: 0.42927759885787964
889: 0.44767579436302185
890: 0.4112171530723572
891: 0.4098142087459564
892: 0.4218529462814331
893: 0.43339261412620544
894: 0.40432342886924744
895: 0.434281885623931

training: 100%|██████████| 1000/1000 [00:57<00:00, 17.49it/s]

997: 0.33476829528808594
998: 0.3896484673023224
999: 0.36537671089172363





In [6]:
print(f'Accuracy : {(len(src) - incorrects)/len(src)*100}%')

Accuracy : 81.25%
