# 0. Install the dependencies

In [1]:
!pip install open3d
!pip install learning3d

Collecting open3d
  Downloading open3d-0.18.0-cp310-cp310-manylinux_2_27_x86_64.whl.metadata (4.2 kB)
Collecting dash>=2.6.0 (from open3d)
  Downloading dash-2.17.0-py3-none-any.whl.metadata (10 kB)
Collecting configargparse (from open3d)
  Downloading ConfigArgParse-1.7-py3-none-any.whl.metadata (23 kB)
Collecting ipywidgets>=8.0.4 (from open3d)
  Downloading ipywidgets-8.1.2-py3-none-any.whl.metadata (2.4 kB)
Collecting addict (from open3d)
  Downloading addict-2.4.0-py3-none-any.whl.metadata (1.0 kB)
Collecting pyquaternion (from open3d)
  Downloading pyquaternion-0.9.9-py3-none-any.whl.metadata (1.4 kB)
Collecting dash-html-components==2.0.0 (from dash>=2.6.0->open3d)
  Downloading dash_html_components-2.0.0-py3-none-any.whl.metadata (3.8 kB)
Collecting dash-core-components==2.0.0 (from dash>=2.6.0->open3d)
  Downloading dash_core_components-2.0.0-py3-none-any.whl.metadata (2.9 kB)
Collecting dash-table==5.0.0 (from dash>=2.6.0->open3d)
  Downloading dash_table-5.0

In [2]:
!python --version

Python 3.10.13


In [3]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import open3d as o3d
import torch
from learning3d.models import PointNet, DGCNN, PPFNet, DCP
from learning3d.models import Classifier

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [4]:
device = o3d.core.Device("CUDA:1")
dtype = o3d.core.float32

# 1. Load and preprocess data

In [5]:
"""%%time
mesh = o3d.io.read_triangle_mesh("/kaggle/input/2d-3d-retrieval-dataset/wb_2D3Dretrieval_dataset/database/0.stl")
pcd = mesh.sample_points_uniformly(1024)
point_arr = np.asarray(pcd.points)
print(f"Point cloud arr shape {point_arr.shape}, max: {np.max(point_arr)}, min: {np.min(point_arr)}")"""

'%%time\nmesh = o3d.io.read_triangle_mesh("/kaggle/input/2d-3d-retrieval-dataset/wb_2D3Dretrieval_dataset/database/0.stl")\npcd = mesh.sample_points_uniformly(1024)\npoint_arr = np.asarray(pcd.points)\nprint(f"Point cloud arr shape {point_arr.shape}, max: {np.max(point_arr)}, min: {np.min(point_arr)}")'

In [6]:
df = pd.read_csv('/kaggle/input/2d-3d-retrieval-dataset/wb_2D3Dretrieval_dataset/labels.csv')

In [7]:
df.head()

Unnamed: 0,query,label
0,0.png,"0.stl,17.stl,70.stl,110.stl,130.stl"
1,1.png,"1.stl,13.stl,19.stl,223.stl,119.stl"
2,10.png,"10.stl,155.stl,237.stl,75.stl,15.stl"
3,100.png,"100.stl,72.stl,216.stl,127.stl,55.stl"
4,101.png,"101.stl,134.stl,248.stl,31.stl,199.stl"


In [8]:
df['label'] = df['label'].apply(lambda x: x.split(','))

In [9]:
# split the df into train and test
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

