In [1]:
import os
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
import torch
from torch import nn
import torchvision.models as models
from help_func_torch.my_function import MyDataset, test_model

In [2]:
weights = models.ViT_B_16_Weights.IMAGENET1K_V1
preprocess = weights.transforms(antialias=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
path = 'dpl-project-1-dog-breed-identification/test/'

In [4]:
test_df = pd.DataFrame(os.listdir(path))
test_df['id'] = test_df[0].str.rstrip('.jpg')
test_df['breed'] = None
test_df.drop(columns=[0], inplace=True)
test_df

Unnamed: 0,id,breed
0,000621fb3cbb32d8935728e48679680e,
1,00102ee9d8eb90812350685311fe5890,
2,0012a730dfa437f5f3613fb75efcd4ce,
3,001510bc8570bbeee98c8d80c8a95ec1,
4,001a5f3114548acdefa3d4da05474c2e,
...,...,...
10352,ffeda8623d4eee33c6d1156a2ecbfcf8,
10353,fff1ec9e6e413275984966f745a313b0,
10354,fff74b59b758bbbf13a5793182a9bbe4,
10355,fff7d50d848e8014ac1e9172dc6762a3,


In [5]:
test_data = MyDataset(directory=path, subset=test_df, transform=preprocess)

In [6]:
checkpoint = torch.load('my_model_v2_final.pt', map_location=device)

vit_b16 = models.vit_b_16(weights=weights)

for param in vit_b16.parameters():
    param.requires_grad = False

vit_b16.heads = nn.Sequential(nn.Dropout(0.5),
                              nn.Linear(768, 120))

vit_b16.load_state_dict(checkpoint['model_latest'])
vit_b16.to(device)

labels_data = pd.read_csv('dpl-project-1-dog-breed-identification/labels.csv')
le = LabelEncoder()
le.fit(labels_data['breed'])
dict(enumerate(le.classes_))

{0: 'affenpinscher',
 1: 'afghan_hound',
 2: 'african_hunting_dog',
 3: 'airedale',
 4: 'american_staffordshire_terrier',
 5: 'appenzeller',
 6: 'australian_terrier',
 7: 'basenji',
 8: 'basset',
 9: 'beagle',
 10: 'bedlington_terrier',
 11: 'bernese_mountain_dog',
 12: 'black-and-tan_coonhound',
 13: 'blenheim_spaniel',
 14: 'bloodhound',
 15: 'bluetick',
 16: 'border_collie',
 17: 'border_terrier',
 18: 'borzoi',
 19: 'boston_bull',
 20: 'bouvier_des_flandres',
 21: 'boxer',
 22: 'brabancon_griffon',
 23: 'briard',
 24: 'brittany_spaniel',
 25: 'bull_mastiff',
 26: 'cairn',
 27: 'cardigan',
 28: 'chesapeake_bay_retriever',
 29: 'chihuahua',
 30: 'chow',
 31: 'clumber',
 32: 'cocker_spaniel',
 33: 'collie',
 34: 'curly-coated_retriever',
 35: 'dandie_dinmont',
 36: 'dhole',
 37: 'dingo',
 38: 'doberman',
 39: 'english_foxhound',
 40: 'english_setter',
 41: 'english_springer',
 42: 'entlebucher',
 43: 'eskimo_dog',
 44: 'flat-coated_retriever',
 45: 'french_bulldog',
 46: 'german_sheph

In [7]:
_, _, y_prob = test_model(model=vit_b16,test_data=test_data, device=device)
y_prob = np.array(y_prob, dtype=np.float32)

Time remain: 100%|██████████| 10357/10357 [04:41<00:00, 36.80it/s]


In [8]:
submission = pd.DataFrame(y_prob, index=test_df['id'], columns=le.classes_)
submission.reset_index(inplace=True)
submission.rename(columns={'index':'id'}, inplace=True)
submission.to_csv('submission.csv', index=False)

In [9]:
pd.read_csv('submission.csv')

Unnamed: 0,id,affenpinscher,afghan_hound,african_hunting_dog,airedale,american_staffordshire_terrier,appenzeller,australian_terrier,basenji,basset,...,toy_poodle,toy_terrier,vizsla,walker_hound,weimaraner,welsh_springer_spaniel,west_highland_white_terrier,whippet,wire-haired_fox_terrier,yorkshire_terrier
0,000621fb3cbb32d8935728e48679680e,4.035860e-08,2.021701e-06,1.165243e-06,1.805319e-06,3.551559e-07,1.637970e-06,3.556419e-07,7.241077e-07,8.058891e-07,...,4.457836e-07,8.080706e-08,3.183542e-07,3.621445e-07,9.985978e-07,4.439828e-08,3.965216e-06,7.452616e-07,1.286437e-06,7.465851e-08
1,00102ee9d8eb90812350685311fe5890,2.809243e-07,1.837857e-06,4.501475e-07,9.229699e-07,1.803121e-06,1.998900e-06,7.052435e-07,7.886821e-07,1.431012e-06,...,1.745065e-06,2.672899e-07,5.337183e-07,2.438032e-06,5.390185e-07,1.526373e-06,1.015255e-06,9.106932e-07,1.012615e-06,1.247528e-06
2,0012a730dfa437f5f3613fb75efcd4ce,3.327718e-07,8.099063e-07,1.465983e-06,1.005662e-06,8.002340e-07,5.208129e-06,1.778605e-06,1.767610e-06,8.493267e-07,...,4.660789e-07,3.354085e-06,1.405114e-05,2.254757e-06,2.600819e-06,8.271462e-07,4.858152e-07,1.923746e-07,9.090958e-07,9.980789e-07
3,001510bc8570bbeee98c8d80c8a95ec1,9.000853e-06,1.850288e-06,1.730118e-06,1.086695e-06,3.871108e-06,1.865897e-05,1.024501e-05,2.151401e-06,1.523703e-06,...,6.749207e-06,2.389444e-06,1.477762e-05,1.045697e-06,7.303917e-07,9.167536e-07,1.276493e-06,4.401307e-06,9.748670e-07,5.429786e-06
4,001a5f3114548acdefa3d4da05474c2e,9.690485e-04,9.652331e-05,2.286494e-05,6.191523e-05,2.987710e-05,2.698196e-05,6.410926e-05,1.260529e-04,1.405676e-05,...,1.209744e-05,6.864322e-05,3.224154e-05,1.849857e-05,2.678608e-05,8.653181e-06,8.002221e-06,2.794174e-05,8.446447e-04,1.395040e-05
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10352,ffeda8623d4eee33c6d1156a2ecbfcf8,3.923260e-06,3.930640e-05,8.388447e-06,1.636117e-06,1.844416e-05,8.747823e-06,5.635195e-05,1.539912e-05,9.783007e-06,...,6.212030e-05,1.943619e-05,7.383130e-06,3.575269e-05,1.660553e-05,6.802877e-06,4.612699e-06,3.599432e-05,3.639339e-06,2.393782e-05
10353,fff1ec9e6e413275984966f745a313b0,2.440570e-06,6.653313e-07,1.707172e-06,1.228679e-06,1.276879e-06,9.313202e-07,5.043230e-07,1.689161e-07,5.005754e-07,...,2.423861e-07,3.530976e-08,2.752643e-06,4.802571e-08,9.998500e-01,5.441582e-07,6.840867e-07,2.274040e-09,1.060110e-06,2.818668e-07
10354,fff74b59b758bbbf13a5793182a9bbe4,7.706101e-07,2.210893e-06,3.565711e-05,7.307522e-07,1.548046e-06,9.261169e-07,1.596685e-06,5.148352e-07,1.316910e-06,...,7.213390e-07,1.073783e-06,4.102535e-07,8.686077e-07,1.101157e-06,9.663478e-07,5.492228e-07,1.901057e-06,1.038873e-06,1.633764e-06
10355,fff7d50d848e8014ac1e9172dc6762a3,2.291636e-05,5.339044e-06,1.154938e-05,1.022123e-05,9.357650e-06,1.397231e-05,9.808099e-06,9.968570e-06,1.614039e-05,...,7.998681e-06,5.580865e-06,4.510803e-06,5.769140e-06,3.991896e-06,3.789908e-06,5.036178e-06,6.911714e-06,2.136978e-05,3.263017e-06
