In [1]:
import json
import pandas as pd
import numpy as np

In [2]:
from scipy.spatial.distance import cdist

In [3]:
with open('offices.json', encoding='utf-8') as f:
    offices = json.load(f)

In [4]:
kremlin_lat = 55.752
kremlin_lon = 37.618
degree_to_meter = 111319
moscow_rad = 70000 / degree_to_meter
office_latitudes = []
office_longitudes = []
max_people = []
opening_times = [[] for _ in range(7)]
closing_times = [[] for _ in range(7)]
business_morning = [[] for _ in range(7)]
business_day = [[] for _ in range(7)]
business_evening = [[] for _ in range(7)]

In [5]:
ord_to_day = {0: 'mn', 1: 'ts', 2: 'wd', 3: 'th', 4: 'fr', 5: 'st', 6: 'sn'}
day_to_ord = {'пн': 0, 'вт': 1, 'ср': 2, 'чт': 3, 'пт': 4, 'сб': 5, 'вс': 6}
ans = []
for obj in offices:
    if  'Не' in obj['openHours'][0]['days']:
        continue

    if len(obj['openHours']) == 7:
        ans.append((obj['openHours'] + [float(obj['latitude']), float(obj['longitude'])]).copy())
        continue

    new_dict_list = list()
    for l in obj['openHours']:
        if l["days"] == 'перерыв':
            for i in range(len(new_dict_list)):
                new_dict_list[i]['hours'] = new_dict_list[i]['hours'][:5] + '-' + l['hours'][:5] + ', ' + l['hours'][6:] + '-' + new_dict_list[i]['hours'][6:]
        elif len(l["days"]) > 2:
            s_front = l["days"][:2]
            s_back = l["days"][3:]

            for i in range(day_to_ord[s_front], day_to_ord[s_back] + 1):
                if i + 1 <= len(new_dict_list):
                    continue
                new_dict_list.append({'days': ord_to_day[i], "hours": l["hours"]})
        else:
            new_dict_list.append({"days": l["days"], "hours": l["hours"]})
    ans.append((new_dict_list + [float(obj['latitude']), float(obj['longitude'])]).copy())

In [6]:
ans