In [10]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 201 entries, 136 to 102
Data columns (total 2 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   query   201 non-null    object
 1   label   201 non-null    object
dtypes: object(2)
memory usage: 4.7+ KB


In [11]:
test_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 51 entries, 165 to 38
Data columns (total 2 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   query   51 non-null     object
 1   label   51 non-null     object
dtypes: object(2)
memory usage: 1.2+ KB


In [12]:
%%time

# Load all the database into the RAM, in form of a dictionary

db_names = os.listdir('/kaggle/input/2d-3d-retrieval-dataset/wb_2D3Dretrieval_dataset/database')
#db_list = ['/kaggle/input/2d-3d-retrieval-dataset/wb_2D3Dretrieval_dataset/database/' + path for path in db_names]
database_arr = {}

for path in db_names:
    mesh = o3d.io.read_triangle_mesh('/kaggle/input/2d-3d-retrieval-dataset/wb_2D3Dretrieval_dataset/database/' + path)
    pcd = mesh.sample_points_uniformly(1024)
    point_arr = np.asarray(pcd.points)
    database_arr[path] = point_arr

CPU times: user 1min 15s, sys: 11.9 s, total: 1min 27s
Wall time: 1min 37s


In [13]:
import pandas as pd
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms

import albumentations as A
from albumentations.pytorch import ToTensorV2

from timm.data import create_transform
import timm

In [14]:
class MCDataset(Dataset):
    def __init__(self, data, transform : transforms.Compose = None):
        # Get a transform pipeline for image.    
        if transform: 
             self.transform = transform 
        else: 
            self.transform = A.Compose([ 
                A.Resize(224, 224), 
                A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 
                ToTensorV2()
                ], 
                    keypoint_params=A.KeypointParams(format='xy', remove_invisible=False))  
        self.data = data

    def __getitem__(self, index):
        root = '/kaggle/input/2d-3d-retrieval-dataset/wb_2D3Dretrieval_dataset/'
        # Anchor image
        img_path = root + 'queries/' + self.data.iloc[index].query
        img = Image.open(img_path).convert("RGB")
        img_tensor = self.transform(img)
        # Positive 3d model
        label_id = self.data.iloc[index].label[0]

        # Negative 3d model (randomly select from the rest of the dataset)
        neg_index = np.random.choice(np.delete(np.arange(len(self.data)), index))
        neg_label_id = self.data.iloc[neg_index].label[0]
        
        return img_tensor, database_arr[label_id], database_arr[neg_label_id]

    def __len__(self):
        return len(self.data)

In [15]:
train_transform = create_transform(224, is_training=True)
test_transform = create_transform(224)

train_set = MCDataset(train_df, train_transform)
test_set = MCDataset(test_df, test_transform)

BATCH_SIZE = 32
train_dataloader = DataLoader(train_set, batch_size=BATCH_SIZE)
test_dataloader = DataLoader(test_set, batch_size=BATCH_SIZE)

In [16]:
for img_tensor, label, neg in train_dataloader:
    print(img_tensor.shape, label.shape, neg.shape)
    break
for img_tensor, label, neg in test_dataloader:
    print(img_tensor.shape, label.shape, neg.shape)
    break

torch.Size([32, 3, 224, 224]) torch.Size([32, 1024, 3]) torch.Size([32, 1024, 3])
torch.Size([32, 3, 224, 224]) torch.Size([32, 1024, 3]) torch.Size([32, 1024, 3])


# 2. Build Model

In [17]:
class MCClassifier(nn.Module):

    def __init__(self, pc_model, image_model, num_pc_features, num_img_features, dropout=0.5):

        super(MCClassifier, self).__init__()

        self.img_model = image_model
        self.pc_model = pc_model
        self.dropout = nn.Dropout(dropout)
        self.batch_norm = nn.BatchNorm1d(512)  # Add batch normalization layer
        self.relu = nn.ReLU()

        # Image MLP
        self.lni1 = nn.Linear(num_img_features, 512)
        self.lni2 = nn.Linear(512,256)
        #self.img_ln = nn.Linear(1280, 1024)
        
        # Point Cloud MLP
        self.lnp1 = nn.Linear(1024, 512)
        self.lnp2 = nn.Linear(512,256)
     
        
        # For scaling the point cloud features masp
        self.conv1d = nn.Conv1d(1024, 1, kernel_size=1)
    def forward(self, image, pc):
        # 2D Image
        image_features = self.img_model(image)
        img_emb = self.dropout(image_features)
        img_emb = self.lni1(img_emb)
        img_emb = self.batch_norm(img_emb)
        img_emb = self.relu(img_emb)
        #img_emb = self.lni2(img_emb)
        #img_emb = self.relu(img_emb)
        
        # Point Cloud
        pc_features = self.pc_model(pc)
        pc_features = self.conv1d(pc_features).squeeze(1)
        pc_emb = self.dropout(pc_features)
        pc_emb = self.lnp1(pc_emb)
        pc_emb = self.batch_norm(pc_emb)  # Apply batch normalization
        pc_emb = self.relu(pc_emb)
        #pc_emb = self.linear2(out)
        #pc_emb = self.relu(out)
        

        return img_emb, pc_emb

In [18]:
"""dgcnn = DGCNN(emb_dims=1024, input_shape='bnc')
pnt_cls = Classifier(feature_model=dgcnn)"""

"dgcnn = DGCNN(emb_dims=1024, input_shape='bnc')\npnt_cls = Classifier(feature_model=dgcnn)"

In [19]:
pn = PointNet(emb_dims=1024, use_bn=True)
pnt_cls = Classifier(feature_model=pn)
pt_path = '/kaggle/input/3d-cls-pretrained-model/pytorch/modelnet40-pt/1/best_model.t7'
pnt_cls.load_state_dict(torch.load(pt_path, map_location='cpu'))
pc_model = pnt_cls.feature_model
pc_model.to('cuda')

PointNet(
  (conv1): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
  (conv2): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
  (conv3): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
  (conv4): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
  (conv5): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))
  (relu): ReLU()
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn5): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [20]:
img_model = timm.create_model('mobilenetv2_100', pretrained=True, num_classes = 0).to('cuda')#.get_classifier()
img_model.to('cuda')
print("CU")

model.safetensors:   0%|          | 0.00/14.2M [00:00<?, ?B/s]

CU


In [21]:
model = MCClassifier(pc_model, img_model, 1024 ,img_model.num_features).to('cuda')

In [22]:
from torch.optim import Adam
from tqdm import tqdm
loss = nn.TripletMarginLoss(margin=1.0)
optimizer = Adam(model.parameters(), lr= 1e-4)

In [23]:
"""pnt_tensor = torch.from_numpy(point_arr).unsqueeze(0).float().to('cuda')
pnt_tensor.shape

test_transform = create_transform(224)

img = Image.open('/kaggle/input/2d-3d-retrieval-dataset/wb_2D3Dretrieval_dataset/queries/0.png')
img_tensor = test_transform(img)
img_tensor.shape

model.eval()
print("CI")

for data in train_dataloader:
    model(anchor_img.to('cuda').float(), pos_char.to('cuda').float())
    
    break"""

