In [1]:
from pathlib import Path

In [2]:
import numpy as np
import pandas as pd
import torch
import torchvision
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
from matplotlib.pyplot import imshow

%matplotlib inline

!pip install efficientnet_pytorch

import joblib
from efficientnet_pytorch import EfficientNet

Collecting efficientnet_pytorch
  Downloading efficientnet_pytorch-0.7.0.tar.gz (20 kB)
Building wheels for collected packages: efficientnet-pytorch
  Building wheel for efficientnet-pytorch (setup.py) ... [?25l- \ done
[?25h  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.7.0-py3-none-any.whl size=16035 sha256=2e3b17d961437ed5c233b2b8df1b1778ee2fc6aec814cf2e83d8cd6917d45c7a
  Stored in directory: /root/.cache/pip/wheels/b7/cc/0d/41d384b0071c6f46e542aded5f8571700ace4f1eb3f1591c29
Successfully built efficientnet-pytorch
Installing collected packages: efficientnet-pytorch
Successfully installed efficientnet-pytorch-0.7.0


In [3]:
from torchvision import transforms

In [4]:
ball_tree_dump_file = '/kaggle/input/2nd-step/library_ball_tree.sav'
lib_files_dump_file = '/kaggle/input/2nd-step/library_files_list.sav'
CNN_MODEL_WEIGHTS = Path('/kaggle/input/digix-ai-1st-attempt/eff_net_w_2.pt')
QUERY_DIR = Path('/kaggle/input/digixquery/query')
RESCALE_SIZE=224

In [5]:
import torch.nn as nn

In [6]:
eval_on_gpu = torch.cuda.is_available()

if eval_on_gpu:
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')

In [7]:
def prepare_model(model):
    fc_without_last_layer = list(model._fc.children())[:-2]
    model._fc = torch.nn.Sequential(*fc_without_last_layer)

In [8]:
knn_model = joblib.load(ball_tree_dump_file)

cnn_model = EfficientNet.from_name('efficientnet-b1')
additional_ftrs = 4096
n_classes = 3094
#Изменяем выходные слои модели
num_ftrs_resnext = cnn_model._fc.in_features
new_fc_seq = nn.Sequential(
    nn.Linear(num_ftrs_resnext,additional_ftrs),
    nn.LeakyReLU(),
    nn.Linear(additional_ftrs, n_classes)
)
cnn_model._fc = new_fc_seq

cnn_model.to(DEVICE)
cnn_model.load_state_dict(torch.load(CNN_MODEL_WEIGHTS))
cnn_model.eval()
prepare_model(cnn_model)

lib_files = joblib.load(lib_files_dump_file)

In [9]:
def query_processing(knn_model, cnn_model, query_path, lib_files, n_results=10):
    image = Image.open(query_path)
    image.load()
    
    if (len(image.mode) < 2):
        image = transforms.Grayscale(3)(image)
    transforms.functional.adjust_saturation(img=image,saturation_factor=1.25)
    transforms.functional.adjust_gamma(img=image, gamma=0.25)
    
    image_transform = transforms.Compose([
            transforms.Resize(RESCALE_SIZE),
            transforms.CenterCrop(RESCALE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    image = image_transform(image)
    image = image.unsqueeze(0)
    image = image.to(DEVICE)
    
    with torch.no_grad():
        feature_vector = cnn_model(image).cpu()
    neigh_dist, nearest_nbrs = knn_model.kneighbors(feature_vector)
    
    result = list()
    nearest_nbrs = np.squeeze(np.array(nearest_nbrs))
    neigh_dist = np.squeeze(np.array(neigh_dist))
    for i in range(10):
        result.append((lib_files[nearest_nbrs[i]].name, neigh_dist[i]))
    
    result = sorted(result, key=lambda pair: pair[1], reverse=True)
    result = list(map(lambda pair: pair[0], result))
    return result

In [10]:
query_files = list(QUERY_DIR.rglob('*.jpg'))

query_results = list()
                     
for query_file in query_files:
    result = query_processing(knn_model, cnn_model, query_file, lib_files)
    query_results.append(result)

IndexError: list index out of range

In [11]:
import csv
with open('submission.csv', 'w', encoding="utf8") as f:
    for i in range(len(query_files)):
        f.write("%s,{"%(query_files[i].name))
        for j in range(9):
            f.write("%s,"%(query_results[i][j]))
        f.write("%s}\n"%(query_results[i][9]))

IndexError: list index out of range

In [12]:
import os, sys, codecs

BUFSIZE = 4096
BOMLEN = len(codecs.BOM_UTF8)

path = 'submission.csv'
with open(path, "r+b") as fp:
    chunk = fp.read(BUFSIZE)
    if chunk.startswith(codecs.BOM_UTF8):
        i = 0
        chunk = chunk[BOMLEN:]
        while chunk:
            fp.seek(i)
            fp.write(chunk)
            i += len(chunk)
            fp.seek(BOMLEN, os.SEEK_CUR)
            chunk = fp.read(BUFSIZE)
        fp.seek(-BOMLEN, os.SEEK_CUR)
        fp.truncate()