## IMPORT

In [2]:
import pandas as pd
import numpy as np
import timm
from PIL import Image
from tqdm import trange, tqdm
from utils import create_space
import geoio
import convert as conv

import torch
import torch.nn as nn
import torch.nn.parallel
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.models as models

from score_satellite import *

torch.manual_seed(31)
torch.cuda.manual_seed_all(31)
np.random.seed(31)
torch.manual_seed(31)
torch.cuda.manual_seed_all(31)
random.seed(31)

  from .autonotebook import tqdm as notebook_tqdm


# 数据处理

In [4]:
def add_nightlights(df, tif, tif_array):
    ''' 
    按经纬度匹配，计算区域内夜光强度，取均值
    This takes a dataframe with columns cluster_lat, cluster_lon and finds the average 
    nightlights in 2015 using a 10kmx10km box around the point
    
    I try all the nighlights tifs until a match is found, or none are left upon which an error is raised
    '''
    cluster_nightlights = []
    for i,r in df.iterrows():
        min_lat, min_lon, max_lat, max_lon = create_space(r.cluster_lat, r.cluster_lon,s=8)
        
        xminPixel, ymaxPixel = tif.proj_to_raster(min_lon, min_lat)
        xmaxPixel, yminPixel = tif.proj_to_raster(max_lon, max_lat)
        assert xminPixel < xmaxPixel, print(r.cluster_lat, r.cluster_lon)
        assert yminPixel < ymaxPixel, print(r.cluster_lat, r.cluster_lon)
        if xminPixel < 0 or xmaxPixel >= tif_array.shape[1]:
            print(f"no match for {r.cluster_lat}, {r.cluster_lon}")
            raise ValueError()
        elif yminPixel < 0 or ymaxPixel >= tif_array.shape[0]:
            print(f"no match for {r.cluster_lat}, {r.cluster_lon}")
            raise ValueError()
        xminPixel, yminPixel, xmaxPixel, ymaxPixel = int(xminPixel), int(yminPixel), int(xmaxPixel), int(ymaxPixel)
        cluster_nightlights.append(tif_array[yminPixel:ymaxPixel,xminPixel:xmaxPixel].mean())
        
    df['nightlights'] = cluster_nightlights

In [6]:
# 读取图片
images = os.listdir(data_path)
images = [x for x in images if 'png' in x]
data = pd.DataFrame()
data['name'] = images
data['cluster_lat'] = [conv.num2deg(conv.deg2num(float(x.strip('.png').split('_')[1]),float(x.strip('.png').split('_')[0]),12)[1],conv.deg2num(float(x.strip('.png').split('_')[1]),float(x.strip('.png').split('_')[0]),12)[0],12)[0] for x in images]
data['cluster_lon'] = [conv.num2deg(conv.deg2num(float(x.strip('.png').split('_')[1]),float(x.strip('.png').split('_')[0]),12)[1],conv.deg2num(float(x.strip('.png').split('_')[1]),float(x.strip('.png').split('_')[0]),12)[0],12)[1] for x in images]
# 读取夜光数据
NIGHTLIGHTS_DIR = '/home/chenshangbin/SatelliteImage_RGZ/VNL_v2_npp_2020_global_vcmslcfg_c202101211500.average.tif'
tif = geoio.GeoImage(NIGHTLIGHTS_DIR)
tif_array = np.squeeze(tif.get_data())
# 添加
add_nightlights(data, tif, tif_array)
display(data.head())

data[['cluster_lon','cluster_lat']].to_csv('./position.csv',index=False,encoding='utf_8_sig')
data.to_csv('./nightlights.csv',index=False)
data.columns=['y_x','cluster_lat','cluster_lon','nightlights']



Unnamed: 0,name,cluster_lat,cluster_lon,nightlights
0,-48.779296875_-66.93006025862447.png,43.771094,90.966797,0.382294
1,-44.82421875_-75.97355295343337.png,40.84706,120.058594,0.477784
2,-44.82421875_-70.98834922412489.png,40.84706,102.304688,0.273093
3,-39.111328125_-65.4034447883078.png,36.385913,87.275391,0.38557
4,-50.80078125_-77.86034459764656.png,45.213004,128.320312,0.341066


In [7]:
# 降采样使之平衡
def drop_in_range(df, lower=0, upper=2, fr=0.25):
    """
        Very similar to drop_0s calculation, but more generalized. Lower and upper are inclusive.
    """
    boolean_idx = ((lower <= df['nightlights']) & (df['nightlights'] <= upper))
    c_under = boolean_idx.sum()
    n = len(df)
    assert c_under / n > fr, print(f'Dataframe already has under {fr} rows in the given range')
    d = (c_under - n * fr) / (1 - fr)
    d = int(d)
    print(f'dropping: {d}')
    select_df = df[boolean_idx]
    drop_inds = random.sample(select_df.index.tolist(), d)
    return df.drop(drop_inds).reset_index(drop=True)
