# Task + Query augmentation
https://github.com/RenkunNi/MetaAug/tree/79d1a6a457be37258df50a9194946caeb86845a2

In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torchvision.models import resnet50
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torch.nn.functional as F
from tqdm.auto import tqdm,trange
from pathlib import Path
import pandas as pd
import torch.nn.utils.prune as prune
import numpy as np
import random
import learn2learn as l2l
import pickle
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
ways = 5 
shots =5

In [4]:
with open("./data/test.pkl",'rb') as f:
    test_data = pickle.load(f)
with open("./data/train.pkl",'rb') as f:
    train_data = pickle.load(f)
with open("./data/validation.pkl",'rb') as f:
    val_data = pickle.load(f)

In [5]:
train_data['images'].shape

(38400, 3, 84, 84)

In [6]:
val_data['images'].shape

(9600, 3, 84, 84)

## Task augmentation

In [7]:
%%time
images = train_data["images"]
labels = train_data["labels"]
num_labels = np.unique(labels).shape[0]
num_new_labels = num_labels*3
n_data = int(len(labels)/num_labels)
for i in range(num_labels):
    images_of_one_class = images[i*n_data:(i+1)*n_data] #600
    new_img1 = np.array([np.rot90(img.copy(), axes=(2,1)) for img in images_of_one_class]) 
    new_img2 = np.array([np.rot90(img.copy(),2, axes=(2,1)) for img in images_of_one_class])
    new_img3 = np.array([np.rot90(img.copy(),3, axes=(2,1)) for img in images_of_one_class])
    train_data["images"] = np.concatenate((train_data["images"],new_img1,new_img2,new_img3))
    train_data["labels"] = np.concatenate((train_data["labels"],[num_labels+3*i]*n_data,[num_labels+3*i+1]*n_data,[num_labels+3*i+2]*n_data))
    
    # print(train_data['labels'].shape)
    # break

CPU times: user 32 s, sys: 38.2 s, total: 1min 10s
Wall time: 1min 10s


In [8]:
train_data["images"].shape,train_data["labels"].shape

((153600, 3, 84, 84), (153600,))

In [11]:
%%time
images = val_data["images"]
labels = val_data["labels"]
num_labels = np.unique(labels).shape[0]
num_new_labels = num_labels*3
n_data = int(len(labels)/num_labels)
for i in range(num_labels):
    images_of_one_class = images[i*n_data:(i+1)*n_data]
    new_img1 = np.array([np.rot90(img.copy(), axes=(2,1)) for img in images_of_one_class]) 
    new_img2 = np.array([np.rot90(img.copy(),2, axes=(2,1)) for img in images_of_one_class])
    new_img3 = np.array([np.rot90(img.copy(),3, axes=(2,1)) for img in images_of_one_class])
    val_data["images"] = np.concatenate((val_data["images"],new_img1,new_img2,new_img3))
    val_data["labels"] = np.concatenate((val_data["labels"],[num_labels+3*i]*n_data,[num_labels+3*i+1]*n_data,[num_labels+3*i+2]*n_data))

CPU times: user 2.34 s, sys: 3.13 s, total: 5.47 s
Wall time: 5.47 s


In [12]:
val_data["images"].shape,val_data["labels"].shape

((38400, 3, 84, 84), (38400,))

## Build dataset

In [13]:
class Dataset(Dataset):
    def __init__(self,data,transform = None):
        self.transform = transform
        self.data=data
        self.images = np.transpose(self.data['images'],(0,2,3,1))
        self.labels = data['labels']
    def __len__(self): 
        return len(self.data['images'])
    def __getitem__(self,index):
        img = self.images[index]
        label = self.labels[index]
        if self.transform:
            img = self.transform(img)
        return img, label

In [14]:
train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
        transforms.RandomHorizontalFlip(),
        # transforms.RandomAffine(degrees = 0, translate=(0.1,0.1)),
        # transforms.RandomRotation(20),#原本是10

    ])
val_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
        transforms.RandomHorizontalFlip(),
        # transforms.RandomRotation(20),#原本是10
        # transforms.RandomAffine(degrees = 0, translate=(0.1,0.1))
])

In [15]:
train_dataset = Dataset(train_data,train_transform)
val_dataset = Dataset(val_data, val_transform)

In [16]:
train_dataset = l2l.data.MetaDataset(train_dataset)
val_dataset = l2l.data.MetaDataset(val_dataset)

In [17]:
task_transforms = [
    l2l.data.transforms.FusedNWaysKShots(train_dataset, n=ways, k=2*shots),
    l2l.data.transforms.LoadData(train_dataset),
    l2l.data.transforms.RemapLabels(train_dataset),
    l2l.data.transforms.ConsecutiveLabels(train_dataset),
]
train_taskset =  l2l.data.TaskDataset(train_dataset, task_transforms, num_tasks=20000)
task_transforms = [
    l2l.data.transforms.FusedNWaysKShots(val_dataset, n = ways, k=2*shots),
    l2l.data.transforms.LoadData(val_dataset),
    l2l.data.transforms.RemapLabels(val_dataset),
    l2l.data.transforms.ConsecutiveLabels(val_dataset),
]
val_taskset =  l2l.data.TaskDataset(val_dataset, task_transforms, num_tasks=20000)

