In [1]:
import math
import logging
from functools import partial
from collections import OrderedDict
from einops import rearrange, repeat

import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import load_pretrained
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model

import os
from torch.utils.data import Dataset
import random
import numpy as np

from torch.utils.data import DataLoader
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.manifold import TSNE

import matplotlib
matplotlib.use('Agg')  # Use the Agg backend
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.cm as cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap

import seaborn as sns
from scipy.stats import gaussian_kde
from scipy.optimize import brentq
from scipy.stats import gaussian_kde
from scipy.ndimage import gaussian_filter1d
from matplotlib.lines import Line2D
from matplotlib.patches import Patch


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class GestureTransformer(nn.Module):
    def __init__(self, num_frame=101, num_joints=21, in_chans=3, embed_dim_ratio=32, depth=4,
                 num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2,  norm_layer=None, num_classes = 0):
        """    ##########hybrid_backbone=None, representation_size=None,
        Args:
            num_frame (int, tuple): input frame number
            num_joints (int, tuple): joints number
            in_chans (int): number of input channels, 2D joints have 2 channels: (x,y)
            embed_dim_ratio (int): embedding dimension ratio
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
            drop_rate (float): dropout rate
            attn_drop_rate (float): attention dropout rate
            drop_path_rate (float): stochastic depth rate
            norm_layer: (nn.Module): normalization layer
        """
        super().__init__()

        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        embed_dim = embed_dim_ratio * num_joints   #### temporal embed_dim is num_joints * spatial embedding dim ratio
        out_dim = num_classes


        ### spatial patch embedding
        self.Spatial_patch_to_embedding = nn.Linear(in_chans, embed_dim_ratio)
        self.Spatial_pos_embed = nn.Parameter(torch.zeros(1, num_joints, embed_dim_ratio))

        self.Temporal_pos_embed = nn.Parameter(torch.zeros(1, num_frame, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)


        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule

        self.Spatial_blocks = nn.ModuleList([
            Block(
                dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])

        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])

        self.Spatial_norm = norm_layer(embed_dim_ratio)
        self.Temporal_norm = norm_layer(embed_dim)

        ####### A easy way to implement weighted mean
        self.weighted_mean = torch.nn.Conv1d(in_channels=num_frame, out_channels=1, kernel_size=1)

        self.head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim , out_dim),
        )


    def Spatial_forward_features(self, x):
        b, _, f, p = x.shape  ##### b is batch size, f is number of frames, p is number of joints
        x = rearrange(x, 'b c f p  -> (b f) p  c', )
        x = self.Spatial_patch_to_embedding(x)
        x += self.Spatial_pos_embed
        x = self.pos_drop(x)

        for blk in self.Spatial_blocks:
            x = blk(x)

        x = self.Spatial_norm(x)
        x = rearrange(x, '(b f) w c -> b f (w c)', f=f)
        return x

    def forward_features(self, x):
        b  = x.shape[0]
        x += self.Temporal_pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)

        x = self.Temporal_norm(x)
        ##### x size [b, f, emb_dim], then take weighted mean on frame dimension, we only predict 3D pose of the center frame
        x = self.weighted_mean(x)
        x = x.view(b, 1, -1)
        return x

    def forward(self, x):
        bs, t, w, k, d = x.shape
        x = x.view(bs * t, w, k, d)
        x = x.permute(0, 3, 1, 2)
        bs_temp, _, _, p = x.shape
        x = self.Spatial_forward_features(x)
        x = self.forward_features(x)
        x = self.head(x)
        out_dim = 45
        x = x.view(bs, 3, out_dim)

        return x



