## Example

In this simple example, we load an image, pre-process it, and classify it with a pretrained EfficientNet.

In [2]:
pip install torchvision

Collecting torchvision
  Downloading torchvision-0.22.0-cp313-cp313-macosx_11_0_arm64.whl.metadata (6.1 kB)
Downloading torchvision-0.22.0-cp313-cp313-macosx_11_0_arm64.whl (1.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m430.6 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: torchvision
Successfully installed torchvision-0.22.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip3 install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [3]:
import torch
from torchvision import transforms
from efficientnet_pytorch import EfficientNet
from PIL import Image
import os
import numpy as np
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt


In [4]:
model = EfficientNet.from_pretrained('efficientnet-b0')
model.eval()


Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth" to /Users/tarasetareh/.cache/torch/hub/checkpoints/efficientnet-b0-355c32eb.pth


100.0%

Loaded pretrained weights for efficientnet-b0





EfficientNet(
  (_conv_stem): Conv2dStaticSamePadding(
    3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
    (static_padding): ZeroPad2d((0, 1, 0, 1))
  )
  (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
  (_blocks): ModuleList(
    (0): MBConvBlock(
      (_depthwise_conv): Conv2dStaticSamePadding(
        32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
        (static_padding): ZeroPad2d((1, 1, 1, 1))
      )
      (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_se_reduce): Conv2dStaticSamePadding(
        32, 8, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv2dStaticSamePadding(
        8, 32, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Conv2dStaticSamePadding(
        32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False
    

In [5]:
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])


In [None]:
image_dir = './covid_cxnet_dataset/covid19'
features_list = []

for img_name in os.listdir(image_dir):
    img_path = os.path.join(image_dir, img_name)
    img = Image.open(img_path).convert('RGB')
    img_tensor = transform(img).unsqueeze(0) 

    with torch.no_grad():
        feat = model.extract_features(img_tensor) 
        feat = torch.nn.functional.adaptive_avg_pool2d(feat, 1).squeeze().numpy() 
        features_list.append(feat)


In [None]:
X = np.stack(features_list) 
kmeans = KMeans(n_clusters=2, random_state=42)
labels = kmeans.fit_predict(X)
print(X)


[[-0.02936083  0.07647577 -0.1135061  ... -0.09110641 -0.16290665
   1.3335425 ]
 [-0.12872352 -0.04895796  0.45590958 ... -0.1576537  -0.01354143
   1.2783446 ]
 [-0.03685876 -0.14339817  0.0384442  ... -0.1393349   0.01769202
   0.27133164]
 ...
 [ 0.05318588 -0.17784473 -0.14047313 ... -0.07463481 -0.17176776
   1.5367606 ]
 [-0.05858427 -0.10624857 -0.05262347 ... -0.06321011  0.00421527
   1.616744  ]
 [-0.10871799 -0.05148267  0.11626962 ... -0.10215705 -0.13870943
   1.6855773 ]]


In [11]:
for i, img_name in enumerate(os.listdir(image_dir)):
    print(f"{img_name} → cluster: {labels[i]}")


189.jpg → cluster: 0
374.PNG → cluster: 1
818-.png → cluster: 1
162.jpg → cluster: 0
252.jfif → cluster: 1
638.png → cluster: 0
610.png → cluster: 0
604.png → cluster: 0
837.png → cluster: 1
348.jpg → cluster: 0
360.jpg → cluster: 1
412.jpg → cluster: 0
770.png → cluster: 0
228.jpg → cluster: 0
200.jpg → cluster: 1
566.png → cluster: 0
214.png → cluster: 0
572.png → cluster: 1
599.png → cluster: 0
109.jpeg → cluster: 1
482p-.jpg → cluster: 1
244.jfif → cluster: 0
765.jpg → cluster: 1
215.png → cluster: 0
573.png → cluster: 0
567.png → cluster: 1
172.jpeg → cluster: 1
201.jpg → cluster: 1
229.jpg → cluster: 0
771.png → cluster: 1
060.jpeg → cluster: 1
430.jpeg → cluster: 0
463-.jpg → cluster: 1
413.jpg → cluster: 0
361.jpg → cluster: 0
349.jpg → cluster: 1
836.png → cluster: 1
605.png → cluster: 1
076.jpeg → cluster: 0
611.png → cluster: 0
639.png → cluster: 0
099.jpeg → cluster: 1
471.jpeg → cluster: 1
375.PNG → cluster: 0
188.jpg → cluster: 1
834.jpg → cluster: 1
113.jpeg → cluster: 1