In [2]:
import sys

if "include" not in sys.path:
    sys.path.append("include")


In [3]:
import os
import sys
import pathlib

from PIL import Image
import numpy as np
from sklearn.cluster import MiniBatchKMeans, KMeans
import torch
import torchvision

from include.img2vec_pytorch import Img2Vec


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def crop_cpu(img, crop_sz, step):
    n_channels = len(img.shape)
    if n_channels == 2:
        h, w = img.shape
    elif n_channels == 3:
        h, w, c = img.shape
    else:
        raise ValueError("Wrong image shape - {}".format(n_channels))
    h_space = np.arange(0, h - crop_sz + 1, step)
    w_space = np.arange(0, w - crop_sz + 1, step)
    index = 0
    num_h = 0
    lr_list = []
    for x in h_space:
        num_h += 1
        num_w = 0
        for y in w_space:
            num_w += 1
            index += 1
            if n_channels == 2:
                crop_img = img[x : x + crop_sz, y : y + crop_sz]
            else:
                crop_img = img[x : x + crop_sz, y : y + crop_sz, :]
            lr_list.append(crop_img)
    h = x + crop_sz
    w = y + crop_sz
    return lr_list, num_h, num_w, h, w


def crop_image(image: torch.Tensor, size: int, stride: int):
    """@param image: C, H, W"""
    c, h, w = image.shape
    image = image.unsqueeze(0)
    unfolded = torch.nn.functional.unfold(image, size, stride=stride)
    unfolded = unfolded.permute(0, 2, 1)
    unfolded = unfolded.reshape(-1, c, size, size)
    return unfolded.contiguous()


In [5]:
DATASET_DIR = pathlib.Path().parent / "dataset/cam1/LQ"

image_files = DATASET_DIR.iterdir()
img2vec = Img2Vec(cuda=True)

all_image_vectors = []

resnet = torchvision.models.resnet18(
    weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1
).eval()

for imgfile in image_files:
    image = torchvision.io.read_image(str(imgfile)).to(dtype=torch.float32)
    sub_images = crop_image(image, 32, 28)
    output = torch.empty(sub_images.shape[0], 512, requires_grad=False)
    copy_output = lambda m, i, o: output.copy_(o.detach().squeeze())
    hook = resnet._modules.get("avgpool").register_forward_hook(copy_output)
    resnet(sub_images)
    hook.remove()
    all_image_vectors.append(output)
    
    # labels = kmeans.predict(output)
    # labels = torch.from_numpy(labels).to(dtype=torch.long)

    # image = Image.open(imgfile)
    # subimages = crop_cpu(np.array(image), 32, 28)[0]
    # print(len(subimages), subimages[0].shape)
    # image_vectors = img2vec.get_vec(
    #     [Image.fromarray(sub) for sub in subimages], tensor=True
    # ).squeeze()
    # all_image_vectors.append(image_vectors)




In [27]:
all_image_vectors_cat = torch.cat(all_image_vectors)
print(all_image_vectors_cat.shape)


all_n_clusters = [4, 6, 8, 10]
all_kmeans = {}

for n_clusters in all_n_clusters:
    kmeans = MiniBatchKMeans(n_clusters)
    kmeans.fit(all_image_vectors_cat)

    all_counts = np.empty((len(all_image_vectors), n_clusters), dtype=int)
    for i, image_vectors in enumerate(all_image_vectors):
        labels = kmeans.predict(image_vectors)
        counts = np.bincount(labels, minlength=n_clusters)
        all_counts[i] = counts

    gates_max = np.amax(all_counts, axis=0)
    gates_ptp = np.ptp(all_counts, axis=0)
    all_kmeans[n_clusters] = {'kmeans': kmeans, "gates_max": gates_max, "gates_ptp": gates_ptp}
    print(np.array([gates_max, gates_ptp]))
    # print(np.amin(all_counts, axis=0))


torch.Size([8800, 512])
[[42 20 10 22]
 [ 4  3  4  4]]
[[ 9 32 18 16 18  7]
 [ 3  4  3  4  5  3]]
[[21  8 16 11 20  8  7 15]
 [ 6  3  3  4  5  3  3  7]]
