# Cyclic Self Training Implementation on SVHN (Source) -> MNIST (Target) Dataset

In [1]:
import random
import time
import warnings
import sys
import argparse
import shutil
import os.path as osp
import os
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch.nn.functional as F

In [2]:
sys.path.append(os.path.dirname(os.getcwd()))
import common.vision.datasets as datasets
import common.vision.models as models
from common.vision.transforms import ResizeImage
from common.utils.data import ForeverDataIterator
from common.utils.metric import accuracy, ConfusionMatrix
from common.utils.meter import AverageMeter, ProgressMeter
from common.utils.logger import CompleteLogger
from common.utils.analysis import collect_feature, tsne, a_distance
from common.tools.randaugment import rand_augment_transform, GaussianBlur
from common.tools.fix_utils import ImageClassifier
from common.tools.sam import SAM

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using Device:", device)
rgb_mean = (0.485, 0.456, 0.406)
ra_params = dict(translate_const=int(224 * 0.45), img_mean=tuple([min(255, round(255 * x)) for x in rgb_mean]),)

Using Device: cuda


In [4]:
MNIST_MEAN, MNIST_STD = (0.1307,), (0.3081,)
SVHN_MEAN, SVHN_STD = (0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970)

def get_transforms(dataset_name, crop=True):
    if dataset_name.lower() == "mnist":
        normalize = T.Normalize(mean=[0.1307]*3, std=[0.3081]*3)
        to3channels = T.Lambda(lambda x: x.repeat(3,1,1))
        base = [
            ResizeImage(256),
            T.CenterCrop(224) if crop else T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            to3channels,
            normalize
        ]
    elif dataset_name.lower() == "svhn":
        normalize = T.Normalize(mean=SVHN_MEAN, std=SVHN_STD)
        base = [
            ResizeImage(256),
            T.CenterCrop(224) if crop else T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            normalize
        ]
    else:
        normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        base = [
            ResizeImage(256),
            T.CenterCrop(224) if crop else T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            normalize
        ]
    return T.Compose(base)


In [5]:
class TransformFixMatch(object):
    def __init__(self, dataset_name="default"):
        if dataset_name.lower() == "mnist":
            normalize = T.Normalize(mean=[0.1307]*3, std=[0.3081]*3)
            to3channels = T.Lambda(lambda x: x.repeat(3,1,1))
            weak = [
                ResizeImage(256),
                T.CenterCrop(224),
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                to3channels,
                normalize
            ]
            strong = [
                ResizeImage(256),
                T.CenterCrop(224),
                T.RandomHorizontalFlip(),
                T.RandomApply([
                    T.ColorJitter(0.4, 0.4, 0.4, 0.0)
                ], p=1.0),
                rand_augment_transform('rand-n{}-m{}-mstd0.5'.format(2, 10), ra_params),
                T.ToTensor(),
                to3channels,
                normalize
            ]
        elif dataset_name.lower() == "svhn":
            normalize = T.Normalize(mean=SVHN_MEAN, std=SVHN_STD)
            weak = [
                ResizeImage(256),
                T.CenterCrop(224),
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                normalize
            ]
            strong = [
                ResizeImage(256),
                T.CenterCrop(224),
                T.RandomHorizontalFlip(),
                T.RandomApply([
                    T.ColorJitter(0.4, 0.4, 0.4, 0.0)
                ], p=1.0),
                rand_augment_transform('rand-n{}-m{}-mstd0.5'.format(2, 10), ra_params),
                T.ToTensor(),
                normalize
            ]
        else:
            normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            weak = [
                ResizeImage(256),
                T.CenterCrop(224),
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                normalize
            ]
            strong = [
                ResizeImage(256),
                T.CenterCrop(224),
                T.RandomHorizontalFlip(),
                T.RandomApply([
                    T.ColorJitter(0.4, 0.4, 0.4, 0.0)
                ], p=1.0),
                rand_augment_transform('rand-n{}-m{}-mstd0.5'.format(2, 10), ra_params),
                T.ToTensor(),
                normalize
            ]
        self.weak = T.Compose(weak)
        self.strong = T.Compose(strong)
    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)
        return weak, strong


