In [1]:
from models import model, train_model
from torchvision.models import resnet101, ResNet101_Weights
from torchvision.datasets import ImageFolder
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset, ConcatDataset
import numpy as np
import torch

In [20]:
resnet_101 = model.Model(resnet101(ResNet101_Weights.IMAGENET1K_V2).to('cuda'))



In [21]:
resnet_101.model.load_state_dict(torch.load('models/models_weights/resnet101_similar_signs.pth'))

<All keys matched successfully>

In [22]:
data = ImageFolder('data_retrieval/data_other', ResNet101_Weights.IMAGENET1K_V2.transforms())

In [23]:
targets = data.targets
train_indices, test_val_indices = train_test_split(np.arange(len(targets)), stratify=targets, train_size=0.65, random_state=21)
train_data = Subset(data, indices=train_indices)

In [24]:
val_test_data = Subset(data, indices=test_val_indices)
targets = np.array(targets)

In [25]:
test_val_targets = targets[test_val_indices]
test_indices, val_indices = train_test_split(test_val_indices, stratify=test_val_targets, train_size=0.57, random_state=21)
val_data, test_data = Subset(data, indices=val_indices), Subset(data, indices=test_indices)

In [26]:
print("Length of train set:", len(train_data))
print("Length of val set:", len(val_data))
print("Length of test set:", len(test_data))

Length of train set: 8299
Length of val set: 1922
Length of test set: 2547


In [27]:
from collections import Counter
idx_to_class = {v: k for k, v in test_data.dataset.class_to_idx.items()}
samples_per_class_test = Counter()
for data, label in test_data:
    samples_per_class_test[idx_to_class[label]] += 1
    

KeyboardInterrupt: 

In [18]:
samples_per_class_test

Counter({'UD_Neo-Babylonian': 270,
         'DIŠ_Neo-Babylonian': 255,
         'AN_Neo-Babylonian': 240,
         'NA_Neo-Babylonian': 239,
         'A_Neo-Babylonian': 227,
         'MU_Neo-Babylonian': 109,
         'MA_Neo-Babylonian': 108,
         'AŠ_Neo-Babylonian': 103,
         'ŠU₂_Neo-Babylonian': 83,
         'I_Neo-Babylonian': 75,
         'NU_Neo-Babylonian': 72,
         'AN_Neo-Assyrian': 71,
         'IGI_Neo-Babylonian': 65,
         'NA_Neo-Assyrian': 61,
         'DIŠ_Neo-Assyrian': 59,
         'UD_Neo-Assyrian': 56,
         'BAD_Neo-Assyrian': 51,
         'BAD_Neo-Babylonian': 48,
         'MA_Neo-Assyrian': 47,
         'GIŠ_Neo-Babylonian': 45,
         'IGI_Neo-Assyrian': 43,
         'A_Neo-Assyrian': 41,
         'AŠ_Neo-Assyrian': 38,
         'GIŠ_Neo-Assyrian': 33,
         'MU_Neo-Assyrian': 30,
         'ŠU₂_Neo-Assyrian': 30,
         'NU_Neo-Assyrian': 27,
         'I_Neo-Assyrian': 21})

In [28]:
train_model.train_model(train_data, val_data, resnet_101, 'models/models_weights/resnet101_similar_signs.pth', epochs=80)

  0%|          | 0/80 [00:00<?, ?it/s]

Epoch 1 		 Training Loss: 0.3934429760713683 		 Validation Loss: 0.6576851424106882
Validation Loss Decreased(inf--->79.579902) 	 Saving The Model


  1%|▏         | 1/80 [04:00<5:17:07, 240.86s/it]

Epoch 2 		 Training Loss: 0.2619896970351939 		 Validation Loss: 0.6140058597257315
Validation Loss Decreased(79.579902--->74.294709) 	 Saving The Model


  2%|▎         | 2/80 [07:53<5:06:41, 235.91s/it]

