# Train & Test on Data with Intraclass Variance : ViT

In [78]:
from transformers import AutoImageProcessor, ViTModel
import torch
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam,lr_scheduler
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.nn.parallel import DataParallel
import torchvision.models as models
from PIL import Image, ImageFile
import numpy as np
from tqdm import tqdm
import faiss
import os
import pickle
import random
import time
import cv2

np.object = np.object_
np.int = np.int_
np.bool = np.bool_

# TRIPLET

In [79]:
class Triplet:
    def __init__(self, train_folder):
        self.train_folder = train_folder
        self.labels = [label for label in os.listdir(train_folder) if label != '.ipynb_checkpoints']
        self.label_to_path = {label: os.path.join(train_folder, label) for label in self.labels}

    def get_triplet(self):
        anchor_label = random.choice(self.labels)
        anchor_path = random.choice(os.listdir(self.label_to_path[anchor_label]))
        positive_label = anchor_label
        positive_path = random.choice(os.listdir(self.label_to_path[positive_label]))
        negative_label = random.choice([label for label in self.labels if label != anchor_label])
        negative_path = random.choice(os.listdir(self.label_to_path[negative_label]))

        anchor_image = os.path.join(self.label_to_path[anchor_label], anchor_path)
        positive_image = os.path.join(self.label_to_path[positive_label], positive_path)
        negative_image = os.path.join(self.label_to_path[negative_label], negative_path)

        anchor_label_num = self.labels.index(anchor_label)
        positive_label_num = self.labels.index(positive_label)
        negative_label_num = self.labels.index(negative_label)

        return anchor_image, positive_image, negative_image

class TripletDataset(Dataset):
    def __init__(self, train_folder, length, transform=None,):
        self.triplet_generator = Triplet(train_folder)
        self.transform = transform
        self.length = length

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        anchor_image, positive_image, negative_image = self.triplet_generator.get_triplet()
        anchor = self._load_image(anchor_image)
        positive = self._load_image(positive_image)
        negative = self._load_image(negative_image)
        return anchor, positive, negative

    def _load_image(self, image_path):
        image = Image.open(image_path).convert("RGB")
        if self.transform is not None:
            image = self.transform(image)
        return image

    def get_triplet_names(self, index):
        anchor_image, positive_image, negative_image = self.triplet_generator.get_triplet()
        return anchor_image, positive_image, negative_image

# Train Loader

In [80]:
import os
folder_path = "/kaggle/input/animals-insects-reptiles/Animals-Insects-Reptiles/Species-Train"

# Get a list of subfolders in the specified folder
subfolders = [subfolder for subfolder in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, subfolder))]

total = 0
# Iterate through the subfolders and count the number of files in each
for subfolder in subfolders:
    subfolder_path = os.path.join(folder_path, subfolder)
    num_files = len([filename for filename in os.listdir(subfolder_path) if os.path.isfile(os.path.join(subfolder_path, filename))])
    print(f"Subfolder: {subfolder}, Number of Files: {num_files}")
    total += num_files
print(total)

Subfolder: lynx_canadensis, Number of Files: 130
Subfolder: loxodonta_cyclotis, Number of Files: 130
Subfolder: leptailurus_serval, Number of Files: 130
Subfolder: acinonyx_jubatus, Number of Files: 130
Subfolder: leopardus_wiedii, Number of Files: 130
Subfolder: felis_silvestris, Number of Files: 130
Subfolder: leopardus_pardalis, Number of Files: 130
Subfolder: elephas_maximus, Number of Files: 130
Subfolder: herpailurus_yagouaroundi, Number of Files: 130
Subfolder: felis_lybica, Number of Files: 130
1300


In [81]:
bs = 16
train_folder = "/kaggle/input/animals-insects-reptiles/Animals-Insects-Reptiles/Species-Train"

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
    transforms.ToTensor(),           # Convert to a PyTorch tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize the tensor
])
train_dataset = TripletDataset(train_folder, 1300, transform=transform, )
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)