In [6]:
def entropy(predictions: torch.Tensor, reduction='none') -> torch.Tensor:
    epsilon = 1e-5
    H = -predictions * torch.log(predictions + epsilon)
    H = H.sum(dim=1)
    if reduction == 'mean':
        return H.mean()
    else:
        return H

class TsallisEntropy(nn.Module):
    def __init__(self, temperature: float, alpha: float):
        super(TsallisEntropy, self).__init__()
        self.temperature = temperature
        self.alpha = alpha

    def forward(self, logits: torch.Tensor) -> torch.Tensor:
        N, C = logits.shape
        pred = F.softmax(logits / self.temperature, dim=1) 
        entropy_weight = entropy(pred).detach()
        entropy_weight = 1 + torch.exp(-entropy_weight)
        entropy_weight = (N * entropy_weight / torch.sum(entropy_weight)).unsqueeze(dim=1)  
        sum_dim = torch.sum(pred * entropy_weight, dim = 0).unsqueeze(dim=0)
        return 1 / (self.alpha - 1) * torch.sum((1 / torch.mean(sum_dim) - torch.sum(pred ** self.alpha / sum_dim * entropy_weight, dim = -1)))


In [7]:
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    train_transform = get_transforms(args.data, crop=args.center_crop)
    unlabeled_transform = TransformFixMatch(args.data)
    val_transform = get_transforms(args.data, crop=True)

    dataset = datasets.__dict__[args.data]
    train_source_dataset = dataset(root=args.root, task=args.source, download=True, transform=train_transform)
    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
                                     shuffle=True, num_workers=args.workers, drop_last=True)
    train_target_dataset = dataset(root=args.root, task=args.target, download=True, transform=unlabeled_transform)
    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
                                     shuffle=True, num_workers=args.workers, drop_last=True)
    val_dataset = dataset(root=args.root, task=args.target, download=True, transform=val_transform)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    if args.data == 'DomainNet':
        test_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform)
        test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    else:
        test_loader = val_loader

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    print("=> using pre-trained model '{}'".format(args.arch))
    backbone = models.__dict__[args.arch](pretrained=True)
    num_classes = train_source_dataset.num_classes
    args.num_cls = num_classes
    classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim).to(device)

    # define optimizer and lr scheduler
    base_optimizer = SGD
    optimizer = SAM(classifier.get_parameters(), base_optimizer, lr = args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay, adaptive = True, rho = args.rho)
    lr_scheduler = LambdaLR(optimizer, lambda x:  args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))

    # define loss function
    ts_loss = TsallisEntropy(temperature=args.temperature, alpha = args.alpha)

    # resume from the best checkpoint
    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
        classifier.load_state_dict(checkpoint)

    # analysis the model
    if args.phase == 'analysis':
        # extract features from both domains
        feature_extractor = nn.Sequential(classifier.backbone, classifier.bottleneck).to(device)
        source_feature = collect_feature(train_source_loader, feature_extractor, device)
        target_feature = collect_feature(train_target_loader, feature_extractor, device)
        # plot t-SNE
        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
        tsne.visualize(source_feature, target_feature, tSNE_filename)
        print("Saving t-SNE to", tSNE_filename)
        # calculate A-distance, which is a measure for distribution discrepancy
        A_distance = a_distance.calculate(source_feature, target_feature, device)
        print("A-distance =", A_distance)
        return

    if args.phase == 'test':
        acc1 = validate(test_loader, classifier, args)
        print(acc1)
        return

    # start training
    best_acc1 = 0.
    for epoch in range(min(args.epochs, args.early)):
        print("lr:", lr_scheduler.get_last_lr()[0])
        # train for one epoch
        train(train_source_iter, train_target_iter, classifier, ts_loss, optimizer,
              lr_scheduler, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, classifier, args)

        # remember best acc@1 and save checkpoint
        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
        if acc1 > best_acc1:
            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
        best_acc1 = max(acc1, best_acc1)

    print("best_acc1 = {:3.1f}".format(best_acc1))

    # evaluate on test set
    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
    acc1 = validate(test_loader, classifier, args)
    print("test_acc1 = {:3.1f}".format(acc1))

    logger.close()


