In [2]:
import gym
import torch
%load_ext autoreload
%autoreload 2

In [3]:
from models.policy_gradient import VanillaPG

In [7]:
env = gym.make('LunarLander-v2')
agent = VanillaPG(env, {
	"policy": {
		"layers": [4]
	},
	"step_size": 1e-2,
	"batch_size": 2048,
	"gamma": 0.99
})

observation, info = env.reset(return_info=True)

In [9]:
from tqdm import tqdm
episodes = 20000
render_every = 1000
agent.reset()
for epoch in tqdm(range(episodes)):
	observation, info = env.reset(return_info=True)
	while True:
		observation = torch.from_numpy(observation).float()
		action = agent.get_action(observation)
		n_observation, reward, done, _= env.step(action)
		reward = reward ** 5
		agent.update(observation, action, reward, n_observation, done)
		observation = n_observation
		if epoch % render_every == 0:
			env.render()
		if done:
			break

  0%|          | 0/20000 [00:00<?, ?it/s]

0.8470598214007918
1.2360754951929789
1.4196856490270602
1.6493809260174725
1.8393961699504768
2.0085479488478213
1.539644426138126
0.3450662858904263
0.15397014421071845
-0.7774400750051029
-0.20424374321171854
-1.6063808332095164
-0.31188935825698993
-0.5215133886122192
-2.4123162427964915
-2.254198916554968
-0.49449068833175036
-2.589587018580174
-1.4786990455242517
-0.8802632088747362
-1.815272244479587
-1.2175924780765104
1.4675726082437677
-1.4123810650674204
-2.921241678674362
-1.7517696658210855
-3.16817173378456
-2.7991552387376815
-2.571288974520171
-2.656803639024018
2.9576174120497116
-2.1527246303206184
1.8078538092479561
-2.0242964515662707
0.5812294835378793
-2.066685334336681
-1.6135932698194029
3.412401489551695
2.59480553122192
2.9992627101762084
5.335448555203397
5.718021125664177
-1.323296934933353
3.9218161066639654
-1.676047934198009
-1.4016268338013151
-0.9550831003895439
5.819293442751598
4.48621546754016
-1.2727327597788178
-1.2095887552523255
-0.60039391639781

  0%|          | 1/20000 [00:03<21:05:38,  3.80s/it]

-0.3456507538801219
1.7694777318953225
13.756286447063674
4.986190855314291
-100
-0.4842529151374879
1.0225319199575449
1.212791652503115
1.695881608974646
-0.7414096498261802
1.264681531365967
1.7356145388460231
1.9933454199199343
0.7619876314779503
-0.7158905317895392
1.1970395985104585
1.6176069764055125
1.5488867042767833
-0.32940899813803526
-0.32309016793692424
-0.6594020116117167
-2.3352849482861004
-0.6117035948023204
-2.402699487921096
-2.251886173449178
-0.9409765873113816
-1.7698243584980844
-2.348193108198386
0.29178618326947686
-1.1626286895811393
-2.601259208491824
2.3451423501888486
-1.7135131708223525
-2.341954141410723
-1.9254937113935842
1.9385545467992642
1.8971107923330066
-0.9762599541138275
-1.5023925056584346
-1.1725809423557212
3.444812329627763
-1.515507891948572
3.5682611836030107
2.560094408852007
-2.0335357612498512
-2.398369232688195
-1.801557949087936
2.1814565441336358
4.02684180455513
-2.304807565613173
-1.7031692372277973
3.3918304956087413
-2.135737717

  0%|          | 2/20000 [00:05<14:58:18,  2.70s/it]

