In [None]:
pip install tensorboardX

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorboardX
  Downloading tensorboardX-2.6-py2.py3-none-any.whl (114 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.5/114.5 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.6


In [None]:
import argparse
import os
from time import time

import numpy as np
import torch
import torch.nn as nn
import torchvision

from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from torchvision import datasets
from tqdm import tqdm

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd '/content/drive/MyDrive/Colab Notebooks/CoTuning-main/module'


Mounted at /content/drive
/content/drive/MyDrive/Colab Notebooks/CoTuning-main/module


In [None]:
from backbone import ResNet50_F, ResNet50_C
from relationship_learning import relationship_learning

In [None]:
%cd '/content/drive/MyDrive/Colab Notebooks/CoTuning-main/utils'

/content/drive/MyDrive/Colab Notebooks/CoTuning-main/utils


In [None]:
from transforms import get_transforms
from tools import AccuracyMeter, TenCropsTest

In [None]:
ResNet50_F
ResNet50_C
relationship_learning
get_transforms
AccuracyMeter
TenCropsTest

<function tools.TenCropsTest(loader, net)>

In [None]:

def get_writer(log_dir):
    return SummaryWriter(log_dir)



In [None]:


gpu=0
seed=2020
batch_size=48
total_iter=9050
eval_iter=1000
save_iter=9000
print_iter=100

# dataset
data_path="/content/drive/MyDrive/Colab Notebooks/CoTuning-main/CUB_200_2011"
class_num=200
num_workers=2

# optimizer
lr=1e-3
gamma=0.1
nesterov=True
momentum=0.9
weight_decay=5e-4

    # experiment
root='.'
name='StochNorm'
trade_off=2.3
relationship_path='relationship.npy'
save_dir="model"
visual_dir="visual"

In [None]:

def str2list(v):
    return v.split(',')


def str2bool(v):
    return v.lower() in ("yes", "true", "t", "1")


def get_data_loader():
    data_transforms = get_transforms(resize_size=256, crop_size=224)

    # build dataset
    train_dataset = datasets.ImageFolder(
        os.path.join(data_path, 'images'),
        transform=data_transforms['train'])
    determin_train_dataset = datasets.ImageFolder(
        os.path.join(data_path, 'images'),
        transform=data_transforms['val'])
    val_dataset = datasets.ImageFolder(
        os.path.join(data_path, 'images'),
        transform=data_transforms['val'])
    test_datasets = {
        'test' + str(i):
            datasets.ImageFolder(
                os.path.join(data_path, 'images'),
                transform=data_transforms["test" + str(i)]
        )
        for i in range(10)
    }

    # build dataloader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=True)
    determin_train_loader = DataLoader(determin_train_dataset, batch_size=batch_size, shuffle=False,
                                       num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, pin_memory=True)
    test_loaders = {
        'test' + str(i):
            DataLoader(
                test_datasets["test" + str(i)],
                batch_size=4, shuffle=False, num_workers=num_workers
        )
        for i in range(10)
    }

    return train_loader, determin_train_loader, val_loader, test_loaders


In [None]:

def set_seeds(seed):
    # Set random seeds for reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def main():
    torch.cuda.set_device(gpu)
    set_seeds(seed)

    train_loader, determin_train_loader, val_loader, test_loaders = get_data_loader()

    # Define the neural network model
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.f_net = ResNet50_F(pretrained=True)
            self.c_net_1 = ResNet50_C(pretrained=True)
            self.c_net_2 = nn.Linear(self.f_net.output_dim, class_num)
            self.c_net_2.weight.data.normal_(0, 0.01)
            self.c_net_2.bias.data.fill_(0.0)

        def forward(self, x):
            feature = self.f_net(x)
            out_1 = self.c_net_1(feature)
            out_2 = self.c_net_2(feature)

            return out_1, out_2

    net = Net().cuda()

    if os.path.exists(relationship_path):
        print('Loading pre-computed relationship from {}.'.format(relationship_path))
        relationship = np.load(relationship_path)
    else:
        print('Computing relationship')

        def get_feature(loader):
            train_labels_list = []
            imagenet_labels_list = []

            for train_inputs, train_labels in tqdm(loader):
                net.eval()
                train_labels_list.append(train_labels)

                train_inputs, train_labels = train_inputs.cuda(), train_labels.cuda()
                imagenet_labels, _ = net(train_inputs)
                imagenet_labels = imagenet_labels.detach().cpu().numpy()

                imagenet_labels_list.append(imagenet_labels)

            all_train_labels = np.concatenate(train_labels_list, 0)
            all_imagenet_labels = np.concatenate(imagenet_labels_list, 0)

            return all_imagenet_labels, all_train_labels

        train_imagenet_labels, train_train_labels = get_feature(determin_train_loader)
        val_imagenet_labels, val_train_labels = get_feature(val_loader)
        relationship = relationship_learning(train_imagenet_labels, train_train_labels,
                                             val_imagenet_labels, val_train_labels)

        np.save(relationship_path, relationship)

    train(train_loader, val_loader, test_loaders, net, relationship)


In [None]:

def train(train_loader, val_loader, test_loaders, net, relationship):
    # Get the length of the train loader
    train_len = len(train_loader) - 1
    train_iter = iter(train_loader)

    # Different learning rates for different layers
    params_list = [
        {"params": filter(lambda p: p.requires_grad, net.f_net.parameters())},
        {"params": filter(lambda p: p.requires_grad, net.c_net_1.parameters())},
        {"params": filter(lambda p: p.requires_grad, net.c_net_2.parameters()), "lr": lr * 10}
    ]

    # Optimizer setup
    optimizer = torch.optim.SGD(params_list, lr=lr, weight_decay=weight_decay,
                                momentum=momentum, nesterov=nesterov)

    # Learning rate scheduling
    milestones = [6000]
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=gamma)

    # Check visual path
    visual_path = os.path.join(visual_dir, name)
    if not os.path.exists(visual_path):
        os.makedirs(visual_path)
    writer = get_writer(visual_path)

    # Check model save path
    save_path = os.path.join(save_dir, name)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    for iter_num in range(total_iter):
        net.train()

        if iter_num % train_len == 0:
            train_iter = iter(train_loader)

        # Data Stage
        data_start = time()

        # Get the next batch of training inputs and labels
        train_inputs, train_labels = next(train_iter)

        # Convert labels to one-hot encoding and move data to GPU
        imagenet_targets = torch.from_numpy(relationship[train_labels]).cuda().float()
        train_inputs, train_labels = train_inputs.cuda(), train_labels.cuda()

        data_duration = time() - data_start

        # Calc Stage
        calc_start = time()

        # Forward pass
        imagenet_outputs, train_outputs = net(train_inputs)

        # Calculate losses
        ce_loss = nn.CrossEntropyLoss()(train_outputs, train_labels)
        imagenet_loss = - imagenet_targets * nn.LogSoftmax(dim=-1)(imagenet_outputs)
        imagenet_loss = torch.mean(torch.sum(imagenet_loss, dim=-1))
        loss = ce_loss + trade_off * imagenet_loss

        # Log losses
        writer.add_scalar('loss/ce_loss', ce_loss, iter_num)
        writer.add_scalar('loss/imagenet_loss', imagenet_loss, iter_num)
        writer.add_scalar('loss/loss', loss, iter_num)

        # Backpropagation and optimization
        net.zero_grad()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        calc_duration = time() - calc_start

        if iter_num % eval_iter == 0:
            # Evaluation on validation dataset
            acc_meter = AccuracyMeter(topk=(1,))
            with torch.no_grad():
                net.eval()
                for val_inputs, val_labels in tqdm(val_loader):
                    val_inputs, val_labels = val_inputs.cuda(), val_labels.cuda()
                    _, val_outputs = net(val_inputs)
                    acc_meter.update(val_outputs, val_labels)
                writer.add_scalar('acc/val_acc', acc_meter.avg[1], iter_num)
                print("Iter: {}/{} Val_Acc: {:2f}".format(iter_num, total_iter, acc_meter.avg[1]))
            acc_meter.reset()

        if iter_num % save_iter == 0 and iter_num > 0:
            # Evaluation on test dataset and model saving
            test_acc = TenCropsTest(test_loaders, net)
            writer.add_scalar('acc/test_acc', test_acc, iter_num)
            print("Iter: {}/{} Test_Acc: {:2f}".format(iter_num, total_iter, test_acc))
            checkpoint = {
                'state_dict': net.state_dict(),
                'iter': iter_num,
                'acc': test_acc,
            }
            torch.save(checkpoint, os.path.join(save_path, '{}.pkl'.format(iter_num)))
            print("Model Saved.")

        if iter_num % print_iter == 0:
            # Print progress
            print("Iter: {}/{} Loss_main: {:2f}, d/c: {}/{}".format(iter_num, total_iter, loss,
                                                                     data_duration, calc_duration))