'pnt_tensor = torch.from_numpy(point_arr).unsqueeze(0).float().to(\'cuda\')\npnt_tensor.shape\n\ntest_transform = create_transform(224)\n\nimg = Image.open(\'/kaggle/input/2d-3d-retrieval-dataset/wb_2D3Dretrieval_dataset/queries/0.png\')\nimg_tensor = test_transform(img)\nimg_tensor.shape\n\nmodel.eval()\nprint("CI")\n\nfor data in train_dataloader:\n    model(anchor_img.to(\'cuda\').float(), pos_char.to(\'cuda\').float())\n    \n    break'

# 3. Training 

In [24]:
def train(model, train_data, val_data, learning_rate, epochs, batch_size):

    train_transform = create_transform(224, is_training=True)
    test_transform = create_transform(224)
    train_set = MCDataset(train_df, train_transform)
    test_set = MCDataset(test_df, test_transform)
    BATCH_SIZE = batch_size
    train_dataloader = DataLoader(train_set, batch_size=BATCH_SIZE)
    val_dataloader = DataLoader(test_set, batch_size=BATCH_SIZE)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    criterion = nn.TripletMarginLoss(margin=1.0)
    optimizer = Adam(model.parameters(), lr= learning_rate)
    if use_cuda:
        model = model.cuda()
        criterion = criterion.cuda()
    
    for epoch_num in range(epochs):

        train_loss = []
        mrr_train = []
        #model.train()
        for data in tqdm(train_dataloader):
            anchor_img, pos_char, neg_char = data
            anchor_feat, pos_feat = model(anchor_img.to('cuda').float(), pos_char.to('cuda').float())
            _, neg_feat = model(anchor_img.to('cuda').float(), neg_char.to('cuda').float())

            # Calculate triplet loss
            batch_loss = criterion(anchor_feat, pos_feat, neg_feat)
            train_loss.append(float(batch_loss))

            # Backpropagation and update weights
            model.zero_grad()
            batch_loss.backward()
            optimizer.step()
            
        avg_loss = np.mean(train_loss)
        print("epoch %2d train end : avg_loss = %.4f" % (epoch_num+1, avg_loss))
        #print("epoch %2d train end : avg_acc = %.4f" % (epoch_num+1, avg_train_acc))

        test_loss = []
        test_mrr = []
        #model.eval()
        with torch.no_grad():

            for data in tqdm(val_dataloader):

                anchor_img, pos_char, neg_char = data
                anchor_feat, pos_feat = model(anchor_img.to('cuda').float(), pos_char.to('cuda').float())
                _, neg_feat = model(anchor_img.to('cuda').float(), neg_char.to('cuda').float())

                batch_loss = criterion(anchor_feat, pos_feat, neg_feat)
                test_loss.append(float(batch_loss))

                #batch_acc = (output.round() == val_label).float().mean()
                #test_acc.append(float(batch_acc))

            avg_test_loss = np.mean(test_loss)
            #mrr_test = np.mean(test_acc)
            print("epoch %2d test end : avg_loss = %.4f" % (epoch_num+1, avg_test_loss))
            #print("epoch %2d test end : avg_acc = %.4f" % (epoch_num+1, avg_test_acc))
            
    torch.save(model, 'retrival_net.ckpt')


In [25]:
batch_size = 16
EPOCHS = 100
LR = 1e-3
# SAVE TRAINING LOGS.
train(model, train_df, test_df, LR, EPOCHS, batch_size)

100%|██████████| 13/13 [00:04<00:00,  2.63it/s]


epoch  1 train end : avg_loss = 1.2226


100%|██████████| 4/4 [00:00<00:00,  5.19it/s]


epoch  1 test end : avg_loss = 1.2687


100%|██████████| 13/13 [00:03<00:00,  4.19it/s]


epoch  2 train end : avg_loss = 1.1634


100%|██████████| 4/4 [00:00<00:00,  6.18it/s]


epoch  2 test end : avg_loss = 1.4078


100%|██████████| 13/13 [00:03<00:00,  4.19it/s]


epoch  3 train end : avg_loss = 1.0362


100%|██████████| 4/4 [00:00<00:00,  6.19it/s]


epoch  3 test end : avg_loss = 1.0949


100%|██████████| 13/13 [00:03<00:00,  4.26it/s]


epoch  4 train end : avg_loss = 1.0654


100%|██████████| 4/4 [00:00<00:00,  6.16it/s]


epoch  4 test end : avg_loss = 0.9823


100%|██████████| 13/13 [00:03<00:00,  4.18it/s]


epoch  5 train end : avg_loss = 1.0794


100%|██████████| 4/4 [00:00<00:00,  5.99it/s]


epoch  5 test end : avg_loss = 1.0718


100%|██████████| 13/13 [00:03<00:00,  4.23it/s]


epoch  6 train end : avg_loss = 1.0874


100%|██████████| 4/4 [00:00<00:00,  6.18it/s]


epoch  6 test end : avg_loss = 1.1070


100%|██████████| 13/13 [00:03<00:00,  4.09it/s]


epoch  7 train end : avg_loss = 1.0357


100%|██████████| 4/4 [00:00<00:00,  6.19it/s]


epoch  7 test end : avg_loss = 1.1051


100%|██████████| 13/13 [00:03<00:00,  4.23it/s]


epoch  8 train end : avg_loss = 1.0230


100%|██████████| 4/4 [00:00<00:00,  6.13it/s]