In [2]:
class MyCustomDataset(Dataset):
    def __init__(self, root_dir, mode='train', sliding_window_size=101):
        self.root_dir = root_dir
        self.subjects = os.listdir(root_dir)
        self.frame_data = []  # Stores the information of each frame
        self.labels = [] # Stores the labels of each frame
        self.indexvalues = []
        self.sliding_window_size = sliding_window_size
        self.class_frame_dict={}
        self._process_data(mode)

    def __len__(self):
        return len(self.frame_data)


    def __getitem__(self, index):
        center_frame = self.frame_data[index]
        anchor_label = self.labels[index]
        anchor_skeleton_number = center_frame['skeleton_number']

        # Positive sample: Random frame from the same gesture class as anchor
        positive_frames = self.class_frame_dict[anchor_label]
        positive_frame = random.choice(positive_frames)
        positive_index_data = self.frame_data[positive_frame['frame_number']]
        positive_frame_number = positive_index_data['frame_number']
        positive_label = self.labels[positive_frame_number]
        positive_skeleton_number = positive_index_data['skeleton_number']

        # Negative sample: Random frame from a different gesture class
        # Select a random gesture class that is not the same as the anchor class
        negative_label = random.choice([label for label in self.class_frame_dict.keys() if label != anchor_label])
        negative_frames = self.class_frame_dict[negative_label]
        negative_frame = random.choice(negative_frames)
        negative_index_data = self.frame_data[negative_frame['frame_number']]
        negative_frame_number = negative_index_data['frame_number']
        negative_label = self.labels[negative_frame_number]
        negative_skeleton_number = negative_index_data['skeleton_number']

        skeleton_start_index = self.indexvalues[anchor_skeleton_number]['start_index']
        skeleton_end_index = self.indexvalues[anchor_skeleton_number]['end_index']

        positive_start_index = self.indexvalues[positive_skeleton_number]['start_index']
        positive_end_index = self.indexvalues[positive_skeleton_number]['end_index']


        negative_start_index = self.indexvalues[negative_skeleton_number]['start_index']
        negative_end_index = self.indexvalues[negative_skeleton_number]['end_index']


        start_index = index - self.sliding_window_size // 2
        end_index = index + self.sliding_window_size // 2


        pos_start_index = positive_frame_number - self.sliding_window_size // 2
        pos_end_index = positive_frame_number + self.sliding_window_size // 2


        neg_start_index = negative_frame_number - self.sliding_window_size // 2
        neg_end_index = negative_frame_number + self.sliding_window_size // 2

        if (start_index < skeleton_start_index):
          start_index = skeleton_start_index

        else:
          start_index = max(start_index, 0)

        if (end_index > skeleton_end_index):
          end_index = skeleton_end_index


        if (pos_start_index < positive_start_index):
          pos_start_index = positive_start_index

        else:
          pos_start_index = max(pos_start_index, 0)

        if (pos_end_index > positive_end_index):
          pos_end_index = positive_end_index


        if (neg_start_index < negative_start_index):
          neg_start_index = negative_start_index

        else:
          neg_start_index = max(neg_start_index, 0)

        if (neg_end_index > negative_end_index):
          neg_end_index = negative_end_index

        anchor = []

        for i in range(start_index, end_index + 1):

          frame_data = self.frame_data[i]
          joint_coordinates = frame_data['joint_coordinates']
          anchor.append(joint_coordinates)

        while len(anchor) < self.sliding_window_size:
          anchor.append([0.0] * 63)  # Assuming 21 joints with (x, y, z) coordinates


        anchor = torch.tensor(anchor, dtype=torch.float32)
        anchor = anchor.view(self.sliding_window_size, 21, 3)

        positive = []

        for i in range(pos_start_index, pos_end_index + 1):

          frame_data = self.frame_data[i]
          joint_coordinates = frame_data['joint_coordinates']
          positive.append(joint_coordinates)

        while len(positive) < self.sliding_window_size:
          positive.append([0.0] * 63)  # Assuming 21 joints with (x, y, z) coordinates


        positive = torch.tensor(positive, dtype=torch.float32)
        positive = positive.view(self.sliding_window_size, 21, 3)

        negative = []

        for i in range(neg_start_index, neg_end_index + 1):

          frame_data = self.frame_data[i]
          joint_coordinates = frame_data['joint_coordinates']
          negative.append(joint_coordinates)

        while len(negative) < self.sliding_window_size:
          negative.append([0.0] * 63)  # Assuming 21 joints with (x, y, z) coordinates


        negative = torch.tensor(negative, dtype=torch.float32)
        negative = negative.view(self.sliding_window_size, 21, 3)

        # Create a tensor with three dimensions: anchor, positive, and negative samples
        sample = torch.stack([anchor, positive, negative])

        return sample, anchor_label



    def _process_data(self, mode):
        frame_counter = 0  # Counter for sequential frame numbering
        skeleton_temp=0
        indices = []
        class_frames_dict = {}  # Dictionary to store frames by gesture class

        for subject in self.subjects:
            subject_dir = os.path.join(self.root_dir, subject)

            if mode == 'train':
                data_dir = os.path.join(subject_dir, 'TrainingData')
            elif mode == 'test':
                data_dir = os.path.join(subject_dir, 'TestingData')
            else:
                raise ValueError("Invalid mode. Mode should be 'train' or 'test'.")

            action_folders = os.listdir(subject_dir)

            for action_folder in action_folders:
                action_dir = os.path.join(subject_dir, action_folder)
                number_folders = os.listdir(action_dir)

                for number_folder in number_folders:
                    number_dir = os.path.join(action_dir, number_folder)
                    skeleton_file = os.path.join(number_dir, "skeleton.txt")

                    # Read and process the skeleton.txt file to extract frame-level data
                    frames, labels , start , end= self._read_skeleton_file(skeleton_file, frame_counter, action_folder , skeleton_temp)

                    indices.append({
                    'start_index': start,
                    'end_index': end,

                    })
                    # Extend the frame_data list with the frame-level data
                    self.frame_data.extend(frames)
                    self.labels.extend(labels)

                    # Update the frame counter based on the number of frames in the current skeleton file
                    frame_counter += len(frames)
                    skeleton_temp = skeleton_temp + 1

        self.indexvalues.extend(indices)

        # Store frames in class-specific lists
        for label, frame in zip(self.labels, self.frame_data):
          if label not in class_frames_dict:
            class_frames_dict[label] = []
          class_frames_dict[label].append(frame)
        self.class_frame_dict.update(class_frames_dict)


    def _read_skeleton_file(self, skeleton_file, frame_counter, action , skeleton_temp):
        frames = []
        labels = []
        start_ind = frame_counter
        temp_index = 0

        # Read and parse the skeleton.txt file to extract frame-level data
        with open(skeleton_file, 'r') as f:
            for line in f:
                # Parse the line and extract frame number, joint coordinates, and label
                frame_info = line.strip().split(' ')
                frame_number = frame_counter + int(frame_info[0])
                joint_coordinates = [float(coord) for coord in frame_info[1:]]

                label = action
                end_ind = frame_number

                # Append the frame data and label to the respective lists
                frames.append({
                    'frame_number': frame_number,
                    'joint_coordinates': joint_coordinates,
                    'skeleton_number' : skeleton_temp,

                })
                labels.append(label)

                temp_index = frame_number

        return frames, labels, start_ind, temp_index



