In [None]:
!git clone https://github.com/brjathu/SKD.git

In [None]:
import wandb
wandb_api = "xxxxxxxxxxxxxxxxxxx"

wandb.login(key=wandb_api)

In [None]:
pip install wandb --upgrade

In [None]:
pwd

In [None]:
cd ./SKD

In [None]:
!pip install -r requirements.txt

In [None]:
!conda install mkl-service -y

In [None]:
%%writefile /kaggle/working/SKD/models/__init__.py


from .resnet_ssl import VAE_resnet18_ssl


model_pool = [
    
    'VAE_resnet18_ssl'
]

model_dict = {
    'VAE_resnet18_ssl': VAE_resnet18_ssl
}

In [None]:
%%writefile /kaggle/working/SKD/models/resnet_ssl.py

import torch
from torch import nn, optim
import torch.nn.functional as F

class ResizeConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, scale_factor, mode='nearest'):
        super().__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
        x = self.conv(x)
        return x

class BasicBlockEnc(nn.Module):

    def __init__(self, in_planes, stride=1):
        super().__init__()

        planes = in_planes*stride

        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        if stride == 1:
            self.shortcut = nn.Sequential()
        else:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class BasicBlockDec(nn.Module):

    def __init__(self, in_planes, stride=1):
        super().__init__()

        planes = int(in_planes/stride)

        self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(in_planes)
        # self.bn1 could have been placed here, but that messes up the order of the layers when printing the class

        if stride == 1:
            self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(planes)
            self.shortcut = nn.Sequential()
        else:
            self.conv1 = ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride)
            self.bn1 = nn.BatchNorm2d(planes)
            self.shortcut = nn.Sequential(
                ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = torch.relu(self.bn2(self.conv2(x)))
        out = self.bn1(self.conv1(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class ResNet18Enc(nn.Module):

    def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3,num_classes=0):
        super().__init__()
        self.in_planes = 64
        self.num_classes=num_classes
        self.z_dim = z_dim
        self.conv1 = nn.Conv2d(nc, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(BasicBlockEnc, 64, num_Blocks[0], stride=1)
        self.layer2 = self._make_layer(BasicBlockEnc, 128, num_Blocks[1], stride=2)
        self.layer3 = self._make_layer(BasicBlockEnc, 256, num_Blocks[2], stride=2)
        self.layer4 = self._make_layer(BasicBlockEnc, 512, num_Blocks[3], stride=2)
        self.linear = nn.Linear(512, 2 * z_dim)
        if self.num_classes>0:
            self.classifier=nn.Linear(512,num_classes) #basic task
            self.rot_classifier = nn.Linear(self.num_classes, 4) #rotation task
        

    def _make_layer(self, BasicBlockEnc, planes, num_Blocks, stride):
        strides = [stride] + [1]*(num_Blocks-1)
        layers = []
        for stride in strides:
            layers += [BasicBlockEnc(self.in_planes, stride)]
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.adaptive_avg_pool2d(x, 1)
        feat = x.view(x.size(0), -1)
        #x = x.view(x.size(0), -1)
        
        xx=self.classifier(feat)
        xy=self.rot_classifier(xx)
        #xx= self.classifier(x)
        #x_rot=self.rot_classifier(xx)
        
        #x = self.linear(feat)
        
        #mu = x[:, :self.z_dim]
        #logvar = x[:, self.z_dim:]
        return xx, xy #, mu, logvar

class ResNet18Dec(nn.Module):

    def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3):
        super().__init__()
        self.in_planes = 512

        self.linear = nn.Linear(z_dim, 512)

        self.layer4 = self._make_layer(BasicBlockDec, 256, num_Blocks[3], stride=2)
        self.layer3 = self._make_layer(BasicBlockDec, 128, num_Blocks[2], stride=2)
        self.layer2 = self._make_layer(BasicBlockDec, 64, num_Blocks[1], stride=2)
        self.layer1 = self._make_layer(BasicBlockDec, 64, num_Blocks[0], stride=1)
        self.conv1 = ResizeConv2d(64, nc, kernel_size=3, scale_factor=2)

    def _make_layer(self, BasicBlockDec, planes, num_Blocks, stride):
        strides = [stride] + [1]*(num_Blocks-1)
        layers = []
        for stride in reversed(strides):
            layers += [BasicBlockDec(self.in_planes, stride)]
        self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, z):
        x = self.linear(z)
        x = x.view(z.size(0), 512, 1, 1)
        x = F.interpolate(x, scale_factor=2)
        x = self.layer4(x)
        x = self.layer3(x)
        x = self.layer2(x)
        x = self.layer1(x)
        x = torch.sigmoid(self.conv1(x))
        x = x.view(x.size(0), 3, 32, 32)
        return x

class VAE(nn.Module):

    def __init__(self, z_dim,num_classes):
        super().__init__()
        self.encoder = ResNet18Enc(z_dim=z_dim,num_classes=num_classes)
        #self.decoder = ResNet18Dec(z_dim=z_dim)

    def forward(self, x, rot=False):
        #xx, mean, logvar = self.encoder(x)
        xx, xy = self.encoder(x)
        #z = self.reparameterize(mean, logvar)
        #x = self.decoder(z)
        if rot: #is selfsup
            return xx, xy
        else:
            return xx
    
    @staticmethod
    def reparameterize(mean, logvar):
        std = torch.exp(logvar / 2) # in log-space, squareroot is divide by two
        epsilon = torch.randn_like(std)
        return epsilon * std + mean

def VAE_resnet18_ssl(num_classes,**kwargs):
    """Constructs a VAE-RESNET18 model.
    """
    model = VAE(num_classes=num_classes,z_dim=10)
    return model

In [None]:
%%writefile /kaggle/working/SKD/train_selfsupervison.py

from __future__ import print_function

import os
import argparse
import socket
import time
import sys
from tqdm import tqdm
import mkl

# import tensorboard_logger as tb_logger
import torch
import torch.optim as optim
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.autograd import Variable

from models import model_pool
from models.util import create_model

from dataset.mini_imagenet import ImageNet, MetaImageNet
from dataset.tiered_imagenet import TieredImageNet, MetaTieredImageNet
from dataset.cifar import CIFAR100, MetaCIFAR100
from dataset.transform_cfg import transforms_options, transforms_test_options, transforms_list

from util import adjust_learning_rate, accuracy, AverageMeter
from eval.meta_eval import meta_test, meta_test_tune
from eval.cls_eval import validate

from models.resnet import resnet12
import numpy as np
from util import Logger
import wandb
from dataloader import get_dataloaders

def get_freer_gpu():
    os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
    memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
    return np.argmax(memory_available)

os.environ["CUDA_VISIBLE_DEVICES"]=str(get_freer_gpu())
mkl.set_num_threads(2)


def parse_option():

    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--eval_freq', type=int, default=10, help='meta-eval frequency')
    parser.add_argument('--print_freq', type=int, default=100, help='print frequency')
    parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency')
    parser.add_argument('--save_freq', type=int, default=10, help='save frequency')
    parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
    parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=100, help='number of training epochs')

    # optimization
    parser.add_argument('--learning_rate', type=float, default=0.05, help='learning rate')
    parser.add_argument('--lr_decay_epochs', type=str, default='60,80', help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
    parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--adam', action='store_true', help='use adam optimizer')
    parser.add_argument('--simclr', type=bool, default=False, help='use simple contrastive learning representation')
    parser.add_argument('--ssl', type=bool, default=True, help='use self supervised learning')
    parser.add_argument('--tags', type=str, default="gen0, ssl", help='add tags for the experiment')


    # dataset
    parser.add_argument('--model', type=str, default='resnet12', choices=model_pool)
    parser.add_argument('--dataset', type=str, default='miniImageNet', choices=['miniImageNet', 'tieredImageNet',
                                                                                'CIFAR-FS', 'FC100'])
    parser.add_argument('--transform', type=str, default='A', choices=transforms_list)
    parser.add_argument('--use_trainval', type=bool, help='use trainval set')

    # cosine annealing
    parser.add_argument('--cosine', action='store_true', help='using cosine annealing')

    # specify folder
    parser.add_argument('--model_path', type=str, default='save/', help='path to save model')
    parser.add_argument('--tb_path', type=str, default='tb/', help='path to tensorboard')
    parser.add_argument('--data_root', type=str, default='/raid/data/IncrementLearn/imagenet/Datasets/MiniImagenet/', help='path to data root')

    # meta setting
    parser.add_argument('--n_test_runs', type=int, default=600, metavar='N',
                        help='Number of test runs')
    parser.add_argument('--n_ways', type=int, default=5, metavar='N',
                        help='Number of classes for doing each classification run')
    parser.add_argument('--n_shots', type=int, default=1, metavar='N',
                        help='Number of shots in test')
    parser.add_argument('--n_queries', type=int, default=15, metavar='N',
                        help='Number of query in test')
    parser.add_argument('--n_aug_support_samples', default=5, type=int,
                        help='The number of augmented samples for each meta test sample')
    parser.add_argument('--test_batch_size', type=int, default=1, metavar='test_batch_size',
                        help='Size of test batch)')

    parser.add_argument('-t', '--trial', type=str, default='1', help='the experiment id')
    
    
    
    #hyper parameters
    parser.add_argument('--gamma', type=float, default=2, help='loss cofficient for ssl loss')
    
    opt = parser.parse_args()

    if opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        opt.transform = 'D'

    if opt.use_trainval:
        opt.trial = opt.trial + '_trainval'

    # set the path according to the environment
    if not opt.model_path:
        opt.model_path = './models_pretrained'
    if not opt.tb_path:
        opt.tb_path = './tensorboard'
    if not opt.data_root:
        opt.data_root = './data/{}'.format(opt.dataset)
    else:
        opt.data_root = '{}/{}'.format(opt.data_root, opt.dataset)
    opt.data_aug = True

    iterations = opt.lr_decay_epochs.split(',')
    opt.lr_decay_epochs = list([])
    for it in iterations:
        opt.lr_decay_epochs.append(int(it))
        
    tags = opt.tags.split(',')
    opt.tags = list([])
    for it in tags:
        opt.tags.append(it)

    opt.model_name = '{}_{}_lr_{}_decay_{}_trans_{}'.format(opt.model, opt.dataset, opt.learning_rate,
                                                            opt.weight_decay, opt.transform)

    if opt.cosine:
        opt.model_name = '{}_cosine'.format(opt.model_name)

    if opt.adam:
        opt.model_name = '{}_useAdam'.format(opt.model_name)

    opt.model_name = '{}_trial_{}'.format(opt.model_name, opt.trial)

    opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
    if not os.path.isdir(opt.tb_folder):
        os.makedirs(opt.tb_folder)

    opt.save_folder = os.path.join(opt.model_path, opt.model_name)
    if not os.path.isdir(opt.save_folder):
        os.makedirs(opt.save_folder)

    opt.n_gpu = torch.cuda.device_count()
    
    
    #extras
    opt.fresh_start = True
    
    
    return opt


def main():

    opt = parse_option()
    wandb.init(project=opt.model_path.split("/")[-1], tags=opt.tags)
    wandb.config.update(opt)
    wandb.save('*.py')
    wandb.run.save()
    
        
    train_loader, val_loader, meta_testloader, meta_valloader, n_cls = get_dataloaders(opt)

    # model
    model = create_model(opt.model, n_cls, opt.dataset)
    wandb.watch(model)
    
    # optimizer
    if opt.adam:
        print("Adam")
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=opt.learning_rate,
                                     weight_decay=0.0005)
    else:
        print("SGD")
        optimizer = optim.SGD(model.parameters(),
                              lr=opt.learning_rate,
                              momentum=opt.momentum,
                              weight_decay=opt.weight_decay)
        
        

    criterion = nn.CrossEntropyLoss()

    if torch.cuda.is_available():
        if opt.n_gpu > 1:
            model = nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    # set cosine annealing scheduler
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.epochs, eta_min, -1)

    # routine: supervised pre-training
    for epoch in range(1, opt.epochs + 1):
            if opt.cosine:
                scheduler.step()
            else:
                adjust_learning_rate(epoch, opt, optimizer)
            print("==> training...")


            time1 = time.time()
            train_acc, train_acc5, train_loss = train(epoch, train_loader, model, criterion, optimizer, opt)
            time2 = time.time()
            print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))


            val_acc, val_acc_top5, val_loss = validate(val_loader, model, criterion, opt)


            #validate
            start = time.time()
            meta_val_acc, meta_val_std, meta_val_acc5, meta_val_std5 = meta_test(model, meta_valloader,use_logit=True)
            test_time = time.time() - start
            print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}, Time: {:.1f}'.format(meta_val_acc, meta_val_std, test_time))

            #evaluate
            start = time.time()
            meta_test_acc, meta_test_std, meta_test_acc5, meta_test_std5 = meta_test(model, meta_testloader,use_logit=True)
            test_time = time.time() - start
            print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}, Time: {:.1f}'.format(meta_test_acc, meta_test_std, test_time))


            # regular saving
            if epoch % opt.save_freq == 0 or epoch==opt.epochs:
                print('==> Saving...')
                state = {
                    'epoch': epoch,
                    'optimizer': optimizer.state_dict(),
                    'model': model.state_dict(),
                }            
                save_file = os.path.join(opt.save_folder, 'model_'+str(wandb.run.name)+'.pth')
                torch.save(state, save_file)

                #wandb saving
                torch.save(state, os.path.join(wandb.run.dir, "model.pth"))

                ## onnx saving
                #dummy_input = torch.autograd.Variable(torch.randn(1, 3, 32, 32)).cuda()
                #torch.onnx.export(model, dummy_input, os.path.join(wandb.run.dir, "model.onnx"))

            wandb.log({'epoch': epoch, 
                       'Train Acc': train_acc,
                       'Train Acc top 5': train_acc5,
                       'Train Loss': train_loss,
                       'Val Acc': val_acc,
                       'Val Acc top 5': val_acc_top5,
                       'Val Loss': val_loss,
                       'Meta Test Acc': meta_test_acc,
                       'Meta Test std': meta_test_std,
                       'Meta Val Acc': meta_val_acc,
                       'Meta Val std': meta_val_std,
                       'Meta Test Acc5': meta_test_acc5,
                       'Meta Test std5': meta_test_std5,
                       'Meta Val Acc5': meta_val_acc5,
                       'Meta Val std5': meta_val_std5
                      })

    #final report 
    generate_final_report(model, opt, wandb)
    
    #remove output.txt log file 
    output_log_file = os.path.join(wandb.run.dir, "output.log")
    if os.path.isfile(output_log_file):
        os.remove(output_log_file)
    else:    ## Show an error ##
        print("Error: %s file not found" % output_log_file)
        
        
        
