# ⭐ DOCUMENT DETECTION - TRIPLET LOSS

# ![https://a-z-animals.com/media/2021/10/sumatran-tiger-panthera-tigris-sumatrae-cub-standing-on-rock-picture-id1254523938.jpg](https://a-z-animals.com/media/2021/10/sumatran-tiger-panthera-tigris-sumatrae-cub-standing-on-rock-picture-id1254523938.jpg)

# IMPORT REQUIRED LIBRARIES 🌧

In [58]:
import torch
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
from torch.optim import Adam,lr_scheduler
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.nn.parallel import DataParallel
import torchvision.models as models
from PIL import Image
import os, pickle, random, time, cv2, faiss

# PREPROCESSING

In [59]:
transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# CUSTOM DATASET

In [60]:
class CustomDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.image_paths = []
        self.labels = []

        self._load_images()

    def _load_images(self):
        valid_extensions = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
        for class_name in os.listdir(self.folder_path):
            class_folder = os.path.join(self.folder_path, class_name)
            if os.path.isdir(class_folder):
                for filename in os.listdir(class_folder):
                    if filename.lower().endswith(valid_extensions):
                        self.image_paths.append(os.path.join(class_folder, filename))
                        self.labels.append(class_name)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]

        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        return image, label

# CONCATENATING DATASETS

In [62]:
class ConcatenatedDataset(Dataset):
    def __init__(self, dataset_list, transform=None):
        self.dataset_list = dataset_list
        self.transform = transform

    def __len__(self):
        return sum(len(dataset) for dataset in self.dataset_list)

    def __getitem__(self, idx):
        for dataset in self.dataset_list:
            if idx < len(dataset):
                return dataset[idx]
            idx -= len(dataset)
    
    def _transform_image(self, image):
        if self.transform:
            return self.transform(image)
        return image

    def _load_image(self, image_path):
        image = Image.open(image_path).convert("RGB")
        return self._transform_image(image)


# TRAIN AND TEST SET

In [64]:
train_folders = [
                 "./Datasets/TOBACCO/train/",
                 "./Datasets/ARABIC_DOCS/train/",
                 "./Datasets/DOCS/train/",
                 "./Datasets/LANG_DOCS/train/",
                 "./Datasets/MLIMGS/train/"
]

test_folders = [
                 "./Datasets/TOBACCO/test/",
                 "./Datasets/ARABIC_DOCS/test/",
                 "./Datasets/DOCS/test/",
                 "./Datasets/LANG_DOCS/test/",
                 "./Datasets/MLIMGS/test/"
]
train_datasets = [CustomDataset(folder, transform=transform) for folder in train_folders]
test_datasets = [CustomDataset(folder, transform=transform) for folder in test_folders]

In [65]:
combined_train_dataset = ConcatenatedDataset(train_datasets)
combined_test_dataset = ConcatenatedDataset(test_datasets)

# TRIPLETIZATION