class TripletLossCosine(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLossCosine, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        # Cosine similarity
        similarity_positive = nn.functional.cosine_similarity(anchor, positive)
        similarity_negative = nn.functional.cosine_similarity(anchor, negative)

        # Calculate triplet loss
        triplet_loss = torch.relu(similarity_negative - similarity_positive + self.margin)

        return torch.mean(triplet_loss)


class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):

        distance_positive = torch.sqrt(torch.sum((anchor - positive) ** 2, dim=1))
        distance_negative = torch.sqrt(torch.sum((anchor - negative) ** 2, dim=1))
        losses = torch.relu(distance_positive - distance_negative + self.margin)

        triplet_loss = torch.mean(losses)


        return triplet_loss



def compute_distance(embeddings1, embeddings2, distance_metric='euclidean'):
    if distance_metric == 'euclidean':
        distances = torch.norm(embeddings1 - embeddings2, p=2, dim=2)

    elif distance_metric == 'cosine':
        similarity = torch.nn.functional.cosine_similarity(embeddings1, embeddings2, dim=2)
        distances = 1 - similarity  # Convert similarity to distance
    else:
        raise ValueError("Unsupported distance metric. Choose 'euclidean' or 'cosine'.")

    return distances



def calculate_accuracy(anchor_positive_distances, anchor_negative_distances, threshold):
    # Assuming anchor-positive distances are similar pairs and anchor-negative distances are dissimilar pairs
    similar_pairs = anchor_positive_distances <= threshold
    dissimilar_pairs = anchor_negative_distances > threshold

    # True positives: Similar pairs correctly classified as similar
    true_positives = np.sum(similar_pairs)

    # True negatives: Dissimilar pairs correctly classified as dissimilar
    true_negatives = np.sum(dissimilar_pairs)

    # False positives: Dissimilar pairs incorrectly classified as similar
    false_positives = np.sum(anchor_negative_distances <= threshold)

    # False negatives: Similar pairs incorrectly classified as dissimilar
    false_negatives = np.sum(anchor_positive_distances > threshold)

    # Calculate accuracy
    accuracy = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives)

    return accuracy