def train(epoch, train_loader, model, criterion, optimizer, opt):
    """One epoch training"""
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    end = time.time()
    with tqdm(train_loader, total=len(train_loader)) as pbar:
        for idx, (input, target, _) in enumerate(pbar):
            data_time.update(time.time() - end)

            input = input.float()
            if torch.cuda.is_available():
                input = input.cuda()
                target = target.cuda()
            
            
            batch_size = input.size()[0]
            x = input
            x_90 = x.transpose(2,3).flip(2)
            x_180 = x.flip(2).flip(3)
            x_270 = x.flip(2).transpose(2,3)
            generated_data = torch.cat((x, x_90, x_180, x_270),0)
            train_targets = target.repeat(4)
            
            rot_labels = torch.zeros(4*batch_size).cuda().long()
            for i in range(4*batch_size):
                if i < batch_size:
                    rot_labels[i] = 0
                elif i < 2*batch_size:
                    rot_labels[i] = 1
                elif i < 3*batch_size:
                    rot_labels[i] = 2
                else:
                    rot_labels[i] = 3

            # ===================forward=====================
            
            #Rotation
            train_logit, rot_logits = model(generated_data,rot=True)
            rot_labels = F.one_hot(rot_labels.to(torch.int64), 4).float()
            loss_ss = torch.sum(F.binary_cross_entropy_with_logits(input = rot_logits, target = rot_labels))
            loss_ce = criterion(train_logit, train_targets)
            
            loss = opt.gamma * loss_ss + loss_ce
            
            #(_,_,_,_, feat), (train_logit, rot_logits) = model(generated_data, rot=True)
            
            #train_logit, VAE_logit = model(x,selfsup=True)
            #train_logit= model(x)
            
            
            #loss_ss = torch.sum(F.binary_cross_entropy_with_logits(input = VAE_logit, target = x))
            
            #loss_ss = nn.MSELoss(reduction='mean')(VAE_logit, x) #nn.BCELoss(reduction='sum')(VAE_logit, x)
            #loss_ce = criterion(train_logit, target)
            
            #loss = opt.gamma * loss_ss + loss_ce
            #loss = loss_ss + loss_ce # 1:(0.75,0.25), 2:(0.25,0.75)
            
           # loss = criterion(train_logit, target)
            
            acc1, acc5 = accuracy(train_logit, train_targets, topk=(1, 5)) #this target need to be modified in case of the softmax experiment, because of the additional data, 
            losses.update(loss.item(), input.size(0))
            top1.update(acc1[0], input.size(0))
            top5.update(acc5[0], input.size(0))

            # ===================backward=====================
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
          
            # ===================meters=====================
            batch_time.update(time.time() - end)
            end = time.time()
            
            
            pbar.set_postfix({"Acc@1":'{0:.2f}'.format(top1.avg.cpu().numpy()), 
                              "Acc@5":'{0:.2f}'.format(top5.avg.cpu().numpy(),2), 
                              "Loss" :'{0:.2f}'.format(losses.avg,2), 
                             })

    print('Train_Acc@1 {top1.avg:.3f} Train_Acc@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))

    return top1.avg,top5.avg, losses.avg


def generate_final_report(model, opt, wandb):
    
    
    opt.n_shots = 1
    train_loader, val_loader, meta_testloader, meta_valloader, _ = get_dataloaders(opt)
    
    #validate
    meta_val_acc, meta_val_std, meta_val_acc5, meta_val_std5 = meta_test(model, meta_valloader, use_logit=True)
    
   # meta_val_acc_feat, meta_val_std_feat = meta_test(model, meta_valloader, use_logit=False)

    #evaluate
    meta_test_acc, meta_test_std, meta_test_acc5, meta_test_std5 = meta_test(model, meta_testloader, use_logit=True)
    
    #meta_test_acc_feat, meta_test_std_feat = meta_test(model, meta_testloader, use_logit=False)
        
    print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}'.format(meta_val_acc, meta_val_std))
    #print('Meta Val Acc (feat): {:.4f}, Meta Val std (feat): {:.4f}'.format(meta_val_acc_feat, meta_val_std_feat))
    print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}'.format(meta_test_acc, meta_test_std))
    #print('Meta Test Acc (feat): {:.4f}, Meta Test std (feat): {:.4f}'.format(meta_test_acc_feat, meta_test_std_feat))
    
    
    wandb.log({'Final Meta Test Acc @1': meta_test_acc,
               'Final Meta Test std @1': meta_test_std,
               'Final Meta Test Acc @5': meta_test_acc5,
               'Final Meta Test std @5': meta_test_std5,
               #'Final Meta Test Acc  (feat) @1': meta_test_acc_feat,
               #'Final Meta Test std  (feat) @1': meta_test_std_feat,
               'Final Meta Val Acc @1': meta_val_acc,
               'Final Meta Val std @1': meta_val_std,
               'Final Meta Val Acc @5': meta_val_acc5,
               'Final Meta Val std @5': meta_val_std5,
               #'Final Meta Val Acc   (feat) @1': meta_val_acc_feat,
               #'Final Meta Val std   (feat) @1': meta_val_std_feat
              })

    
   # opt.n_shots = 5
   # train_loader, val_loader, meta_testloader, meta_valloader, _ = get_dataloaders(opt)
    
    #validate
   # meta_val_acc, meta_val_std = meta_test(model, meta_valloader, use_logit=True)
    
    #meta_val_acc_feat, meta_val_std_feat = meta_test(model, meta_valloader, use_logit=False)

    #evaluate
   # meta_test_acc, meta_test_std = meta_test(model, meta_testloader, use_logit=True)
    
    #meta_test_acc_feat, meta_test_std_feat = meta_test(model, meta_testloader, use_logit=False)
        
   # print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}'.format(meta_val_acc, meta_val_std))
    #print('Meta Val Acc (feat): {:.4f}, Meta Val std (feat): {:.4f}'.format(meta_val_acc_feat, meta_val_std_feat))
   # print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}'.format(meta_test_acc, meta_test_std))
    #print('Meta Test Acc (feat): {:.4f}, Meta Test std (feat): {:.4f}'.format(meta_test_acc_feat, meta_test_std_feat))

    #wandb.log({'Final Meta Test Acc @5': meta_test_acc,
     #          'Final Meta Test std @5': meta_test_std,
               #'Final Meta Test Acc  (feat) @5': meta_test_acc_feat,
               #'Final Meta Test std  (feat) @5': meta_test_std_feat,
       #        'Final Meta Val Acc @5': meta_val_acc,
      #         'Final Meta Val std @5': meta_val_std,
               #'Final Meta Val Acc   (feat) @5': meta_val_acc_feat,
               #'Final Meta Val std   (feat) @5': meta_val_std_feat
        #      })
    
    
