In [33]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoImageProcessor, ResNetForImageClassification
import torch
from torchvision import transforms
import numpy as np
from tqdm import tqdm

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load Data

In [3]:
dataset = load_dataset("labels", data_dir="./")
labels = dataset["train"].features["label"].names
num_labels=len(labels)

Resolving data files:   0%|          | 0/335 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/335 [00:00<?, ?files/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [4]:
id2label = {i:dataset["train"].features["label"].names[i] for i in range(num_labels)}
label2id = {label:idx for idx,label in id2label.items() }

## Transform and Augment Data

In [5]:
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor, Resize

In [6]:
checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


In [63]:
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)

size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)

_transforms = Compose([
    Resize(size),#RandomResizedCrop(size), 
    ToTensor(), 
    normalize])

In [64]:
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    examples["image_names"] = [img.filename.split("/")[-1] for img in examples["image"] ]
    del examples["image"]
    return examples

In [9]:
dataset = dataset.with_transform(transforms)

In [10]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

## build model

In [11]:
import evaluate
accuracy = evaluate.load("accuracy")

Downloading builder script: 0.00B [00:00, ?B/s]

In [12]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [13]:
from transformers import AutoModelForImageClassification

In [46]:
model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [47]:
model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

## Training

In [19]:
from transformers import TrainingArguments, Trainer

In [48]:
training_args = TrainingArguments(
    output_dir="./image_classifier",
    remove_unused_columns=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=25,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset["train"],
    eval_dataset=dataset["train"],
    processing_class=image_processor,
    compute_metrics=compute_metrics,
)

In [49]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,1.767505,0.6
2,1.668000,1.452908,0.78806
3,1.668000,1.13742,0.785075
4,1.124000,0.907603,0.80597
5,0.779100,0.744653,0.880597
6,0.779100,0.625319,0.895522
7,0.615700,0.534103,0.934328
8,0.615700,0.463122,0.958209
9,0.424300,0.411647,0.958209
10,0.352500,0.369873,0.979104


TrainOutput(global_step=125, training_loss=0.5184059734344483, metrics={'train_runtime': 441.1309, 'train_samples_per_second': 18.985, 'train_steps_per_second': 0.283, 'total_flos': 5.4401814531477504e+17, 'train_loss': 0.5184059734344483, 'epoch': 20.952380952380953})

In [50]:
trainer.save_model("./model.pt")

In [51]:
model.save_pretrained("image_classifier")

## validate per class

In [33]:
from PIL import Image
import os
import seaborn as sns
from tqdm import tqdm

In [None]:
data_path = "labels"
classes_dirs = [d for d in os.listdir(data_path) if "." not in d ]

In [None]:
def predict_image_at(class_name,img_name):
    img = Image.open(f"{data_path}/{class_name}/{img_name}")
    pixel_values = _transforms(img.convert("RGB")).to(device)
    #pixel_values = torch.from_numpy(image_processor(img.convert("RGB"))["pixel_values"][0]).to(device)

    outputs = None
    with torch.no_grad():
        outputs = model(pixel_values.unsqueeze(0))
    #return outputs
    prediction = outputs.logits[0].argmax()
    prediction_label = id2label[int(prediction)]
    img.close()
    return prediction_label

In [None]:
performance = dict()
for cl in tqdm(classes_dirs):
    class_results = []
    for img in [img for img in os.listdir(f"{data_path}/{cl}") if ".jpg" in img]:
        out = predict_image_at(cl,img)
        class_results.append(1 if out == cl else 0)
    cl_result = np.array(class_results).mean()
    performance[cl] = cl_result

In [None]:
sns.barplot(performance)

## classify images

In [52]:
# Load the model
loaded_model = AutoModelForImageClassification.from_pretrained("./model.pt").to(device)
loaded_model.eval()

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

In [65]:
unlabeled_data = load_dataset("images", data_dir="./",)
unlabeled_data = unlabeled_data.with_transform(transforms)

Resolving data files:   0%|          | 0/17095 [00:00<?, ?it/s]

In [66]:
unlabeled_dataloader = DataLoader(
        unlabeled_data['train'],
        batch_size=128,
        shuffle=False,  # Don't shuffle for consistent processing
        drop_last=False  # Keep all images
)

In [57]:
model.eval()

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

## save clustering dictionary

In [67]:
clustering = {i:[] for i in id2label.keys()}
with torch.no_grad():
    for image_batch in tqdm(unlabeled_dataloader):
        pixel_values = image_batch["pixel_values"].to(device)
        output = model(pixel_values)
        class_ids = output.logits.argmax(dim=1)
        for class_id, fname in zip(class_ids, image_batch["image_names"]):
            clustering[int(class_id)].append(fname)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 134/134 [09:40<00:00,  4.33s/it]


In [68]:
import pickle
img2cluster = {img: id2label[idx] for idx, cluster in clustering.items() for img in cluster}
with open("clustering.pkl","wb")as f:
    pickle.dump(img2cluster, f )

## save clustered images

In [None]:
for _, label in id2label.items():
    os.makedirs(f"clustering/{label}",exist_ok=True)
img2cluster = {img: idx for idx, cluster in clustering.items() for img in cluster}
for img, idx in tqdm(img2cluster.items()):
    label = id2label[idx]
    Image.open(f"data/{img_name}").save(f"clustering/{label}/{img_name}")
    img_back_name = img_name.replace("_1.jpg","_2.jpg")
    if os.path.exists(f"data/{img_back_name}"):
        Image.open(f"{path}/{imgname}").save(f"clustered/{label}/{img_back_name}")

## Inspect Image

In [None]:
img = Image.open(f"path/to/img.jpg")

In [None]:
pixel_values1 = torch.from_numpy(image_processor(img.convert("RGB"))["pixel_values"][0]).to(device)
pixel_values2 = _transforms(img.convert("RGB")).to(device)

In [None]:
def predict_unlabeled_image(img):
    pixel_values = _transforms(img.convert("RGB")).to(device)
    #pixel_values = torch.from_numpy(image_processor(img.convert("RGB"))["pixel_values"][0]).to(device)

    outputs = model(pixel_values.unsqueeze(0))
    #return outputs
    prediction = outputs.logits[0].argmax()
    prediction_label = id2label[int(prediction)]
    return prediction_label

In [None]:
with torch.no_grad():
    print(loaded_model(pixel_values1.unsqueeze(0)).logits.argmax(-1))

In [None]:
with torch.no_grad():
    print(loaded_model(pixel_values2.unsqueeze(0)).logits.argmax(-1))

In [None]:
id2label[7]

In [None]:
dataset["train"]["label"][0]

In [None]:
image = d["image"][0]