In [1]:
import torch
from torch.utils.data import Dataset
import glob
import argparse
import os
import argparse

import numpy as np
import pickle
from PIL import Image
from tqdm import tqdm


In [2]:
cd ..\..

E:\CVprojects\Butterflies


In [3]:
# !git clone https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch.git

In [4]:
def load_dict(filename_):
    with open(filename_, 'rb') as f:
        ret_di = pickle.load(f)
    return ret_di
def save_dict(di_, filename_):
    with open(filename_, 'wb') as f:
        pickle.dump(di_, f)

In [5]:
from src.Prototypical_Networks_for_Few_shot_Learning_PyTorch.src.prototypical_batch_sampler import PrototypicalBatchSampler
from src.Prototypical_Networks_for_Few_shot_Learning_PyTorch.src.prototypical_loss import prototypical_loss as loss_fn
from src.Prototypical_Networks_for_Few_shot_Learning_PyTorch.src.protonet import ProtoNet
from src.Prototypical_Networks_for_Few_shot_Learning_PyTorch.src.parser_util import get_parser

In [6]:
cd src

E:\CVprojects\Butterflies\src


In [7]:
from Butterfly200DataSet import Butterfly200DataSet

In [8]:
cd ..

E:\CVprojects\Butterflies


In [9]:
split_dict_name = os.path.join('configs\\splits',"split_dict2.pkl")
split_dict_path = os.path.join(os.getcwd(),split_dict_name)
split_dict_path

'E:\\CVprojects\\Butterflies\\configs\\splits\\split_dict2.pkl'

In [10]:
def init_seed(opt):
    torch.cuda.cudnn_enabled = True
    np.random.seed(opt.manual_seed)
    torch.manual_seed(opt.manual_seed)
    torch.cuda.manual_seed(opt.manual_seed)
def init_dataset(opt, mode):
    dataset = Butterfly200DataSet(split_dict_path,mode=mode)
    n_classes = len(np.unique(dataset.y))
    if n_classes < opt.classes_per_it_tr or n_classes < opt.classes_per_it_val:
        raise(Exception('There are not enough classes in the dataset in order ' +
                        'to satisfy the chosen classes_per_it. Decrease the ' +
                        'classes_per_it_{tr/val} option and try again.'))
    return dataset


def init_sampler(opt, labels, mode):
    if 'train' in mode:
        classes_per_it = opt.classes_per_it_tr
        num_samples = opt.num_support_tr + opt.num_query_tr
    else:
        classes_per_it = opt.classes_per_it_val
        num_samples = opt.num_support_val + opt.num_query_val

    return PrototypicalBatchSampler(labels=labels,
                                    classes_per_it=classes_per_it,
                                    num_samples=num_samples,
                                    iterations=opt.iterations)
def init_dataloader(opt, mode):
    dataset = init_dataset(opt, mode)
    sampler = init_sampler(opt, dataset.y, mode)
    dataloader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler)
    return dataloader
def init_protonet(opt):
    '''
    Initialize the ProtoNet
    '''
    device = 'cuda:0' if torch.cuda.is_available() and opt.cuda else 'cpu'
    model = ProtoNet(x_dim=128).to(device)
    return model
def init_optim(opt, model):
    '''
    Initialize optimizer
    '''
    return torch.optim.Adam(params=model.parameters(),
                            lr=opt.learning_rate)
def init_lr_scheduler(opt, optim):
    '''
    Initialize the learning rate scheduler
    '''
    return torch.optim.lr_scheduler.StepLR(optimizer=optim,
                                           gamma=opt.lr_scheduler_gamma,
                                           step_size=opt.lr_scheduler_step)
def save_list_to_file(path, thelist):
    with open(path, 'w') as f:
        for item in thelist:
            f.write("%s\n" % item)

In [11]:
def train(opt, tr_dataloader, model, optim, lr_scheduler, val_dataloader=None):
    '''
    Train the model with the prototypical learning algorithm
    '''

    device = 'cuda:0' if torch.cuda.is_available() and opt.cuda else 'cpu'

    if val_dataloader is None:
        best_state = None
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    best_acc = 0

    best_model_path = os.path.join(opt.experiment_root, 'best_model.pth')
    last_model_path = os.path.join(opt.experiment_root, 'last_model.pth')
    for epoch in range(opt.epochs):
        print('=== Epoch: {} ==='.format(epoch))
        tr_iter = iter(tr_dataloader)
        model.train()
        
        for batch in tqdm(tr_iter):
            optim.zero_grad()
            x, y = batch
            x, y = x.to(device), y.to(device)
