In [1]:
import pandas as pd

# Loading dataset
ind_df = pd.read_csv('data/indigenous_collection_processed.csv', index_col='id')
print(f'Dataframe columns: \n{ind_df.columns}')

Dataframe columns: 
Index(['url', 'thumbnail', 'creation_date', 'modification_date',
       'numero_do_item', 'tripticos', 'categoria', 'nome_do_item',
       'nome_do_item_dic', 'colecao', 'coletor', 'doador', 'modo_de_aquisicao',
       'data_de_aquisicao', 'ano_de_aquisicao', 'data_de_confeccao', 'autoria',
       'nome_etnico', 'descricao', 'dimensoes', 'funcao', 'materia_prima',
       'tecnica_confeccao', 'descritor_tematico', 'descritor_comum',
       'numero_de_pecas', 'itens_relacionados', 'responsavel_guarda',
       'inst_detentora', 'povo', 'autoidentificacao', 'lingua',
       'estado_de_origem', 'geolocalizacao', 'pais_de_origem', 'exposicao',
       'referencias', 'disponibilidade', 'qualificacao', 'historia_adm',
       'notas_gerais', 'observacao', 'conservacao', 'image_path'],
      dtype='object')


In [2]:
from IPython.core.magic import register_cell_magic

# Creating skip cell command
@register_cell_magic
def skip(line, cell):
    return

# Image Clustering

Clustering experiments with image feature extractors. The idea is to fine-tune some pre-trained models on our dataset and then remove the last layer of the model to cluster on the embedding space projections.

## Dataset Preparation

For fine-tuning the model on our dataset, we are going to try a few different labels and study how they affect the generated emebdding space. For now, we focus *povo* and *categoria*.

In [3]:
from PIL import Image

# Filtering out corrupted images
corrupted_images = []
for index, row in ind_df.loc[ind_df['image_path'].notna()].iterrows():
    try:
        Image.open(row['image_path'])
    except Exception as e:
        # print(e)
        corrupted_images.append(row['image_path'])
        ind_df.loc[index, 'image_path'] = pd.NA
print(f'{len(corrupted_images)} corrupted images')

# Creating 'image_path_br' column
ind_df['image_path_br'] = ind_df['image_path'].values
ind_df.loc[ind_df['image_path_br'].notna(), 'image_path_br'] = \
    ind_df.loc[ind_df['image_path_br'].notna(), \
               'image_path'].apply(lambda path: \
                                   f"data/br_images/{path.split('/')[-1].split('.')[0]}.png")

# Preparing labels for dataset training
label_column = 'povo' # 'categoria', 'povo', 'ano_de_aquisicao'
name_to_num = {c: i for i, c in enumerate(ind_df[label_column].unique())}
labels = {row['image_path_br']: name_to_num[row[label_column]] \
          for index, row in ind_df.loc[ind_df['image_path_br'].notna()].iterrows()}

1 corrupted images


In [4]:
import os
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Getting the proper device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Creating the ImageDataset class and the DataLoader object to avoid loading all the images
# simultaneously and run out of GPU memory
class ImageDataset(Dataset):
    def __init__(self, image_dir, labels, transform=None):
        self.image_dir = image_dir
        self.image_files = [f for f in os.listdir(image_dir) \
                            if f.endswith(('.png', '.jpg', '.jpeg'))]
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = self.labels.get(image_path, -1)
        return image, label

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = ImageDataset("data/br_images/", labels, transform=transform)

## ViT Base Patch-16

### Pre-trained Embedding Space

In [5]:
# Projecting data onto the off-the-shelf pre-trained embedding space from ViT
import numpy as np
from transformers import ViTImageProcessor, ViTModel
from tqdm import tqdm

# Loading model
model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
model.to(device)

# Getting data
dataloader = DataLoader(dataset, batch_size=512, shuffle=True, num_workers=0, pin_memory=True)