epoch  8 test end : avg_loss = 1.2766


100%|██████████| 13/13 [00:03<00:00,  4.17it/s]


epoch  9 train end : avg_loss = 0.9756


100%|██████████| 4/4 [00:00<00:00,  6.08it/s]


epoch  9 test end : avg_loss = 1.0777


100%|██████████| 13/13 [00:03<00:00,  4.17it/s]


epoch 10 train end : avg_loss = 0.9158


100%|██████████| 4/4 [00:00<00:00,  6.18it/s]


epoch 10 test end : avg_loss = 1.0347


100%|██████████| 13/13 [00:03<00:00,  4.09it/s]


epoch 11 train end : avg_loss = 1.0388


100%|██████████| 4/4 [00:00<00:00,  5.87it/s]


epoch 11 test end : avg_loss = 1.1930


100%|██████████| 13/13 [00:03<00:00,  4.20it/s]


epoch 12 train end : avg_loss = 0.9522


100%|██████████| 4/4 [00:00<00:00,  5.72it/s]


epoch 12 test end : avg_loss = 1.3454


100%|██████████| 13/13 [00:03<00:00,  4.06it/s]


epoch 13 train end : avg_loss = 0.9522


100%|██████████| 4/4 [00:00<00:00,  6.10it/s]


epoch 13 test end : avg_loss = 1.2332


100%|██████████| 13/13 [00:03<00:00,  4.15it/s]


epoch 14 train end : avg_loss = 0.8593


100%|██████████| 4/4 [00:00<00:00,  6.08it/s]


epoch 14 test end : avg_loss = 1.0078


100%|██████████| 13/13 [00:03<00:00,  4.10it/s]


epoch 15 train end : avg_loss = 0.8727


100%|██████████| 4/4 [00:00<00:00,  6.07it/s]


epoch 15 test end : avg_loss = 1.1315


100%|██████████| 13/13 [00:03<00:00,  4.21it/s]


epoch 16 train end : avg_loss = 0.8939


100%|██████████| 4/4 [00:00<00:00,  5.88it/s]


epoch 16 test end : avg_loss = 1.1483


100%|██████████| 13/13 [00:03<00:00,  4.00it/s]


epoch 17 train end : avg_loss = 0.8603


100%|██████████| 4/4 [00:00<00:00,  6.09it/s]


epoch 17 test end : avg_loss = 1.3898


100%|██████████| 13/13 [00:03<00:00,  4.15it/s]


epoch 18 train end : avg_loss = 0.8587


100%|██████████| 4/4 [00:00<00:00,  6.16it/s]


epoch 18 test end : avg_loss = 1.0335


100%|██████████| 13/13 [00:03<00:00,  4.18it/s]


epoch 19 train end : avg_loss = 0.8403


100%|██████████| 4/4 [00:00<00:00,  6.12it/s]


epoch 19 test end : avg_loss = 0.9101


100%|██████████| 13/13 [00:03<00:00,  4.15it/s]


epoch 20 train end : avg_loss = 0.7121


100%|██████████| 4/4 [00:00<00:00,  6.18it/s]


epoch 20 test end : avg_loss = 0.9232


100%|██████████| 13/13 [00:03<00:00,  4.11it/s]


epoch 21 train end : avg_loss = 0.9332


100%|██████████| 4/4 [00:00<00:00,  6.05it/s]


epoch 21 test end : avg_loss = 1.0149


100%|██████████| 13/13 [00:03<00:00,  4.21it/s]


epoch 22 train end : avg_loss = 0.8430


100%|██████████| 4/4 [00:00<00:00,  6.17it/s]


epoch 22 test end : avg_loss = 0.8285


100%|██████████| 13/13 [00:03<00:00,  4.18it/s]


epoch 23 train end : avg_loss = 0.8027


100%|██████████| 4/4 [00:00<00:00,  5.44it/s]


epoch 23 test end : avg_loss = 1.1782


100%|██████████| 13/13 [00:03<00:00,  4.13it/s]


epoch 24 train end : avg_loss = 0.7824


100%|██████████| 4/4 [00:00<00:00,  5.80it/s]


epoch 24 test end : avg_loss = 0.7951


100%|██████████| 13/13 [00:03<00:00,  3.99it/s]


epoch 25 train end : avg_loss = 0.7412


100%|██████████| 4/4 [00:00<00:00,  5.90it/s]


epoch 25 test end : avg_loss = 1.1582


100%|██████████| 13/13 [00:03<00:00,  4.04it/s]


epoch 26 train end : avg_loss = 0.6316


100%|██████████| 4/4 [00:00<00:00,  6.02it/s]


epoch 26 test end : avg_loss = 0.8456


100%|██████████| 13/13 [00:03<00:00,  4.08it/s]


epoch 27 train end : avg_loss = 0.7230


100%|██████████| 4/4 [00:00<00:00,  6.22it/s]


epoch 27 test end : avg_loss = 1.0988


100%|██████████| 13/13 [00:03<00:00,  4.24it/s]


epoch 28 train end : avg_loss = 0.8070


100%|██████████| 4/4 [00:00<00:00,  6.14it/s]


epoch 28 test end : avg_loss = 1.2493


100%|██████████| 13/13 [00:03<00:00,  4.27it/s]


epoch 29 train end : avg_loss = 0.6393


100%|██████████| 4/4 [00:00<00:00,  6.18it/s]