## Query augmentation

In [19]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2
def adaptation(task, learner, fas, loss_func, device, mode):
    data, labels = task
    data = data.to(device)
    labels = labels.to(device)
    sup_mask = np.array(([False]*5+[True]*5)*5)
    query_mask = ~sup_mask
    sup_data , query_data = data[sup_mask], data[query_mask]
    sup_labels , query_labels = labels[sup_mask], labels[query_mask]
    for step in range(fas): # inner loop
        pred = learner(sup_data)
        train_loss = loss_func(pred, sup_labels)
        learner.adapt(train_loss)
    r = np.random.rand(1)
    if  mode=="valid" or r > 0.5:
         # no query augmentation
        query_pred = learner(query_data)
        query_loss = loss_func(query_pred, query_labels)
    else:
        # generate mixed sample for query augmentation(CutMix)
        lam = np.random.beta(2., 2.)
        rand_index = torch.randperm(query_data.size()[0]).cuda()
        target_a = query_labels
        target_b = query_labels[rand_index]
        bbx1, bby1, bbx2, bby2 = rand_bbox(query_data.size(), lam)
        query_data[:, :, bbx1:bbx2, bby1:bby2] = query_data[rand_index, :, bbx1:bbx2, bby1:bby2]
        # adjust lambda to exactly match pixel ratio
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (query_data.size()[-1] * query_data.size()[-2]))
        # compute output
        query_pred = learner(query_data)
        query_loss = loss_func(query_pred, target_a) * lam + loss_func(query_pred, target_b) * (1. - lam)   
    query_acc = (torch.argmax(query_pred,1)==query_labels).sum()/len(query_labels)
    return query_loss, query_acc

In [None]:
def train(EPOCHS, meta_model, train_taskset, val_taskset, loss_func, opt, device, model_name):
    log_file = f"./log/{model_name}.txt"
    res = {
        "train_acc_list" : [],
        "val_acc_list" : [],
        "train_loss_list" : [],
        "val_loss_list" : [],
    }
    last_val_loss=float("inf")
    for epoch in trange(EPOCHS):
        opt.zero_grad()
        train_loss = 0
        train_acc = 0
        val_loss = 0
        val_acc = 0
        for _ in range(meta_batch):
            learner = meta_model.clone(first_order = False)
            task = train_taskset.sample() # sample 一組 [50,3,84,84],[50]
            query_loss, query_acc = adaptation(task, learner,fas, loss_func, device,mode="train")
            query_loss.backward()
            train_loss += query_loss.item()
            train_acc += query_acc.item()
            #validation
            learner = meta_model.clone(first_order = False)
            task = val_taskset.sample() # sample 一組 [50,3,84,84],[50]
            query_loss, query_acc = adaptation(task, learner,fas, loss_func, device,mode="valid")
            val_loss+=query_loss.item()
            val_acc +=query_acc.item()
        res["train_acc_list"].append((train_acc/meta_batch))
        res["val_acc_list"].append((val_acc/meta_batch))
        res["train_loss_list"].append((train_loss/meta_batch))
        res["val_loss_list"].append((val_loss/meta_batch))
        for p in meta_model.parameters():
            p.grad.data.mul_(1.0 / meta_batch)
        opt.step()
        
        # write log file
        out = {
            "Epoch" : epoch,
            "train accuracy" : (train_acc/meta_batch),
            "validation accuracy" : (val_acc/meta_batch),
            "train loss" : (train_loss/meta_batch),
            "validation loss" : (val_loss/meta_batch)
        }
        with open(log_file,"a") as f:
            f.write(str(out) + '\n')
            
        # print log
        if (epoch+1)%100==0:
            print(f"Epoch {epoch+1} | Train loss :{train_loss/meta_batch} | Train accuracy : {train_acc/meta_batch}")
            print(f"Epoch {epoch+1} | Validation loss :{val_loss/meta_batch} | Validation accuracy : {val_acc/meta_batch}")
        # save model
        if (epoch+1)%500==0:
            MODEL_PATH=f"./model/{model_name}"
            Path(MODEL_PATH).mkdir(parents=True, exist_ok=True)
            torch.save(meta_model,Path(MODEL_PATH)/f"model_{epoch+1}.pt")
        # early stop
        n = 750
        if len(res["val_loss_list"])>n:
            if (epoch+1)%n == 0:
                avg_val_loss = sum(res["val_loss_list"][-n:])/n
                print("Average val_loss of last 500 epoch:",avg_val_loss)
                if last_val_loss < avg_val_loss:
                    print("Early stop!!!")
                    return res
                last_val_loss = avg_val_loss
    return res

