In [8]:
import torch
from torchvision import transforms as T
from lightly.transforms import SimCLRTransform, DINOTransform, MAETransform, MoCoV2Transform, utils
from datasets import create_dataset
from models import MAEModel
import pytorch_lightning as pl
import os
import copy
import gc
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import os
from pathlib import Path
from models import BYOLModel, SimCLRModel, SimpleMLP
import pandas as pd
from tqdm import tqdm
from torchmetrics.classification import Accuracy
from evaluation_whole import *
import evaluation_whole

In [11]:

import eval
import importlib
import models
from eval import *
importlib.reload(models)
importlib.reload(evaluation_whole)
importlib.reload(eval)

<module 'eval' from 'c:\\Users\\User\\Desktop\\studia\\Sem 6\\WB\\nowewb\\Warsztaty_Badawcze\\eval.py'>

In [6]:
def evaluate_model(model_path, dataset_name, path_to_data='./data', batch_size=256, num_workers=4, device=None):
    """
    Evaluate a saved model on CIFAR-10 or CIFAR-100 test dataset using torchmetrics.
    
    Args:
        model_path (str): Path to the saved model checkpoint
        dataset_name (str): Either 'CIFAR10' or 'CIFAR100'
        path_to_data (str): Path where dataset will be downloaded/loaded from
        batch_size (int): Batch size for evaluation
        num_workers (int): Number of workers for DataLoader
        device (str/torch.device): Device to run evaluation on. If None, will auto-detect
    
    Returns:
        float: Test accuracy
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = torch.device(device)
    
    print(f"Using device: {device}")

    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize all images to 224x224
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    if dataset_name == 'CIFAR10':
        test_dataset = torchvision.datasets.CIFAR10(
            root=path_to_data, train=False, transform=test_transform, download=True
        )
        num_classes = 10
    elif dataset_name == 'CIFAR100':
        test_dataset = torchvision.datasets.CIFAR100(
            root=path_to_data, train=False, transform=test_transform, download=True
        )
        num_classes = 100
    elif dataset_name == 'flowers':
        test_dataset = torchvision.datasets.Flowers102(
            root='./data', split='test', download=True, transform=test_transform
        )
        num_classes=102

    elif dataset_name == 'pets':
        full_dataset = torchvision.datasets.OxfordIIITPet(
            root='./data', download=True,
            transform=test_transform,
            target_types='category'  # 'segmentation' też dostępne
        )
        test_dataset = torch.utils.data.Subset(full_dataset, range(int(len(full_dataset)*0.8), len(full_dataset)))
        num_classes = 37 
    

    elif dataset_name == 'aircraft':
        test_dataset = torchvision.datasets.FGVCAircraft(
            root='./data', download=True,
            split='test',
            transform=test_transform
        )
        num_classes = 102
    else:
        raise ValueError(f"Dataset {dataset_name} not supported. Use 'CIFAR10' or 'CIFAR100'")

    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=num_workers, 
        pin_memory=(device.type == 'cuda')
    )

    print(f"Loading model from: {model_path}")
    checkpoint = torch.load(model_path, map_location=device)

    if hasattr(checkpoint, 'model'):
        model = checkpoint.model
        print("yay")
    elif 'state_dict' in checkpoint:
        params = checkpoint['hyper_parameters']
        model = ClassifierModel(
            num_classes=params['num_classes'],
            lr=params['lr'],
            weight_decay=params['weight_decay'],
            max_epochs=params['max_epochs'],
            backbone_type=params['backbone_type']
        )
        model.load_state_dict(checkpoint['state_dict'])
    else:
        model = checkpoint  # fallback

    model = model.to(device)
    model.eval()

    # Initialize torchmetrics Accuracy
    accuracy_metric = Accuracy(task="multiclass", num_classes=num_classes).to(device)

    print(f"Evaluating on {dataset_name} test set...")

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Testing"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            accuracy_metric.update(preds, labels)

    accuracy = accuracy_metric.compute().item() * 100  # convert to percent

    print(f"\nResults for {dataset_name}:")
    print(f"Test Accuracy: {accuracy:.2f}%")

    return accuracy

models=['pre-2','pre-3','ran-2','ran-3']
# Test on CIFAR-10
data=['aircraft', 'flowers','pets']
for dataset in data:
    acc=[]

    for model in models:
        cifar10_acc = evaluate_model(
            model_path=f"./checkpoints/reszta/checkpoints_{dataset}/checkpoints_{dataset}/scratch/{model}/model_epoch_15.ckpt",
            dataset_name=dataset,
            path_to_data="./data"
        )
        print(f"CIFAR-10 Accuracy: {cifar10_acc:.2f}%, model: {model}")
        acc.append(cifar10_acc)


Using device: cuda


100%|██████████| 2.75G/2.75G [08:31<00:00, 5.39MB/s] 


Loading model from: ./checkpoints/reszta/checkpoints_aircraft/checkpoints_aircraft/scratch/pre-2/model_epoch_15.ckpt
Evaluating on aircraft test set...


Testing: 100%|██████████| 14/14 [01:20<00:00,  5.72s/it]



Results for aircraft:
Test Accuracy: 5.10%
CIFAR-10 Accuracy: 5.10%, model: pre-2
Using device: cuda
Loading model from: ./checkpoints/reszta/checkpoints_aircraft/checkpoints_aircraft/scratch/pre-3/model_epoch_15.ckpt
Evaluating on aircraft test set...


Testing: 100%|██████████| 14/14 [01:13<00:00,  5.24s/it]



Results for aircraft:
Test Accuracy: 4.11%
CIFAR-10 Accuracy: 4.11%, model: pre-3
Using device: cuda
Loading model from: ./checkpoints/reszta/checkpoints_aircraft/checkpoints_aircraft/scratch/ran-2/model_epoch_15.ckpt
Evaluating on aircraft test set...


Testing: 100%|██████████| 14/14 [01:16<00:00,  5.50s/it]



Results for aircraft:
Test Accuracy: 5.40%
CIFAR-10 Accuracy: 5.40%, model: ran-2
Using device: cuda
Loading model from: ./checkpoints/reszta/checkpoints_aircraft/checkpoints_aircraft/scratch/ran-3/model_epoch_15.ckpt
Evaluating on aircraft test set...


Testing: 100%|██████████| 14/14 [01:19<00:00,  5.70s/it]



Results for aircraft:
Test Accuracy: 7.47%
CIFAR-10 Accuracy: 7.47%, model: ran-3
7.470747083425522
Using device: cuda
Loading model from: ./checkpoints/reszta/checkpoints_flowers/checkpoints_flowers/scratch/pre-2/model_epoch_15.ckpt
Evaluating on flowers test set...


Testing: 100%|██████████| 25/25 [01:58<00:00,  4.73s/it]



Results for flowers:
Test Accuracy: 1.20%
CIFAR-10 Accuracy: 1.20%, model: pre-2
Using device: cuda
Loading model from: ./checkpoints/reszta/checkpoints_flowers/checkpoints_flowers/scratch/pre-3/model_epoch_15.ckpt
Evaluating on flowers test set...


Testing: 100%|██████████| 25/25 [01:53<00:00,  4.54s/it]



Results for flowers:
Test Accuracy: 66.64%
CIFAR-10 Accuracy: 66.64%, model: pre-3
Using device: cuda
Loading model from: ./checkpoints/reszta/checkpoints_flowers/checkpoints_flowers/scratch/ran-2/model_epoch_15.ckpt
Evaluating on flowers test set...


Testing: 100%|██████████| 25/25 [01:52<00:00,  4.51s/it]



Results for flowers:
Test Accuracy: 11.25%
CIFAR-10 Accuracy: 11.25%, model: ran-2
Using device: cuda
Loading model from: ./checkpoints/reszta/checkpoints_flowers/checkpoints_flowers/scratch/ran-3/model_epoch_15.ckpt
Evaluating on flowers test set...


Testing: 100%|██████████| 25/25 [01:52<00:00,  4.50s/it]



Results for flowers:
Test Accuracy: 14.64%
CIFAR-10 Accuracy: 14.64%, model: ran-3
66.64498448371887
Using device: cuda


100%|██████████| 792M/792M [02:17<00:00, 5.76MB/s] 
100%|██████████| 19.2M/19.2M [00:03<00:00, 5.10MB/s]
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x0000024D6747ECB0>
Traceback (most recent call last):
  File "c:\Users\User\Desktop\studia\Sem 6\WB\Warsztaty_Badawcze\.conda\lib\site-packages\torch\utils\data\dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "c:\Users\User\Desktop\studia\Sem 6\WB\Warsztaty_Badawcze\.conda\lib\site-packages\torch\utils\data\dataloader.py", line 1576, in _shutdown_workers
    if self._persistent_workers or self._workers_status[worker_id]:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_workers_status'


Loading model from: ./checkpoints/reszta/checkpoints_pets/checkpoints_pets/scratch/pre-2/model_epoch_15.ckpt
Evaluating on pets test set...


Testing: 100%|██████████| 3/3 [00:32<00:00, 10.73s/it]



Results for pets:
Test Accuracy: 10.46%
CIFAR-10 Accuracy: 10.46%, model: pre-2
Using device: cuda
Loading model from: ./checkpoints/reszta/checkpoints_pets/checkpoints_pets/scratch/pre-3/model_epoch_15.ckpt
Evaluating on pets test set...


Testing: 100%|██████████| 3/3 [00:25<00:00,  8.55s/it]



Results for pets:
Test Accuracy: 70.79%
CIFAR-10 Accuracy: 70.79%, model: pre-3
Using device: cuda
Loading model from: ./checkpoints/reszta/checkpoints_pets/checkpoints_pets/scratch/ran-2/model_epoch_15.ckpt
Evaluating on pets test set...


Testing: 100%|██████████| 3/3 [00:28<00:00,  9.37s/it]



Results for pets:
Test Accuracy: 9.10%
CIFAR-10 Accuracy: 9.10%, model: ran-2
Using device: cuda
Loading model from: ./checkpoints/reszta/checkpoints_pets/checkpoints_pets/scratch/ran-3/model_epoch_15.ckpt
Evaluating on pets test set...


Testing: 100%|██████████| 3/3 [00:29<00:00,  9.69s/it]


Results for pets:
Test Accuracy: 15.08%
CIFAR-10 Accuracy: 15.08%, model: ran-3
70.7880437374115





In [88]:
models=['model1','model2','model3','model4']
acc=[]
# Test on CIFAR-100
for model in models:
    cifar100_acc = evaluate_model(
        model_path=f"./checkpoints/cifar100/baseline/{model}/model_epoch_15.ckpt",
        dataset_name="CIFAR100",
        path_to_data="./data"
    )
    print(f"CIFAR-100 Accuracy: {cifar100_acc:.2f}%, model: {model}")
    acc.append(cifar100_acc)



Using device: cuda
Loading model from: ./checkpoints/cifar100/baseline/model1/model_epoch_15.ckpt
Evaluating on CIFAR100 test set...


Testing:   0%|          | 0/40 [00:03<?, ?it/s]


KeyboardInterrupt: 

In [83]:
models=['pre-2','pre-3','ran-2','ran-3']
acc=[]
# Test on CIFAR-10
for model in models:
    cifar10_acc = evaluate_model(
        model_path=f"./checkpoints/cifar10/baseline/{model}/model_epoch_15.ckpt",
        dataset_name="CIFAR10",
        path_to_data="./data"
    )
    print(f"CIFAR-10 Accuracy: {cifar10_acc:.2f}%, model: {model}")
    acc.append(cifar10_acc)
print(max(acc))

Using device: cuda
Loading model from: ./checkpoints/cifar10/baseline/pre-2/model_epoch_15.ckpt
Evaluating on CIFAR10 test set...


Testing:   0%|          | 0/40 [00:07<?, ?it/s]


KeyboardInterrupt: 

In [None]:
test_byol_imgnet()
test_simclr_imgnet()
test_mae_imgnet()