In [8]:
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
          model: ImageClassifier, ts: TsallisEntropy, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':3.1f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    trans_losses = AverageMeter('Trans Loss', ':3.2f')
    rev_losses = AverageMeter('CST Loss', ':3.2f')
    fix_losses = AverageMeter('Fix Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, trans_losses, rev_losses, fix_losses, cls_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        (x_t, x_t_u), _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        x_t_u = x_t_u.to(device)
        labels_s = labels_s.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_t_u, _ = model(x_t_u)

        f_s, f_t = f.chunk(2, dim=0)
        y_s, y_t = y.chunk(2, dim=0)

        # generate target pseudo-labels
        max_prob, pred_u = torch.max(F.softmax(y_t), dim=-1)
        Lu = (F.cross_entropy(y_t_u, pred_u,
                              reduction='none') * max_prob.ge(args.threshold).float().detach()).mean()

        # compute cst
        target_data_train_r = f_t
        target_data_train_r = target_data_train_r / (torch.norm(target_data_train_r, dim = -1).reshape(target_data_train_r.shape[0], 1))
        target_data_test_r = f_s
        target_data_test_r = target_data_test_r / (torch.norm(target_data_test_r, dim = -1).reshape(target_data_test_r.shape[0], 1))
        target_gram_r = torch.clamp(target_data_train_r.mm(target_data_train_r.transpose(dim0 = 1, dim1 = 0)),-0.99999999,0.99999999)
        target_kernel_r = target_gram_r
        test_gram_r = torch.clamp(target_data_test_r.mm(target_data_train_r.transpose(dim0 = 1, dim1 = 0)),-0.99999999,0.99999999)
        test_kernel_r = test_gram_r
        target_train_label_r = torch.nn.functional.one_hot(pred_u, args.num_cls) - 1 / float(args.num_cls) 
        target_test_pred_r = test_kernel_r.mm(torch.inverse(target_kernel_r + 0.001 * torch.eye(args.batch_size).cuda())).mm(target_train_label_r)
        reverse_loss = nn.MSELoss()(target_test_pred_r, torch.nn.functional.one_hot(labels_s, args.num_cls) - 1 / float(args.num_cls)) 

        cls_loss = F.cross_entropy(y_s, labels_s)
        transfer_loss = ts(y_t)

        if Lu != 0:
            loss = cls_loss + transfer_loss * args.trade_off + reverse_loss * args.trade_off1 + Lu * args.trade_off3
        else: 
            loss = cls_loss + transfer_loss * args.trade_off + reverse_loss * args.trade_off1

        cls_acc = accuracy(y_s, labels_s)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        trans_losses.update(transfer_loss.item(), x_s.size(0))
        rev_losses.update(reverse_loss.item(), x_s.size(0))
        fix_losses.update(Lu.item(), x_s.size(0))

        # compute gradient and do the first SGD step
        loss.backward()
        optimizer.first_step(zero_grad=True)
        lr_scheduler.step()

        # compute gradient and do the second SGD step
        y, f = model(x)
        y_t_u, _ = model(x_t_u)

        f_s, f_t = f.chunk(2, dim=0)
        y_s, y_t = y.chunk(2, dim=0)

        # generate target pseudo-labels
        max_prob, pred_u = torch.max(F.softmax(y_t), dim=-1)
        Lu = (F.cross_entropy(y_t_u, pred_u,
                              reduction='none') * max_prob.ge(args.threshold).float().detach()).mean()

        # compute cst
        target_data_train_r = f_t
        target_data_train_r = target_data_train_r / (torch.norm(target_data_train_r, dim = -1).reshape(target_data_train_r.shape[0], 1))
        target_data_test_r = f_s
        target_data_test_r = target_data_test_r / (torch.norm(target_data_test_r, dim = -1).reshape(target_data_test_r.shape[0], 1))
        target_gram_r = torch.clamp(target_data_train_r.mm(target_data_train_r.transpose(dim0 = 1, dim1 = 0)),-0.99999999,0.99999999)
        target_kernel_r = target_gram_r
        test_gram_r = torch.clamp(target_data_test_r.mm(target_data_train_r.transpose(dim0 = 1, dim1 = 0)),-0.99999999,0.99999999)
        test_kernel_r = test_gram_r
        target_train_label_r = torch.nn.functional.one_hot(pred_u, args.num_cls) - 1 / float(args.num_cls) 
        target_test_pred_r = test_kernel_r.mm(torch.inverse(target_kernel_r + 0.001 * torch.eye(args.batch_size).cuda())).mm(target_train_label_r)
        reverse_loss = nn.MSELoss()(target_test_pred_r, torch.nn.functional.one_hot(labels_s, args.num_cls) - 1 / float(args.num_cls)) 

        cls_loss = F.cross_entropy(y_s, labels_s)
        transfer_loss = ts(y_t)

        if Lu != 0:
            loss1 = cls_loss + transfer_loss * args.trade_off + reverse_loss * args.trade_off1 + Lu * args.trade_off3
        else: 
            loss1 = cls_loss + transfer_loss * args.trade_off + reverse_loss * args.trade_off1

        loss1.backward()
        optimizer.second_step(zero_grad=True)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)