In [20]:
def test(meta_model, test_data, loss_func, device):
    test_loss = 0
    test_acc = 0
    test_sup_images = test_data['sup_images']
    test_sup_labels = test_data['sup_labels']
    test_qry_images = test_data['qry_images']
    all_pred = []
    for i in trange(len(test_sup_images)):
        learner = meta_model.clone(first_order = False)
        sup_image = torch.tensor(test_sup_images[i]).to(device)
        sup_label = torch.tensor(test_sup_labels[i]).to(device)
        qry_image = torch.tensor(test_qry_images[i]).to(device)
        for step in range(fas): # inner loop
            pred = learner(sup_image)
            train_loss = loss_func(pred, sup_label)
            learner.adapt(train_loss)
        output_pred = torch.argmax(learner(qry_image),1).reshape(-1,1).detach().cpu().numpy()
        all_pred += [item for sublist in output_pred for item in sublist]
    return all_pred

In [None]:
maml_lr = 0.5
lr = 1e-3
EPOCHS = 10000
meta_batch = 32
fas = 1
model = l2l.vision.models.ResNet12(output_size=ways).to(device)
meta_model = l2l.algorithms.MAML(model, lr=maml_lr,first_order=False)
opt = torch.optim.AdamW(meta_model.parameters(), lr=lr)
loss_func = nn.CrossEntropyLoss(reduction='mean')
results  = train(EPOCHS, meta_model, train_taskset, val_taskset, loss_func,opt, device, "resnet_task_augmentation_queryaug_v3")

  1%|▏                  | 100/10000 [31:29<51:53:24, 18.87s/it]

Epoch 100 | Train loss :1.659164372831583 | Train accuracy : 0.3149999922607094
Epoch 100 | Validation loss :1.578516859561205 | Validation accuracy : 0.30749999010004103


  2%|▎                | 200/10000 [1:02:56<51:37:59, 18.97s/it]

Epoch 200 | Train loss :1.4744772650301456 | Train accuracy : 0.36874999152496457
Epoch 200 | Validation loss :1.482232816517353 | Validation accuracy : 0.35499999043531716


  3%|▌                | 300/10000 [1:34:23<50:52:48, 18.88s/it]

Epoch 300 | Train loss :1.3867033869028091 | Train accuracy : 0.4037499879486859
Epoch 300 | Validation loss :1.4301941245794296 | Validation accuracy : 0.37249998888000846


  4%|▋                | 400/10000 [2:05:52<50:23:59, 18.90s/it]

Epoch 400 | Train loss :1.4045482762157917 | Train accuracy : 0.40249998890794814
Epoch 400 | Validation loss :1.4300137422978878 | Validation accuracy : 0.3712499886751175


  5%|▊                | 500/10000 [2:37:22<49:54:17, 18.91s/it]

Epoch 500 | Train loss :1.298897909000516 | Train accuracy : 0.4487499869428575
Epoch 500 | Validation loss :1.4095103740692139 | Validation accuracy : 0.3837499925866723


  6%|█                | 600/10000 [3:08:51<49:18:51, 18.89s/it]

Epoch 600 | Train loss :1.294132774695754 | Train accuracy : 0.4712499906308949
Epoch 600 | Validation loss :1.3467906098812819 | Validation accuracy : 0.42499998677521944


  7%|█▏               | 700/10000 [3:40:22<48:52:14, 18.92s/it]

Epoch 700 | Train loss :1.190301824361086 | Train accuracy : 0.5299999872222543
Epoch 700 | Validation loss :1.267466789111495 | Validation accuracy : 0.47374998684972525


  8%|█▎               | 800/10000 [4:11:53<48:19:49, 18.91s/it]

Epoch 800 | Train loss :1.2015194818377495 | Train accuracy : 0.5512499865144491
Epoch 800 | Validation loss :1.2691648304462433 | Validation accuracy : 0.4824999887496233


  9%|█▌               | 900/10000 [4:43:25<47:49:10, 18.92s/it]

Epoch 900 | Train loss :1.2194095235317945 | Train accuracy : 0.5262499856762588
Epoch 900 | Validation loss :1.273890845477581 | Validation accuracy : 0.4437499903142452


 10%|█▌              | 1000/10000 [5:14:56<47:17:58, 18.92s/it]

Epoch 1000 | Train loss :1.1131517495959997 | Train accuracy : 0.5762499845586717
Epoch 1000 | Validation loss :1.2274135667830706 | Validation accuracy : 0.49374998454004526


 11%|█▊              | 1100/10000 [5:46:27<46:44:11, 18.90s/it]

Epoch 1100 | Train loss :1.0755971185863018 | Train accuracy : 0.5887499852105975
Epoch 1100 | Validation loss :1.128284327685833 | Validation accuracy : 0.5587499849498272


 12%|█▉              | 1200/10000 [6:17:58<46:13:37, 18.91s/it]

Epoch 1200 | Train loss :1.1624013632535934 | Train accuracy : 0.5437499838881195
Epoch 1200 | Validation loss :1.2407455835491419 | Validation accuracy : 0.4799999864771962


 13%|██              | 1300/10000 [6:49:29<45:44:35, 18.93s/it]

Epoch 1300 | Train loss :1.0670703630894423 | Train accuracy : 0.5712499867659062
Epoch 1300 | Validation loss :1.1128732059150934 | Validation accuracy : 0.5624999860301614


 14%|██▏             | 1400/10000 [7:21:02<45:10:41, 18.91s/it]