# Network

In [82]:
import torch
import torch.nn as nn
import torchvision.models as models

class TEmbeddingNet(nn.Module):
    def __init__(self, modelt):
        super(TEmbeddingNet, self).__init__()
        self.modelt = modelt

    def forward(self, x):
        x = self.modelt(x)  # Shape: (batch_size, 2048, H, W)
        return x.last_hidden_state

    def get_embedding(self, x):
        x = self.modelt(x)  # Shape: (batch_size, 2048, H, W)
        return x.last_hidden_state

image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

tmodel = TEmbeddingNet(model)

In [84]:
class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super(TripletNet, self).__init__()
        self.enet = embedding_net

    def forward(self, x1, x2=None, x3=None):
        if x2 is None and x3 is None:
            return self.enet.get_embedding(x1)
        return self.enet.get_embedding(x1),self.enet.get_embedding(x2),self.enet.get_embedding(x3)

    def get_embedding(self, x):
        return self.enet.get_embedding(x)

# Loss, Device, Parameters

In [85]:
class TripletLoss(nn.Module):
    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative, size_average=True):
        distance_positive = torch.norm(anchor - positive, dim=1)
        distance_negative = torch.norm(anchor - negative, dim=1)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()

In [86]:
model = TripletNet(tmodel)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(TripletNet(tmodel))
else:
    model = TripletNet(tmodel)
# Move the model to the selected device (CPU or GPU)
model = model.to(device)

In [87]:
margin = 1
lr = 0.0001
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=3, gamma=0.5)  # Learning rate scheduler
loss_fn = TripletLoss(margin)
clip_value = 0.5  # You can adjust this value as needed

# Fit

In [88]:
def fit(model, num_epochs, train_loader, bs):
    for epoch in range(n_epochs):
        start = time.time()
        model.train()
        train_loss = 0.0

        for idx, batch in enumerate(train_loader):
            anchor, positive, negative = batch
            anchor = anchor.to(device)
            positive = positive.to(device)
            negative = negative.to(device)

            optimizer.zero_grad()
            anchor_embedding, positive_embedding, negative_embedding = model(anchor, positive, negative)
            anchor_embedding.requires_grad_(True)
            positive_embedding.requires_grad_(True)
            negative_embedding.requires_grad_(True)
            loss = loss_fn(anchor_embedding, positive_embedding, negative_embedding)
            loss.backward()
            clip_grad_norm_(model.parameters(), clip_value)
            optimizer.step()
            train_loss += loss.item()
            if idx%10==0:
                print(loss.item())
        print(f"Epoch {epoch + 1}/{n_epochs}, Train Loss: {train_loss / len(train_loader):.4f}, TIME: {time.time()-start}")
        scheduler.step()


fit(model, n_epochs:=int(input("NO OF EPOCHS : ")), train_loader, bs)


NO OF EPOCHS :  10


0.9419780969619751
0.5718819499015808
0.46392497420310974
0.435746431350708
0.3477965295314789
0.2732747197151184
0.41396546363830566
0.6759012937545776
0.2052626609802246
Epoch 1/10, Train Loss: 0.4697, TIME: 80.2966718673706
0.26862633228302
0.18844661116600037
0.36903461813926697
0.2968551516532898
0.35544896125793457
0.27230557799339294
0.44292712211608887
0.37228116393089294
0.19519975781440735
Epoch 2/10, Train Loss: 0.2761, TIME: 80.65482115745544
0.2640100419521332
0.24788737297058105
0.39497971534729004
0.32253706455230713
0.17366936802864075
0.2623654007911682
0.1968090832233429
0.22086220979690552
0.16596153378486633
Epoch 3/10, Train Loss: 0.2431, TIME: 80.38875031471252
0.07644326984882355
0.14680920541286469
0.17981955409049988
0.07786305993795395
0.0982128381729126
0.011171698570251465
0.0975198969244957
0.09838885068893433
0.042773425579071045
Epoch 4/10, Train Loss: 0.1247, TIME: 80.11810731887817
0.1397906392812729
0.17582854628562927
0.24987852573394775
0.04602783173

