In [1]:
import pickle
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms

from models import *
from utils import progress_bar

In [2]:
with open(r"./data/cifar_test_nolabel.pkl", "rb") as f:
    data = pickle.load(f)

In [3]:
def load_data(args):

    class CustomDataset(torch.utils.data.Dataset):
        def __init__(self, data, transform=None):
            self.data = data
            self.transform = transform

        def __len__(self):
            return len(self.data)

        def __getitem__(self, idx):
            image = self.data[idx]
            label = 0

            if self.transform:
                image = self.transform(image)

            return image, label

    transform_test = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
    custom_testset = CustomDataset(data[b"data"], transform=transform_test)

    testloader = torch.utils.data.DataLoader(
        custom_testset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
    )

    return testloader

In [4]:
def load_checkpoint(args):


    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"device: {device}")

    print("==> build model...")
    net = Resnet_Leaky()
    net = net.to(device)

    if device == "cuda":
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    print(f"==> load {args.checkpoint}...")
    checkpoint = torch.load(args.checkpoint)

    if "net" in checkpoint:
        state_dict = checkpoint["net"]
        from collections import OrderedDict

        new_state_dict = OrderedDict()
        if all(k.startswith("module.") for k in state_dict.keys()) and not all(
            k.startswith("module.") for k in net.state_dict().keys()
        ):
            for k, v in state_dict.items():
                name = k[7:]  
                new_state_dict[name] = v
        elif not all(k.startswith("module.") for k in state_dict.keys()) and all(
            k.startswith("module.") for k in net.state_dict().keys()
        ):
            for k, v in state_dict.items():
                name = "module." + k
                new_state_dict[name] = v
        else:
            new_state_dict = state_dict

        try:
            net.load_state_dict(new_state_dict, strict=False)

        except Exception as e:
            print(f"load fail: {e}")
            print("trial...")

            print("model key:")
            model_keys = set(net.state_dict().keys())
            print(f"total: {len(model_keys)}")
            print("keys:", list(model_keys)[:5])

            print("Checkpoint key:")
            ckpt_keys = set(state_dict.keys())
            print(f"total: {len(ckpt_keys)}")
            print("keys:", list(ckpt_keys)[:5])

            compatible_state_dict = {}
            for model_key in model_keys:
                possible_keys = [
                    model_key,
                    model_key.replace("module.", ""),
                    "module." + model_key,
                ]

                for key in possible_keys:
                    if key in ckpt_keys:
                        compatible_state_dict[model_key] = state_dict[key]
                        break

            if compatible_state_dict:
                net.load_state_dict(compatible_state_dict, strict=False)
                print(
                    f"load {len(compatible_state_dict)}/{len(model_keys)} weights"
                )
            else:
                print("fail")

        best_acc = checkpoint.get("acc", 0)
        epoch = checkpoint.get("epoch", 0)
        print(f"load, best accuracy: {best_acc:.2f}%, epoch: {epoch}")

    return net, device

In [5]:

def test(args):
    testloader = load_data(args)
    net, device = load_checkpoint(args)
    net.eval()

    all_predictions = []
    all_sample_indices = []

    with torch.no_grad():
        for batch_idx, (inputs, _) in enumerate(testloader):
            inputs = inputs.to(device)
            outputs = net(inputs)
            _, predicted = outputs.max(1)

            for i in range(inputs.size(0)):
                sample_idx = batch_idx * testloader.batch_size + i
                all_sample_indices.append(sample_idx)
                all_predictions.append(predicted[i].item())

            progress_bar(
                batch_idx,
                len(testloader),
                "Processing: %d/%d"
                % (
                    batch_idx + 1,
                    len(testloader),
                ),
            )
    import csv

    with open("prediction_results.csv", "w", newline="") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["ID", "Labels"])
        for idx, pred in zip(all_sample_indices, all_predictions):
            writer.writerow(
                [
                    idx,
                    pred,
                ]
            )

    print("Results saved to prediction_results.csv")

In [6]:
class args:
    def __init__(self):
        self.batch_size = 128
        self.num_workers = 0
        self.checkpoint = r"./checkpoint/ckpt9549_leaky1.pth"


args = args()
test(args)

device: cuda
==> build model...
==> load ./checkpoint/ckpt9549_leaky1.pth...
load, best accuracy: 95.49%, epoch: 237
Results saved to prediction_results.csv