In [9]:
def validate(val_loader: DataLoader, model: ImageClassifier, args: argparse.Namespace) -> float:
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    if args.per_class_eval:
        classes = val_loader.dataset.classes
        confmat = ConfusionMatrix(len(classes))
    else:
        confmat = None

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            output, _ = model(images)
            loss = F.cross_entropy(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            if confmat:
                confmat.update(target, output.argmax(1))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))
        if confmat:
            print(confmat.format(classes))

    return top1.avg


In [10]:
architecture_names = sorted(
    name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name])
)
dataset_names = sorted(
    name for name in datasets.__dict__
    if not name.startswith("__") and callable(datasets.__dict__[name])
)
from types import SimpleNamespace

args = SimpleNamespace(
    root = './data',         
    data = 'SVHN',          
    source = 'svhn',         
    target = 'mnist',        
    center_crop = False,    
    arch = 'resnet18',       
    bottleneck_dim = 256,   
    temperature = 2.0,      
    alpha = 1.9,             
    trade_off = 0.08,        
    trade_off1 = 0.5,        
    trade_off3 = 0.5,        
    threshold = 0.97,       
    rho = 0.5,               
    batch_size = 28,        
    lr = 0.005,             
    lr_gamma = 0.001,        
    lr_decay = 0.75,        
    momentum = 0.9,          
    weight_decay = 1e-3,    
    workers = 2,            
    epochs = 15,             
    early = 15,              
    iters_per_epoch = 1000, 
    print_freq = 100,        
    seed = None,             
    per_class_eval = True,  
    log = 'logs',             
    phase = 'train',        
)

print("Available architectures:", architecture_names)
print("Arguments set for CST experiment (SVHN -> MNIST):")
for k, v in args.__dict__.items():
    print(f"{k} = {v}")


