In [3]:
import multiprocessing
import sys

import numpy as np
import pytorch_lightning as pl
import sklearn
import torch
import torchvision

from data.dataloaders import ImagesDataset
from models.model import SelfSupervisedLearner
from sklearn.linear_model import LogisticRegression
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader

BATCH_SIZE = 256
EPOCHS     = 1000
LR         = 3e-4
IMAGE_SIZE = 96 # Change this depending on dataset
NUM_GPUS= 0 # Change this depending on host
NUM_WORKERS = multiprocessing.cpu_count()

In [5]:
resnet = torchvision.models.resnet18(pretrained=False)
model = SelfSupervisedLearner(
    resnet,
    image_size = IMAGE_SIZE,
    hidden_layer = 'avgpool',
    projection_size = 256,
    projection_hidden_size = 4096,
    moving_average_decay = 0.99,
    lr = LR
)

    
argv = ["train.py", "--load", "./ckpt/learner_0510_v100.pt"]
model.load_state_dict(torch.load(argv[2]))
print("Loaded checkpoint from ", argv[2])

#TODO: for some reason labels don't exist in my wget data 
#ds = ImagesDataset("./dataset/test_images", IMAGE_SIZE, train=False)
data_transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_dataset = torchvision.datasets.STL10('./dataset/train_split', split='train', download=False,
                   transform=data_transforms)
train_loader = DataLoader(train_dataset, batch_size=5000, num_workers=NUM_WORKERS, shuffle=False)


train_imgs, train_labels = next(iter(train_loader))
print("Train loading done")

test_dataset = torchvision.datasets.STL10('./dataset/test_split', split='test', download=False, transform=data_transforms)
test_loader = DataLoader(test_dataset, batch_size=8000, num_workers=NUM_WORKERS, shuffle=False)
test_imgs, test_labels = next(iter(test_loader))
print("Test loading done")

train_projs, train_embeddings = model.learner.forward(train_imgs, return_embedding=True)
test_projs, test_embeddings = model.learner.forward(test_imgs, return_embedding=True)

print("got embeddings")

train_imgs = torch.flatten(train_imgs, start_dim=1)
test_imgs = torch.flatten(test_imgs, start_dim=1)

scaler = sklearn.preprocessing.StandardScaler()
scaler.fit(train_imgs)
train_imgs = scaler.transform(train_imgs).astype(np.float32)
test_imgs = scaler.transform(test_imgs).astype(np.float32)


pca = PCA(n_components=512)
train_imgs_pca = pca.fit_transform(train_imgs)
test_imgs_pca = pca.transform(test_imgs)

lr_baseline = LogisticRegression(max_iter=100000)
baseline_preds = lr_baseline.fit(train_imgs_pca, train_labels)

baseline_preds = lr_baseline.predict_proba(test_imgs_pca)
baseline_classes = lr_baseline.predict(test_imgs_pca)
baseline_acc = sklearn.metrics.accuracy_score(test_labels, baseline_classes)

lr_byol = LogisticRegression(max_iter=100000)
lr_byol.fit(train_embeddings.detach().numpy(), train_labels)

byol_preds = lr_byol.predict_proba(test_embeddings.detach().numpy())
byol_classes = lr_byol.predict(test_embeddings.detach().numpy())
byol_acc = sklearn.metrics.accuracy_score(test_labels, byol_classes)

Loaded checkpoint from  ./ckpt/learner_0510_v100.pt
Train loading done
Test loading done
got embeddings