-2.307990092048243
8.856773725980805
11.351017476567412
1.9227469907935035
0.3019177209473405
-0.0469658353915576
-0.4369366005245634
0.3424493515105269
0.12930230982554328
-0.13397518478413303
-0.0007762238379000752
-0.1988842705397775
0.1441862192955299
-0.9468087657800244
-0.5474317361333056
0.5590603049093907
0.6106041841582031
-0.12714416143206606
-0.39559726535648987
0.44061543088640076
-0.533835505635962
0.38274098308235094
0.24589921125310887
-0.34486820445918265
0.32981756622566216
-0.43860943181638556
-0.5337022546607997
-0.5746359680065047
1.4376239813434648
-0.46404275965951025
-0.5948972510924613
-0.13110739127194534
-0.5726796203235949
1.3470465571426218
-0.5194532879635243
0.5059601849155975
-0.5126796036522918
0.18297597953067538
-0.05647553624386312
-0.07709011919987738
0.3014361032159165
-1.7889396051617663
0.1492215274154092
1.3649438382474621
-3.879128600513563
-18.19544418123774
-1.4590765307446076
11.403693420589189
-1.6154025755389163
12.121897514052305
-0.199974

  0%|          | 3/20000 [00:07<13:40:02,  2.46s/it]

-2.1006977216785154
3.0054610941113196
0.8520677650228532
-0.25379073281962405
-2.743706771774298
1.6721803584529205
-2.8073219675731047
-3.380827797469698
12.466935827881334
2.7428133049926275
-0.46007285601377634
-3.4714613724881453
11.888531023112357
1.3736370325261298
1.8766232465370436
-0.9936211954891963
-0.9948400271917694
-1.6902608643672739
1.414387513871003
-2.003885839447405
-2.2122968984384657
-1.9503050659527321
3.793470991332339
-1.6266879935370184
0.49382131808293367
3.2252491203209503
0.9107910461589028
-1.6890910840962192
3.795883096077131
1.8970378011744458
-1.9183773071172812
-1.487987965973075
0.59852283171939
2.565784174729129
-1.6628332299676731
2.601796637303676
2.403239421897905
2.254817061825247
-1.2942631603155224
4.323572790096972
-1.6333694003516246
1.1283461601625902
3.430046955035192
2.9511893784123915
0.9110632266327741
2.642802005996953
-2.0100339161982093
0.9548233037485374
-2.367011867011938
-0.13424058547461754
-2.6086302943025417
3.625121055761963
-2

  0%|          | 4/20000 [00:09<11:56:16,  2.15s/it]

-3.663508955176895
-17.702648993767937
-1.9475625513410335
9.414491267426751
-1.9283566383747963
12.5840058083844
0.7197247624133076
1.3936609103950037
-0.002191899729686142
-3.88777813680496
2.4834170921256606
1.170018277859904
-0.28496879911888295
0.30928131691289806
-0.5428525296571347
0.5606609148588877
-1.172577896461893
0.5969931348684512
-0.6391918627573154
0.5697536812330846
-0.6647985160675709
-0.5482526325191255
-0.2963100118734221
-2.4318497056950603
-18.495803278223605
0.18305587142840032
-2.2139937470269495
23.80784978395302
-3.157878372091134
-18.774641920455302
-1.8133742273987223
9.380668030612119
12.026955260505634
-2.253830432602686
1.9259614679838155
1.8823728147184013
0.21781376792381835
-0.05703746519986197
-0.5292456010214497
0.5653433472963929
-0.3938917697806008
-0.011114063131384383
-1.9296864408948349
-13.364776310977879
-8.274375600353356
-1.2551342057387818
0.9014047964696488
-0.2340365747359285
-2.3035359712228622
13.832000171907373
11.91772804743254
-0.275

  0%|          | 5/20000 [00:09<7:59:42,  1.44s/it] 

0.602999203134645
1.7195605600146677
1.1272393595346
-1.0590421338370521
2.2291922402898736
1.185041427192769
0.9883637784781041
2.193834876262838
-0.38741154830449775
1.0213935058855952
1.6426078716982186
1.7620652083642596
-2.0376057835219425
0.2969159513396676
-3.0814606044741866
-2.6071962731192273
2.3189036836100714
0.9252304029238274
-1.1373639071926391
0.9156718194482323
-2.5454491836539277
-2.363940038343428
1.1201484999265545
-1.0469106572006115
1.908175321985891
0.2244758794278255
-2.146847618525355
-2.0984298447263954
0.2007968111286516
1.192620212715778
-2.018719120319274
3.313700153308088
0.3083047844969798
1.2642850874941416
4.241545242028576
-1.7329603747799769
3.422495551185034
-1.664115485203708
-1.3279449843202957
-1.4383962762511675
3.85358610661157
2.8785847077750306
2.6325594582824694
-1.2340655005632175
2.6547237857017594
-0.01642510411518458
4.511527767268956
4.287654822136727
5.157734885785513
-0.15020786471052247
2.9014780397273343
1.569257891303249
-1.04941600

  0%|          | 6/20000 [00:09<5:43:01,  1.03s/it]