Epoch 1400 | Train loss :1.1604851949959993 | Train accuracy : 0.5462499847635627
Epoch 1400 | Validation loss :1.1192103680223227 | Validation accuracy : 0.5437499871477485


 15%|██▍             | 1500/10000 [7:52:34<44:42:50, 18.94s/it]

Epoch 1500 | Train loss :1.086253447458148 | Train accuracy : 0.5987499807961285
Epoch 1500 | Validation loss :1.0837458949536085 | Validation accuracy : 0.5637499871663749
Average val_loss of last 500 epoch: 1.203187665939331


 16%|██▌             | 1600/10000 [8:24:06<44:11:48, 18.94s/it]

Epoch 1600 | Train loss :1.1522970497608185 | Train accuracy : 0.5274999882094562
Epoch 1600 | Validation loss :1.1512635350227356 | Validation accuracy : 0.5362499849870801


 17%|██▋             | 1700/10000 [8:55:38<43:36:57, 18.92s/it]

Epoch 1700 | Train loss :0.9640792971476912 | Train accuracy : 0.6149999883491546
Epoch 1700 | Validation loss :1.0579181350767612 | Validation accuracy : 0.5649999836459756


 18%|██▉             | 1800/10000 [9:27:11<43:04:26, 18.91s/it]

Epoch 1800 | Train loss :1.0144678819924593 | Train accuracy : 0.6212499812245369
Epoch 1800 | Validation loss :1.0333731528371572 | Validation accuracy : 0.5887499852105975


 19%|███             | 1900/10000 [9:58:43<42:36:40, 18.94s/it]

Epoch 1900 | Train loss :1.108295918442309 | Train accuracy : 0.5737499864771962
Epoch 1900 | Validation loss :1.0920731965452433 | Validation accuracy : 0.5662499880418181


 20%|███            | 2000/10000 [10:30:15<42:02:33, 18.92s/it]

Epoch 2000 | Train loss :0.9476522896438837 | Train accuracy : 0.6112499805167317
Epoch 2000 | Validation loss :0.973399356007576 | Validation accuracy : 0.5974999843165278


 21%|███▏           | 2100/10000 [11:01:47<41:30:04, 18.91s/it]

Epoch 2100 | Train loss :1.0413963794708252 | Train accuracy : 0.5887499842792749
Epoch 2100 | Validation loss :1.0713344290852547 | Validation accuracy : 0.5774999922141433


 22%|███▎           | 2200/10000 [11:33:19<40:58:05, 18.91s/it]

Epoch 2200 | Train loss :0.9842305118218064 | Train accuracy : 0.6562499795109034
Epoch 2200 | Validation loss :1.0179340466856956 | Validation accuracy : 0.6112499814480543


 22%|███▍           | 2250/10000 [11:49:06<40:45:58, 18.94s/it]

Average val_loss of last 500 epoch: 1.0766387531496584


 23%|███▍           | 2300/10000 [12:04:53<40:30:06, 18.94s/it]

Epoch 2300 | Train loss :1.0062467949464917 | Train accuracy : 0.6049999808892608
Epoch 2300 | Validation loss :0.9460050333291292 | Validation accuracy : 0.6274999845772982


 24%|███▌           | 2400/10000 [12:36:26<39:57:31, 18.93s/it]

Epoch 2400 | Train loss :0.8934958167374134 | Train accuracy : 0.6787499822676182
Epoch 2400 | Validation loss :1.022388856858015 | Validation accuracy : 0.5762499859556556


 25%|███▊           | 2500/10000 [13:08:01<39:30:55, 18.97s/it]

Epoch 2500 | Train loss :1.005561763420701 | Train accuracy : 0.6137499811593443
Epoch 2500 | Validation loss :0.9771730322390795 | Validation accuracy : 0.602499982342124


 26%|███▉           | 2600/10000 [13:39:39<39:04:34, 19.01s/it]

Epoch 2600 | Train loss :0.9313857229426503 | Train accuracy : 0.6374999796971679
Epoch 2600 | Validation loss :0.922493421472609 | Validation accuracy : 0.6312499865889549


 27%|████           | 2700/10000 [14:11:04<38:01:34, 18.75s/it]

Epoch 2700 | Train loss :0.9955679848790169 | Train accuracy : 0.6599999866448343
Epoch 2700 | Validation loss :0.9474103953689337 | Validation accuracy : 0.6287499889731407


 28%|████▏          | 2800/10000 [14:42:17<37:26:28, 18.72s/it]

Epoch 2800 | Train loss :0.9918023757636547 | Train accuracy : 0.6162499897181988
Epoch 2800 | Validation loss :0.9926697909832001 | Validation accuracy : 0.622499986551702


 29%|████▎          | 2900/10000 [15:13:30<36:55:43, 18.72s/it]

Epoch 2900 | Train loss :0.9090625764802098 | Train accuracy : 0.658749982714653
Epoch 2900 | Validation loss :0.9916948582977057 | Validation accuracy : 0.5799999898299575


 30%|████▌          | 3000/10000 [15:44:43<36:25:34, 18.73s/it]

Epoch 3000 | Train loss :0.9249895019456744 | Train accuracy : 0.6787499794736505
Epoch 3000 | Validation loss :0.9227081909775734 | Validation accuracy : 0.6362499808892608
Average val_loss of last 500 epoch: 0.9874991498036931


 31%|████▋          | 3100/10000 [16:15:55<35:52:01, 18.71s/it]