if __name__ == '__main__':
    main()

In [None]:
%%writefile /kaggle/working/SKD/eval/meta_eval.py 

from __future__ import print_function

import numpy as np
import scipy
from scipy.stats import t
from tqdm import tqdm

import torch
from sklearn import metrics
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier

import torch
import torch.nn as nn
import sys, os
from collections import Counter


sys.path.append(os.path.abspath('..'))

from util import accuracy


def mean_confidence_interval(data, data5, confidence=0.95):
    a = 100.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * t._ppf((1+confidence)/2., n-1)
    
    a5 = 100.0 * np.array(data5)
    n5 = len(a5)
    m5, se5 = np.mean(a5), scipy.stats.sem(a5)
    h5 = se5 * t._ppf((1+confidence)/2., n5-1)
    return m, h, m5, h5


def normalize(x):
    norm = x.pow(2).sum(1, keepdim=True).pow(1. / 2)
    out = x.div(norm)
    return out


def meta_test(net, testloader, use_logit=False, is_norm=True, classifier='LR'):
    net = net.eval()
    acc = []
    acc5 = []

    with torch.no_grad():
        with tqdm(testloader, total=len(testloader)) as pbar:
            for idx, data in enumerate(pbar):
                support_xs, support_ys, support_xs5, support_ys5, query_xs, query_ys, query_xs5, query_ys5 = data

                support_xs = support_xs.cuda()
                support_xs5 = support_xs5.cuda()
                query_xs = query_xs.cuda()
                query_xs5 = query_xs5.cuda()
                batch_size, _, height, width, channel = support_xs.size()
                batch_size5, _, height5, width5, channel5 = support_xs5.size()
                support_xs = support_xs.view(-1, height, width, channel)
                support_xs5 = support_xs5.view(-1, height5, width5, channel5)
                query_xs = query_xs.view(-1, height, width, channel)
                query_xs5 = query_xs5.view(-1, height5, width5, channel5)

                
                
