<a href="https://colab.research.google.com/github/anna-alt/AI-Lab/blob/main/fewshot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import torch
from torch.autograd import Variable as V
import torchvision.models as models
from torchvision import transforms as trn
from torch.nn import functional as F
import os
import numpy as np
import cv2
from PIL import Image
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

import torch.nn.functional as F
!pip install pyg-lib torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-1.13.0+cu116.html
!pip install torch_geometric
from torch_geometric.nn import GCNConv

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.13.0+cu116.html
Collecting pyg-lib
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu116/pyg_lib-0.1.0%2Bpt113cu116-cp38-cp38-linux_x86_64.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 5.3 MB/s 
[?25hCollecting torch_scatter
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_scatter-2.1.0%2Bpt113cu116-cp38-cp38-linux_x86_64.whl (9.4 MB)
[K     |████████████████████████████████| 9.4 MB 49.5 MB/s 
[?25hCollecting torch_sparse
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_sparse-0.6.15%2Bpt113cu116-cp38-cp38-linux_x86_64.whl (4.6 MB)
[K     |████████████████████████████████| 4.6 MB 42.3 MB/s 
Installing collected packages: torch-sparse, torch-scatter, pyg-lib
Successfully installed pyg-lib-0.1.0+pt113cu116 torch-scatter-2.1.0+pt113cu116 torch-sparse-0.6.15+pt113cu116
Looking in in

In [5]:
class Graph_conv_block(nn.Module):
    def __init__(self, input_dim, output_dim, use_bn=True):
        super(Graph_conv_block, self).__init__()

        self.weight = nn.Linear(input_dim, output_dim)
        if use_bn:
            self.bn = nn.BatchNorm1d(output_dim)
        else:
            self.bn = None

    def forward(self, x, A):
        x_next = torch.matmul(A, x) # (b, N, input_dim)
        x_next = self.weight(x_next) # (b, N, output_dim)

        if self.bn is not None:
            x_next = torch.transpose(x_next, 1, 2) # (b, output_dim, N)
            x_next = x_next.contiguous()
            x_next = self.bn(x_next)
            x_next = torch.transpose(x_next, 1, 2) # (b, N, output)

        return x_next

class Adjacency_layer(nn.Module):
    def __init__(self, input_dim, hidden_dim, ratio=[2,2,1,1]):

        super(Adjacency_layer, self).__init__()

        module_list = []

        for i in range(len(ratio)):
            if i == 0:
                module_list.append(nn.Conv2d(input_dim, hidden_dim*ratio[i], 1, 1))
            else:
                module_list.append(nn.Conv2d(hidden_dim*ratio[i-1], hidden_dim*ratio[i], 1, 1))

            module_list.append(nn.BatchNorm2d(hidden_dim*ratio[i]))
            module_list.append(nn.LeakyReLU())

        module_list.append(nn.Conv2d(hidden_dim*ratio[-1], 1, 1, 1))

        self.module_list = nn.ModuleList(module_list)

    def forward(self, x):
        X_i = x.unsqueeze(2) # (b, N , 1, input_dim)
        X_j = torch.transpose(X_i, 1, 2) # (b, 1, N, input_dim)

        phi = torch.abs(X_i - X_j) # (b, N, N, input_dim)

        phi = torch.transpose(phi, 1, 3) # (b, input_dim, N, N)

        A = phi

        for l in self.module_list:
            A = l(A)
        # (b, 1, N, N)

        A = torch.transpose(A, 1, 3) # (b, N, N, 1)

        A = F.softmax(A, 2) # normalize

        return A.squeeze(3) # (b, N, N)

class GNN_module(nn.Module):
    def __init__(self, nway, input_dim, hidden_dim, num_layers, feature_type='dense'):
        super(GNN_module, self).__init__()

        self.feature_type = feature_type

        adjacency_list = []
        graph_conv_list = []

        # ratio = [2, 2, 1, 1]
        ratio = [2, 1]

        if self.feature_type == 'dense':
            for i in range(num_layers):
                adjacency_list.append(Adjacency_layer(
                    input_dim=input_dim+hidden_dim//2*i, 
                    hidden_dim=hidden_dim, 
                    ratio=ratio))

                graph_conv_list.append(Graph_conv_block(
                    input_dim=input_dim+hidden_dim//2*i, 
                    output_dim=hidden_dim//2))

            # last layer
            last_adjacency = Adjacency_layer(
                        input_dim=input_dim+hidden_dim//2*num_layers, 
                        hidden_dim=hidden_dim, 
                        ratio=ratio)

            last_conv = Graph_conv_block(
                    input_dim=input_dim+hidden_dim//2*num_layers, 
                    output_dim=nway, 
                    use_bn=False)

        elif self.feature_type == 'forward':
            for i in range(num_layers):
                adjacency_list.append(Adjacency_layer(
                    input_dim=input_dim if i == 0 else hidden_dim, 
                    hidden_dim=hidden_dim, 
                    ratio=ratio))

                graph_conv_list.append(Graph_conv_block(
                    input_dim=hidden_dim, 
                    output_dim=hidden_dim))

            # last layer
            last_adjacency = Adjacency_layer(
                        input_dim=hidden_dim, 
                        hidden_dim=hidden_dim, 
                        ratio=ratio)

            last_conv = Graph_conv_block(
                    input_dim=hidden_dim, 
                    output_dim=nway,
                    use_bn=False)

        else:
            raise NotImplementedError

        self.adjacency_list = nn.ModuleList(adjacency_list)
        self.graph_conv_list = nn.ModuleList(graph_conv_list)
        self.last_adjacency = last_adjacency
        self.last_conv = last_conv


    def forward(self, x):
        for i, _ in enumerate(self.adjacency_list):
            adjacency_layer = self.adjacency_list[i]
            conv_block = self.graph_conv_list[i]

            A = adjacency_layer(x)

            x_next = conv_block(x, A)

            x_next = F.leaky_relu(x_next, 0.1)

            if self.feature_type == 'dense':
                x = torch.cat([x, x_next], dim=2)
            elif self.feature_type == 'forward':
                x = x_next
            else:
                raise NotImplementedError
        
        A = self.last_adjacency(x)
        out = self.last_conv(x, A)   

        return out[:, 0, :]

In [14]:
import os
import time
import random
import skimage.io
import numpy as np

import torch
from torch.utils.data import Dataset
import torchvision as tv
from torchvision.datasets import CIFAR100


class self_Dataset(Dataset):
    def __init__(self, data, label=None):
        super(self_Dataset, self).__init__()

        self.data = data
        self.label = label
    def __getitem__(self, index):
        data = self.data[index]
        # data = np.moveaxis(data, 3, 1)
        # data = data.astype(np.float32)

        if self.label is not None:
            label = self.label[index]
            # print(label)
            # label = torch.from_numpy(label)
            # label = torch.LongTensor([label])
            return data, label
        else:
            return data, 1
    def __len__(self):
        return len(self.data)

def count_data(data_dict):
    num = 0
    for key in data_dict.keys():
        num += len(data_dict[key])
    return num

class self_DataLoader(Dataset):
    def __init__(self, root, train=True, dataset='cifar100', seed=1, nway=5):
        super(self_DataLoader, self).__init__()

        self.seed = seed
        self.nway = nway
        self.num_labels = 100
        self.input_channels = 3
        self.size = 32

        self.transform = tv.transforms.Compose([
            tv.transforms.ToTensor(),
            tv.transforms.Normalize([0.5071, 0.4866, 0.4409], 
                [0.2673, 0.2564, 0.2762])
            ])

        self.full_data_dict, self.few_data_dict = self.load_data(root, train, dataset)

        print('full_data_num: %d' % count_data(self.full_data_dict))
        print('few_data_num: %d' % count_data(self.few_data_dict))

    def load_data(self, root, train, dataset):
        if dataset == 'cifar100':
            few_selected_label = random.Random(self.seed).sample(range(self.num_labels), self.nway)
            print('selected labeled', few_selected_label)

            full_data_dict = {}
            few_data_dict = {}

            d = CIFAR100(root, train=train, download=True)

            for i, (data, label) in enumerate(d):

                data = self.transform(data)

                if label in few_selected_label:
                    data_dict = few_data_dict
                else:
                    data_dict = full_data_dict

                if label not in data_dict:
                    data_dict[label] = [data]
                else:
                    data_dict[label].append(data)
            print(i + 1)
        else:
            raise NotImplementedError

        return full_data_dict, few_data_dict

    def load_batch_data(self, train=True, batch_size=16, nway=5, num_shots=1):
        if train:
            data_dict = self.full_data_dict
        else:
            data_dict = self.few_data_dict

        x = []
        label_y = [] # fake label: from 0 to (nway - 1)
        one_hot_y = [] # one hot for fake label
        class_y = [] # real label

        xi = []
        label_yi = []
        one_hot_yi = []
        

        map_label2class = []

        ### the format of x, label_y, one_hot_y, class_y is 
        ### [tensor, tensor, ..., tensor] len(label_y) = batch size
        ### the first dimension of tensor = num_shots

        for i in range(batch_size):

            # sample the class to train
            sampled_classes = random.sample(data_dict.keys(), nway)

            positive_class = random.randint(0, nway - 1)

            label2class = torch.LongTensor(nway)

            single_xi = []
            single_one_hot_yi = []
            single_label_yi = []
            single_class_yi = []


            for j, _class in enumerate(sampled_classes):
                if j == positive_class:
                    ### without loss of generality, we assume the 0th 
                    ### sampled  class is the target class
                    sampled_data = random.sample(data_dict[_class], num_shots+1)

                    x.append(sampled_data[0])
                    label_y.append(torch.LongTensor([j]))

                    one_hot = torch.zeros(nway)
                    one_hot[j] = 1.0
                    one_hot_y.append(one_hot)

                    class_y.append(torch.LongTensor([_class]))

                    shots_data = sampled_data[1:]
                else:
                    shots_data = random.sample(data_dict[_class], num_shots)

                single_xi += shots_data
                single_label_yi.append(torch.LongTensor([j]).repeat(num_shots))
                one_hot = torch.zeros(nway)
                one_hot[j] = 1.0
                single_one_hot_yi.append(one_hot.repeat(num_shots, 1))

                label2class[j] = _class

            shuffle_index = torch.randperm(num_shots*nway)
            xi.append(torch.stack(single_xi, dim=0)[shuffle_index])
            label_yi.append(torch.cat(single_label_yi, dim=0)[shuffle_index])
            one_hot_yi.append(torch.cat(single_one_hot_yi, dim=0)[shuffle_index])

            map_label2class.append(label2class)

        return [torch.stack(x, 0), torch.cat(label_y, 0), torch.stack(one_hot_y, 0), \
            torch.cat(class_y, 0), torch.stack(xi, 0), torch.stack(label_yi, 0), \
            torch.stack(one_hot_yi, 0), torch.stack(map_label2class, 0)]

    # def load_batch_data(self, train=True, batch_size=16, nway=5, num_shots=1):

    #     if train:
    #         data_dict = self.full_data_dict
    #     else:
    #         data_dict = self.few_data_dict

    #     x = torch.zeros(batch_size, self.input_channels, self.size, self.size)
    #     label_y = torch.LongTensor(batch_size).zero_()
    #     one_hot_y = torch.zeros(batch_size, nway)
    #     class_y = torch.LongTensor(batch_size).zero_()
    #     xi, label_yi, one_hot_yi, class_yi = [], [], [], []

    #     for i in range(nway*num_shots):
    #         xi.append(torch.zeros(batch_size, self.input_channels, self.size, self.size))
    #         label_yi.append(torch.LongTensor(batch_size).zero_())
    #         one_hot_yi.append(torch.zeros(batch_size, nway))
    #         class_yi.append(torch.LongTensor(batch_size).zero_())

    #     # sample data

    #     for i in range(batch_size):

    #         # sample the class to train
    #         sampled_classes = random.sample(data_dict.keys(), nway)

    #         positive_class = random.randint(0, nway - 1)

    #         indexes_perm = np.random.permutation(nway * num_shots)

    #         counter = 0

    #         for j, _class in enumerate(sampled_classes):
    #             if j == positive_class:
    #                 ### without loss of generality, we assume the 0th 
    #                 ### sampled  class is the target class
    #                 sampled_data = random.sample(data_dict[_class], num_shots+1)

    #                 x[i] = sampled_data[0]
    #                 label_y[i] = j

    #                 one_hot_y[i, j] = 1.0

    #                 class_y[i] = _class

    #                 shots_data = sampled_data[1:]
    #             else:
    #                 shots_data = random.sample(data_dict[_class], num_shots)

    #             for s_i in range(0, len(shots_data)):
    #                 xi[indexes_perm[counter]][i] = shots_data[s_i]
                    
    #                 label_yi[indexes_perm[counter]][i] = j
    #                 one_hot_yi[indexes_perm[counter]][i, j] = 1.0
    #                 class_yi[indexes_perm[counter]][i] = _class

    #                 counter += 1
    #     return [x, label_y, one_hot_y, class_y, torch.stack(xi, 1), torch.stack(label_yi, 1), \
    #         torch.stack(one_hot_yi, 1), torch.stack(class_yi, 1)]

    def load_tr_batch(self, batch_size=16, nway=5, num_shots=1):
        return self.load_batch_data(True, batch_size, nway, num_shots)

    def load_te_batch(self, batch_size=16, nway=5, num_shots=1):
        return self.load_batch_data(False, batch_size, nway, num_shots)

    def get_data_list(self, data_dict):
        data_list = []
        label_list = []
        for i in data_dict.keys():
            for data in data_dict[i]:
                data_list.append(data)
                label_list.append(i)

        

        random.shuffle(data_list)
        random.shuffle(label_list)

        return data_list, label_list

    def get_full_data_list(self):
        return self.get_data_list(self.full_data_dict)

    def get_few_data_list(self):
        return self.get_data_list(self.few_data_dict)

if __name__ == '__main__':
    D = self_DataLoader('/home/lab5300/Data', True)

    [x, label_y, one_hot_y, class_y, xi, label_yi, one_hot_yi, class_yi] = \
        D.load_tr_batch(batch_size=16, nway=5, num_shots=5)
    print(x.size(), label_y.size(), one_hot_y.size(), class_y.size())
    print(xi.size(), label_yi.size(), one_hot_yi.size(), class_yi.size())

    # print(label_y)
    # print(one_hot_y)

    print(label_yi[0])
    print(one_hot_yi[0])

selected labeled [17, 72, 97, 8, 32]
Files already downloaded and verified
50000
full_data_num: 47500
few_data_num: 2500
torch.Size([16, 3, 32, 32]) torch.Size([16]) torch.Size([16, 5]) torch.Size([16])
torch.Size([16, 25, 3, 32, 32]) torch.Size([16, 25]) torch.Size([16, 25, 5]) torch.Size([16, 5])
tensor([3, 0, 2, 2, 4, 4, 3, 3, 1, 2, 3, 0, 2, 1, 0, 4, 1, 1, 0, 2, 4, 1, 4, 0,
        3])
tensor([[0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        

In [27]:
from time import time




def np2cuda(array):
    tensor = torch.from_numpy(array)
    if torch.cuda.is_available():
        tensor = tensor.cuda()
    return tensor


def tensor2cuda(tensor):
    if torch.cuda.is_available():
        tensor = tensor.cuda()
    return tensor

class myModel(nn.Module):
    def __init__(self):
        super(myModel, self).__init__()

    def load(self, file_name):
        self.load_state_dict(torch.load(file_name, map_location=lambda storage, loc: storage))
    def save(self, file_name):
        torch.save(self.state_dict(), file_name)

###############################################################
## Vanilla CNN model, used to extract visual features

class EmbeddingCNN(myModel):

    def __init__(self, image_size, cnn_feature_size, cnn_hidden_dim, cnn_num_layers):
        super(EmbeddingCNN, self).__init__()

        module_list = []
        dim = cnn_hidden_dim
        for i in range(cnn_num_layers):
            if i == 0:
                module_list.append(nn.Conv2d(3, dim, 3, 1, 1, bias=False))
                module_list.append(nn.BatchNorm2d(dim))
            else:
                module_list.append(nn.Conv2d(dim, dim*2, 3, 1, 1, bias=False))
                module_list.append(nn.BatchNorm2d(dim*2))
                dim *= 2
            module_list.append(nn.MaxPool2d(2))
            module_list.append(nn.LeakyReLU(0.1, True))
            image_size //= 2
        module_list.append(nn.Conv2d(dim, cnn_feature_size, image_size, 1, bias=False))
        module_list.append(nn.BatchNorm2d(cnn_feature_size))
        module_list.append(nn.LeakyReLU(0.1, True))

        self.module_list = nn.ModuleList(module_list)

    def forward(self, inputs):
        for l in self.module_list:
            inputs = l(inputs)

        outputs = inputs.view(inputs.size(0), -1)
        return outputs

    def freeze_weight(self):
        for p in self.parameters():
            p.requires_grad = False
    
class GNN(myModel):
    def __init__(self, cnn_feature_size, gnn_feature_size, nway):
        super(GNN, self).__init__()

        num_inputs = cnn_feature_size + nway
        graph_conv_layer = 2
        self.gnn_obj = GNN_module(nway=nway, input_dim=num_inputs, 
            hidden_dim=gnn_feature_size, 
            num_layers=graph_conv_layer, 
            feature_type='dense')

    def forward(self, inputs):
        logits = self.gnn_obj(inputs).squeeze(-1)

        return logits
      
class gnnModel(myModel):
    def __init__(self, nway = 5):
        super(myModel, self).__init__()
        image_size = 32
        cnn_feature_size = 64
        cnn_hidden_dim = 32
        cnn_num_layers = 3

        gnn_feature_size = 32

        self.cnn_feature = EmbeddingCNN(image_size, cnn_feature_size, cnn_hidden_dim, cnn_num_layers)
        self.gnn = GNN(cnn_feature_size, gnn_feature_size, nway)

    def forward(self, data):
        [x, _, _, _, xi, _, one_hot_yi,_] = data

        z = self.cnn_feature(x)
        zi_s = [self.cnn_feature(xi[:, i, :, :, :]) for i in range(xi.size(1))]

        zi_s = torch.stack(zi_s, dim=1)


        # follow the paper, concatenate the information of labels to input features
        uniform_pad = torch.FloatTensor(one_hot_yi.size(0), 1, one_hot_yi.size(2)).fill_(
            1.0/one_hot_yi.size(2))
        uniform_pad = tensor2cuda(uniform_pad)

        labels = torch.cat([uniform_pad, one_hot_yi], dim=1)
        features = torch.cat([z.unsqueeze(1), zi_s], dim=1)

        nodes_features = torch.cat([features, labels], dim=2)

        out_logits = self.gnn(inputs=nodes_features)
        logsoft_prob = F.log_softmax(out_logits, dim=1)

        return logsoft_prob

class Trainer():
    def __init__(self, trainer_dict):

        self.num_labels = 100

        
        self.tr_dataloader = trainer_dict['tr_dataloader']

        
        Model = gnnModel
        
        self.model = Model(nway=20)

        self.total_iter = 0
        self.sample_size = 32

    def load_model(self, model_dir):
        self.model.load(model_dir)

        print('load model sucessfully...')

    def load_pretrain(self, model_dir):
        self.model.cnn_feature.load(model_dir)

        print('load pretrain feature sucessfully...')
    
    def model_cuda(self):
        if torch.cuda.is_available():
            self.model.cuda()

    def eval(self, dataloader, test_sample):
        self.model.eval()
        iteration = int(test_sample/16)

        total_loss = 0.0
        total_sample = 0
        total_correct = 0
        with torch.no_grad():
            for i in range(iteration):
                data = dataloader.load_te_batch(batch_size=16, 
                    nway=20, num_shots=5)

                data_cuda = [tensor2cuda(_data) for _data in data]

                logsoft_prob = self.model(data_cuda)

                label = data_cuda[1]
                loss = F.nll_loss(logsoft_prob, label)

                total_loss += loss.item() * logsoft_prob.shape[0]

                pred = torch.argmax(logsoft_prob, dim=1)

                # print(pred)

                # print(torch.eq(pred, label).float().sum().item())
                # print(label)

                assert pred.shape == label.shape

                total_correct += torch.eq(pred, label).float().sum().item()
                total_sample += pred.shape[0]
        print('correct: %d / %d' % (total_correct, total_sample))
        print(total_correct)
        return total_loss / total_sample, 100.0 * total_correct / total_sample

    def train_batch(self):
        self.model.train()
        

        data = self.tr_dataloader.load_tr_batch(batch_size=16, 
            nway=20, num_shots=5)

        data_cuda = [tensor2cuda(_data) for _data in data]

        self.opt.zero_grad()

        logsoft_prob = self.model(data_cuda)

        # print('pred', torch.argmax(logsoft_prob, dim=1))
        # print('label', data[2])
        label = data_cuda[1]

        loss = F.nll_loss(logsoft_prob, label)
        loss.backward()
        self.opt.step()

        return loss.item()

    def train(self):

        best_loss = 1e8
        best_acc = 0.0
        stop = 0
        eval_sample = 5000
        self.model_cuda()
        

        self.opt = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.model.parameters()), 
            lr=1e-2,
            weight_decay=1e-6)
        # self.opt = torch.optim.Adam(self.model.parameters(), lr=self.args.lr, 
        #     weight_decay=1e-6)

        start = time()
        tr_loss_list = []
        for i in range(100000):
            
            tr_loss = self.train_batch()
            tr_loss_list.append(tr_loss)

            if i % 100 == 0:
                del tr_loss_list[:]
                start = time()  

            if i % 2000 == 0:
                va_loss, va_acc = self.eval(self.tr_dataloader, eval_sample)

                if va_loss < best_loss:
                    stop = 0
                    best_loss = va_loss
                    best_acc = va_acc

                stop += 1
                start = time()
            
                if stop > 5:
                    break

            self.total_iter += 1


    def test(self, test_data_array, te_dataloader):
        self.model_cuda()
        self.model.eval()
        start = 0
        end = 0

        batch_size = 16
        pred_list = []

        with torch.no_grad():
            while start < test_data_array.shape[0]:
                end = start + batch_size 
                if end >= test_data_array.shape[0]:
                    batch_size = test_data_array.shape[0] - start

                data = te_dataloader.load_te_batch(batch_size=batch_size, nway=20, 
                    num_shots=5)

                test_x = test_data_array[start:end]

                data[0] = np2cuda(test_x)

                data_cuda = [tensor2cuda(_data) for _data in data]

                map_label2class = data[-1].cpu().numpy()

                logsoft_prob = self.model(data_cuda)
                # print(logsoft_prob)
                pred = torch.argmax(logsoft_prob, dim=1).cpu().numpy()

                pred = map_label2class[range(len(pred)), pred]

                pred_list.append(pred)

                start = end

        return np.hstack(pred_list)

    def pretrain_eval(self, loader, cnn_feature, classifier):
        total_loss = 0 
        total_sample = 0
        total_correct = 0

        with torch.no_grad():

            for j, (data, label) in enumerate(loader):
                data = tensor2cuda(data)
                label = tensor2cuda(label)
                output = classifier(cnn_feature(data))
                output = F.log_softmax(output, dim=1)
                loss = F.nll_loss(output, label)

                total_loss += loss.item() * output.shape[0]

                pred = torch.argmax(output, dim=1)

                assert pred.shape == label.shape

                total_correct += torch.eq(pred, label).float().sum().item()
                total_sample += pred.shape[0]

        return total_loss / total_sample, 100.0 * total_correct / total_sample

    def pretrain(self, pretrain_dataset, test_dataset):
        pretrain_loader = torch.utils.data.DataLoader(pretrain_dataset, 
                batch_size=16, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, 
                        batch_size=16, shuffle=True)

        self.model_cuda()

        best_loss = 1e8
        

        cnn_feature = self.model.cnn_feature
        classifier = nn.Linear(list(cnn_feature.parameters())[-3].shape[0], self.num_labels)
        
        if torch.cuda.is_available():
            classifier.cuda()
        self.pretrain_opt =  torch.optim.Adam(
            list(cnn_feature.parameters()) + list(classifier.parameters()), 
            lr=1e-2, 
            weight_decay=1e-6)

        start = time()

        for i in range(10000):
            total_tr_loss = []
            for j, (data, label) in enumerate(pretrain_loader):
                data = tensor2cuda(data)
                label = tensor2cuda(label)
                output = classifier(cnn_feature(data))

                output = F.log_softmax(output, dim=1)
                loss = F.nll_loss(output, label)

                self.pretrain_opt.zero_grad()
                loss.backward()
                self.pretrain_opt.step()
                total_tr_loss.append(loss.item())

            te_loss, te_acc = self.pretrain_eval(test_loader, cnn_feature, classifier)

            if te_loss < best_loss:
                stop = 0
                best_loss = te_loss
                
                

            stop += 1
            start = time()
        
            if stop > 5:
                break





b_s = 10
nway = 5
shots = 5
batch_x = torch.rand(b_s, 3, 32, 32)
batches_xi = [torch.rand(b_s, 3, 32, 32) for i in range(nway*shots)]

label_x = torch.rand(b_s, nway)

labels_yi = [torch.rand(b_s, nway) for i in range(nway*shots)]

print('create model...')
model = gnnModel(128)

#print(model([batch_x, label_x, None, None, batches_xi, labels_yi, None]).shape)

create model...


In [30]:


tr_dataloader = self_DataLoader('data', 
    train=True, dataset='cifar100', seed=1, nway=20)

trainer_dict = {'tr_dataloader': tr_dataloader}

trainer = Trainer(trainer_dict)

    ###########################################
    ## pretrain CNN embedding

    

    
#pretr_tr_data, pretr_tr_label = tr_dataloader.get_full_data_list() # already shuffled the data

#va_size = int(0.1 * len(pretr_tr_data))

#pretr_tr_dataset = self_Dataset(pretr_tr_data[va_size:], pretr_tr_label[va_size:])
#pretr_va_dataset = self_Dataset(pretr_tr_data[:va_size], pretr_tr_label[:va_size])



#trainer.pretrain(pretr_tr_dataset, pretr_va_dataset)



    ###########################################
    ## load the model trained before

    
    #model_path = os.path.join('', 'model.pth')
    #trainer.load_model(model_path)

    ###########################################
    ## start training

trainer.train()

selected labeled [17, 72, 97, 8, 32, 15, 63, 57, 60, 83, 48, 26, 12, 62, 3, 49, 55, 77, 98, 0]
Files already downloaded and verified
50000
full_data_num: 40000
few_data_num: 10000
correct: 229 / 4992
229.0


KeyboardInterrupt: ignored

In [None]:
te_dataloader = self_DataLoader('data', 
    train=False, dataset='cifar100', seed=1, nway=20)


test_data_list, test_label_list = te_dataloader.get_few_data_list()

test_data_array, test_label_array = np.stack(test_data_list), np.hstack(test_label_list)


test_pred = trainer.test(test_data_array, te_dataloader)

print(test_pred.shape, test_label_array.shape)

correct = (test_pred == test_label_array).sum()
test_acc = (test_pred == test_label_array).mean() * 100.0

print('test_acc: %.4f %%, correct: %d / %d' % (test_acc, correct, len(test_label_array)))

