## **Color quantization variant 2**
Attempt to generate a unique key based on colors and their percentages in the image

In [1]:
import numpy as np
import cv2
from glob import glob
from collections import Counter

from matplotlib import pyplot as plt
import seaborn as sns

from sklearn.cluster import MiniBatchKMeans
import torch
from torch import optim
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.nn import functional as F

In [2]:
# Transformations

class ColorBar(object):
    """
    Generate colorbar with prominent colors of input image
    
    Args:
        size       :    Size of colorbar in pixels
        n_colors   :    Number of prominent colors to extract
    """
    
    def __init__(self, size=100, n_colors=10):
        assert isinstance(size, int)
        assert isinstance(n_colors, int)
        
        self.size = size
        self.n_colors = n_colors
        
        
    def __call__(self, sample):
        
        # MiniBatchKMeans for color quantization
        img = np.array(sample)
        img = img.reshape(img.shape[0]*img.shape[1], -1)
        clf = MiniBatchKMeans(n_clusters=self.n_colors)
        clf.fit(img)
        centers, labels = clf.cluster_centers_, clf.labels_
        
        # Color histogram
        # Delete the color with max pixels (belongs to white)
        counts = Counter(labels)
        max_value_id = np.argmax(counts.values())
        max_label_locations = np.where(labels == list(counts.keys())[max_value_id])[0]
        labels = np.delete(labels, max_label_locations)
        centers = np.array(list(centers[:max_value_id]) + list(centers[max_value_id+1:]))
        label_ids = np.delete(np.array(list(counts.keys())), max_value_id)
        
        num_labels = np.arange(0, len(label_ids)+1)
        (hist, _) = np.histogram(labels, bins=num_labels)
        hist = hist / sum(hist)
        
        # Create colorbar
        bar = np.zeros((1, self.size, 3), dtype=np.uint8)
        startX = 0

        for (percent, color) in zip(hist, centers):
            endX = startX + percent * self.size
            cv2.rectangle(bar, (int(startX), 0), (int(endX), 50), color.astype('uint8').tolist(), -1)
            startX = endX
            
        return bar
    
    
# Define image transforms
im_trans = transforms.Compose([
    transforms.Resize(size=256),
    ColorBar(size=100, n_colors=10),
    transforms.ToTensor()
])

In [3]:
# Create datasets and data loaders

BATCH_SIZE = 200

data = {
    'train': datasets.ImageFolder(root='../../images/train', transform=im_trans),
    'valid': datasets.ImageFolder(root='../../images/validation', transform=im_trans),
    'test': datasets.ImageFolder(root='../../images/test', transform=im_trans)
}

dataloaders = {
    'train': DataLoader(data['train'], batch_size=BATCH_SIZE, shuffle=True),
    'valid': DataLoader(data['valid'], batch_size=BATCH_SIZE, shuffle=True),
    'test': DataLoader(data['test'], batch_size=BATCH_SIZE, shuffle=True)
}

In [4]:
# Define model

N_CLASSES = 168

model = torch.nn.Sequential(
    torch.nn.Conv2d(in_channels=3, out_channels=1, kernel_size=(1, 3), stride=1, bias=True),
    torch.nn.ReLU(),
    torch.nn.Flatten(),
    torch.nn.Linear(98, N_CLASSES),
    torch.nn.LogSoftmax(dim=1)
)


# Train and test functions

def train(model, train_loader, optimizer, loss_function, epoch):
    
    model.train()
    batch_loss = []
    top_1_correct = 0
    top_5_correct = 0
    
    for batch_id, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = loss_function(output, target)
        loss.backward()
        optimizer.step()
        
        top_1_preds = output.argmax(dim=1, keepdim=True)
        top_5_preds = output.topk(5, dim=1)[1]
        
        batch_loss.append(loss.mean().item())
        top_1_correct += top_1_preds.eq(target.view_as(top_1_preds)).sum().item()
        top_5_correct += sum([1 if target[i] in top_5_preds[i] else 0 for i in range(len(target))])
        
        print("Epoch {} [Batch {}/{}] \t Top 1 accuracy: {:.2f}% \t Top 5 accuracy: {:.2f}%".format(
            epoch, batch_id+1, len(train_loader), 
            top_1_correct/(BATCH_SIZE * (batch_id+1)) * 100.,
            top_5_correct/(BATCH_SIZE * (batch_id+1)) * 100.
        ))
        
    train_loss = np.mean(batch_loss)
    top_1_accuracy = top_1_correct / len(train_loader.dataset)
    top_5_accuracy = top_5_correct / len(train_loader.dataset)
    
    return train_loss, top_1_accuracy, top_5_accuracy