Epoch 3100 | Train loss :0.8060100181028247 | Train accuracy : 0.7149999802932143
Epoch 3100 | Validation loss :0.9484586380422115 | Validation accuracy : 0.5987499840557575


 32%|████▊          | 3200/10000 [16:47:05<35:20:48, 18.71s/it]

Epoch 3200 | Train loss :0.8286797730252147 | Train accuracy : 0.7187499813735485
Epoch 3200 | Validation loss :0.9486202225089073 | Validation accuracy : 0.6149999853223562


 33%|████▉          | 3300/10000 [17:18:14<34:47:02, 18.69s/it]

Epoch 3300 | Train loss :0.8642847863957286 | Train accuracy : 0.707499978132546
Epoch 3300 | Validation loss :0.8891632743179798 | Validation accuracy : 0.6537499837577343


 34%|█████          | 3400/10000 [17:45:07<17:03:24,  9.30s/it]

Epoch 3400 | Train loss :0.9294671658426523 | Train accuracy : 0.6324999842327088
Epoch 3400 | Validation loss :0.9414457846432924 | Validation accuracy : 0.6362499846145511


 35%|█████▎         | 3500/10000 [18:00:37<16:48:00,  9.30s/it]

Epoch 3500 | Train loss :0.8891032291576266 | Train accuracy : 0.6812499780207872
Epoch 3500 | Validation loss :0.8960882890969515 | Validation accuracy : 0.6324999798089266


 36%|█████▍         | 3600/10000 [18:16:06<16:31:30,  9.30s/it]

Epoch 3600 | Train loss :0.8701068093068898 | Train accuracy : 0.6749999793246388
Epoch 3600 | Validation loss :0.8881254568696022 | Validation accuracy : 0.6449999883770943


 37%|█████▌         | 3700/10000 [18:31:35<16:14:23,  9.28s/it]

Epoch 3700 | Train loss :0.853585266508162 | Train accuracy : 0.7037499831058085
Epoch 3700 | Validation loss :0.9135750913992524 | Validation accuracy : 0.6299999840557575


 38%|█████▋         | 3750/10000 [18:39:21<16:06:38,  9.28s/it]

Average val_loss of last 500 epoch: 0.901548601815477


 38%|█████▋         | 3800/10000 [18:51:27<32:07:04, 18.65s/it]

Epoch 3800 | Train loss :0.8910126686096191 | Train accuracy : 0.6624999791383743
Epoch 3800 | Validation loss :0.7999071041122079 | Validation accuracy : 0.6899999883025885


 39%|█████▊         | 3900/10000 [19:22:29<31:32:17, 18.61s/it]

Epoch 3900 | Train loss :0.9429361801594496 | Train accuracy : 0.6474999813362956
Epoch 3900 | Validation loss :0.8932441091164947 | Validation accuracy : 0.6499999826774001


 40%|██████         | 4000/10000 [19:53:30<31:01:38, 18.62s/it]

Epoch 4000 | Train loss :0.7629273247439414 | Train accuracy : 0.7574999779462814
Epoch 4000 | Validation loss :0.8252092069014907 | Validation accuracy : 0.674999981187284


 41%|██████▏        | 4100/10000 [20:24:31<30:32:43, 18.64s/it]

Epoch 4100 | Train loss :0.9054587306454778 | Train accuracy : 0.6874999739229679
Epoch 4100 | Validation loss :0.8079782091081142 | Validation accuracy : 0.6924999821931124


 42%|██████▎        | 4200/10000 [20:55:34<30:04:31, 18.67s/it]

Epoch 4200 | Train loss :0.8136979686096311 | Train accuracy : 0.6924999866168946
Epoch 4200 | Validation loss :0.8617292568087578 | Validation accuracy : 0.6712499829009175


 43%|██████▍        | 4300/10000 [21:26:41<29:37:59, 18.72s/it]

Epoch 4300 | Train loss :0.7808879534713924 | Train accuracy : 0.7549999821931124
Epoch 4300 | Validation loss :0.8410601392388344 | Validation accuracy : 0.668749981559813


 44%|██████▌        | 4400/10000 [21:57:53<29:06:21, 18.71s/it]

Epoch 4400 | Train loss :0.7986296750605106 | Train accuracy : 0.731249981559813
Epoch 4400 | Validation loss :0.8146346090361476 | Validation accuracy : 0.6912499824538827


 45%|██████▊        | 4500/10000 [22:29:07<28:42:30, 18.79s/it]

Epoch 4500 | Train loss :0.7658988051116467 | Train accuracy : 0.6874999776482582
Epoch 4500 | Validation loss :0.8220382854342461 | Validation accuracy : 0.6899999836459756
Average val_loss of last 500 epoch: 0.8368894002195447


 46%|██████▉        | 4600/10000 [23:00:22<28:11:30, 18.79s/it]

Epoch 4600 | Train loss :0.7466665636748075 | Train accuracy : 0.7437499854713678
Epoch 4600 | Validation loss :0.8117775609716773 | Validation accuracy : 0.6912499843165278


 47%|███████        | 4700/10000 [23:31:42<27:40:53, 18.80s/it]