In [None]:
if __name__ == '__main__':
    # print("PyTorch {}".format(torch.__version__))
    # print("TorchVision {}".format(torchvision.__version__))
    main()


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 236MB/s]


Loading pre-computed relationship from relationship.npy.


100%|██████████| 246/246 [27:30<00:00,  6.71s/it]


Iter: 0/9050 Val_Acc: 0.567412
Iter: 0/9050 Loss_main: 21.235962, d/c: 3.4217073917388916/2.002387285232544
Iter: 100/9050 Loss_main: 10.397576, d/c: 0.4138517379760742/0.16207289695739746
Iter: 200/9050 Loss_main: 10.094559, d/c: 0.31432652473449707/0.1681227684020996
Iter: 300/9050 Loss_main: 8.873284, d/c: 0.32572054862976074/0.1576075553894043
Iter: 400/9050 Loss_main: 7.709411, d/c: 0.281203031539917/0.1657414436340332
Iter: 500/9050 Loss_main: 8.347646, d/c: 0.9492826461791992/0.1599268913269043
Iter: 600/9050 Loss_main: 7.565969, d/c: 0.7830579280853271/0.16496729850769043
Iter: 700/9050 Loss_main: 6.670202, d/c: 0.9346919059753418/0.16289162635803223
Iter: 800/9050 Loss_main: 6.389672, d/c: 0.6015768051147461/0.18138527870178223
Iter: 900/9050 Loss_main: 8.533734, d/c: 0.2918531894683838/0.18207359313964844