#             emb_model(x)
#             model_output = model(activation['layer2'])
            model_output = model(x)
            loss, acc = loss_fn(model_output, target=y,
                                n_support=opt.num_support_tr)
            loss.backward()
            optim.step()
            train_loss.append(loss.item())
            train_acc.append(acc.item())
            torch.cuda.empty_cache()
        avg_loss = np.mean(train_loss[-opt.iterations:])
        avg_acc = np.mean(train_acc[-opt.iterations:])
        print('Avg Train Loss: {}, Avg Train Acc: {}'.format(avg_loss, avg_acc))
        lr_scheduler.step()
        if val_dataloader is None:
            continue
        val_iter = iter(val_dataloader)
        model.eval()
        for batch in tqdm(val_iter):
            x, y = batch
            x, y = x.to(device), y.to(device)
#             emb_model(x)
#             model_output = model(activation['layer2'])
            model_output = model(x)
            loss, acc = loss_fn(model_output, target=y,
                                n_support=opt.num_support_val)
            val_loss.append(loss.item())
            val_acc.append(acc.item())
            torch.cuda.empty_cache()
        avg_loss = np.mean(val_loss[-opt.iterations:])
        avg_acc = np.mean(val_acc[-opt.iterations:])
        postfix = ' (Best)' if avg_acc >= best_acc else ' (Best: {})'.format(
            best_acc)
        print('Avg Val Loss: {}, Avg Val Acc: {}{}'.format(
            avg_loss, avg_acc, postfix))
        if avg_acc >= best_acc:
            torch.save(model.state_dict(), best_model_path)
            best_acc = avg_acc
            best_state = model.state_dict()

    torch.save(model.state_dict(), last_model_path)

    for name in ['train_loss', 'train_acc', 'val_loss', 'val_acc']:
        save_list_to_file(os.path.join(opt.experiment_root,
                                       name + '.txt'), locals()[name])

    return best_state, best_acc, train_loss, train_acc, val_loss, val_acc


def test(opt, test_dataloader, model):
    '''
    Test the model trained with the prototypical learning algorithm
    '''
    device = 'cuda:0' if torch.cuda.is_available() and opt.cuda else 'cpu'
    avg_acc = list()
    for epoch in tqdm(range(opt.test_epochs)):
        test_iter = iter(test_dataloader)
        for batch in test_iter:
            x, y = batch
            x, y = x.to(device), y.to(device)
            model_output = model(x)
            _, acc = loss_fn(model_output, target=y,
                             n_support=opt.num_support_val)
            avg_acc.append(acc.item())
    avg_acc = np.mean(avg_acc)
    print('Test Acc: {}'.format(avg_acc))

    return avg_acc


def eval(opt):
    '''
    Initialize everything and train
    '''
    options = get_parser().parse_args()

    if torch.cuda.is_available() and not options.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    init_seed(options)
    test_dataloader = init_dataset(options)[-1]
    model = init_protonet(options)
    model_path = os.path.join(opt.experiment_root, 'best_model.pth')
    model.load_state_dict(torch.load(model_path))

    test(opt=options,
         test_dataloader=test_dataloader,
         model=model)




In [12]:
d = {
    'experiment_root':'base_exp',
     'epochs':200,
     'test_epochs':20,
    'iterations':200,
     'cuda':True,
     'classes_per_it_tr':5,
    'lr_scheduler_step':20,
    'num_query_tr':2,
    'num_support_tr':1,
         'num_query_val':1,
     'num_support_val':1,
    'learning_rate':0.0001,
#      'check_point_path':'base_exp/best_after_blur_t48_val50.pth'
     }
options = get_parser().parse_args('')

for key, value in d.items():
        setattr(options, key, value)

In [13]:
options.learning_rate

0.0001

In [14]:
torch.cuda.empty_cache()

In [15]:
split_dict_path

'E:\\CVprojects\\Butterflies\\configs\\splits\\split_dict2.pkl'

In [16]:
'''
Initialize everything and train
'''
# options = get_parser().parse_args()
if not os.path.exists(options.experiment_root):
    os.makedirs(options.experiment_root)

if torch.cuda.is_available() and not options.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

init_seed(options)

tr_dataloader = init_dataloader(options, 'train')
val_dataloader = init_dataloader(options, 'val')
# test_dataloader = init_dataloader(options, 'test')