1.099710561416441
-0.38307854993199725
-0.46442511730733144
-0.1340761623431092
1.8117439965453002
-0.6330264630716613
2.8533288679774786
3.160035672234284
-2.6906364186552723
3.459561097006085
-0.17387946581685582
-2.5950809872912672
1.052791033923026
0.13691498159645904
-2.24418379102505
-2.28709628449151
1.8516189478274867
-1.9814879201124274
2.0949474476707737
0.19590123202646909
-1.9022940825902583
2.791438500331293
2.16572319418749
-1.3112346589299761
1.313262631973015
0.9671066607281034
-1.9546655500768526
-1.2433604246643029
1.6757168033961392
1.0958355271565778
3.2569540417035343
2.8469046475423285
3.2179695930948755
2.211438705531063
-1.8509422771763713
3.709136925049873
-1.6153945886013605
1.7456962368410671
-1.3783575643604695
-1.174268361188452
1.9455325996098594
2.379646714662198
3.482365617285933
4.862798014931411
1.2159964615534562
0.907874732231879
-0.9679694117941924
-0.8490606341519396
5.630896567289466
-0.9378964189907251
4.328065140492176
-1.1230085799131178
1.4870

  0%|          | 7/20000 [00:11<7:14:09,  1.30s/it]

-0.5298844977419421
-9.9896212923566
10.244739264768304
-2.163434962030513
11.911720654473983
1.3159892674852653
-0.42396889425137974
-0.7374262889942464
2.4437309648380765
0.16485025485933782
0.03489633927574268
0.3551951266470059
-0.9821708424948146
0.5645952494929993
0.1118208512737251
-0.03523155428320557
-3.161728519149425
-7.86731468874421
-2.961790600902356
-8.026405856103183
-1.9963755500748634
9.405197607039376
11.626549160981716
2.195796883558321
0.03268021359015208
-0.45738018687711235
0.56815362696946
-4.180525355839234
-17.183047456818283
-3.1357644943368674
0.921773924408485
-3.220420953661221
1.389369023664786
1.0075466512583688
-0.7769623438208757
-2.070740111388974
8.72128854065039
11.275467424613186
1.4409969445940807
2.712209070051221
-1.416013133858162
-1.1837661428109019
-3.6232650694575073
-7.9514718267877305
1.867808945151084
0.7465521226885354
-0.5462263357838981
0.3223308835067183
-9.218364271350659
1.4378128469539047
-0.8667170709288985
1.4835267727142718
1.62

  0%|          | 7/20000 [00:12<9:36:44,  1.73s/it]

3.626850212695939
2.614768452897988
1.1707508675180918
-1.8201762276287639
3.1572215358434734
3.1700392275079112
-2.490175753425858
0.3844989382787503
0.7612645153790651
-2.583489891583737
-1.540957406540939
1.291046096180051
-1.3410516232417098
-1.0369378152099944
2.575897362745093
3.454753668815937
4.577246301423625
-0.845287249482509
3.6325924198964517
2.9114311461695737
-0.4830202054678512
-1.6263223297557363
3.239243463913593
4.466723045941291
-0.24925400724775557
0.7723134711776083
-1.5936468809364885
3.795780257769505
0.07961382036176873
2.577394139211987
1.3299968273694105
2.323588934267261
-0.7727286536010922
1.690305517517072
3.360661282311182
-0.8650181144378024
1.9040889301136346
0.1271259665392293
3.2172956202557517
1.0297847615601483
0.298576130395891
-1.78278831639971
1.9153604875502424
-2.157427544365602
-1.6558428924251913
-2.126514578718434
0.4520580542921067
1.4841642442344323
1.2780505196625682
1.6348993694352287
0.05353367325911479
-0.5665752213382931
0.45013458777




KeyboardInterrupt: 

In [6]:
env.close()