In [None]:
import os
import json
import random
import argparse
import itertools
import math
import torch
import numpy as np
from torch import nn, optim
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms
from tqdm import tqdm

import utils

from data_utils import (
    CS4243dataset
)
from models import (
    ViT
)

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
def get_transform(is_train):
    data_transforms = []
    data_transforms.append(transforms.Resize((512, 512)))
    if is_train:
        data_transforms.append(transforms.RandomHorizontalFlip())
    data_transforms.append(transforms.ToTensor())
    return transforms.Compose(data_transforms)

def run(device, hps):
    valid_data = CS4243dataset(hps.CS4243dataset.dataset_path, hps.CS4243dataset.label_dict, is_train=False, transform=get_transform(False))
    
    valid_loader = DataLoader(dataset = valid_data, batch_size=hps.train.batch_size, shuffle=False)
    
    vit = ViT(
        image_size = hps.CS4243dataset.image_size,
        patch_size = hps.CS4243dataset.patch_size,
        num_classes = hps.CS4243dataset.num_classes,
        **hps.ViTmodel
    ).to(device)
    
    # loss function
    criterion = nn.CrossEntropyLoss()
    
def evaluate(device, model, criterion, valid_loader):
    model.eval()
    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data, label = data.to(device), label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)
    model.train()
    return epoch_val_accuracy, epoch_val_loss