100%|█████████████████████████████████████████████████████████████████████████████████| 60/60 [00:00<00:00, 171.92it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 46/46 [00:00<00:00, 238.33it/s]


In [17]:
from src.models.EncoderProtoNet import EncoderProtoNet

In [18]:
model = EncoderProtoNet(proto_x_dim=128,proto_hid_dim=64,proto_z_dim=32)
# model.load_state_dict(torch.load('E:\\CVprojects\\Butterflies\\src\\base_exp\\best_model_tr88_val87.pth'))


enc_weights_path = 'base_exp\\best_model_embed_res.pth'
model.load_encoder_weights(enc_weights_path)


In [19]:
model.encoder.requires_gradient = True
model = model.to('cuda')

In [20]:
model

EncoderProtoNet(
  (encoder): SubResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True

In [21]:
optim = init_optim(options, model)
lr_scheduler = init_lr_scheduler(options, optim)

In [22]:
res = train(opt=options,
            tr_dataloader=tr_dataloader,
            val_dataloader=val_dataloader,
            model=model,
            optim=optim,
            lr_scheduler=lr_scheduler)
best_state, best_acc, train_loss, train_acc, val_loss, val_acc = res

=== Epoch: 0 ===


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:48<00:00,  4.10it/s]


Avg Train Loss: 395.06485773086547, Avg Train Acc: 0.31100000616163015


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:16<00:00, 12.06it/s]


Avg Val Loss: 446.25671159267426, Avg Val Acc: 0.38600000940263274 (Best)
=== Epoch: 1 ===


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:49<00:00,  4.00it/s]


Avg Train Loss: 179.89413833618164, Avg Train Acc: 0.39200000688433645


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:17<00:00, 11.74it/s]


Avg Val Loss: 282.2200828320533, Avg Val Acc: 0.4540000119060278 (Best)
=== Epoch: 2 ===


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:51<00:00,  3.92it/s]


Avg Train Loss: 111.10251488685608, Avg Train Acc: 0.4815000064298511


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:17<00:00, 11.67it/s]


Avg Val Loss: 232.80025629043578, Avg Val Acc: 0.49400001257658005 (Best)
=== Epoch: 3 ===


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:57<00:00,  3.50it/s]


Avg Train Loss: 79.55961943149566, Avg Train Acc: 0.5600000068545341


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:20<00:00,  9.54it/s]


Avg Val Loss: 205.0301984171197, Avg Val Acc: 0.4780000109225512 (Best: 0.49400001257658005)
=== Epoch: 4 ===


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [01:01<00:00,  3.26it/s]


Avg Train Loss: 57.121012983322146, Avg Train Acc: 0.6150000025704503


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:19<00:00, 10.07it/s]


Avg Val Loss: 179.5149845993519, Avg Val Acc: 0.48700001172721386 (Best: 0.49400001257658005)
=== Epoch: 5 ===


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [01:01<00:00,  3.26it/s]


Avg Train Loss: 39.74523802812398, Avg Train Acc: 0.7030000014603138


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:23<00:00,  8.64it/s]


Avg Val Loss: 160.01845502042744, Avg Val Acc: 0.4820000118762255 (Best: 0.49400001257658005)
=== Epoch: 6 ===


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [01:02<00:00,  3.21it/s]


Avg Train Loss: 33.99303068153468, Avg Train Acc: 0.7250000002980233


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:20<00:00,  9.79it/s]


Avg Val Loss: 177.0726792317629, Avg Val Acc: 0.45900000996887685 (Best: 0.49400001257658005)
=== Epoch: 7 ===


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [01:00<00:00,  3.29it/s]


Avg Train Loss: 25.949398875782776, Avg Train Acc: 0.7699999976158142


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:19<00:00, 10.35it/s]


Avg Val Loss: 172.43178587335348, Avg Val Acc: 0.47400001049041746 (Best: 0.49400001257658005)
=== Epoch: 8 ===


 20%|███████████████▊                                                                 | 39/200 [00:11<00:46,  3.49it/s]


KeyboardInterrupt: 

In [61]:
model.load_state_dict(torch.load('base_exp/best_model_95val_82tr.pth'))

<All keys matched successfully>

In [62]:
model.encoder.requires_gradient = True

In [63]:
test_dataloader = init_dataloader(options, 'test')




100%|██████████████████████████████████████████████████████████████████████████████| 4168/4168 [01:30<00:00, 46.24it/s]


In [64]:
print('Testing with last model..')
test(opt=options,
      test_dataloader=test_dataloader,
      model=model)

Testing with last model..


100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [02:07<00:00,  6.37s/it]

Test Acc: 0.9118999935984612





0.9118999935984612

In [None]:
model.load_state_dict(best_state)
print('Testing with best model..')
test(opt=options,
      test_dataloader=test_dataloader,
      model=model)