In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import faiss
from tqdm import tqdm
import matplotlib.pyplot as plt

: 

In [None]:
class TripletGeologyDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.labels = np.array([s[1] for s in dataset.samples])
        self.label_to_indices = {label: np.where(self.labels == label)[0]
                                 for label in np.unique(self.labels)}
    
    def __getitem__(self, index):
        anchor_img, anchor_label = self.dataset[index]
        positive_index = index
        # Ensure positive index is different from anchor
        while positive_index == index:
            positive_index = np.random.choice(self.label_to_indices[anchor_label])
        negative_label = np.random.choice(list(set(self.label_to_indices.keys()) - set([anchor_label])))
        negative_index = np.random.choice(self.label_to_indices[negative_label])
        positive_img, _ = self.dataset[positive_index]
        negative_img, _ = self.dataset[negative_index]
        return (anchor_img, positive_img, negative_img), []
    
    def __len__(self):
        return len(self.dataset)

In [None]:
batch_size = 256

# Image transformations
transform = transforms.Compose([
    transforms.ToTensor(),  
    transforms.Normalize((0.5,), (0.5,))  # Normalizes the images
])

# Load the dataset
dataset = datasets.ImageFolder(root='data', transform=transform)

# Data loader
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Create triplet dataset
triplet_dataset = TripletGeologyDataset(dataset)

# Triplet data loader
triplet_loader = DataLoader(triplet_dataset, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)

In [None]:
class EmbeddingNet(nn.Module):
    def __init__(self, embedding_size=128):
        super(EmbeddingNet, self).__init__()
        self.convnet = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(128),
            nn.AdaptiveAvgPool2d(1),
        )
        self.fc = nn.Sequential(
            nn.Linear(128, embedding_size),
            nn.BatchNorm1d(embedding_size),
        )

    def forward(self, x):
        x = self.convnet(x)
        x = x.view(x.size(0), -1)  # Flatten to (batch_size, features)
        x = self.fc(x)
        x = F.normalize(x, p=2, dim=1)  # L2 normalization
        return x

In [None]:
triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)

In [None]:
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

model = EmbeddingNet()
optimizer = Adam(model.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

In [None]:
def train_model(model, data_loader, optimizer, scheduler, num_epochs=20, device='cpu'):
    model.to(device)
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for batch_idx, (data, _) in enumerate(tqdm(data_loader)):
            optimizer.zero_grad()
            anchor, positive, negative = data
            anchor = anchor.to(device)
            positive = positive.to(device)
            negative = negative.to(device)
            anchor_out = model(anchor)
            positive_out = model(positive)
            negative_out = model(negative)
            loss = triplet_loss(anchor_out, positive_out, negative_out)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        scheduler.step()
        epoch_loss = running_loss / len(data_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
    print('CUDA is not available. Training on CPU!')
train_model(model, triplet_loader, optimizer, scheduler, num_epochs=20, device=device)

In [None]:
model.eval()

In [None]:
embedding_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
def generate_embeddings(model, data_loader, device='cpu'):
    embeddings = []
    labels = []
    with torch.no_grad():
        for data, target in tqdm(data_loader):
            data = data.to(device)
            output = model(data)
            embeddings.append(output.cpu().numpy())
            labels.extend(target.numpy())
    embeddings = np.vstack(embeddings)
    labels = np.array(labels)
    return embeddings, labels

embeddings, labels = generate_embeddings(model, embedding_loader, device=device)

In [None]:
np.save('embeddings.npy', embeddings)
np.save('labels.npy', labels)

In [None]:
# Dimension of embeddings
d = embeddings.shape[1]

# Create an index
index = faiss.IndexFlatL2(d)  # For exact search

# Add embeddings to the index
index.add(embeddings.astype('float32'))  # FAISS requires float32 arrays
print(f"Number of vectors in the index: {index.ntotal}")
faiss.write_index(index, 'geology_index.faiss')

In [None]:
index = faiss.read_index('geology_index.faiss')

def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    image = Image.open(image_path)
    if image.mode == 'RGBA':
        image = image.convert('RGB')
    image = transform(image)
    return image.unsqueeze(0)  # Add batch dimension

searchImagePath = './data/rhyolite/ZXB0V.jpg'
# searchImagePath = './ben.png'

query_image = preprocess_image(searchImagePath)

model.eval()

with torch.no_grad():
    query_embedding = model(query_image.to(device)).cpu().numpy()

In [None]:
def retrieve_images(indices, dataset):
    image_paths = [dataset.samples[i][0] for i in indices[0]]
    images = [Image.open(path) for path in image_paths]
    return images

In [None]:
k = 5  # Number of nearest neighbors
distances, indices = index.search(query_embedding, k)


similar_images = retrieve_images(indices, dataset)

In [None]:
def show_images(query_image_path, similar_images):
    query_image = Image.open(query_image_path)
    plt.figure(figsize=(15, 5))
    plt.subplot(1, len(similar_images) + 1, 1)
    plt.imshow(query_image, cmap='gray')
    plt.title('Query Image')
    plt.axis('off')
    for i, img in enumerate(similar_images):
        plt.subplot(1, len(similar_images) + 1, i + 2)
        plt.imshow(img, cmap='gray')
        plt.title(f'Similar Image {i+1}')
        plt.axis('off')
    plt.show()

show_images(searchImagePath, similar_images)

In [None]:
from sklearn.manifold import TSNE

# tsne = TSNE(n_components=2, perplexity=30, random_state=42)
# embeddings_2d = tsne.fit_transform(embeddings)

# plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=labels, cmap='tab10', s=5)
# plt.colorbar()
# plt.show()

# tsne = TSNE(n_components=3, perplexity=30, random_state=42)
# embeddings_3d = tsne.fit_transform(embeddings)
from sklearn.decomposition import PCA

# Reduce embeddings to 3 dimensions
pca = PCA(n_components=3)
embeddings_3d = pca.fit_transform(embeddings)
from mpl_toolkits.mplot3d import Axes3D

# Convert labels to a color map
unique_labels = set(labels)
colors = plt.cm.tab10([i / len(unique_labels) for i in labels])

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Scatter plot
ax.scatter(embeddings_3d[:, 0], embeddings_3d[:, 1], embeddings_3d[:, 2],
           c=colors, s=20, alpha=0.8)

# Add labels and title
ax.set_title("3D Representation of Embeddings")
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
ax.set_zlabel("Dimension 3")

plt.show()

import plotly.express as px
import pandas as pd

# Create a DataFrame for Plotly
df = pd.DataFrame({
    'x': embeddings_3d[:, 0],
    'y': embeddings_3d[:, 1],
    'z': embeddings_3d[:, 2],
    'label': labels
})

# Plot 3D scatter
fig = px.scatter_3d(df, x='x', y='y', z='z', color='label',
                    title="3D Representation of Embeddings")
fig.show()