epoch 29 test end : avg_loss = 1.0730


100%|██████████| 13/13 [00:03<00:00,  4.21it/s]


epoch 30 train end : avg_loss = 0.6059


100%|██████████| 4/4 [00:00<00:00,  6.11it/s]


epoch 30 test end : avg_loss = 0.9054


100%|██████████| 13/13 [00:03<00:00,  4.26it/s]


epoch 31 train end : avg_loss = 0.6566


100%|██████████| 4/4 [00:00<00:00,  6.08it/s]


epoch 31 test end : avg_loss = 1.1254


100%|██████████| 13/13 [00:03<00:00,  4.04it/s]


epoch 32 train end : avg_loss = 0.5442


100%|██████████| 4/4 [00:00<00:00,  6.14it/s]


epoch 32 test end : avg_loss = 0.8338


100%|██████████| 13/13 [00:03<00:00,  4.23it/s]


epoch 33 train end : avg_loss = 0.5404


100%|██████████| 4/4 [00:00<00:00,  6.02it/s]


epoch 33 test end : avg_loss = 0.9101


100%|██████████| 13/13 [00:03<00:00,  4.15it/s]


epoch 34 train end : avg_loss = 0.4922


100%|██████████| 4/4 [00:00<00:00,  6.08it/s]


epoch 34 test end : avg_loss = 0.7864


100%|██████████| 13/13 [00:03<00:00,  4.12it/s]


epoch 35 train end : avg_loss = 0.5640


100%|██████████| 4/4 [00:00<00:00,  6.09it/s]


epoch 35 test end : avg_loss = 1.0579


100%|██████████| 13/13 [00:03<00:00,  4.14it/s]


epoch 36 train end : avg_loss = 0.4045


100%|██████████| 4/4 [00:00<00:00,  5.97it/s]


epoch 36 test end : avg_loss = 0.8905


100%|██████████| 13/13 [00:03<00:00,  4.14it/s]


epoch 37 train end : avg_loss = 0.5191


100%|██████████| 4/4 [00:00<00:00,  6.04it/s]


epoch 37 test end : avg_loss = 0.9691


100%|██████████| 13/13 [00:03<00:00,  4.21it/s]


epoch 38 train end : avg_loss = 0.4959


100%|██████████| 4/4 [00:00<00:00,  6.00it/s]


epoch 38 test end : avg_loss = 1.3037


100%|██████████| 13/13 [00:03<00:00,  4.22it/s]


epoch 39 train end : avg_loss = 0.4837


100%|██████████| 4/4 [00:00<00:00,  6.07it/s]


epoch 39 test end : avg_loss = 1.2149


100%|██████████| 13/13 [00:03<00:00,  3.94it/s]


epoch 40 train end : avg_loss = 0.3946


100%|██████████| 4/4 [00:00<00:00,  6.09it/s]


epoch 40 test end : avg_loss = 0.9653


100%|██████████| 13/13 [00:03<00:00,  4.21it/s]


epoch 41 train end : avg_loss = 0.4262


100%|██████████| 4/4 [00:00<00:00,  6.14it/s]


epoch 41 test end : avg_loss = 0.9306


100%|██████████| 13/13 [00:03<00:00,  4.22it/s]


epoch 42 train end : avg_loss = 0.3831


100%|██████████| 4/4 [00:00<00:00,  6.20it/s]


epoch 42 test end : avg_loss = 0.9248


100%|██████████| 13/13 [00:03<00:00,  4.16it/s]


epoch 43 train end : avg_loss = 0.4851


100%|██████████| 4/4 [00:00<00:00,  6.18it/s]


epoch 43 test end : avg_loss = 0.9755


100%|██████████| 13/13 [00:03<00:00,  4.21it/s]


epoch 44 train end : avg_loss = 0.3980


100%|██████████| 4/4 [00:00<00:00,  6.27it/s]


epoch 44 test end : avg_loss = 1.1538


100%|██████████| 13/13 [00:03<00:00,  4.25it/s]


epoch 45 train end : avg_loss = 0.3875


100%|██████████| 4/4 [00:00<00:00,  5.96it/s]


epoch 45 test end : avg_loss = 1.1447


100%|██████████| 13/13 [00:03<00:00,  4.17it/s]


epoch 46 train end : avg_loss = 0.4003


100%|██████████| 4/4 [00:00<00:00,  6.06it/s]


epoch 46 test end : avg_loss = 0.8709


100%|██████████| 13/13 [00:03<00:00,  4.15it/s]


epoch 47 train end : avg_loss = 0.3232


100%|██████████| 4/4 [00:00<00:00,  6.05it/s]


epoch 47 test end : avg_loss = 0.8283


100%|██████████| 13/13 [00:03<00:00,  3.86it/s]


epoch 48 train end : avg_loss = 0.3145


100%|██████████| 4/4 [00:00<00:00,  6.25it/s]


epoch 48 test end : avg_loss = 0.7538


100%|██████████| 13/13 [00:03<00:00,  4.23it/s]


epoch 49 train end : avg_loss = 0.2659


100%|██████████| 4/4 [00:00<00:00,  6.13it/s]


epoch 49 test end : avg_loss = 0.9829


100%|██████████| 13/13 [00:03<00:00,  4.21it/s]


epoch 50 train end : avg_loss = 0.3311