100%|██████████| 246/246 [01:49<00:00,  2.24it/s]


Iter: 1000/9050 Val_Acc: 81.444771
Iter: 1000/9050 Loss_main: 6.877843, d/c: 0.3873310089111328/0.16183996200561523
Iter: 1100/9050 Loss_main: 7.306240, d/c: 0.31978559494018555/0.17606353759765625
Iter: 1200/9050 Loss_main: 6.164840, d/c: 0.3230776786804199/0.16017460823059082
Iter: 1300/9050 Loss_main: 6.651655, d/c: 0.3894014358520508/0.15588617324829102
Iter: 1400/9050 Loss_main: 6.455770, d/c: 0.3228952884674072/0.16415834426879883
Iter: 1500/9050 Loss_main: 5.444801, d/c: 0.31731653213500977/0.173661470413208
Iter: 1600/9050 Loss_main: 5.659955, d/c: 0.31207275390625/0.17258024215698242
Iter: 1700/9050 Loss_main: 5.646506, d/c: 0.3158912658691406/0.1657419204711914
Iter: 1800/9050 Loss_main: 6.001375, d/c: 0.31704068183898926/0.16053462028503418
Iter: 1900/9050 Loss_main: 5.557127, d/c: 0.32526373863220215/0.15932917594909668


100%|██████████| 246/246 [01:49<00:00,  2.24it/s]


Iter: 2000/9050 Val_Acc: 89.456291
Iter: 2000/9050 Loss_main: 6.039419, d/c: 0.3634054660797119/0.15803766250610352
Iter: 2100/9050 Loss_main: 5.466941, d/c: 0.31296420097351074/0.15680360794067383
Iter: 2200/9050 Loss_main: 5.732442, d/c: 0.3184378147125244/0.15892767906188965
Iter: 2300/9050 Loss_main: 6.107769, d/c: 0.3145103454589844/0.17448925971984863
Iter: 2400/9050 Loss_main: 5.956676, d/c: 0.30258703231811523/0.16739797592163086
Iter: 2500/9050 Loss_main: 5.792704, d/c: 0.31651878356933594/0.18184328079223633
Iter: 2600/9050 Loss_main: 6.137660, d/c: 0.3307063579559326/0.16489601135253906
Iter: 2700/9050 Loss_main: 5.180362, d/c: 0.3211195468902588/0.16051435470581055
Iter: 2800/9050 Loss_main: 5.676848, d/c: 0.32062458992004395/0.16235852241516113
Iter: 2900/9050 Loss_main: 5.750184, d/c: 0.31626462936401367/0.17041587829589844