[[{'days': 'пн', 'hours': '09:00-18:00'},
  {'days': 'вт', 'hours': '09:00-18:00'},
  {'days': 'ср', 'hours': '09:00-18:00'},
  {'days': 'чт', 'hours': '09:00-18:00'},
  {'days': 'пт', 'hours': '09:00-17:00'},
  {'days': 'сб', 'hours': 'выходной'},
  {'days': 'вс', 'hours': 'выходной'},
  56.184479,
  36.984314],
 [{'days': 'mn', 'hours': '09:00-18:00'},
  {'days': 'ts', 'hours': '09:00-18:00'},
  {'days': 'wd', 'hours': '09:00-18:00'},
  {'days': 'th', 'hours': '09:00-18:00'},
  {'days': 'пт', 'hours': '09:00-17:00'},
  {'days': 'st', 'hours': 'выходной'},
  {'days': 'sn', 'hours': 'выходной'},
  56.183239,
  36.9757],
 [{'days': 'пн', 'hours': '10:00-19:00'},
  {'days': 'вт', 'hours': '10:00-19:00'},
  {'days': 'ср', 'hours': '10:00-19:00'},
  {'days': 'чт', 'hours': '10:00-19:00'},
  {'days': 'пт', 'hours': '10:00-18:00'},
  {'days': 'сб', 'hours': 'выходной'},
  {'days': 'вс', 'hours': 'выходной'},
  56.012386,
  37.482059],
 [{'days': 'пн', 'hours': '09:00-18:00'},
  {'days': 'вт'

In [7]:
for elem in ans:
    office_latitudes.append(elem[7])
    office_longitudes.append(elem[8])
    max_people.append(np.random.randint(1000, 3000))
    for i in range(7):
        opening_times[i].append(elem[i]['hours'][:5] if elem[i]['hours'] != 'выходной' else None)
        closing_times[i].append(elem[i]['hours'][-5:] if elem[i]['hours'] != 'выходной' else None)
        business_morning[i].append(round(np.random.uniform(0.01, 0.5), 2))
        business_day[i].append(round(np.random.uniform(0.01, 1), 2))
        business_evening[i].append(round(np.random.uniform(0.01, 0.5), 2))

In [8]:
len(office_latitudes)

160

In [9]:
len(office_longitudes)

160

In [10]:
len(max_people)

160

In [11]:
for i in range(7):
    print(len(opening_times[i]))
    print(len(closing_times[i]))
    print(len(business_morning[i]))
    print(len(business_day[i]))
    print(len(business_evening[i]))

160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160
160


In [12]:
db = pd.DataFrame()
db['Latitude'] = pd.Series(office_latitudes)

In [13]:
db

Unnamed: 0,Latitude
0,56.184479
1,56.183239
2,56.012386
3,56.010849
4,56.008335
...,...
155,55.478329
156,55.438740
157,55.432919
158,55.427883


In [14]:
db['Longitude'] = pd.Series(office_longitudes)
db['Max_People'] = pd.Series(max_people)
for i in range(7):
    db[f'opening_{ord_to_day[i]}'] = pd.Series(opening_times[i]).str[:2].fillna(0).astype(int)
    db[f'closing_{ord_to_day[i]}'] = pd.Series(closing_times[i]).str[:2].fillna(0).astype(int)
    db[f'morning_{ord_to_day[i]}'] = pd.Series(business_morning[i])
    db[f'day_{ord_to_day[i]}'] = pd.Series(business_day[i])
    db[f'evening_{ord_to_day[i]}'] = pd.Series(business_evening[i])

In [15]:
db

Unnamed: 0,Latitude,Longitude,Max_People,opening_mn,closing_mn,morning_mn,day_mn,evening_mn,opening_ts,closing_ts,...,opening_st,closing_st,morning_st,day_st,evening_st,opening_sn,closing_sn,morning_sn,day_sn,evening_sn
0,56.184479,36.984314,2879,9,18,0.24,0.24,0.16,9,18,...,0,0,0.20,0.88,0.09,0,0,0.13,0.20,0.46
1,56.183239,36.975700,2155,9,18,0.48,0.24,0.13,9,18,...,0,0,0.27,0.03,0.40,0,0,0.17,0.42,0.32
2,56.012386,37.482059,1423,10,19,0.13,0.93,0.30,10,19,...,0,0,0.13,0.45,0.38,0,0,0.31,0.88,0.04
3,56.010849,37.854359,1412,9,18,0.02,0.15,0.32,9,18,...,0,0,0.06,0.02,0.18,0,0,0.11,0.05,0.20
4,56.008335,37.851467,2390,9,18,0.41,0.37,0.33,9,18,...,0,0,0.21,0.61,0.21,0,0,0.34,0.23,0.06
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
155,55.478329,37.298706,1258,9,18,0.11,0.91,0.36,9,18,...,0,0,0.04,0.55,0.35,0,0,0.47,0.21,0.12
156,55.438740,37.767536,1928,9,18,0.19,0.75,0.17,9,18,...,0,0,0.19,0.49,0.38,0,0,0.14,0.62,0.15
157,55.432919,37.550838,1484,9,18,0.50,0.58,0.22,9,18,...,0,0,0.48,0.16,0.03,0,0,0.17,0.91,0.30
158,55.427883,37.550110,1259,9,18,0.19,0.90,0.05,9,18,...,0,0,0.38,0.96,0.15,0,0,0.18,0.56,0.46


In [16]:
examples_quantity = 8192
examples = pd.DataFrame()
usr_lat = []
usr_lon = []
days = [np.random.randint(0, 5) for i in range(examples_quantity)]
time = [np.random.uniform(9, 19) for i in range(examples_quantity)]
for i in range(examples_quantity):
    rho = np.random.uniform(0, moscow_rad)
    phi = np.random.uniform(0, 2 * np.pi)
    usr_lat.append(kremlin_lat + rho * np.sin(phi))
    usr_lon.append(kremlin_lon + rho * np.cos(phi))

In [17]:
examples['Latitude'] = usr_lat
examples['Longitude'] = usr_lon
examples['day'] = days
examples['time'] = time

In [18]:
examples

Unnamed: 0,Latitude,Longitude,day,time
0,55.529277,37.404163,0,17.145107
1,55.737108,37.631739,2,12.323854
2,55.916487,38.074830,3,16.221863
3,56.044643,37.516204,4,10.518354
4,55.861014,37.635104,4,17.448565
...,...,...,...,...
8187,55.855223,37.182133,0,12.594759
8188,55.664094,37.826010,3,16.869592
8189,56.078248,37.943674,0,9.726419
8190,55.259491,37.525568,4,18.889654


In [19]:
db

Unnamed: 0,Latitude,Longitude,Max_People,opening_mn,closing_mn,morning_mn,day_mn,evening_mn,opening_ts,closing_ts,...,opening_st,closing_st,morning_st,day_st,evening_st,opening_sn,closing_sn,morning_sn,day_sn,evening_sn
0,56.184479,36.984314,2879,9,18,0.24,0.24,0.16,9,18,...,0,0,0.20,0.88,0.09,0,0,0.13,0.20,0.46
1,56.183239,36.975700,2155,9,18,0.48,0.24,0.13,9,18,...,0,0,0.27,0.03,0.40,0,0,0.17,0.42,0.32
2,56.012386,37.482059,1423,10,19,0.13,0.93,0.30,10,19,...,0,0,0.13,0.45,0.38,0,0,0.31,0.88,0.04
3,56.010849,37.854359,1412,9,18,0.02,0.15,0.32,9,18,...,0,0,0.06,0.02,0.18,0,0,0.11,0.05,0.20
4,56.008335,37.851467,2390,9,18,0.41,0.37,0.33,9,18,...,0,0,0.21,0.61,0.21,0,0,0.34,0.23,0.06
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
155,55.478329,37.298706,1258,9,18,0.11,0.91,0.36,9,18,...,0,0,0.04,0.55,0.35,0,0,0.47,0.21,0.12
156,55.438740,37.767536,1928,9,18,0.19,0.75,0.17,9,18,...,0,0,0.19,0.49,0.38,0,0,0.14,0.62,0.15
157,55.432919,37.550838,1484,9,18,0.50,0.58,0.22,9,18,...,0,0,0.48,0.16,0.03,0,0,0.17,0.91,0.30
158,55.427883,37.550110,1259,9,18,0.19,0.90,0.05,9,18,...,0,0,0.38,0.96,0.15,0,0,0.18,0.56,0.46


In [20]:
points1 = db.loc[:, ['Latitude', 'Longitude']].values

In [21]:
points2 = examples.loc[:, ['Latitude', 'Longitude']].values

In [22]:
distances = cdist(points1, points2)
distances_df = pd.DataFrame(distances, index=db.index, columns=['dist' + str(idx) for idx in examples.index])
distances_df

Unnamed: 0,dist0,dist1,dist2,dist3,dist4,dist5,dist6,dist7,dist8,dist9,...,dist8182,dist8183,dist8184,dist8185,dist8186,dist8187,dist8188,dist8189,dist8190,dist8191
0,0.778179,0.786956,1.122962,0.549965,0.726745,0.895984,0.669559,0.888853,0.702914,0.726085,...,1.183028,0.284927,1.291589,0.162997,0.718722,0.384112,0.989572,0.965223,1.071708,0.916433
1,0.781823,0.793360,1.131036,0.557991,0.733923,0.903521,0.675242,0.896134,0.710768,0.732950,...,1.189901,0.293188,1.298833,0.168925,0.725530,0.387568,0.996262,0.973651,1.075018,0.922537
2,0.489348,0.313340,0.600478,0.046973,0.215259,0.370942,0.270041,0.369456,0.177080,0.231316,...,0.672675,0.248706,0.770612,0.378892,0.228072,0.338608,0.489499,0.466289,0.754151,0.448585
3,0.659233,0.352836,0.239816,0.339839,0.265562,0.166023,0.454854,0.214793,0.203015,0.306096,...,0.442535,0.602323,0.484061,0.748336,0.315852,0.690005,0.347912,0.111892,0.820147,0.409448
4,0.655421,0.349062,0.241510,0.337223,0.261756,0.162915,0.451036,0.211525,0.199719,0.302269,...,0.441224,0.600009,0.483846,0.745693,0.312027,0.686623,0.345181,0.115715,0.816686,0.406088
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
155,0.117120,0.421756,0.891264,0.606645,0.509521,0.636114,0.324424,0.589658,0.609908,0.466030,...,0.722299,0.655545,0.863681,0.630556,0.457131,0.394510,0.559069,0.880844,0.315208,0.438975
156,0.374482,0.327818,0.568042,0.655962,0.442553,0.413511,0.413289,0.368202,0.546494,0.413084,...,0.290553,0.857304,0.425877,0.918847,0.414401,0.718439,0.232816,0.663321,0.301129,0.196407
157,0.175494,0.314764,0.713026,0.612704,0.436310,0.493430,0.315955,0.442746,0.550471,0.394098,...,0.488725,0.756341,0.629352,0.783207,0.389830,0.560610,0.359390,0.755493,0.175259,0.252131
158,0.177711,0.319818,0.716983,0.617692,0.441392,0.498076,0.320745,0.447389,0.555554,0.399186,...,0.491122,0.760731,0.631615,0.786963,0.394918,0.563938,0.363203,0.760176,0.170171,0.256487


In [23]:
examples['day'].values

array([0, 2, 3, ..., 0, 4, 2], dtype=int64)

In [24]:
diff_matrix = [[0] * 8192 for _ in range(160)]
days = examples['day'].values
coming = examples['time'].values
for i in range(160):
    for j in range(8192):
        closing_hours = db[f'closing_{ord_to_day[days[j]]}'].values
        diff_matrix[i][j] = 1 / max(closing_hours[i] - coming[j], 0.0001)

In [25]:
diff_df = pd.DataFrame(diff_matrix, index=db.index, columns=['diff' + str(idx) for idx in examples.index])

In [26]:
diff_df

Unnamed: 0,diff0,diff1,diff2,diff3,diff4,diff5,diff6,diff7,diff8,diff9,...,diff8182,diff8183,diff8184,diff8185,diff8186,diff8187,diff8188,diff8189,diff8190,diff8191
0,1.169737,0.176176,0.562386,0.154282,10000.000000,0.419787,10000.000000,0.188150,0.163183,0.963802,...,0.172685,0.128274,0.37114,0.162481,0.204318,0.185006,0.884637,0.120867,10000.0,2.914480
1,1.169737,0.176176,0.562386,0.154282,10000.000000,0.419787,10000.000000,0.188150,0.163183,0.963802,...,0.172685,0.128274,0.37114,0.162481,0.204318,0.185006,0.884637,0.120867,10000.0,2.914480
2,0.539115,0.149787,0.359953,0.133660,1.813449,0.295669,6.679396,0.158355,0.140290,0.490784,...,0.147256,0.113691,0.27068,0.139771,0.169654,0.156122,0.469394,0.107833,10000.0,0.744538
3,1.169737,0.176176,0.562386,0.154282,10000.000000,0.419787,10000.000000,0.188150,0.163183,0.963802,...,0.172685,0.128274,0.37114,0.162481,0.204318,0.185006,0.884637,0.120867,10000.0,2.914480
4,1.169737,0.176176,0.562386,0.154282,10000.000000,0.419787,10000.000000,0.188150,0.163183,0.963802,...,0.172685,0.128274,0.37114,0.162481,0.204318,0.185006,0.884637,0.120867,10000.0,2.914480
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
155,1.169737,0.176176,0.562386,0.154282,10000.000000,0.419787,10000.000000,0.188150,0.163183,0.963802,...,0.172685,0.128274,0.37114,0.162481,0.204318,0.185006,0.884637,0.120867,10000.0,2.914480
156,1.169737,0.176176,0.562386,0.154282,10000.000000,0.419787,10000.000000,0.188150,0.163183,0.963802,...,0.172685,0.128274,0.37114,0.162481,0.204318,0.185006,0.884637,0.120867,10000.0,2.914480
157,1.169737,0.176176,0.562386,0.154282,10000.000000,0.419787,10000.000000,0.188150,0.163183,0.963802,...,0.172685,0.128274,0.37114,0.162481,0.204318,0.185006,0.884637,0.120867,10000.0,2.914480
158,1.169737,0.176176,0.562386,0.154282,10000.000000,0.419787,10000.000000,0.188150,0.163183,0.963802,...,0.172685,0.128274,0.37114,0.162481,0.204318,0.185006,0.884637,0.120867,10000.0,2.914480


In [27]:
busy_matrix = [[0] * 8192 for _ in range(160)]
days = examples['day'].values
time = examples['time'].values
for i in range(160):
    for j in range(8192):
        if time[j] < 12:
            business = db[f'morning_{ord_to_day[days[j]]}'].values
        elif time[j] > 15:
            business = db[f'evening_{ord_to_day[days[j]]}'].values
        else:
            business = db[f'day_{ord_to_day[days[j]]}'].values
        busy_matrix[i][j] = business[i]


In [28]:
busy_df = pd.DataFrame(busy_matrix, index=db.index, columns=['busy' + str(idx) for idx in examples.index])

In [29]:
busy_df

Unnamed: 0,busy0,busy1,busy2,busy3,busy4,busy5,busy6,busy7,busy8,busy9,...,busy8182,busy8183,busy8184,busy8185,busy8186,busy8187,busy8188,busy8189,busy8190,busy8191
0,0.16,0.50,0.41,0.01,0.07,0.46,0.46,0.24,0.01,0.07,...,0.01,0.01,0.46,0.01,0.24,0.24,0.41,0.24,0.07,0.27
1,0.13,0.16,0.37,0.30,0.18,0.29,0.29,0.24,0.13,0.18,...,0.30,0.13,0.29,0.30,0.24,0.24,0.37,0.48,0.18,0.46
2,0.30,0.30,0.35,0.43,0.36,0.05,0.05,0.93,0.08,0.36,...,0.43,0.08,0.05,0.43,0.93,0.93,0.35,0.13,0.36,0.17
3,0.32,0.08,0.09,0.22,0.46,0.16,0.16,0.15,0.21,0.46,...,0.22,0.21,0.16,0.22,0.15,0.15,0.09,0.02,0.46,0.02
4,0.33,0.75,0.35,0.17,0.06,0.06,0.06,0.37,0.39,0.06,...,0.17,0.39,0.06,0.17,0.37,0.37,0.35,0.41,0.06,0.06
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
155,0.36,0.83,0.21,0.33,0.35,0.26,0.26,0.91,0.40,0.35,...,0.33,0.40,0.26,0.33,0.91,0.91,0.21,0.11,0.35,0.02
156,0.17,0.12,0.24,0.20,0.04,0.48,0.48,0.75,0.26,0.04,...,0.20,0.26,0.48,0.20,0.75,0.75,0.24,0.19,0.04,0.47
157,0.22,0.87,0.19,0.26,0.22,0.23,0.23,0.58,0.21,0.22,...,0.26,0.21,0.23,0.26,0.58,0.58,0.19,0.50,0.22,0.22
158,0.05,0.89,0.12,0.26,0.38,0.38,0.38,0.90,0.47,0.38,...,0.26,0.47,0.38,0.26,0.90,0.90,0.12,0.19,0.38,0.07


In [30]:
vect = []
for i in range(8192):
    dist = distances_df.loc[:, f'dist{i}']
    diff = diff_df.loc[:, f'diff{i}']
    busy = busy_df.loc[:, f'busy{i}']
    result = pd.concat([dist, diff, busy], axis=1, join="outer")
    result = result.sort_values(by=[f'dist{i}', f'diff{i}', f'busy{i}'])
    vect.append(list(result.loc[[result.index[0], result.index[-1]]].values))

In [31]:
vect

[[array([0.09229051, 0.53911467, 0.02      ]),
  array([0.86316656, 1.1697371 , 0.33      ])],
 [array([0.0074266 , 0.17617587, 0.02      ]),
  array([0.79336001, 0.17617587, 0.16      ])],
 [array([0.07823996, 0.56238629, 0.2       ]),
  array([1.23498749, 0.56238629, 0.31      ])],
 [array([0.04697293, 0.13366043, 0.43      ]),
  array([0.97098638, 0.1542818 , 0.2       ])],
 [array([0.0192495 , 1.81344877, 0.23      ]),
  array([7.91963174e-01, 1.00000000e+04, 4.40000000e-01])],
 [array([0.05397494, 0.29566916, 0.42      ]),
  array([0.96880967, 0.41978732, 0.41      ])],
 [array([1.42253e-02, 1.00000e+04, 8.00000e-02]),
  array([8.37434304e-01, 1.00000000e+04, 1.00000000e-01])],
 [array([0.01982233, 0.15835511, 0.9       ]),
  array([0.9380748 , 0.18814956, 0.49      ])],
 [array([0.08571652, 0.14029032, 0.47      ]),
  array([0.83650099, 0.16318337, 0.27      ])],
 [array([0.02627201, 0.49078358, 0.1       ]),
  array([0.7651165 , 0.96380157, 0.44      ])],
 [array([0.02450419, 0.

In [32]:
X = np.reshape(np.array(vect), (-1, 3))

In [33]:
y = np.reshape(np.array([[0, 1] for _ in range(8192)]), (8192*2))

In [34]:
y = np.reshape(y, 8192*2)

In [35]:
from skopt import BayesSearchCV
from sklearn.datasets import load_digits
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split

In [40]:
from torch.autograd import Variable
from torchvision import datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data_utils
import torch
import matplotlib.pyplot as plt

import os
import pandas as pd
import skimage.io
from skimage.transform import resize


In [98]:
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')

CUDA is available!  Training on GPU ...


In [99]:
# разные режимы датасета 
DATA_MODES = ['train', 'val', 'test']
# работаем на видеокарте
DEVICE = torch.device("cuda")

EPOCHS = 1000
BATCH_SIZE = 256

In [100]:
class JustSigmoid(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Sequential(
                nn.Linear(3, 1)
            )
    def forward(self, x):
        return self.layer(x)

In [101]:
lean = JustSigmoid().to(DEVICE)
print(lean)

JustSigmoid(
  (layer): Sequential(
    (0): Linear(in_features=3, out_features=1, bias=True)
  )
)


In [102]:
from torch.utils.data import DataLoader, Dataset
criterion = F.mse_loss

optimizer = torch.optim.Adam(lean.parameters(), lr=0.001)
train_losses = []
val_losses = []
X = torch.Tensor(X)
y = torch.Tensor(y)
train_X, val_X, train_y, val_y = train_test_split(X, y, train_size=0.9)
# train_dataset = Dataset(train_X, train_y)
# val_dataset = Dataset(val_X, val_y)
train_loader = DataLoader(list(zip(train_X, train_y)), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(list(zip(val_X, val_y)), batch_size=BATCH_SIZE, shuffle=True)

In [103]:
from tqdm import notebook

for epoch in notebook.tqdm(range(EPOCHS)):
    lean.train()
    train_losses_per_epoch = []
    for X_batch, y_batch in train_loader:
        torch.cuda.empty_cache()
        optimizer.zero_grad()
        calculated = torch.reshape(lean(torch.Tensor(X_batch).to(DEVICE)), (-1,))
        loss = criterion(calculated, torch.Tensor(y_batch).to(DEVICE))
        loss.backward()
        optimizer.step()
        train_losses_per_epoch.append(loss.item())

    train_losses.append(np.mean(train_losses_per_epoch))

    lean.eval()
    val_losses_per_epoch = []
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            calculated = torch.reshape(lean(torch.Tensor(X_batch).to(DEVICE)), (-1,))
            loss = criterion(calculated, torch.Tensor(y_batch).to(DEVICE))
            val_losses_per_epoch.append(loss.item())

    val_losses.append(np.mean(val_losses_per_epoch))

  0%|          | 0/1000 [00:00<?, ?it/s]

In [104]:
with torch.no_grad():
    calc = torch.reshape(lean(torch.Tensor([[0.03095769, 0.2211782 , 0.93], [0.9236934 , 0.19648128, 0.55]]).to(DEVICE)), (-1,))

In [105]:
calc

tensor([-0.0135,  0.8823], device='cuda:0')

In [107]:
torch.save(lean.state_dict(), 'weights/lean')

In [109]:
java = torch.jit.trace(lean, torch.Tensor(val_X).to(DEVICE))

In [110]:
java.save('./model1.zip')