#                 batch_size = support_xs.size()[0]
#                 x = support_xs
#                 x_90 = x.transpose(2,3).flip(2)
#                 x_180 = x.flip(2).flip(3)
#                 x_270 = x.flip(2).transpose(2,3)
#                 generated_data = torch.cat((x, x_90, x_180, x_270),0)
#                 support_ys = support_ys.repeat(1,4)
#                 support_xs = generated_data
            
#                 print(support_xs.size())
#                 print(support_ys.size())



                if use_logit:
                    support_features = net(support_xs).view(support_xs.size(0), -1)
                    support_features5 = net(support_xs5).view(support_xs5.size(0), -1)
                    query_features = net(query_xs).view(query_xs.size(0), -1)
                    query_features5 = net(query_xs5).view(query_xs5.size(0), -1)
                else:
                    feat_support, _ = net(support_xs, is_feat=True)
                    support_features = feat_support[-1].view(support_xs.size(0), -1)
                    feat_query, _ = net(query_xs, is_feat=True)
                    query_features = feat_query[-1].view(query_xs.size(0), -1)

#                     feat_support, _ = net(support_xs)
#                     support_features = feat_support.view(support_xs.size(0), -1)
#                     feat_query, _ = net(query_xs)
#                     query_features = feat_query.view(query_xs.size(0), -1)


                if is_norm:
                    support_features = normalize(support_features)
                    support_features5 = normalize(support_features5)
                    query_features = normalize(query_features)
                    query_features5 = normalize(query_features5)

                support_features = support_features.detach().cpu().numpy()
                support_features5 = support_features5.detach().cpu().numpy()
                query_features = query_features.detach().cpu().numpy()
                query_features5 = query_features5.detach().cpu().numpy()
                
                support_ys = support_ys.view(-1).numpy()
                support_ys5 = support_ys5.view(-1).numpy()
                query_ys = query_ys.view(-1).numpy()
                query_ys5 = query_ys5.view(-1).numpy()
                
                
                
                if classifier == 'LR':
                    clf = LogisticRegression(random_state=0, solver='lbfgs', max_iter=1000, penalty='l2',
                                             multi_class='multinomial')
                    clf5 = LogisticRegression(random_state=0, solver='lbfgs', max_iter=1000, penalty='l2',
                                             multi_class='multinomial')                         
                    clf.fit(support_features, support_ys)
                    clf5.fit(support_features5, support_ys5)
                    query_ys_pred = clf.predict(query_features)
                    query_ys_pred5 = clf5.predict(query_features5)
                elif classifier == 'NN':
                    query_ys_pred = NN(support_features, support_ys, query_features)
                elif classifier == 'Cosine':
                    query_ys_pred = Cosine(support_features, support_ys, query_features)
                else:
                    raise NotImplementedError('classifier not supported: {}'.format(classifier))

                    