100%|██████████| 246/246 [01:51<00:00,  2.20it/s]


Iter: 3000/9050 Val_Acc: 92.894630
Iter: 3000/9050 Loss_main: 5.359901, d/c: 0.3238992691040039/0.1619882583618164
Iter: 3100/9050 Loss_main: 5.341998, d/c: 0.3881337642669678/0.16286969184875488
Iter: 3200/9050 Loss_main: 5.225932, d/c: 0.322908878326416/0.15955519676208496
Iter: 3300/9050 Loss_main: 5.147526, d/c: 0.3085916042327881/0.15896344184875488
Iter: 3400/9050 Loss_main: 5.163166, d/c: 0.38439154624938965/0.16325950622558594
Iter: 3500/9050 Loss_main: 5.430894, d/c: 0.38226318359375/0.15519452095031738
Iter: 3600/9050 Loss_main: 5.139458, d/c: 0.3388864994049072/0.1697087287902832
Iter: 3700/9050 Loss_main: 5.042335, d/c: 0.3273146152496338/0.16856813430786133
Iter: 3800/9050 Loss_main: 4.970862, d/c: 0.4395787715911865/0.18523502349853516
Iter: 3900/9050 Loss_main: 5.907378, d/c: 0.3206300735473633/0.1696491241455078


100%|██████████| 246/246 [01:45<00:00,  2.33it/s]


Iter: 4000/9050 Val_Acc: 95.689316
Iter: 4000/9050 Loss_main: 5.582864, d/c: 0.32436275482177734/0.17222356796264648
Iter: 4100/9050 Loss_main: 5.909525, d/c: 0.2839982509613037/0.16730046272277832
Iter: 4200/9050 Loss_main: 5.159949, d/c: 0.32003092765808105/0.16126275062561035
Iter: 4300/9050 Loss_main: 5.054413, d/c: 0.3144063949584961/0.16941452026367188
Iter: 4400/9050 Loss_main: 5.280721, d/c: 0.3868393898010254/0.16321372985839844
Iter: 4500/9050 Loss_main: 5.238613, d/c: 0.3203921318054199/0.16097164154052734
Iter: 4600/9050 Loss_main: 5.381647, d/c: 0.3260807991027832/0.15623736381530762
Iter: 4700/9050 Loss_main: 5.555718, d/c: 0.3235292434692383/0.17194771766662598
Iter: 4800/9050 Loss_main: 5.549509, d/c: 0.30980968475341797/0.16274237632751465
Iter: 4900/9050 Loss_main: 5.005818, d/c: 0.858701229095459/0.16877269744873047


100%|██████████| 246/246 [01:49<00:00,  2.24it/s]


Iter: 5000/9050 Val_Acc: 97.264542
Iter: 5000/9050 Loss_main: 4.394911, d/c: 0.32352471351623535/0.15900063514709473
Iter: 5100/9050 Loss_main: 4.865605, d/c: 0.3183915615081787/0.1756131649017334
Iter: 5200/9050 Loss_main: 5.153826, d/c: 0.29007458686828613/0.19354677200317383
Iter: 5300/9050 Loss_main: 5.716899, d/c: 0.31125712394714355/0.1813793182373047
Iter: 5400/9050 Loss_main: 5.267262, d/c: 0.41721105575561523/0.15568089485168457
Iter: 5500/9050 Loss_main: 5.276050, d/c: 0.38530969619750977/0.1668083667755127
Iter: 5600/9050 Loss_main: 5.701625, d/c: 0.3675856590270996/0.17180728912353516
Iter: 5700/9050 Loss_main: 4.915949, d/c: 0.3243408203125/0.159621000289917
Iter: 5800/9050 Loss_main: 4.250974, d/c: 0.325425386428833/0.15685343742370605
Iter: 5900/9050 Loss_main: 4.800254, d/c: 0.3199188709259033/0.16908931732177734