100%|██████████| 4/4 [00:00<00:00,  6.09it/s]


epoch 50 test end : avg_loss = 1.0830


100%|██████████| 13/13 [00:03<00:00,  4.16it/s]


epoch 51 train end : avg_loss = 0.2902


100%|██████████| 4/4 [00:00<00:00,  5.93it/s]


epoch 51 test end : avg_loss = 0.7825


100%|██████████| 13/13 [00:03<00:00,  4.17it/s]


epoch 52 train end : avg_loss = 0.3128


100%|██████████| 4/4 [00:00<00:00,  6.18it/s]


epoch 52 test end : avg_loss = 1.0217


100%|██████████| 13/13 [00:03<00:00,  4.20it/s]


epoch 53 train end : avg_loss = 0.3520


100%|██████████| 4/4 [00:00<00:00,  5.94it/s]


epoch 53 test end : avg_loss = 0.9105


100%|██████████| 13/13 [00:03<00:00,  4.26it/s]


epoch 54 train end : avg_loss = 0.3496


100%|██████████| 4/4 [00:00<00:00,  6.25it/s]


epoch 54 test end : avg_loss = 1.1782


100%|██████████| 13/13 [00:03<00:00,  4.22it/s]


epoch 55 train end : avg_loss = 0.2921


100%|██████████| 4/4 [00:00<00:00,  6.11it/s]


epoch 55 test end : avg_loss = 1.0424


100%|██████████| 13/13 [00:03<00:00,  4.11it/s]


epoch 56 train end : avg_loss = 0.1746


100%|██████████| 4/4 [00:00<00:00,  5.30it/s]


epoch 56 test end : avg_loss = 0.7791


100%|██████████| 13/13 [00:03<00:00,  4.00it/s]


epoch 57 train end : avg_loss = 0.3169


100%|██████████| 4/4 [00:00<00:00,  5.86it/s]


epoch 57 test end : avg_loss = 0.9020


100%|██████████| 13/13 [00:03<00:00,  4.20it/s]


epoch 58 train end : avg_loss = 0.2607


100%|██████████| 4/4 [00:00<00:00,  5.70it/s]


epoch 58 test end : avg_loss = 1.2584


100%|██████████| 13/13 [00:03<00:00,  4.12it/s]


epoch 59 train end : avg_loss = 0.2424


100%|██████████| 4/4 [00:00<00:00,  6.01it/s]


epoch 59 test end : avg_loss = 1.1773


100%|██████████| 13/13 [00:03<00:00,  4.17it/s]


epoch 60 train end : avg_loss = 0.2394


100%|██████████| 4/4 [00:00<00:00,  6.25it/s]


epoch 60 test end : avg_loss = 0.8002


100%|██████████| 13/13 [00:03<00:00,  4.24it/s]


epoch 61 train end : avg_loss = 0.2292


100%|██████████| 4/4 [00:00<00:00,  6.18it/s]


epoch 61 test end : avg_loss = 0.9299


100%|██████████| 13/13 [00:03<00:00,  4.26it/s]


epoch 62 train end : avg_loss = 0.2765


100%|██████████| 4/4 [00:00<00:00,  6.16it/s]


epoch 62 test end : avg_loss = 0.8292


100%|██████████| 13/13 [00:03<00:00,  4.26it/s]


epoch 63 train end : avg_loss = 0.2081


100%|██████████| 4/4 [00:00<00:00,  6.17it/s]


epoch 63 test end : avg_loss = 1.0816


100%|██████████| 13/13 [00:03<00:00,  4.12it/s]


epoch 64 train end : avg_loss = 0.2682


100%|██████████| 4/4 [00:00<00:00,  5.80it/s]


epoch 64 test end : avg_loss = 1.0967


100%|██████████| 13/13 [00:03<00:00,  4.00it/s]


epoch 65 train end : avg_loss = 0.1905


100%|██████████| 4/4 [00:00<00:00,  6.10it/s]


epoch 65 test end : avg_loss = 0.7839


100%|██████████| 13/13 [00:03<00:00,  4.14it/s]


epoch 66 train end : avg_loss = 0.2248


100%|██████████| 4/4 [00:00<00:00,  5.81it/s]


epoch 66 test end : avg_loss = 0.7657


100%|██████████| 13/13 [00:03<00:00,  4.21it/s]


epoch 67 train end : avg_loss = 0.1859


100%|██████████| 4/4 [00:00<00:00,  6.22it/s]


epoch 67 test end : avg_loss = 0.9788


100%|██████████| 13/13 [00:03<00:00,  4.16it/s]


epoch 68 train end : avg_loss = 0.2682


100%|██████████| 4/4 [00:00<00:00,  6.10it/s]


epoch 68 test end : avg_loss = 1.1302


100%|██████████| 13/13 [00:03<00:00,  4.17it/s]


epoch 69 train end : avg_loss = 0.2265


100%|██████████| 4/4 [00:00<00:00,  6.21it/s]


epoch 69 test end : avg_loss = 0.7856


100%|██████████| 13/13 [00:03<00:00,  4.26it/s]


epoch 70 train end : avg_loss = 0.2801


100%|██████████| 4/4 [00:00<00:00,  6.20it/s]


epoch 70 test end : avg_loss = 1.4023


100%|██████████| 13/13 [00:03<00:00,  4.24it/s]


epoch 71 train end : avg_loss = 0.2351