#                 bs = query_features.shape[0]//opt.n_aug_support_samples
#                 a = np.reshape(query_ys_pred[:bs], (-1,1))
#                 c = query_ys[:bs]
#                 for i in range(1,opt.n_aug_support_samples):
#                     a = np.hstack([a, np.reshape(query_ys_pred[i*bs:(i+1)*bs], (-1,1))])
                
#                 d = [] 
#                 for i in range(a.shape[0]):
#                     b = Counter(a[i,:])
#                     d.append(b.most_common(1)[0][0])
                
# #                 (values,counts) = np.unique(a,axis=1, return_counts=True)
# #                 print(counts)
# # ind=np.argmax(counts)
# # print values[ind]  # pr


# # #                 a = np.argmax
# #                 print(a.shape)
# #                 print(c.shape)
                    
                acc.append(metrics.accuracy_score(query_ys, query_ys_pred))
                acc5.append(metrics.accuracy_score(query_ys5, query_ys_pred5))
                
                pbar.set_postfix({"FSL_Acc":'{0:.2f}'.format(metrics.accuracy_score(query_ys, query_ys_pred))})
    
    return mean_confidence_interval(acc,acc5)




def meta_test_tune(net, testloader, use_logit=False, is_norm=True, classifier='LR', lamda=0.2):
    net = net.eval()
    acc = []
    
    with tqdm(testloader, total=len(testloader)) as pbar:
        for idx, data in enumerate(pbar):
            support_xs, support_ys, query_xs, query_ys, support_ts, query_ts = data

            support_xs = support_xs.cuda()
            support_ys = support_ys.cuda()
            query_ys = query_ys.cuda()
            query_xs = query_xs.cuda()
            batch_size, _, height, width, channel = support_xs.size()
            support_xs = support_xs.view(-1, height, width, channel)
            support_ys = support_ys.view(-1,1)
            query_ys = query_ys.view(-1)
            query_xs = query_xs.view(-1, height, width, channel)

            if use_logit:
                support_features = net(support_xs).view(support_xs.size(0), -1)
                query_features = net(query_xs).view(query_xs.size(0), -1)
            else:
                feat_support, _ = net(support_xs, is_feat=True)
                support_features = feat_support[-1].view(support_xs.size(0), -1)
                feat_query, _ = net(query_xs, is_feat=True)
                query_features = feat_query[-1].view(query_xs.size(0), -1)

            if is_norm:
                support_features = normalize(support_features)
                query_features = normalize(query_features)
               
            y_onehot = torch.FloatTensor(support_ys.size()[0], 5).cuda()

            # In your for loop
            y_onehot.zero_()
            y_onehot.scatter_(1, support_ys, 1)

    
            X = support_features
            XTX = torch.matmul(torch.t(X),X)
            
            B = torch.matmul( (XTX + lamda*torch.eye(640).cuda() ).inverse(), torch.matmul(torch.t(X), y_onehot.float()) )
#             print(B.size())
            m = nn.Sigmoid()
            Y_pred = m(torch.matmul(query_features, B))
                
                
#             print(Y_pred, query_ys)
#             model = nn.Sequential(nn.Linear(64, 10),nn.LogSoftmax(dim=1))
#             optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#             criterion = nn.CrossEntropyLoss()

#             model.cuda()
#             criterion.cuda()
#             model.train()
            
#             for i in range(5):
#                 output = model(support_features)
#                 loss = criterion(output, support_ys)
#                 optimizer.zero_grad()
#                 loss.backward(retain_graph=True) # auto-grad 
#                 optimizer.step() # update  weights 
            
#             model.eval()
#             query_ys_pred = model(query_features)

            acc1, acc5 = accuracy(Y_pred, query_ys, topk=(1, 1))
            
            
#             support_features = support_features.detach().cpu().numpy()
#             query_features = query_features.detach().cpu().numpy()

#             support_ys = support_ys.view(-1).numpy()
#             query_ys = query_ys.view(-1).numpy()