data_drop2 = drop_in_range(data, lower=0, upper=2, fr=0.6)
len(data_drop2)

dropping: 120822


11130

In [8]:
# 划分三个等级
from sklearn.mixture import GaussianMixture as GMM
X = data_drop2['nightlights'].values.reshape(-1,1)
gmm = GMM(n_components=3).fit(X)
labels = gmm.predict(data_drop2['nightlights'].values.reshape(-1,1))

label0_max = data_drop2['nightlights'][labels==0].max()
label1_max = data_drop2['nightlights'][labels==1].max()
label2_max = data_drop2['nightlights'][labels==2].max()
label0_max, label1_max, label2_max

(0.5020440816879272, 80.57090759277344, 301.58966064453125)

In [9]:
label0_max, label1_max, label2_max = 1, 10, 302

In [10]:
def create_nightlights_bin(df, cutoffs):
    assert len(cutoffs) >= 2, print('need at least 2 bins')
    cutoffs = sorted(cutoffs, reverse=True)
    labels = list(range(len(cutoffs)))[::-1]
    df['nightlights_bin'] = len(cutoffs)
    for cutoff, label in zip(cutoffs, labels):
        df['nightlights_bin'].loc[df['nightlights'] <= cutoff] = label

data_final = data_drop2.copy()
create_nightlights_bin(data_final, cutoffs=[label0_max, label1_max, label2_max])
print(data_final['nightlights_bin'].value_counts()/len(data_final))

0    0.580234
1    0.346721
2    0.073046
Name: nightlights_bin, dtype: float64


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  iloc._setitem_with_indexer(indexer, value)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  iloc._setitem_with_indexer(indexer, value)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  iloc._setitem_with_indexer(indexer, value)


In [11]:
data_final.to_csv('./nightlights_labeled.csv', index = False)
data_final[data_final['nightlights_bin'] == 0].iloc[:,[0,4]].to_csv('./nightlights_labeled0.csv', index = False)
data_final[data_final['nightlights_bin'] == 1].iloc[:,[0,4]].to_csv('./nightlights_labeled1.csv', index = False)
data_final[data_final['nightlights_bin'] == 2].iloc[:,[0,4]].to_csv('./nightlights_labeled2.csv', index = False)

# 微调

In [13]:
model = timm.models.resnet18(pretrained=True)
model.fc = nn.Linear(512, 1)

nature_csv = './nightlights_labeled0.csv'
rural_csv = './nightlights_labeled1.csv'
city_csv = './nightlights_labeled2.csv'
ckpt_nature = finetune(nature_csv, 'nature', data_path, model, hidden, num_epoch=10, bz=8)


In [28]:
ckpt_city = finetune(city_csv, 'city', data_path, model, hidden, num_epoch=10)

100%|██████████| 51/51 [00:04<00:00, 11.08it/s]


k-means loss evolution: [269.95010376 170.42271423 154.23361206 147.87632751 146.07366943
 145.19363403 144.4311676  143.70246887 143.18788147 142.90963745
 142.68092346 142.51338196 142.45046997 142.39775085 142.38806152
 142.38806152 142.38806152 142.38806152 142.38806152 142.38806152]
