## CPEA
### Tiered imagenet training/testing

Imports

In [1]:
import os.path as osp

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from cpea import CPEA
from models.backbones import BackBone
from dataloader.samplers import CategoriesSampler
from utils import ensure_path, Averager, count_acc, compute_confidence_interval
from tensorboardX import SummaryWriter
from types import SimpleNamespace
import gc

In [2]:
args = SimpleNamespace(
    max_epoch=2,
    way=5,
    test_way=5,
    shot=1,
    query=15,
    lr=0.00001,
    lr_mul=100,
    step_size=5,
    gamma=0.5,
    model_type='small',
    dataset='FC100',
    init_weights='./initialization/fc100/checkpoint1600.pth',
    gpu='0',
    exp='CPEA'
)
save_path = '-'.join([args.exp, args.dataset, args.model_type])
args.save_path = osp.join('./results', save_path)
ensure_path(args.save_path)

torch.cuda.empty_cache()
gc.collect()
print(torch.cuda.memory_summary(device=None, abbreviated=False))

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------

Load dataset

In [3]:
from dataloader.fc100 import FC100 as Dataset

trainset = Dataset('train', args)
train_sampler = CategoriesSampler(trainset.label, 10, args.way, args.shot + args.query)
train_loader = DataLoader(dataset=trainset, batch_sampler=train_sampler, num_workers=8, pin_memory=True)

valset = Dataset('val', args)
val_sampler = CategoriesSampler(valset.label, 10, args.test_way, args.shot + args.query)
val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, num_workers=8, pin_memory=True)

Load model

In [4]:
model = BackBone(args)
dense_predict_network = CPEA()

optimizer = torch.optim.Adam([{'params': model.encoder.parameters()}], lr=args.lr, weight_decay=0.001)
print('Using {}'.format(args.model_type))

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)

dense_predict_network_optim = torch.optim.Adam(dense_predict_network.parameters(), lr=args.lr * args.lr_mul,
                                               weight_decay=0.001)
dense_predict_network_scheduler = torch.optim.lr_scheduler.StepLR(dense_predict_network_optim,
                                                                  step_size=args.step_size, gamma=args.gamma)

Using small


In [None]:
from cpea_hierarchical import HierarchicalCPEA as HCPEA


model = BackBone(args)
dense_predict_network = HCPEA(CPEA, 75, 60, 12)

optimizer = torch.optim.Adam([{'params': model.encoder.parameters()}], lr=args.lr, weight_decay=0.001)
print('Using {}'.format(args.model_type))

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)

dense_predict_network_optim = torch.optim.Adam(dense_predict_network.parameters(), lr=args.lr * args.lr_mul,
                                               weight_decay=0.001)
dense_predict_network_scheduler = torch.optim.lr_scheduler.StepLR(dense_predict_network_optim,
                                                                  step_size=args.step_size, gamma=args.gamma)

Initialize pretrained model

In [5]:
# load pre-trained model (no FC weights)
model_dict = model.state_dict()
# print(model_dict.keys())
if args.init_weights is not None:
    pretrained_dict = torch.load(args.init_weights, map_location='cpu')['teacher']
    # print(pretrained_dict.keys())
    pretrained_dict = {k.replace('backbone', 'encoder'): v for k, v in pretrained_dict.items()}
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    # print(pretrained_dict.keys())

'''
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    model = model.cuda()
    model = torch.nn.DataParallel(model)
    dense_predict_network = dense_predict_network.cuda()
'''

def save_model(name):
    torch.save(dict(params=model.state_dict()), osp.join(args.save_path, name + '.pth'))
    torch.save(dict(params=dense_predict_network.state_dict()),
               osp.join(args.save_path, name + '_dense_predict.pth'))

Initialize logging

In [6]:
trlog = {'args': vars(args), 'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'max_acc': 0.0,
         'max_acc_epoch': 0}
global_count = 0
writer = SummaryWriter(comment=args.save_path)
torch.cuda.empty_cache()

