In [1]:
import easydict
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

from maml.utils import load_dataset, load_model

%load_ext autoreload
%autoreload 2

In [2]:
args = easydict.EasyDict({'folder': '/home/osilab7/hdd/ml_dataset',
                          'dataset': 'miniimagenet',
                          'model': 'resnet',
                          'device': 'cuda:1',
                          'download': True,
                          'num_shots': 5,
                          'num_ways': 5,
                          'meta_lr': 1e-3,
                          'first_order': False,
                          'extractor_step_size': 0.5,
                          'classifier_step_size': 0.5,
                          'hidden_size': 64,
                          'output_folder': './output/',
                          'save_name': None,
                          'batch_size': 4,
                          'batch_iter': 1200,
                          'train_batches': 50,
                          'valid_batches': 25,
                          'test_batches': 2500,
                          'num_workers': 1,
                          'init': False})

In [3]:
device = torch.device(args.device)

args.num_ways = 64
pretrained_model = load_model(args)
pretrained_model_optimizer = torch.optim.Adam(pretrained_model.parameters(), lr=1e-2)
pretrained_model.to(device=device)
pretrained_model.train()

num_epochs = 10

dataset = load_dataset(args, 'meta_train')
data = torch.cat([torch.from_numpy(np.array(dataset.dataset.__getitem__(i).data))/255. for i in range(args.num_ways)], dim=0).permute(0, 3, 1, 2).float()
labels = torch.tensor(sum([[i]*600 for i in range(args.num_ways)], []))

dataloader = DataLoader(dataset=list(zip(data,labels)), batch_size=128, shuffle=True)
for _ in tqdm(range(num_epochs)):
    accuracy = []
    for input, target in dataloader:
        input = input.type(torch.FloatTensor).to(device=device)
        target = target.type(torch.LongTensor).to(device=device)

        features, logit = pretrained_model(input)
        loss = F.cross_entropy(logit, target)
        accuracy += (target==torch.max(logit, dim=1)[1]).tolist()

        pretrained_model_optimizer.zero_grad()
        loss.backward()
        pretrained_model_optimizer.step()
    print (sum(accuracy)/float(len(accuracy)))
torch.save(pretrained_model.state_dict(), 'resnet_pretrained.pt')

 10%|█         | 1/10 [00:50<07:38, 50.96s/it]

0.15260416666666668


 20%|██        | 2/10 [01:42<06:50, 51.27s/it]

0.2998177083333333


 30%|███       | 3/10 [02:34<05:59, 51.40s/it]

0.4059375


 40%|████      | 4/10 [03:25<05:08, 51.38s/it]

0.47869791666666667


 50%|█████     | 5/10 [04:17<04:16, 51.32s/it]

0.5309114583333333


 60%|██████    | 6/10 [05:08<03:24, 51.24s/it]

0.5723177083333333


 70%|███████   | 7/10 [05:59<02:33, 51.17s/it]

0.6083333333333333


 80%|████████  | 8/10 [06:50<01:42, 51.11s/it]

0.6422135416666667


 90%|█████████ | 9/10 [07:41<00:51, 51.07s/it]

0.67328125


100%|██████████| 10/10 [08:32<00:00, 51.22s/it]

0.7017447916666667



