In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader
import ResNetModel as models
from dataloader import LeukemiaLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def test(model, model_name):
    test_acc = 0
    model.eval()
    for i, (data, label) in enumerate(tqdm(test_dataloader)):
        with torch.no_grad():  
            data = data.to(device, dtype = torch.float)
            label = label.to(device, dtype = torch.long)
            predicts = model(data)  
            test_acc += predicts.max(dim=1)[1].eq(label).sum().item() 

    test_acc = 100.0 * test_acc / len(test_dataloader.dataset)
    print(f'Valid acc: {test_acc:.3f}') 
    torch.cuda.empty_cache() 

if __name__ == '__main__':
    all_model = ['ResNet18', 'ResNet50', 'ResNet152']
    
    for model_name in all_model:
        batch_size = 32
        test_dataloader = DataLoader(LeukemiaLoader('', 'valid'), batch_size = batch_size, shuffle = False)
        
        if model_name == 'ResNet18':
            model = models.ResNet18()
            
        elif model_name == 'ResNet50':
            model = models.ResNet50()
            
        else:
            model = models.ResNet152()

        print(f'{model_name}:')
        model.load_state_dict(torch.load(f'./{model_name}.pt'))
        model = model.to(device)
        test(model, model_name)
        print('----------------------------\n')