Epoch 3 		 Training Loss: 0.185257055242196 		 Validation Loss: 0.5456231058505941
Validation Loss Decreased(74.294709--->66.020396) 	 Saving The Model


  5%|▌         | 4/80 [17:19<5:38:01, 266.87s/it]

Epoch 4 		 Training Loss: 0.15328526739901388 		 Validation Loss: 0.5921515385227755
Epoch 5 		 Training Loss: 0.11349324486308116 		 Validation Loss: 0.5294719927382371
Validation Loss Decreased(66.020396--->64.066111) 	 Saving The Model


  8%|▊         | 6/80 [26:33<5:36:14, 272.63s/it]

Epoch 6 		 Training Loss: 0.11428225295113935 		 Validation Loss: 0.6148846962485339


  9%|▉         | 7/80 [31:10<5:33:18, 273.96s/it]

Epoch 7 		 Training Loss: 0.10262867436598025 		 Validation Loss: 0.6739391397900325


 10%|█         | 8/80 [35:46<5:29:47, 274.83s/it]

Epoch 8 		 Training Loss: 0.07636720268941809 		 Validation Loss: 0.6212713862769306


 11%|█▏        | 9/80 [40:23<5:26:00, 275.50s/it]

Epoch 9 		 Training Loss: 0.06899096393006224 		 Validation Loss: 0.5742059447225337


 12%|█▎        | 10/80 [45:01<5:22:08, 276.12s/it]

Epoch 10 		 Training Loss: 0.05107194775708629 		 Validation Loss: 0.6132639013486211


 14%|█▍        | 11/80 [49:38<5:17:51, 276.40s/it]

Epoch 11 		 Training Loss: 0.0734185534199937 		 Validation Loss: 0.7167721455377981


 15%|█▌        | 12/80 [54:15<5:13:33, 276.67s/it]

Epoch 12 		 Training Loss: 0.05506059921745647 		 Validation Loss: 0.5942107625690497


 16%|█▋        | 13/80 [58:52<5:09:05, 276.80s/it]

Epoch 13 		 Training Loss: 0.04982690701944402 		 Validation Loss: 0.8378106806627359


 18%|█▊        | 14/80 [1:03:30<5:04:51, 277.14s/it]

Epoch 14 		 Training Loss: 0.07323749942754046 		 Validation Loss: 0.6214122832416026


 19%|█▉        | 15/80 [1:08:07<5:00:09, 277.06s/it]

Epoch 15 		 Training Loss: 0.03395456710243992 		 Validation Loss: 0.5939272447094256


 20%|██        | 16/80 [1:12:44<4:55:27, 277.00s/it]

Epoch 16 		 Training Loss: 0.0506892013796733 		 Validation Loss: 0.7271490475116682


 21%|██▏       | 17/80 [1:17:20<4:50:45, 276.91s/it]

Epoch 17 		 Training Loss: 0.047651099531958084 		 Validation Loss: 0.6506245914800477


 22%|██▎       | 18/80 [1:21:57<4:46:08, 276.91s/it]

Epoch 18 		 Training Loss: 0.04539434304091618 		 Validation Loss: 0.6851352448459745


 24%|██▍       | 19/80 [1:26:34<4:41:27, 276.85s/it]

Epoch 19 		 Training Loss: 0.039348465837715445 		 Validation Loss: 0.7565036286245878


 25%|██▌       | 20/80 [1:31:11<4:36:50, 276.84s/it]

Epoch 20 		 Training Loss: 0.026093843713555446 		 Validation Loss: 0.6071345377713442


 26%|██▋       | 21/80 [1:35:48<4:32:18, 276.92s/it]

Epoch 21 		 Training Loss: 0.03963545855930427 		 Validation Loss: 0.8024826136625502


 28%|██▊       | 22/80 [1:40:25<4:27:40, 276.91s/it]