In [None]:
# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


train_root_dir = '/kaggle/input/fphadataset/FPHADataset/TrainingData'
test_root_dir = '/kaggle/input/fphadataset/FPHADataset/TestingData'


# Create an instance of the training dataset
train_dataset = MyCustomDataset(train_root_dir, mode='train')

# Create an instance of the testing dataset
test_dataset = MyCustomDataset(test_root_dir, mode='test')

# Set the batch size
batch_size = 32

# Create data loaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)


train_dataset_length = len(train_dataset)

test_dataset_length = len(test_dataset)

# Set the number of output classes
num_classes = 45 # Replace with the embedding dimension value
emb_dim = num_classes
# Set the hyperparameters
#Sliding window size
num_frame =101
num_joints = 21
in_chans = 3
# Embedding Dimension
embed_dim_ratio = 64
depth = 4
num_heads = 8
mlp_ratio = 2.
qkv_bias = True
qk_scale = None
drop_rate = 0.
attn_drop_rate = 0.
drop_path_rate = 0.2

# Create an instance of the PoseTransformer model
model = GestureTransformer(num_frame=num_frame, num_joints=num_joints, in_chans=in_chans, embed_dim_ratio=embed_dim_ratio,
                        depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                        drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
                        num_classes =num_classes)

# Move the model to the device
model = model.to(device)

#Set the optimizer and learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Set the number of training epochs
num_epochs = 10
#print(num_epochs)
#print(num_classes)
unique_labels = set(train_dataset.labels)

# Create a label encoder
label_encoder = LabelEncoder()

# Fit label encoder on the training labels
label_encoder.fit(train_dataset.labels)

# Convert training labels to numerical values
train_labels_encoded = label_encoder.transform(train_dataset.labels)

# Convert the encoded labels to tensors
train_labels_tensor = torch.tensor(train_labels_encoded, dtype=torch.long).to(device)

# Set the threshold variable
threshold = None

# Create an instance of the TripletLoss
compute_triplet_loss = TripletLoss()


# Training loop
for epoch in range(num_epochs):
    #print("epoch")
    #print(epoch)
    model.train()  # Set the model to training mode
    total_triplet_loss = 0.0

    # Initialize an empty list to store anchor embeddings
    all_anchor_embeddings_tsne = []
    all_pos_embeddings_tsne = []
    all_neg_embeddings_tsne = []

    for triplet_sample, labels in train_dataloader:  # Load triplets

        triplet_sample = triplet_sample.to(device)

        labels_encoded = label_encoder.transform(labels)  # Convert current batch labels to numerical values
        labels_tensor = torch.tensor(labels_encoded, dtype=torch.long).to(device)

        # Forward pass for the entire triplet sample
        triplet_embeddings = model(triplet_sample)  # Shape: (bs, 3, f)


        anchor_embeddings = triplet_embeddings[:, 0, :]  # Select the anchor embeddings
        positive_embeddings = triplet_embeddings[:, 1, :]  # Select the positive embeddings
        negative_embeddings = triplet_embeddings[:, 2, :]  # Select the negative embeddings

        # Compute triplet loss
        triplet_loss = compute_triplet_loss(anchor_embeddings, positive_embeddings, negative_embeddings)

        optimizer.zero_grad()
        triplet_loss.backward()
        optimizer.step()

        total_triplet_loss += triplet_loss.item()

        anchor_embeddings_tsne = anchor_embeddings.view(-1, 1, emb_dim)
        pos_embeddings_tsne = positive_embeddings.view(-1, 1, emb_dim)
        neg_embeddings_tsne = negative_embeddings.view(-1, 1, emb_dim)

        # Append anchor_embeddings to the list

        all_anchor_embeddings_tsne.append(anchor_embeddings_tsne)
        all_pos_embeddings_tsne.append(pos_embeddings_tsne)
        all_neg_embeddings_tsne.append(neg_embeddings_tsne)


    average_triplet_loss = total_triplet_loss / len(train_dataloader)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Average Triplet Loss: {average_triplet_loss:.4f}')