Epoch 4700 | Train loss :0.7312054326757789 | Train accuracy : 0.7712499750778079
Epoch 4700 | Validation loss :0.8227842189371586 | Validation accuracy : 0.6824999805539846


 48%|███████▏       | 4800/10000 [24:03:00<27:08:45, 18.79s/it]

Epoch 4800 | Train loss :0.8244198141619563 | Train accuracy : 0.6824999814853072
Epoch 4800 | Validation loss :0.7742259176447988 | Validation accuracy : 0.7037499854341149


 49%|███████▎       | 4900/10000 [24:34:22<26:38:38, 18.81s/it]

Epoch 4900 | Train loss :0.7257954252418131 | Train accuracy : 0.7499999841675162
Epoch 4900 | Validation loss :0.7877708654850721 | Validation accuracy : 0.7149999821558595


 50%|███████▌       | 5000/10000 [25:05:44<26:09:21, 18.83s/it]

Epoch 5000 | Train loss :0.7638338706456125 | Train accuracy : 0.746249983087182
Epoch 5000 | Validation loss :0.7579178223386407 | Validation accuracy : 0.7237499821931124


 51%|███████▋       | 5100/10000 [25:37:07<25:38:08, 18.83s/it]

Epoch 5100 | Train loss :0.7687619253993034 | Train accuracy : 0.7287499830126762
Epoch 5100 | Validation loss :0.7406299188733101 | Validation accuracy : 0.7149999793618917


 52%|███████▊       | 5200/10000 [26:08:30<25:06:36, 18.83s/it]

Epoch 5200 | Train loss :0.7699375851079822 | Train accuracy : 0.712499986635521
Epoch 5200 | Validation loss :0.7766423420980573 | Validation accuracy : 0.6962499804794788


 52%|███████▉       | 5250/10000 [26:24:12<24:51:18, 18.84s/it]

Average val_loss of last 500 epoch: 0.7851526381323735


 53%|███████▉       | 5300/10000 [26:39:54<24:38:12, 18.87s/it]

Epoch 5300 | Train loss :0.7149089979939163 | Train accuracy : 0.7362499786540866
Epoch 5300 | Validation loss :0.7807506620883942 | Validation accuracy : 0.6987499799579382


 54%|████████       | 5400/10000 [27:10:19<24:07:43, 18.88s/it]

Epoch 5400 | Train loss :0.7074364367872477 | Train accuracy : 0.7787499874830246
Epoch 5400 | Validation loss :0.7877695886418223 | Validation accuracy : 0.6974999820813537


 55%|████████▎      | 5500/10000 [27:41:45<23:34:04, 18.85s/it]

Epoch 5500 | Train loss :0.6194571782834828 | Train accuracy : 0.8212499776855111
Epoch 5500 | Validation loss :0.6904381047934294 | Validation accuracy : 0.7549999784678221


 56%|████████▍      | 5600/10000 [28:13:43<23:02:07, 18.85s/it]

Epoch 5600 | Train loss :0.6879879427142441 | Train accuracy : 0.7587499744258821
Epoch 5600 | Validation loss :0.709243499673903 | Validation accuracy : 0.7399999797344208


 57%|████████▌      | 5700/10000 [28:45:09<22:32:54, 18.88s/it]

Epoch 5700 | Train loss :0.5792372026480734 | Train accuracy : 0.8287499770522118
Epoch 5700 | Validation loss :0.6969082178547978 | Validation accuracy : 0.7324999775737524


 58%|████████▋      | 5800/10000 [29:16:36<21:59:12, 18.85s/it]

Epoch 5800 | Train loss :0.6733779720962048 | Train accuracy : 0.7574999770149589
Epoch 5800 | Validation loss :0.7194994706660509 | Validation accuracy : 0.7187499795109034


 59%|████████▊      | 5900/10000 [29:48:03<21:29:53, 18.88s/it]

Epoch 5900 | Train loss :0.7060918561182916 | Train accuracy : 0.7699999772012234
Epoch 5900 | Validation loss :0.7341092824935913 | Validation accuracy : 0.722499979659915


 60%|█████████      | 6000/10000 [30:19:30<20:57:30, 18.86s/it]

Epoch 6000 | Train loss :0.7978540668264031 | Train accuracy : 0.7374999783933163
Epoch 6000 | Validation loss :0.7171809636056423 | Validation accuracy : 0.7049999823793769
Average val_loss of last 500 epoch: 0.7247789400083323


 61%|█████████▏     | 6100/10000 [30:50:56<20:26:21, 18.87s/it]

Epoch 6100 | Train loss :0.7725830161944032 | Train accuracy : 0.7599999764934182
Epoch 6100 | Validation loss :0.6552701080217957 | Validation accuracy : 0.772499980404973


 62%|█████████▎     | 6200/10000 [31:22:20<19:51:50, 18.82s/it]

Epoch 6200 | Train loss :0.5347561407834291 | Train accuracy : 0.8099999809637666
Epoch 6200 | Validation loss :0.7674911953508854 | Validation accuracy : 0.7062499783933163


 63%|█████████▍     | 6300/10000 [31:53:44<19:20:42, 18.82s/it]