Training

In [7]:
torch.cuda.empty_cache()
gc.collect()

9

In [None]:
for epoch in range(1, args.max_epoch + 1):
    print(f"Epoch {epoch}")
    # lr_scheduler.step()
    # dense_predict_network_scheduler.step()
    model.train()
    dense_predict_network.train()
    tl = Averager()
    ta = Averager()

    for i, batch in enumerate(train_loader, 1):
        optimizer.zero_grad()
        dense_predict_network_optim.zero_grad()
        global_count = global_count + 1
        # if torch.cuda.is_available():
        #     data, _ = [_.cuda() for _ in batch]
        # else:
        data = batch[0]
        labels = batch[1]
        
        p = args.shot * args.way
        
        data_shot, data_query = data[:p], data[p:]
        labels_shot, labels_query = labels[:p],labels[p:]
        
        feat_shot, feat_query = model(data_shot, data_query)
        results, _ = dense_predict_network(feat_query, feat_shot, args)
        results = torch.cat(results, dim=0)  # Q x S
        # label = torch.arange(args.way).repeat(args.query).long().to('cuda')
        label = torch.arange(args.way).repeat(args.query).long()

        eps = 0.1
        one_hot = torch.zeros_like(results).scatter(1, label.view(-1, 1), 1)
        one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (args.way - 1)
        log_prb = F.log_softmax(results, dim=1)

        loss = -(one_hot * log_prb).sum(dim=1)
        loss = loss.mean()

        acc = count_acc(results.data, label)
        writer.add_scalar('data/loss', float(loss), global_count)
        writer.add_scalar('data/acc', float(acc), global_count)
        print('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}'.format(epoch, i, len(train_loader), loss.item(), acc))

        tl.add(loss.item())
        ta.add(acc)

        loss_total = loss

        loss_total.backward()
        optimizer.step()
        dense_predict_network_optim.step()

    lr_scheduler.step()
    dense_predict_network_scheduler.step()

    tl = tl.item()
    ta = ta.item()

    model.eval()
    dense_predict_network.eval()

    vl = Averager()
    va = Averager()

    print('best epoch {}, best val acc={:.4f}'.format(trlog['max_acc_epoch'], trlog['max_acc']))
    with torch.no_grad():
        for i, batch in enumerate(val_loader, 1):
            # if torch.cuda.is_available():
            #     data, _ = [_.cuda() for _ in batch]
            # else:
            data = batch[0]
            p = args.shot * args.test_way
            data_shot, data_query = data[:p], data[p:]
            feat_shot, feat_query = model(data_shot, data_query)

            results, _ = dense_predict_network(feat_query, feat_shot, args)  # Q x S

            results = [torch.mean(idx, dim=0, keepdim=True) for idx in results]

            results = torch.cat(results, dim=0)  # Q x S
            # label = torch.arange(args.test_way).repeat(args.query).long().to('cuda')
            label = torch.arange(args.test_way).repeat(args.query).long()

            loss = F.cross_entropy(results, label)
            acc = count_acc(results.data, label)
            vl.add(loss.item())
            va.add(acc)

    vl = vl.item()
    va = va.item()
    writer.add_scalar('data/val_loss', float(vl), epoch)
    writer.add_scalar('data/val_acc', float(va), epoch)
    print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va))

    if va >= trlog['max_acc']:
        trlog['max_acc'] = va
        trlog['max_acc_epoch'] = epoch
        save_model('max_acc')

    trlog['train_loss'].append(tl)
    trlog['train_acc'].append(ta)
    trlog['val_loss'].append(vl)
    trlog['val_acc'].append(va)

    torch.save(trlog, osp.join(args.save_path, 'trlog'))
    save_model('epoch-last')

Testing

In [14]:
trlog = torch.load(osp.join(args.save_path, 'trlog'))
test_set = Dataset('test', args)
sampler = CategoriesSampler(test_set.label, 50, args.test_way, args.shot + args.query)
loader = DataLoader(test_set, batch_sampler=sampler, num_workers=8, pin_memory=True)
test_acc_record = np.zeros((50,))