all_anchor_embeddings_tsne = torch.cat(all_anchor_embeddings_tsne, dim=0)
all_pos_embeddings_tsne = torch.cat(all_pos_embeddings_tsne, dim=0)
all_neg_embeddings_tsne = torch.cat(all_neg_embeddings_tsne, dim=0)

# Save the anchor embeddings to a file in the working directory
torch.save(all_anchor_embeddings_tsne, 'anchor_embeddings.pth')
torch.save(all_pos_embeddings_tsne, 'positive_embeddings.pth')
torch.save(all_neg_embeddings_tsne, 'negative_embeddings.pth')

# Save the trained model to the working directory
torch.save(model.state_dict(), 'gesture_comparison_pose_transformer_model.pth')

pos_distances = compute_distance(all_anchor_embeddings_tsne, all_pos_embeddings_tsne)

neg_distances = compute_distance(all_anchor_embeddings_tsne, all_neg_embeddings_tsne)


cos_pos_distances = compute_distance(all_anchor_embeddings_tsne, all_pos_embeddings_tsne, distance_metric='cosine')

cos_neg_distances = compute_distance(all_anchor_embeddings_tsne, all_neg_embeddings_tsne, distance_metric='cosine')


minimum_distance = min(torch.min(pos_distances), torch.min(neg_distances)).item()
maximum_distance = max(torch.max(pos_distances), torch.max(neg_distances)).item()

cos_minimum_distance = min(torch.min(cos_pos_distances), torch.min(cos_neg_distances)).item()
cos_maximum_distance = max(torch.max(cos_pos_distances), torch.max(cos_neg_distances)).item()


pos_distances_np = pos_distances.detach().cpu().numpy()

neg_distances_np = neg_distances.detach().cpu().numpy()

cos_pos_distances_np = cos_pos_distances.detach().cpu().numpy()

cos_neg_distances_np = cos_neg_distances.detach().cpu().numpy()

sns.set_style("whitegrid")  # Set plot style

# Calculate the minimum and maximum values across both datasets
min_distance = min(np.min(pos_distances_np), np.min(neg_distances_np))
max_distance = max(np.max(pos_distances_np), np.max(neg_distances_np))

# Create histograms with the same bin width for both datasets
plt.figure(figsize=(8, 6))
pos_counts, pos_bins, _ = plt.hist(pos_distances_np, bins=100, range=(min_distance, max_distance), alpha=0.5, color="blue", label="Anchor-Positive Distances")
neg_counts, neg_bins, _ = plt.hist(neg_distances_np, bins=100, range=(min_distance, max_distance), alpha=0.5, color="red", label="Anchor-Negative Distances")

plt.xlabel("Distance")
plt.ylabel("Number of Samples")
plt.title("Histogram of Anchor-Positive and Anchor-Negative Euclidean Distances ")

# Find the intersection point (threshold)
def find_intersection1(x):
    pos_indices = np.searchsorted(pos_bins, x, side='right') - 1
    neg_indices = np.searchsorted(neg_bins, x, side='right') - 1
    pos_interpolated = np.interp(x, pos_bins[:-1], pos_counts)
    neg_interpolated = np.interp(x, neg_bins[:-1], neg_counts)
    return pos_interpolated - neg_interpolated

minimum_distance = max(min(pos_distances_np), min(neg_distances_np))
maximum_distance = min(max(pos_distances_np), max(neg_distances_np))


# Calculate the intersection point within the defined range
euclidean_intersection_point = brentq(find_intersection1, minimum_distance, maximum_distance)