In [67]:
class TripletDataset(Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
        self.num_samples = len(self.base_dataset)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        anchor, anchor_label = self.base_dataset[idx]
        positive_idx = idx
        while positive_idx == idx:
            positive_idx = random.randint(0, self.num_samples - 1)
        positive, positive_label = self.base_dataset[positive_idx]

        negative_idx = idx
        while negative_idx == idx or self.base_dataset[negative_idx][1] == anchor_label:
            negative_idx = random.randint(0, self.num_samples - 1)
        negative, negative_label = self.base_dataset[negative_idx]

        return anchor, positive, negative

train_set = ConcatenatedDataset(train_datasets)
test_set = ConcatenatedDataset(test_datasets)

train_set = TripletDataset(train_set)
test_set = TripletDataset(test_set)

In [68]:
len(train_set), len(test_set)

(31549, 3460)

# LOADERS

In [70]:
train_loader = DataLoader(train_set, batch_size=bs, shuffle=True)
test_loader = DataLoader(test_set, batch_size=bs, shuffle=False)

# THE MODEL ✨

In [71]:
class EmbeddingNet(nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__()
        resnet50 = models.resnet50(pretrained=True)
        self.convnet = nn.Sequential(*list(resnet50.children())[:-1])  # Remove the fully connected layer

    def forward(self, x):
        output = self.convnet(x)
        return output

    def get_embedding(self, x):
        return self.forward(x)

In [72]:
emb = EmbeddingNet()

# TRIPLET WRAPPER ☘️

In [73]:
class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super(TripletNet, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, x1, x2=None, x3=None):
        if x2 is None and x3 is None:
            return self.embedding_net(x1)
        return self.embedding_net(x1),self.embedding_net(x2),self.embedding_net(x3)

    def get_embedding(self, x):
        return self.embedding_net(x)

# TRIPLET LOSS 📈

In [74]:
class TripletLoss(nn.Module):
    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative, size_average=True):
        distance_positive = torch.norm(anchor - positive, dim=1)
        distance_negative = torch.norm(anchor - negative, dim=1)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()


# SET THE STUFFS UP  🚀

In [75]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TripletNet(emb)
model = nn.DataParallel(model)
model = model.to(device)
margin = 1
lr = 0.0001
#n_epochs = int(input("NO OF EPOCHS : "))
n_epochs = 2
optimizer = Adam(model.parameters(), lr=lr)
loss_fn = TripletLoss(margin)

In [76]:
model_path = './Models/samplervl50.pth'
torch.save(model.state_dict(), model_path)
state_dict = torch.load(model_path)
model.load_state_dict(state_dict)

<All keys matched successfully>

# EVALUATING

In [52]:
def evaluate_model(model, triplet_test_loader,for_log=False,LIMIT=None):
    model.eval()
    correct = 0
    total = 0
    start = time.time()
    with torch.no_grad():
        for idx,(anchor, positive, negative) in enumerate(triplet_test_loader):
            if for_log and idx==LIMIT:
                return f'ACCURACY: {correct/total*100}% ,TIME: {time.time()-start}'
            anchor_embedding, positive_embedding, negative_embedding = model(anchor.to(device),
                                                                             positive.to(device),
                                                                             negative.to(device))
            distance_positive = torch.norm(anchor_embedding - positive_embedding, dim=1).to(device)
            distance_negative = torch.norm(anchor_embedding - negative_embedding, dim=1).to(device)
            correct += torch.sum(distance_positive < distance_negative).item()
            total += anchor.size(0)
    accuracy = correct / total
    print(accuracy*100,time.time()-start)

# TRAIN 🦜

In [53]:
def fit(model, num_epochs, train_loader,bs):
    for epoch in range(num_epochs):
        start = time.time()
        model.train()
        train_loss = 0.0
        for idx, batch in enumerate(train_loader):
            anchor, positive, negative = batch
            anchor = anchor.to(device)
            positive = positive.to(device)
            negative = negative.to(device)
            optimizer.zero_grad()
            anchor_embedding, positive_embedding, negative_embedding = model(anchor, positive, negative)
            anchor_embedding.requires_grad_(True)
            positive_embedding.requires_grad_(True)
            negative_embedding.requires_grad_(True)
            loss = loss_fn(anchor_embedding, positive_embedding, negative_embedding)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            print(f"({idx + 1}).  LOSS : {loss.item()}  SEEN : {bs * (idx + 1)}/{len(train_loader.dataset)}")
        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss / len(train_loader):.4f}, TIME: {time.time()-start}")
        #print('VALIDATION :')
        #evaluate_model(model, valid_loader)
        print('TESTING :')
        evaluate_model(model, test_loader)

In [None]:
fit(model,n_epochs,train_loader,bs)

In [97]:
torch.save(model,"./Models/PANTHER.pt")
print("MODEL SAVED!")

MODEL SAVED!


# WITH FAISS

In [92]:
train_folders = [
                 "./Datasets/TOBACCO/train/",
                 "./Datasets/ARABIC_DOCS/train/",
                 "./Datasets/DOCS/train/",
                 "./Datasets/LANG_DOCS/train/",
                 "./Datasets/MLIMGS/train/"
]