Epoch 6300 | Train loss :0.6611708416603506 | Train accuracy : 0.7874999777413905
Epoch 6300 | Validation loss :0.6541992966085672 | Validation accuracy : 0.7662499770522118


 64%|█████████▌     | 6400/10000 [32:25:05<18:49:34, 18.83s/it]

Epoch 6400 | Train loss :0.6544985007494688 | Train accuracy : 0.7812499767169356
Epoch 6400 | Validation loss :0.747683672234416 | Validation accuracy : 0.7174999816343188


 65%|█████████▊     | 6500/10000 [32:56:25<18:17:28, 18.81s/it]

Epoch 6500 | Train loss :0.6377438041381538 | Train accuracy : 0.7987499795854092
Epoch 6500 | Validation loss :0.6738097527995706 | Validation accuracy : 0.7362499814480543


 66%|█████████▉     | 6600/10000 [33:27:44<17:43:04, 18.76s/it]

Epoch 6600 | Train loss :0.6963828611187637 | Train accuracy : 0.7462499821558595
Epoch 6600 | Validation loss :0.6359192552044988 | Validation accuracy : 0.7649999801069498


 67%|██████████     | 6700/10000 [33:59:03<17:12:07, 18.77s/it]

Epoch 6700 | Train loss :0.6450179838575423 | Train accuracy : 0.7599999811500311
Epoch 6700 | Validation loss :0.7127655986696482 | Validation accuracy : 0.7374999802559614


 68%|██████████▏    | 6750/10000 [34:14:41<16:55:26, 18.75s/it]

Average val_loss of last 500 epoch: 0.6875400129590804


 68%|██████████▏    | 6800/10000 [34:30:20<16:40:51, 18.77s/it]

Epoch 6800 | Train loss :0.5939934230409563 | Train accuracy : 0.8062499761581421
Epoch 6800 | Validation loss :0.5974295199848711 | Validation accuracy : 0.7824999755248427


 69%|██████████▎    | 6900/10000 [35:01:37<16:10:04, 18.78s/it]

Epoch 6900 | Train loss :0.6644518007524312 | Train accuracy : 0.803749980404973
Epoch 6900 | Validation loss :0.6873088236898184 | Validation accuracy : 0.7587499851360917


 70%|██████████▌    | 7000/10000 [35:32:52<15:39:24, 18.79s/it]

Epoch 7000 | Train loss :0.6770240946207196 | Train accuracy : 0.7937499796971679
Epoch 7000 | Validation loss :0.672240599989891 | Validation accuracy : 0.7599999764934182


 71%|██████████▋    | 7100/10000 [36:04:06<15:06:33, 18.76s/it]

Epoch 7100 | Train loss :0.6574778147041798 | Train accuracy : 0.8062499836087227
Epoch 7100 | Validation loss :0.6489347955211997 | Validation accuracy : 0.7562499828636646


 72%|██████████▊    | 7200/10000 [36:35:22<14:36:28, 18.78s/it]

Epoch 7200 | Train loss :0.5834554438479245 | Train accuracy : 0.8224999830126762
Epoch 7200 | Validation loss :0.610069687012583 | Validation accuracy : 0.7649999745190144


 73%|██████████▉    | 7300/10000 [37:06:38<14:03:48, 18.75s/it]

Epoch 7300 | Train loss :0.6885086970869452 | Train accuracy : 0.7774999812245369
Epoch 7300 | Validation loss :0.6094815162941813 | Validation accuracy : 0.7637499738484621


 74%|███████████    | 7400/10000 [37:37:53<13:33:30, 18.77s/it]

Epoch 7400 | Train loss :0.6585623703431338 | Train accuracy : 0.7812499785795808
Epoch 7400 | Validation loss :0.6896035056561232 | Validation accuracy : 0.7437499798834324


 75%|███████████▎   | 7500/10000 [38:09:08<13:00:03, 18.72s/it]

Epoch 7500 | Train loss :0.6164632108993828 | Train accuracy : 0.8124999795109034
Epoch 7500 | Validation loss :0.602304286789149 | Validation accuracy : 0.7724999841302633
Average val_loss of last 500 epoch: 0.6639788341152792


 76%|███████████▍   | 7600/10000 [38:40:18<12:27:42, 18.69s/it]

Epoch 7600 | Train loss :0.5815018797293305 | Train accuracy : 0.801249978132546
Epoch 7600 | Validation loss :0.6581033309921622 | Validation accuracy : 0.7474999791011214


 77%|███████████▌   | 7700/10000 [39:12:00<12:30:35, 19.58s/it]

Epoch 7700 | Train loss :0.541153475176543 | Train accuracy : 0.8574999812990427
Epoch 7700 | Validation loss :0.6062654983252287 | Validation accuracy : 0.7824999783188105


 78%|███████████▋   | 7800/10000 [39:43:38<11:23:44, 18.65s/it]