#print(f"Euclidean Intersection Point (Threshold): {euclidean_intersection_point:.2f}")

plt.axvline(x=euclidean_intersection_point, color='green', linestyle='--', label=f"Threshold: {euclidean_intersection_point:.2f}")
plt.legend()

# Save the histogram plot as an image in the working directory
plt.savefig('euclidean_threshold_plot.png', dpi=300)  # Save as PNG with high resolution
plt.close()  # Close the plot to release memory



# Calculate the minimum and maximum values across both datasets
min_distance = min(np.min(cos_pos_distances_np), np.min(cos_neg_distances_np))
max_distance = max(np.max(cos_pos_distances_np), np.max(cos_neg_distances_np))

# Create histograms with the same bin width for both datasets
plt.figure(figsize=(8, 6))
pos_counts, pos_bins, _ = plt.hist(cos_pos_distances_np, bins=100, range=(min_distance, max_distance), alpha=0.5, color="blue", label="Anchor-Positive Distances")
neg_counts, neg_bins, _ = plt.hist(cos_neg_distances_np, bins=100, range=(min_distance, max_distance), alpha=0.5, color="red", label="Anchor-Negative Distances")

plt.xlabel("Distance")
plt.ylabel("Number of Samples")
plt.title("Histogram of Anchor-Positive and Anchor-Negative Cosine Distances ")

# Find the intersection point (threshold)
def find_intersection1(x):
    pos_indices = np.searchsorted(pos_bins, x, side='right') - 1
    neg_indices = np.searchsorted(neg_bins, x, side='right') - 1
    pos_interpolated = np.interp(x, pos_bins[:-1], pos_counts)
    neg_interpolated = np.interp(x, neg_bins[:-1], neg_counts)
    return pos_interpolated - neg_interpolated

minimum_distance = max(min(cos_pos_distances_np), min(cos_neg_distances_np))
maximum_distance = min(max(cos_pos_distances_np), max(cos_neg_distances_np))


#Calculate the intersection point within the defined range
cosine_intersection_point = brentq(find_intersection1, minimum_distance, maximum_distance)

#print(f"Cosine Intersection Point (Threshold): {cosine_intersection_point:.2f}")

plt.axvline(x=cosine_intersection_point, color='green', linestyle='--', label=f"Threshold: {cosine_intersection_point:.2f}")
plt.legend()

# Save the histogram plot as an image in the working directory
plt.savefig('cosine_threshold_plot.png', dpi=300)  # Save as PNG with high resolution
plt.close()  # Close the plot to release memory

threshold = euclidean_intersection_point

# Set the model to evaluation mode
model.eval()

# Initialize lists to store predictions and targets
test_predictions = []
test_targets = []

# Initialize an empty list to store anchor embeddings
all_anchor_embeddings = []
all_pos_embeddings = []
all_neg_embeddings = []


# Testing loop
with torch.no_grad():
    for triplet_sample, labels in test_dataloader:  # Load triplets and labels
        triplet_sample = triplet_sample.to(device)

        labels_encoded = label_encoder.transform(labels)  # Convert current batch labels to numerical values
        labels_tensor = torch.tensor(labels_encoded, dtype=torch.long).to(device)

        # Forward pass for the entire triplet sample
        triplet_embeddings = model(triplet_sample)  # Shape: (bs, 3, f)

        # Split the embeddings into anchor, positive, and negative parts
        anchor_embeddings = triplet_embeddings[:, 0, :]  # Select the anchor embeddings
        positive_embeddings = triplet_embeddings[:, 1, :]  # Select the positive embeddings
        negative_embeddings = triplet_embeddings[:, 2, :]  # Select the negative embeddings

        anchor_embeddings = anchor_embeddings.view(-1, 1, emb_dim)
        positive_embeddings = positive_embeddings.view(-1, 1, emb_dim)
        negative_embeddings = negative_embeddings.view(-1, 1, emb_dim)

        # Append anchor_embeddings to the list
        all_anchor_embeddings.append(anchor_embeddings)
        all_pos_embeddings.append(positive_embeddings)
        all_neg_embeddings.append(negative_embeddings)

        test_targets.append(labels_tensor)