Available architectures: ['resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'resnext101_32x8d', 'resnext50_32x4d', 'wide_resnet101_2', 'wide_resnet50_2']
Arguments set for CST experiment (SVHN -> MNIST):
root = ./data
data = SVHN
source = svhn
target = mnist
center_crop = False
arch = resnet18
bottleneck_dim = 256
temperature = 2.0
alpha = 1.9
trade_off = 0.08
trade_off1 = 0.5
trade_off3 = 0.5
threshold = 0.97
rho = 0.5
batch_size = 28
lr = 0.005
lr_gamma = 0.001
lr_decay = 0.75
momentum = 0.9
weight_decay = 0.001
workers = 2
epochs = 15
early = 15
iters_per_epoch = 1000
print_freq = 100
seed = None
per_class_eval = True
log = logs
phase = train


In [None]:
if __name__ == '__main__':
    main(args)

namespace(root='./data', data='SVHN', source='svhn', target='mnist', center_crop=False, arch='resnet18', bottleneck_dim=256, temperature=2.0, alpha=1.9, trade_off=0.08, trade_off1=0.5, trade_off3=0.5, threshold=0.97, rho=0.5, batch_size=28, lr=0.005, lr_gamma=0.001, lr_decay=0.75, momentum=0.9, weight_decay=0.001, workers=2, epochs=15, early=15, iters_per_epoch=1000, print_freq=100, seed=None, per_class_eval=True, log='logs', phase='train')
=> using pre-trained model 'resnet18'
lr: 0.0005
  max_prob, pred_u = torch.max(F.softmax(y_t), dim=-1)
  max_prob, pred_u = torch.max(F.softmax(y_t), dim=-1)
Epoch: [0][   0/1000]	Time 1.3 (1.3)	Data 0.0 (0.0)	Loss 3.21 (3.21)	Trans Loss 9.69 (9.69)	CST Loss 0.12 (0.12)	Fix Loss 0.00 (0.00)	Cls Acc 10.7 (10.7)
Epoch: [0][ 100/1000]	Time 0.1 (0.1)	Data 0.0 (0.0)	Loss 2.81 (3.00)	Trans Loss 9.65 (9.66)	CST Loss 0.11 (0.12)	Fix Loss 0.00 (0.00)	Cls Acc 25.0 (23.2)
Epoch: [0][ 200/1000]	Time 0.1 (0.1)	Data 0.0 (0.0)	Loss 2.79 (2.88)	Trans Loss 9.49 (9.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

def generate_confusion_matrix(model, test_loader, device):
    model.eval()
    
    all_targets = []
    all_predictions = []
    
    with torch.no_grad():
        for images, target in test_loader:
            images = images.to(device)
            target = target.to(device)
            
            output, _ = model(images)
            
            _, preds = torch.max(output, 1)
            
            all_targets.extend(target.cpu().numpy())
            all_predictions.extend(preds.cpu().numpy())
    
    cm = confusion_matrix(all_targets, all_predictions)
    
    fig, ax = plt.subplots(figsize=(10, 10))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, 
                                  display_labels=list(range(10)))  
    disp.plot(ax=ax, cmap=plt.cm.Blues, values_format='d')
    plt.title('Confusion Matrix for SVHN → MNIST Transfer')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.tight_layout()
    plt.show()
    
    return cm

classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))

confusion_mat, class_accuracies = generate_confusion_matrix(classifier, test_loader, device)

In [None]:
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix

def test_and_print_metrics(model, test_loader, device):
    model.eval()
    all_targets = []
    all_predictions = []

    with torch.no_grad():
        for images, targets in test_loader:
            images = images.to(device)
            targets = targets.to(device)
            outputs, _ = model(images)
            _, preds = torch.max(outputs, 1)
            all_targets.extend(targets.cpu().numpy())
            all_predictions.extend(preds.cpu().numpy())

    all_targets = np.array(all_targets)
    all_predictions = np.array(all_predictions)

    print("Idx\tGround Truth\tPredicted")
    for i in range(min(30, len(all_targets))):
        print(f"{i}\t{all_targets[i]}\t\t{all_predictions[i]}")

    print("\nClassification Report:")
    print(classification_report(all_targets, all_predictions, digits=4))

    print("Confusion Matrix:")
    print(confusion_matrix(all_targets, all_predictions))

classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
test_and_print_metrics(classifier, test_loader, device)