In [25]:
import os
import torch
import torchvision.models as tvm
from torchvision import transforms
from sklearn.cluster import KMeans
import numpy as np
from PIL import Image

In [26]:
def print_count_files(dir, message = None):
    if os.path.isdir(dir):
        file_count = sum(1 for _ in os.listdir(dir) if os.path.isfile(os.path.join(dir, _)))
        print(f"{message if message else dir}: {file_count}")


def get_all_files(dir):
    if os.path.isdir(dir):
        files = []
        for root, directories, filenames in os.walk(dir):
            for filename in filenames:
                files.append(os.path.join(root, filename))
        return files
    return []



dirs = ["./q4/patch_camelyon/train/0", "./q4/patch_camelyon/train/1", "./q4/patch_camelyon/test",]    
msgs = ["train with Label 0", "train with label 1", "test images"] 

for dir, msg in zip(dirs, msgs):
    print_count_files(dir, msg)


image_lists = [get_all_files(x) for x in dirs]




train with Label 0: 10500
train with label 1: 5500
test images: 2000


In [27]:
base_resin = tvm.resnet50(pretrained=True)
base_resin = torch.nn.Sequential(*list(base_resin.children())[:-1])

for p in base_resin.parameters():
    p.requires_grad = False


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])




In [30]:
def health_check(msg, i, threshold = 100):
    if (i % threshold == 0):
            print(msg)

def image_cleanup(image_list):
    images = []
    for (i, image_path) in enumerate(image_list):
        health_check(f"loading {i}th image", i)
        image = Image.open(image_path)
        image.convert("RGB")
        # print(image_path)
        image = transform(image)
        images.append(image)
    feats = []
    with torch.no_grad():
        for (i, img) in enumerate(images):
            health_check(f"extracting {i}th image features")
            img = img.unsqueeze(0)
            feature = base_resin(img)
            feature = feature.squeeze().numpy()
            feats.append(feature)

    n_clusters = 2
    km = KMeans(n_clusters=n_clusters)
    labels = km.fit_predict(feats)

    # Determine the cluster with the lower number of items
    label_counts = np.bincount(labels)
    lower_cluster_label = np.argmin(label_counts)

    # Get the names of images assigned to the lower cluster label
    assigned_images = [image_list[i] for i, label in enumerate(labels) if label == lower_cluster_label]

    return assigned_images

print(len(image_lists[2]))
image_cleanup(image_lists[2])

2000
loading 0th image
loading 100th image
loading 200th image
loading 300th image
loading 400th image
loading 500th image
loading 600th image
loading 700th image
loading 800th image
loading 900th image
loading 1000th image
loading 1100th image
loading 1200th image
loading 1300th image
loading 1400th image
loading 1500th image
loading 1600th image
loading 1700th image
loading 1800th image
loading 1900th image




['./q4/patch_camelyon/test/x_77495.png',
 './q4/patch_camelyon/test/x_84509.png',
 './q4/patch_camelyon/test/x_30709.png',
 './q4/patch_camelyon/test/x_100609.png',
 './q4/patch_camelyon/test/x_64416.png',
 './q4/patch_camelyon/test/x_38369.png',
 './q4/patch_camelyon/test/x_30292.png',
 './q4/patch_camelyon/test/x_171437.png',
 './q4/patch_camelyon/test/x_206613.png',
 './q4/patch_camelyon/test/x_57341.png',
 './q4/patch_camelyon/test/x_158099.png',
 './q4/patch_camelyon/test/x_40173.png',
 './q4/patch_camelyon/test/x_124350.png',
 './q4/patch_camelyon/test/x_140629.png',
 './q4/patch_camelyon/test/x_229425.png',
 './q4/patch_camelyon/test/x_31577.png',
 './q4/patch_camelyon/test/x_197415.png',
 './q4/patch_camelyon/test/x_77250.png',
 './q4/patch_camelyon/test/x_248795.png',
 './q4/patch_camelyon/test/x_48922.png',
 './q4/patch_camelyon/test/x_199377.png',
 './q4/patch_camelyon/test/x_239151.png',
 './q4/patch_camelyon/test/x_16877.png',
 './q4/patch_camelyon/test/x_101489.png',
 './