In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

from pathlib import Path

data_dir = "../artifacts/data"

In [None]:
# Hyper-parameters 
input_size = 784
num_classes = 10
num_epochs = 50
batch_size = 100
learning_rate = 0.001

In [None]:
# Preprocess Data
MNIST_train = torchvision.datasets.MNIST(data_dir, 
                                         train=True, 
                                         transform=transforms.ToTensor(),
                                         download=True)
MNIST_test = torchvision.datasets.MNIST(data_dir, 
                                        train=False, 
                                        transform=transforms.ToTensor(),
                                        download=True)

X = MNIST_train.data.reshape(-1, input_size)
y = MNIST_train.targets

In [None]:
# Use sklearn logreg model
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression(fit_intercept=True,
                        solver='saga',
                        max_iter=1000,
                        verbose=2,
                        n_jobs=5,
                        tol=0.01 # Tolerance for Stopping Criteria - we can keep it quite high as we just want to capture the hardest samples
                        )
clf

In [None]:
# Fit the model (untill convergence)
clf.fit(X, y)

In [None]:
# Sanity check - test on MNIST test
X_test = MNIST_test.data.reshape(-1, input_size)
y_test = MNIST_test.targets
score = clf.score(X_test, y_test) # test score
print(f" The LogReg model got a test score of: {score}")

In [None]:
# Find the hardest samples

preds = clf.predict_proba(X)
sample_losses = -np.log(preds[np.arange(len(y)), y]) # indexes the proba for the correct class to calculate log los

# Find top %5 hardest per class
hardest_per_class = {}
for c in range(10):
    class_indices = np.where(y == c)[0]
    total = len(class_indices)
    subset = round(0.05 * total)
    print(f"Finding the top {subset} hardest sample for the class {c}")
    sorted_idx = class_indices[np.argsort(sample_losses[class_indices])[::-1]]
    hardest_per_class[c] = sorted_idx[:subset]  

# Get the list of indicies
hard_indices = np.concatenate(list(hardest_per_class.values()))
all_indices = np.arange(len(y))
mask = np.ones(len(y), dtype=bool)
mask[hard_indices] = False
complementary_indices = all_indices[mask]

print(f"Total samples: {len(y)}")
print(f"Hard subset size: {len(hard_indices)}")
print(f"Complementary subset size: {len(complementary_indices)}")



In [None]:
def visualise_samples(dataset, indices, title):
    n = len(indices)
    fig, axes = plt.subplots(1, n, figsize=(n * 2, 2))
    if n == 1:
        axes = [axes]
    for ax, idx in zip(axes, indices):
        img, label = dataset[idx]
        ax.imshow(img.squeeze(), cmap='gray')
        ax.set_title(f"Label {label}\nIdx {idx}")
        ax.axis('off')
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

In [None]:
i=4
visualise_samples(MNIST_train, hardest_per_class[i][:10], f"Hardest samples for class {i}'s:")

In [None]:
# Save to disk
np.save(os.path.join(data_dir, "hard_indices.npy"), hard_indices)
np.save(os.path.join(data_dir, "new_train_indices.npy"), complementary_indices)

Testing the colour map for MNIST

In [None]:
COLOUR_MAP = {
            0: ('red',     [1.0, 0.2, 0.2]),
            1: ('green',   [0.2, 1.0, 0.2]),
            2: ('blue',    [0.2, 0.2, 1.0]),
            3: ('yellow',  [1.0, 1.0, 0.2]),
            4: ('cyan',    [0.2, 1.0, 1.0]),
            5: ('magenta', [1.0, 0.2, 1.0]),
            6: ('orange',  [1.0, 0.6, 0.2]),
            7: ('purple',  [0.6, 0.2, 1.0]),
            8: ('lime',    [0.6, 1.0, 0.2]),
            9: ('brown',   [0.6, 0.4, 0.2])
        }

In [None]:
import numpy as np
from ipywidgets import FloatSlider, interact

import plotly.graph_objects as go

fig = go.Figure()

# Add base color points
for label, (name, vec) in COLOUR_MAP.items():
    fig.add_trace(go.Scatter3d(
        x=[vec[0]], y=[vec[1]], z=[vec[2]],
        mode='markers+text',
        marker=dict(size=8, color=[f'rgb({vec[0]*255},{vec[1]*255},{vec[2]*255})']),
        text=[f"{label}"],
        textposition="top center",
        name=name
    ))

def add_noise_spheres(noise):
    fig.data = fig.data[:len(COLOUR_MAP)]

    u = np.linspace(0, 2 * np.pi, 30)
    v = np.linspace(0, np.pi, 30)

    for label, (name, vec) in COLOUR_MAP.items():
        x = vec[0] + noise * np.outer(np.cos(u), np.sin(v))
        y = vec[1] + noise * np.outer(np.sin(u), np.sin(v))
        z = vec[2] + noise * np.outer(np.ones_like(u), np.cos(v))

        fig.add_trace(go.Surface(
            x=x, y=y, z=z,
            showscale=False,
            opacity=0.2,
            colorscale=[[0, f'rgb({vec[0]*255},{vec[1]*255},{vec[2]*255})'],
                        [1, f'rgb({vec[0]*255},{vec[1]*255},{vec[2]*255})']],
            name=f"{name} noise"
        ))
    fig.update_layout(title=f"Noise = {noise:.2f}")

std = 0.07
slider = FloatSlider(value=std*3, min=0.05, max=1.0, step=0.05, description='Noise')
interact(add_noise_spheres, noise=slider)

fig.update_layout(
    scene=dict(
        xaxis_title='R',
        yaxis_title='G',
        zaxis_title='B'
    ),
    title='COLOUR_MAP in RGB 3D Space with Noise Spheres'
)


fig.show()