[[19  4 16 15 19  6 10  5  9  2]
 [ 4  3  3  6  4  3  2  2  4  2]]


In [29]:
import pickle

IS_OVERWRITING = False

for n_clusters in all_n_clusters:
    kmeans = all_kmeans[n_clusters]["kmeans"]
    gates_max = all_kmeans[n_clusters]["gates_max"]
    gates_ptp = all_kmeans[n_clusters]["gates_ptp"]
    pkl_file = pathlib.Path(f"kmeans_{n_clusters}.pkl")
    if not IS_OVERWRITING and pkl_file.exists():
        print(f"[INFO] {pkl_file} exists, won't overwriting")
    else: # IS_OVERWRITING
        if pkl_file.exists():
            print(f"[WARNING] {pkl_file} exists, overwriting")
        with open(pkl_file, "wb") as pkl:
            pickle.dump({"kmeans": kmeans, "gates_max": gates_max, "gates_ptp": gates_ptp}, pkl)

    with open(pkl_file, "rb") as pkl:
        kmeans_dict = pickle.load(pkl)
        kmeans_load = kmeans_dict["kmeans"]
        gates_max_load = kmeans_dict["gates_max"]
        gates_ptp_load = kmeans_dict["gates_ptp"]

    all_counts = np.empty((len(all_image_vectors), n_clusters), dtype=int)
    for i, image_vectors in enumerate(all_image_vectors):
        labels = kmeans_load.predict(image_vectors)
        counts = np.bincount(labels, minlength=n_clusters)
        all_counts[i] = counts

    print(np.array([np.amax(all_counts, axis=0), np.ptp(all_counts, axis=0)]))


[INFO] kmeans_4.pkl exists, won't overwriting
[[19 25 18 36]
 [ 6  3  3  7]]
[INFO] kmeans_6.pkl exists, won't overwriting
[[ 9 32 18 16 18  7]
 [ 3  4  3  4  5  3]]
[INFO] kmeans_8.pkl exists, won't overwriting
[[21  8 16 11 20  8  7 15]
 [ 6  3  3  4  5  3  3  7]]
[INFO] kmeans_10.pkl exists, won't overwriting
[[ 6  7 12 27  8  8  8 12 12  4]
 [ 1  3  3  4  5  3  2  6  3  1]]


In [None]:
import torch


def crop_image(image: torch.Tensor, size: int, stride: int):
    n, c, h, w = image.shape
    print(f"{image=}, {image.shape=}")
    unfolded = torch.nn.functional.unfold(image, size, stride=stride)
    # print(f"{unfolded=}, {unfolded.shape=}")
    unfolded = unfolded.permute(0, 2, 1)
    # print(f"{unfolded=}, {unfolded.shape=}")
    unfolded = unfolded.reshape(8, c, size, size)
    print(f"{unfolded=}, {unfolded.shape=}")
    return unfolded


image = torch.arange(3 * 3 * 5, dtype=float).reshape(1, 3, 3, 5).contiguous()
unfolded = crop_image(image, 2, 1)


In [34]:
def random_gradual_03(elem_list):
    random_list = []

    if len(elem_list) == 1:
        random_list.extend([elem_list[0]])
    else:
        for i in range(len(elem_list)):
            if i == len(elem_list) - 1:
                random_list.extend([elem_list[i]] * len(random_list))
            else:
                random_list.extend([elem_list[i]] * 1)

    return random_list

print(f"{random_gradual_03([0]) = }")
print(f"{random_gradual_03([0, 1]) = }")
print(f"{random_gradual_03([0, 2, 4]) = }")
print(f"{random_gradual_03([0, 2, 4, 1]) = }")
print(f"{random_gradual_03([0, 1, 2, 3, 4]) = }")
print(f"{random_gradual_03([0, 1, 0, 1, 0, 1]) = }")

random_gradual_03([0]) = [0]
random_gradual_03([0, 1]) = [0, 1]
random_gradual_03([0, 2, 4]) = [0, 2, 4, 4]
random_gradual_03([0, 1, 2, 3, 4]) = [0, 1, 2, 3, 4, 4, 4, 4]
random_gradual_03([0, 1, 0, 1, 0, 1]) = [0, 1, 0, 1, 0, 1, 1, 1, 1, 1]
