In [1]:
import torch
from torch import nn
from torch.utils import model_zoo
import torchvision
import numpy as np
import pandas as pd

In [701]:
def chunk_iterator(df, chunksize=1, limit=None):
    length = len(df)
    i = 0
    
    while True:
        mn = i * chunksize
        mx = mn + min(chunksize, length)

        if limit is not None:
            mx = min(mx, limit)
        
        yield df[mn:mx]
        i += 1
        
        if mx >= length or (limit is not None and mx >= limit):
            return
        

In [6]:
train_file = '../data/train.csv'
validation_file = '../data/validation.csv'

In [7]:
def read_string_array(arr, dtype=None):
    return np.fromstring(arr[1:-1], dtype=None, sep=' ')

In [8]:
def read_string_ndarray(arr, dtype=None):
    return np.array([read_string_array(target, dtype=None) for target in arr])

In [9]:
def read_chunk(chunk, device=torch.device('cpu')):
    X = np.array(chunk['image'].values.tolist(), dtype=np.float32)
    y = chunk['label'].to_numpy(dtype=np.int)
    
    
    X = torch.from_numpy(X).to(device)
    y = torch.from_numpy(y).to(device)
    
    return X, y

In [40]:
class BrootForceCosineSimiliaritySearch(nn.Module):
    def __init__(self, X_train, y_train, treshold, dim=1):
        super(BrootForceCosineSimiliaritySearch, self).__init__()
        self.similiarity = torch.nn.CosineSimilarity(dim=dim)
        self.X_train = X_train
        self.y_train = y_train
        self.treshold = treshold
        
    def set_treshold(self, treshold):
        self.treshold = treshold
    
    
    def forward(self, X, chunksize=None):
        if chunksize:
            sim = [torch.stack([self.similiarity(torch.stack([x]), x_train) for x in X])
                    for x_train in chunk_iterator(self.X_train.to(X.device), chunksize=chunksize)]
            sim = torch.cat(sim, 1)
        else:
            sim = torch.stack([self.similiarity(torch.stack([x]).to(X.device), self.X_train) for x in X])

        index = sim.argmax(axis=1)
        
        if sim[0][index] > self.treshold:
            return self.y_train[index]

        return None

In [591]:
class BatchedBrootForceCosineSimiliaritySearch(nn.Module):
    def __init__(self, X_train, y_train, treshold, dim=1):
        super(BatchedBrootForceCosineSimiliaritySearch, self).__init__()
        self.similiarity = torch.nn.CosineSimilarity(dim=dim)
        self.X_train = X_train
        self.y_train = y_train
        self.treshold = treshold
        
    def set_treshold(self, treshold):
        self.treshold = treshold
    
    
    def forward(self, X, chunksize=10000):
        result = []
        i = 0
        
        for x in chunk_iterator(X, chunksize=1):
            count = 0
            for chunk in chunk_iterator(self.X_train.to(X.device), chunksize=chunksize):
                sim = self.similiarity(x, chunk)
                index = sim.argmax()

                if sim[index] > self.treshold:
                    index = index + count * chunksize
                    result.append(self.y_train[index])
                    break

                count += 1
            if i == len(result):
                result.append(torch.tensor(-1))
            i+=1

        return torch.stack(result)

In [81]:
from sklearn.cluster import MiniBatchKMeans

