In [1]:
%matplotlib notebook
from mpl_toolkits import mplot3d


import numpy as np
import matplotlib.pyplot as plt
import pybullet as p
import pybullet_data 
import time
import math
import random


In [2]:
physicsClient = p.connect(p.GUI)

In [3]:
p.setGravity(0,0,-10) 
p.resetSimulation() 
p.setAdditionalSearchPath(pybullet_data.getDataPath()) 
planeId = p.loadURDF("plane.urdf") 
robotId = p.loadURDF("iiwa7.urdf",flags=9, useFixedBase=1)

robotStartPos = [0,0,0]
robotStartOrientation = p.getQuaternionFromEuler([0,0,0])

p.resetBasePositionAndOrientation(robotId,robotStartPos,robotStartOrientation)

p.setJointMotorControlArray(robotId,range(7),p.VELOCITY_CONTROL,forces=np.zeros(7))

In [4]:
def simulate_system(x, u):
    x_next=[]
    for i in range(7):
        p.resetJointState(robotId,i,x[i],targetVelocity = x[i+7])
    
    p.setJointMotorControlArray(robotId,range(7), controlMode=p.TORQUE_CONTROL,forces=u)
    p.stepSimulation()
    for i in range(7):
        x_next.append(p.getJointStates(robotId,range(7))[i][0])
    for i in range(7):
        x_next.append(p.getJointStates(robotId,range(7))[i][1])
    x_next = np.array(x_next)
    return x_next

In [5]:
N=5
x = np.zeros([14,N])
u= np.zeros([7, N])
x_new=np.zeros([14,N])

In [6]:
for i in range(N):
    for j in range(7):
        a=random.randint(-180,180)#angule range(-pi,pi)
        b=random.uniform(-1, 10)#velocity range
        c=random.randint(-200,200)#torque range (-200,200)
        x[j,i]=math.radians(a/math.pi)
        x[j+7,i]=b
        u[j,i]=c
print(x)
print(u)