Epoch 7800 | Train loss :0.704782550688833 | Train accuracy : 0.7562499763444066
Epoch 7800 | Validation loss :0.6873286608606577 | Validation accuracy : 0.7424999792128801


 79%|███████████▊   | 7900/10000 [40:14:46<10:54:01, 18.69s/it]

Epoch 7900 | Train loss :0.5309466957114637 | Train accuracy : 0.8287499747239053
Epoch 7900 | Validation loss :0.6511427778750658 | Validation accuracy : 0.7849999740719795


 80%|████████████   | 8000/10000 [40:45:54<10:23:09, 18.69s/it]

Epoch 8000 | Train loss :0.7210809206590056 | Train accuracy : 0.7662499817088246
Epoch 8000 | Validation loss :0.685418195091188 | Validation accuracy : 0.7424999792128801


 81%|████████████▉   | 8100/10000 [41:17:03<9:51:28, 18.68s/it]

Epoch 8100 | Train loss :0.6429456085897982 | Train accuracy : 0.7999999886378646
Epoch 8100 | Validation loss :0.7281881812959909 | Validation accuracy : 0.7224999833852053


 82%|█████████████   | 8200/10000 [41:48:13<9:20:07, 18.67s/it]

Epoch 8200 | Train loss :0.5745856761932373 | Train accuracy : 0.8524999804794788
Epoch 8200 | Validation loss :0.6356121366843581 | Validation accuracy : 0.7624999769032001


 82%|█████████████▏  | 8237/10000 [41:59:46<9:10:32, 18.74s/it]

In [1]:
def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w
df1 = pd.DataFrame(results["train_acc_list"],columns=['train_accuracy'])
df2 = pd.DataFrame([i for i in results["val_acc_list"]],columns=['val_accuracy'])
df = pd.concat([df1,df2],axis=1)
df.plot()

NameError: name 'pd' is not defined

In [None]:
df1 = pd.DataFrame(results["train_loss_list"][1000:],columns=['train_loss'])
df2 = pd.DataFrame([i for i in results["val_loss_list"][1000:]],columns=['val_loss'])
df = pd.concat([df1,df2],axis=1)
df.plot()

# Testing

In [None]:
torch.save(meta_model,"./model/resnet_task_augmentation_queryaug_v3/task_augmentation_v3.pt")

In [None]:
model = torch.load("./model/resnet_task_augmentation_queryaug_v3/task_augmentation_v3.pt").to(device)
pred_res = test(model, test_data, loss_func, device)
df = pd.concat([pd.Series(range(15000)),pd.Series(pred_res)],axis = 1)
df = df.rename(columns={df.columns[0]:"Id",df.columns[1]:"Category"})
df.to_csv("submission_taskaug_v3.csv", index =False)

In [30]:
pred_res = test(meta_model, test_data, loss_func, device)
df = pd.concat([pd.Series(range(15000)),pd.Series(pred_res)],axis = 1)
df = df.rename(columns={df.columns[0]:"Id",df.columns[1]:"Category"})
df.to_csv("submission_taskaug.csv", index =False)

100%|████████████████████████| 600/600 [01:01<00:00,  9.70it/s]


In [289]:
for i in range(3):
    model_param=f"./model/resnet_task_augmentation/model_{8000+500*i}.pt"

    model = l2l.vision.models.ResNet12(output_size=ways).to(device)
    meta_model = l2l.algorithms.MAML(model, lr=maml_lr,first_order=False)
    param = torch.load(model_param)
    
    meta_model.load_state_dict(param)
    meta_model = meta_model.to(device)
    meta_model.eval()
    pred_res = test(meta_model, test_data, loss_func, device)
    df = pd.concat([pd.Series(range(15000)),pd.Series(pred_res)],axis = 1)
    df = df.rename(columns={df.columns[0]:"Id",df.columns[1]:"Category"})
    df.to_csv(f"submission_{7500+500*i}.csv", index =False)

100%|████████████████████████| 600/600 [00:27<00:00, 22.04it/s]
100%|████████████████████████| 600/600 [00:27<00:00, 21.82it/s]
100%|████████████████████████| 600/600 [00:27<00:00, 21.71it/s]


In [80]:
model_param=f"./model/resnet_task_augmentation/final_model.pt"

model = l2l.vision.models.ResNet12(output_size=ways).to(device)
meta_model = l2l.algorithms.MAML(model, lr=maml_lr,first_order=False)
param = torch.load(model_param)

meta_model.load_state_dict(param)
meta_model = meta_model.to(device)
meta_model.eval()
pred_res = test(meta_model, test_data, loss_func, device)
df = pd.concat([pd.Series(range(15000)),pd.Series(pred_res)],axis = 1)
df = df.rename(columns={df.columns[0]:"Id",df.columns[1]:"Category"})
df.to_csv(f"submission_final.csv", index =False)

NameError: name 'test' is not defined

In [83]:
meta_model.eval()
pred_res = test(meta_model, test_data, loss_func, device)
df = pd.concat([pd.Series(range(15000)),pd.Series(pred_res)],axis = 1)
df = df.rename(columns={df.columns[0]:"Id",df.columns[1]:"Category"})
df.to_csv("submission_final.csv", index =False)

100%|████████████████████████| 600/600 [00:27<00:00, 22.11it/s]