#             if classifier == 'LR':
#                 clf = LogisticRegression(random_state=0, solver='lbfgs', max_iter=1000,
#                                          multi_class='multinomial')
#                 clf.fit(support_features, support_ys)
#                 query_ys_pred = clf.predict(query_features)
#             elif classifier == 'NN':
#                 query_ys_pred = NN(support_features, support_ys, query_features)
#             elif classifier == 'Cosine':
#                 query_ys_pred = Cosine(support_features, support_ys, query_features)
#             else:
#                 raise NotImplementedError('classifier not supported: {}'.format(classifier))

            acc.append(acc1.item()/100.0)

            pbar.set_postfix({"FSL_Acc":'{0:.4f}'.format(np.mean(acc))})
                
                
    return mean_confidence_interval(acc)



def meta_test_ensamble(net, testloader, use_logit=True, is_norm=True, classifier='LR'):
    for n in net:
        n = n.eval()
    acc = []

    with torch.no_grad():
        with tqdm(testloader, total=len(testloader)) as pbar:
            for idx, data in enumerate(pbar):
                support_xs, support_ys, query_xs, query_ys = data

                support_xs = support_xs.cuda()
                query_xs = query_xs.cuda()
                batch_size, _, height, width, channel = support_xs.size()
                support_xs = support_xs.view(-1, height, width, channel)
                query_xs = query_xs.view(-1, height, width, channel)

                if use_logit:
                    support_features = net[0](support_xs).view(support_xs.size(0), -1)
                    query_features = net[0](query_xs).view(query_xs.size(0), -1)
                    for n in net[1:]:
                        support_features += n(support_xs).view(support_xs.size(0), -1)
                        query_features += n(query_xs).view(query_xs.size(0), -1)
                else:
                    feat_support, _ = net(support_xs, is_feat=True)
                    support_features = feat_support[-1].view(support_xs.size(0), -1)
                    feat_query, _ = net(query_xs, is_feat=True)
                    query_features = feat_query[-1].view(query_xs.size(0), -1)

                if is_norm:
                    support_features = normalize(support_features)
                    query_features = normalize(query_features)

                support_features = support_features.detach().cpu().numpy()
                query_features = query_features.detach().cpu().numpy()

                support_ys = support_ys.view(-1).numpy()
                query_ys = query_ys.view(-1).numpy()

                if classifier == 'LR':
                    clf = LogisticRegression(random_state=0, solver='lbfgs', max_iter=1000,
                                             multi_class='multinomial')
                    clf.fit(support_features, support_ys)
                    query_ys_pred = clf.predict(query_features)
                elif classifier == 'NN':
                    query_ys_pred = NN(support_features, support_ys, query_features)
                elif classifier == 'Cosine':
                    query_ys_pred = Cosine(support_features, support_ys, query_features)
                else:
                    raise NotImplementedError('classifier not supported: {}'.format(classifier))

                acc.append(metrics.accuracy_score(query_ys, query_ys_pred))
                
                pbar.set_postfix({"FSL_Acc":'{0:.2f}'.format(metrics.accuracy_score(query_ys, query_ys_pred))})
                
    return mean_confidence_interval(acc)


def NN(support, support_ys, query):
    """nearest classifier"""
    support = np.expand_dims(support.transpose(), 0)
    query = np.expand_dims(query, 2)

    diff = np.multiply(query - support, query - support)
    distance = diff.sum(1)
    min_idx = np.argmin(distance, axis=1)
    pred = [support_ys[idx] for idx in min_idx]
    return pred


def Cosine(support, support_ys, query):
    """Cosine classifier"""
    support_norm = np.linalg.norm(support, axis=1, keepdims=True)
    support = support / support_norm
    query_norm = np.linalg.norm(query, axis=1, keepdims=True)
    query = query / query_norm

    cosine_distance = query @ support.transpose()
    max_idx = np.argmax(cosine_distance, axis=1)
    pred = [support_ys[idx] for idx in max_idx]
    return pred

In [None]:
%%writefile /kaggle/working/SKD/dataset/cifar.py 

from __future__ import print_function

import os
import pickle
from PIL import Image
import numpy as np

import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset


class CIFAR100(Dataset):
    """support FC100 and CIFAR-FS"""
    def __init__(self, args, partition='train', pretrain=True, is_sample=False, k=4096,
                 transform=None):
        super(Dataset, self).__init__()
        self.data_root = args.data_root
        self.partition = partition
        self.data_aug = args.data_aug
        self.mean = [0.5071, 0.4867, 0.4408]
        self.std = [0.2675, 0.2565, 0.2761]
        self.normalize = transforms.Normalize(mean=self.mean, std=self.std)
        self.pretrain = pretrain
        self.simclr = args.simclr
        
        
        if transform is None:
            if self.partition == 'train' and self.data_aug:
                self.transform = transforms.Compose([
                    lambda x: Image.fromarray(x),
                    transforms.RandomCrop(32, padding=4),
                    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                    transforms.RandomHorizontalFlip(),
                    lambda x: np.asarray(x),
                    transforms.ToTensor(),
                    self.normalize
                ])
            else:
                self.transform = transforms.Compose([
                    lambda x: Image.fromarray(x),
                    transforms.ToTensor(),
                    self.normalize
                ])
        else:
            self.transform = transform

        if self.pretrain:
            self.file_pattern = '%s.pickle'
        else:
            self.file_pattern = '%s.pickle'
        self.data = {}

        with open(os.path.join(self.data_root, self.file_pattern % partition), 'rb') as f:
            data = pickle.load(f, encoding='latin1')
            self.imgs = data['data']
            labels = data['labels']
            # adjust sparse labels to labels from 0 to n.
            cur_class = 0
            label2label = {}
            for idx, label in enumerate(labels):
                if label not in label2label:
                    label2label[label] = cur_class
                    cur_class += 1
            new_labels = []
            for idx, label in enumerate(labels):
                new_labels.append(label2label[label])
            self.labels = new_labels
        
        
        # pre-process for contrastive sampling
        self.k = k
        self.is_sample = is_sample
        if self.is_sample:
            self.labels = np.asarray(self.labels)
            self.labels = self.labels - np.min(self.labels)
            num_classes = np.max(self.labels) + 1

            self.cls_positive = [[] for _ in range(num_classes)]
            for i in range(len(self.imgs)):
                self.cls_positive[self.labels[i]].append(i)

            self.cls_negative = [[] for _ in range(num_classes)]
            for i in range(num_classes):
                for j in range(num_classes):
                    if j == i:
                        continue
                    self.cls_negative[i].extend(self.cls_positive[j])

            self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)]
            self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)]
            self.cls_positive = np.asarray(self.cls_positive)
            self.cls_negative = np.asarray(self.cls_negative)

    def __getitem__(self, item):
        img = np.asarray(self.imgs[item]).astype('uint8')
        target = self.labels[item] - min(self.labels)
        
        if(self.simclr):
            img1 = self.transform(img)
            img2 = self.transform(img)
            return (img1, img2), target, item
        
        img = self.transform(img)
        if not self.is_sample:
            return img, target, item
        else:
            pos_idx = item
            replace = True if self.k > len(self.cls_negative[target]) else False
            neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace)
            sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
            return img, target, item, sample_idx

    def __len__(self):
        return len(self.labels)
    
    
    
    