Epoch 22 		 Training Loss: 0.028431633992616286 		 Validation Loss: 0.6944071016849734


 29%|██▉       | 23/80 [1:45:05<4:24:04, 277.97s/it]

Epoch 23 		 Training Loss: 0.04705217143855547 		 Validation Loss: 0.7383976357273203


 30%|███       | 24/80 [1:49:44<4:19:30, 278.04s/it]

Epoch 24 		 Training Loss: 0.03682983381873355 		 Validation Loss: 0.7228894804446565


 31%|███▏      | 25/80 [1:54:21<4:14:39, 277.81s/it]

Epoch 25 		 Training Loss: 0.022342992044945702 		 Validation Loss: 0.7407969977872154


 32%|███▎      | 26/80 [1:58:58<4:09:50, 277.60s/it]

Epoch 26 		 Training Loss: 0.015055536168721453 		 Validation Loss: 0.7150220328899608


 34%|███▍      | 27/80 [2:03:36<4:05:14, 277.62s/it]

Epoch 27 		 Training Loss: 0.009147853144287552 		 Validation Loss: 0.785878418305935


 35%|███▌      | 28/80 [2:08:13<4:00:28, 277.47s/it]

Epoch 28 		 Training Loss: 0.0689450519834516 		 Validation Loss: 0.7320491296243046


 36%|███▋      | 29/80 [2:12:50<3:55:51, 277.47s/it]

Epoch 29 		 Training Loss: 0.02413437597108498 		 Validation Loss: 0.7628444381508483


 38%|███▊      | 30/80 [2:17:28<3:51:14, 277.49s/it]

Epoch 30 		 Training Loss: 0.023832631694948092 		 Validation Loss: 0.6835073210753131


 39%|███▉      | 31/80 [2:22:05<3:46:35, 277.47s/it]

Epoch 31 		 Training Loss: 0.01638633759410855 		 Validation Loss: 0.7214494828973735


 40%|████      | 32/80 [2:26:43<3:41:58, 277.47s/it]

Epoch 32 		 Training Loss: 0.005525085623821461 		 Validation Loss: 0.7498823751410189


 41%|████▏     | 33/80 [2:31:20<3:37:24, 277.55s/it]

Epoch 33 		 Training Loss: 0.050722432187339773 		 Validation Loss: 0.7885877005640448


 42%|████▎     | 34/80 [2:35:58<3:32:46, 277.52s/it]

Epoch 34 		 Training Loss: 0.03167729727744817 		 Validation Loss: 0.6841831638757717


 44%|████▍     | 35/80 [2:40:36<3:28:10, 277.58s/it]

Epoch 35 		 Training Loss: 0.011275340413863284 		 Validation Loss: 0.7028448963607565


 45%|████▌     | 36/80 [2:45:13<3:23:34, 277.60s/it]

Epoch 36 		 Training Loss: 0.0021938525180433812 		 Validation Loss: 0.6962143299742294


 46%|████▋     | 37/80 [2:49:51<3:19:03, 277.74s/it]

Epoch 37 		 Training Loss: 0.002937031638266963 		 Validation Loss: 0.7488437946770371


 48%|████▊     | 38/80 [2:54:30<3:14:40, 278.11s/it]

Epoch 38 		 Training Loss: 0.0703409424484576 		 Validation Loss: 0.8555350362223165


 49%|████▉     | 39/80 [2:59:22<3:12:48, 282.15s/it]

Epoch 39 		 Training Loss: 0.020629723553020698 		 Validation Loss: 0.6648799165205892


 50%|█████     | 40/80 [3:04:05<3:08:13, 282.33s/it]

Epoch 40 		 Training Loss: 0.00744456802990262 		 Validation Loss: 0.6720008688989915


 51%|█████▏    | 41/80 [3:08:48<3:03:46, 282.73s/it]

Epoch 41 		 Training Loss: 0.011473702023425235 		 Validation Loss: 0.7587472012820692


 52%|█████▎    | 42/80 [3:13:47<3:02:08, 287.59s/it]