tensor([1, 6, 5, 4, 5, 5, 6, 7, 5, 4, 3, 2, 7, 0, 7, 5, 6, 1, 0, 0, 0, 6, 1, 6,
        4, 7, 5, 4, 5, 1, 1, 0, 2, 0, 4, 2, 2, 2, 5, 6, 2, 0, 2, 5, 0, 5, 1, 2,
        7, 0, 7, 3, 0, 3, 6, 0, 0, 2, 2, 4, 6, 7, 0, 2, 0, 3, 2, 0, 6, 7, 4, 1,
        1, 1, 6, 4, 2, 6, 5, 0, 1, 3, 2, 5, 5, 2, 6, 4, 2, 3, 0, 1, 3, 5, 0, 2,
        4, 1, 4, 1, 3, 6, 1, 3, 4, 7, 3, 1, 4, 4, 3, 1, 7, 0, 4, 1, 6, 1, 0, 2,
        2, 6, 6, 4, 7, 2, 0, 1, 7, 7, 3, 2, 6, 6, 0, 5, 7, 2, 6, 7, 1, 1, 7, 7,
        7, 6, 6, 0, 2, 6, 1, 2, 6, 4, 3, 7, 1, 7, 5, 6, 0, 3, 4, 0, 7, 1, 2, 4,
        2, 4, 0, 5, 5, 2, 5, 0, 1, 5, 1, 1, 6, 4, 6, 5, 2, 3, 7, 1, 0, 3, 3, 1,
        5, 1, 0, 2, 6, 0, 2, 7, 1, 0, 7, 6, 0, 6, 3, 5, 6, 4, 3, 7, 0, 

[BATCH_IDX : 40 LOSS : 5.6753387451171875 CE_LOSS 1.9650843143463135 AUG_LOSS : 3.710254669189453: 100%|██████████| 50/50 [00:08<00:00,  5.69it/s]


Epoch : 1


[BATCH_IDX : 40 LOSS : 6.290657997131348 CE_LOSS 1.8333112001419067 AUG_LOSS : 4.4573469161987305: 100%|██████████| 50/50 [00:08<00:00,  5.79it/s]


Epoch : 2


[BATCH_IDX : 40 LOSS : 5.875636100769043 CE_LOSS 1.9494022130966187 AUG_LOSS : 3.9262337684631348: 100%|██████████| 50/50 [00:08<00:00,  5.78it/s]


Epoch : 3


[BATCH_IDX : 40 LOSS : 6.126987457275391 CE_LOSS 1.8650805950164795 AUG_LOSS : 4.26190710067749: 100%|██████████| 50/50 [00:08<00:00,  5.86it/s] 


Epoch : 4


[BATCH_IDX : 40 LOSS : 5.780439376831055 CE_LOSS 1.3799078464508057 AUG_LOSS : 4.40053129196167: 100%|██████████| 50/50 [00:08<00:00,  5.79it/s] 


Epoch : 5


[BATCH_IDX : 40 LOSS : 5.981578826904297 CE_LOSS 1.4553724527359009 AUG_LOSS : 4.5262064933776855: 100%|██████████| 50/50 [00:08<00:00,  5.71it/s]


Epoch : 6


[BATCH_IDX : 40 LOSS : 6.461679935455322 CE_LOSS 1.4686241149902344 AUG_LOSS : 4.993055820465088: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]


Epoch : 7


[BATCH_IDX : 40 LOSS : 4.917314529418945 CE_LOSS 1.3138279914855957 AUG_LOSS : 3.6034862995147705: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]


Epoch : 8


[BATCH_IDX : 40 LOSS : 6.292108535766602 CE_LOSS 1.3651721477508545 AUG_LOSS : 4.926936626434326: 100%|██████████| 50/50 [00:08<00:00,  5.82it/s]


Epoch : 9


[BATCH_IDX : 40 LOSS : 4.321183204650879 CE_LOSS 1.274330496788025 AUG_LOSS : 3.0468528270721436: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s] 


In [None]:
ckpt_rural = finetune(rural_csv, 'rural', data_path, model, hidden, num_epoch=10)

100%|██████████| 242/242 [00:21<00:00, 11.23it/s]


k-means loss evolution: [950.95263672 579.44433594 557.51623535 554.2220459  553.41162109
 553.06768799 552.86029053 552.73168945 552.64178467 552.57086182
 552.54632568 552.53277588 552.50952148 552.48162842 552.46447754
 552.44732666 552.42730713 552.40118408 552.39770508 552.39770508]
tensor([7, 4, 4,  ..., 3, 0, 3], device='cuda:0')
k-means time: 3 s
(3859, 1280) (3859, 138)
Epoch : 0


[BATCH_IDX : 240 LOSS : 6.5894775390625 CE_LOSS 1.6722487211227417 AUG_LOSS : 4.917228698730469: 100%|██████████| 241/241 [00:42<00:00,  5.73it/s]  


Epoch : 1


[BATCH_IDX : 240 LOSS : 6.773061275482178 CE_LOSS 1.397537112236023 AUG_LOSS : 5.375524044036865: 100%|██████████| 241/241 [00:41<00:00,  5.85it/s]  


Epoch : 2


[BATCH_IDX : 240 LOSS : 5.365902900695801 CE_LOSS 1.06307053565979 AUG_LOSS : 4.302832126617432: 100%|██████████| 241/241 [00:41<00:00,  5.82it/s]  


Epoch : 3


[BATCH_IDX : 240 LOSS : 5.995803356170654 CE_LOSS 1.206323504447937 AUG_LOSS : 4.789479732513428: 100%|██████████| 241/241 [00:40<00:00,  5.98it/s]  


Epoch : 4


[BATCH_IDX : 240 LOSS : 6.216527938842773 CE_LOSS 0.8052971959114075 AUG_LOSS : 5.411230564117432: 100%|██████████| 241/241 [00:41<00:00,  5.82it/s]  


Epoch : 5


[BATCH_IDX : 240 LOSS : 4.3796916007995605 CE_LOSS 0.9812684655189514 AUG_LOSS : 3.398422956466675: 100%|██████████| 241/241 [00:41<00:00,  5.82it/s]


Epoch : 6


[BATCH_IDX : 240 LOSS : 5.700704574584961 CE_LOSS 0.8370754718780518 AUG_LOSS : 4.86362886428833: 100%|██████████| 241/241 [00:40<00:00,  5.89it/s]  


Epoch : 7


[BATCH_IDX : 240 LOSS : 5.3678131103515625 CE_LOSS 1.1099352836608887 AUG_LOSS : 4.257877826690674: 100%|██████████| 241/241 [00:40<00:00,  5.92it/s] 


Epoch : 8


[BATCH_IDX : 240 LOSS : 5.485835075378418 CE_LOSS 1.5133774280548096 AUG_LOSS : 3.9724578857421875: 100%|██████████| 241/241 [00:40<00:00,  5.88it/s] 


Epoch : 9


[BATCH_IDX : 0 LOSS : 5.719106674194336 CE_LOSS 0.7997424602508545 AUG_LOSS : 4.919363975524902:   7%|▋         | 16/241 [00:03<01:01,  3.65it/s]

# 聚类

In [15]:
ckpt_rural = 'finetune/rural.pt'
ckpt_city = 'finetune/city.pt'
ckpt_nature = 'finetune/nature.pt'

hidden = 512

city_cluster = extract_cluster(ckpt_city, city_csv, data_path, batch_size=10, hidden=hidden, classes_num=5, offset=10)
rural_cluster = extract_cluster(ckpt_rural, rural_csv, data_path, batch_size=10, hidden=hidden, classes_num=5, offset=5)
nature_cluster = extract_cluster(ckpt_nature, nature_csv, data_path, batch_size=10, hidden=hidden, classes_num=5, offset=0)
# 聚类情况
total_cluster = city_cluster + rural_cluster + nature_cluster
df = pd.DataFrame(total_cluster,columns=['y_x','cluster_id'])
df.groupby('cluster_id')['y_x'].count()

100%|██████████| 82/82 [00:04<00:00, 16.88it/s]


k-means loss evolution: [692.06567383 359.04537964 317.99319458 297.82156372 278.7286377
 257.70681763 241.53445435 232.88967896 227.38075256 224.72581482
 223.29537964 222.41459656 221.55838013 221.34605408 221.31034851
 221.27575684 221.25610352 221.24967957 221.24967957 221.24967957]
tensor([4, 1, 1, 2, 3, 3, 1, 0, 3, 4, 1, 3, 2, 2, 0, 1, 4, 4, 0, 2, 2, 4, 4, 3,
        0, 0, 1, 4, 1, 4, 4, 2, 2, 2, 4, 2, 3, 2, 1, 1, 2, 0, 2, 3, 2, 3, 1, 2,
        0, 2, 0, 3, 1, 2, 4, 0, 0, 0, 2, 4, 1, 0, 2, 1, 0, 0, 2, 2, 4, 0, 4, 4,
        4, 4, 1, 0, 1, 4, 1, 2, 4, 3, 2, 1, 3, 2, 1, 0, 1, 3, 2, 4, 0, 3, 2, 2,
        4, 1, 0, 4, 0, 1, 1, 0, 0, 2, 1, 4, 0, 4, 3, 4, 0, 2, 4, 4, 4, 3, 2, 3,
        2, 1, 1, 0, 0, 2, 2, 1, 0, 0, 3, 2, 1, 1, 2, 3, 0, 2, 4, 2, 4, 4, 2, 0,
        0, 1, 1, 0, 0, 1, 4, 2, 1, 0, 3, 0, 4, 0, 3, 1, 2, 3, 1, 3, 0, 2, 2, 0,
        0, 3, 2, 3, 1, 2, 3, 2, 4, 1, 4, 4, 4, 4, 1, 3, 2, 0, 0, 1, 2, 0, 3, 4,
        2, 4, 0, 2, 1, 2, 2, 0, 4, 2, 0, 4, 2, 4, 3, 3, 3, 0, 3, 0, 2, 2

100%|██████████| 386/386 [00:23<00:00, 16.55it/s]


k-means loss evolution: [2330.05395508 1065.93664551  953.18487549  929.06390381  922.53259277
  919.55456543  917.35046387  916.22247314  915.64428711  915.26043701
  915.11816406  914.97332764  914.8503418   914.80249023  914.77252197
  914.71911621  914.6887207   914.65869141  914.63250732  914.60552979]
tensor([4, 4, 4,  ..., 3, 0, 3], device='cuda:0')
k-means time: 4 s


100%|██████████| 646/646 [00:39<00:00, 16.49it/s]


k-means loss evolution: [2963.38085938 1798.71105957 1621.92468262 1483.58398438 1406.71118164
 1379.66967773 1369.79431152 1365.26037598 1362.39208984 1360.36425781
 1358.91516113 1357.88366699 1356.96801758 1356.27319336 1355.82788086
 1355.34423828 1355.0090332  1354.79931641 1354.62561035 1354.421875  ]
tensor([0, 2, 4,  ..., 2, 0, 4], device='cuda:0')
k-means time: 3 s


cluster_id
0     1068
1     1133
2     1615
3     1576
4     1066
5      942
6      588
7      655
8      859
9      815
10     162
11     149
12     210
13     126
14     166
Name: y_x, dtype: int64

In [16]:
total_cluster = city_cluster + rural_cluster + nature_cluster
cnum = len(df.groupby('cluster_id')['y_x'].count())
cluster_dir = './data/'
if not os.path.exists(cluster_dir):
    os.makedirs(cluster_dir, exist_ok=True)
for i in range(0, cnum + 1):
    os.makedirs(cluster_dir + str(i), exist_ok=True)

# 聚类情况写入文件
for i in range(0, cnum + 1):
    path=os.path.join(cluster_dir, str(i), 'cluster.csv')
    df[df['cluster_id']==i].to_csv(path, index = False)
df.to_csv('./unified.csv', index = False)

# 训练

In [4]:
class ClusterDataset(Dataset):
    def __init__(self, cluster_list, data_path, transform=None):
        self.file_list = []
        self.transform = transform      
        for cluster_num in cluster_list:
            file=list(pd.read_csv("./data/{}/cluster.csv".format(str(cluster_num)))['y_x'].values)
            self.file_list.extend(file)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        image = Image.open("./data12/train/data_zl12/{}".format(data_path, self.file_list[idx])).convert('RGB')
        if self.transform:
            image = self.transform(image).squeeze()
        return image

def se_graph_config(ordered_list, name):
    f = open(name, 'w')
        
    for i in range(len(ordered_list) - 1):
        f.write('{}<{}\n'.format(ordered_list[i+1][0], ordered_list[i][0]))
    
    for orders in ordered_list:
        if len(orders) >= 2:
            f.write(str(orders[0]))
            for element in orders[1:]:
                f.write('={}'.format(element))
            f.write('\n')        
    f.close()        
def graph_inference_nightlight(grid_df, nightlight_df, cluster_num, file_path):
    def numeric_compare(x, y):
        pop_list1 = df_merge_group.get_group(x)['nightlights'].tolist()
        pop_list2 = df_merge_group.get_group(y)['nightlights'].tolist()
        tTestResult = stats.ttest_ind(pop_list1, pop_list2)
        if (tTestResult.pvalue < 0.01) and (np.mean(pop_list1) < np.mean(pop_list2)):
            return 1
        elif (tTestResult.pvalue < 0.01) and (np.mean(pop_list1) >= np.mean(pop_list2)):
            return -1
        else:
            return 0
        
    df_merge = pd.merge(nightlight_df, grid_df, how='left', on='y_x')
    df_merge = df_merge.dropna()
    df_merge_group = df_merge.groupby('cluster_id')
    
    sorted_list = sorted(range(cluster_num - 1), key=cmp_to_key(numeric_compare))
    ordered_list = []
    ordered_list.append([sorted_list[0]])
    curr = 0
    for i in range(len(sorted_list) - 1):
        if numeric_compare(sorted_list[i], sorted_list[i+1]) == 0:
            ordered_list[curr].append(sorted_list[i+1])
        else:
            curr += 1
            ordered_list.append([sorted_list[i+1]])
            
    ordered_list.append([cluster_num - 1])        
    se_graph_config(ordered_list, file_path)
    return ordered_list

In [5]:
batch_size = 2
cluster_number = 14
grid_df = pd.read_csv('./unified.csv')
nightlight_df = pd.read_csv('./nightlights_labeled.csv')
graph_config = graph_inference_nightlight(grid_df, nightlight_df, cluster_number, './CNconfig')   
loader_list = generate_loader_dict(range(cluster_number), batch_size)
cluster_path_list=[]
for item in itertools.product(*graph_config):
    a=list(item)[::-1]
    cluster_path_list.append(a)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 加载自监督预训练权重
model.cuda()
model.to(device)

In [12]:
def train(epoch, model, optimizer, loader_list, cluster_path_list, device,batch_sz):
    model.train()
    train_loss = AverageMeter()
    reg_loss = AverageMeter()
    # For each cluster route
    path_idx = 0
    avg_loss = 0
    count = 0
    for cluster_path in cluster_path_list:
        path_idx += 1
        dataloaders = []
        for cluster_id in cluster_path:
            dataloaders.append(loader_list[cluster_id])
        pbar = tqdm(enumerate(zip(*dataloaders)), total=15*len(dataloaders))
        for batch_idx, data in pbar:
            cluster_num = len(data)
            data_zip = torch.cat(data, 0).to(device)
            # Generating Score
            scores = model(data_zip).squeeze()
            score_list = torch.split(scores, batch_sz, dim = 0)
            # Standard deviation as a loss
            loss_var = torch.zeros(1).to(device)
            for score in score_list:
                loss_var += score.var()
            loss_var /= len(score_list)
            # Differentiable Ranking with sigmoid function
            rank_matrix = torch.zeros((batch_sz, cluster_num, cluster_num)).to(device)
            for itertuple in list(itertools.permutations(range(cluster_num), 2)):
                score1 = score_list[itertuple[0]]
                score2 = score_list[itertuple[1]]
                diff = 30 * (score2 - score1)
                results = torch.sigmoid(diff)
                rank_matrix[:, itertuple[0], itertuple[1]] = results
                rank_matrix[:, itertuple[1], itertuple[0]] = 1 - results
            rank_predicts = rank_matrix.sum(1)
            temp = torch.Tensor(range(cluster_num))
            target_rank = temp.unsqueeze(0).repeat(batch_sz, 1).to(device)
            # Equivalent to spearman rank correlation loss
            loss_train = ((rank_predicts - target_rank)**2).mean()
            loss = loss_train + loss_var * 6
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss.update(loss_train.item(), batch_sz)
            reg_loss.update(loss_var.item(), batch_sz)
            avg_loss += loss.item()
            count += 1
            # Print status
            if batch_idx % 10 == 0:
                pbar.set_description('Epoch: [{epoch}][{path_idx}][{elps_iters}] '
                      'Train loss: {train_loss.val:.4f} ({train_loss.avg:.4f}) '
                      'Reg loss: {reg_loss.val:.4f} ({reg_loss.avg:.4f})'.format(
                          epoch=epoch, path_idx=path_idx, elps_iters=batch_idx, train_loss=train_loss, reg_loss=reg_loss))
    return avg_loss / count

In [14]:
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
best_loss = float('inf')
for epoch in range(50):          
    loss = train(epoch, model, optimizer, loader_list, cluster_path_list, device, batch_size)
    if epoch % 5 == 0 and epoch != 0:                
        if best_loss > loss:
            torch.save(model, './efficient-unlimited.pt')
            best_loss = loss
            print("best loss: %.4f\n" % (best_loss))
    
torch.save(model, './efficient-last.pt')

Epoch: [0][1][60] Train loss: 3.8935 (6.1770) Reg loss: 0.0817 (0.0840):  42%|████▏     | 63/150 [00:16<00:22,  3.86it/s] 
Epoch: [0][2][60] Train loss: 3.1924 (5.0421) Reg loss: 0.0604 (0.0838):  42%|████▏     | 63/150 [00:16<00:22,  3.88it/s]
Epoch: [0][3][60] Train loss: 1.2501 (4.4108) Reg loss: 0.0523 (0.0787):  42%|████▏     | 63/150 [00:16<00:22,  3.87it/s]
Epoch: [0][4][60] Train loss: 3.4789 (4.0344) Reg loss: 0.0525 (0.0757):  42%|████▏     | 63/150 [00:16<00:23,  3.77it/s]
Epoch: [0][5][60] Train loss: 0.8108 (3.6821) Reg loss: 0.0852 (0.0710):  42%|████▏     | 63/150 [00:16<00:22,  3.85it/s]
Epoch: [0][6][60] Train loss: 2.6006 (3.4577) Reg loss: 0.0362 (0.0686):  42%|████▏     | 63/150 [00:16<00:22,  3.86it/s]
Epoch: [0][7][60] Train loss: 1.2754 (3.3439) Reg loss: 0.0410 (0.0679):  42%|████▏     | 63/150 [00:16<00:22,  3.86it/s]
Epoch: [0][8][60] Train loss: 1.4825 (3.1748) Reg loss: 0.0304 (0.0677):  42%|████▏     | 63/150 [00:16<00:22,  3.84it/s]
Epoch: [0][9][60] Train

best loss: 0.4010



Epoch: [6][1][60] Train loss: 0.1340 (0.3439) Reg loss: 0.0038 (0.0071):  42%|████▏     | 63/150 [00:16<00:22,  3.85it/s]
Epoch: [6][2][60] Train loss: 0.6072 (0.3277) Reg loss: 0.0104 (0.0062):  42%|████▏     | 63/150 [00:16<00:22,  3.90it/s]
Epoch: [6][3][60] Train loss: 0.2428 (0.3167) Reg loss: 0.0029 (0.0064):  42%|████▏     | 63/150 [00:16<00:22,  3.87it/s]
Epoch: [6][4][60] Train loss: 0.0130 (0.3137) Reg loss: 0.0035 (0.0062):  42%|████▏     | 63/150 [00:16<00:23,  3.77it/s]
Epoch: [6][5][60] Train loss: 0.1138 (0.3071) Reg loss: 0.0125 (0.0062):  42%|████▏     | 63/150 [00:16<00:22,  3.89it/s]
Epoch: [6][6][60] Train loss: 0.8333 (0.2895) Reg loss: 0.0200 (0.0062):  42%|████▏     | 63/150 [00:16<00:22,  3.93it/s]
Epoch: [6][7][60] Train loss: 0.6305 (0.2952) Reg loss: 0.0062 (0.0062):  42%|████▏     | 63/150 [00:16<00:22,  3.86it/s]
Epoch: [6][8][60] Train loss: 0.5792 (0.2972) Reg loss: 0.0181 (0.0063):  42%|████▏     | 63/150 [00:16<00:22,  3.90it/s]
Epoch: [6][9][60] Train 

best loss: 0.1494



Epoch: [11][1][60] Train loss: 0.1864 (0.1546) Reg loss: 0.0043 (0.0045):  42%|████▏     | 63/150 [00:16<00:22,  3.90it/s]
Epoch: [11][2][60] Train loss: 0.0313 (0.1276) Reg loss: 0.0022 (0.0042):  42%|████▏     | 63/150 [00:16<00:22,  3.88it/s]
Epoch: [11][3][60] Train loss: 0.0172 (0.1235) Reg loss: 0.0022 (0.0043):  42%|████▏     | 63/150 [00:16<00:23,  3.72it/s]
Epoch: [11][4][60] Train loss: 0.2100 (0.1390) Reg loss: 0.0073 (0.0049):  42%|████▏     | 63/150 [00:16<00:22,  3.88it/s]
Epoch: [11][5][60] Train loss: 0.6350 (0.1631) Reg loss: 0.0106 (0.0059):  42%|████▏     | 63/150 [00:16<00:22,  3.93it/s]
Epoch: [11][6][60] Train loss: 0.2176 (0.1681) Reg loss: 0.0236 (0.0063):  42%|████▏     | 63/150 [00:16<00:22,  3.88it/s]
Epoch: [11][7][60] Train loss: 0.1933 (0.1697) Reg loss: 0.0036 (0.0062):  42%|████▏     | 63/150 [00:16<00:22,  3.92it/s]
Epoch: [11][8][60] Train loss: 0.0339 (0.1639) Reg loss: 0.0026 (0.0060):  42%|████▏     | 63/150 [00:16<00:22,  3.92it/s]
Epoch: [11][9][6

best loss: 0.0728



Epoch: [16][1][60] Train loss: 0.1707 (0.0853) Reg loss: 0.0060 (0.0030):  42%|████▏     | 63/150 [00:16<00:22,  3.83it/s]
Epoch: [16][2][60] Train loss: 0.0068 (0.0675) Reg loss: 0.0014 (0.0029):  42%|████▏     | 63/150 [00:16<00:22,  3.86it/s]
Epoch: [16][3][60] Train loss: 0.0062 (0.0627) Reg loss: 0.0011 (0.0029):  42%|████▏     | 63/150 [00:16<00:22,  3.92it/s]
Epoch: [16][4][60] Train loss: 0.0134 (0.0643) Reg loss: 0.0016 (0.0029):  42%|████▏     | 63/150 [00:16<00:22,  3.78it/s]
Epoch: [16][5][60] Train loss: 0.0874 (0.0588) Reg loss: 0.0048 (0.0027):  42%|████▏     | 63/150 [00:15<00:21,  3.96it/s]
Epoch: [16][6][60] Train loss: 0.0507 (0.0547) Reg loss: 0.0044 (0.0027):  42%|████▏     | 63/150 [00:16<00:22,  3.86it/s]
Epoch: [16][7][60] Train loss: 0.0783 (0.0567) Reg loss: 0.0020 (0.0027):  42%|████▏     | 63/150 [00:16<00:22,  3.90it/s]
Epoch: [16][8][60] Train loss: 0.0387 (0.0544) Reg loss: 0.0037 (0.0026):  42%|████▏     | 63/150 [00:16<00:22,  3.88it/s]
Epoch: [16][9][6

best loss: 0.0444



Epoch: [21][1][60] Train loss: 0.0258 (0.0464) Reg loss: 0.0030 (0.0023):  42%|████▏     | 63/150 [00:16<00:23,  3.75it/s]
Epoch: [21][2][60] Train loss: 0.0048 (0.0354) Reg loss: 0.0014 (0.0020):  42%|████▏     | 63/150 [00:16<00:23,  3.77it/s]
Epoch: [21][3][60] Train loss: 0.0249 (0.0333) Reg loss: 0.0015 (0.0020):  42%|████▏     | 63/150 [00:16<00:22,  3.81it/s]
Epoch: [21][4][60] Train loss: 0.1290 (0.0330) Reg loss: 0.0072 (0.0020):  42%|████▏     | 63/150 [00:16<00:23,  3.75it/s]
Epoch: [21][5][60] Train loss: 0.0368 (0.0335) Reg loss: 0.0022 (0.0020):  42%|████▏     | 63/150 [00:16<00:23,  3.76it/s]
Epoch: [21][6][60] Train loss: 0.0257 (0.0332) Reg loss: 0.0011 (0.0020):  42%|████▏     | 63/150 [00:16<00:22,  3.85it/s]
Epoch: [21][7][60] Train loss: 0.0036 (0.0346) Reg loss: 0.0020 (0.0020):  42%|████▏     | 63/150 [00:16<00:22,  3.86it/s]
Epoch: [21][8][60] Train loss: 0.0041 (0.0336) Reg loss: 0.0009 (0.0020):  42%|████▏     | 63/150 [00:16<00:22,  3.94it/s]
Epoch: [21][9][6

best loss: 0.0356



Epoch: [26][1][60] Train loss: 0.0065 (0.0370) Reg loss: 0.0028 (0.0021):  42%|████▏     | 63/150 [00:16<00:22,  3.82it/s]
Epoch: [26][2][60] Train loss: 0.0011 (0.0283) Reg loss: 0.0016 (0.0020):  42%|████▏     | 63/150 [00:16<00:22,  3.93it/s]
Epoch: [26][3][60] Train loss: 0.0545 (0.0316) Reg loss: 0.0015 (0.0020):  42%|████▏     | 63/150 [00:16<00:22,  3.91it/s]
Epoch: [26][4][60] Train loss: 0.0228 (0.0327) Reg loss: 0.0029 (0.0019):  42%|████▏     | 63/150 [00:16<00:22,  3.92it/s]
Epoch: [26][5][60] Train loss: 0.0013 (0.0315) Reg loss: 0.0003 (0.0020):  42%|████▏     | 63/150 [00:16<00:23,  3.78it/s]
Epoch: [26][6][60] Train loss: 0.0129 (0.0309) Reg loss: 0.0018 (0.0020):  42%|████▏     | 63/150 [00:18<00:24,  3.49it/s]
Epoch: [26][7][60] Train loss: 0.0385 (0.0351) Reg loss: 0.0022 (0.0020):  42%|████▏     | 63/150 [00:18<00:25,  3.43it/s]
Epoch: [26][8][60] Train loss: 0.0058 (0.0346) Reg loss: 0.0017 (0.0020):  42%|████▏     | 63/150 [00:18<00:25,  3.42it/s]
Epoch: [26][9][6

best loss: 0.0211



Epoch: [31][1][60] Train loss: 0.0056 (0.0161) Reg loss: 0.0017 (0.0014):  42%|████▏     | 63/150 [00:17<00:23,  3.69it/s]
Epoch: [31][2][60] Train loss: 0.0074 (0.0145) Reg loss: 0.0007 (0.0014):  42%|████▏     | 63/150 [00:16<00:22,  3.82it/s]
Epoch: [31][3][60] Train loss: 0.0007 (0.0139) Reg loss: 0.0005 (0.0014):  42%|████▏     | 63/150 [00:17<00:24,  3.54it/s]
Epoch: [31][4][60] Train loss: 0.0015 (0.0146) Reg loss: 0.0005 (0.0014):  42%|████▏     | 63/150 [00:18<00:25,  3.43it/s]
Epoch: [31][5][60] Train loss: 0.0318 (0.0183) Reg loss: 0.0029 (0.0014):  42%|████▏     | 63/150 [00:18<00:25,  3.47it/s]
Epoch: [31][6][60] Train loss: 0.0060 (0.0179) Reg loss: 0.0010 (0.0014):  42%|████▏     | 63/150 [00:18<00:24,  3.48it/s]
Epoch: [31][7][60] Train loss: 0.0629 (0.0198) Reg loss: 0.0015 (0.0014):  42%|████▏     | 63/150 [00:18<00:25,  3.48it/s]
Epoch: [31][8][60] Train loss: 0.0041 (0.0190) Reg loss: 0.0018 (0.0014):  42%|████▏     | 63/150 [00:18<00:25,  3.48it/s]
Epoch: [31][9][6

best loss: 0.0180



Epoch: [41][1][60] Train loss: 0.0053 (0.0085) Reg loss: 0.0027 (0.0009):  42%|████▏     | 63/150 [00:16<00:22,  3.82it/s]
Epoch: [41][2][60] Train loss: 0.0013 (0.0111) Reg loss: 0.0008 (0.0009):  42%|████▏     | 63/150 [00:16<00:22,  3.83it/s]
Epoch: [41][3][60] Train loss: 0.0013 (0.0104) Reg loss: 0.0008 (0.0010):  42%|████▏     | 63/150 [00:15<00:21,  3.98it/s]
Epoch: [41][4][60] Train loss: 0.0006 (0.0094) Reg loss: 0.0007 (0.0009):  42%|████▏     | 63/150 [00:16<00:22,  3.89it/s]
Epoch: [41][5][60] Train loss: 0.0055 (0.0096) Reg loss: 0.0007 (0.0009):  42%|████▏     | 63/150 [00:16<00:22,  3.90it/s]
Epoch: [41][6][60] Train loss: 0.0019 (0.0096) Reg loss: 0.0007 (0.0009):  42%|████▏     | 63/150 [00:15<00:22,  3.94it/s]
Epoch: [41][7][60] Train loss: 0.0019 (0.0099) Reg loss: 0.0007 (0.0009):  42%|████▏     | 63/150 [00:16<00:23,  3.75it/s]
Epoch: [41][8][60] Train loss: 0.1532 (0.0112) Reg loss: 0.0027 (0.0009):  42%|████▏     | 63/150 [00:16<00:22,  3.93it/s]
Epoch: [41][9][6