In [1]:
import os
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torchvision.models import resnet18
import glob
import gc
from torch.cuda.amp import autocast, GradScaler

In [2]:
def arr_to_str(a):
    return ';'.join([str(x) for x in a.reshape(-1)])

src = '/kaggle/input/image-matching-challenge-2023'
is_train = False

In [3]:
# Get data from csv.

data_dict = {}
with open(f'{src}/sample_submission.csv', 'r') as f:
    for i, l in enumerate(f):
        # Skip header.
        if l and i > 0:
            image, dataset, scene, _, _ = l.strip().split(',')
            if dataset not in data_dict:
                data_dict[dataset] = {}
            if scene not in data_dict[dataset]:
                data_dict[dataset][scene] = []
            data_dict[dataset][scene].append(image)

In [4]:
for dataset in data_dict:
    for scene in data_dict[dataset]:
        print(f'{dataset} / {scene} -> {len(data_dict[dataset][scene])} images')

2cfa01ab573141e4 / 2fa124afd1f74f38 -> 3 images


In [5]:
# Function to create a submission file.
def create_submission(out_results, data_dict):
    with open(f'submission.csv', 'w') as f:
        f.write('image_path,dataset,scene,rotation_matrix,translation_vector\n')
        for dataset in data_dict:
            if dataset in out_results:
                res = out_results[dataset]
            else:
                res = {}
            for scene in data_dict[dataset]:
                if scene in res:
                    scene_res = res[scene]
                else:
                    scene_res = {"R":{}, "t":{}}
                for image in data_dict[dataset][scene]:
                    if image in scene_res:
                        print (image)
                        R = scene_res[image]['R'].reshape(-1)
                        T = scene_res[image]['t'].reshape(-1)
                    else:
                        R = np.eye(3).reshape(-1)
                        T = np.zeros((3))
                    f.write(f'{image},{dataset},{scene},{arr_to_str(R)},{arr_to_str(T)}\n')

In [6]:
train_base = os.path.join(src, "train")
class CustomDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data_frame = pd.read_csv(csv_file)
        self.transform = transform
        
    def __len__(self):
        return len(self.data_frame)
    
    def __getitem__(self, idx):
        img_path = os.path.join(train_base, self.data_frame.loc[idx, 'image_path'])
        img = Image.open(img_path).convert('RGB')
        rot_mat = np.fromstring(self.data_frame.loc[idx, 'rotation_matrix'], sep=';').reshape(3, 3)
        trans_vec = np.fromstring(self.data_frame.loc[idx, 'translation_vector'], sep=';')
        sample = {'image': img, 'rot_mat': rot_mat, 'trans_vec': trans_vec}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        return sample

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

dataset = CustomDataset(csv_file=os.path.join(train_base, 'train_labels.csv'), transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


In [7]:
# Define the MyModel model
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        try:
            backbone = resnet18(pretrained=False)
        except:
            backbone = resnet18(pretrained=False)
            backbone.load_state_dict(torch.load('/kaggle/input/trained-models/resnet18_weights.pth'))
        self.backbone = backbone
        self.fc1 = nn.Linear(512, 128)
        self.fc2 = nn.Linear(128, 12)

    def forward(self, x):
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)

        x = self.backbone.avgpool(x)
        x = torch.flatten(x, 1)

        x = self.fc1(x)
        x = self.fc2(x)

        return x
    
# Initialize the model, loss function, and optimizer
num_epochs = 50
model = MyModel()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
if torch.cuda.device_count() > 1:
    print('Using', torch.cuda.device_count(), 'GPUs')
    model = nn.DataParallel(model)
model.to(device)

if is_train:
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scaler = GradScaler()

    # Train the model
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, batch in enumerate(dataloader):
            # Get the inputs and labels
            images, rot_mats, trans_vecs = batch['image'], batch['rot_mat'], batch['trans_vec']
            images, rot_mats, trans_vecs = images.float().to(device), rot_mats.float().to(device), trans_vecs.float().to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            with autocast():
                outputs = model(images)
                outputs = outputs.view(-1, 3, 4)
                rot_mats_pred = outputs[:, :3, :3]
                trans_vecs_pred = outputs[:, :3, 3]
                loss = criterion(rot_mats_pred, rot_mats) + criterion(trans_vecs_pred, trans_vecs)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # Print statistics
            running_loss += loss.item()
            if i % 10 == 0:
                print(f'Epoch {epoch + 1}, Batch {i + 1}: Loss = {running_loss / 10}')
                running_loss = 0.0
#     torch.save(model.state_dict(), 'model.pth')
else:
    model.load_state_dict(torch.load('/kaggle/input/trained-models/model.pth'))



Using device: cuda


In [8]:
def predict(model, image_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    image = Image.open(image_path)
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    ])

    with torch.no_grad():
        image = transform(image).unsqueeze(0).to(device)

        output = model(image).detach().cpu()[0]
        output = output.view(3, 4)
        rot_mats_pred = output[:3, :3].numpy()
        trans_vecs_pred = output[:3, 3].numpy()

    return rot_mats_pred, trans_vecs_pred

In [9]:
datasets = []
out_results = {}
for dataset in data_dict:
    datasets.append(dataset)

for dataset in datasets:
    print(dataset)
    if dataset not in out_results:
        out_results[dataset] = {}
    for scene in data_dict[dataset]:
        print(scene)
        # Fail gently if the notebook has not been submitted and the test data is not populated.
        # You may want to run this on the training data in that case?
        img_dir = f'{src}/test/{dataset}/{scene}/images'
        # img_dir = f'{src}/train/{dataset}/{scene}/images'
        if not os.path.exists(img_dir):
            continue
        # Wrap the meaty part in a try-except block.
        try:
            out_results[dataset][scene] = {}
            img_fnames = [f'{src}/test/{x}' for x in data_dict[dataset][scene]]
            # img_fnames = [img.replace('test', 'train') for img in img_fnames]
            print (f"Got {len(img_fnames)} images")

            for i in range(len(img_fnames)):
                key1 = f'{dataset}/{scene}/images/{img_fnames[i]}'
                r, t = predict(model, key1)
                out_results[dataset][scene][key1] = {}
                out_results[dataset][scene][key1]["R"] = r
                out_results[dataset][scene][key1]["t"] = t

            create_submission(out_results, data_dict)
            gc.collect()
        except:
            pass

2cfa01ab573141e4
2fa124afd1f74f38


In [10]:
create_submission(out_results, data_dict)