# Test

In [89]:
class CustomDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.image_paths = []
        self.labels = []

        self._load_images()

    def _load_images(self):
        valid_extensions = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
        for class_name in os.listdir(self.folder_path):
            class_folder = os.path.join(self.folder_path, class_name)
            if os.path.isdir(class_folder):
                for filename in os.listdir(class_folder):
                    if filename.lower().endswith(valid_extensions):
                        self.image_paths.append(os.path.join(class_folder, filename))
                        self.labels.append(class_name)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]

        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        return image, label

In [90]:
train_folder = "/kaggle/input/animals-insects-reptiles/Animals-Insects-Reptiles/Species-Train"
train_dataloader = DataLoader(CustomDataset(train_folder,transform=transform))
test_folder = "/kaggle/input/animals-insects-reptiles/Animals-Insects-Reptiles/Species-Test"
test_dataloader = DataLoader(CustomDataset(test_folder,transform=transform))

In [91]:
start = time.time()
train_embs = None
train_labels = []
for i in tqdm(train_dataloader):
    I, L = i
    train_labels.append(L)
    emb = tmodel(I.to(device)) # Assuming `model_loaded(I)` returns a PyTorch tensor
    emb = emb.detach()
    if train_embs is None:
        train_embs = emb
    else:
        train_embs = torch.cat((train_embs, emb), dim=0)
print(time.time()-start)

100%|██████████| 1300/1300 [00:24<00:00, 54.17it/s]

24.004308462142944





In [92]:
start = time.time()
test_embs = None
test_labels = []
for i in tqdm(test_dataloader):
    I, L = i
    try:
        emb = tmodel(I.to(device)) # Assuming `model_loaded(I)` returns a PyTorch tensor
        emb = emb.detach()
        if test_embs is None:
            test_embs = emb
        else:
            test_embs = torch.cat((test_embs, emb), dim=0)
        test_labels.append(L)
    except:
        print("ERROR")
print(time.time()-start)

100%|██████████| 200/200 [00:03<00:00, 66.37it/s]

3.020395517349243





In [93]:
embs_cpu_np = train_embs.cpu().numpy()
embs_cpu_np = embs_cpu_np.reshape(embs_cpu_np.shape[0], -1)
index = faiss.IndexHNSWFlat(embs_cpu_np.shape[1], 32)  # M = 32 for the HNSW index
index.add(embs_cpu_np)

In [95]:
def evaluate_with_faiss(embs, index):
    TOTAL = len(embs)
    CORRECT = 0
    start = time.time()
    
    # Initialize the tqdm progress bar
    with tqdm(total=TOTAL) as pbar:
        for idx, emb in enumerate(embs):
            label = index.search(emb.reshape(1, -1), 1)[1][0][0]
            if train_labels[label][0] == test_labels[idx][0]:
                CORRECT += 1
            pbar.update(1)  # Update the progress bar

    accuracy = (CORRECT / TOTAL) * 100
    elapsed_time = time.time() - start
    return f'Accuracy: {CORRECT}/{TOTAL} = {accuracy:.2f}%, Time: {elapsed_time:.2f} seconds'


In [97]:
embs2_cpu_np = test_embs.cpu().numpy()
embs2_cpu_np = embs2_cpu_np.reshape(embs2_cpu_np.shape[0], -1)

In [98]:
print(f'IndexHNSWFlat : {evaluate_with_faiss(embs2_cpu_np,index)}')

100%|██████████| 200/200 [00:08<00:00, 24.05it/s]

IndexHNSWFlat : Accuracy: 155/200 = 77.50%, Time: 8.32 seconds



