In [1]:
from google.colab import drive
drive.mount('/content/drive')
import os
os.chdir('drive/MyDrive')
os.listdir('src')

Mounted at /content/drive


['models', '__pycache__', 'model_trainer.py', 'transformations.py', 'utils.py']

In [None]:
from src.models.our_model import OurModel
from src.models.pretrained_models import VGG16Pretrained, ResNetPretrained
import torch
from src.models.ensemble import HardVotingEnsemble, SoftVotingEnsemble
from src.utils import load_data, evaluate_model
from src.model_trainer import ModelTrainer
from src.transformations import normalized_simple_transform, pretrained_transform
device_str = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device_str)
torch.manual_seed(123)
print(device)
torch.set_num_threads(14)

cuda


In [3]:
# cell for donwloading data

import os
import tarfile
import urllib.request
import shutil

download_dir = "/data"

if not os.path.exists(download_dir):
    os.makedirs(download_dir)

tar_path = os.path.join(download_dir, "CINIC-10.tar.gz")
cinic10_url = "https://datashare.is.ed.ac.uk/bitstream/handle/10283/3192/CINIC-10.tar.gz"
urllib.request.urlretrieve(cinic10_url, tar_path)
print("Finished downloading")

with tarfile.open(tar_path, "r:gz") as tar:
    tar.extractall(path=download_dir)

extracted_dir = os.path.join(download_dir, "CINIC-10")

if os.path.exists(extracted_dir) and os.path.isdir(extracted_dir):
    for item in os.listdir(extracted_dir):
        src_path = os.path.join(extracted_dir, item)
        dst_path = os.path.join(download_dir, item)
        shutil.move(src_path, dst_path)

print(download_dir)

Finished downloading
/data


In [7]:
test_loader = load_data('/data/test', batch_size=128, shuffle=True, transform=normalized_simple_transform(), num_workers=1)
ourmodel = OurModel(aux_enabled=False, se_squeeze=8)
ourmodel.load_state_dict(torch.load("./saved_models/ourmodel/combined_20/OurModel_16.pth", map_location=device))
ourmodel.to(device)
results = evaluate_model(model=ourmodel, dataloader=test_loader, device=device, max_batches=None)
print(f"\n {results}")

0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 210 220 230 240 250 260 270 280 290 300 310 320 330 340 350 360 370 380 390 400 410 420 430 440 450 460 470 480 490 500 510 520 530 540 550 560 570 580 590 600 610 620 630 640 650 660 670 680 690 700 
 {'accuracy': 0.7467555555555555, 'f1_score': 0.7449930276087703, 'roc_auc': np.float64(0.9699651569958847)}


In [5]:
test_loader_pretrained = load_data('/data/test', batch_size=128, shuffle=True, transform=pretrained_transform(), num_workers=1)
vgg16 = VGG16Pretrained()
vgg16.load_state_dict(torch.load("./saved_models/vgg_pretrained/horizontal_flip_20/VGG16Pretrained_20.pth", map_location=device))
vgg16.to(device)
results_pretrained = evaluate_model(model=vgg16, dataloader=test_loader_pretrained, device=device)
print(f"\n {results_pretrained}")

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:04<00:00, 111MB/s]


0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 210 220 230 240 250 260 270 280 290 300 310 320 330 340 350 360 370 380 390 400 410 420 430 440 450 460 470 480 490 500 510 520 530 540 550 560 570 580 590 600 610 620 630 640 650 660 670 680 690 700 {'accuracy': 0.7581333333333333, 'f1_score': 0.7569758427001803, 'roc_auc': np.float64(0.9724191121399176)}


In [4]:
test_loader_pretrained = load_data('/data/test', batch_size=128, shuffle=True, transform=pretrained_transform(), num_workers=1)
resnet = ResNetPretrained()
resnet.load_state_dict(torch.load("./saved_models/resnet/final/ResNetPretrained_20.pth", map_location=device))
resnet.to(device)
results_pretrained = evaluate_model(model=resnet, dataloader=test_loader_pretrained, device=device)
print(f"\n {results_pretrained}")

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 133MB/s]


0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 210 220 230 240 250 260 270 280 290 300 310 320 330 340 350 360 370 380 390 400 410 420 430 440 450 460 470 480 490 500 510 520 530 540 550 560 570 580 590 600 610 620 630 640 650 660 670 680 690 700 
 {'accuracy': 0.7855888888888889, 'f1_score': 0.7844789385324462, 'roc_auc': np.float64(0.9786484316186558)}
