In [None]:
import os
import args
import logging
import torch
from datetime import datetime
from model import Classifier
from uuid import uuid4
from data import ProjectDataset
from torch.utils.data import DataLoader
from timm.utils import AverageMeter
from utils import *
from torch.utils.tensorboard import SummaryWriter

In [None]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('Log')

In [None]:
def test(test_loader, model, criterion, device, eval_mode=True):
    model.eval()
    losses = AverageMeter()
    metrics_tracker = MetricTracker(args.num_classes).to(device)
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            with torch.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)

            losses.update(loss.item(), images.size(0))
            metrics_tracker.update(outputs, labels)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += predicted.eq(labels.data).sum().item()

    test_accuracy = 100 * (correct / total)
    metrics = metrics_tracker.compute()
    metrics_tracker.reset()
    if eval_mode:
        return losses.avg, test_accuracy
    return test_accuracy, metrics

In [None]:
train_data = ProjectDataset(mode='train', root_dir='dataset', seed=args.seed)
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
train_len = len(train_loader.dataset)

val_data = ProjectDataset(mode='val', root_dir='dataset', seed=args.seed)
val_loader = DataLoader(val_data, batch_size=args.batch_size, shuffle=False)
val_len = len(val_loader.dataset)

test_data = ProjectDataset(mode='test', root_dir='dataset', seed=args.seed)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False)
test_len = len(test_loader.dataset)

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

model = Classifier(num_classes=args.num_classes)
model.to(device)
model_name = model.model.__class__.__name__

if not os.path.exists("./checkpoint/"):
    os.mkdir("./checkpoint/")

In [None]:
base_params = [  # Parameters in the model except FC layer
        param for name, param in model.named_parameters() if "fc" not in str(name)
    ]

optimiser = torch.optim.Adam(
    params=[
        {"params": base_params},
        {"params": model.model.fc.parameters(), "lr": 0.01},
    ],
    lr=args.lr,
    weight_decay=args.weight_decay,
)
#   SCHEDULER
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimiser, gamma=args.gamma)
#   LOSS
criterion = torch.nn.CrossEntropyLoss()
#   SCALER
scaler = torch.cuda.amp.GradScaler()

log_dir = (f"logs/runs/"
           f"{model.__class__.__name__}/"
           f"lr_{args.lr}_bs_{args.batch_size}/"
           f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_"
           f"{uuid4().hex[:6]}")
writer = SummaryWriter(log_dir)
handler = logging.FileHandler(f'{log_dir}/log.txt')
logger.addHandler(handler)

In [None]:
if not args.eval_mode:
    if args.model_path:
        model.load_state_dict(torch.load(args.model_path))
        model.to(device)
    else:
        raise ValueError("Model path not set in args.py")
    test_accuracy, test_metrics = test(test_loader, model, criterion, device, args.eval_mode)

    print_summary(logger, model_name, train_len, val_len, test_len)
    logger.info("=> Testing Results")
    logger.info("Top-1 Accuracy: {:.2f}%".format(test_metrics['top1_acc']))
    logger.info("Top-5 Accuracy: {:.2f}%".format(test_metrics['top5_acc']))
    logger.info("F1-Score: {:.2f}%".format(test_metrics['f1']))
handler.close()