100%|██████████| 246/246 [01:48<00:00,  2.27it/s]


Iter: 6000/9050 Val_Acc: 98.213066
Iter: 6000/9050 Loss_main: 5.897944, d/c: 0.3291194438934326/0.15900349617004395
Iter: 6100/9050 Loss_main: 5.121997, d/c: 0.3177022933959961/0.16529512405395508
Iter: 6200/9050 Loss_main: 5.690584, d/c: 0.3198823928833008/0.17158150672912598
Iter: 6300/9050 Loss_main: 4.567341, d/c: 0.32100486755371094/0.16512393951416016
Iter: 6400/9050 Loss_main: 4.530938, d/c: 0.32823824882507324/0.17086124420166016
Iter: 6500/9050 Loss_main: 4.508157, d/c: 0.3162879943847656/0.16156625747680664
Iter: 6600/9050 Loss_main: 4.533612, d/c: 0.3212578296661377/0.165177583694458
Iter: 6700/9050 Loss_main: 4.162060, d/c: 0.3186061382293701/0.1585071086883545
Iter: 6800/9050 Loss_main: 4.828248, d/c: 0.3116183280944824/0.16470623016357422
Iter: 6900/9050 Loss_main: 5.357476, d/c: 0.3195338249206543/0.1825113296508789


100%|██████████| 246/246 [01:48<00:00,  2.27it/s]


Iter: 7000/9050 Val_Acc: 99.254723
Iter: 7000/9050 Loss_main: 4.490765, d/c: 0.32824110984802246/0.1679680347442627
Iter: 7100/9050 Loss_main: 5.117365, d/c: 0.29003190994262695/0.1804969310760498
Iter: 7200/9050 Loss_main: 5.553283, d/c: 0.28574490547180176/0.17293787002563477
Iter: 7300/9050 Loss_main: 4.617667, d/c: 0.3139469623565674/0.16779756546020508
Iter: 7400/9050 Loss_main: 5.272613, d/c: 0.5334460735321045/0.1760408878326416
Iter: 7500/9050 Loss_main: 4.719364, d/c: 0.3244597911834717/0.18225312232971191
Iter: 7600/9050 Loss_main: 5.624404, d/c: 0.293698787689209/0.1625068187713623
Iter: 7700/9050 Loss_main: 4.771202, d/c: 0.780968189239502/0.17334842681884766
Iter: 7800/9050 Loss_main: 4.815221, d/c: 0.3237295150756836/0.1734921932220459
Iter: 7900/9050 Loss_main: 5.435906, d/c: 0.38922715187072754/0.1680591106414795


100%|██████████| 246/246 [01:51<00:00,  2.22it/s]


Iter: 8000/9050 Val_Acc: 99.432549
Iter: 8000/9050 Loss_main: 4.983940, d/c: 0.32604432106018066/0.16464447975158691
Iter: 8100/9050 Loss_main: 4.442134, d/c: 0.31666088104248047/0.1868443489074707
Iter: 8200/9050 Loss_main: 5.121646, d/c: 0.31320905685424805/0.15841269493103027
Iter: 8300/9050 Loss_main: 4.692663, d/c: 0.31890106201171875/0.16425609588623047
Iter: 8400/9050 Loss_main: 5.353915, d/c: 0.32420778274536133/0.16029572486877441
Iter: 8500/9050 Loss_main: 4.749154, d/c: 0.31821680068969727/0.1712348461151123
Iter: 8600/9050 Loss_main: 5.148423, d/c: 0.31516361236572266/0.16322660446166992
Iter: 8700/9050 Loss_main: 4.427621, d/c: 0.32207775115966797/0.15943598747253418
Iter: 8800/9050 Loss_main: 4.771709, d/c: 0.31677961349487305/0.1577284336090088
Iter: 8900/9050 Loss_main: 4.423179, d/c: 0.314708948135376/0.15876555442810059


100%|██████████| 246/246 [01:51<00:00,  2.21it/s]