test_folders = [
                 "./Datasets/TOBACCO/test/",
                 "./Datasets/ARABIC_DOCS/test/",
                 "./Datasets/DOCS/test/",
                 "./Datasets/LANG_DOCS/test/",
                 "./Datasets/MLIMGS/test/"
]
train_datasets = [CustomDataset(folder, transform=transform) for folder in train_folders]
test_datasets = [CustomDataset(folder, transform=transform) for folder in test_folders]

In [93]:
train_set = ConcatenatedDataset(train_datasets)
test_set = ConcatenatedDataset(test_datasets)
train_dataloader = DataLoader(train_set, batch_size=bs, shuffle=True)
test_dataloader = DataLoader(test_set, batch_size=bs, shuffle=False)

# EMBEDDINGS

In [96]:
train_embs = None
train_labels = []
for idx,i in enumerate(train_dataloader):
    if idx%100==0:
        print(idx)
    I, L = i
    train_labels.append(L)
    emb = model(I) # Assuming `model_loaded(I)` returns a PyTorch tensor
    emb = emb.detach()
    if train_embs is None:
        train_embs = emb
    else:
        train_embs = torch.cat((train_embs, emb), dim=0)

0
100


KeyboardInterrupt: 

In [105]:
test_embs = None
test_labels = []
for idx,i in enumerate(test_dataloader):
    if idx%100==0:
        print(idx)
    I, L = i
    test_labels.append(L)
    emb = model(I) # Assuming `model_loaded(I)` returns a PyTorch tensor
    emb = emb.detach()
    if test_embs is None:
        test_embs = emb
    else:
        test_embs = torch.cat((test_embs, emb), dim=0)

0
100
200
300


# DIFFERENT INDICES

In [106]:
embs_cpu_np = train_embs.cpu().numpy()
embs_cpu_np = embs_cpu_np.reshape(embs_cpu_np.shape[0], -1)

index1 = faiss.IndexFlatL2(embs_cpu_np.shape[1])  # Assuming embs_cpu_np.shape[1] represents the dimensionality of the embeddings
index1.add(embs_cpu_np)

nlist = 100  # Number of cells/buckets

quantizer = faiss.IndexFlatL2(embs_cpu_np.shape[1])  # Quantizer index (same as IndexFlatL2)
index2 = faiss.IndexIVFFlat(quantizer, embs_cpu_np.shape[1], nlist)
index2.train(embs_cpu_np)
index2.add(embs_cpu_np)

index3 = faiss.IndexHNSWFlat(embs_cpu_np.shape[1], 128)  # M = 32 for the HNSW index
index3.add(embs_cpu_np)

nbits = 8  # Number of bits for the LSH hash
index4 = faiss.IndexLSH(embs_cpu_np.shape[1], nbits)
index4.add(embs_cpu_np)




In [108]:
def evaluatewithfaiss(embs,index):
    TOTAL = len(embs)
    CORRECT = 0
    start = time.time()
    for idx,emb in enumerate(embs):
        label = index.search(emb.reshape(1,-1),1)[1][0][0]
        if train_labels[label][0]==test_labels[idx][0]:
            CORRECT += 1
    return f'{CORRECT}/{TOTAL}={(CORRECT/TOTAL)*100}',f'TIME = {time.time()-start} SECONDS'

In [109]:
embs2_cpu_np = test_embs.cpu().numpy()
embs2_cpu_np = embs2_cpu_np.reshape(embs2_cpu_np.shape[0], -1)

In [110]:
print(f'IndexIVFFlat : {evaluatewithfaiss(embs2_cpu_np,index2)}')
print(f'IndexHNSWFlat : {evaluatewithfaiss(embs2_cpu_np,index3)}')
print(f'IndexLSH : {evaluatewithfaiss(embs2_cpu_np,index4)}')
print(f'IndexFlatL2 : {evaluatewithfaiss(embs2_cpu_np,index1)}')

IndexIVFFlat : ('359/398=90.20100502512562', 'TIME = 0.07941412925720215 SECONDS')
IndexHNSWFlat : ('358/398=89.9497487437186', 'TIME = 0.25783753395080566 SECONDS')
IndexLSH : ('217/398=54.52261306532663', 'TIME = 0.012828350067138672 SECONDS')
IndexFlatL2 : ('358/398=89.9497487437186', 'TIME = 1.2119579315185547 SECONDS')
