<a href="https://colab.research.google.com/github/Jf-Chen/colabEdit/blob/main/GLoFA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pwd

'/content'

In [2]:
!mkdir /content/run

In [3]:
cd /content/run

/content/run


In [4]:
# 下载代码
!git clone https://github.com/Jf-Chen/colabEdit.git

Cloning into 'colabEdit'...
remote: Enumerating objects: 24, done.[K
remote: Counting objects: 100% (24/24), done.[K
remote: Compressing objects: 100% (19/19), done.[K
remote: Total 24 (delta 0), reused 18 (delta 0), pack-reused 0[K
Unpacking objects: 100% (24/24), done.


In [5]:
# 下载数据集
!gdown --id 1cy5mK_ALMgWBoTgHGpjrPFhetnZZGi46

Downloading...
From: https://drive.google.com/uc?id=1cy5mK_ALMgWBoTgHGpjrPFhetnZZGi46
To: /content/run/mini_imagenet.tar
198MB [00:03, 51.0MB/s]


In [6]:
# 解压
!tar -xf mini_imagenet.tar

In [7]:
cd /content/run/colabEdit/GLoFA

/content/run/colabEdit/GLoFA


In [15]:
# 下载预训练网络权重
!gdown --i 1WCGj_Tu16n_vmv7JbnDJz8RqZKvUdYBn

Downloading...
From: https://drive.google.com/uc?id=1WCGj_Tu16n_vmv7JbnDJz8RqZKvUdYBn
To: /content/run/colabEdit/GLoFA/Res12-pre.pth
49.9MB [00:01, 49.1MB/s]


以下应该是main.py的内容

In [10]:
import os
import argparse
import random
import importlib

import numpy as np

import torch
from torch import nn
from torch import optim

from Train import train
from Test import test
from utils import global_variable as GV

def display_args(args):
    print('===== task arguments =====')
    print('data_name = %s' % (args.data_name))
    print('network_name = %s' % (args.network_name))
    print('model_name = %s' % (args.model_name))
    print('N = %d' % (args.N))
    print('K = %d' % (args.K))
    print('Q = %d' % (args.Q))
    print('===== experiment environment arguments =====')
    print('devices = %s' % str(args.devices))
    print('flag_debug = %r' % (args.flag_debug))
    print('n_workers = %d' % (args.n_workers))
    print('===== optimizer arguments =====')
    print('lr_network = %f' % (args.lr_network))
    print('lr = %f' % (args.lr))
    print('point = %s' % str(args.point))
    print('gamma = %f' % (args.gamma))
    print('wd = %f' % (args.wd))
    print('mo = %f' % (args.mo))
    print('===== training procedure arguments =====')
    print('n_training_episodes = %d' % (args.n_training_episodes))
    print('n_validating_episodes = %d' % (args.n_validating_episodes))
    print('n_testing_episodes = %d' % (args.n_testing_episodes))
    print('episode_gap = %d' % (args.episode_gap))
    print('===== model arguments =====')
    print('tau = %f' % (args.tau))
    print('delta = %f' % (args.delta))


# set random seed
random.seed(960402)
np.random.seed(960402)
torch.manual_seed(960402)
torch.cuda.manual_seed(960402)
torch.backends.cudnn.deterministic = True

# create a parser
parser = argparse.ArgumentParser()
# task arguments
parser.add_argument('--data_name', type=str, default='mini_imagenet', choices=['mini_imagenet'])
parser.add_argument('--network_name', type=str, default='resnet', choices=['resnet'])
parser.add_argument('--model_name', type=str, default='glofa', choices=['glofa'])
parser.add_argument('--N', type=int, default=5)
parser.add_argument('--K', type=int, default=1)
parser.add_argument('--Q', type=int, default=15)
# experiment environment arguments
parser.add_argument('--devices', type=int, nargs='+', default=GV.DEVICES)
parser.add_argument('--flag_debug', action='store_true', default=False)
parser.add_argument('--n_workers', type=int, default=GV.WORKERS)
# optimizer arguments
parser.add_argument('--lr_network', type=float, default=0.0001)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--point', type=int, nargs='+', default=(20,30,40))
parser.add_argument('--gamma', type=float, default=0.2)
parser.add_argument('--wd', type=float, default=0.0005)  # weight decay
parser.add_argument('--mo', type=float, default=0.9)  # momentum
# training procedure arguments
parser.add_argument('--n_training_episodes', type=int, default=10000)
parser.add_argument('--n_validating_episodes', type=int, default=200)
parser.add_argument('--n_testing_episodes', type=int, default=10000)
parser.add_argument('--episode_gap', type=int, default=200)
# model arguments
parser.add_argument('--tau', type=float, default=1)
parser.add_argument('--delta', type=float, default=1)

args,unknown = parser.parse_known_args()

display_args(args)

===== task arguments =====
data_name = mini_imagenet
network_name = resnet
model_name = glofa
N = 5
K = 1
Q = 15
===== experiment environment arguments =====
devices = [0]
flag_debug = False
n_workers = 8
===== optimizer arguments =====
lr_network = 0.000100
lr = 0.010000
point = (20, 30, 40)
gamma = 0.200000
wd = 0.000500
mo = 0.900000
===== training procedure arguments =====
n_training_episodes = 10000
n_validating_episodes = 200
n_testing_episodes = 10000
episode_gap = 200
===== model arguments =====
tau = 1.000000
delta = 1.000000


In [11]:
data_path = '/content/run/' + args.data_name + '/'
print(data_path)

/content/run/mini_imagenet/


In [12]:
# import modules
Data = importlib.import_module('dataloaders.' + args.data_name)
Network = importlib.import_module('networks.' + args.network_name)
Model = importlib.import_module('models.' + args.model_name)

In [13]:
# generate data loaders
train_data_loader = Data.generate_data_loader(data_path, 'train', args.n_training_episodes, args.N, args.K + args.Q)
validate_data_loader = Data.generate_data_loader(data_path, 'validate', args.n_validating_episodes, args.N, args.K + args.Q)
test_data_loader = Data.generate_data_loader(data_path, 'test', args.n_testing_episodes, args.N, args.K + args.Q)
print('===== data loader ready. =====')

===== data loader ready. =====


In [14]:
# generate network
network = Network.MyNetwork(args)
if len(args.devices) > 1:
    network = torch.nn.DataParallel(network, device_ids=args.devices)
print('===== network ready. =====')

===== network ready. =====


In [16]:
# generate model
model = Model.MyModel(args, network)
# pretrained_state_dict = torch.load('pretrained_weights.pth')['params']
pretrained_state_dict = torch.load('Res12-pre.pth')['params']
pretrained_state_dict = {k:v for k, v in pretrained_state_dict.items() if k.startswith('encoder')}
model_state_dict = model.state_dict()
model_state_dict.update(pretrained_state_dict)
model.load_state_dict(model_state_dict)
model = model.cuda(args.devices[0])
print('===== model ready. =====')

===== model ready. =====


In [17]:
model_save_path = 'saves/trained_models/' + \
                    args.data_name + '_' + args.network_name + '_' + args.model_name + \
                    '_N=' + str(args.N) + \
                    '_K=' + str(args.K) + \
                    '_Q=' + str(args.Q) + \
                    '_lr-net=' + str(args.lr_network) + \
                    '_lr=' + str(args.lr) + \
                    '_point=' + str(args.point) + \
                    '_gamma=' + str(args.gamma) + \
                    '_wd=' + str(args.wd) + \
                    '_mo=' + str(args.mo) + \
                    '_tau=' + str(args.tau) + \
                    '_delta=' + str(args.delta) + \
                    '.model'
statistic_save_path = 'saves/statistics/' + \
                        args.data_name + '_' + args.network_name + '_' + args.model_name + \
                        '_N=' + str(args.N) + \
                        '_K=' + str(args.K) + \
                        '_Q=' + str(args.Q) + \
                        '_lr-net=' + str(args.lr_network) + \
                        '_lr=' + str(args.lr) + \
                        '_point=' + str(args.point) + \
                        '_gamma=' + str(args.gamma) + \
                        '_wd=' + str(args.wd) + \
                        '_mo=' + str(args.mo) + \
                        '_tau=' + str(args.tau) + \
                        '.stat'

# create directories
dirs = os.path.dirname(model_save_path)
os.makedirs(dirs, exist_ok=True)
dirs = os.path.dirname(statistic_save_path)
os.makedirs(dirs, exist_ok=True)


In [None]:
完成训练过程需要49分钟

In [18]:
# training process
training_loss_list, validating_accuracy_list = train(args, train_data_loader, validate_data_loader, model,
    model_save_path)
if not args.flag_debug:
    record = {
        'training_loss': training_loss_list,
        'validating_accuracy': validating_accuracy_list
    }
    torch.save(record, statistic_save_path)

display_args(args)

epoch 1 finish: training loss = 1.133408, validating acc = 0.592600
epoch 2 finish: training loss = 1.045709, validating acc = 0.626400
epoch 3 finish: training loss = 1.066004, validating acc = 0.634333
epoch 4 finish: training loss = 0.988005, validating acc = 0.641533
epoch 5 finish: training loss = 0.919411, validating acc = 0.621600
epoch 6 finish: training loss = 0.931248, validating acc = 0.601400
epoch 7 finish: training loss = 0.937277, validating acc = 0.641333
epoch 8 finish: training loss = 0.924227, validating acc = 0.629267
epoch 9 finish: training loss = 0.921411, validating acc = 0.643000
epoch 10 finish: training loss = 0.899672, validating acc = 0.645933
epoch 11 finish: training loss = 0.892667, validating acc = 0.632933
epoch 12 finish: training loss = 0.869544, validating acc = 0.640000
epoch 13 finish: training loss = 0.865257, validating acc = 0.638333
epoch 14 finish: training loss = 0.847199, validating acc = 0.626400
epoch 15 finish: training loss = 0.854251, 