100%|██████████| 4/4 [00:00<00:00,  6.24it/s]


epoch 71 test end : avg_loss = 0.9006


100%|██████████| 13/13 [00:03<00:00,  4.14it/s]


epoch 72 train end : avg_loss = 0.1880


100%|██████████| 4/4 [00:00<00:00,  6.20it/s]


epoch 72 test end : avg_loss = 0.8574


100%|██████████| 13/13 [00:03<00:00,  4.12it/s]


epoch 73 train end : avg_loss = 0.2595


100%|██████████| 4/4 [00:00<00:00,  5.96it/s]


epoch 73 test end : avg_loss = 1.0375


100%|██████████| 13/13 [00:03<00:00,  4.14it/s]


epoch 74 train end : avg_loss = 0.2453


100%|██████████| 4/4 [00:00<00:00,  5.77it/s]


epoch 74 test end : avg_loss = 0.8025


100%|██████████| 13/13 [00:03<00:00,  4.22it/s]


epoch 75 train end : avg_loss = 0.2840


100%|██████████| 4/4 [00:00<00:00,  6.09it/s]


epoch 75 test end : avg_loss = 1.0014


100%|██████████| 13/13 [00:03<00:00,  4.19it/s]


epoch 76 train end : avg_loss = 0.2797


100%|██████████| 4/4 [00:00<00:00,  6.09it/s]


epoch 76 test end : avg_loss = 1.1492


100%|██████████| 13/13 [00:03<00:00,  4.21it/s]


epoch 77 train end : avg_loss = 0.2355


100%|██████████| 4/4 [00:00<00:00,  6.15it/s]


epoch 77 test end : avg_loss = 1.2837


100%|██████████| 13/13 [00:03<00:00,  4.25it/s]


epoch 78 train end : avg_loss = 0.2057


100%|██████████| 4/4 [00:00<00:00,  6.18it/s]


epoch 78 test end : avg_loss = 0.9010


100%|██████████| 13/13 [00:03<00:00,  4.25it/s]


epoch 79 train end : avg_loss = 0.1764


100%|██████████| 4/4 [00:00<00:00,  6.11it/s]


epoch 79 test end : avg_loss = 1.1337


100%|██████████| 13/13 [00:03<00:00,  4.14it/s]


epoch 80 train end : avg_loss = 0.1372


100%|██████████| 4/4 [00:00<00:00,  6.11it/s]


epoch 80 test end : avg_loss = 0.8704


100%|██████████| 13/13 [00:03<00:00,  4.12it/s]


epoch 81 train end : avg_loss = 0.1323


100%|██████████| 4/4 [00:00<00:00,  5.21it/s]


epoch 81 test end : avg_loss = 0.9946


100%|██████████| 13/13 [00:03<00:00,  4.19it/s]


epoch 82 train end : avg_loss = 0.1761


100%|██████████| 4/4 [00:00<00:00,  5.68it/s]


epoch 82 test end : avg_loss = 1.1690


100%|██████████| 13/13 [00:03<00:00,  4.16it/s]


epoch 83 train end : avg_loss = 0.1638


100%|██████████| 4/4 [00:00<00:00,  6.16it/s]


epoch 83 test end : avg_loss = 0.9800


100%|██████████| 13/13 [00:03<00:00,  4.18it/s]


epoch 84 train end : avg_loss = 0.1671


100%|██████████| 4/4 [00:00<00:00,  6.26it/s]


epoch 84 test end : avg_loss = 0.8656


100%|██████████| 13/13 [00:03<00:00,  4.17it/s]


epoch 85 train end : avg_loss = 0.1816


100%|██████████| 4/4 [00:00<00:00,  6.13it/s]


epoch 85 test end : avg_loss = 0.8214


100%|██████████| 13/13 [00:03<00:00,  4.11it/s]


epoch 86 train end : avg_loss = 0.1704


100%|██████████| 4/4 [00:00<00:00,  6.12it/s]


epoch 86 test end : avg_loss = 0.9258


100%|██████████| 13/13 [00:03<00:00,  4.18it/s]


epoch 87 train end : avg_loss = 0.1261


100%|██████████| 4/4 [00:00<00:00,  6.03it/s]


epoch 87 test end : avg_loss = 1.0138


100%|██████████| 13/13 [00:03<00:00,  4.16it/s]


epoch 88 train end : avg_loss = 0.3037


100%|██████████| 4/4 [00:00<00:00,  6.19it/s]


epoch 88 test end : avg_loss = 1.2732


100%|██████████| 13/13 [00:03<00:00,  4.29it/s]


epoch 89 train end : avg_loss = 0.2004


100%|██████████| 4/4 [00:00<00:00,  6.25it/s]


epoch 89 test end : avg_loss = 0.9282


100%|██████████| 13/13 [00:03<00:00,  4.02it/s]


epoch 90 train end : avg_loss = 0.1514


100%|██████████| 4/4 [00:00<00:00,  5.93it/s]


epoch 90 test end : avg_loss = 0.9953


100%|██████████| 13/13 [00:03<00:00,  4.19it/s]


epoch 91 train end : avg_loss = 0.1950


100%|██████████| 4/4 [00:00<00:00,  5.99it/s]


epoch 91 test end : avg_loss = 0.8467


100%|██████████| 13/13 [00:03<00:00,  4.22it/s]