# Concatenate all anchor embeddings into a single tensor
all_anchor_embeddings_tsne = torch.cat(all_anchor_embeddings, dim=0)
anchor_embeddings = all_anchor_embeddings_tsne
all_pos_embeddings_tsne = torch.cat(all_pos_embeddings, dim=0)
all_neg_embeddings_tsne = torch.cat(all_neg_embeddings, dim=0)



pos_distances = compute_distance(all_anchor_embeddings_tsne, all_pos_embeddings_tsne)

neg_distances = compute_distance(all_anchor_embeddings_tsne, all_neg_embeddings_tsne)


cos_pos_distances = compute_distance(all_anchor_embeddings_tsne, all_pos_embeddings_tsne, distance_metric='cosine')

cos_neg_distances = compute_distance(all_anchor_embeddings_tsne, all_neg_embeddings_tsne, distance_metric='cosine')



pos_distances_np = pos_distances.detach().cpu().numpy()

neg_distances_np = neg_distances.detach().cpu().numpy()

cos_pos_distances_np = cos_pos_distances.detach().cpu().numpy()

cos_neg_distances_np = cos_neg_distances.detach().cpu().numpy()



# Calculate accuracy
accuracy = calculate_accuracy(pos_distances_np, neg_distances_np, threshold)

accuracy_cos = calculate_accuracy(cos_pos_distances_np, cos_neg_distances_np, threshold)


print(f'Accuracy: {accuracy * 100:.2f}%')


targets = torch.cat(test_targets, dim=0)


custom_colors = ['#a26989', '#924ff1', '#6029e8', '#f0215b', '#8da6ae', '#fc8cd9', '#5cc8ff', '#d2cd29', '#fc85a8', '#a55e73', '#89b38f', '#3c2b60', '#27ba26', '#5b274d', '#84421a', '#c142c0', '#97bf08', '#fe978e', '#c3dd15', '#5b7ff6', '#fd979c', '#13a66c', '#ba2b12', '#963409', '#d155d4', '#86ad4d', '#da12ba', '#0273dd', '#998723', '#d7d618', '#213eb7', '#d291c8', '#86690b', '#caf200', '#c546e2', '#0ab739', '#8b0624', '#90acc7', '#944934', '#565835', '#50efbb', '#905ee0', '#f3a7ea', '#bbd9f9', '#a452d9']


# Convert anchor_embeddings to numpy array
anchor_embeddings_np = anchor_embeddings.cpu().numpy()  # Assuming anchor_embeddings is a PyTorch tensor

#Flatten the anchor embeddings to a 2D array
flattened_anchor_embeddings = anchor_embeddings.view(anchor_embeddings.size(0), -1).cpu().numpy()


# Get the labels for the tsne embeddings
tsne_labels = targets.cpu().numpy()

perplexity_values = [5, 10, 20, 30, 50, 100]
iteration_values = [250, 500, 1000, 2000]

plt.figure(figsize=(15, 15))
plot_num = 1
num= 1

cmap = ListedColormap(custom_colors[:len(np.unique(tsne_labels))])

for perplexity in perplexity_values:
    for iterations in iteration_values:
        tsne = TSNE(n_components=2, perplexity=perplexity, n_iter=iterations, random_state=42)
        tsne_embeddings = tsne.fit_transform(flattened_anchor_embeddings)
        plt.subplot(len(perplexity_values), len(iteration_values), plot_num)
        plt.scatter(tsne_embeddings[:, 0], tsne_embeddings[:, 1], c=tsne_labels, cmap=cmap , s=0.5)
        plt.title(f'Perplexity: {perplexity}, Iterations: {iterations}')
        plt.show()
        plt.xticks([])
        plt.yticks([])
        plt.savefig('tsne_plot' + str(num) + '.png', dpi=300)
        plot_num += 1
        num = num+1


2
45
epoch
0
Epoch [1/2], Average Triplet Loss: 0.5145
epoch
1
Epoch [2/2], Average Triplet Loss: 0.4237
Euclidean Intersection Point (Threshold): 2.08
Cosine Intersection Point (Threshold): 0.01
Accuracy: 71.13%
Accuracy cos: 50.00%