Epoch 42 		 Training Loss: 0.02956926022857269 		 Validation Loss: 0.7300579818841146


 54%|█████▍    | 43/80 [3:18:29<2:56:19, 285.94s/it]

Epoch 43 		 Training Loss: 0.015409865425642794 		 Validation Loss: 0.7545369529080662


 55%|█████▌    | 44/80 [3:23:08<2:50:16, 283.78s/it]

Epoch 44 		 Training Loss: 0.033266482149403656 		 Validation Loss: 0.7571803597225384


 56%|█████▋    | 45/80 [3:27:47<2:44:39, 282.27s/it]

Epoch 45 		 Training Loss: 0.014573293605047855 		 Validation Loss: 0.7317500721022472


 57%|█████▊    | 46/80 [3:32:26<2:39:28, 281.44s/it]

Epoch 46 		 Training Loss: 0.013081622469377418 		 Validation Loss: 0.7596345495063165


 59%|█████▉    | 47/80 [3:37:06<2:34:33, 281.01s/it]

Epoch 47 		 Training Loss: 0.023477257095036643 		 Validation Loss: 0.7177537730630402


 60%|██████    | 48/80 [3:41:36<2:28:08, 277.76s/it]

Epoch 48 		 Training Loss: 0.007913117762527094 		 Validation Loss: 0.7216282568606666


 61%|██████▏   | 49/80 [3:45:17<2:14:42, 260.71s/it]

Epoch 49 		 Training Loss: 0.010612354019020968 		 Validation Loss: 0.7550520304222937


 62%|██████▎   | 50/80 [3:49:08<2:05:49, 251.65s/it]

Epoch 50 		 Training Loss: 0.006562603442245907 		 Validation Loss: 0.7744777141627663


 64%|██████▍   | 51/80 [3:52:56<1:58:12, 244.58s/it]

Epoch 51 		 Training Loss: 0.002464126336897244 		 Validation Loss: 0.7327582094664897


 65%|██████▌   | 52/80 [3:56:47<1:52:14, 240.52s/it]

Epoch 52 		 Training Loss: 0.03285643500039323 		 Validation Loss: 0.9164901927110166


 66%|██████▋   | 53/80 [4:00:45<1:47:54, 239.81s/it]

Epoch 53 		 Training Loss: 0.029235461920256483 		 Validation Loss: 0.7581961354440894


 68%|██████▊   | 54/80 [4:04:41<1:43:25, 238.67s/it]

Epoch 54 		 Training Loss: 0.010445892710925194 		 Validation Loss: 0.6745344586882063


 69%|██████▉   | 55/80 [4:08:29<1:38:08, 235.52s/it]

Epoch 55 		 Training Loss: 0.0014874814838905968 		 Validation Loss: 0.7129785350301642


 70%|███████   | 56/80 [4:12:23<1:33:58, 234.94s/it]

Epoch 56 		 Training Loss: 0.005357954392554027 		 Validation Loss: 0.7231911615495501


 71%|███████▏  | 57/80 [4:16:23<1:30:37, 236.42s/it]

Epoch 57 		 Training Loss: 0.022872099385600093 		 Validation Loss: 0.8829673080404928


 72%|███████▎  | 58/80 [4:20:18<1:26:30, 235.94s/it]

Epoch 58 		 Training Loss: 0.030737098305324846 		 Validation Loss: 0.8960121548272793


 74%|███████▍  | 59/80 [4:24:13<1:22:30, 235.75s/it]

Epoch 59 		 Training Loss: 0.0062375230425764806 		 Validation Loss: 0.7680067853705789


 75%|███████▌  | 60/80 [4:28:09<1:18:34, 235.73s/it]

Epoch 60 		 Training Loss: 0.0038928270590108407 		 Validation Loss: 0.831920373530612


 76%|███████▋  | 61/80 [4:32:07<1:14:54, 236.55s/it]