[[-5.66666667e-01  4.27777778e-01 -3.88888889e-02  3.94444444e-01
   2.66666667e-01]
 [ 6.55555556e-01 -2.27777778e-01  4.33333333e-01  9.33333333e-01
  -8.00000000e-01]
 [-6.16666667e-01 -3.61111111e-01 -2.38888889e-01 -1.94444444e-01
   7.27777778e-01]
 [-6.38888889e-01  6.22222222e-01 -6.61111111e-01 -2.50000000e-01
  -7.72222222e-01]
 [-7.00000000e-01 -9.00000000e-01  5.55555556e-03  9.16666667e-01
  -9.61111111e-01]
 [-3.72222222e-01  1.66666667e-01  7.16666667e-01  9.50000000e-01
  -2.61111111e-01]
 [ 5.83333333e-01 -5.94444444e-01  8.61111111e-01  7.33333333e-01
   7.16666667e-01]
 [ 2.83586616e+00  2.17248658e+00  7.98017623e+00  6.20267979e+00
   4.50269610e+00]
 [-2.04000061e-01  9.81500574e+00  5.06423678e+00  7.85705333e+00
   4.67239498e+00]
 [ 5.50254033e+00  3.84184786e+00  4.86672695e+00  3.38582452e-01
   3.55383663e+00]
 [ 9.36845923e+00  1.53405861e+00  8.94775787e+00  7.17375577e+00
   7.99424330e+00]
 [ 6.06092341e+00  1.96397650e+00  3.73234433e+00  7.07712279e+00

In [7]:
for i in range(N):
    x_new[:,i]=simulate_system(x[:,i], u[:,i])
print("new state is:", x_new)

new state is: [[-5.61095167e-01  4.80660387e-01 -1.75394828e-02  4.26643019e-01
   2.88307446e-01]
 [ 6.52617291e-01 -1.88149492e-01  4.55148212e-01  9.69206463e-01
  -7.76902732e-01]
 [-5.85444608e-01 -4.53334056e-01 -1.97013402e-01 -2.33913759e-01
   7.38411339e-01]
 [-6.07194076e-01  6.56204514e-01 -6.29452580e-01 -2.12402165e-01
  -7.40364803e-01]
 [-7.50126780e-01 -5.40930991e-01  3.41543799e-02  9.17551965e-01
  -9.73254165e-01]
 [-4.32613547e-01  2.88309438e-01  6.43563617e-01  1.03860478e+00
  -3.29286258e-01]
 [ 8.35422105e-01 -1.01111111e+00  1.14481798e+00  1.15000000e+00
   3.68499835e-01]
 [ 1.33715980e+00  1.26918262e+01  5.12385747e+00  7.72765801e+00
   5.19378708e+00]
 [-7.05183579e-01  9.51078868e+00  5.23557078e+00  8.60955123e+00
   5.54334430e+00]
 [ 7.49329411e+00 -2.21335067e+01  1.00501168e+01 -9.47263558e+00
   2.55205459e+00]
 [ 7.60675518e+00  8.15575001e+00  7.59804748e+00  9.02348047e+00
   7.64578065e+00]
 [-1.20304273e+01  8.61765621e+01  6.86371783e+00  

In [8]:
p.disconnect()

In [9]:
def compute_error(b,w,x,x_new,N):
    totalError = 0
    for i in range(0,N):
        xx = x[:,i]
        yy = x_new[:,i]
        totalError += ( yy - (w @ xx + b)) ** 2
    return totalError / float(N)

In [10]:
def step_gradient(b_current, w_current,x,x_new,learningRate):
    b_gradient = np.zeros([1,14])
    w_gradient = np.zeros([14,14])
    float(N)
    for i in range(0, N):
        xx = x[:,i]
        yy = x_new[:,i]
        b_gradient += -(2/N) * (yy - ((w_current @ xx)+ b_current))
        print(yy- ((w_current @ xx)+ b_current))
        w_gradient += -(2/N) * xx.reshape((14,1)) @ (yy- ((w_current @ xx)+ b_current))
    new_b = b_current - (learningRate * b_gradient)
    new_w = w_current - (learningRate * w_gradient)
    return [new_b, new_w]

In [11]:
def gradient_descent_runner(x,x_new,starting_b,starting_w,learning_rate,num_iterations):
    b = starting_b
    w = starting_w
    for i in range(num_iterations):
        b,w = step_gradient(b,w,x,x_new,learning_rate)
    return [b,w]

In [12]:
learning_rate = 0.0001
initial_b = np.zeros([1,14])
initial_w = np.zeros([14,14])
num_iterations = 1000
[b,w] = gradient_descent_runner(x,x_new,initial_b,initial_w,learning_rate,num_iterations)
error = compute_error(b,w,x,x_new,N)

    

[[ -0.56109517   0.65261729  -0.58544461  -0.60719408  -0.75012678
   -0.43261355   0.8354221    1.3371598   -0.70518358   7.49329411
    7.60675518 -12.0304273  -14.49391787  60.50130514]]
[[   0.48066039   -0.18814949   -0.45333406    0.65620451   -0.54093099
     0.28830944   -1.01111111   12.6918262     9.51078868  -22.1335067
     8.15575001   86.17656207   29.19426513 -100.        ]]
[[-1.75394828e-02  4.55148212e-01 -1.97013402e-01 -6.29452580e-01
   3.41543799e-02  6.43563617e-01  1.14481798e+00  5.12385747e+00
   5.23557078e+00  1.00501168e+01  7.59804748e+00  6.86371783e+00
  -1.75447320e+01  6.80896486e+01]]
[[  0.42664302   0.96920646  -0.23391376  -0.21240216   0.91755197
    1.03860478   1.15         7.72765801   8.60955123  -9.47263558
    9.02348047   0.21247166  21.26514637 100.        ]]
[[  0.28830745  -0.77690273   0.73841134  -0.7403648   -0.97325416
   -0.32928626   0.36849983   5.19378708   5.5433443    2.55205459
    7.64578065  -2.91433293 -16.36203534 -83.5600

   -5.16663076 -24.21858161 -29.43902714  54.92624441]]
[[  0.92711001  -3.16574525   0.9930965    0.9323314   -3.51856643
   -1.61585021  -2.05856789   6.33951749  10.41101139 -19.91546375
    3.28534602  79.72824805  22.16655103 -98.69182528]]
[[ -1.40554426   0.23961537   0.40978887  -2.50822924  -0.47389203
   -0.96242529   2.50417027   1.59266993 -14.57587675  10.73544687
   13.05632961   5.86695656 -20.97239397  63.55408849]]
[[ -1.6234602   -0.2878369    1.46812751  -3.57878104  -0.42903683
   -1.79737995   3.54360334   2.79343908 -23.1418962   -8.0220242
   18.4453456   -3.11683558  16.15744005  91.84215796]]
[[ -0.63301839  -1.97848888   1.68057422  -1.95727825  -2.25369927
   -2.16752026   0.91082899  -0.4905248  -10.26190353   3.07576207
    8.22078698  -7.24579793 -22.68384254 -87.24744045]]
[[ -1.14386797  -1.73606139   1.0069492   -0.84830116  -2.23693368
   -2.95672271   0.11009556 -13.38897006 -20.64782831   2.05025734
   -5.29039329 -24.3398326  -29.62172919  54.876105

[[ -3.23531552   0.58610534   0.32082557  -4.71676522  -1.78860494
   -2.90010885   4.1326122    0.32262347 -32.96037653  16.84171959
   26.19932792  10.32296942 -22.54976477  62.26246949]]
[[ -3.81624037  -0.24729773   1.98193873  -7.27317734  -2.37632947
   -4.19383817   6.54436156   6.11816908 -44.27652218   1.89411526
   42.39263051   5.10913461  19.86272874  90.18583364]]
[[ -1.9884274   -2.36675553   1.92114461  -3.55282136  -3.83790261
   -4.08514701   1.87143499  -3.50590712 -25.3404086    7.75873367
   16.10466094  -6.11841174 -26.39431625 -88.48503952]]
[[ -2.34056312  -3.41575993   1.81251934  -1.10745932  -3.91897606
   -5.78565554  -0.69653905 -29.42292059 -44.4322171   -1.25774507
  -16.74704537 -35.11874332 -47.28399323  50.05385126]]
[[  0.92622225  -4.17576671   1.59186661  -0.23616277  -5.7426519
   -2.56738256  -1.15028238  11.86374312  16.13759834 -11.40592742
   17.61558954  86.35896994  28.7969587  -94.43165586]]
[[ -3.27429048   0.59461483   0.31665185  -4.760155

   40.9506807   -1.28895426 -39.16207822 -93.93447801]]
[[  -7.17554929   -6.23002787    2.83737885   -2.2740549    -7.18641348
   -13.75951784   -2.55698506  -78.12676808 -128.26796146  -12.49739849
   -50.89556264  -66.6877729  -103.10813941   32.17311635]]
[[ -0.55439027  -3.3511865    1.567464    -5.83535811  -9.93680298
   -3.87379605   4.38431566  41.46982017  27.40755569  19.18539087
   81.78894081 120.39540018  61.91848862 -82.33012114]]
[[ -9.63805334   2.16184372  -0.53962603 -11.44818264  -5.61648363
   -9.23291264   8.87851507  -7.84532642 -99.20204465  34.78518474
   62.5049231   20.93328076 -34.17608434  56.06294669]]
[[ -10.63209644    1.90027196    1.59727227  -17.57479703   -7.53236454
   -10.31339911   15.1543118    19.78274002 -100.81318348   35.28241539
   117.24599414   35.61472096   32.89653421   89.10116678]]
[[ -7.02014102  -2.00640281   1.67701198  -8.75520154  -7.59194193
   -9.70076993   5.23524475 -13.08922451 -79.86714952  21.61430655
   41.38260798  -1.202

  -7.39077209e+01 -1.16303178e+02]]
[[ -22.79090416  -10.28705615    4.26935076   -7.32508173  -11.95515173
   -34.62880128   -5.46580872 -211.72199395 -388.76940844  -49.79613781
  -143.23073793 -154.37002978 -258.88854804  -28.28099488]]
[[ -7.12645514   4.25029095  -1.05760822 -25.32576464 -19.0037321
   -5.50434274  24.58758968 147.47053428  50.84256651 117.5045215
  299.5470873  237.76389759 176.18601088 -49.05369487]]
[[ -29.03212464    7.54239745   -3.01950998  -31.4581466   -14.39029244
   -26.96234788   23.18191866  -34.35580432 -308.74084399   79.58697819
   159.13873171   43.10257917  -74.08914553   28.71087999]]
[[ -30.22866131   10.84809409   -1.56483637  -46.05242342  -19.47540495
   -25.27045585   39.5717772    64.81165014 -257.43471973  129.49548401
   328.58349633  124.23878002   72.48621705   85.02320706]]
[[-2.26179947e+01  1.49056702e+00  1.29890304e-01 -2.48634087e+01
  -1.53509281e+01 -2.45351979e+01  1.65026734e+01 -3.68045909e+01
  -2.50553990e+02  5.76625618e+0

[[ -82.66089772   37.77523911  -10.88530499 -122.15586873  -45.69117621
   -61.46255937  105.95687098  193.02422831 -681.65472637  370.43892616
   885.53038326  355.66799121  182.9984297    63.30330449]]
[[ -65.41402318   13.83960866   -4.0159485   -70.29728201  -30.7718692
   -61.68605706   49.86420677  -90.87811286 -726.10400017  145.00971704
   323.93015108   53.45389074 -159.22054155 -190.30148637]]
[[  -66.84632499   -16.17189323     7.46888458   -24.36843635
    -17.3015638    -87.58470334    -8.63614644  -559.37782691
  -1119.87954814  -161.66108977  -381.53660903  -387.1540269
   -665.73284943  -207.80659139]]
[[-27.37162863  30.32931466 -10.63686689 -82.93000942 -40.97843196
   -7.28107443  85.19412564 466.50883225 110.7658103  401.46221116
  938.18767202 581.8403443  515.07931592  40.78112593]]
[[ -82.81623687   23.64083116   -8.62112113  -87.84684282  -32.72633821
   -73.30663373   64.52041098 -106.07149677 -909.07955253  184.91297697
   411.90809675   90.43881981 -184.27398

    495.97114181   -23.89743266]]
[[ -183.86379434    52.04819201   -14.07132565  -199.28216165
    -63.26728291  -158.96898649   147.42799092  -221.88688434
  -2060.37227902   364.33907963   900.26351646   161.6016794
   -370.83365192  -415.89328115]]
[[ -188.37514605   -25.3181655     16.77061817   -76.59962015
    -19.66215918  -225.65940055    -9.03530077 -1478.46719445
  -3141.44155275  -486.02476427 -1007.8940359  -1011.49678257
  -1739.77018053  -721.31715351]]
[[ -84.80655065  107.94105247  -38.88964763 -247.13163418  -97.56279069
    -7.84533104  259.33515269 1388.23366282  282.00600368 1201.06074588
  2753.7538155  1559.33542282 1490.37502832  286.89038568]]
[[ -230.48374365    70.40544458   -20.96071665  -245.50431462
    -71.23149378  -195.39263707   182.51475582  -296.96234967
  -2594.95881241   437.17633018  1085.98623596   200.4031436
   -477.72190951  -356.31881865]]
[[ -229.11340374   117.08771359   -35.85064642  -336.36625072
   -107.13808677  -156.40354662   294.9742

  -4.97187743e+03 -2.35238112e+03]]
[[-2.64192545e+02  3.58871698e+02 -1.27921916e+02 -7.68398650e+02
  -2.64805778e+02  1.24178180e-01  8.15435306e+02  4.34473642e+03
   8.46737917e+02  3.71398888e+03  8.51015305e+03  4.65981598e+03
   4.62015236e+03  1.05109628e+03]]
[[ -687.32082281   221.78462709   -51.73345303  -740.91058522
   -162.7135086   -561.75852493   559.29551981  -872.16178021
  -7896.75681485  1131.49510595  3137.67416238   506.72352816
  -1354.25087264 -1330.53351797]]
[[ -678.20021603   368.1284733   -106.94398519  -998.76571233
   -267.63000434  -435.39975276   884.31138054  1700.9269166
  -5687.92446109  2895.2357102   7093.85455897  2884.24643069
   1501.83951321  -369.50773624]]
[[ -556.74070943   180.71171878   -41.08467661  -613.74511747
   -142.52469893  -453.81461529   467.34668185  -595.60496666
  -6308.84836875   999.63736343  2724.8024693    500.15921406
   -982.1871493  -1176.88793624]]
[[-5.71179692e+02 -4.14673927e+01  4.92653845e+01 -2.52854955e+02
  -2.

  12375.88247693  2904.16513553]]
[[ -1790.58470262    599.52435155   -113.77529337  -1950.84140357
    -337.8603866   -1428.56117885   1489.76045489  -2235.3419243
  -20840.57726804   2665.69574539   8050.0503678    1205.62322864
   -3408.76298501  -3798.29350046]]
[[ -1761.93015255    986.00707727   -267.20458163  -2607.46353674
    -607.76220249  -1090.66831238   2323.41526707   4458.36071651
  -14997.32182819   7283.75614328  18286.4607496    7402.66185028
    3953.0062651   -1320.27405044]]
[[ -1452.0210725     502.71936653    -97.47763957  -1623.07279115
    -295.92495504  -1144.75207406   1256.56975537  -1436.30465763
  -16592.98155085   2431.82507153   7124.19740082   1314.37123258
   -2360.07318646  -3084.68109461]]
[[ -1490.21092922    -62.84281982    133.89301277   -694.24284885
      77.39715179  -1634.83595153     73.58557426 -10974.04266979
  -24924.02146768  -4163.32364481  -7431.38833141  -7546.03382544
  -12779.43176954  -6429.7782336 ]]
[[ -705.91481739   992.7295458 

   -6592.56635122  -9442.9820866 ]]
[[ -4460.18446816    -93.47438437    424.88210386  -2163.25826015
     419.24728268  -4787.23381404    349.1181978  -32327.95085769
  -74789.8042337  -12783.28946328 -21819.2527061  -22308.15695501
  -37515.824625   -19666.62666943]]
[[-2121.67969702  3066.98984278 -1048.68007456 -6312.24159627
  -1900.61195808   196.32869052  6770.70846333 36147.6215893
   7224.56771564 30092.18480495 69673.82034042 37663.98988083
  38401.24296244  9021.40756502]]
[[ -5419.24607417   1874.82702921   -287.38540352  -5965.48330056
    -798.71782424  -4235.6584771    4603.40953653  -6653.35891887
  -63758.14289508   7353.34300961  24124.02650778   3430.91634716
   -9990.78609288 -12190.55613951]]
[[ -5325.94670176   3048.12550871   -763.08085051  -7924.30248447
   -1605.20953079  -3203.47125439   7099.2970654   13516.89535048
  -46022.13770243  21298.54653375  54884.51184162  22113.22715806
   12104.55406663  -4754.33834233]]
[[ -4395.95475349   1594.19614584   -260.88

   -21425.19440867  -33269.08611914]]
[[-1.53290367e+04 -9.76507064e+01  1.54264304e+03 -7.66044899e+03
   1.91307677e+03 -1.62054852e+04  1.53212643e+03 -1.09902699e+05
  -2.57669021e+05 -4.48313790e+04 -7.39606507e+04 -7.60349380e+04
  -1.27151931e+05 -6.85386504e+04]]
[[ -7292.18154237  10778.00722392  -3604.89916008 -21988.11593548
   -6316.867497      909.85384523  23671.69387078 126584.45199385
   25862.00423753 104132.1822046  242561.37769656 131092.7130498
  134707.82737826  31365.46542498]]
[[ -18630.59626485    6620.28703631    -832.22504433  -20686.81345665
    -2139.14115751  -14330.23324048   16098.90081823  -22546.52092343
  -221004.38443943   23372.66776045   82455.82551834   11385.61825863
   -33401.22145579  -43519.66219436]]
[[ -18304.00081741   10646.94676133   -2474.08332454  -27361.6463708
    -4876.69677543  -10779.64197329   24615.98049445   46415.06599824
  -160248.31754993   71052.37763146  187501.68063307   75256.77121525
    42040.10021816  -18159.6043078 ]]


In [13]:
print(w@x[:,0]+b)

[[ 3.65741032e+04 -1.78869714e+01 -3.79021718e+03  1.85438487e+04
  -5.11429448e+03  3.83923420e+04 -4.04347609e+03  2.60914754e+05
   6.15618828e+05  1.08085330e+05  1.75317884e+05  1.80696281e+05
   3.01338683e+05  1.64622221e+05]]


In [14]:
print(x_new[:,0])
print(x[:,0])

[ -0.56109517   0.65261729  -0.58544461  -0.60719408  -0.75012678
  -0.43261355   0.8354221    1.3371598   -0.70518358   7.49329411
   7.60675518 -12.0304273  -14.49391787  60.50130514]
[-0.56666667  0.65555556 -0.61666667 -0.63888889 -0.7        -0.37222222
  0.58333333  2.83586616 -0.20400006  5.50254033  9.36845923  6.06092341
  2.78884156  2.71086552]