def validate(model, val_loader, loss_function, epoch):
    
    model.eval()
    batch_loss = []
    top_1_correct = 0
    top_5_correct = 0
    
    for data, target in val_loader:
        output = model(data)
        loss = loss_function(output, target)
        
        top_1_preds = output.argmax(dim=1, keepdim=True)
        top_5_preds = output.topk(5, dim=1)[1]
        
        epoch_loss.append(loss.mean().item())
        top_1_correct += top_1_preds.eq(target.view_as(top_1_preds)).sum().item()
        top_5_correct += sum([1 if target[i] in top_5_preds[i] else 0 for i in range(len(target))])
        
    val_loss = np.mean(batch_loss)
    top_1_accuracy = top_1_correct / len(val_loader.dataset)
    top_5_accuracy = top_5_correct / len(val_laoder.dataset)
    
    print("\n[VALIDATION] Epoch {} \t Top 1 accuracy: {:.2f}% \t Top 5 accuracy: {:.2f}%".format(
        epoch, top_1_accuracy*100., top_5_accuracy*100.
    ))
    print('\n--------------------------------------------------------------------------\n')
    
    return val_loss, top_1_accuracy, top_5_accuracy



def test(model, test_loader, loss_function):
    
    model.eval()
    top_1_correct, top_5_correct = 0, 0
    
    for data, target in test_loader:
        output = model(data)
        top_1_preds = output.argmax(dim=1, keepdim=True)
        top_5_preds = output.topk(5, dim=1)[1]
        
        top_1_correct += top_1_preds.eq(target.view_as(top_1_preds)).sum().item()
        top_5_correct += sum([1 if target[i] in top_5_preds[i] else 0 for i in range(len(target))])
        
    top_1_accuracy = top_1_correct / len(test_loader.dataset)
    top_5_accuracy = top_5_correct / len(test_loader.dataset)
    
    print("[TESTING] Test dataset \t Top 1 accuracy: {:.2f}% \t Top 5 accuracy: {:.2f}%".format(
        top_1_accuracy*100., top_5_accuracy*100.
    ))
    
    return top_1_accuracy, top_5_accuracy

In [5]:
# Loss functions, optimizer, scheduler

epochs = 20
loss_function = torch.nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=5, gamma=0.9)

In [1]:
# Train

train_loss_hist, train_acc1_hist, train_acc5_hist = [], [], []
val_loss_hist, val_acc1_hist, val_acc5_hist = [], [], []


for epoch in range(epochs):
    
    train_loss, train_acc1, train_acc5 = train(model, dataloaders['train'], optimizer, loss_function, epoch)
    val_loss, val_acc1, val_acc5 = validate(model, dataloaders['valid'], loss_function, epoch)
    
    train_loss_hist.append(train_loss)
    train_acc1_hist.append(train_acc1)
    train_acc5_hist.append(train_acc5)
    val_loss_hist.append(val_loss)
    val_acc1_hist.append(val_acc1)
    val_acc5_hist.append(val_acc5)
    
    scheduler.step()
    

# Test model on test data

test(model, dataloaders['test'], loss_function)


# Plot losses and accuracies
# Train
fig = plt.figure(figsize=(18, 4))

ax1 = fig.add_subplot(131)
ax1.plot(train_loss_hist, color='blue', alpha=0.7, label='Train')
ax1.plot(val_loss_hist, color='orange', alpha=0.8, label='Validation')
ax1.set_title('Loss', fontweight='bold')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Nonlinear logloss')
ax1.legend()
ax1.grid()

ax2 = fig.add_subplot(132)
ax2.plot(train_acc1_hist, color='blue', alpha=0.7, label='Train')
ax2.plot(val_acc1_hist, color='orange', alpha=0.8, label='Validation')
ax2.set_title('Top 1 accuracy', fontweight='bold')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.legend()
ax2.grid()

ax3 = fig.add_subplot(133)
ax3.plot(train_acc5_hist, color='blue', alpha=0.7, label='Train')
ax3.plot(val_acc5_hist, color='orange', alpha=0.8, label='Validation')
ax3.set_title('Top 5 accuracy', fontweight='bold')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Accuracy')
ax3.legend()
ax3.grid()

plt.tight_layout()
plt.show()