In [115]:
for _ in range(10):
    random_idx = np.random.randint(0, high=train_imgs.shape[0], size = 25)

    embeddings_subset = train_embeddings.detach().numpy()[random_idx]
    train_labels_subset = train_labels[random_idx]

    lr_rand = LogisticRegression(max_iter=100000)
    lr_rand.fit(embeddings_subset, train_labels_subset)

    rand_preds =lr_rand.predict(test_embeddings.detach().numpy())
    rand_acc = sklearn.metrics.accuracy_score(test_labels, rand_preds)
    
    lr_baseline = LogisticRegression(max_iter=100000)
    lr_baseline.fit(train_imgs[random_idx], train_labels_subset)
    
    lr_baseline_preds = lr_baseline.predict(test_imgs)
    lr_baseline_acc = sklearn.metrics.accuracy_score(test_labels, lr_baseline_preds)
    
    print("lr baseline: ", lr_baseline_acc)

    print("random embeddings: ", rand_acc)

lr baseline:  0.154875
random embeddings:  0.403125
lr baseline:  0.17575
random embeddings:  0.331625
lr baseline:  0.19275
random embeddings:  0.357125
lr baseline:  0.1715
random embeddings:  0.344875
lr baseline:  0.1785
random embeddings:  0.39175
lr baseline:  0.212125
random embeddings:  0.330875
lr baseline:  0.162375
random embeddings:  0.352
lr baseline:  0.176375
random embeddings:  0.349625
lr baseline:  0.155125
random embeddings:  0.3
lr baseline:  0.170375
random embeddings:  0.352125


In [117]:
from sklearn.cluster import KMeans

km = KMeans(n_clusters=10, max_iter=100000)
km.fit(train_embeddings.detach().numpy())

clusters = km.labels_

In [118]:
from collections import Counter

counts = Counter(clusters)
total = train_embeddings.detach().numpy().shape[0]

weights = {}
uniform_prob = 0.1
for k in counts:
    weights[k] = uniform_prob / (counts[k] / total)
    
print(counts)
print(weights)

weights_full = [weights[k] for k in clusters]

Counter({9: 921, 5: 767, 7: 642, 4: 638, 2: 541, 0: 477, 3: 358, 1: 314, 8: 227, 6: 115})
{4: 0.7836990595611286, 2: 0.9242144177449169, 9: 0.5428881650380022, 5: 0.651890482398957, 0: 1.0482180293501049, 3: 1.3966480446927376, 6: 4.347826086956522, 1: 1.5923566878980895, 7: 0.7788161993769471, 8: 2.202643171806167}


In [119]:
import random

for _ in range(10):
    kmeans_idx = random.choices(range(train_imgs.shape[0]), weights=weights_full, k=25)

    """
    cluster_subset = clusters[kmeans_idx]
    cluster_counts = Counter(cluster_subset)
    print(cluster_subset)
    print(cluster_counts)
    """

    embeddings_subset = train_embeddings.detach().numpy()[kmeans_idx]
    train_labels_subset = train_labels[kmeans_idx]

    lr_km = LogisticRegression(max_iter=100000)
    lr_km.fit(embeddings_subset, train_labels_subset)

    km_preds =lr_km.predict(test_embeddings.detach().numpy())
    km_acc = sklearn.metrics.accuracy_score(test_labels, km_preds)
    
    lr_baseline = LogisticRegression(max_iter=100000)
    lr_baseline.fit(train_imgs[kmeans_idx], train_labels_subset)
    
    lr_baseline_preds = lr_baseline.predict(test_imgs)
    lr_baseline_acc = sklearn.metrics.accuracy_score(test_labels, lr_baseline_preds)

    print("km: ", km_acc)
    print("lr baseline acc:", lr_baseline_acc)

km:  0.37025
lr baseline acc: 0.189625
km:  0.26125
lr baseline acc: 0.17625
km:  0.344125
lr baseline acc: 0.202625
km:  0.353375
lr baseline acc: 0.17825
km:  0.396375
lr baseline acc: 0.21675
km:  0.3405
lr baseline acc: 0.18375
km:  0.4075
lr baseline acc: 0.214
km:  0.3765
lr baseline acc: 0.166375
km:  0.401875
lr baseline acc: 0.18675
km:  0.36025
lr baseline acc: 0.211625