# Function to iterating over data to get projections
def get_embeddings(model, dataloader):
    image_embeddings = []
    for batch_images, _ in tqdm(dataloader, desc="Computing embeddings"):
        batch_images = batch_images.to(device)
        with torch.no_grad():
            outputs = model(batch_images)
        
        # Do I get the last_hidden_state of CLS token or the pooler_output?
        # embeddings = outputs['last_hidden_state'][:, 0, :]
        embeddings = outputs['pooler_output']
        image_embeddings.append(embeddings.cpu())
    return image_embeddings

image_embeddings = np.concatenate(get_embeddings(model, dataloader), axis=0)

Computing embeddings: 100%|████| 23/23 [02:37<00:00,  6.86s/it]


In [6]:
# Computing data projection
import trimap

proj_trimap = trimap.TRIMAP(n_dims=2, n_inliers=12, n_outliers=6, n_random=3,\
                            weight_temp=0.5, lr=0.1, apply_pca=True)
vanilla_vit = proj_trimap.fit_transform(image_embeddings)

### Fine-tuning Embedding Space

In [7]:
import numpy as np
from transformers import ViTImageProcessor, ViTModel
from tqdm import tqdm
import trimap

In [8]:
# Creating our own ViT classifier head for fine-tuning
import torch.nn as nn
import torch.optim as optim

class ViTClassifier(nn.Module):
    def __init__(self, num_classes):
        super(ViTClassifier, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.classifier = nn.Linear(self.vit.config.hidden_size, num_classes)

    def forward(self, x):
        outputs = self.vit(x)
        
        # Do I get the last_hidden_state of CLS token or the pooler_output?
        # embeddings = outputs['last_hidden_state'][:, 0, :]
        embeddings = outputs['pooler_output']

        logits = self.classifier(embeddings)
        return logits

In [9]:
# Training function
def train_loop(model, dataloader, epochs=20):
    model.train()
    for epoch in tqdm(range(epochs), desc=f"Training model", leave=True):
        batch_loss = 0
        for batch_images, batch_labels in dataloader:
            batch_images, batch_labels = batch_images.to(device), batch_labels.to(device)
            
            opt.zero_grad()
            logits = model(batch_images)
            loss = criterion(logits, batch_labels)
            loss.backward()
            opt.step()

            batch_loss += loss.item()

            # # Freeing space
            # del batch_images, batch_labels, logits, loss
            # torch.cuda.empty_cache()
        
        tqdm.write(f'Epoch {epoch+1}, Loss: {batch_loss}')
        
# Training set-up and execution
num_classes = ind_df['povo'].nunique()
model = ViTClassifier(num_classes).to(device)
criterion = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr=5e-5, weight_decay=0)
epochs = 100
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, \
                        num_workers=0, pin_memory=True)

train_loop(model, dataloader, epochs)

Training model:   5%|▍       | 1/20 [04:41<1:29:00, 281.06s/it]

Epoch 1, Loss: 1102.6993127465248


Training model:   5%|▍       | 1/20 [05:26<1:43:30, 326.87s/it]


KeyboardInterrupt: 

In [None]:
# Computing image embeddings

In [None]:
# Computing data projection
proj_trimap = trimap.TRIMAP(n_dims=2, n_inliers=12, n_outliers=6, n_random=3,\
                            weight_temp=0.5, lr=0.1, apply_pca=True)
vanilla_vit = proj_trimap.fit_transform(image_embeddings)

### Visualizing and Comparing Projections

In [None]:
# Visualizing resulting projections
import matplotlib.pyplot as plt

plt.figure(figsize=(8,4))
plt.suptitle('Comparing Projections of ViT Models')

# Plotting vanilla ViT projections
plt.subplot(1, 2, 1)
plt.scatter(vanilla_vit[:, 0], vanilla_vit[:, 1], c='b')
plt.title("Vanilla ViT")
plt.xlabel("")
plt.ylabel("")
plt.xticks([])
plt.yticks([])

# Plotting ViT fine-tuned on 'povo' projections
plt.subplot(1, 2, 2)
plt.scatter(povo_vit[:, 0], povo_vit[:, 1], c='b')
plt.title("ViT Fine-Tuned on Povo")
plt.xlabel("")
plt.ylabel("")
plt.xticks([])
plt.yticks([])

plt.tight_layout()
plt.show()