In [1]:
import random
from torch.utils import data
from PIL import Image
import torchvision.transforms as transforms
import torch
from math import floor

torch.manual_seed(1234)

class myDataSet(data.Dataset):
    def __init__(self,
                 ssw_path:str,
                 transform,
                 cat_num:int =102,
                 R: int =20):
        self.transform = transform
        self.imgs = []
        with open(ssw_path, 'r') as f:
            for wordss in f:
                wordss = wordss.rstrip().split()
                label_cur = [0 for i in range(cat_num)]
                cat_id = int(wordss[0].split('/')[-2])
                label_cur[cat_id-1] = 1
                # print(cat_id)
                num_blocks = floor((len(wordss) - 1) / 4)
                ssw_block = torch.Tensor(R, 4)
                for i in range(R):
                    if i<num_blocks:
                        w = max(int(wordss[i * 4 + 3]), 2)
                        h = max(int(wordss[i * 4 + 4]), 2)
                        ssw_block[i, 0] = (
                            30 - w if (int(wordss[i * 4 + 1]) + w >= 31) else int(wordss[i * 4 + 1]))
                        ssw_block[i, 2] = w
                        ssw_block[i, 1] = (
                            30 - h if (int(wordss[i * 4 + 2]) + h >= 31) else int(wordss[i * 4 + 2]))
                        ssw_block[i, 3] = h
                    else:
                        ssw_block[i] = torch.tensor([0,0,2,2])
                self.imgs.append([wordss[0], ssw_block, label_cur])



    def __getitem__(self, index):
        cur_img = Image.open(self.imgs[index][0])
        data_once = self.transform(cur_img)
        ssw_block = self.imgs[index][1]
        label_once = self.imgs[index][2]
        return data_once, ssw_block, torch.Tensor(label_once)
    def __len__(self):
        return len(self.imgs)

def random_resize(image, size, scale_range):
    '''随机缩放图像大小的函数'''
    # 随机选择一个缩放因子
    scale_factor = random.uniform(*scale_range)
    # 计算新的尺寸
    new_size = int(size * scale_factor)
    # 使用transforms.Resize调整图像大小
    resize_transform = transforms.Resize((new_size, new_size))
    return resize_transform(image)

class RandomResize(object):
    def __init__(self, size, scale_range):
        self.size=size
        self.scale_range=scale_range
    def __call__(self, image):
        '''随机缩放图像大小的函数'''
        # 随机选择一个缩放因子
        scale_factor = random.uniform(*self.scale_range)
        # 计算新的尺寸
        new_size = int(self.size * scale_factor)
        # 使用transforms.Resize调整图像大小
        resize_transform = transforms.Resize((new_size, new_size))
        return resize_transform(image)

In [2]:
import torch
import torchvision.models as v_models
import torch.nn as nn
import torch.nn.functional as F
from math import floor
import math

def spatial_pyramid_pool(previous_conv, num_sample, previous_conv_size, out_pool_size):
    '''
    previous_conv: a tensor vector of previous convolution layer
    num_sample: an int number of image in the batch
    previous_conv_size: an int vector [height, width] of the matrix features size of previous convolution layer
    out_pool_size: a int vector of expected output size of max pooling layer
    
    returns: a tensor vector with shape [1 x n] is the concentration of multi-level pooling
    '''    
    # print(previous_conv.size())
    for i in range(len(out_pool_size)):
        # print(previous_conv_size)
        h_wid = math.ceil(previous_conv_size[0] / out_pool_size[i])
        w_wid = math.ceil(previous_conv_size[1] / out_pool_size[i])
        h_pad = min(math.floor((h_wid*out_pool_size[i] - previous_conv_size[0] + 1)/2),math.floor(h_wid/2))
        w_pad = min(math.floor((w_wid*out_pool_size[i] - previous_conv_size[1] + 1)/2),math.floor(w_wid/2))
        #print([h_wid,w_wid,h_pad,w_pad])
        maxpool = nn.MaxPool2d((h_wid, w_wid), stride=(h_wid, w_wid), padding=(h_pad, w_pad))
        x = maxpool(previous_conv)
        if(i == 0):
            spp = x.view(num_sample,-1)
            # print("spp size:",spp.size())
        else:
            # print("size:",spp.size())
            spp = torch.cat((spp,x.view(num_sample,-1)), 1)
    return spp


cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}


class WSDDN(nn.Module):
    def __init__(self, vgg_name, cat_num):
        super(WSDDN, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.fc6 = nn.Linear(4096, 4096)
        self.fc7 = nn.Linear(4096, 4096)
        self.fc8c = nn.Linear(4096, cat_num)
        self.fc8d = nn.Linear(4096, cat_num)

    def forward(self, x, ssw_bbox):  # x.shape = [BATCH_SIZE, c, h, w]  ssw_bbox.shape = [BATCH_SIZE, R, 4]
        # print("input",end=':')
        # print(x.shape)#[1, 3, 480, 480]
        # print("ssw_bbox", end=':')
        # print(ssw_bbox.shape)  #[1, 24, 4]
        x = self.features(x)
        # print(x.shape)#[1, 512, 30, 30]
        x = self.through_spp_new(x, ssw_bbox)
        # print("through_spp",end=':')
        # print(x.shape)#[1, 24, 4096]
        x = nn.LeakyReLU()(self.fc6(x))
        x = nn.LeakyReLU()(self.fc7(x))
        x_c = nn.LeakyReLU()(self.fc8c(x))
        x_d = nn.LeakyReLU()(self.fc8d(x))
        # print("x_c", end=':')
        # print(x_c.shape)#[1, 24, 102]
        # print("x_d", end=':')
        # print(x_d.shape)#[1, 24, 102]
        segma_c = F.softmax(x_c, dim=2)
        segma_d = F.softmax(x_d, dim=1)
        # print("segma_c", end=':')
        # print(segma_c.shape)#[1, 24, 102]
        # print("segma_d", end=':')
        # print(segma_d.shape)#[1, 24, 102]
        y = torch.mul(segma_c, segma_d)
        y = torch.sum(y, dim=1)
        # print("y", end=':')
        # print(y.shape)#[1, 102]
        return y, segma_c, segma_d

    def _make_layers(self, cfg):  # init VGG
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        return nn.Sequential(*layers)

    def through_spp_new(self, x,
                        ssw):  # x.shape = [BATCH_SIZE, 512, 14, 14] ssw_bbox.shape = [BATCH_SIZE, R, 4] y.shape = [BATCH_SIZE, R, 4096]
        for i in range(x.size(0)):
            for j in range(ssw.size(1)):
                fmap_piece = torch.unsqueeze(x[i, :, floor(ssw[i, j, 0]): floor(ssw[i, j, 0] + ssw[i, j, 2]),
                                             floor(ssw[i, j, 1]): floor(ssw[i, j, 1] + ssw[i, j, 3])], 0)
                fmap_piece = spatial_pyramid_pool(previous_conv=fmap_piece, num_sample=1,
                                                  previous_conv_size=[fmap_piece.size(2), fmap_piece.size(3)],
                                                  out_pool_size=[2, 2])
                if j == 0:
                    y_piece = fmap_piece
                    # print('fmap_piece.shape', fmap_piece.shape)
                else:

                    y_piece = torch.cat((y_piece, fmap_piece))
            if i == 0:
                y = torch.unsqueeze(y_piece, 0)
                # print('y_piece', y_piece.shape)
            else:
                y = torch.cat((y, torch.unsqueeze(y_piece, 0)))
        return y

In [3]:
from datetime import datetime
import time
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.optim as optim
import os
import wandb
import yaml

wandb.login(key='025737bf0e2deb6900256f426ca16b1fff57f95b')

# os.environ["WANDB_MODE"] = "offline"

def build_dataset(batch_size, train_data_path, num_workers):
    Transform = transforms.Compose([
        transforms.Resize([480, 480]),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    # 加载整个数据集
    full_trainData = myDataSet(train_data_path, Transform)

    # 计算拆分点
    train_size = int(0.8 * len(full_trainData))
    validation_size = len(full_trainData) - train_size

    # 拆分数据集
    trainData, valData = torch.utils.data.random_split(full_trainData, [train_size, validation_size])

    # 创建训练和验证数据加载器
    trainLoader = torch.utils.data.DataLoader(dataset=trainData,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=num_workers)

    valLoader = torch.utils.data.DataLoader(dataset=valData,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=num_workers)
    return trainLoader, valLoader

def build_model(preCNN_name, cat_num):#, model_path, result_pkl
    model = WSDDN(preCNN_name, cat_num)
#     preCNN_path = ''
#     if preCNN_name == 'VGG11':
#         preCNN_path = r'/kaggle/input/wsddn-od/wsddn-data/model_para/vgg11_bn-6002323d.pth'
#     elif preCNN_name == 'VGG13':
#         preCNN_path = r'/kaggle/input/wsddn-od/wsddn-data/model_para/vgg13_bn-abd245e5.pth'
#     elif preCNN_name == 'VGG16':
#         preCNN_path = r'/kaggle/input/wsddn-od/wsddn-data/model_para/vgg16_bn-6c64b313.pth'
#     else :
#         preCNN_path = r'/kaggle/input/wsddn-od/wsddn-data/model_para/vgg19_bn-c79401a0.pth'
#     pretrained_dict = torch.load(preCNN_path)
#     modified_dict = model.state_dict()
#     pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in modified_dict}
#     modified_dict.update(pretrained_dict)
#     model.load_state_dict(modified_dict)
    model.load_state_dict(torch.load('/kaggle/input/wsddn-od/kind-sweep-1.pth'))
    model.cuda()
    return model

def train(args=None):
    with wandb.init(config=args, project="WSDDN"):
        args = wandb.config
        trainLoader, valLoader = build_dataset(args.batch_size, args.train_data_path, args.num_workers)
        wsddn = build_model(args.preCNN_name, args.cat_num)  #, args.model_path, args.result_name

        # for m in wsddn.children():
        #     m.register_forward_hook(hook=WSDDN.hook_forward_fn)
        #     m.register_full_backward_hook(hook=WSDDN.hook_backward_fn)

        optimizer = optim.Adam(wsddn.parameters(), lr=args.learn_rate, weight_decay=args.weight_decay)
#         optimizer = optim.SGD(wsddn.parameters(), lr=args.learn_rate, momentum=args.momentum, weight_decay=args.weight_decay)
        # cross_entropy_Loss = nn.BCELoss(weight=None, reduction='mean')
        ###开始训练
        start_time = time.time()
        print('Start Training at {}'.format(datetime.now().strftime('%H:%M:%S')))
        wsddn.train()  # Set the model to training mode
        max_acc = 0.0
        for epoch in range(args.epochs):
            correct = 0
            epoch_loss = 0.0
            running_loss = 0.0
            print(f'Epoch {epoch + 1}\{args.epochs}')
            for i, (images, bbox, labels) in enumerate(trainLoader):
                images = images.cuda()
                labels = labels.cuda()
                bbox = bbox.cuda()
                optimizer.zero_grad()
                y, _, _ = wsddn(images, bbox)
                y = torch.clamp(y, min=0.0, max=1.0)
                loss_ = torch.mul(labels, torch.log(y + 1e-8)) + torch.mul((1 - labels), torch.log((1 - y) + 1e-8))
                loss = -torch.sum(loss_)
#                 l2_loss = sum(torch.norm(p).pow(2) for p in wsddn.parameters()) * args.reg  # 添加正则项
#                 loss += l2_loss
                # register_hook(save_grad('y'))
                loss.backward()  # Compute the gradient
                optimizer.step()  # Update the weights
                running_loss += loss.item()
                correct = cal_correct(y, labels, correct)
                wandb.log({"train_loss": loss.item(), "epoch": epoch + 1})
                count = 10
                if i % count == count - 1 or i == len(trainLoader) - 1:
                    if i == len(trainLoader) - 1:
                        count = i % count
                    if count == 0:
                        continue
                    # 计算当前进度
                    progress = (args.epochs * len(trainLoader)) / (i + 1)
                    # 计算当前用去的时间
                    time_elapsed = time.time() - start_time
                    # 估计整个的完成时间
                    time_finishing = time_elapsed * progress
                    time_remaining = time_finishing - time_elapsed
                    # 将剩余时间转换为小时和分钟
                    days, remainder = divmod(time_remaining, 86400)
                    hours, remainder = divmod(remainder, 3600)
                    minutes, seconds = divmod(remainder, 60)
                    # 打印ETA
                    print('[%s] [%d/%5d] ETA: %02d:%02d:%02d:%02d loss: %.3f' % (
                        datetime.now().strftime('%H:%M:%S'),
                        i + 1,
                        len(trainLoader),
                        days,
                        hours,
                        minutes,
                        int(seconds),
                        running_loss / count
                    ))
                    epoch_loss += running_loss
                    running_loss = 0.0
            epoch_loss = epoch_loss / float(len(trainLoader))
            acc = float(correct) / float(len(trainLoader) * args.batch_size)
            # 检查是否是目前为止最高的准确率，并更新保存的模型
            epoch_val_loss, val_acc = test(args, wsddn, valLoader)
            if val_acc > max_acc:
                max_acc = val_acc
                torch.save(wsddn.state_dict(), os.path.join(args.model_path, args.result_name))
            wandb.log({"train_acc": acc, "val_acc": val_acc, "epoch_train_loss": epoch_loss,
                       "epoch_val_loss": epoch_val_loss})  #
            print(f'Finish Epoch {epoch + 1}\{args.epochs} with acc: %.3f' % acc)
        print('Finished Training')
        wandb.finish()
        torch.cuda.empty_cache()


def test(args, wsddn, loader):
    with torch.no_grad():
        wsddn.eval()
        correct = 0
        val_loss = 0.0
        running_loss = 0.0
        start_time = time.time()
        print('Start Testing at {}'.format(datetime.now().strftime('%H:%M:%S')))
        for i, (images, bbox, labels) in enumerate(loader):
            images = images.cuda()
            bbox = bbox.cuda()
            labels = labels.cuda()
            y, _, _ = wsddn(images, bbox)
            y = torch.clamp(y, min=0.0, max=1.0)
            # loss_ = torch.mul(labels, y - 0.5) + 0.5
            # loss = -torch.sum(torch.log(loss_ + 1e-8))
            loss_ = torch.mul(labels, torch.log(y + 1e-8)) + torch.mul((1 - labels), torch.log((1 - y) + 1e-8))
            loss = -torch.sum(loss_)
#             l2_loss = sum(torch.norm(p).pow(2) for p in wsddn.parameters()) * args.reg  # 添加正则项
#             loss += l2_loss
            running_loss += loss.item()
            # wandb.log({"test_loss": loss.item()})
            # y[1,102],segma_c[1,R,102],segma_d[1,R,102]，#image= [1, c, h, w]  bbox= [1, R, 4]
            correct = cal_correct(y, labels, correct)
            count = 10
            if i % count == count - 1 or i == len(loader) - 1:
                if i == len(loader) - 1:
                    count = i % count
                if count == 0:
                    continue
                # 计算当前进度的倒数
                progress = (args.epochs * len(loader)) / (i + 1)
                # 计算当前用去的时间
                time_elapsed = time.time() - start_time
                # 估计完成时间
                time_finishing = time_elapsed * progress
                time_remaining = time_finishing - time_elapsed
                # 将剩余时间转换为天、小时、分钟和秒
                days, remainder = divmod(time_remaining, 86400)  # 86400秒 = 24小时
                hours, remainder = divmod(remainder, 3600)  # 3600秒 = 1小时
                minutes, seconds = divmod(remainder, 60)  # 60秒 = 1分钟
                # 打印ETA
                print('[%s] [%d/%5d] ETA: %02d:%02d:%02d:%02d loss: %.3f' % (
                    datetime.now().strftime('%H:%M:%S'),
                    i + 1,
                    len(loader),
                    days,
                    hours,
                    minutes,
                    int(seconds),
                    running_loss / count
                ))
                val_loss += running_loss
                running_loss = 0.0

        acc = float(correct) / float(len(loader) * args.batch_size)
        print('Classification Accuracy of the model on the test images(mAcc): %.4f %%' % acc)
        print('Finished Testing')
        return val_loss / float(len(loader)), acc


def cal_correct(y, labels, correct):
    # 将预测结果更改为分类标签对应的one-hot格式
    for i in range(y.size(0)):
        max_indices = torch.argmax(y[i])
        predicted = torch.zeros_like(y[i], dtype=torch.bool)
        predicted[max_indices] = 1
        # 把每个元素相加得到分类正确的数量
        correct += 1 if torch.equal(predicted, labels[i]) else 0
    return correct

# torch.cuda.set_device(0)
# args=parse_args()
# train(args)

# with open('/kaggle/working/wsddn-project/sweep-random-hyperband.yaml', 'r') as file:
#     sweep_config = yaml.safe_load(file)
# sweep_id = wandb.sweep(sweep_config, project="WSDDN")
sweep_configuration = {
    "method": "grid",
    "metric": {"goal": "minimize", "name": "train_loss"},
    "parameters": {
        "model_path":{"value":'/kaggle/working/'},
        "result_name":{"value":'wsddn.pth'},
        "batch_size":{"value":16},
        "learn_rate":{"value":6.4e-6},
        "epochs":{"value":15},
        "train_data_path":{"value":'/kaggle/input/wsddn-od/ssw_train.txt'},
        "val_data_path":{"value":'/kaggle/input/wsddn-od/ssw_val.txt'},
        "preCNN_name":{"value":'VGG11'},
        "cat_num":{"value":102},
        "momentum":{"value":0},
        "weight_decay":{"value":1e-7},
        "num_workers":{"value":8}
    },
}

sweep_id = wandb.sweep(sweep=sweep_configuration, project="WSDDN")
wandb.agent(sweep_id, function=train, project="WSDDN", count=6)
# print("what fuck")
# wandb.agent('yun6p7im', function=train, count=140, project="WSDDN")


[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Create sweep with ID: s3re0wrs
Sweep URL: https://wandb.ai/1844986810/WSDDN/sweeps/s3re0wrs


[34m[1mwandb[0m: Agent Starting Run: z0s1x4ya with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	cat_num: 102
[34m[1mwandb[0m: 	epochs: 15
[34m[1mwandb[0m: 	learn_rate: 6.4e-06
[34m[1mwandb[0m: 	model_path: /kaggle/working/
[34m[1mwandb[0m: 	momentum: 0
[34m[1mwandb[0m: 	num_workers: 8
[34m[1mwandb[0m: 	preCNN_name: VGG11
[34m[1mwandb[0m: 	result_name: wsddn.pth
[34m[1mwandb[0m: 	train_data_path: /kaggle/input/wsddn-od/ssw_train.txt
[34m[1mwandb[0m: 	val_data_path: /kaggle/input/wsddn-od/ssw_val.txt
[34m[1mwandb[0m: 	weight_decay: 1e-07
[34m[1mwandb[0m: Currently logged in as: [33m1844986810[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.17.2 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.16.6
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20240618_1101

Start Training at 11:02:09
Epoch 1\15
[11:02:18] [10/  328] ETA: 00:01:08:57 loss: 9.773
[11:02:23] [20/  328] ETA: 00:00:58:01 loss: 5.915
[11:02:29] [30/  328] ETA: 00:00:54:10 loss: 5.856
[11:02:35] [40/  328] ETA: 00:00:52:16 loss: 4.604
[11:02:41] [50/  328] ETA: 00:00:51:11 loss: 3.383
[11:02:46] [60/  328] ETA: 00:00:50:19 loss: 3.639
[11:02:52] [70/  328] ETA: 00:00:49:43 loss: 2.502
[11:02:58] [80/  328] ETA: 00:00:49:09 loss: 2.250
[11:03:04] [90/  328] ETA: 00:00:48:43 loss: 3.091
[11:03:10] [100/  328] ETA: 00:00:48:26 loss: 3.890
[11:03:15] [110/  328] ETA: 00:00:48:10 loss: 3.775
[11:03:21] [120/  328] ETA: 00:00:47:53 loss: 1.567
[11:03:27] [130/  328] ETA: 00:00:47:39 loss: 3.053
[11:03:33] [140/  328] ETA: 00:00:47:25 loss: 2.893
[11:03:38] [150/  328] ETA: 00:00:47:12 loss: 1.361
[11:03:44] [160/  328] ETA: 00:00:47:02 loss: 1.884
[11:03:50] [170/  328] ETA: 00:00:46:51 loss: 2.334
[11:03:56] [180/  328] ETA: 00:00:46:40 loss: 2.903
[11:04:01] [190/  328] ETA: 00:00:4

[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:            epoch ▁▁▁▁▁▁▂▂▃▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇███
[34m[1mwandb[0m: epoch_train_loss █▃▂▁▂▁▁▂▁▁▁▁▁▁▁
[34m[1mwandb[0m:   epoch_val_loss █▄▄▂▃▁▂█▁▁▂▁▁▂▂
[34m[1mwandb[0m:        train_acc ▁▄▇▇▇██▇███████
[34m[1mwandb[0m:       train_loss ▄▃█▄▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:          val_acc ▄▅▄▇▇▇▇▁█▇▇▇█▇▆
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:            epoch 15
[34m[1mwandb[0m: epoch_train_loss 0.00194
[34m[1mwandb[0m:   epoch_val_loss 0.76414
[34m[1mwandb[0m:        train_acc 0.99867
[34m[1mwandb[0m:       train_loss 3e-05
[34m[1mwandb[0m:          val_acc 0.99162
[34m[1mwandb[0m: 
[34m[1mwandb[0m: 🚀 View run [33meffortless-sweep-1[0m at: [34m[4mhttps://wandb.ai/1844986810/WSDDN/runs/z0s1x4ya[0m
[34m[1mwandb[

In [4]:

def evaluate(args=None):
    sizes = [480, 688, 1200]
    all_transforms = []
    for size in sizes:
        resize = transforms.Resize((size, size))
        for horizontal_flip in [True, False]:
            for vertical_flip in [True, False]:
                horizontal_flip_transform = transforms.RandomHorizontalFlip(p=1 if horizontal_flip else 0)
                vertical_flip_transform = transforms.RandomVerticalFlip(p=1 if vertical_flip else 0)
                combined_transform = transforms.Compose([
                    resize,
                    horizontal_flip_transform,
                    vertical_flip_transform,
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
                ])
                all_transforms.append(combined_transform)
    with wandb.init(config=args, project="WSDDN"):
        args = wandb.config
        testData=torch.utils.data.ConcatDataset([])
        for Transform in all_transforms:
            testData.add_dataset(myDataSet(args.val_data_path, Transform))
        loader = torch.utils.data.DataLoader(dataset=testData,
                                             batch_size=args.batch_size,
                                             shuffle=False)
        wsddn = WSDDN(args.preCNN_name, args.cat_num)
        wsddn.load_state_dict(torch.load(os.path.join(args.model_path, args.result_pkl)))
        wsddn.cuda()
        test(args, wsddn, loader)
        
# sweep_configuration = {
#     "method": "grid",
#     "metric": {"goal": "minimize", "name": "train_loss"},
#     "parameters": {
#         "model_path":{"value":'/kaggle/working/'},
#         "result_name":{"value":'wsddn.pth'},
#         "batch_size":{"value":16},
#         "learn_rate":{"value": 0.0000306},
#         "epochs":{"value":20},
#         "train_data_path":{"value":'/kaggle/input/wsddn-od/ssw_train.txt'},
#         "val_data_path":{"value":'/kaggle/input/wsddn-od/ssw_val.txt'},
#         "preCNN_name":{"value":'VGG11'},
#         "cat_num":{"value":102},
#         "momentum":{"value":0},
#         "weight_decay":{"value":1e-7},
#         "num_workers":{"value":8}
#     },
# }

# sweep_id = wandb.sweep(sweep=sweep_configuration, project="WSDDN")
# wandb.agent(sweep_id, function=evaluate, project="WSDDN")