In [522]:
class KMeansCosineSimiliaritySearch(nn.Module):
    def __init__(self, X_train, y_train, clusters, treshold, n_clusters=128, dim=1):
        super(KMeansCosineSimiliaritySearch, self).__init__()
        self.kmeans = MiniBatchKMeans(n_clusters=n_clusters, random_state=42, compute_labels=True)
        self.similiarity = torch.nn.CosineSimilarity(dim=dim)
        self.X_train = X_train
        self.y_train = y_train
        self.treshold = treshold
        self.clusters = clusters
        
    def set_treshold(self, treshold):
        self.treshold = treshold
        
    def search_in_cluster(self, x, cluster, chunksize):
        X_train = self.X_train.to(x.device)
        y_train = self.y_train.to(x.device)
        
        for indexes in chunk_iterator(cluster, chunksize=chunksize):
            X = X_train[indexes]
            Y = y_train[indexes]
            sim = self.similiarity(x, X)
            index = sim.argmax()
            
            if sim[index] > self.treshold:
                    return Y[index]
        return None
    
    def forward(self, X, chunksize=10000):
        centers = torch.tensor(self.kmeans.cluster_centers_, device=X.device)
        i = 0
        result = []
        
        for x in chunk_iterator(X, chunksize=1):
            sim = self.similiarity(x, centers)
            clusters_sorted = torch.argsort(-sim).cpu().numpy()
            
            for index in clusters_sorted:
                cluster = self.clusters[index]
                
                y_hat = self.search_in_cluster(x, cluster, chunksize)
                
                if y_hat is not None:
                    result.append(y_hat)
                    break
            if len(result) == i:
                result.append(torch.tensor(-1))
            i+=1

        print(result)
        return torch.stack(result)

In [523]:
from torch.nn import Parameter
import torch.nn.functional as F