class CIFAR100_toy(Dataset):
    """support FC100 and CIFAR-FS"""
    def __init__(self, args, partition='train', pretrain=True, is_sample=False, k=4096,
                 transform=None):
        super(Dataset, self).__init__()
        self.data_root = args.data_root
        self.partition = partition
        self.data_aug = args.data_aug
        self.mean = [0.5071, 0.4867, 0.4408]
        self.std = [0.2675, 0.2565, 0.2761]
        self.normalize = transforms.Normalize(mean=self.mean, std=self.std)
        self.pretrain = pretrain
        self.simclr = args.simclr
        
        
        if transform is None:
            if self.partition == 'train' and self.data_aug:
                self.transform = transforms.Compose([
                    lambda x: Image.fromarray(x),
                    transforms.RandomCrop(32, padding=4),
                    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                    transforms.RandomHorizontalFlip(),
                    lambda x: np.asarray(x),
                    transforms.ToTensor(),
                    self.normalize
                ])
            else:
                self.transform = transforms.Compose([
                    lambda x: Image.fromarray(x),
                    transforms.ToTensor(),
                    self.normalize
                ])
        else:
            self.transform = transform

        if self.pretrain:
            self.file_pattern = '%s.pickle'
        else:
            self.file_pattern = '%s.pickle'
        self.data = {}

        with open(os.path.join(self.data_root, self.file_pattern % partition), 'rb') as f:
            data = pickle.load(f, encoding='latin1')
            self.imgs = data['data']
            labels = data['labels']
            # adjust sparse labels to labels from 0 to n.
            cur_class = 0
            label2label = {}
            for idx, label in enumerate(labels):
                if label not in label2label:
                    label2label[label] = cur_class
                    cur_class += 1
            new_labels = []
            for idx, label in enumerate(labels):
                new_labels.append(label2label[label])
            self.labels = new_labels
        
        self.labels = np.array(self.labels)
        self.imgs = np.array(self.imgs)
        print(self.labels.shape)
        print(self.imgs.shape)
        
        loc = np.where(self.labels<5)[0]
        self.labels = self.labels[loc]
        self.imgs   = self.imgs[loc]
        
        
        self.k = k
        self.is_sample = is_sample

    def __getitem__(self, item):
        img = np.asarray(self.imgs[item]).astype('uint8')
        target = self.labels[item] - min(self.labels)
        
        if(self.simclr):
            img1 = self.transform(img)
            img2 = self.transform(img)
            return (img1, img2), target, item
        
        img = self.transform(img)
        if not self.is_sample:
            return img, target, item
        else:
            pos_idx = item
            replace = True if self.k > len(self.cls_negative[target]) else False
            neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace)
            sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
            return img, target, item, sample_idx

    def __len__(self):
        return len(self.labels)
    
    
    