Iter: 9000/9050 Val_Acc: 99.542664


100%|██████████| 2947/2947 [21:24<00:00,  2.30it/s]


Iter: 9000/9050 Test_Acc: 99.728539
Model Saved.
Iter: 9000/9050 Loss_main: 4.694311, d/c: 0.31838011741638184/0.15957331657409668


In [None]:
# %cd r"/content/drive/MyDrive/Colab Notebooks/CoTuning-main/utils/model/StochNorm/9000.pkl"
# %cd r"/content/drive/MyDrive/Colab Notebooks/CoTuning-main/testing.jpg"

[Errno 2] No such file or directory: 'r/content/drive/MyDrive/Colab Notebooks/CoTuning-main/utils/model/StochNorm/9000.pkl'
/content/drive/MyDrive/Colab Notebooks/CoTuning-main/utils
[Errno 2] No such file or directory: 'r/content/drive/MyDrive/Colab Notebooks/CoTuning-main/testing.jpg'
/content/drive/MyDrive/Colab Notebooks/CoTuning-main/utils


In [None]:
# import torch
# from torchvision import transforms
# from PIL import Image

# # Load the trained model checkpoint
# checkpoint_path = '/content/drive/MyDrive/Colab Notebooks/CoTuning-main/utils/model/StochNorm/9000.pkl'
# checkpoint = torch.load(checkpoint_path)
# net = Net()
# net.load_state_dict(checkpoint['state_dict'])
# net.eval()

# # Define the image transformation
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])

# # Load and preprocess the image
# image_path = '/content/drive/MyDrive/Colab Notebooks/CoTuning-main/testing.jpg'
# image = Image.open(image_path).convert('RGB')
# image_tensor = transform(image).unsqueeze(0)

# # Perform the prediction
# with torch.no_grad():
#     outputs = net(image_tensor)

# # Process the prediction results
# _, predicted = torch.max(outputs[1], 1)
# class_index = predicted.item()

# # Load the class labels
# class_labels_path = '/path/to/your/class_labels.txt'
# with open(class_labels_path, 'r') as f:
#     class_labels = f.readlines()
# class_labels = [label.strip() for label in class_labels]

# # Get the predicted class label
# predicted_label = class_labels[class_index]

# # Print the predicted label
# print("Predicted class label:", predicted_label)


FileNotFoundError: ignored

In [None]:
# # Example usage
# image_path = '/content/drive/MyDrive/Colab Notebooks/CoTuning-main/testing.jpg'  # Replace with the path to your image
# model_path = '/content/drive/MyDrive/Colab Notebooks/CoTuning-main/utils/model/StochNorm/9000.pkl'  # Replace with the path to your saved model


In [None]:
# import argparse
# import os
# import numpy as np
# import torch
# import torch.nn as nn
# import torchvision
# from torch.utils.data import DataLoader
# from torchvision import datasets
# from tqdm import tqdm
# from PIL import Image

# # from module.backbone import ResNet50_F, ResNet50_C
# # from module.relationship_learning import relationship_learning
# # from utils.transforms import get_transforms
# # from utils.tools import TenCropsTest


# def get_data_loader():
#     data_transforms = get_transforms(resize_size=256, crop_size=224)

#     # build dataset
#     test_datasets = {
#         'test' + str(i):
#             datasets.ImageFolder(
#                 os.path.join(data_path, 'images'),
#                 transform=data_transforms["test" + str(i)]
#         )
#         for i in range(10)
#     }

#     # build dataloader
#     test_loaders = {
#         'test' + str(i):
#             DataLoader(
#                 test_datasets["test" + str(i)],
#                 batch_size=1, shuffle=False, num_workers=num_workers
#         )
#         for i in range(10)
#     }

#     return test_loaders


# def main():
#     # configs = get_configs()
#     # print(configs)

#     device = torch.device('cpu')

#     net = Net().to(device)
#     if os.path.exists(relationship_path):
#         print('load pre-computed relationship from {}.'.format(relationship_path))
#         relationship = np.load(relationship_path)
#     else:
#         print('computing relationship')
#         # ... relationship computation code ...

#     test_loaders = get_data_loader()
#     checkpoint_path = '/path/to/your/model.pkl'