epoch 92 train end : avg_loss = 0.0748


100%|██████████| 4/4 [00:00<00:00,  5.99it/s]


epoch 92 test end : avg_loss = 1.1828


100%|██████████| 13/13 [00:03<00:00,  4.13it/s]


epoch 93 train end : avg_loss = 0.1669


100%|██████████| 4/4 [00:00<00:00,  5.93it/s]


epoch 93 test end : avg_loss = 0.6210


100%|██████████| 13/13 [00:03<00:00,  4.20it/s]


epoch 94 train end : avg_loss = 0.1901


100%|██████████| 4/4 [00:00<00:00,  5.34it/s]


epoch 94 test end : avg_loss = 1.1482


100%|██████████| 13/13 [00:03<00:00,  3.95it/s]


epoch 95 train end : avg_loss = 0.1669


100%|██████████| 4/4 [00:00<00:00,  6.24it/s]


epoch 95 test end : avg_loss = 1.0841


100%|██████████| 13/13 [00:03<00:00,  4.19it/s]


epoch 96 train end : avg_loss = 0.1336


100%|██████████| 4/4 [00:00<00:00,  6.15it/s]


epoch 96 test end : avg_loss = 1.3245


100%|██████████| 13/13 [00:03<00:00,  4.19it/s]


epoch 97 train end : avg_loss = 0.1893


100%|██████████| 4/4 [00:00<00:00,  6.08it/s]


epoch 97 test end : avg_loss = 1.1623


100%|██████████| 13/13 [00:03<00:00,  4.02it/s]


epoch 98 train end : avg_loss = 0.1306


100%|██████████| 4/4 [00:00<00:00,  5.91it/s]


epoch 98 test end : avg_loss = 0.7158


100%|██████████| 13/13 [00:03<00:00,  4.03it/s]


epoch 99 train end : avg_loss = 0.1547


100%|██████████| 4/4 [00:00<00:00,  5.83it/s]


epoch 99 test end : avg_loss = 0.8942


100%|██████████| 13/13 [00:03<00:00,  4.05it/s]


epoch 100 train end : avg_loss = 0.1546


100%|██████████| 4/4 [00:00<00:00,  5.96it/s]

epoch 100 test end : avg_loss = 1.2214





# 4. Evaluation

In [26]:
test_df.head()

Unnamed: 0,query,label
165,185.png,"[185.stl, 85.stl, 200.stl, 214.stl, 115.stl]"
6,103.png,"[103.stl, 19.stl, 218.stl, 141.stl, 214.stl]"
111,59.png,"[59.stl, 109.stl, 116.stl, 133.stl, 222.stl]"
172,191.png,"[191.stl, 135.stl, 81.stl, 100.stl, 118.stl]"
115,138.png,"[138.stl, 60.stl, 148.stl, 235.stl, 226.stl]"


In [27]:
model.eval()
mrr_test = []
for q in test_df['query']:
    
    # Load the query image
    img = Image.open('/kaggle/input/2d-3d-retrieval-dataset/wb_2D3Dretrieval_dataset/queries/' + q).convert("RGB")
    img_tensor = test_transform(img).unsqueeze(0).to('cuda').float()

    label = test_df[test_df['query'] == q].label.values[0][0]
    print(f"Query with image: {q}, label: {label}")
    # Calculate the distance between the query image and all the database images
    distances = {}
    for db in db_names:
        db_feat = torch.from_numpy(database_arr[db]).unsqueeze(0).to('cuda').float()
        img_emb, db_emb = model(img_tensor, db_feat)
        dist = torch.nn.functional.pairwise_distance(img_emb, db_emb)
        distances[db] = dist.item()

    # Sort the distances
    sorted_distances = sorted(distances.items(), key=lambda x: x[1])
    # Calculate the MRR
    mrr = 0
    for i, (db, dist) in enumerate(sorted_distances):
        if db == label:
            print(f"Matching at pos {i}")
            mrr = 1/(i+1)
            break
    mrr_test.append(mrr)
    print(f"Query: {q}, MRR@5: {mrr}")
print(f"Test MRR {sum(mrr_test)/ len(mrr_test)}")

Query with image: 185.png, label: 185.stl
Matching at pos 35
Query: 185.png, MRR@5: 0.027777777777777776
Query with image: 103.png, label: 103.stl
Matching at pos 22
Query: 103.png, MRR@5: 0.043478260869565216
Query with image: 59.png, label: 59.stl
Matching at pos 161
Query: 59.png, MRR@5: 0.006172839506172839
Query with image: 191.png, label: 191.stl
Matching at pos 116
Query: 191.png, MRR@5: 0.008547008547008548
Query with image: 138.png, label: 138.stl
Matching at pos 179
Query: 138.png, MRR@5: 0.005555555555555556
Query with image: 201.png, label: 201.stl
Matching at pos 110
Query: 201.png, MRR@5: 0.009009009009009009
Query with image: 217.png, label: 217.stl
Matching at pos 8
Query: 217.png, MRR@5: 0.1111111111111111
Query with image: 250.png, label: 250.stl
Matching at pos 63
Query: 250.png, MRR@5: 0.015625
Query with image: 106.png, label: 106.stl
Matching at pos 18
Query: 106.png, MRR@5: 0.05263157894736842
Query with image: 126.png, label: 126.stl
Matching at pos 131
Query: 1