Epoch 61 		 Training Loss: 0.006296852842346152 		 Validation Loss: 0.7228927775231418


 78%|███████▊  | 62/80 [4:36:04<1:11:02, 236.79s/it]

Epoch 62 		 Training Loss: 0.02259500019812442 		 Validation Loss: 1.0617326678908314


 79%|███████▉  | 63/80 [4:40:01<1:07:06, 236.86s/it]

Epoch 63 		 Training Loss: 0.0269424370119137 		 Validation Loss: 0.8331991279236048


 80%|████████  | 64/80 [4:43:56<1:02:59, 236.24s/it]

Epoch 64 		 Training Loss: 0.008725609661179308 		 Validation Loss: 0.7594171409872322


 81%|████████▏ | 65/80 [4:47:52<59:01, 236.08s/it]  

Epoch 65 		 Training Loss: 0.011684924825308923 		 Validation Loss: 0.7859881694143838


 82%|████████▎ | 66/80 [4:51:46<54:55, 235.39s/it]

Epoch 66 		 Training Loss: 0.006218518722172385 		 Validation Loss: 0.8168975483349039


 84%|████████▍ | 67/80 [4:55:39<50:51, 234.77s/it]

Epoch 67 		 Training Loss: 0.010045280031194185 		 Validation Loss: 0.7294148877438654


 85%|████████▌ | 68/80 [4:59:28<46:37, 233.11s/it]

Epoch 68 		 Training Loss: 0.0157158276404828 		 Validation Loss: 0.9412632293783683


 86%|████████▋ | 69/80 [5:03:18<42:32, 232.02s/it]

Epoch 69 		 Training Loss: 0.016169656299244713 		 Validation Loss: 0.8012558722200858


 88%|████████▊ | 70/80 [5:07:08<38:33, 231.35s/it]

Epoch 70 		 Training Loss: 0.006337238206917438 		 Validation Loss: 0.818298417609185


 89%|████████▉ | 71/80 [5:10:57<34:37, 230.83s/it]

Epoch 71 		 Training Loss: 0.0037156185069761485 		 Validation Loss: 0.74755752288899


 90%|█████████ | 72/80 [5:14:48<30:46, 230.78s/it]

Epoch 72 		 Training Loss: 0.0012381113556184423 		 Validation Loss: 0.7865555774687647


 91%|█████████▏| 73/80 [5:18:36<26:49, 229.95s/it]

Epoch 73 		 Training Loss: 0.01931492963129458 		 Validation Loss: 0.9246066231112195


 92%|█████████▎| 74/80 [5:22:27<23:02, 230.37s/it]

Epoch 74 		 Training Loss: 0.0224371811297254 		 Validation Loss: 0.8712203923279943


 94%|█████████▍| 75/80 [5:26:18<19:13, 230.65s/it]

Epoch 75 		 Training Loss: 0.0082852328148161 		 Validation Loss: 0.8175215924174086


 95%|█████████▌| 76/80 [5:30:11<15:24, 231.07s/it]

Epoch 76 		 Training Loss: 0.007623564426391921 		 Validation Loss: 0.8289660465468707


 96%|█████████▋| 77/80 [5:34:01<11:33, 231.01s/it]

Epoch 77 		 Training Loss: 0.006575879837186597 		 Validation Loss: 0.9002092367322916


 98%|█████████▊| 78/80 [5:37:57<07:44, 232.40s/it]

Epoch 78 		 Training Loss: 0.0150727242041183 		 Validation Loss: 0.8807683935358135


 99%|█████████▉| 79/80 [5:41:54<03:53, 233.65s/it]

Epoch 79 		 Training Loss: 0.011633341166965138 		 Validation Loss: 0.7677802002024481


100%|██████████| 80/80 [5:45:52<00:00, 259.41s/it]

Epoch 80 		 Training Loss: 0.004767991691364219 		 Validation Loss: 0.8408448371789918



