-
Notifications
You must be signed in to change notification settings - Fork 0
/
cluster.py
579 lines (464 loc) · 27.1 KB
/
cluster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
import math
import pdb
import scipy.stats as stat
import dataset as ds
import product as pro
#==========================================================================================================================================
# Cluster Class
#==========================================================================================================================================
class Cluster:
"""
Class: Cluster
Description:
Wraps around a map of related products and provides various utility functions that facilitate filtering.
Instance Variables:
Commodity commodity:
Commodity object that contains the commodity information of the Cluster.
Geography geography:
Geography object that contains the geographic information of the Cluster.
list<int> tpoIDs:
List of integer identifiers that uniquely identify TargetProductOffer objects stored in the Substituter object.
dict<int, Product> products:
Collection of Product objects that are within the cluster.
dict<string, set<int>> filterSets:
A collection of sets of product IDs. Each set is a subset of the products property.
Key: name of the filter set
Value: set object containing product IDs.
"""
#--------------------------------------------------------------------------------------------------------------------------------------
# Constructor
#--------------------------------------------------------------------------------------------------------------------------------------
def __init__(self, commodity, geography):
self.commodity = commodity
self.geography = geography
self.tpoIDs = []
self.products = {}
self.filterSets = {}
#--------------------------------------------------------------------------------------------------------------------------------------
# hasDuplicates Method
#--------------------------------------------------------------------------------------------------------------------------------------
def hasDuplicates(self) -> bool:
"""
Method: bool hasDuplicates()
Description:
This method checks whether the cluster has any duplicate products.
Output:
bool
"""
productIDs = list(self.products.keys())
for i in range(0, len(productIDs)):
for j in range(i + 1, len(productIDs)):
if productIDs[i] == productIDs[j]:
return True
return False
#--------------------------------------------------------------------------------------------------------------------------------------
# addUniqueProduct Method
#--------------------------------------------------------------------------------------------------------------------------------------
def addUniqueProduct(self, productID, unitSize: float, UOM: str, brandType: str, desc: str, periodID: int, outletID: int, unitCount: float, sales: float):
"""
Method: void addUniqueProduct
(
T productID,
float unitSize,
string UOM,
string desc,
int periodID,
int outletID,
float unitCount,
float sales
)
Description:
This method checks whether a Product object with the given product ID already exists. If not, it creates a new Product object.
Otherwise, it aggregates the given properties with those of the existing object.
Arguments:
T productID: Unique identifier of the product to be added.
float unitSize: Volume or mass of a single unit of product.
string UOM: Unit of measure.
string desc: Concatenated text features of the product.
int periodID: Unique identifier of the reference period during which the product was sold.
int outletID: Unique identifier of the outlet at which the product was sold.
float unitCount: Number of units sold.
"""
if productID not in self.products:
self.products[productID] = pro.Product(productID, UOM, brandType, desc)
self.products[productID].addProperties(periodID, outletID, unitSize, unitCount, sales)
#--------------------------------------------------------------------------------------------------------------------------------------
# removeProduct Method
#--------------------------------------------------------------------------------------------------------------------------------------
def removeProduct(self, productID):
"""
Method: void removeProduct
(
T productID
)
Description:
This method removes the product associated with the given product ID from all sets.
Arguments:
T productID: Unique identifier of the product to be removed.
"""
self.products.pop(productID)
for setName, filterSet in self.filterSets.items():
try:
filterSet.remove(productID)
except KeyError:
pass
#--------------------------------------------------------------------------------------------------------------------------------------
# find Method
#--------------------------------------------------------------------------------------------------------------------------------------
def find(self, comparer = lambda refProd, prod: True, retriever = lambda prod: prod.productID, setName: str = None):
"""
Method: T find
(
bool comparer(Product refProd, Product prod),
T retriever<T>(Product prod),
string setName
)
Description:
This method compares all products within the designated set by way of the comparer. It then applies the retriever onto the
remaining product and returns an object of type T.
If setName is None, then the method uses the full product set.
Arguments:
bool comparer(Product refProd, Product prod):
A callable object that evaluates a boolean expression on two Product objects.
If it returns True, then prod is assigned to refProd. Otherwise, refProd remains unchanged.
T retriever<T>(Product prod):
A callable object that retrieves a property of a Product object.
The retriever gets the desired property of refProd once the comparer has been applied to every Product object in the
set. This property is then returned as the output of the method.
string setName:
Name of the filter set that the method will consider. If set to None, the method will consider all the products in the
Cluster.
Output:
Property of refProd of type T.
"""
if setName == None:
productIDList = list(self.products.keys())
else:
productIDList = list(self.filterSets[setName])
ref = self.products[productIDList[0]]
for i in range(1, len(productIDList)):
if comparer(ref, self.products[productIDList[i]]):
ref = self.products[productIDList[i]]
return retriever(ref)
#--------------------------------------------------------------------------------------------------------------------------------------
# addFilterSet Method
#--------------------------------------------------------------------------------------------------------------------------------------
def addFilterSet(self, name: str):
"""
Method: void addFilterSet
(
string name
)
Description:
This method adds a filter set with the given name.
Arguments:
string name: Name of the filter set to be added.
"""
self.filterSets[name] = set(self.products.keys())
#--------------------------------------------------------------------------------------------------------------------------------------
# copyFilterSet Method
#--------------------------------------------------------------------------------------------------------------------------------------
def copyFilterSet(self, sourceSetName: str, copySetName: str):
"""
Method: void copyFilterSet
(
string sourceSetName,
string copySetName
)
Description:
This method copies the set with the name given by sourceSetName and creates a new set with the name given by copySetName.
Arguments:
string sourceSetName: Name of the filter set to be copied.
string copySetName: Name of the filter set to be created.
"""
self.filterSets[copySetName] = set(self.filterSets[sourceSetName])
#--------------------------------------------------------------------------------------------------------------------------------------
# removeFilterSet Method
#--------------------------------------------------------------------------------------------------------------------------------------
def removeFilterSet(self, name: str):
"""
Method: void removeFilterSet
(
string name
)
Description:
This method removes the filter set with the given name.
Arguments:
string name: Name of the filter set to be removed.
"""
self.filterSets.remove(name)
#--------------------------------------------------------------------------------------------------------------------------------------
# clearFilterSets Method
#--------------------------------------------------------------------------------------------------------------------------------------
def clearFilterSets(self):
"""
Method: void clearFilterSets()
Description:
This method removes all filter sets.
"""
self.filterSets = {}
#--------------------------------------------------------------------------------------------------------------------------------------
# intersectFilterSets Method
#--------------------------------------------------------------------------------------------------------------------------------------
def intersectFilterSets(self, innerSetName: str, setNames = []):
"""
Method: void intersectFilterSets
(
string interSetName,
list<string> setNames
)
Description:
This method creates a new filter set (name given by innerSetName) that is the intersection of all filter sets given by
the names stored in setNames.
Arguments:
string innerSetName: Name of the intersection filter set to be created.
list<string> setNames: List of names of the filter sets that will be intersected.
"""
if len(setNames) == 0:
raise ValueError("Set list is empty.")
firstSetName = setNames.pop(0)
filterSets = []
for name in setNames:
filterSets.append(self.filterSets[name])
self.filterSets[innerSetName] = self.filterSets[firstSetName].intersection(*filterSets)
#--------------------------------------------------------------------------------------------------------------------------------------
# applyFilterMask Method
#--------------------------------------------------------------------------------------------------------------------------------------
def applyFilterMask(self, setName: str, productIDSet: set, filterMode: str = 'drop'):
"""
Method: void applyFilterMask
(
string setName,
set<T> productIDSet,
string filterMode
)
Description:
This method filters the product IDs stored in the filter set given by setName. If filterMode is set to 'drop', the product IDs
in productIDSet are dropped from the set. If filterMode is set to 'keep', then only the product IDs in productIDSet are retained
in the set.
Arguments:
string setName: Name of the set to which the filter mask will be applied.
set<int> productIDSet: Set of product IDs that will serve as the mask.
string filterMode: = 'drop' or 'keep'
Designates the type of mask. If set to 'drop', then all product IDs in productIDSet will be dropped from setName. If
set to 'keep', then all productIDs not in productIDSet will be dropped from setName.
"""
if filterMode == 'drop':
self.filterSets[setName] = self.filterSets[setName].difference(productIDSet)
elif filterMode == 'keep':
self.filterSets[setName] = set(productIDSet)
#--------------------------------------------------------------------------------------------------------------------------------------
# applyFilterFunction Method
#--------------------------------------------------------------------------------------------------------------------------------------
def applyFilterFunction(self, setName: str, filterer = lambda product: product.properties[0].sales > 0):
"""
Method: void applyFilterMask
(
string setName,
bool filterer(Product product)
)
Description:
This method applies the filterer function to determine whether each product in the set given by setName should be kept
or dropped.
Arguments:
string setName: Name of the filter set on which the filter function will be applied.
bool filterer(Product product):
A callable object that applies a boolean expression on product.
Arguments:
Product product: The Product object that is evaluated by the filterer.
"""
keepProductIDs = set()
for productID in self.filterSets[setName]:
if filterer(self.products[productID]) == True:
keepProductIDs.add(productID)
self.applyFilterMask(setName, keepProductIDs, filterMode = 'keep')
#--------------------------------------------------------------------------------------------------------------------------------------
# applyCutoffFilter Method
#--------------------------------------------------------------------------------------------------------------------------------------
def applyCutoffFilter(self,
setName: str,
retriever = lambda product: product.properties[0].sales,
lowerCutoff: float = 0.5,
upperCutoff: float = 1):
"""
Method: void applyCutoffFilter
(
string setName,
float retriever(Product product),
float lowerCutoff,
float upperCutoff
)
Description:
This method filters out all products with a specific property that is above the given upperCutoff and/or below the given
lowerCutoff. The getVariable function is applied to retrieve a property of a product.
Arguments:
string setName: Name of the set on which the cut-off filter will be applied.
float retriever(Product product):
A callable object that returns a numeric property of the given Product object.
float lowerCutoff: Lowest acceptable value for the property being evaluated.
float upperCutoff: Highest acceptable value for the property being evaluated.
"""
if lowerCutoff != None and upperCutoff != None:
if upperCutoff < lowerCutoff:
raise ValueError("The upper cutoff cannot be smaller than the lower cutoff.")
filterer = lambda product: retriever(product) >= lowerCutoff and retriever(product) <= upperCutoff
elif lowerCutoff != None:
filterer = lambda product: retriever(product) >= lowerCutoff
elif upperCutoff != None:
filterer = lambda product: retriever(product) <= upperCutoff
else:
filterer = lambda product: True
keepProductIDs = set()
for productID in self.filterSets[setName]:
try:
if filterer(self.products[productID]):
keepProductIDs.add(productID)
except Exception:
pass
self.applyFilterMask(setName, keepProductIDs, filterMode = 'keep')
#--------------------------------------------------------------------------------------------------------------------------------------
# addNormalizedVariable Method
#--------------------------------------------------------------------------------------------------------------------------------------
def addNormalizedVariable(self,
varKey: str,
retriever = lambda product: product.properties[0].sales,
setName: str = None,
normMode: str = 'rank',
invert: bool = False):
"""
Method: void addNormalizedvariable
(
string varKey,
float retriever(Product product),
string setName,
string normMode,
bool invert
)
Description:
This method creates a new normalized variable (called varKey) for each product in the set given by setName. This new variable
is derived from a product property retrieved by getVariable. It is stored in the variables property of each Product object.
If normMode is set to 'rank', then the normalized variable will be derived from the sorted rank of each product variable.
If normMode is set to 'magnitude', then the normalized variable will be proportional to the product variable itself.
If normMode is set to 'weight', then the normalized variable will be proportional to the product variable as well as summing
up to 1 across all products in the set.
Arguments:
string varKey: Name of the normalized variable to be created.
float retriever(Product product):
A callable object that retrieves a numeric property of the given Product object.
string setName: Name of the filter set to be consider. If set to None, then all products in the Cluster will be considered.
string normMode: = 'rank' or 'magnitude' or 'weight'
Determines how the variable will be normalized.
bool invert: If True, then the variable is also inverted.
"""
if setName == None:
productIDs = list(self.products.keys())
else:
productIDs = list(self.filterSets[setName])
count = len(productIDs)
values = []
try:
for i in range(0, count):
self.products[productIDs[i]].variables[varKey] = 0
values.append(retriever(self.products[productIDs[i]]))
# Normalize the variable by rank.
if normMode == 'rank':
ranks = stat.rankdata(values)
for i in range(0, count):
if not invert:
self.products[productIDs[i]].variables[varKey] = float(ranks[i] / count)
if invert:
self.products[productIDs[i]].variables[varKey] = float(1 - (ranks[i] - 1) / count)
# Normalize the variable by magnitude.
elif normMode == 'magnitude':
minValue = min(values)
maxValue = max(values)
for i in range(0, count):
if not invert:
self.products[productIDs[i]].variables[varKey] = (values[i] - minValue) / (maxValue - minValue)
if invert:
self.products[productIDs[i]].variables[varKey] = 1 - ((values[i] - minValue) / (maxValue - minValue))
# Normalize the variable by weight.
elif normMode == 'weight':
summedValue = sum(values)
for i in range(0, count):
if not invert:
self.products[productIDs[i]].variables[varKey] = values[i] / summedValue
if invert:
self.products[productIDs[i]].variables[varKey] = 1 - values[i] / summedValue
except KeyError as e:
for productID in productIDs:
self.products[productID].variables[varKey] = 0
#--------------------------------------------------------------------------------------------------------------------------------------
# toFile Method
#--------------------------------------------------------------------------------------------------------------------------------------
def toDataFrame(self, periodID: int, tpoID: int, varLabels: list, setName: str = None):
"""
Method: DataFrame toDataFrame
(
int periodID,
int tpoID,
list<string> varLabels,
string setName
)
Description:
This method generates a DataFrame containing the properties of each product in the cluster during the period
given by periodID. A column is created for each label contained in varLabels.
Arguments:
int periodID: Unique identifier of the reference period to consider when retrieving the properties of the Cluster.
int tpoID: Unique identifier of the TPO that this cluster is associated with.
list<string> varLabels: Additional columns to be added to the output. Each varLabel corresponds to a variable in the
variables property of each Product object.
"""
if setName == None:
productIDList = list(self.products.keys())
else:
productIDList = list(self.filterSets[setName])
filterLabels = []
for filterLabel in self.filterSets:
filterLabels.append(filterLabel)
colLabels = ['TPO_ID', 'ProductID', 'UOM', 'Description', 'Unit Size', 'Quantity Units', 'Sales'] + varLabels + filterLabels
df = ds.fromDict(colLabels)
for productID in productIDList:
product = self.products[productID]
if periodID in product.properties:
props = product.properties[periodID]
variables = []
for varLabel in varLabels:
if varLabel in product.variables:
variables.append(product.variables[varLabel])
else:
variables.append(-1)
filters = []
for filterLabel in filterLabels:
if productID in self.filterSets[filterLabel]:
filters.append(1)
else:
filters.append(0)
ds.addRow(df, [tpoID, productID, product.UOM, product.desc, props.unitSize, props.unitCount, props.sales] + variables + filters)
return df
#--------------------------------------------------------------------------------------------------------------------------------------
# toString Method
#--------------------------------------------------------------------------------------------------------------------------------------
def toString(self, escChars: str = "\n", limit: bool = True) -> str:
response = escChars + "------------------------------"
response += escChars + "Data type: Product Cluster"
response += escChars + "Geography: " + self.geography.toString(escChars + "\t")
response += escChars + "TPOs in this cluster: " + str(len(self.tpoIDs))
for i, tpoID in enumerate(self.tpoIDs):
response += escChars + str(i + 1) + ". " + str(tpoID)
response += escChars + "Products in this cluster: " + str(len(self.products))
for i, productID in enumerate(self.products):
response += escChars + "Product " + str(i + 1) + ". " + str(self.products[productID].toString(escChars + "\t", limit))
if i == 9 and limit:
response += escChars + "..."
break
response += escChars + "Filter sets: " + str(len(self.filterSets))
for i, setName in enumerate(self.filterSets):
response += escChars + str(i + 1) + ". '" + setName + "' filter set: " + str(len(self.filterSets[setName])) + " product IDs"
for i, productID in enumerate(self.filterSets[setName]):
response += escChars + "\t" + str(i + 1) + ". " + str(productID)
if i == 9 and limit:
response += escChars + "\t" + "..."
break
response += escChars + "------------------------------"
return response