class MetaCIFAR100(CIFAR100):

    def __init__(self, args, partition='train', train_transform=None, test_transform=None, fix_seed=True):
        super(MetaCIFAR100, self).__init__(args, partition, False)
        self.fix_seed = fix_seed
        self.n_ways = args.n_ways
        self.n_shots = args.n_shots
        self.n_queries = args.n_queries
        self.classes = list(self.data.keys())
        self.n_test_runs = args.n_test_runs
        self.n_aug_support_samples = args.n_aug_support_samples
        if train_transform is None:
            self.train_transform = transforms.Compose([
                lambda x: Image.fromarray(x),
                transforms.RandomCrop(32, padding=4),
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                transforms.RandomHorizontalFlip(),
                lambda x: np.asarray(x),
                transforms.ToTensor(),
                self.normalize
            ])
        else:
            self.train_transform = train_transform

        if test_transform is None:
            self.test_transform = transforms.Compose([
                lambda x: Image.fromarray(x),
                transforms.ToTensor(),
                self.normalize
            ])
        else:
            self.test_transform = test_transform

        self.data = {}
        for idx in range(self.imgs.shape[0]):
            if self.labels[idx] not in self.data:
                self.data[self.labels[idx]] = []
            self.data[self.labels[idx]].append(self.imgs[idx])
        self.classes = list(self.data.keys())

    def __getitem__(self, item):
        if self.fix_seed:
            np.random.seed(item)
        cls_sampled = np.random.choice(self.classes, self.n_ways, False)
        
        support_xs = []
        support_ys = []
        support_ts = []
        query_xs = []
        query_ys = []
        query_ts = []
        
        support_xs5 = []
        support_ys5 = []
        support_ts5 = []
        query_xs5 = []
        query_ys5 = []
        query_ts5 = []
        
        for idx, cls in enumerate(cls_sampled):
            imgs = np.asarray(self.data[cls]).astype('uint8')
            support_xs_ids_sampled = np.random.choice(range(imgs.shape[0]), self.n_shots, False)
            support_xs.append(imgs[support_xs_ids_sampled])
            support_ys.append([idx] * self.n_shots)
            support_ts.append([cls] * self.n_shots)
                 
            support_xs_ids_sampled5 = np.random.choice(range(imgs.shape[0]), 5, False)
            support_xs5.append(imgs[support_xs_ids_sampled5])
            #support_xs5.append(imgs[support_xs_ids_sampled])
            support_ys5.append([idx] * 5)
            support_ts5.append([cls] * 5)
            
            query_xs_ids = np.setxor1d(np.arange(imgs.shape[0]), support_xs_ids_sampled)
            query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False)
            query_xs.append(imgs[query_xs_ids])
            query_ys.append([idx] * query_xs_ids.shape[0])
            query_ts.append([cls] * query_xs_ids.shape[0])
            
            query_xs_ids5 = np.setxor1d(np.arange(imgs.shape[0]), support_xs_ids_sampled5)
            query_xs_ids5 = np.random.choice(query_xs_ids5, self.n_queries, False)
            query_xs5.append(imgs[query_xs_ids5])
            query_ys5.append([idx] * query_xs_ids5.shape[0])
            query_ts5.append([cls] * query_xs_ids5.shape[0])
            
        support_xs, support_ys,support_xs5, support_ys5, query_xs, query_ys, query_xs5, query_ys5 = np.array(support_xs), np.array(support_ys),np.array(support_xs5), np.array(support_ys5), np.array(
            query_xs), np.array(query_ys), np.array(query_xs5),np.array(query_ys5)
        
        support_ts,support_ts5, query_ts, query_ts5 = np.array(support_ts),np.array(support_ts5), np.array(query_ts),np.array(query_ts5)
        
        num_ways, n_queries_per_way, height, width, channel = query_xs.shape
        query_xs = query_xs.reshape((num_ways * n_queries_per_way, height, width, channel))
        query_ys = query_ys.reshape((num_ways * n_queries_per_way,))
        query_ts = query_ts.reshape((num_ways * n_queries_per_way,))
        
        num_ways5, n_queries_per_way5, height5, width5, channel5 = query_xs5.shape
        query_xs5 = query_xs5.reshape((num_ways5 * n_queries_per_way5, height5, width5, channel5))
        query_ys5 = query_ys5.reshape((num_ways5 * n_queries_per_way5,))
        query_ts5 = query_ts5.reshape((num_ways5 * n_queries_per_way5,))

        support_xs = support_xs.reshape((-1, height, width, channel))
        support_xs5 = support_xs5.reshape((-1, height5, width5, channel5))
        
        
        
        
        
        if self.n_aug_support_samples > 1:
            support_xs = np.tile(support_xs, (self.n_aug_support_samples, 1, 1, 1))
            support_ys = np.tile(support_ys.reshape((-1,)), (self.n_aug_support_samples))
            support_ts = np.tile(support_ts.reshape((-1,)), (self.n_aug_support_samples))
            
            support_xs5 = np.tile(support_xs5, (self.n_aug_support_samples, 1, 1, 1))
            support_ys5 = np.tile(support_ys5.reshape((-1,)), (self.n_aug_support_samples))
            support_ts5 = np.tile(support_ts5.reshape((-1,)), (self.n_aug_support_samples))
        
        
        support_xs = np.split(support_xs, support_xs.shape[0], axis=0)
        support_xs5 = np.split(support_xs5, support_xs5.shape[0], axis=0)
        
        
        
        query_xs = query_xs.reshape((-1, height, width, channel))
        query_xs5 = query_xs5.reshape((-1, height5, width5, channel5))
        
        if self.n_aug_support_samples > 1:
            query_xs = np.tile(query_xs, (self.n_aug_support_samples, 1, 1, 1))
            query_ys = np.tile(query_ys.reshape((-1,)), (self.n_aug_support_samples))
            query_ts = np.tile(query_ts.reshape((-1,)), (self.n_aug_support_samples))
            
            query_xs5 = np.tile(query_xs5, (self.n_aug_support_samples, 1, 1, 1))
            query_ys5 = np.tile(query_ys5.reshape((-1,)), (self.n_aug_support_samples))
            query_ts5 = np.tile(query_ts5.reshape((-1,)), (self.n_aug_support_samples))
            
        query_xs = np.split(query_xs, query_xs.shape[0], axis=0)
        query_xs5 = np.split(query_xs5, query_xs5.shape[0], axis=0)

        support_xs = torch.stack(list(map(lambda x: self.train_transform(x.squeeze()), support_xs)))
        support_xs5 = torch.stack(list(map(lambda x: self.train_transform(x.squeeze()), support_xs5)))
        query_xs = torch.stack(list(map(lambda x: self.test_transform(x.squeeze()), query_xs)))
        query_xs5 = torch.stack(list(map(lambda x: self.test_transform(x.squeeze()), query_xs5)))

        return support_xs, support_ys, support_xs5, support_ys5, query_xs, query_ys, query_xs5, query_ys5

    def __len__(self):
        return self.n_test_runs


if __name__ == '__main__':
    args = lambda x: None
    args.n_ways = 5
    args.n_shots = 1
    args.n_queries = 12
    # args.data_root = 'data'
    args.data_root = '/home/yonglong/Downloads/FC100'
    args.data_aug = True
    args.n_test_runs = 5
    args.n_aug_support_samples = 1
    imagenet = CIFAR100(args, 'train')
    print(len(imagenet))
    print(imagenet.__getitem__(500)[0].shape)

    metaimagenet = MetaCIFAR100(args, 'train')
    print(len(metaimagenet))
    print(metaimagenet.__getitem__(500)[0].size())
    print(metaimagenet.__getitem__(500)[1].shape)
    print(metaimagenet.__getitem__(500)[2].size())
    print(metaimagenet.__getitem__(500)[3].shape)

In [None]:
#!wandb off

In [None]:
!python /kaggle/working/SKD/train_selfsupervison.py --model VAE_resnet18_ssl --model_path save/CIFAR-FS --dataset CIFAR-FS --data_root /kaggle/input/cifar-fs --epochs 65 --lr_decay_epochs 60 

In [None]:
#!python /kaggle/working/SKD/train_selfsupervison.py --model VAE_resnet18_ssl --model_path save/FC100 --dataset FC100 --data_root /kaggle/input/few-shotcifar-100 --epochs 20 --lr_decay_epochs 8