model.load_state_dict(torch.load(osp.join(args.save_path, 'max_acc' + '.pth'))['params'])
model.eval()

dense_predict_network.load_state_dict(
    torch.load(osp.join(args.save_path, 'max_acc' + '_dense_predict.pth'))['params'])
dense_predict_network.eval()

CPEA(
  (fc1): Mlp(
    (fc1): Linear(in_features=384, out_features=96, bias=True)
    (act): GELU(approximate='none')
    (fc2): Linear(in_features=96, out_features=384, bias=True)
    (drop): Dropout(p=0.1, inplace=False)
  )
  (fc_norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
  (fc2): Mlp(
    (fc1): Linear(in_features=38416, out_features=256, bias=True)
    (act): GELU(approximate='none')
    (fc2): Linear(in_features=256, out_features=1, bias=True)
    (drop): Dropout(p=0.1, inplace=False)
  )
)

In [15]:
ave_acc = Averager()
label = torch.arange(args.test_way).repeat(args.query)

with torch.no_grad():
    for i, batch in enumerate(loader, 1):
        
        # if torch.cuda.is_available():
        #     data, _ = [_.cuda() for _ in batch]
        # else:
        data = batch[0]
        k = args.test_way * args.shot
        data_shot, data_query = data[:k], data[k:]
        feat_shot, feat_query = model(data_shot, data_query)

        results, _ = dense_predict_network(feat_query, feat_shot, args)  # Q x S
        results = [torch.mean(idx, dim=0, keepdim=True) for idx in results]
        results = torch.cat(results, dim=0)  # Q x S
        # label = torch.arange(args.test_way).repeat(args.query).long().to('cuda')
        label = torch.arange(args.test_way).repeat(args.query).long()

        acc = count_acc(results.data, label)
        ave_acc.add(acc)
        test_acc_record[i - 1] = acc
        print('batch {}: acc {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100))

m, pm = compute_confidence_interval(test_acc_record)
print('Val Best Epoch {}, Acc {:.4f}'.format(trlog['max_acc_epoch'], trlog['max_acc']))
print('Test Acc {:.4f} + {:.4f}'.format(m * 100, pm * 100))

batch 1: acc 48.00(48.00)
batch 2: acc 50.67(53.33)
batch 3: acc 49.33(46.67)
batch 4: acc 50.00(52.00)
batch 5: acc 46.40(32.00)
batch 6: acc 47.56(53.33)
batch 7: acc 47.43(46.67)
batch 8: acc 46.83(42.67)
batch 9: acc 47.70(54.67)
batch 10: acc 47.20(42.67)
batch 11: acc 47.27(48.00)
batch 12: acc 48.33(60.00)
batch 13: acc 48.51(50.67)
batch 14: acc 48.76(52.00)
batch 15: acc 48.36(42.67)
batch 16: acc 49.00(58.67)
batch 17: acc 49.02(49.33)
batch 18: acc 48.44(38.67)
batch 19: acc 47.72(34.67)
batch 20: acc 47.80(49.33)
batch 21: acc 48.70(66.67)
batch 22: acc 47.88(30.67)
batch 23: acc 47.77(45.33)
batch 24: acc 48.11(56.00)
batch 25: acc 47.36(29.33)
batch 26: acc 47.13(41.33)
batch 27: acc 47.60(60.00)
batch 28: acc 48.48(72.00)
batch 29: acc 48.51(49.33)
batch 30: acc 49.20(69.33)
batch 31: acc 49.08(45.33)
batch 32: acc 49.71(69.33)
batch 33: acc 49.98(58.67)
batch 34: acc 50.04(52.00)
batch 35: acc 50.29(58.67)
batch 36: acc 49.96(38.67)
batch 37: acc 50.09(54.67)
batch 38: 