In [524]:
class VAE(nn.Module):
    def __init__(self, tau=1):
        super(VAE, self).__init__()
        self.tau = tau
        
        self.fc1 = nn.Linear(512, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 128)

        self.embeddings = Parameter(torch.empty((128, 512)).uniform_(-0.5, 0.5).requires_grad_(True))
        
        self.fc3 = nn.Linear(512, 512)
        self.bn2 = nn.BatchNorm1d(512)
        self.fc4 = nn.Linear(512, 512)
        
    def update_temperature(self, tau):
        self.tau = tau

    def encode(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.tanh(x)
        x = self.fc2(x)
        return x

    def decode(self, z):
        z = self.fc3(z)
        z = self.bn2(z)
        z = F.tanh(z)
        z = self.fc4(z)
        return z
    
    def reparametrize(self, q_y):
        z = F.gumbel_softmax(q_y, tau=self.tau, dim=-1)
        z = torch.mm(z, self.embeddings)
        return z
    
    def forward(self, x):
        encoding = self.encode(x)
        z = self.reparametrize(encoding)
        return self.decode(z), z, encoding

In [621]:
class VAECosineSimiliaritySearch(nn.Module):
    def __init__(self, X_train, y_train, clusters, vae, treshold, n_clusters=128, dim=1):
        super(VAECosineSimiliaritySearch, self).__init__()
        self.vae = vae
        self.similiarity = torch.nn.CosineSimilarity(dim=dim)
        self.X_train = X_train
        self.y_train = y_train
        self.treshold = treshold
        self.clusters = clusters
        
    def set_treshold(self, treshold):
        self.treshold = treshold
        
    def search_in_cluster(self, x, cluster, chunksize):
        X_train = self.X_train.to(x.device)
        y_train = self.y_train.to(x.device)
        
        for indexes in chunk_iterator(cluster, chunksize=chunksize):
            X = X_train[indexes]
            Y = y_train[indexes]
            sim = self.similiarity(x, X)
            index = sim.argmax()
            
            if sim[index] > self.treshold:
                    return Y[index]
        return None
    
    def forward(self, X, chunksize=10000):
        centers = self.vae.encode(X).detach()
        centers = -centers
        clusters_sorted = centers.argsort(axis=1).cpu().numpy()
        
        i = 0
        result = []
        
        for i in range(len(clusters_sorted)): 
            for index in clusters_sorted[i]:
                cluster = self.clusters[index]
                
                y_hat = self.search_in_cluster(torch.stack([X[i]]), cluster, chunksize)
                
                if y_hat is not None:
                    result.append(y_hat)
                    break
            if len(result) == i:
                result.append(torch.tensor(-1))
            i+=1

        return torch.stack(result)

In [637]:
# device = torch.device('cuda:2')
device = torch.device('cpu')

In [366]:
from time import time

In [367]:
class ResultsStore():
    def __init__(self):
        self.results = []

    def add(self, y, y_hat, duration):
        self.results.append((y, y_hat, duration))
        return self

    def dataframe(self):
        df = pd.DataFrame.from_records(self.results, columns=['label', 'predicted', 'duration'])
        return df

In [638]:
def find(df, model, items_count=None, chunksize=10000, device=device):
    store = ResultsStore()
    i = 0
    
    with torch.no_grad():
        for chunk in chunk_iterator(df, chunksize=1, limit=items_count):
            X, y = read_chunk(chunk, device=device)

            start = time()
            y_hat = model(X, chunksize=chunksize)
            end = time()

            store.add(y[0].detach().cpu().numpy(), 
                      y_hat[0].detach().cpu().numpy(),
                      end - start)
            
            i+=1
            print(i)
        
    return store

In [369]:
def load_treshold(file='../data/treshold.npy'):
    return np.load(file)[0]

In [666]:
treshold = 0.7

In [447]:
from sklearn.utils import shuffle

In [641]:
train_df = pd.read_csv(train_file,
                       header=0,
                       converters={'image': read_string_array})
X_train, y_train = read_chunk(train_df, device)

In [19]:
test_df = pd.read_csv(train_file,
                       header=0,
                       converters={'image': read_string_array})

In [472]:
samples = test_df.sample(1000, random_state=42)
samples

Unnamed: 0,image,label
210931,"[-0.0486107506, -0.192878664, -0.116991892, -0...",5520
239762,"[0.06753061, -0.08758814, -0.20101719, -0.3760...",6434
154989,"[-0.115992084, -0.478344321, -0.157596901, 0.1...",3720
261937,"[0.170259625, 0.148349941, -0.0222489573, 0.26...",7018
342105,"[0.110233679, -0.057050921, -0.277625531, -0.0...",10254
...,...,...
55058,"[0.2035736, 0.11481629, -0.19044346, 0.0738445...",882
21142,"[-0.0673625469, -0.064171195, 0.381778806, 0.1...",343
172239,"[0.145213515, -0.403272718, 0.154248521, 0.378...",4306
273338,"[-0.0632000938, 0.149921402, 0.337661922, 0.13...",7645


In [667]:
bf_model = BrootForceCosineSimiliaritySearch(X_train, y_train, treshold=treshold).to(device)
bf_model.eval()

BrootForceCosineSimiliaritySearch(
  (similiarity): CosineSimilarity()
)

In [668]:
store = find(samples, bf_model)
store

  from ipykernel import kernelapp as app


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277


  


<__main__.ResultsStore at 0x7f864035b8d0>

In [669]:
df = store.dataframe()
df

Unnamed: 0,label,predicted,duration
0,5520,5520,0.068076
1,6434,6434,0.040551
2,3720,3720,0.065209
3,7018,7018,0.088357
4,10254,10254,0.055037
...,...,...,...
995,882,882,0.029392
996,343,343,0.033484
997,4306,4306,0.033058
998,7645,7645,0.031517


In [670]:
df.describe()

Unnamed: 0,duration
count,1000.0
mean,0.03855
std,0.027977
min,0.027559
25%,0.03129
50%,0.033975
75%,0.037507
max,0.563258


In [671]:
df['duration'].sum()

38.549822092056274

In [672]:
batched_bf_model = BatchedBrootForceCosineSimiliaritySearch(X_train, y_train, treshold=treshold).to(device)
batched_bf_model.eval()

BatchedBrootForceCosineSimiliaritySearch(
  (similiarity): CosineSimilarity()
)

In [673]:
batched_store = find(samples, batched_bf_model)
batched_store



1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277


  


<__main__.ResultsStore at 0x7f8633062400>

In [674]:
batched_df = batched_store.dataframe()
batched_df

Unnamed: 0,label,predicted,duration
0,5520,5520,0.063802
1,6434,6434,0.024415
2,3720,3720,0.015313
3,7018,7018,0.028188
4,10254,10254,0.038635
...,...,...,...
995,882,882,0.005867
996,343,343,0.002798
997,4306,4306,0.015628
998,7645,7645,0.024685


In [675]:
batched_df.describe()

Unnamed: 0,duration
count,1000.0
mean,0.019506
std,0.018792
min,0.00099
25%,0.009348
50%,0.017706
75%,0.025801
max,0.358974


In [676]:
batched_df['duration'].sum()

19.505995512008667

In [677]:
def load_clusters(file='../data/buckets.npy'):
    return np.load(file, allow_pickle=True)[0]

In [678]:
clusters = load_clusters()

In [679]:
import joblib

In [680]:
kmeans_model = KMeansCosineSimiliaritySearch(X_train, y_train, clusters, treshold=treshold, n_clusters=128).to(device)
kmeans_model.eval()
kmeans_model.kmeans = joblib.load('kmeans.joblib')

In [681]:
kmeans_store = find(samples, kmeans_model)
kmeans_store

[tensor(5520)]
1
[tensor(6434)]
2
[tensor(3720)]
3
[tensor(7018)]
4
[tensor(10254)]
5
[tensor(79)]
6
[tensor(386)]
7
[tensor(6284)]
8
[tensor(6050)]
9
[tensor(879)]
10
[tensor(1296)]
11
[tensor(1528)]
12
[tensor(1585)]
13
[tensor(10042)]
14
[tensor(2374)]
15
[tensor(476)]
16
[tensor(6072)]
17
[tensor(343)]
18
[tensor(4823)]
19
[tensor(8113)]




20
[tensor(8631)]
21
[tensor(2485)]
22
[tensor(1545)]
23
[tensor(46)]
24
[tensor(56)]
25
[tensor(1866)]
26
[tensor(2414)]
27
[tensor(232)]
28
[tensor(1542)]
29
[tensor(3065)]
30
[tensor(1533)]
31
[tensor(9016)]
32
[tensor(8697)]
33
[tensor(7152)]
34
[tensor(10079)]
35
[tensor(2444)]
36
[tensor(384)]
37
[tensor(7085)]
38
[tensor(5731)]
39
[tensor(4835)]
40
[tensor(76)]
41
[tensor(658)]
42
[tensor(5788)]
43
[tensor(1245)]
44
[tensor(3348)]
45
[tensor(4438)]
46
[tensor(4888)]
47
[tensor(6334)]
48
[tensor(163)]
49
[tensor(9153)]
50
[tensor(7890)]
51
[tensor(3877)]
52
[tensor(2027)]
53
[tensor(637)]
54
[tensor(1891)]
55
[tensor(4238)]
56
[tensor(5320)]
57
[tensor(9583)]
58
[tensor(277)]
59
[tensor(6671)]
60
[tensor(4790)]
61
[tensor(1082)]
62
[tensor(763)]
63
[tensor(978)]
64
[tensor(5993)]
65
[tensor(6796)]
66
[tensor(1211)]
67
[tensor(7849)]
68
[tensor(3513)]
69
[tensor(3809)]
70
[tensor(7456)]
71
[tensor(1145)]
72
[tensor(5748)]
73
[tensor(1795)]
74
[tensor(1903)]
75
[tensor(3879)]
76
[t

508
[tensor(28)]
509
[tensor(9753)]
510
[tensor(8071)]
511
[tensor(7239)]
512
[tensor(8049)]
513
[tensor(3705)]
514
[tensor(3927)]
515
[tensor(1933)]
516
[tensor(2519)]
517
[tensor(7018)]
518
[tensor(776)]
519
[tensor(9410)]
520
[tensor(4969)]
521
[tensor(4683)]
522
[tensor(1595)]
523
[tensor(23)]
524
[tensor(1376)]
525
[tensor(10206)]
526
[tensor(9013)]
527
[tensor(2961)]
528
[tensor(7792)]
529
[tensor(2194)]
530
[tensor(6627)]
531
[tensor(1910)]
532
[tensor(4207)]
533
[tensor(956)]
534
[tensor(872)]
535
[tensor(7964)]
536
[tensor(9291)]
537
[tensor(174)]
538
[tensor(1655)]
539
[tensor(4463)]
540
[tensor(10121)]
541
[tensor(6023)]
542
[tensor(3652)]
543
[tensor(455)]
544
[tensor(10008)]
545
[tensor(940)]
546
[tensor(6573)]
547
[tensor(3715)]
548
[tensor(7083)]
549
[tensor(7726)]
550
[tensor(8716)]
551
[tensor(2439)]
552
[tensor(825)]
553
[tensor(2025)]
554
[tensor(4540)]
555
[tensor(5403)]
556
[tensor(1106)]
557
[tensor(6523)]
558
[tensor(5813)]
559
[tensor(863)]
560
[tensor(3601)]
56

[tensor(1655)]
952
[tensor(4212)]
953
[tensor(658)]
954
[tensor(4238)]
955
[tensor(5467)]
956
[tensor(601)]
957
[tensor(2920)]
958
[tensor(4007)]
959
[tensor(6170)]
960
[tensor(6323)]
961
[tensor(6579)]
962
[tensor(5387)]
963
[tensor(3003)]
964
[tensor(4407)]
965
[tensor(5985)]
966
[tensor(5557)]
967
[tensor(3274)]
968
[tensor(9714)]
969
[tensor(5237)]
970
[tensor(9958)]
971
[tensor(84)]
972
[tensor(1887)]
973
[tensor(2374)]
974
[tensor(693)]
975
[tensor(2639)]
976
[tensor(8233)]
977
[tensor(1769)]
978
[tensor(8969)]
979
[tensor(8504)]
980
[tensor(1398)]
981
[tensor(1409)]
982
[tensor(1032)]
983
[tensor(3509)]
984
[tensor(1830)]
985
[tensor(1611)]
986
[tensor(4424)]
987
[tensor(2947)]
988
[tensor(5260)]
989
[tensor(10545)]
990
[tensor(10043)]
991
[tensor(6029)]
992
[tensor(9526)]
993
[tensor(7033)]
994
[tensor(6325)]
995
[tensor(882)]
996
[tensor(343)]
997
[tensor(4306)]
998
[tensor(7645)]
999
[tensor(10028)]
1000


  


<__main__.ResultsStore at 0x7f86401c0470>

In [682]:
kmeans_df = kmeans_store.dataframe()
kmeans_df

Unnamed: 0,label,predicted,duration
0,5520,5520,0.024539
1,6434,6434,0.001726
2,3720,3720,0.002011
3,7018,7018,0.004622
4,10254,10254,0.004204
...,...,...,...
995,882,882,0.002018
996,343,343,0.001788
997,4306,4306,0.002544
998,7645,7645,0.002185


In [683]:
kmeans_df.describe()

Unnamed: 0,duration
count,1000.0
mean,0.002442
std,0.003288
min,0.001146
25%,0.001784
50%,0.002055
75%,0.002426
max,0.067887


In [684]:
kmeans_df['duration'].sum()

2.442487955093384

In [685]:
clusters_vae = load_clusters(file='../data/buckets_vae.npy')

In [702]:
vae = VAE().to(device)
vae.eval()
vae.load_state_dict(torch.load('vae.torch', map_location=device))
vae_model = VAECosineSimiliaritySearch(X_train, y_train, clusters_vae, vae, treshold=treshold, n_clusters=128).to(device)
vae_model.eval()

VAECosineSimiliaritySearch(
  (vae): VAE(
    (fc1): Linear(in_features=512, out_features=256, bias=True)
    (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc2): Linear(in_features=256, out_features=128, bias=True)
    (fc3): Linear(in_features=512, out_features=512, bias=True)
    (bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc4): Linear(in_features=512, out_features=512, bias=True)
  )
  (similiarity): CosineSimilarity()
)

In [703]:
vae_store = find(samples, vae_model)
vae_store



1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277


<__main__.ResultsStore at 0x7f8650edbe48>

In [704]:
vae_df = vae_store.dataframe()
vae_df

Unnamed: 0,label,predicted,duration
0,5520,5520,0.053108
1,6434,6434,0.040970
2,3720,3720,0.033854
3,7018,7018,0.034127
4,10254,10254,0.011823
...,...,...,...
995,882,882,0.006131
996,343,343,0.034890
997,4306,4306,0.018050
998,7645,7645,0.026436


In [705]:
vae_df.describe()

Unnamed: 0,duration
count,1000.0
mean,0.010486
std,0.011841
min,0.001238
25%,0.002481
50%,0.004276
75%,0.015401
max,0.092162


In [706]:
vae_df['duration'].sum()

10.485718488693237